snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.discriminant_analysis".r
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LinearDiscriminantAnalysis(BaseTransformer):
70
64
  r"""Linear Discriminant Analysis
71
65
  For more details on this class, see [sklearn.discriminant_analysis.LinearDiscriminantAnalysis]
@@ -318,20 +312,17 @@ class LinearDiscriminantAnalysis(BaseTransformer):
318
312
  self,
319
313
  dataset: DataFrame,
320
314
  inference_method: str,
321
- ) -> List[str]:
322
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
323
- return the available package that exists in the snowflake anaconda channel
315
+ ) -> None:
316
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
324
317
 
325
318
  Args:
326
319
  dataset: snowpark dataframe
327
320
  inference_method: the inference method such as predict, score...
328
-
321
+
329
322
  Raises:
330
323
  SnowflakeMLException: If the estimator is not fitted, raise error
331
324
  SnowflakeMLException: If the session is None, raise error
332
325
 
333
- Returns:
334
- A list of available package that exists in the snowflake anaconda channel
335
326
  """
336
327
  if not self._is_fitted:
337
328
  raise exceptions.SnowflakeMLException(
@@ -349,9 +340,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
349
340
  "Session must not specified for snowpark dataset."
350
341
  ),
351
342
  )
352
- # Validate that key package version in user workspace are supported in snowflake conda channel
353
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
354
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
343
+
355
344
 
