snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class EllipticEnvelope(BaseTransformer):
70
64
  r"""An object for detecting outliers in a Gaussian distributed dataset
71
65
  For more details on this class, see [sklearn.covariance.EllipticEnvelope]
@@ -287,20 +281,17 @@ class EllipticEnvelope(BaseTransformer):
287
281
  self,
288
282
  dataset: DataFrame,
289
283
  inference_method: str,
290
- ) -> List[str]:
291
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
292
- return the available package that exists in the snowflake anaconda channel
284
+ ) -> None:
285
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
293
286
 
294
287
  Args:
295
288
  dataset: snowpark dataframe
296
289
  inference_method: the inference method such as predict, score...
297
-
290
+
298
291
  Raises:
299
292
  SnowflakeMLException: If the estimator is not fitted, raise error
300
293
  SnowflakeMLException: If the session is None, raise error
301
294
 
302
- Returns:
303
- A list of available package that exists in the snowflake anaconda channel
304
295
  """
305
296
  if not self._is_fitted:
306
297
  raise exceptions.SnowflakeMLException(
@@ -318,9 +309,7 @@ class EllipticEnvelope(BaseTransformer):
318
309
  "Session must not specified for snowpark dataset."
319
310
  ),
320
311
  )
321
- # Validate that key package version in user workspace are supported in snowflake conda channel
322
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
323
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
312
+
324
313
 
