snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GraphicalLassoCV(BaseTransformer):
70
64
  r"""Sparse inverse covariance w/ cross-validated choice of the l1 penalty
71
65
  For more details on this class, see [sklearn.covariance.GraphicalLassoCV]
@@ -337,20 +331,17 @@ class GraphicalLassoCV(BaseTransformer):
337
331
  self,
338
332
  dataset: DataFrame,
339
333
  inference_method: str,
340
- ) -> List[str]:
341
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
342
- return the available package that exists in the snowflake anaconda channel
334
+ ) -> None:
335
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
343
336
 
344
337
  Args:
345
338
  dataset: snowpark dataframe
346
339
  inference_method: the inference method such as predict, score...
347
-
340
+
348
341
  Raises:
349
342
  SnowflakeMLException: If the estimator is not fitted, raise error
350
343
  SnowflakeMLException: If the session is None, raise error
351
344
 
352
- Returns:
353
- A list of available package that exists in the snowflake anaconda channel
354
345
  """
355
346
  if not self._is_fitted:
356
347
  raise exceptions.SnowflakeMLException(
@@ -368,9 +359,7 @@ class GraphicalLassoCV(BaseTransformer):
368
359
  "Session must not specified for snowpark dataset."
369
360
  ),
370
361
  )
371
- # Validate that key package version in user workspace are supported in snowflake conda channel
372
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
373
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
362
+
374
363
 
