snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SequentialFeatureSelector(BaseTransformer):
70
64
  r"""Transformer that performs Sequential Feature Selection
71
65
  For more details on this class, see [sklearn.feature_selection.SequentialFeatureSelector]
@@ -324,20 +318,17 @@ class SequentialFeatureSelector(BaseTransformer):
324
318
  self,
325
319
  dataset: DataFrame,
326
320
  inference_method: str,
327
- ) -> List[str]:
328
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
329
- return the available package that exists in the snowflake anaconda channel
321
+ ) -> None:
322
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
330
323
 
331
324
  Args:
332
325
  dataset: snowpark dataframe
333
326
  inference_method: the inference method such as predict, score...
334
-
327
+
335
328
  Raises:
336
329
  SnowflakeMLException: If the estimator is not fitted, raise error
337
330
  SnowflakeMLException: If the session is None, raise error
338
331
 
339
- Returns:
340
- A list of available package that exists in the snowflake anaconda channel
341
332
  """
342
333
  if not self._is_fitted:
343
334
  raise exceptions.SnowflakeMLException(
@@ -355,9 +346,7 @@ class SequentialFeatureSelector(BaseTransformer):
355
346
  "Session must not specified for snowpark dataset."
356
347
  ),
357
348
  )
358
- # Validate that key package version in user workspace are supported in snowflake conda channel
359
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
360
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
349
+
361
350
 