356
345
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
357
346
  @telemetry.send_api_usage_telemetry(
@@ -399,7 +388,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
399
388
 
400
389
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
401
390
 
402
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
391
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
392
+ self._deps = self._get_dependencies()
403
393
  assert isinstance(
404
394
  dataset._session, Session
405
395
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -484,10 +474,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
484
474
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
485
475
  expected_dtype = convert_sp_to_sf_type(output_types[0])
486
476
 
487
- self._deps = self._batch_inference_validate_snowpark(
488
- dataset=dataset,
489
- inference_method=inference_method,
490
- )
477
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
478
+ self._deps = self._get_dependencies()
491
479
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
492
480
 
493
481
  transform_kwargs = dict(
@@ -554,16 +542,42 @@ class LinearDiscriminantAnalysis(BaseTransformer):
554
542
  self._is_fitted = True
555
543
  return output_result
556
544
 
545
+
546
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
547
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
548
+ """ Fit to data, then transform it
549
+ For more details on this function, see [sklearn.discriminant_analysis.LinearDiscriminantAnalysis.fit_transform]
550
+ (https://scikit-learn.org/stable/modules/generated/sklearn.discriminant_analysis.LinearDiscriminantAnalysis.html#sklearn.discriminant_analysis.LinearDiscriminantAnalysis.fit_transform)
551
+
557
552
 
558
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
559
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
560
- """
553
+ Raises:
554
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
555
+
556
+ Args:
557
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
558
+ Snowpark or Pandas DataFrame.
559
+ output_cols_prefix: Prefix for the response columns
561
560
  Returns:
562
561
  Transformed dataset.
563
562
  """
564
- self.fit(dataset)
565
- assert self._sklearn_object is not None
566
- return self._sklearn_object.embedding_
563
+ self._infer_input_output_cols(dataset)
564
+ super()._check_dataset_type(dataset)
565
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
566
+ estimator=self._sklearn_object,
567
+ dataset=dataset,
568
+ input_cols=self.input_cols,
569
+ label_cols=self.label_cols,
570
+ sample_weight_col=self.sample_weight_col,
571
+ autogenerated=self._autogenerated,
572
+ subproject=_SUBPROJECT,
573
+ )
574
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=self.output_cols,
577
+ )
578
+ self._sklearn_object = fitted_estimator
579
+ self._is_fitted = True
580
+ return output_result
567
581
 
568
582
 
569
583
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -656,10 +670,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
656
670
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
657
671
 
658
672
  if isinstance(dataset, DataFrame):
659
- self._deps = self._batch_inference_validate_snowpark(
660
- dataset=dataset,
661
- inference_method=inference_method,
662
- )
673
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
674
+ self._deps = self._get_dependencies()
663
675
  assert isinstance(
664
676
  dataset._session, Session
665
677
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -726,10 +738,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
726
738
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
727
739
 
728
740
  if isinstance(dataset, DataFrame):
729
- self._deps = self._batch_inference_validate_snowpark(
730
- dataset=dataset,
731
- inference_method=inference_method,
732
- )
741
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
742
+ self._deps = self._get_dependencies()
733
743
  assert isinstance(
734
744
  dataset._session, Session
735
745
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -793,10 +803,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
793
803
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
794
804
 
795
805
  if isinstance(dataset, DataFrame):
796
- self._deps = self._batch_inference_validate_snowpark(
797
- dataset=dataset,
798
- inference_method=inference_method,
799
- )
806
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
807
+ self._deps = self._get_dependencies()
800
808
  assert isinstance(
801
809
  dataset._session, Session
802
810
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -862,10 +870,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
862
870
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
863
871
 
864
872
  if isinstance(dataset, DataFrame):
865
- self._deps = self._batch_inference_validate_snowpark(
866
- dataset=dataset,
867
- inference_method=inference_method,
868
- )
873
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
874
+ self._deps = self._get_dependencies()
869
875
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
870
876
  transform_kwargs = dict(
871
877
  session=dataset._session,
@@ -929,17 +935,15 @@ class LinearDiscriminantAnalysis(BaseTransformer):
929
935
  transform_kwargs: ScoreKwargsTypedDict = dict()
930
936
 
931
937
  if isinstance(dataset, DataFrame):
932
- self._deps = self._batch_inference_validate_snowpark(
933
- dataset=dataset,
934
- inference_method="score",
935
- )
938
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
939
+ self._deps = self._get_dependencies()
936
940
  selected_cols = self._get_active_columns()
937
941
  if len(selected_cols) > 0:
938
942
  dataset = dataset.select(selected_cols)
939
943
  assert isinstance(dataset._session, Session) # keep mypy happy
940
944
  transform_kwargs = dict(
941
945
  session=dataset._session,
942
- dependencies=["snowflake-snowpark-python"] + self._deps,
946
+ dependencies=self._deps,
943
947
  score_sproc_imports=['sklearn'],
944
948
  )
945
949
  elif isinstance(dataset, pd.DataFrame):
@@ -1004,11 +1008,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
1004
1008
 
1005
1009
  if isinstance(dataset, DataFrame):
1006
1010
 
1007
- self._deps = self._batch_inference_validate_snowpark(
1008
- dataset=dataset,
1009
- inference_method=inference_method,
1010
-
1011
- )
1011
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1012
+ self._deps = self._get_dependencies()
1012
1013
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1013
1014
  transform_kwargs = dict(
1014
1015
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.discriminant_analysis".r
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class QuadraticDiscriminantAnalysis(BaseTransformer):
70
64
  r"""Quadratic Discriminant Analysis
71
65
  For more details on this class, see [sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis]
@@ -280,20 +274,17 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
280
274
  self,
281
275
  dataset: DataFrame,
282
276
  inference_method: str,
283
- ) -> List[str]:
284
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
285
- return the available package that exists in the snowflake anaconda channel
277
+ ) -> None:
278
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
286
279
 
287
280
  Args:
288
281
  dataset: snowpark dataframe
289
282
  inference_method: the inference method such as predict, score...
290
-
283
+
291
284
  Raises:
292
285
  SnowflakeMLException: If the estimator is not fitted, raise error
293
286
  SnowflakeMLException: If the session is None, raise error
294
287
 
295
- Returns:
296
- A list of available package that exists in the snowflake anaconda channel
297
288
  """
298
289
  if not self._is_fitted:
299
290
  raise exceptions.SnowflakeMLException(
@@ -311,9 +302,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
311
302
  "Session must not specified for snowpark dataset."
312
303
  ),
313
304
  )
314
- # Validate that key package version in user workspace are supported in snowflake conda channel
315
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
316
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
305
+
317
306
 
318
307
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
319
308
  @telemetry.send_api_usage_telemetry(
@@ -361,7 +350,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
361
350
 
362
351
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
363
352
 
364
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
353
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
354
+ self._deps = self._get_dependencies()
365
355
  assert isinstance(
366
356
  dataset._session, Session
367
357
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -444,10 +434,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
444
434
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
445
435
  expected_dtype = convert_sp_to_sf_type(output_types[0])
446
436
 
447
- self._deps = self._batch_inference_validate_snowpark(
448
- dataset=dataset,
449
- inference_method=inference_method,
450
- )
437
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
438
+ self._deps = self._get_dependencies()
451
439
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
452
440
 
453
441
  transform_kwargs = dict(
@@ -514,16 +502,40 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
514
502
  self._is_fitted = True
515
503
  return output_result
516
504
 
505
+
506
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
507
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
508
+ """ Method not supported for this class.
517
509
 
518
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
519
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
520
- """
510
+
511
+ Raises:
512
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
513
+
514
+ Args:
515
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
516
+ Snowpark or Pandas DataFrame.
517
+ output_cols_prefix: Prefix for the response columns
521
518
  Returns:
522
519
  Transformed dataset.
523
520
  """
524
- self.fit(dataset)
525
- assert self._sklearn_object is not None
526
- return self._sklearn_object.embedding_
521
+ self._infer_input_output_cols(dataset)
522
+ super()._check_dataset_type(dataset)
523
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
524
+ estimator=self._sklearn_object,
525
+ dataset=dataset,
526
+ input_cols=self.input_cols,
527
+ label_cols=self.label_cols,
528
+ sample_weight_col=self.sample_weight_col,
529
+ autogenerated=self._autogenerated,
530
+ subproject=_SUBPROJECT,
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=self.output_cols,
535
+ )
536
+ self._sklearn_object = fitted_estimator
537
+ self._is_fitted = True
538
+ return output_result
527
539
 