325
314
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
326
315
  @telemetry.send_api_usage_telemetry(
@@ -368,7 +357,8 @@ class EllipticEnvelope(BaseTransformer):
368
357
 
369
358
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
370
359
 
371
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
360
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
361
+ self._deps = self._get_dependencies()
372
362
  assert isinstance(
373
363
  dataset._session, Session
374
364
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -451,10 +441,8 @@ class EllipticEnvelope(BaseTransformer):
451
441
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
452
442
  expected_dtype = convert_sp_to_sf_type(output_types[0])
453
443
 
454
- self._deps = self._batch_inference_validate_snowpark(
455
- dataset=dataset,
456
- inference_method=inference_method,
457
- )
444
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
445
+ self._deps = self._get_dependencies()
458
446
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
459
447
 
460
448
  transform_kwargs = dict(
@@ -523,16 +511,40 @@ class EllipticEnvelope(BaseTransformer):
523
511
  self._is_fitted = True
524
512
  return output_result
525
513
 
514
+
515
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
516
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
517
+ """ Method not supported for this class.
518
+
526
519
 
527
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
528
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
529
- """
520
+ Raises:
521
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
522
+
523
+ Args:
524
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
525
+ Snowpark or Pandas DataFrame.
526
+ output_cols_prefix: Prefix for the response columns
530
527
  Returns:
531
528
  Transformed dataset.
532
529
  """
533
- self.fit(dataset)
534
- assert self._sklearn_object is not None
535
- return self._sklearn_object.embedding_
530
+ self._infer_input_output_cols(dataset)
531
+ super()._check_dataset_type(dataset)
532
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
533
+ estimator=self._sklearn_object,
534
+ dataset=dataset,
535
+ input_cols=self.input_cols,
536
+ label_cols=self.label_cols,
537
+ sample_weight_col=self.sample_weight_col,
538
+ autogenerated=self._autogenerated,
539
+ subproject=_SUBPROJECT,
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=self.output_cols,
544
+ )
545
+ self._sklearn_object = fitted_estimator
546
+ self._is_fitted = True
547
+ return output_result
536
548
 
537
549
 
538
550
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -623,10 +635,8 @@ class EllipticEnvelope(BaseTransformer):
623
635
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
624
636
 
625
637
  if isinstance(dataset, DataFrame):
626
- self._deps = self._batch_inference_validate_snowpark(
627
- dataset=dataset,
628
- inference_method=inference_method,
629
- )
638
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
639
+ self._deps = self._get_dependencies()
630
640
  assert isinstance(
631
641
  dataset._session, Session
632
642
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -691,10 +701,8 @@ class EllipticEnvelope(BaseTransformer):
691
701
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
692
702
 
693
703
  if isinstance(dataset, DataFrame):
694
- self._deps = self._batch_inference_validate_snowpark(
695
- dataset=dataset,
696
- inference_method=inference_method,
697
- )
704
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
705
+ self._deps = self._get_dependencies()
698
706
  assert isinstance(
699
707
  dataset._session, Session
700
708
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -758,10 +766,8 @@ class EllipticEnvelope(BaseTransformer):
758
766
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
759
767
 
760
768
  if isinstance(dataset, DataFrame):
761
- self._deps = self._batch_inference_validate_snowpark(
762
- dataset=dataset,
763
- inference_method=inference_method,
764
- )
769
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
770
+ self._deps = self._get_dependencies()
765
771
  assert isinstance(
766
772
  dataset._session, Session
767
773
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -829,10 +835,8 @@ class EllipticEnvelope(BaseTransformer):
829
835
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
830
836
 
831
837
  if isinstance(dataset, DataFrame):
832
- self._deps = self._batch_inference_validate_snowpark(
833
- dataset=dataset,
834
- inference_method=inference_method,
835
- )
838
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
839
+ self._deps = self._get_dependencies()
836
840
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
837
841
  transform_kwargs = dict(
838
842
  session=dataset._session,
@@ -896,17 +900,15 @@ class EllipticEnvelope(BaseTransformer):
896
900
  transform_kwargs: ScoreKwargsTypedDict = dict()
897
901
 
898
902
  if isinstance(dataset, DataFrame):
899
- self._deps = self._batch_inference_validate_snowpark(
900
- dataset=dataset,
901
- inference_method="score",
902
- )
903
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
904
+ self._deps = self._get_dependencies()
903
905
  selected_cols = self._get_active_columns()
904
906
  if len(selected_cols) > 0:
905
907
  dataset = dataset.select(selected_cols)
906
908
  assert isinstance(dataset._session, Session) # keep mypy happy
907
909
  transform_kwargs = dict(
908
910
  session=dataset._session,
909
- dependencies=["snowflake-snowpark-python"] + self._deps,
911
+ dependencies=self._deps,
910
912
  score_sproc_imports=['sklearn'],
911
913
  )
912
914
  elif isinstance(dataset, pd.DataFrame):
@@ -971,11 +973,8 @@ class EllipticEnvelope(BaseTransformer):
971
973
 
972
974
  if isinstance(dataset, DataFrame):
973
975
 
974
- self._deps = self._batch_inference_validate_snowpark(
975
- dataset=dataset,
976
- inference_method=inference_method,
977
-
978
- )
976
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
977
+ self._deps = self._get_dependencies()
979
978
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
980
979
  transform_kwargs = dict(
981
980
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class EmpiricalCovariance(BaseTransformer):
70
64
  r"""Maximum likelihood covariance estimator
71
65
  For more details on this class, see [sklearn.covariance.EmpiricalCovariance]
@@ -263,20 +257,17 @@ class EmpiricalCovariance(BaseTransformer):
263
257
  self,
264
258
  dataset: DataFrame,
265
259
  inference_method: str,
266
- ) -> List[str]:
267
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
268
- return the available package that exists in the snowflake anaconda channel
260
+ ) -> None:
261
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
269
262
 
270
263
  Args:
271
264
  dataset: snowpark dataframe
272
265
  inference_method: the inference method such as predict, score...
273
-
266
+
274
267
  Raises:
275
268
  SnowflakeMLException: If the estimator is not fitted, raise error
276
269
  SnowflakeMLException: If the session is None, raise error
277
270
 
278
- Returns:
279
- A list of available package that exists in the snowflake anaconda channel
280
271
  """
281
272
  if not self._is_fitted:
282
273
  raise exceptions.SnowflakeMLException(
@@ -294,9 +285,7 @@ class EmpiricalCovariance(BaseTransformer):
294
285
  "Session must not specified for snowpark dataset."
295
286
  ),
296
287
  )
297
- # Validate that key package version in user workspace are supported in snowflake conda channel
298
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
299
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
288
+
300
289
 
301
290
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
302
291
  @telemetry.send_api_usage_telemetry(
@@ -342,7 +331,8 @@ class EmpiricalCovariance(BaseTransformer):
342
331
 
343
332
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
344
333
 
345
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
334
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
335
+ self._deps = self._get_dependencies()
346
336
  assert isinstance(
347
337
  dataset._session, Session
348
338
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -425,10 +415,8 @@ class EmpiricalCovariance(BaseTransformer):
425
415
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
426
416
  expected_dtype = convert_sp_to_sf_type(output_types[0])
427
417
 
428
- self._deps = self._batch_inference_validate_snowpark(
429
- dataset=dataset,
430
- inference_method=inference_method,
431
- )
418
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
419
+ self._deps = self._get_dependencies()
432
420
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
433
421
 
434
422
  transform_kwargs = dict(
@@ -495,16 +483,40 @@ class EmpiricalCovariance(BaseTransformer):
495
483
  self._is_fitted = True
496
484
  return output_result
497
485
 
486
+
487
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
488
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
489
+ """ Method not supported for this class.
498
490
 
499
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
500
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
501
- """
491
+
492
+ Raises:
493
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
494
+
495
+ Args:
496
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
497
+ Snowpark or Pandas DataFrame.
498
+ output_cols_prefix: Prefix for the response columns
502
499
  Returns:
503
500
  Transformed dataset.
504
501
  """
505
- self.fit(dataset)
506
- assert self._sklearn_object is not None
507
- return self._sklearn_object.embedding_
502
+ self._infer_input_output_cols(dataset)
503
+ super()._check_dataset_type(dataset)
504
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
505
+ estimator=self._sklearn_object,
506
+ dataset=dataset,
507
+ input_cols=self.input_cols,
508
+ label_cols=self.label_cols,
509
+ sample_weight_col=self.sample_weight_col,
510
+ autogenerated=self._autogenerated,
511
+ subproject=_SUBPROJECT,
512
+ )
513
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=self.output_cols,
516
+ )
517
+ self._sklearn_object = fitted_estimator
518
+ self._is_fitted = True
519
+ return output_result
508
520
 
