snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class OAS(BaseTransformer):
70
64
  r"""Oracle Approximating Shrinkage Estimator as proposed in [1]_
71
65
  For more details on this class, see [sklearn.covariance.OAS]
@@ -263,20 +257,17 @@ class OAS(BaseTransformer):
263
257
  self,
264
258
  dataset: DataFrame,
265
259
  inference_method: str,
266
- ) -> List[str]:
267
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
268
- return the available package that exists in the snowflake anaconda channel
260
+ ) -> None:
261
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
269
262
 
270
263
  Args:
271
264
  dataset: snowpark dataframe
272
265
  inference_method: the inference method such as predict, score...
273
-
266
+
274
267
  Raises:
275
268
  SnowflakeMLException: If the estimator is not fitted, raise error
276
269
  SnowflakeMLException: If the session is None, raise error
277
270
 
278
- Returns:
279
- A list of available package that exists in the snowflake anaconda channel
280
271
  """
281
272
  if not self._is_fitted:
282
273
  raise exceptions.SnowflakeMLException(
@@ -294,9 +285,7 @@ class OAS(BaseTransformer):
294
285
  "Session must not specified for snowpark dataset."
295
286
  ),
296
287
  )
297
- # Validate that key package version in user workspace are supported in snowflake conda channel
298
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
299
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
288
+
300
289
 