362
351
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
363
352
  @telemetry.send_api_usage_telemetry(
@@ -403,7 +392,8 @@ class SequentialFeatureSelector(BaseTransformer):
403
392
 
404
393
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
405
394
 
406
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
395
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
396
+ self._deps = self._get_dependencies()
407
397
  assert isinstance(
408
398
  dataset._session, Session
409
399
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -488,10 +478,8 @@ class SequentialFeatureSelector(BaseTransformer):
488
478
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
489
479
  expected_dtype = convert_sp_to_sf_type(output_types[0])
490
480
 
491
- self._deps = self._batch_inference_validate_snowpark(
492
- dataset=dataset,
493
- inference_method=inference_method,
494
- )
481
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
482
+ self._deps = self._get_dependencies()
495
483
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
496
484
 
497
485
  transform_kwargs = dict(
@@ -558,16 +546,42 @@ class SequentialFeatureSelector(BaseTransformer):
558
546
  self._is_fitted = True
559
547
  return output_result
560
548
 
549
+
550
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
551
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
552
+ """ Fit to data, then transform it
553
+ For more details on this function, see [sklearn.feature_selection.SequentialFeatureSelector.fit_transform]
554
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SequentialFeatureSelector.html#sklearn.feature_selection.SequentialFeatureSelector.fit_transform)
555
+
561
556
 
562
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
563
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
564
- """
557
+ Raises:
558
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
559
+
560
+ Args:
561
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
562
+ Snowpark or Pandas DataFrame.
563
+ output_cols_prefix: Prefix for the response columns
565
564
  Returns:
566
565
  Transformed dataset.
567
566
  """
568
- self.fit(dataset)
569
- assert self._sklearn_object is not None
570
- return self._sklearn_object.embedding_
567
+ self._infer_input_output_cols(dataset)
568
+ super()._check_dataset_type(dataset)
569
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
570
+ estimator=self._sklearn_object,
571
+ dataset=dataset,
572
+ input_cols=self.input_cols,
573
+ label_cols=self.label_cols,
574
+ sample_weight_col=self.sample_weight_col,
575
+ autogenerated=self._autogenerated,
576
+ subproject=_SUBPROJECT,
577
+ )
578
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
579
+ drop_input_cols=self._drop_input_cols,
580
+ expected_output_cols_list=self.output_cols,
581
+ )
582
+ self._sklearn_object = fitted_estimator
583
+ self._is_fitted = True
584
+ return output_result
571
585
 
572
586
 
573
587
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -658,10 +672,8 @@ class SequentialFeatureSelector(BaseTransformer):
658
672
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
659
673
 
660
674
  if isinstance(dataset, DataFrame):
661
- self._deps = self._batch_inference_validate_snowpark(
662
- dataset=dataset,
663
- inference_method=inference_method,
664
- )
675
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
676
+ self._deps = self._get_dependencies()
665
677
  assert isinstance(
666
678
  dataset._session, Session
667
679
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -726,10 +738,8 @@ class SequentialFeatureSelector(BaseTransformer):
726
738
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
727
739
 
728
740
  if isinstance(dataset, DataFrame):
729
- self._deps = self._batch_inference_validate_snowpark(
730
- dataset=dataset,
731
- inference_method=inference_method,
732
- )
741
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
742
+ self._deps = self._get_dependencies()
733
743
  assert isinstance(
734
744
  dataset._session, Session
735
745
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -791,10 +801,8 @@ class SequentialFeatureSelector(BaseTransformer):
791
801
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
792
802
 
793
803
  if isinstance(dataset, DataFrame):
794
- self._deps = self._batch_inference_validate_snowpark(
795
- dataset=dataset,
796
- inference_method=inference_method,
797
- )
804
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
805
+ self._deps = self._get_dependencies()
798
806
  assert isinstance(
799
807
  dataset._session, Session
800
808
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -860,10 +868,8 @@ class SequentialFeatureSelector(BaseTransformer):
860
868
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
861
869
 
862
870
  if isinstance(dataset, DataFrame):
863
- self._deps = self._batch_inference_validate_snowpark(
864
- dataset=dataset,
865
- inference_method=inference_method,
866
- )
871
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
872
+ self._deps = self._get_dependencies()
867
873
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
868
874
  transform_kwargs = dict(
869
875
  session=dataset._session,
@@ -925,17 +931,15 @@ class SequentialFeatureSelector(BaseTransformer):
925
931
  transform_kwargs: ScoreKwargsTypedDict = dict()
926
932
 
927
933
  if isinstance(dataset, DataFrame):
928
- self._deps = self._batch_inference_validate_snowpark(
929
- dataset=dataset,
930
- inference_method="score",
931
- )
934
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
935
+ self._deps = self._get_dependencies()
932
936
  selected_cols = self._get_active_columns()
933
937
  if len(selected_cols) > 0:
934
938
  dataset = dataset.select(selected_cols)
935
939
  assert isinstance(dataset._session, Session) # keep mypy happy
936
940
  transform_kwargs = dict(
937
941
  session=dataset._session,
938
- dependencies=["snowflake-snowpark-python"] + self._deps,
942
+ dependencies=self._deps,
939
943
  score_sproc_imports=['sklearn'],
940
944
  )
941
945
  elif isinstance(dataset, pd.DataFrame):
@@ -1000,11 +1004,8 @@ class SequentialFeatureSelector(BaseTransformer):
1000
1004
 
1001
1005
  if isinstance(dataset, DataFrame):
1002
1006
 
1003
- self._deps = self._batch_inference_validate_snowpark(
1004
- dataset=dataset,
1005
- inference_method=inference_method,
1006
-
1007
- )
1007
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1008
+ self._deps = self._get_dependencies()
1008
1009
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1009
1010
  transform_kwargs = dict(
1010
1011
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class VarianceThreshold(BaseTransformer):
70
64
  r"""Feature selector that removes all low-variance features
71
65
  For more details on this class, see [sklearn.feature_selection.VarianceThreshold]
@@ -257,20 +251,17 @@ class VarianceThreshold(BaseTransformer):
257
251
  self,
258
252
  dataset: DataFrame,
259
253
  inference_method: str,
260
- ) -> List[str]:
261
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
262
- return the available package that exists in the snowflake anaconda channel
254
+ ) -> None:
255
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
263
256
 
264
257
  Args:
265
258
  dataset: snowpark dataframe
266
259
  inference_method: the inference method such as predict, score...
267
-
260
+
268
261
  Raises:
269
262
  SnowflakeMLException: If the estimator is not fitted, raise error
270
263
  SnowflakeMLException: If the session is None, raise error
271
264
 
272
- Returns:
273
- A list of available package that exists in the snowflake anaconda channel
274
265
  """
275
266
  if not self._is_fitted:
276
267
  raise exceptions.SnowflakeMLException(
@@ -288,9 +279,7 @@ class VarianceThreshold(BaseTransformer):
288
279
  "Session must not specified for snowpark dataset."
289
280
  ),
290
281
  )
291
- # Validate that key package version in user workspace are supported in snowflake conda channel
292
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
293
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
282
+
294
283
 
295
284
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
296
285
  @telemetry.send_api_usage_telemetry(
@@ -336,7 +325,8 @@ class VarianceThreshold(BaseTransformer):
336
325
 
337
326
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
338
327
 
339
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
328
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
329
+ self._deps = self._get_dependencies()
340
330
  assert isinstance(
341
331
  dataset._session, Session
342
332
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -421,10 +411,8 @@ class VarianceThreshold(BaseTransformer):
421
411
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
422
412
  expected_dtype = convert_sp_to_sf_type(output_types[0])
423
413
 
424
- self._deps = self._batch_inference_validate_snowpark(
425
- dataset=dataset,
426
- inference_method=inference_method,
427
- )
414
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
415
+ self._deps = self._get_dependencies()
428
416
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
429
417
 
430
418
  transform_kwargs = dict(
@@ -491,16 +479,42 @@ class VarianceThreshold(BaseTransformer):
491
479
  self._is_fitted = True
492
480
  return output_result
493
481
 
482
+
483
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
484
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
485
+ """ Fit to data, then transform it
486
+ For more details on this function, see [sklearn.feature_selection.VarianceThreshold.fit_transform]
487
+ (https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.VarianceThreshold.html#sklearn.feature_selection.VarianceThreshold.fit_transform)
488
+
494
489
 
495
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
496
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
497
- """
490
+ Raises:
491
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
492
+
493
+ Args:
494
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
495
+ Snowpark or Pandas DataFrame.
496
+ output_cols_prefix: Prefix for the response columns
498
497
  Returns:
499
498
  Transformed dataset.
500
499
  """
501
- self.fit(dataset)
502
- assert self._sklearn_object is not None
503
- return self._sklearn_object.embedding_
500
+ self._infer_input_output_cols(dataset)
501
+ super()._check_dataset_type(dataset)
502
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
503
+ estimator=self._sklearn_object,
504
+ dataset=dataset,
505
+ input_cols=self.input_cols,
506
+ label_cols=self.label_cols,
507
+ sample_weight_col=self.sample_weight_col,
508
+ autogenerated=self._autogenerated,
509
+ subproject=_SUBPROJECT,
510
+ )
511
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
512
+ drop_input_cols=self._drop_input_cols,
513
+ expected_output_cols_list=self.output_cols,
514
+ )
515
+ self._sklearn_object = fitted_estimator
516
+ self._is_fitted = True
517
+ return output_result
504
518
 