509
521
 
510
522
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -595,10 +607,8 @@ class EmpiricalCovariance(BaseTransformer):
595
607
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
596
608
 
597
609
  if isinstance(dataset, DataFrame):
598
- self._deps = self._batch_inference_validate_snowpark(
599
- dataset=dataset,
600
- inference_method=inference_method,
601
- )
610
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
611
+ self._deps = self._get_dependencies()
602
612
  assert isinstance(
603
613
  dataset._session, Session
604
614
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -663,10 +673,8 @@ class EmpiricalCovariance(BaseTransformer):
663
673
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
664
674
 
665
675
  if isinstance(dataset, DataFrame):
666
- self._deps = self._batch_inference_validate_snowpark(
667
- dataset=dataset,
668
- inference_method=inference_method,
669
- )
676
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
677
+ self._deps = self._get_dependencies()
670
678
  assert isinstance(
671
679
  dataset._session, Session
672
680
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -728,10 +736,8 @@ class EmpiricalCovariance(BaseTransformer):
728
736
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
729
737
 
730
738
  if isinstance(dataset, DataFrame):
731
- self._deps = self._batch_inference_validate_snowpark(
732
- dataset=dataset,
733
- inference_method=inference_method,
734
- )
739
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
740
+ self._deps = self._get_dependencies()
735
741
  assert isinstance(
736
742
  dataset._session, Session
737
743
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -797,10 +803,8 @@ class EmpiricalCovariance(BaseTransformer):
797
803
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
798
804
 
799
805
  if isinstance(dataset, DataFrame):
800
- self._deps = self._batch_inference_validate_snowpark(
801
- dataset=dataset,
802
- inference_method=inference_method,
803
- )
806
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
807
+ self._deps = self._get_dependencies()
804
808
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
805
809
  transform_kwargs = dict(
806
810
  session=dataset._session,
@@ -864,17 +868,15 @@ class EmpiricalCovariance(BaseTransformer):
864
868
  transform_kwargs: ScoreKwargsTypedDict = dict()
865
869
 
866
870
  if isinstance(dataset, DataFrame):
867
- self._deps = self._batch_inference_validate_snowpark(
868
- dataset=dataset,
869
- inference_method="score",
870
- )
871
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
872
+ self._deps = self._get_dependencies()
871
873
  selected_cols = self._get_active_columns()
872
874
  if len(selected_cols) > 0:
873
875
  dataset = dataset.select(selected_cols)
874
876
  assert isinstance(dataset._session, Session) # keep mypy happy
875
877
  transform_kwargs = dict(
876
878
  session=dataset._session,
877
- dependencies=["snowflake-snowpark-python"] + self._deps,
879
+ dependencies=self._deps,
878
880
  score_sproc_imports=['sklearn'],
879
881
  )
880
882
  elif isinstance(dataset, pd.DataFrame):
@@ -939,11 +941,8 @@ class EmpiricalCovariance(BaseTransformer):
939
941
 
940
942
  if isinstance(dataset, DataFrame):
941
943
 
942
- self._deps = self._batch_inference_validate_snowpark(
943
- dataset=dataset,
944
- inference_method=inference_method,
945
-
946
- )
944
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
945
+ self._deps = self._get_dependencies()
947
946
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
948
947
  transform_kwargs = dict(
949
948
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GraphicalLasso(BaseTransformer):
70
64
  r"""Sparse inverse covariance estimation with an l1-penalized estimator
71
65
  For more details on this class, see [sklearn.covariance.GraphicalLasso]
@@ -311,20 +305,17 @@ class GraphicalLasso(BaseTransformer):
311
305
  self,
312
306
  dataset: DataFrame,
313
307
  inference_method: str,
314
- ) -> List[str]:
315
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
316
- return the available package that exists in the snowflake anaconda channel
308
+ ) -> None:
309
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
317
310
 
318
311
  Args:
319
312
  dataset: snowpark dataframe
320
313
  inference_method: the inference method such as predict, score...
321
-
314
+
322
315
  Raises:
323
316
  SnowflakeMLException: If the estimator is not fitted, raise error
324
317
  SnowflakeMLException: If the session is None, raise error
325
318
 
326
- Returns:
327
- A list of available package that exists in the snowflake anaconda channel
328
319
  """
329
320
  if not self._is_fitted:
330
321
  raise exceptions.SnowflakeMLException(
@@ -342,9 +333,7 @@ class GraphicalLasso(BaseTransformer):
342
333
  "Session must not specified for snowpark dataset."
343
334
  ),
344
335
  )
345
- # Validate that key package version in user workspace are supported in snowflake conda channel
346
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
347
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
336
+
348
337
 
349
338
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
350
339
  @telemetry.send_api_usage_telemetry(
@@ -390,7 +379,8 @@ class GraphicalLasso(BaseTransformer):
390
379
 
391
380
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
392
381
 
393
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
382
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
383
+ self._deps = self._get_dependencies()
394
384
  assert isinstance(
395
385
  dataset._session, Session
396
386
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -473,10 +463,8 @@ class GraphicalLasso(BaseTransformer):
473
463
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
474
464
  expected_dtype = convert_sp_to_sf_type(output_types[0])
475
465
 
476
- self._deps = self._batch_inference_validate_snowpark(
477
- dataset=dataset,
478
- inference_method=inference_method,
479
- )
466
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
467
+ self._deps = self._get_dependencies()
480
468
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
481
469
 
482
470
  transform_kwargs = dict(
@@ -543,16 +531,40 @@ class GraphicalLasso(BaseTransformer):
543
531
  self._is_fitted = True
544
532
  return output_result
545
533
 
534
+
535
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
536
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
537
+ """ Method not supported for this class.
546
538
 
547
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
548
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
549
- """
539
+
540
+ Raises:
541
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
542
+
543
+ Args:
544
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
545
+ Snowpark or Pandas DataFrame.
546
+ output_cols_prefix: Prefix for the response columns
550
547
  Returns:
551
548
  Transformed dataset.
552
549
  """
553
- self.fit(dataset)
554
- assert self._sklearn_object is not None
555
- return self._sklearn_object.embedding_
550
+ self._infer_input_output_cols(dataset)
551
+ super()._check_dataset_type(dataset)
552
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
553
+ estimator=self._sklearn_object,
554
+ dataset=dataset,
555
+ input_cols=self.input_cols,
556
+ label_cols=self.label_cols,
557
+ sample_weight_col=self.sample_weight_col,
558
+ autogenerated=self._autogenerated,
559
+ subproject=_SUBPROJECT,
560
+ )
561
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
562
+ drop_input_cols=self._drop_input_cols,
563
+ expected_output_cols_list=self.output_cols,
564
+ )
565
+ self._sklearn_object = fitted_estimator
566
+ self._is_fitted = True
567
+ return output_result
556
568
 
557
569
 
558
570
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -643,10 +655,8 @@ class GraphicalLasso(BaseTransformer):
643
655
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
644
656
 
645
657
  if isinstance(dataset, DataFrame):
646
- self._deps = self._batch_inference_validate_snowpark(
647
- dataset=dataset,
648
- inference_method=inference_method,
649
- )
658
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
659
+ self._deps = self._get_dependencies()
650
660
  assert isinstance(
651
661
  dataset._session, Session
652
662
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -711,10 +721,8 @@ class GraphicalLasso(BaseTransformer):
711
721
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
712
722
 
713
723
  if isinstance(dataset, DataFrame):
714
- self._deps = self._batch_inference_validate_snowpark(
715
- dataset=dataset,
716
- inference_method=inference_method,
717
- )
724
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
725
+ self._deps = self._get_dependencies()
718
726
  assert isinstance(
719
727
  dataset._session, Session
720
728
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -776,10 +784,8 @@ class GraphicalLasso(BaseTransformer):
776
784
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
777
785
 
778
786
  if isinstance(dataset, DataFrame):
779
- self._deps = self._batch_inference_validate_snowpark(
780
- dataset=dataset,
781
- inference_method=inference_method,
782
- )
787
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
788
+ self._deps = self._get_dependencies()
783
789
  assert isinstance(
784
790
  dataset._session, Session
785
791
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -845,10 +851,8 @@ class GraphicalLasso(BaseTransformer):
845
851
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
846
852
 
847
853
  if isinstance(dataset, DataFrame):
848
- self._deps = self._batch_inference_validate_snowpark(
849
- dataset=dataset,
850
- inference_method=inference_method,
851
- )
854
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
855
+ self._deps = self._get_dependencies()
852
856
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
853
857
  transform_kwargs = dict(
854
858
  session=dataset._session,
@@ -912,17 +916,15 @@ class GraphicalLasso(BaseTransformer):
912
916
  transform_kwargs: ScoreKwargsTypedDict = dict()
913
917
 
914
918
  if isinstance(dataset, DataFrame):
915
- self._deps = self._batch_inference_validate_snowpark(
916
- dataset=dataset,
917
- inference_method="score",
918
- )
919
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
920
+ self._deps = self._get_dependencies()
919
921
  selected_cols = self._get_active_columns()
920
922
  if len(selected_cols) > 0:
921
923
  dataset = dataset.select(selected_cols)
922
924
  assert isinstance(dataset._session, Session) # keep mypy happy
923
925
  transform_kwargs = dict(
924
926
  session=dataset._session,
925
- dependencies=["snowflake-snowpark-python"] + self._deps,
927
+ dependencies=self._deps,
926
928
  score_sproc_imports=['sklearn'],
927
929
  )
928
930
  elif isinstance(dataset, pd.DataFrame):
@@ -987,11 +989,8 @@ class GraphicalLasso(BaseTransformer):
987
989
 
988
990
  if isinstance(dataset, DataFrame):
989
991
 
990
- self._deps = self._batch_inference_validate_snowpark(
991
- dataset=dataset,
992
- inference_method=inference_method,
993
-
994
- )
992
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
993
+ self._deps = self._get_dependencies()
995
994
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
996
995
  transform_kwargs = dict(
997
996
  session = dataset._session,