528
540
 
529
541
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -616,10 +628,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
616
628
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
617
629
 
618
630
  if isinstance(dataset, DataFrame):
619
- self._deps = self._batch_inference_validate_snowpark(
620
- dataset=dataset,
621
- inference_method=inference_method,
622
- )
631
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
632
+ self._deps = self._get_dependencies()
623
633
  assert isinstance(
624
634
  dataset._session, Session
625
635
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -686,10 +696,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
686
696
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
687
697
 
688
698
  if isinstance(dataset, DataFrame):
689
- self._deps = self._batch_inference_validate_snowpark(
690
- dataset=dataset,
691
- inference_method=inference_method,
692
- )
699
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
700
+ self._deps = self._get_dependencies()
693
701
  assert isinstance(
694
702
  dataset._session, Session
695
703
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -753,10 +761,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
753
761
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
754
762
 
755
763
  if isinstance(dataset, DataFrame):
756
- self._deps = self._batch_inference_validate_snowpark(
757
- dataset=dataset,
758
- inference_method=inference_method,
759
- )
764
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
765
+ self._deps = self._get_dependencies()
760
766
  assert isinstance(
761
767
  dataset._session, Session
762
768
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -822,10 +828,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
822
828
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
823
829
 
824
830
  if isinstance(dataset, DataFrame):
825
- self._deps = self._batch_inference_validate_snowpark(
826
- dataset=dataset,
827
- inference_method=inference_method,
828
- )
831
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
832
+ self._deps = self._get_dependencies()
829
833
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
830
834
  transform_kwargs = dict(
831
835
  session=dataset._session,
@@ -889,17 +893,15 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
889
893
  transform_kwargs: ScoreKwargsTypedDict = dict()
890
894
 
891
895
  if isinstance(dataset, DataFrame):
892
- self._deps = self._batch_inference_validate_snowpark(
893
- dataset=dataset,
894
- inference_method="score",
895
- )
896
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
897
+ self._deps = self._get_dependencies()
896
898
  selected_cols = self._get_active_columns()
897
899
  if len(selected_cols) > 0:
898
900
  dataset = dataset.select(selected_cols)
899
901
  assert isinstance(dataset._session, Session) # keep mypy happy
900
902
  transform_kwargs = dict(
901
903
  session=dataset._session,
902
- dependencies=["snowflake-snowpark-python"] + self._deps,
904
+ dependencies=self._deps,
903
905
  score_sproc_imports=['sklearn'],
904
906
  )
905
907
  elif isinstance(dataset, pd.DataFrame):
@@ -964,11 +966,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
964
966
 
965
967
  if isinstance(dataset, DataFrame):
966
968
 
967
- self._deps = self._batch_inference_validate_snowpark(
968
- dataset=dataset,
969
- inference_method=inference_method,
970
-
971
- )
969
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
970
+ self._deps = self._get_dependencies()
972
971
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
973
972
  transform_kwargs = dict(
974
973
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class AdaBoostClassifier(BaseTransformer):
70
64
  r"""An AdaBoost classifier
71
65
  For more details on this class, see [sklearn.ensemble.AdaBoostClassifier]
@@ -305,20 +299,17 @@ class AdaBoostClassifier(BaseTransformer):
305
299
  self,
306
300
  dataset: DataFrame,
307
301
  inference_method: str,
308
- ) -> List[str]:
309
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
310
- return the available package that exists in the snowflake anaconda channel
302
+ ) -> None:
303
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
311
304
 
312
305
  Args:
313
306
  dataset: snowpark dataframe
314
307
  inference_method: the inference method such as predict, score...
315
-
308
+
316
309
  Raises:
317
310
  SnowflakeMLException: If the estimator is not fitted, raise error
318
311
  SnowflakeMLException: If the session is None, raise error
319
312
 
320
- Returns:
321
- A list of available package that exists in the snowflake anaconda channel
322
313
  """
323
314
  if not self._is_fitted:
324
315
  raise exceptions.SnowflakeMLException(
@@ -336,9 +327,7 @@ class AdaBoostClassifier(BaseTransformer):
336
327
  "Session must not specified for snowpark dataset."
337
328
  ),
338
329
  )
339
- # Validate that key package version in user workspace are supported in snowflake conda channel
340
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
341
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
330
+
342
331
 
343
332
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
344
333
  @telemetry.send_api_usage_telemetry(
@@ -386,7 +375,8 @@ class AdaBoostClassifier(BaseTransformer):
386
375
 
387
376
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
388
377
 
389
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
378
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
379
+ self._deps = self._get_dependencies()
390
380
  assert isinstance(
391
381
  dataset._session, Session
392
382
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -469,10 +459,8 @@ class AdaBoostClassifier(BaseTransformer):
469
459
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
470
460
  expected_dtype = convert_sp_to_sf_type(output_types[0])
471
461
 
472
- self._deps = self._batch_inference_validate_snowpark(
473
- dataset=dataset,
474
- inference_method=inference_method,
475
- )
462
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
463
+ self._deps = self._get_dependencies()
476
464
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
477
465
 
478
466
  transform_kwargs = dict(
@@ -539,16 +527,40 @@ class AdaBoostClassifier(BaseTransformer):
539
527
  self._is_fitted = True
540
528
  return output_result
541
529
 
530
+
531
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
532
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
533
+ """ Method not supported for this class.
542
534
 
543
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
544
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
545
- """
535
+
536
+ Raises:
537
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
538
+
539
+ Args:
540
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
541
+ Snowpark or Pandas DataFrame.
542
+ output_cols_prefix: Prefix for the response columns
546
543
  Returns:
547
544
  Transformed dataset.
548
545
  """
549
- self.fit(dataset)
550
- assert self._sklearn_object is not None
551
- return self._sklearn_object.embedding_
546
+ self._infer_input_output_cols(dataset)
547
+ super()._check_dataset_type(dataset)
548
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
549
+ estimator=self._sklearn_object,
550
+ dataset=dataset,
551
+ input_cols=self.input_cols,
552
+ label_cols=self.label_cols,
553
+ sample_weight_col=self.sample_weight_col,
554
+ autogenerated=self._autogenerated,
555
+ subproject=_SUBPROJECT,
556
+ )
557
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
558
+ drop_input_cols=self._drop_input_cols,
559
+ expected_output_cols_list=self.output_cols,
560
+ )
561
+ self._sklearn_object = fitted_estimator
562
+ self._is_fitted = True
563
+ return output_result
552
564
 
553
565
 
554
566
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -641,10 +653,8 @@ class AdaBoostClassifier(BaseTransformer):
641
653
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
642
654
 
643
655
  if isinstance(dataset, DataFrame):
644
- self._deps = self._batch_inference_validate_snowpark(
645
- dataset=dataset,
646
- inference_method=inference_method,
647
- )
656
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
657
+ self._deps = self._get_dependencies()
648
658
  assert isinstance(
649
659
  dataset._session, Session
650
660
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -711,10 +721,8 @@ class AdaBoostClassifier(BaseTransformer):
711
721
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
712
722
 
713
723
  if isinstance(dataset, DataFrame):
714
- self._deps = self._batch_inference_validate_snowpark(
715
- dataset=dataset,
716
- inference_method=inference_method,
717
- )
724
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
725
+ self._deps = self._get_dependencies()
718
726
  assert isinstance(
719
727
  dataset._session, Session
720
728
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -778,10 +786,8 @@ class AdaBoostClassifier(BaseTransformer):
778
786
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
779
787
 
780
788
  if isinstance(dataset, DataFrame):
781
- self._deps = self._batch_inference_validate_snowpark(
782
- dataset=dataset,
783
- inference_method=inference_method,
784
- )
789
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
790
+ self._deps = self._get_dependencies()
785
791
  assert isinstance(
786
792
  dataset._session, Session
787
793
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -847,10 +853,8 @@ class AdaBoostClassifier(BaseTransformer):
847
853
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
848
854
 
849
855
  if isinstance(dataset, DataFrame):
850
- self._deps = self._batch_inference_validate_snowpark(
851
- dataset=dataset,
852
- inference_method=inference_method,
853
- )
856
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
857
+ self._deps = self._get_dependencies()
854
858
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
855
859
  transform_kwargs = dict(
856
860
  session=dataset._session,
@@ -914,17 +918,15 @@ class AdaBoostClassifier(BaseTransformer):
914
918
  transform_kwargs: ScoreKwargsTypedDict = dict()
915
919
 
916
920
  if isinstance(dataset, DataFrame):
917
- self._deps = self._batch_inference_validate_snowpark(
918
- dataset=dataset,
919
- inference_method="score",
920
- )
921
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
922
+ self._deps = self._get_dependencies()
921
923
  selected_cols = self._get_active_columns()
922
924
  if len(selected_cols) > 0:
923
925
  dataset = dataset.select(selected_cols)
924
926
  assert isinstance(dataset._session, Session) # keep mypy happy
925
927
  transform_kwargs = dict(
926
928
  session=dataset._session,
927
- dependencies=["snowflake-snowpark-python"] + self._deps,
929
+ dependencies=self._deps,
928
930
  score_sproc_imports=['sklearn'],
929
931
  )
930
932
  elif isinstance(dataset, pd.DataFrame):
@@ -989,11 +991,8 @@ class AdaBoostClassifier(BaseTransformer):
989
991
 
990
992
  if isinstance(dataset, DataFrame):
991
993
 
992
- self._deps = self._batch_inference_validate_snowpark(
993
- dataset=dataset,
994
- inference_method=inference_method,
995
-
996
- )
994
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
995
+ self._deps = self._get_dependencies()
997
996
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
998
997
  transform_kwargs = dict(
999
998
  session = dataset._session,