301
290
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
302
291
  @telemetry.send_api_usage_telemetry(
@@ -342,7 +331,8 @@ class OAS(BaseTransformer):
342
331
 
343
332
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
344
333
 
345
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
334
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
335
+ self._deps = self._get_dependencies()
346
336
  assert isinstance(
347
337
  dataset._session, Session
348
338
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -425,10 +415,8 @@ class OAS(BaseTransformer):
425
415
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
426
416
  expected_dtype = convert_sp_to_sf_type(output_types[0])
427
417
 
428
- self._deps = self._batch_inference_validate_snowpark(
429
- dataset=dataset,
430
- inference_method=inference_method,
431
- )
418
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
419
+ self._deps = self._get_dependencies()
432
420
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
433
421
 
434
422
  transform_kwargs = dict(
@@ -495,16 +483,40 @@ class OAS(BaseTransformer):
495
483
  self._is_fitted = True
496
484
  return output_result
497
485
 
486
+
487
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
488
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
489
+ """ Method not supported for this class.
498
490
 
499
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
500
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
501
- """
491
+
492
+ Raises:
493
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
494
+
495
+ Args:
496
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
497
+ Snowpark or Pandas DataFrame.
498
+ output_cols_prefix: Prefix for the response columns
502
499
  Returns:
503
500
  Transformed dataset.
504
501
  """
505
- self.fit(dataset)
506
- assert self._sklearn_object is not None
507
- return self._sklearn_object.embedding_
502
+ self._infer_input_output_cols(dataset)
503
+ super()._check_dataset_type(dataset)
504
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
505
+ estimator=self._sklearn_object,
506
+ dataset=dataset,
507
+ input_cols=self.input_cols,
508
+ label_cols=self.label_cols,
509
+ sample_weight_col=self.sample_weight_col,
510
+ autogenerated=self._autogenerated,
511
+ subproject=_SUBPROJECT,
512
+ )
513
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=self.output_cols,
516
+ )
517
+ self._sklearn_object = fitted_estimator
518
+ self._is_fitted = True
519
+ return output_result
508
520
 
509
521
 
510
522
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -595,10 +607,8 @@ class OAS(BaseTransformer):
595
607
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
596
608
 
597
609
  if isinstance(dataset, DataFrame):
598
- self._deps = self._batch_inference_validate_snowpark(
599
- dataset=dataset,
600
- inference_method=inference_method,
601
- )
610
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
611
+ self._deps = self._get_dependencies()
602
612
  assert isinstance(
603
613
  dataset._session, Session
604
614
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -663,10 +673,8 @@ class OAS(BaseTransformer):
663
673
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
664
674
 
665
675
  if isinstance(dataset, DataFrame):
666
- self._deps = self._batch_inference_validate_snowpark(
667
- dataset=dataset,
668
- inference_method=inference_method,
669
- )
676
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
677
+ self._deps = self._get_dependencies()
670
678
  assert isinstance(
671
679
  dataset._session, Session
672
680
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -728,10 +736,8 @@ class OAS(BaseTransformer):
728
736
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
729
737
 
730
738
  if isinstance(dataset, DataFrame):
731
- self._deps = self._batch_inference_validate_snowpark(
732
- dataset=dataset,
733
- inference_method=inference_method,
734
- )
739
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
740
+ self._deps = self._get_dependencies()
735
741
  assert isinstance(
736
742
  dataset._session, Session
737
743
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -797,10 +803,8 @@ class OAS(BaseTransformer):
797
803
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
798
804
 
799
805
  if isinstance(dataset, DataFrame):
800
- self._deps = self._batch_inference_validate_snowpark(
801
- dataset=dataset,
802
- inference_method=inference_method,
803
- )
806
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
807
+ self._deps = self._get_dependencies()
804
808
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
805
809
  transform_kwargs = dict(
806
810
  session=dataset._session,
@@ -864,17 +868,15 @@ class OAS(BaseTransformer):
864
868
  transform_kwargs: ScoreKwargsTypedDict = dict()
865
869
 
866
870
  if isinstance(dataset, DataFrame):
867
- self._deps = self._batch_inference_validate_snowpark(
868
- dataset=dataset,
869
- inference_method="score",
870
- )
871
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
872
+ self._deps = self._get_dependencies()
871
873
  selected_cols = self._get_active_columns()
872
874
  if len(selected_cols) > 0:
873
875
  dataset = dataset.select(selected_cols)
874
876
  assert isinstance(dataset._session, Session) # keep mypy happy
875
877
  transform_kwargs = dict(
876
878
  session=dataset._session,
877
- dependencies=["snowflake-snowpark-python"] + self._deps,
879
+ dependencies=self._deps,
878
880
  score_sproc_imports=['sklearn'],
879
881
  )
880
882
  elif isinstance(dataset, pd.DataFrame):
@@ -939,11 +941,8 @@ class OAS(BaseTransformer):
939
941
 
940
942
  if isinstance(dataset, DataFrame):
941
943
 
942
- self._deps = self._batch_inference_validate_snowpark(
943
- dataset=dataset,
944
- inference_method=inference_method,
945
-
946
- )
944
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
945
+ self._deps = self._get_dependencies()
947
946
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
948
947
  transform_kwargs = dict(
949
948
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.covariance".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ShrunkCovariance(BaseTransformer):
70
64
  r"""Covariance estimator with shrinkage
71
65
  For more details on this class, see [sklearn.covariance.ShrunkCovariance]
@@ -269,20 +263,17 @@ class ShrunkCovariance(BaseTransformer):
269
263
  self,
270
264
  dataset: DataFrame,
271
265
  inference_method: str,
272
- ) -> List[str]:
273
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
274
- return the available package that exists in the snowflake anaconda channel
266
+ ) -> None:
267
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
275
268
 
276
269
  Args:
277
270
  dataset: snowpark dataframe
278
271
  inference_method: the inference method such as predict, score...
279
-
272
+
280
273
  Raises:
281
274
  SnowflakeMLException: If the estimator is not fitted, raise error
282
275
  SnowflakeMLException: If the session is None, raise error
283
276
 
284
- Returns:
285
- A list of available package that exists in the snowflake anaconda channel
286
277
  """
287
278
  if not self._is_fitted:
288
279
  raise exceptions.SnowflakeMLException(
@@ -300,9 +291,7 @@ class ShrunkCovariance(BaseTransformer):
300
291
  "Session must not specified for snowpark dataset."
301
292
  ),
302
293
  )
303
- # Validate that key package version in user workspace are supported in snowflake conda channel
304
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
305
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
294
+
306
295
 
307
296
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
308
297
  @telemetry.send_api_usage_telemetry(
@@ -348,7 +337,8 @@ class ShrunkCovariance(BaseTransformer):
348
337
 
349
338
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
350
339
 
351
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
340
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
341
+ self._deps = self._get_dependencies()
352
342
  assert isinstance(
353
343
  dataset._session, Session
354
344
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -431,10 +421,8 @@ class ShrunkCovariance(BaseTransformer):
431
421
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
432
422
  expected_dtype = convert_sp_to_sf_type(output_types[0])
433
423
 
434
- self._deps = self._batch_inference_validate_snowpark(
435
- dataset=dataset,
436
- inference_method=inference_method,
437
- )
424
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
425
+ self._deps = self._get_dependencies()
438
426
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
439
427
 
440
428
  transform_kwargs = dict(
@@ -501,16 +489,40 @@ class ShrunkCovariance(BaseTransformer):
501
489
  self._is_fitted = True
502
490
  return output_result
503
491
 
492
+
493
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
494
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
495
+ """ Method not supported for this class.
504
496
 
505
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
506
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
507
- """
497
+
498
+ Raises:
499
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
500
+
501
+ Args:
502
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
503
+ Snowpark or Pandas DataFrame.
504
+ output_cols_prefix: Prefix for the response columns
508
505
  Returns:
509
506
  Transformed dataset.
510
507
  """
511
- self.fit(dataset)
512
- assert self._sklearn_object is not None
513
- return self._sklearn_object.embedding_
508
+ self._infer_input_output_cols(dataset)
509
+ super()._check_dataset_type(dataset)
510
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
511
+ estimator=self._sklearn_object,
512
+ dataset=dataset,
513
+ input_cols=self.input_cols,
514
+ label_cols=self.label_cols,
515
+ sample_weight_col=self.sample_weight_col,
516
+ autogenerated=self._autogenerated,
517
+ subproject=_SUBPROJECT,
518
+ )
519
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
520
+ drop_input_cols=self._drop_input_cols,
521
+ expected_output_cols_list=self.output_cols,
522
+ )
523
+ self._sklearn_object = fitted_estimator
524
+ self._is_fitted = True
525
+ return output_result
514
526
 