505
519
 
506
520
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -591,10 +605,8 @@ class VarianceThreshold(BaseTransformer):
591
605
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
592
606
 
593
607
  if isinstance(dataset, DataFrame):
594
- self._deps = self._batch_inference_validate_snowpark(
595
- dataset=dataset,
596
- inference_method=inference_method,
597
- )
608
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
609
+ self._deps = self._get_dependencies()
598
610
  assert isinstance(
599
611
  dataset._session, Session
600
612
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -659,10 +671,8 @@ class VarianceThreshold(BaseTransformer):
659
671
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
660
672
 
661
673
  if isinstance(dataset, DataFrame):
662
- self._deps = self._batch_inference_validate_snowpark(
663
- dataset=dataset,
664
- inference_method=inference_method,
665
- )
674
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
675
+ self._deps = self._get_dependencies()
666
676
  assert isinstance(
667
677
  dataset._session, Session
668
678
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -724,10 +734,8 @@ class VarianceThreshold(BaseTransformer):
724
734
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
725
735
 
726
736
  if isinstance(dataset, DataFrame):
727
- self._deps = self._batch_inference_validate_snowpark(
728
- dataset=dataset,
729
- inference_method=inference_method,
730
- )
737
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
738
+ self._deps = self._get_dependencies()
731
739
  assert isinstance(
732
740
  dataset._session, Session
733
741
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -793,10 +801,8 @@ class VarianceThreshold(BaseTransformer):
793
801
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
794
802
 
795
803
  if isinstance(dataset, DataFrame):
796
- self._deps = self._batch_inference_validate_snowpark(
797
- dataset=dataset,
798
- inference_method=inference_method,
799
- )
804
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
805
+ self._deps = self._get_dependencies()
800
806
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
801
807
  transform_kwargs = dict(
802
808
  session=dataset._session,
@@ -858,17 +864,15 @@ class VarianceThreshold(BaseTransformer):
858
864
  transform_kwargs: ScoreKwargsTypedDict = dict()
859
865
 
860
866
  if isinstance(dataset, DataFrame):
861
- self._deps = self._batch_inference_validate_snowpark(
862
- dataset=dataset,
863
- inference_method="score",
864
- )
867
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
868
+ self._deps = self._get_dependencies()
865
869
  selected_cols = self._get_active_columns()
866
870
  if len(selected_cols) > 0:
867
871
  dataset = dataset.select(selected_cols)
868
872
  assert isinstance(dataset._session, Session) # keep mypy happy
869
873
  transform_kwargs = dict(
870
874
  session=dataset._session,
871
- dependencies=["snowflake-snowpark-python"] + self._deps,
875
+ dependencies=self._deps,
872
876
  score_sproc_imports=['sklearn'],
873
877
  )
874
878
  elif isinstance(dataset, pd.DataFrame):
@@ -933,11 +937,8 @@ class VarianceThreshold(BaseTransformer):
933
937
 
934
938
  if isinstance(dataset, DataFrame):
935
939
 
936
- self._deps = self._batch_inference_validate_snowpark(
937
- dataset=dataset,
938
- inference_method=inference_method,
939
-
940
- )
940
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
941
+ self._deps = self._get_dependencies()
941
942
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
942
943
  transform_kwargs = dict(
943
944
  session = dataset._session,
@@ -16,6 +16,7 @@ from snowflake.ml._internal.exceptions import (
16
16
  exceptions,
17
17
  modeling_error_messages,
18
18
  )
19
+ from snowflake.ml._internal.lineage import data_source, lineage_utils
19
20
  from snowflake.ml._internal.utils import identifier, parallelize
20
21
  from snowflake.ml.modeling.framework import _utils
21
22
  from snowflake.snowpark import functions as F
@@ -385,6 +386,7 @@ class BaseEstimator(Base):
385
386
  self.file_names = file_names
386
387
  self.custom_states = custom_states
387
388
  self.sample_weight_col = sample_weight_col
389
+ self._data_sources: Optional[List[data_source.DataSource]] = None
388
390
 
389
391
  self.start_time = datetime.now().strftime(_utils.DATETIME_FORMAT)[:-3]
390
392
 
@@ -419,12 +421,18 @@ class BaseEstimator(Base):
419
421
  """
420
422
  return []
421
423
 
424
+ def _get_data_sources(self) -> Optional[List[data_source.DataSource]]:
425
+ return self._data_sources
426
+
422
427
  @telemetry.send_api_usage_telemetry(
423
428
  project=PROJECT,
424
429
  subproject=SUBPROJECT,
425
430
  )
426
431
  def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "BaseEstimator":
427
432
  """Runs universal logics for all fit implementations."""
433
+ self._data_sources = getattr(dataset, lineage_utils.DATA_SOURCES_ATTR, None)
434
+ if self._data_sources:
435
+ assert all(isinstance(ds, data_source.DataSource) for ds in self._data_sources)
428
436
  return self._fit(dataset)
429
437
 
430
438
  @abstractmethod
@@ -539,58 +547,78 @@ class BaseTransformer(BaseEstimator):
539
547
  ),
540
548
  )
541
549
 
542
- def _infer_input_output_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> None:
550
+ def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> List[str]:
543
551
  """
544
- Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
552
+ Infer input_cols from the dataset. Input column are all columns in the input dataset that are not
553
+ designated as label, passthrough, or sample weight columns.
545
554
 
546
555
  Args:
547
556
  dataset: Input dataset.
548
557
 
558
+ Returns:
559
+ The list of input columns.
560
+ """
561
+ cols = [
562
+ c
563
+ for c in dataset.columns
564
+ if (c not in self.get_label_cols() and c not in self.get_passthrough_cols() and c != self.sample_weight_col)
565
+ ]
566
+ return cols
567
+
568
+ def _infer_output_cols(self) -> List[str]:
569
+ """Infer output column names from based on the estimator.
570
+
571
+ Returns:
572
+ The list of output columns.
573
+
549
574
  Raises:
550
575
  SnowflakeMLException: If unable to infer output columns
576
+
551
577
  """
552
- if not self.input_cols:
553
- cols = [
554
- c
555
- for c in dataset.columns
556
- if (
557
- c not in self.get_label_cols()
558
- and c not in self.get_passthrough_cols()
559
- and c != self.sample_weight_col
560
- )
561
- ]
562
- self.set_input_cols(input_cols=cols)
563
578
 
564
- if not self.output_cols:
565
- # keep mypy happy
566
- assert self._sklearn_object is not None
567
-
568
- if hasattr(self._sklearn_object, "_estimator_type"):
569
- # For supervised estimators, infer the output columns from the label columns
570
- if self._sklearn_object._estimator_type in SKLEARN_SUPERVISED_ESTIMATORS:
571
- cols = [identifier.concat_names(["OUTPUT_", c]) for c in self.label_cols]
572
- self.set_output_cols(output_cols=cols)
573
-
574
- # For density estimators, clusterers, and outlier detectors, there is always exactly one output column.
575
- elif self._sklearn_object._estimator_type in SKLEARN_SINGLE_OUTPUT_ESTIMATORS:
576
- self.set_output_cols(output_cols=["OUTPUT_0"])
577
-
578
- else:
579
- raise exceptions.SnowflakeMLException(
580
- error_code=error_codes.INVALID_ARGUMENT,
581
- original_exception=ValueError(
582
- f"Unable to infer output columns for estimator type {self._sklearn_object._estimator_type}."
583
- f"Please include `output_cols` explicitly."
584
- ),
585
- )
579
+ # keep mypy happy
580
+ assert self._sklearn_object is not None
581
+ if hasattr(self._sklearn_object, "_estimator_type"):
582
+ # For supervised estimators, infer the output columns from the label columns
583
+ if self._sklearn_object._estimator_type in SKLEARN_SUPERVISED_ESTIMATORS:
584
+ cols = [identifier.concat_names(["OUTPUT_", c]) for c in self.label_cols]
585
+ return cols
586
+
587
+ # For density estimators, clusterers, and outlier detectors, there is always exactly one output column.
588
+ elif self._sklearn_object._estimator_type in SKLEARN_SINGLE_OUTPUT_ESTIMATORS:
589
+ return ["OUTPUT_0"]
590
+
586
591
  else:
587
592
  raise exceptions.SnowflakeMLException(
588
593
  error_code=error_codes.INVALID_ARGUMENT,
589
594
  original_exception=ValueError(
590
- f"Unable to infer output columns for object {self._sklearn_object}."
595
+ f"Unable to infer output columns for estimator type {self._sklearn_object._estimator_type}."
591
596
  f"Please include `output_cols` explicitly."
592
597
  ),
593
598
  )
599
+ else:
600
+ raise exceptions.SnowflakeMLException(
601
+ error_code=error_codes.INVALID_ARGUMENT,
602
+ original_exception=ValueError(
603
+ f"Unable to infer output columns for object {self._sklearn_object}."
604
+ f"Please include `output_cols` explicitly."
605
+ ),
606
+ )
607
+
608
+ def _infer_input_output_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> None:
609
+ """
610
+ Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
611
+
612
+ Args:
613
+ dataset: Input dataset.
614
+ """
615
+ if not self.input_cols:
616
+ cols = self._infer_input_cols(dataset=dataset)
617
+ self.set_input_cols(input_cols=cols)
618
+
619
+ if not self.output_cols:
620
+ cols = self._infer_output_cols()
621
+ self.set_output_cols(output_cols=cols)
594
622
 
595
623
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
596
624
  """Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.