375
364
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
376
365
  @telemetry.send_api_usage_telemetry(
@@ -416,7 +405,8 @@ class GraphicalLassoCV(BaseTransformer):
416
405
 
417
406
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
418
407
 
419
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
408
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
409
+ self._deps = self._get_dependencies()
420
410
  assert isinstance(
421
411
  dataset._session, Session
422
412
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -499,10 +489,8 @@ class GraphicalLassoCV(BaseTransformer):
499
489
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
500
490
  expected_dtype = convert_sp_to_sf_type(output_types[0])
501
491
 
502
- self._deps = self._batch_inference_validate_snowpark(
503
- dataset=dataset,
504
- inference_method=inference_method,
505
- )
492
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
493
+ self._deps = self._get_dependencies()
506
494
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
507
495
 
508
496
  transform_kwargs = dict(
@@ -569,16 +557,40 @@ class GraphicalLassoCV(BaseTransformer):
569
557
  self._is_fitted = True
570
558
  return output_result
571
559
 
560
+
561
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
562
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
563
+ """ Method not supported for this class.
572
564
 
573
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
574
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
575
- """
565
+
566
+ Raises:
567
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
568
+
569
+ Args:
570
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
571
+ Snowpark or Pandas DataFrame.
572
+ output_cols_prefix: Prefix for the response columns
576
573
  Returns:
577
574
  Transformed dataset.
578
575
  """
579
- self.fit(dataset)
580
- assert self._sklearn_object is not None
581
- return self._sklearn_object.embedding_
576
+ self._infer_input_output_cols(dataset)
577
+ super()._check_dataset_type(dataset)
578
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
579
+ estimator=self._sklearn_object,
580
+ dataset=dataset,
581
+ input_cols=self.input_cols,
582
+ label_cols=self.label_cols,
583
+ sample_weight_col=self.sample_weight_col,
584
+ autogenerated=self._autogenerated,
585
+ subproject=_SUBPROJECT,
586
+ )
587
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
588
+ drop_input_cols=self._drop_input_cols,
589
+ expected_output_cols_list=self.output_cols,
590
+ )
591
+ self._sklearn_object = fitted_estimator
592
+ self._is_fitted = True
593
+ return output_result
582
594
 
583
595
 
584
596
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -669,10 +681,8 @@ class GraphicalLassoCV(BaseTransformer):
669
681
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
670
682
 
671
683
  if isinstance(dataset, DataFrame):
672
- self._deps = self._batch_inference_validate_snowpark(
673
- dataset=dataset,
674
- inference_method=inference_method,
675
- )
684
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
685
+ self._deps = self._get_dependencies()
676
686
  assert isinstance(
677
687
  dataset._session, Session
678
688
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -737,10 +747,8 @@ class GraphicalLassoCV(BaseTransformer):
737
747
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
738
748
 
739
749
  if isinstance(dataset, DataFrame):
740
- self._deps = self._batch_inference_validate_snowpark(
741
- dataset=dataset,
742
- inference_method=inference_method,
743
- )
750
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
751
+ self._deps = self._get_dependencies()
744
752
  assert isinstance(
745
753
  dataset._session, Session
746
754
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -802,10 +810,8 @@ class GraphicalLassoCV(BaseTransformer):
802
810
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
803
811
 
804
812
  if isinstance(dataset, DataFrame):
805
- self._deps = self._batch_inference_validate_snowpark(
806
- dataset=dataset,
807
- inference_method=inference_method,
808
- )
813
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
814
+ self._deps = self._get_dependencies()
809
815
  assert isinstance(
810
816
  dataset._session, Session
811
817
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -871,10 +877,8 @@ class GraphicalLassoCV(BaseTransformer):
871
877
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
872
878
 
873
879
  if isinstance(dataset, DataFrame):
874
- self._deps = self._batch_inference_validate_snowpark(
875
- dataset=dataset,
876
- inference_method=inference_method,
877
- )
880
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
881
+ self._deps = self._get_dependencies()
878
882
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
879
883
  transform_kwargs = dict(
880
884
  session=dataset._session,
@@ -938,17 +942,15 @@ class GraphicalLassoCV(BaseTransformer):
938
942
  transform_kwargs: ScoreKwargsTypedDict = dict()
939
943
 
940
944
  if isinstance(dataset, DataFrame):
941
- self._deps = self._batch_inference_validate_snowpark(
942
- dataset=dataset,
943
- inference_method="score",
944
- )
945
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
946
+ self._deps = self._get_dependencies()
945
947
  selected_cols = self._get_active_columns()
946
948
  if len(selected_cols) > 0:
947
949
  dataset = dataset.select(selected_cols)
948
950
  assert isinstance(dataset._session, Session) # keep mypy happy
949
951
  transform_kwargs = dict(
950
952
  session=dataset._session,
951
- dependencies=["snowflake-snowpark-python"] + self._deps,
953
+ dependencies=self._deps,
952
954
  score_sproc_imports=['sklearn'],
953
955
  )
954
956
  elif isinstance(dataset, pd.DataFrame):
@@ -1013,11 +1015,8 @@ class GraphicalLassoCV(BaseTransformer):
1013
1015
 
1014
1016
  if isinstance(dataset, DataFrame):
1015
1017
 
1016
- self._deps = self._batch_inference_validate_snowpark(
1017
- dataset=dataset,
1018
- inference_method=inference_method,
1019
-
1020
- )
1018
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1019
+ self._deps = self._get_dependencies()
1021
1020
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1022
1021
  transform_kwargs = dict(
1023
1022
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LedoitWolf(BaseTransformer):
70
64
  r"""LedoitWolf Estimator
71
65
  For more details on this class, see [sklearn.covariance.LedoitWolf]
@@ -270,20 +264,17 @@ class LedoitWolf(BaseTransformer):
270
264
  self,
271
265
  dataset: DataFrame,
272
266
  inference_method: str,
273
- ) -> List[str]:
274
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
275
- return the available package that exists in the snowflake anaconda channel
267
+ ) -> None:
268
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
276
269
 
277
270
  Args:
278
271
  dataset: snowpark dataframe
279
272
  inference_method: the inference method such as predict, score...
280
-
273
+
281
274
  Raises:
282
275
  SnowflakeMLException: If the estimator is not fitted, raise error
283
276
  SnowflakeMLException: If the session is None, raise error
284
277
 
285
- Returns:
286
- A list of available package that exists in the snowflake anaconda channel
287
278
  """
288
279
  if not self._is_fitted:
289
280
  raise exceptions.SnowflakeMLException(
@@ -301,9 +292,7 @@ class LedoitWolf(BaseTransformer):
301
292
  "Session must not specified for snowpark dataset."
302
293
  ),
303
294
  )
304
- # Validate that key package version in user workspace are supported in snowflake conda channel
305
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
306
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
295
+
307
296
 
308
297
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
309
298
  @telemetry.send_api_usage_telemetry(
@@ -349,7 +338,8 @@ class LedoitWolf(BaseTransformer):
349
338
 
350
339
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
351
340
 
352
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
341
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
342
+ self._deps = self._get_dependencies()
353
343
  assert isinstance(
354
344
  dataset._session, Session
355
345
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -432,10 +422,8 @@ class LedoitWolf(BaseTransformer):
432
422
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
433
423
  expected_dtype = convert_sp_to_sf_type(output_types[0])
434
424
 
435
- self._deps = self._batch_inference_validate_snowpark(
436
- dataset=dataset,
437
- inference_method=inference_method,
438
- )
425
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
426
+ self._deps = self._get_dependencies()
439
427
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
440
428
 
441
429
  transform_kwargs = dict(
@@ -502,16 +490,40 @@ class LedoitWolf(BaseTransformer):
502
490
  self._is_fitted = True
503
491
  return output_result
504
492
 
493
+
494
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
495
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
496
+ """ Method not supported for this class.
505
497
 
506
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
507
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
508
- """
498
+
499
+ Raises:
500
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
501
+
502
+ Args:
503
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
504
+ Snowpark or Pandas DataFrame.
505
+ output_cols_prefix: Prefix for the response columns
509
506
  Returns:
510
507
  Transformed dataset.
511
508
  """
512
- self.fit(dataset)
513
- assert self._sklearn_object is not None
514
- return self._sklearn_object.embedding_
509
+ self._infer_input_output_cols(dataset)
510
+ super()._check_dataset_type(dataset)
511
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
512
+ estimator=self._sklearn_object,
513
+ dataset=dataset,
514
+ input_cols=self.input_cols,
515
+ label_cols=self.label_cols,
516
+ sample_weight_col=self.sample_weight_col,
517
+ autogenerated=self._autogenerated,
518
+ subproject=_SUBPROJECT,
519
+ )
520
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=self.output_cols,
523
+ )
524
+ self._sklearn_object = fitted_estimator
525
+ self._is_fitted = True
526
+ return output_result
515
527
 
