snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class HuberRegressor(BaseTransformer):
|
70
64
|
r"""L2-regularized linear regression model that is robust to outliers
|
71
65
|
For more details on this class, see [sklearn.linear_model.HuberRegressor]
|
@@ -293,20 +287,17 @@ class HuberRegressor(BaseTransformer):
|
|
293
287
|
self,
|
294
288
|
dataset: DataFrame,
|
295
289
|
inference_method: str,
|
296
|
-
) ->
|
297
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
298
|
-
return the available package that exists in the snowflake anaconda channel
|
290
|
+
) -> None:
|
291
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
299
292
|
|
300
293
|
Args:
|
301
294
|
dataset: snowpark dataframe
|
302
295
|
inference_method: the inference method such as predict, score...
|
303
|
-
|
296
|
+
|
304
297
|
Raises:
|
305
298
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
306
299
|
SnowflakeMLException: If the session is None, raise error
|
307
300
|
|
308
|
-
Returns:
|
309
|
-
A list of available package that exists in the snowflake anaconda channel
|
310
301
|
"""
|
311
302
|
if not self._is_fitted:
|
312
303
|
raise exceptions.SnowflakeMLException(
|
@@ -324,9 +315,7 @@ class HuberRegressor(BaseTransformer):
|
|
324
315
|
"Session must not specified for snowpark dataset."
|
325
316
|
),
|
326
317
|
)
|
327
|
-
|
328
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
329
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
318
|
+
|
330
319
|
|
331
320
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
332
321
|
@telemetry.send_api_usage_telemetry(
|
@@ -374,7 +363,8 @@ class HuberRegressor(BaseTransformer):
|
|
374
363
|
|
375
364
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
376
365
|
|
377
|
-
self.
|
366
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
367
|
+
self._deps = self._get_dependencies()
|
378
368
|
assert isinstance(
|
379
369
|
dataset._session, Session
|
380
370
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -457,10 +447,8 @@ class HuberRegressor(BaseTransformer):
|
|
457
447
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
458
448
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
459
449
|
|
460
|
-
self.
|
461
|
-
|
462
|
-
inference_method=inference_method,
|
463
|
-
)
|
450
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
451
|
+
self._deps = self._get_dependencies()
|
464
452
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
465
453
|
|
466
454
|
transform_kwargs = dict(
|
@@ -527,16 +515,40 @@ class HuberRegressor(BaseTransformer):
|
|
527
515
|
self._is_fitted = True
|
528
516
|
return output_result
|
529
517
|
|
518
|
+
|
519
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
520
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
521
|
+
""" Method not supported for this class.
|
530
522
|
|
531
|
-
|
532
|
-
|
533
|
-
|
523
|
+
|
524
|
+
Raises:
|
525
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
526
|
+
|
527
|
+
Args:
|
528
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
529
|
+
Snowpark or Pandas DataFrame.
|
530
|
+
output_cols_prefix: Prefix for the response columns
|
534
531
|
Returns:
|
535
532
|
Transformed dataset.
|
536
533
|
"""
|
537
|
-
self.
|
538
|
-
|
539
|
-
|
534
|
+
self._infer_input_output_cols(dataset)
|
535
|
+
super()._check_dataset_type(dataset)
|
536
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
537
|
+
estimator=self._sklearn_object,
|
538
|
+
dataset=dataset,
|
539
|
+
input_cols=self.input_cols,
|
540
|
+
label_cols=self.label_cols,
|
541
|
+
sample_weight_col=self.sample_weight_col,
|
542
|
+
autogenerated=self._autogenerated,
|
543
|
+
subproject=_SUBPROJECT,
|
544
|
+
)
|
545
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
546
|
+
drop_input_cols=self._drop_input_cols,
|
547
|
+
expected_output_cols_list=self.output_cols,
|
548
|
+
)
|
549
|
+
self._sklearn_object = fitted_estimator
|
550
|
+
self._is_fitted = True
|
551
|
+
return output_result
|
540
552
|
|
541
553
|
|
542
554
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -627,10 +639,8 @@ class HuberRegressor(BaseTransformer):
|
|
627
639
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
628
640
|
|
629
641
|
if isinstance(dataset, DataFrame):
|
630
|
-
self.
|
631
|
-
|
632
|
-
inference_method=inference_method,
|
633
|
-
)
|
642
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
643
|
+
self._deps = self._get_dependencies()
|
634
644
|
assert isinstance(
|
635
645
|
dataset._session, Session
|
636
646
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -695,10 +705,8 @@ class HuberRegressor(BaseTransformer):
|
|
695
705
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
696
706
|
|
697
707
|
if isinstance(dataset, DataFrame):
|
698
|
-
self.
|
699
|
-
|
700
|
-
inference_method=inference_method,
|
701
|
-
)
|
708
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
709
|
+
self._deps = self._get_dependencies()
|
702
710
|
assert isinstance(
|
703
711
|
dataset._session, Session
|
704
712
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -760,10 +768,8 @@ class HuberRegressor(BaseTransformer):
|
|
760
768
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
761
769
|
|
762
770
|
if isinstance(dataset, DataFrame):
|
763
|
-
self.
|
764
|
-
|
765
|
-
inference_method=inference_method,
|
766
|
-
)
|
771
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
772
|
+
self._deps = self._get_dependencies()
|
767
773
|
assert isinstance(
|
768
774
|
dataset._session, Session
|
769
775
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -829,10 +835,8 @@ class HuberRegressor(BaseTransformer):
|
|
829
835
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
830
836
|
|
831
837
|
if isinstance(dataset, DataFrame):
|
832
|
-
self.
|
833
|
-
|
834
|
-
inference_method=inference_method,
|
835
|
-
)
|
838
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
839
|
+
self._deps = self._get_dependencies()
|
836
840
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
837
841
|
transform_kwargs = dict(
|
838
842
|
session=dataset._session,
|
@@ -896,17 +900,15 @@ class HuberRegressor(BaseTransformer):
|
|
896
900
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
897
901
|
|
898
902
|
if isinstance(dataset, DataFrame):
|
899
|
-
self.
|
900
|
-
|
901
|
-
inference_method="score",
|
902
|
-
)
|
903
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
904
|
+
self._deps = self._get_dependencies()
|
903
905
|
selected_cols = self._get_active_columns()
|
904
906
|
if len(selected_cols) > 0:
|
905
907
|
dataset = dataset.select(selected_cols)
|
906
908
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
907
909
|
transform_kwargs = dict(
|
908
910
|
session=dataset._session,
|
909
|
-
dependencies=
|
911
|
+
dependencies=self._deps,
|
910
912
|
score_sproc_imports=['sklearn'],
|
911
913
|
)
|
912
914
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -971,11 +973,8 @@ class HuberRegressor(BaseTransformer):
|
|
971
973
|
|
972
974
|
if isinstance(dataset, DataFrame):
|
973
975
|
|
974
|
-
self.
|
975
|
-
|
976
|
-
inference_method=inference_method,
|
977
|
-
|
978
|
-
)
|
976
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
977
|
+
self._deps = self._get_dependencies()
|
979
978
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
980
979
|
transform_kwargs = dict(
|
981
980
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class Lars(BaseTransformer):
|
70
64
|
r"""Least Angle Regression model a
|
71
65
|
For more details on this class, see [sklearn.linear_model.Lars]
|
@@ -322,20 +316,17 @@ class Lars(BaseTransformer):
|
|
322
316
|
self,
|
323
317
|
dataset: DataFrame,
|
324
318
|
inference_method: str,
|
325
|
-
) ->
|
326
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
327
|
-
return the available package that exists in the snowflake anaconda channel
|
319
|
+
) -> None:
|
320
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
328
321
|
|
329
322
|
Args:
|
330
323
|
dataset: snowpark dataframe
|
331
324
|
inference_method: the inference method such as predict, score...
|
332
|
-
|
325
|
+
|
333
326
|
Raises:
|
334
327
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
335
328
|
SnowflakeMLException: If the session is None, raise error
|
336
329
|
|
337
|
-
Returns:
|
338
|
-
A list of available package that exists in the snowflake anaconda channel
|
339
330
|
"""
|
340
331
|
if not self._is_fitted:
|
341
332
|
raise exceptions.SnowflakeMLException(
|
@@ -353,9 +344,7 @@ class Lars(BaseTransformer):
|
|
353
344
|
"Session must not specified for snowpark dataset."
|
354
345
|
),
|
355
346
|
)
|
356
|
-
|
357
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
358
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
347
|
+
|
359
348
|
|
360
349
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
361
350
|
@telemetry.send_api_usage_telemetry(
|
@@ -403,7 +392,8 @@ class Lars(BaseTransformer):
|
|
403
392
|
|
404
393
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
405
394
|
|
406
|
-
self.
|
395
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
396
|
+
self._deps = self._get_dependencies()
|
407
397
|
assert isinstance(
|
408
398
|
dataset._session, Session
|
409
399
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -486,10 +476,8 @@ class Lars(BaseTransformer):
|
|
486
476
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
487
477
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
488
478
|
|
489
|
-
self.
|
490
|
-
|
491
|
-
inference_method=inference_method,
|
492
|
-
)
|
479
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
480
|
+
self._deps = self._get_dependencies()
|
493
481
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
494
482
|
|
495
483
|
transform_kwargs = dict(
|
@@ -556,16 +544,40 @@ class Lars(BaseTransformer):
|
|
556
544
|
self._is_fitted = True
|
557
545
|
return output_result
|
558
546
|
|
547
|
+
|
548
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
549
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
550
|
+
""" Method not supported for this class.
|
559
551
|
|
560
|
-
|
561
|
-
|
562
|
-
|
552
|
+
|
553
|
+
Raises:
|
554
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
558
|
+
Snowpark or Pandas DataFrame.
|
559
|
+
output_cols_prefix: Prefix for the response columns
|
563
560
|
Returns:
|
564
561
|
Transformed dataset.
|
565
562
|
"""
|
566
|
-
self.
|
567
|
-
|
568
|
-
|
563
|
+
self._infer_input_output_cols(dataset)
|
564
|
+
super()._check_dataset_type(dataset)
|
565
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
566
|
+
estimator=self._sklearn_object,
|
567
|
+
dataset=dataset,
|
568
|
+
input_cols=self.input_cols,
|
569
|
+
label_cols=self.label_cols,
|
570
|
+
sample_weight_col=self.sample_weight_col,
|
571
|
+
autogenerated=self._autogenerated,
|
572
|
+
subproject=_SUBPROJECT,
|
573
|
+
)
|
574
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
575
|
+
drop_input_cols=self._drop_input_cols,
|
576
|
+
expected_output_cols_list=self.output_cols,
|
577
|
+
)
|
578
|
+
self._sklearn_object = fitted_estimator
|
579
|
+
self._is_fitted = True
|
580
|
+
return output_result
|
569
581
|
|
570
582
|
|
571
583
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -656,10 +668,8 @@ class Lars(BaseTransformer):
|
|
656
668
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
657
669
|
|
658
670
|
if isinstance(dataset, DataFrame):
|
659
|
-
self.
|
660
|
-
|
661
|
-
inference_method=inference_method,
|
662
|
-
)
|
671
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
672
|
+
self._deps = self._get_dependencies()
|
663
673
|
assert isinstance(
|
664
674
|
dataset._session, Session
|
665
675
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -724,10 +734,8 @@ class Lars(BaseTransformer):
|
|
724
734
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
725
735
|
|
726
736
|
if isinstance(dataset, DataFrame):
|
727
|
-
self.
|
728
|
-
|
729
|
-
inference_method=inference_method,
|
730
|
-
)
|
737
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
738
|
+
self._deps = self._get_dependencies()
|
731
739
|
assert isinstance(
|
732
740
|
dataset._session, Session
|
733
741
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -789,10 +797,8 @@ class Lars(BaseTransformer):
|
|
789
797
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
790
798
|
|
791
799
|
if isinstance(dataset, DataFrame):
|
792
|
-
self.
|
793
|
-
|
794
|
-
inference_method=inference_method,
|
795
|
-
)
|
800
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
801
|
+
self._deps = self._get_dependencies()
|
796
802
|
assert isinstance(
|
797
803
|
dataset._session, Session
|
798
804
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -858,10 +864,8 @@ class Lars(BaseTransformer):
|
|
858
864
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
859
865
|
|
860
866
|
if isinstance(dataset, DataFrame):
|
861
|
-
self.
|
862
|
-
|
863
|
-
inference_method=inference_method,
|
864
|
-
)
|
867
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
868
|
+
self._deps = self._get_dependencies()
|
865
869
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
866
870
|
transform_kwargs = dict(
|
867
871
|
session=dataset._session,
|
@@ -925,17 +929,15 @@ class Lars(BaseTransformer):
|
|
925
929
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
926
930
|
|
927
931
|
if isinstance(dataset, DataFrame):
|
928
|
-
self.
|
929
|
-
|
930
|
-
inference_method="score",
|
931
|
-
)
|
932
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
933
|
+
self._deps = self._get_dependencies()
|
932
934
|
selected_cols = self._get_active_columns()
|
933
935
|
if len(selected_cols) > 0:
|
934
936
|
dataset = dataset.select(selected_cols)
|
935
937
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
936
938
|
transform_kwargs = dict(
|
937
939
|
session=dataset._session,
|
938
|
-
dependencies=
|
940
|
+
dependencies=self._deps,
|
939
941
|
score_sproc_imports=['sklearn'],
|
940
942
|
)
|
941
943
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1000,11 +1002,8 @@ class Lars(BaseTransformer):
|
|
1000
1002
|
|
1001
1003
|
if isinstance(dataset, DataFrame):
|
1002
1004
|
|
1003
|
-
self.
|
1004
|
-
|
1005
|
-
inference_method=inference_method,
|
1006
|
-
|
1007
|
-
)
|
1005
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1006
|
+
self._deps = self._get_dependencies()
|
1008
1007
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1009
1008
|
transform_kwargs = dict(
|
1010
1009
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LarsCV(BaseTransformer):
|
70
64
|
r"""Cross-validated Least Angle Regression model
|
71
65
|
For more details on this class, see [sklearn.linear_model.LarsCV]
|
@@ -330,20 +324,17 @@ class LarsCV(BaseTransformer):
|
|
330
324
|
self,
|
331
325
|
dataset: DataFrame,
|
332
326
|
inference_method: str,
|
333
|
-
) ->
|
334
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
335
|
-
return the available package that exists in the snowflake anaconda channel
|
327
|
+
) -> None:
|
328
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
336
329
|
|
337
330
|
Args:
|
338
331
|
dataset: snowpark dataframe
|
339
332
|
inference_method: the inference method such as predict, score...
|
340
|
-
|
333
|
+
|
341
334
|
Raises:
|
342
335
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
343
336
|
SnowflakeMLException: If the session is None, raise error
|
344
337
|
|
345
|
-
Returns:
|
346
|
-
A list of available package that exists in the snowflake anaconda channel
|
347
338
|
"""
|
348
339
|
if not self._is_fitted:
|
349
340
|
raise exceptions.SnowflakeMLException(
|
@@ -361,9 +352,7 @@ class LarsCV(BaseTransformer):
|
|
361
352
|
"Session must not specified for snowpark dataset."
|
362
353
|
),
|
363
354
|
)
|
364
|
-
|
365
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
366
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
355
|
+
|
367
356
|
|
368
357
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
369
358
|
@telemetry.send_api_usage_telemetry(
|
@@ -411,7 +400,8 @@ class LarsCV(BaseTransformer):
|
|
411
400
|
|
412
401
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
413
402
|
|
414
|
-
self.
|
403
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
404
|
+
self._deps = self._get_dependencies()
|
415
405
|
assert isinstance(
|
416
406
|
dataset._session, Session
|
417
407
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -494,10 +484,8 @@ class LarsCV(BaseTransformer):
|
|
494
484
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
495
485
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
496
486
|
|
497
|
-
self.
|
498
|
-
|
499
|
-
inference_method=inference_method,
|
500
|
-
)
|
487
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
488
|
+
self._deps = self._get_dependencies()
|
501
489
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
502
490
|
|
503
491
|
transform_kwargs = dict(
|
@@ -564,16 +552,40 @@ class LarsCV(BaseTransformer):
|
|
564
552
|
self._is_fitted = True
|
565
553
|
return output_result
|
566
554
|
|
555
|
+
|
556
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
557
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
558
|
+
""" Method not supported for this class.
|
567
559
|
|
568
|
-
|
569
|
-
|
570
|
-
|
560
|
+
|
561
|
+
Raises:
|
562
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
563
|
+
|
564
|
+
Args:
|
565
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
566
|
+
Snowpark or Pandas DataFrame.
|
567
|
+
output_cols_prefix: Prefix for the response columns
|
571
568
|
Returns:
|
572
569
|
Transformed dataset.
|
573
570
|
"""
|
574
|
-
self.
|
575
|
-
|
576
|
-
|
571
|
+
self._infer_input_output_cols(dataset)
|
572
|
+
super()._check_dataset_type(dataset)
|
573
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
574
|
+
estimator=self._sklearn_object,
|
575
|
+
dataset=dataset,
|
576
|
+
input_cols=self.input_cols,
|
577
|
+
label_cols=self.label_cols,
|
578
|
+
sample_weight_col=self.sample_weight_col,
|
579
|
+
autogenerated=self._autogenerated,
|
580
|
+
subproject=_SUBPROJECT,
|
581
|
+
)
|
582
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
583
|
+
drop_input_cols=self._drop_input_cols,
|
584
|
+
expected_output_cols_list=self.output_cols,
|
585
|
+
)
|
586
|
+
self._sklearn_object = fitted_estimator
|
587
|
+
self._is_fitted = True
|
588
|
+
return output_result
|
577
589
|
|
578
590
|
|
579
591
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -664,10 +676,8 @@ class LarsCV(BaseTransformer):
|
|
664
676
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
665
677
|
|
666
678
|
if isinstance(dataset, DataFrame):
|
667
|
-
self.
|
668
|
-
|
669
|
-
inference_method=inference_method,
|
670
|
-
)
|
679
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
680
|
+
self._deps = self._get_dependencies()
|
671
681
|
assert isinstance(
|
672
682
|
dataset._session, Session
|
673
683
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -732,10 +742,8 @@ class LarsCV(BaseTransformer):
|
|
732
742
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
733
743
|
|
734
744
|
if isinstance(dataset, DataFrame):
|
735
|
-
self.
|
736
|
-
|
737
|
-
inference_method=inference_method,
|
738
|
-
)
|
745
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
746
|
+
self._deps = self._get_dependencies()
|
739
747
|
assert isinstance(
|
740
748
|
dataset._session, Session
|
741
749
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -797,10 +805,8 @@ class LarsCV(BaseTransformer):
|
|
797
805
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
798
806
|
|
799
807
|
if isinstance(dataset, DataFrame):
|
800
|
-
self.
|
801
|
-
|
802
|
-
inference_method=inference_method,
|
803
|
-
)
|
808
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
809
|
+
self._deps = self._get_dependencies()
|
804
810
|
assert isinstance(
|
805
811
|
dataset._session, Session
|
806
812
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -866,10 +872,8 @@ class LarsCV(BaseTransformer):
|
|
866
872
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
867
873
|
|
868
874
|
if isinstance(dataset, DataFrame):
|
869
|
-
self.
|
870
|
-
|
871
|
-
inference_method=inference_method,
|
872
|
-
)
|
875
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
876
|
+
self._deps = self._get_dependencies()
|
873
877
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
874
878
|
transform_kwargs = dict(
|
875
879
|
session=dataset._session,
|
@@ -933,17 +937,15 @@ class LarsCV(BaseTransformer):
|
|
933
937
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
934
938
|
|
935
939
|
if isinstance(dataset, DataFrame):
|
936
|
-
self.
|
937
|
-
|
938
|
-
inference_method="score",
|
939
|
-
)
|
940
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
941
|
+
self._deps = self._get_dependencies()
|
940
942
|
selected_cols = self._get_active_columns()
|
941
943
|
if len(selected_cols) > 0:
|
942
944
|
dataset = dataset.select(selected_cols)
|
943
945
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
944
946
|
transform_kwargs = dict(
|
945
947
|
session=dataset._session,
|
946
|
-
dependencies=
|
948
|
+
dependencies=self._deps,
|
947
949
|
score_sproc_imports=['sklearn'],
|
948
950
|
)
|
949
951
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1008,11 +1010,8 @@ class LarsCV(BaseTransformer):
|
|
1008
1010
|
|
1009
1011
|
if isinstance(dataset, DataFrame):
|
1010
1012
|
|
1011
|
-
self.
|
1012
|
-
|
1013
|
-
inference_method=inference_method,
|
1014
|
-
|
1015
|
-
)
|
1013
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1014
|
+
self._deps = self._get_dependencies()
|
1016
1015
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1017
1016
|
transform_kwargs = dict(
|
1018
1017
|
session = dataset._session,
|