515
527
 
516
528
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -601,10 +613,8 @@ class ShrunkCovariance(BaseTransformer):
601
613
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
602
614
 
603
615
  if isinstance(dataset, DataFrame):
604
- self._deps = self._batch_inference_validate_snowpark(
605
- dataset=dataset,
606
- inference_method=inference_method,
607
- )
616
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
617
+ self._deps = self._get_dependencies()
608
618
  assert isinstance(
609
619
  dataset._session, Session
610
620
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -669,10 +679,8 @@ class ShrunkCovariance(BaseTransformer):
669
679
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
670
680
 
671
681
  if isinstance(dataset, DataFrame):
672
- self._deps = self._batch_inference_validate_snowpark(
673
- dataset=dataset,
674
- inference_method=inference_method,
675
- )
682
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
683
+ self._deps = self._get_dependencies()
676
684
  assert isinstance(
677
685
  dataset._session, Session
678
686
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -734,10 +742,8 @@ class ShrunkCovariance(BaseTransformer):
734
742
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
735
743
 
736
744
  if isinstance(dataset, DataFrame):
737
- self._deps = self._batch_inference_validate_snowpark(
738
- dataset=dataset,
739
- inference_method=inference_method,
740
- )
745
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
746
+ self._deps = self._get_dependencies()
741
747
  assert isinstance(
742
748
  dataset._session, Session
743
749
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -803,10 +809,8 @@ class ShrunkCovariance(BaseTransformer):
803
809
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
804
810
 
805
811
  if isinstance(dataset, DataFrame):
806
- self._deps = self._batch_inference_validate_snowpark(
807
- dataset=dataset,
808
- inference_method=inference_method,
809
- )
812
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
813
+ self._deps = self._get_dependencies()
810
814
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
811
815
  transform_kwargs = dict(
812
816
  session=dataset._session,
@@ -870,17 +874,15 @@ class ShrunkCovariance(BaseTransformer):
870
874
  transform_kwargs: ScoreKwargsTypedDict = dict()
871
875
 
872
876
  if isinstance(dataset, DataFrame):
873
- self._deps = self._batch_inference_validate_snowpark(
874
- dataset=dataset,
875
- inference_method="score",
876
- )
877
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
878
+ self._deps = self._get_dependencies()
877
879
  selected_cols = self._get_active_columns()
878
880
  if len(selected_cols) > 0:
879
881
  dataset = dataset.select(selected_cols)
880
882
  assert isinstance(dataset._session, Session) # keep mypy happy
881
883
  transform_kwargs = dict(
882
884
  session=dataset._session,
883
- dependencies=["snowflake-snowpark-python"] + self._deps,
885
+ dependencies=self._deps,
884
886
  score_sproc_imports=['sklearn'],
885
887
  )
886
888
  elif isinstance(dataset, pd.DataFrame):
@@ -945,11 +947,8 @@ class ShrunkCovariance(BaseTransformer):
945
947
 
946
948
  if isinstance(dataset, DataFrame):
947
949
 
948
- self._deps = self._batch_inference_validate_snowpark(
949
- dataset=dataset,
950
- inference_method=inference_method,
951
-
952
- )
950
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
951
+ self._deps = self._get_dependencies()
953
952
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
954
953
  transform_kwargs = dict(
955
954
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class DictionaryLearning(BaseTransformer):
70
64
  r"""Dictionary learning
71
65
  For more details on this class, see [sklearn.decomposition.DictionaryLearning]
@@ -375,20 +369,17 @@ class DictionaryLearning(BaseTransformer):
375
369
  self,
376
370
  dataset: DataFrame,
377
371
  inference_method: str,
378
- ) -> List[str]:
379
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
380
- return the available package that exists in the snowflake anaconda channel
372
+ ) -> None:
373
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
381
374
 
382
375
  Args:
383
376
  dataset: snowpark dataframe
384
377
  inference_method: the inference method such as predict, score...
385
-
378
+
386
379
  Raises:
387
380
  SnowflakeMLException: If the estimator is not fitted, raise error
388
381
  SnowflakeMLException: If the session is None, raise error
389
382
 
390
- Returns:
391
- A list of available package that exists in the snowflake anaconda channel
392
383
  """
393
384
  if not self._is_fitted:
394
385
  raise exceptions.SnowflakeMLException(
@@ -406,9 +397,7 @@ class DictionaryLearning(BaseTransformer):
406
397
  "Session must not specified for snowpark dataset."
407
398
  ),
408
399
  )
409
- # Validate that key package version in user workspace are supported in snowflake conda channel
410
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
411
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
400
+
412
401
 
413
402
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
414
403
  @telemetry.send_api_usage_telemetry(
@@ -454,7 +443,8 @@ class DictionaryLearning(BaseTransformer):
454
443
 
455
444
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
456
445
 
457
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
446
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
447
+ self._deps = self._get_dependencies()
458
448
  assert isinstance(
459
449
  dataset._session, Session
460
450
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -539,10 +529,8 @@ class DictionaryLearning(BaseTransformer):
539
529
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
540
530
  expected_dtype = convert_sp_to_sf_type(output_types[0])
541
531
 
542
- self._deps = self._batch_inference_validate_snowpark(
543
- dataset=dataset,
544
- inference_method=inference_method,
545
- )
532
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
533
+ self._deps = self._get_dependencies()
546
534
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
547
535
 
548
536
  transform_kwargs = dict(
@@ -609,16 +597,42 @@ class DictionaryLearning(BaseTransformer):
609
597
  self._is_fitted = True
610
598
  return output_result
611
599
 
600
+
601
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
602
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
603
+ """ Fit the model from data in X and return the transformed data
604
+ For more details on this function, see [sklearn.decomposition.DictionaryLearning.fit_transform]
605
+ (https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.DictionaryLearning.html#sklearn.decomposition.DictionaryLearning.fit_transform)
606
+
612
607
 
613
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
614
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
615
- """
608
+ Raises:
609
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
610
+
611
+ Args:
612
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
613
+ Snowpark or Pandas DataFrame.
614
+ output_cols_prefix: Prefix for the response columns
616
615
  Returns:
617
616
  Transformed dataset.
618
617
  """
619
- self.fit(dataset)
620
- assert self._sklearn_object is not None
621
- return self._sklearn_object.embedding_
618
+ self._infer_input_output_cols(dataset)
619
+ super()._check_dataset_type(dataset)
620
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
621
+ estimator=self._sklearn_object,
622
+ dataset=dataset,
623
+ input_cols=self.input_cols,
624
+ label_cols=self.label_cols,
625
+ sample_weight_col=self.sample_weight_col,
626
+ autogenerated=self._autogenerated,
627
+ subproject=_SUBPROJECT,
628
+ )
629
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
630
+ drop_input_cols=self._drop_input_cols,
631
+ expected_output_cols_list=self.output_cols,
632
+ )
633
+ self._sklearn_object = fitted_estimator
634
+ self._is_fitted = True
635
+ return output_result
622
636
 
623
637
 
624
638
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -709,10 +723,8 @@ class DictionaryLearning(BaseTransformer):
709
723
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
710
724
 
711
725
  if isinstance(dataset, DataFrame):
712
- self._deps = self._batch_inference_validate_snowpark(
713
- dataset=dataset,
714
- inference_method=inference_method,
715
- )
726
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
727
+ self._deps = self._get_dependencies()
716
728
  assert isinstance(
717
729
  dataset._session, Session
718
730
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -777,10 +789,8 @@ class DictionaryLearning(BaseTransformer):
777
789
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
778
790
 
779
791
  if isinstance(dataset, DataFrame):
780
- self._deps = self._batch_inference_validate_snowpark(
781
- dataset=dataset,
782
- inference_method=inference_method,
783
- )
792
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
793
+ self._deps = self._get_dependencies()
784
794
  assert isinstance(
785
795
  dataset._session, Session
786
796
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -842,10 +852,8 @@ class DictionaryLearning(BaseTransformer):
842
852
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
843
853
 
844
854
  if isinstance(dataset, DataFrame):
845
- self._deps = self._batch_inference_validate_snowpark(
846
- dataset=dataset,
847
- inference_method=inference_method,
848
- )
855
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
856
+ self._deps = self._get_dependencies()
849
857
  assert isinstance(
850
858
  dataset._session, Session
851
859
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -911,10 +919,8 @@ class DictionaryLearning(BaseTransformer):
911
919
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
912
920
 
913
921
  if isinstance(dataset, DataFrame):
914
- self._deps = self._batch_inference_validate_snowpark(
915
- dataset=dataset,
916
- inference_method=inference_method,
917
- )
922
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
923
+ self._deps = self._get_dependencies()
918
924
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
919
925
  transform_kwargs = dict(
920
926
  session=dataset._session,
@@ -976,17 +982,15 @@ class DictionaryLearning(BaseTransformer):
976
982
  transform_kwargs: ScoreKwargsTypedDict = dict()
977
983
 
978
984
  if isinstance(dataset, DataFrame):
979
- self._deps = self._batch_inference_validate_snowpark(
980
- dataset=dataset,
981
- inference_method="score",
982
- )
985
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
986
+ self._deps = self._get_dependencies()
983
987
  selected_cols = self._get_active_columns()
984
988
  if len(selected_cols) > 0:
985
989
  dataset = dataset.select(selected_cols)
986
990
  assert isinstance(dataset._session, Session) # keep mypy happy
987
991
  transform_kwargs = dict(
988
992
  session=dataset._session,
989
- dependencies=["snowflake-snowpark-python"] + self._deps,
993
+ dependencies=self._deps,
990
994
  score_sproc_imports=['sklearn'],
991
995
  )
992
996
  elif isinstance(dataset, pd.DataFrame):
@@ -1051,11 +1055,8 @@ class DictionaryLearning(BaseTransformer):
1051
1055
 
1052
1056
  if isinstance(dataset, DataFrame):
1053
1057
 
1054
- self._deps = self._batch_inference_validate_snowpark(
1055
- dataset=dataset,
1056
- inference_method=inference_method,
1057
-
1058
- )
1058
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1059
+ self._deps = self._get_dependencies()
1059
1060
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1060
1061
  transform_kwargs = dict(
1061
1062
  session = dataset._session,