516
528
 
517
529
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -602,10 +614,8 @@ class LedoitWolf(BaseTransformer):
602
614
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
603
615
 
604
616
  if isinstance(dataset, DataFrame):
605
- self._deps = self._batch_inference_validate_snowpark(
606
- dataset=dataset,
607
- inference_method=inference_method,
608
- )
617
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
618
+ self._deps = self._get_dependencies()
609
619
  assert isinstance(
610
620
  dataset._session, Session
611
621
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -670,10 +680,8 @@ class LedoitWolf(BaseTransformer):
670
680
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
671
681
 
672
682
  if isinstance(dataset, DataFrame):
673
- self._deps = self._batch_inference_validate_snowpark(
674
- dataset=dataset,
675
- inference_method=inference_method,
676
- )
683
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
684
+ self._deps = self._get_dependencies()
677
685
  assert isinstance(
678
686
  dataset._session, Session
679
687
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -735,10 +743,8 @@ class LedoitWolf(BaseTransformer):
735
743
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
736
744
 
737
745
  if isinstance(dataset, DataFrame):
738
- self._deps = self._batch_inference_validate_snowpark(
739
- dataset=dataset,
740
- inference_method=inference_method,
741
- )
746
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
747
+ self._deps = self._get_dependencies()
742
748
  assert isinstance(
743
749
  dataset._session, Session
744
750
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -804,10 +810,8 @@ class LedoitWolf(BaseTransformer):
804
810
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
805
811
 
806
812
  if isinstance(dataset, DataFrame):
807
- self._deps = self._batch_inference_validate_snowpark(
808
- dataset=dataset,
809
- inference_method=inference_method,
810
- )
813
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
814
+ self._deps = self._get_dependencies()
811
815
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
812
816
  transform_kwargs = dict(
813
817
  session=dataset._session,
@@ -871,17 +875,15 @@ class LedoitWolf(BaseTransformer):
871
875
  transform_kwargs: ScoreKwargsTypedDict = dict()
872
876
 
873
877
  if isinstance(dataset, DataFrame):
874
- self._deps = self._batch_inference_validate_snowpark(
875
- dataset=dataset,
876
- inference_method="score",
877
- )
878
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
879
+ self._deps = self._get_dependencies()
878
880
  selected_cols = self._get_active_columns()
879
881
  if len(selected_cols) > 0:
880
882
  dataset = dataset.select(selected_cols)
881
883
  assert isinstance(dataset._session, Session) # keep mypy happy
882
884
  transform_kwargs = dict(
883
885
  session=dataset._session,
884
- dependencies=["snowflake-snowpark-python"] + self._deps,
886
+ dependencies=self._deps,
885
887
  score_sproc_imports=['sklearn'],
886
888
  )
887
889
  elif isinstance(dataset, pd.DataFrame):
@@ -946,11 +948,8 @@ class LedoitWolf(BaseTransformer):
946
948
 
947
949
  if isinstance(dataset, DataFrame):
948
950
 
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method=inference_method,
952
-
953
- )
951
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
952
+ self._deps = self._get_dependencies()
954
953
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
955
954
  transform_kwargs = dict(
956
955
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MinCovDet(BaseTransformer):
70
64
  r"""Minimum Covariance Determinant (MCD): robust estimator of covariance
71
65
  For more details on this class, see [sklearn.covariance.MinCovDet]
@@ -282,20 +276,17 @@ class MinCovDet(BaseTransformer):
282
276
  self,
283
277
  dataset: DataFrame,
284
278
  inference_method: str,
285
- ) -> List[str]:
286
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
287
- return the available package that exists in the snowflake anaconda channel
279
+ ) -> None:
280
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
288
281
 
289
282
  Args:
290
283
  dataset: snowpark dataframe
291
284
  inference_method: the inference method such as predict, score...
292
-
285
+
293
286
  Raises:
294
287
  SnowflakeMLException: If the estimator is not fitted, raise error
295
288
  SnowflakeMLException: If the session is None, raise error
296
289
 
297
- Returns:
298
- A list of available package that exists in the snowflake anaconda channel
299
290
  """
300
291
  if not self._is_fitted:
301
292
  raise exceptions.SnowflakeMLException(
@@ -313,9 +304,7 @@ class MinCovDet(BaseTransformer):
313
304
  "Session must not specified for snowpark dataset."
314
305
  ),
315
306
  )
316
- # Validate that key package version in user workspace are supported in snowflake conda channel
317
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
318
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
307
+
319
308
 
320
309
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
321
310
  @telemetry.send_api_usage_telemetry(
@@ -361,7 +350,8 @@ class MinCovDet(BaseTransformer):
361
350
 
362
351
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
363
352
 
364
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
353
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
354
+ self._deps = self._get_dependencies()
365
355
  assert isinstance(
366
356
  dataset._session, Session
367
357
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -444,10 +434,8 @@ class MinCovDet(BaseTransformer):
444
434
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
445
435
  expected_dtype = convert_sp_to_sf_type(output_types[0])
446
436
 
447
- self._deps = self._batch_inference_validate_snowpark(
448
- dataset=dataset,
449
- inference_method=inference_method,
450
- )
437
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
438
+ self._deps = self._get_dependencies()
451
439
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
452
440
 
453
441
  transform_kwargs = dict(
@@ -514,16 +502,40 @@ class MinCovDet(BaseTransformer):
514
502
  self._is_fitted = True
515
503
  return output_result
516
504
 
505
+
506
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
507
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
508
+ """ Method not supported for this class.
517
509
 
518
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
519
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
520
- """
510
+
511
+ Raises:
512
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
513
+
514
+ Args:
515
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
516
+ Snowpark or Pandas DataFrame.
517
+ output_cols_prefix: Prefix for the response columns
521
518
  Returns:
522
519
  Transformed dataset.
523
520
  """
524
- self.fit(dataset)
525
- assert self._sklearn_object is not None
526
- return self._sklearn_object.embedding_
521
+ self._infer_input_output_cols(dataset)
522
+ super()._check_dataset_type(dataset)
523
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
524
+ estimator=self._sklearn_object,
525
+ dataset=dataset,
526
+ input_cols=self.input_cols,
527
+ label_cols=self.label_cols,
528
+ sample_weight_col=self.sample_weight_col,
529
+ autogenerated=self._autogenerated,
530
+ subproject=_SUBPROJECT,
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=self.output_cols,
535
+ )
536
+ self._sklearn_object = fitted_estimator
537
+ self._is_fitted = True
538
+ return output_result
527
539
 
528
540
 
529
541
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -614,10 +626,8 @@ class MinCovDet(BaseTransformer):
614
626
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
615
627
 
616
628
  if isinstance(dataset, DataFrame):
617
- self._deps = self._batch_inference_validate_snowpark(
618
- dataset=dataset,
619
- inference_method=inference_method,
620
- )
629
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
630
+ self._deps = self._get_dependencies()
621
631
  assert isinstance(
622
632
  dataset._session, Session
623
633
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -682,10 +692,8 @@ class MinCovDet(BaseTransformer):
682
692
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
683
693
 
684
694
  if isinstance(dataset, DataFrame):
685
- self._deps = self._batch_inference_validate_snowpark(
686
- dataset=dataset,
687
- inference_method=inference_method,
688
- )
695
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
696
+ self._deps = self._get_dependencies()
689
697
  assert isinstance(
690
698
  dataset._session, Session
691
699
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -747,10 +755,8 @@ class MinCovDet(BaseTransformer):
747
755
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
748
756
 
749
757
  if isinstance(dataset, DataFrame):
750
- self._deps = self._batch_inference_validate_snowpark(
751
- dataset=dataset,
752
- inference_method=inference_method,
753
- )
758
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
759
+ self._deps = self._get_dependencies()
754
760
  assert isinstance(
755
761
  dataset._session, Session
756
762
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -816,10 +822,8 @@ class MinCovDet(BaseTransformer):
816
822
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
817
823
 
818
824
  if isinstance(dataset, DataFrame):
819
- self._deps = self._batch_inference_validate_snowpark(
820
- dataset=dataset,
821
- inference_method=inference_method,
822
- )
825
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
826
+ self._deps = self._get_dependencies()
823
827
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
824
828
  transform_kwargs = dict(
825
829
  session=dataset._session,
@@ -883,17 +887,15 @@ class MinCovDet(BaseTransformer):
883
887
  transform_kwargs: ScoreKwargsTypedDict = dict()
884
888
 
885
889
  if isinstance(dataset, DataFrame):
886
- self._deps = self._batch_inference_validate_snowpark(
887
- dataset=dataset,
888
- inference_method="score",
889
- )
890
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
891
+ self._deps = self._get_dependencies()
890
892
  selected_cols = self._get_active_columns()
891
893
  if len(selected_cols) > 0:
892
894
  dataset = dataset.select(selected_cols)
893
895
  assert isinstance(dataset._session, Session) # keep mypy happy
894
896
  transform_kwargs = dict(
895
897
  session=dataset._session,
896
- dependencies=["snowflake-snowpark-python"] + self._deps,
898
+ dependencies=self._deps,
897
899
  score_sproc_imports=['sklearn'],
898
900
  )
899
901
  elif isinstance(dataset, pd.DataFrame):
@@ -958,11 +960,8 @@ class MinCovDet(BaseTransformer):
958
960
 
959
961
  if isinstance(dataset, DataFrame):
960
962
 
961
- self._deps = self._batch_inference_validate_snowpark(
962
- dataset=dataset,
963
- inference_method=inference_method,
964
-
965
- )
963
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
964
+ self._deps = self._get_dependencies()
966
965
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
967
966
  transform_kwargs = dict(
968
967
  session = dataset._session,