snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class Lasso(BaseTransformer):
|
70
64
|
r"""Linear Model trained with L1 prior as regularizer (aka the Lasso)
|
71
65
|
For more details on this class, see [sklearn.linear_model.Lasso]
|
@@ -323,20 +317,17 @@ class Lasso(BaseTransformer):
|
|
323
317
|
self,
|
324
318
|
dataset: DataFrame,
|
325
319
|
inference_method: str,
|
326
|
-
) ->
|
327
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
328
|
-
return the available package that exists in the snowflake anaconda channel
|
320
|
+
) -> None:
|
321
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
329
322
|
|
330
323
|
Args:
|
331
324
|
dataset: snowpark dataframe
|
332
325
|
inference_method: the inference method such as predict, score...
|
333
|
-
|
326
|
+
|
334
327
|
Raises:
|
335
328
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
336
329
|
SnowflakeMLException: If the session is None, raise error
|
337
330
|
|
338
|
-
Returns:
|
339
|
-
A list of available package that exists in the snowflake anaconda channel
|
340
331
|
"""
|
341
332
|
if not self._is_fitted:
|
342
333
|
raise exceptions.SnowflakeMLException(
|
@@ -354,9 +345,7 @@ class Lasso(BaseTransformer):
|
|
354
345
|
"Session must not specified for snowpark dataset."
|
355
346
|
),
|
356
347
|
)
|
357
|
-
|
358
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
359
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
348
|
+
|
360
349
|
|
361
350
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
362
351
|
@telemetry.send_api_usage_telemetry(
|
@@ -404,7 +393,8 @@ class Lasso(BaseTransformer):
|
|
404
393
|
|
405
394
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
406
395
|
|
407
|
-
self.
|
396
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
397
|
+
self._deps = self._get_dependencies()
|
408
398
|
assert isinstance(
|
409
399
|
dataset._session, Session
|
410
400
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -487,10 +477,8 @@ class Lasso(BaseTransformer):
|
|
487
477
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
488
478
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
489
479
|
|
490
|
-
self.
|
491
|
-
|
492
|
-
inference_method=inference_method,
|
493
|
-
)
|
480
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
481
|
+
self._deps = self._get_dependencies()
|
494
482
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
495
483
|
|
496
484
|
transform_kwargs = dict(
|
@@ -557,16 +545,40 @@ class Lasso(BaseTransformer):
|
|
557
545
|
self._is_fitted = True
|
558
546
|
return output_result
|
559
547
|
|
548
|
+
|
549
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
550
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
551
|
+
""" Method not supported for this class.
|
560
552
|
|
561
|
-
|
562
|
-
|
563
|
-
|
553
|
+
|
554
|
+
Raises:
|
555
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
556
|
+
|
557
|
+
Args:
|
558
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
559
|
+
Snowpark or Pandas DataFrame.
|
560
|
+
output_cols_prefix: Prefix for the response columns
|
564
561
|
Returns:
|
565
562
|
Transformed dataset.
|
566
563
|
"""
|
567
|
-
self.
|
568
|
-
|
569
|
-
|
564
|
+
self._infer_input_output_cols(dataset)
|
565
|
+
super()._check_dataset_type(dataset)
|
566
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
567
|
+
estimator=self._sklearn_object,
|
568
|
+
dataset=dataset,
|
569
|
+
input_cols=self.input_cols,
|
570
|
+
label_cols=self.label_cols,
|
571
|
+
sample_weight_col=self.sample_weight_col,
|
572
|
+
autogenerated=self._autogenerated,
|
573
|
+
subproject=_SUBPROJECT,
|
574
|
+
)
|
575
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
576
|
+
drop_input_cols=self._drop_input_cols,
|
577
|
+
expected_output_cols_list=self.output_cols,
|
578
|
+
)
|
579
|
+
self._sklearn_object = fitted_estimator
|
580
|
+
self._is_fitted = True
|
581
|
+
return output_result
|
570
582
|
|
571
583
|
|
572
584
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -657,10 +669,8 @@ class Lasso(BaseTransformer):
|
|
657
669
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
658
670
|
|
659
671
|
if isinstance(dataset, DataFrame):
|
660
|
-
self.
|
661
|
-
|
662
|
-
inference_method=inference_method,
|
663
|
-
)
|
672
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
673
|
+
self._deps = self._get_dependencies()
|
664
674
|
assert isinstance(
|
665
675
|
dataset._session, Session
|
666
676
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -725,10 +735,8 @@ class Lasso(BaseTransformer):
|
|
725
735
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
726
736
|
|
727
737
|
if isinstance(dataset, DataFrame):
|
728
|
-
self.
|
729
|
-
|
730
|
-
inference_method=inference_method,
|
731
|
-
)
|
738
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
739
|
+
self._deps = self._get_dependencies()
|
732
740
|
assert isinstance(
|
733
741
|
dataset._session, Session
|
734
742
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -790,10 +798,8 @@ class Lasso(BaseTransformer):
|
|
790
798
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
791
799
|
|
792
800
|
if isinstance(dataset, DataFrame):
|
793
|
-
self.
|
794
|
-
|
795
|
-
inference_method=inference_method,
|
796
|
-
)
|
801
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
802
|
+
self._deps = self._get_dependencies()
|
797
803
|
assert isinstance(
|
798
804
|
dataset._session, Session
|
799
805
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -859,10 +865,8 @@ class Lasso(BaseTransformer):
|
|
859
865
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
860
866
|
|
861
867
|
if isinstance(dataset, DataFrame):
|
862
|
-
self.
|
863
|
-
|
864
|
-
inference_method=inference_method,
|
865
|
-
)
|
868
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
869
|
+
self._deps = self._get_dependencies()
|
866
870
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
867
871
|
transform_kwargs = dict(
|
868
872
|
session=dataset._session,
|
@@ -926,17 +930,15 @@ class Lasso(BaseTransformer):
|
|
926
930
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
927
931
|
|
928
932
|
if isinstance(dataset, DataFrame):
|
929
|
-
self.
|
930
|
-
|
931
|
-
inference_method="score",
|
932
|
-
)
|
933
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
934
|
+
self._deps = self._get_dependencies()
|
933
935
|
selected_cols = self._get_active_columns()
|
934
936
|
if len(selected_cols) > 0:
|
935
937
|
dataset = dataset.select(selected_cols)
|
936
938
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
937
939
|
transform_kwargs = dict(
|
938
940
|
session=dataset._session,
|
939
|
-
dependencies=
|
941
|
+
dependencies=self._deps,
|
940
942
|
score_sproc_imports=['sklearn'],
|
941
943
|
)
|
942
944
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1001,11 +1003,8 @@ class Lasso(BaseTransformer):
|
|
1001
1003
|
|
1002
1004
|
if isinstance(dataset, DataFrame):
|
1003
1005
|
|
1004
|
-
self.
|
1005
|
-
|
1006
|
-
inference_method=inference_method,
|
1007
|
-
|
1008
|
-
)
|
1006
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1007
|
+
self._deps = self._get_dependencies()
|
1009
1008
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1010
1009
|
transform_kwargs = dict(
|
1011
1010
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LassoCV(BaseTransformer):
|
70
64
|
r"""Lasso linear model with iterative fitting along a regularization path
|
71
65
|
For more details on this class, see [sklearn.linear_model.LassoCV]
|
@@ -351,20 +345,17 @@ class LassoCV(BaseTransformer):
|
|
351
345
|
self,
|
352
346
|
dataset: DataFrame,
|
353
347
|
inference_method: str,
|
354
|
-
) ->
|
355
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
356
|
-
return the available package that exists in the snowflake anaconda channel
|
348
|
+
) -> None:
|
349
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
357
350
|
|
358
351
|
Args:
|
359
352
|
dataset: snowpark dataframe
|
360
353
|
inference_method: the inference method such as predict, score...
|
361
|
-
|
354
|
+
|
362
355
|
Raises:
|
363
356
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
364
357
|
SnowflakeMLException: If the session is None, raise error
|
365
358
|
|
366
|
-
Returns:
|
367
|
-
A list of available package that exists in the snowflake anaconda channel
|
368
359
|
"""
|
369
360
|
if not self._is_fitted:
|
370
361
|
raise exceptions.SnowflakeMLException(
|
@@ -382,9 +373,7 @@ class LassoCV(BaseTransformer):
|
|
382
373
|
"Session must not specified for snowpark dataset."
|
383
374
|
),
|
384
375
|
)
|
385
|
-
|
386
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
387
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
376
|
+
|
388
377
|
|
389
378
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
390
379
|
@telemetry.send_api_usage_telemetry(
|
@@ -432,7 +421,8 @@ class LassoCV(BaseTransformer):
|
|
432
421
|
|
433
422
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
434
423
|
|
435
|
-
self.
|
424
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
425
|
+
self._deps = self._get_dependencies()
|
436
426
|
assert isinstance(
|
437
427
|
dataset._session, Session
|
438
428
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -515,10 +505,8 @@ class LassoCV(BaseTransformer):
|
|
515
505
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
516
506
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
517
507
|
|
518
|
-
self.
|
519
|
-
|
520
|
-
inference_method=inference_method,
|
521
|
-
)
|
508
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
509
|
+
self._deps = self._get_dependencies()
|
522
510
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
523
511
|
|
524
512
|
transform_kwargs = dict(
|
@@ -585,16 +573,40 @@ class LassoCV(BaseTransformer):
|
|
585
573
|
self._is_fitted = True
|
586
574
|
return output_result
|
587
575
|
|
576
|
+
|
577
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
578
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
579
|
+
""" Method not supported for this class.
|
588
580
|
|
589
|
-
|
590
|
-
|
591
|
-
|
581
|
+
|
582
|
+
Raises:
|
583
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
584
|
+
|
585
|
+
Args:
|
586
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
587
|
+
Snowpark or Pandas DataFrame.
|
588
|
+
output_cols_prefix: Prefix for the response columns
|
592
589
|
Returns:
|
593
590
|
Transformed dataset.
|
594
591
|
"""
|
595
|
-
self.
|
596
|
-
|
597
|
-
|
592
|
+
self._infer_input_output_cols(dataset)
|
593
|
+
super()._check_dataset_type(dataset)
|
594
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
595
|
+
estimator=self._sklearn_object,
|
596
|
+
dataset=dataset,
|
597
|
+
input_cols=self.input_cols,
|
598
|
+
label_cols=self.label_cols,
|
599
|
+
sample_weight_col=self.sample_weight_col,
|
600
|
+
autogenerated=self._autogenerated,
|
601
|
+
subproject=_SUBPROJECT,
|
602
|
+
)
|
603
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
604
|
+
drop_input_cols=self._drop_input_cols,
|
605
|
+
expected_output_cols_list=self.output_cols,
|
606
|
+
)
|
607
|
+
self._sklearn_object = fitted_estimator
|
608
|
+
self._is_fitted = True
|
609
|
+
return output_result
|
598
610
|
|
599
611
|
|
600
612
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -685,10 +697,8 @@ class LassoCV(BaseTransformer):
|
|
685
697
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
686
698
|
|
687
699
|
if isinstance(dataset, DataFrame):
|
688
|
-
self.
|
689
|
-
|
690
|
-
inference_method=inference_method,
|
691
|
-
)
|
700
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
701
|
+
self._deps = self._get_dependencies()
|
692
702
|
assert isinstance(
|
693
703
|
dataset._session, Session
|
694
704
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -753,10 +763,8 @@ class LassoCV(BaseTransformer):
|
|
753
763
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
754
764
|
|
755
765
|
if isinstance(dataset, DataFrame):
|
756
|
-
self.
|
757
|
-
|
758
|
-
inference_method=inference_method,
|
759
|
-
)
|
766
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
767
|
+
self._deps = self._get_dependencies()
|
760
768
|
assert isinstance(
|
761
769
|
dataset._session, Session
|
762
770
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -818,10 +826,8 @@ class LassoCV(BaseTransformer):
|
|
818
826
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
819
827
|
|
820
828
|
if isinstance(dataset, DataFrame):
|
821
|
-
self.
|
822
|
-
|
823
|
-
inference_method=inference_method,
|
824
|
-
)
|
829
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
830
|
+
self._deps = self._get_dependencies()
|
825
831
|
assert isinstance(
|
826
832
|
dataset._session, Session
|
827
833
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -887,10 +893,8 @@ class LassoCV(BaseTransformer):
|
|
887
893
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
888
894
|
|
889
895
|
if isinstance(dataset, DataFrame):
|
890
|
-
self.
|
891
|
-
|
892
|
-
inference_method=inference_method,
|
893
|
-
)
|
896
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
897
|
+
self._deps = self._get_dependencies()
|
894
898
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
895
899
|
transform_kwargs = dict(
|
896
900
|
session=dataset._session,
|
@@ -954,17 +958,15 @@ class LassoCV(BaseTransformer):
|
|
954
958
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
955
959
|
|
956
960
|
if isinstance(dataset, DataFrame):
|
957
|
-
self.
|
958
|
-
|
959
|
-
inference_method="score",
|
960
|
-
)
|
961
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
962
|
+
self._deps = self._get_dependencies()
|
961
963
|
selected_cols = self._get_active_columns()
|
962
964
|
if len(selected_cols) > 0:
|
963
965
|
dataset = dataset.select(selected_cols)
|
964
966
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
965
967
|
transform_kwargs = dict(
|
966
968
|
session=dataset._session,
|
967
|
-
dependencies=
|
969
|
+
dependencies=self._deps,
|
968
970
|
score_sproc_imports=['sklearn'],
|
969
971
|
)
|
970
972
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1029,11 +1031,8 @@ class LassoCV(BaseTransformer):
|
|
1029
1031
|
|
1030
1032
|
if isinstance(dataset, DataFrame):
|
1031
1033
|
|
1032
|
-
self.
|
1033
|
-
|
1034
|
-
inference_method=inference_method,
|
1035
|
-
|
1036
|
-
)
|
1034
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1035
|
+
self._deps = self._get_dependencies()
|
1037
1036
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1038
1037
|
transform_kwargs = dict(
|
1039
1038
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LassoLars(BaseTransformer):
|
70
64
|
r"""Lasso model fit with Least Angle Regression a
|
71
65
|
For more details on this class, see [sklearn.linear_model.LassoLars]
|
@@ -343,20 +337,17 @@ class LassoLars(BaseTransformer):
|
|
343
337
|
self,
|
344
338
|
dataset: DataFrame,
|
345
339
|
inference_method: str,
|
346
|
-
) ->
|
347
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
348
|
-
return the available package that exists in the snowflake anaconda channel
|
340
|
+
) -> None:
|
341
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
349
342
|
|
350
343
|
Args:
|
351
344
|
dataset: snowpark dataframe
|
352
345
|
inference_method: the inference method such as predict, score...
|
353
|
-
|
346
|
+
|
354
347
|
Raises:
|
355
348
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
356
349
|
SnowflakeMLException: If the session is None, raise error
|
357
350
|
|
358
|
-
Returns:
|
359
|
-
A list of available package that exists in the snowflake anaconda channel
|
360
351
|
"""
|
361
352
|
if not self._is_fitted:
|
362
353
|
raise exceptions.SnowflakeMLException(
|
@@ -374,9 +365,7 @@ class LassoLars(BaseTransformer):
|
|
374
365
|
"Session must not specified for snowpark dataset."
|
375
366
|
),
|
376
367
|
)
|
377
|
-
|
378
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
379
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
368
|
+
|
380
369
|
|
381
370
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
382
371
|
@telemetry.send_api_usage_telemetry(
|
@@ -424,7 +413,8 @@ class LassoLars(BaseTransformer):
|
|
424
413
|
|
425
414
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
426
415
|
|
427
|
-
self.
|
416
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
417
|
+
self._deps = self._get_dependencies()
|
428
418
|
assert isinstance(
|
429
419
|
dataset._session, Session
|
430
420
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -507,10 +497,8 @@ class LassoLars(BaseTransformer):
|
|
507
497
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
508
498
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
509
499
|
|
510
|
-
self.
|
511
|
-
|
512
|
-
inference_method=inference_method,
|
513
|
-
)
|
500
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
501
|
+
self._deps = self._get_dependencies()
|
514
502
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
515
503
|
|
516
504
|
transform_kwargs = dict(
|
@@ -577,16 +565,40 @@ class LassoLars(BaseTransformer):
|
|
577
565
|
self._is_fitted = True
|
578
566
|
return output_result
|
579
567
|
|
568
|
+
|
569
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
570
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
571
|
+
""" Method not supported for this class.
|
580
572
|
|
581
|
-
|
582
|
-
|
583
|
-
|
573
|
+
|
574
|
+
Raises:
|
575
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
576
|
+
|
577
|
+
Args:
|
578
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
579
|
+
Snowpark or Pandas DataFrame.
|
580
|
+
output_cols_prefix: Prefix for the response columns
|
584
581
|
Returns:
|
585
582
|
Transformed dataset.
|
586
583
|
"""
|
587
|
-
self.
|
588
|
-
|
589
|
-
|
584
|
+
self._infer_input_output_cols(dataset)
|
585
|
+
super()._check_dataset_type(dataset)
|
586
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
587
|
+
estimator=self._sklearn_object,
|
588
|
+
dataset=dataset,
|
589
|
+
input_cols=self.input_cols,
|
590
|
+
label_cols=self.label_cols,
|
591
|
+
sample_weight_col=self.sample_weight_col,
|
592
|
+
autogenerated=self._autogenerated,
|
593
|
+
subproject=_SUBPROJECT,
|
594
|
+
)
|
595
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
596
|
+
drop_input_cols=self._drop_input_cols,
|
597
|
+
expected_output_cols_list=self.output_cols,
|
598
|
+
)
|
599
|
+
self._sklearn_object = fitted_estimator
|
600
|
+
self._is_fitted = True
|
601
|
+
return output_result
|
590
602
|
|
591
603
|
|
592
604
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -677,10 +689,8 @@ class LassoLars(BaseTransformer):
|
|
677
689
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
678
690
|
|
679
691
|
if isinstance(dataset, DataFrame):
|
680
|
-
self.
|
681
|
-
|
682
|
-
inference_method=inference_method,
|
683
|
-
)
|
692
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
693
|
+
self._deps = self._get_dependencies()
|
684
694
|
assert isinstance(
|
685
695
|
dataset._session, Session
|
686
696
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -745,10 +755,8 @@ class LassoLars(BaseTransformer):
|
|
745
755
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
746
756
|
|
747
757
|
if isinstance(dataset, DataFrame):
|
748
|
-
self.
|
749
|
-
|
750
|
-
inference_method=inference_method,
|
751
|
-
)
|
758
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
759
|
+
self._deps = self._get_dependencies()
|
752
760
|
assert isinstance(
|
753
761
|
dataset._session, Session
|
754
762
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -810,10 +818,8 @@ class LassoLars(BaseTransformer):
|
|
810
818
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
811
819
|
|
812
820
|
if isinstance(dataset, DataFrame):
|
813
|
-
self.
|
814
|
-
|
815
|
-
inference_method=inference_method,
|
816
|
-
)
|
821
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
822
|
+
self._deps = self._get_dependencies()
|
817
823
|
assert isinstance(
|
818
824
|
dataset._session, Session
|
819
825
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -879,10 +885,8 @@ class LassoLars(BaseTransformer):
|
|
879
885
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
880
886
|
|
881
887
|
if isinstance(dataset, DataFrame):
|
882
|
-
self.
|
883
|
-
|
884
|
-
inference_method=inference_method,
|
885
|
-
)
|
888
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
889
|
+
self._deps = self._get_dependencies()
|
886
890
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
887
891
|
transform_kwargs = dict(
|
888
892
|
session=dataset._session,
|
@@ -946,17 +950,15 @@ class LassoLars(BaseTransformer):
|
|
946
950
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
947
951
|
|
948
952
|
if isinstance(dataset, DataFrame):
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method="score",
|
952
|
-
)
|
953
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
954
|
+
self._deps = self._get_dependencies()
|
953
955
|
selected_cols = self._get_active_columns()
|
954
956
|
if len(selected_cols) > 0:
|
955
957
|
dataset = dataset.select(selected_cols)
|
956
958
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
957
959
|
transform_kwargs = dict(
|
958
960
|
session=dataset._session,
|
959
|
-
dependencies=
|
961
|
+
dependencies=self._deps,
|
960
962
|
score_sproc_imports=['sklearn'],
|
961
963
|
)
|
962
964
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1021,11 +1023,8 @@ class LassoLars(BaseTransformer):
|
|
1021
1023
|
|
1022
1024
|
if isinstance(dataset, DataFrame):
|
1023
1025
|
|
1024
|
-
self.
|
1025
|
-
|
1026
|
-
inference_method=inference_method,
|
1027
|
-
|
1028
|
-
)
|
1026
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1027
|
+
self._deps = self._get_dependencies()
|
1029
1028
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1030
1029
|
transform_kwargs = dict(
|
1031
1030
|
session = dataset._session,
|