snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class Ridge(BaseTransformer):
70
64
  r"""Linear least squares with l2 regularization
71
65
  For more details on this class, see [sklearn.linear_model.Ridge]
@@ -358,20 +352,17 @@ class Ridge(BaseTransformer):
358
352
  self,
359
353
  dataset: DataFrame,
360
354
  inference_method: str,
361
- ) -> List[str]:
362
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
363
- return the available package that exists in the snowflake anaconda channel
355
+ ) -> None:
356
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
364
357
 
365
358
  Args:
366
359
  dataset: snowpark dataframe
367
360
  inference_method: the inference method such as predict, score...
368
-
361
+
369
362
  Raises:
370
363
  SnowflakeMLException: If the estimator is not fitted, raise error
371
364
  SnowflakeMLException: If the session is None, raise error
372
365
 
373
- Returns:
374
- A list of available package that exists in the snowflake anaconda channel
375
366
  """
376
367
  if not self._is_fitted:
377
368
  raise exceptions.SnowflakeMLException(
@@ -389,9 +380,7 @@ class Ridge(BaseTransformer):
389
380
  "Session must not specified for snowpark dataset."
390
381
  ),
391
382
  )
392
- # Validate that key package version in user workspace are supported in snowflake conda channel
393
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
394
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
383
+
395
384
 
396
385
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
397
386
  @telemetry.send_api_usage_telemetry(
@@ -439,7 +428,8 @@ class Ridge(BaseTransformer):
439
428
 
440
429
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
441
430
 
442
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
431
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
432
+ self._deps = self._get_dependencies()
443
433
  assert isinstance(
444
434
  dataset._session, Session
445
435
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -522,10 +512,8 @@ class Ridge(BaseTransformer):
522
512
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
523
513
  expected_dtype = convert_sp_to_sf_type(output_types[0])
524
514
 
525
- self._deps = self._batch_inference_validate_snowpark(
526
- dataset=dataset,
527
- inference_method=inference_method,
528
- )
515
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
516
+ self._deps = self._get_dependencies()
529
517
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
530
518
 
531
519
  transform_kwargs = dict(
@@ -592,16 +580,40 @@ class Ridge(BaseTransformer):
592
580
  self._is_fitted = True
593
581
  return output_result
594
582
 
583
+
584
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
585
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
586
+ """ Method not supported for this class.
595
587
 
596
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
597
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
598
- """
588
+
589
+ Raises:
590
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
591
+
592
+ Args:
593
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
594
+ Snowpark or Pandas DataFrame.
595
+ output_cols_prefix: Prefix for the response columns
599
596
  Returns:
600
597
  Transformed dataset.
601
598
  """
602
- self.fit(dataset)
603
- assert self._sklearn_object is not None
604
- return self._sklearn_object.embedding_
599
+ self._infer_input_output_cols(dataset)
600
+ super()._check_dataset_type(dataset)
601
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
602
+ estimator=self._sklearn_object,
603
+ dataset=dataset,
604
+ input_cols=self.input_cols,
605
+ label_cols=self.label_cols,
606
+ sample_weight_col=self.sample_weight_col,
607
+ autogenerated=self._autogenerated,
608
+ subproject=_SUBPROJECT,
609
+ )
610
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
611
+ drop_input_cols=self._drop_input_cols,
612
+ expected_output_cols_list=self.output_cols,
613
+ )
614
+ self._sklearn_object = fitted_estimator
615
+ self._is_fitted = True
616
+ return output_result
605
617
 
606
618
 
607
619
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -692,10 +704,8 @@ class Ridge(BaseTransformer):
692
704
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
693
705
 
694
706
  if isinstance(dataset, DataFrame):
695
- self._deps = self._batch_inference_validate_snowpark(
696
- dataset=dataset,
697
- inference_method=inference_method,
698
- )
707
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
708
+ self._deps = self._get_dependencies()
699
709
  assert isinstance(
700
710
  dataset._session, Session
701
711
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -760,10 +770,8 @@ class Ridge(BaseTransformer):
760
770
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
761
771
 
762
772
  if isinstance(dataset, DataFrame):
763
- self._deps = self._batch_inference_validate_snowpark(
764
- dataset=dataset,
765
- inference_method=inference_method,
766
- )
773
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
774
+ self._deps = self._get_dependencies()
767
775
  assert isinstance(
768
776
  dataset._session, Session
769
777
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -825,10 +833,8 @@ class Ridge(BaseTransformer):
825
833
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
826
834
 
827
835
  if isinstance(dataset, DataFrame):
828
- self._deps = self._batch_inference_validate_snowpark(
829
- dataset=dataset,
830
- inference_method=inference_method,
831
- )
836
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
837
+ self._deps = self._get_dependencies()
832
838
  assert isinstance(
833
839
  dataset._session, Session
834
840
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -894,10 +900,8 @@ class Ridge(BaseTransformer):
894
900
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
895
901
 
896
902
  if isinstance(dataset, DataFrame):
897
- self._deps = self._batch_inference_validate_snowpark(
898
- dataset=dataset,
899
- inference_method=inference_method,
900
- )
903
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
904
+ self._deps = self._get_dependencies()
901
905
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
902
906
  transform_kwargs = dict(
903
907
  session=dataset._session,
@@ -961,17 +965,15 @@ class Ridge(BaseTransformer):
961
965
  transform_kwargs: ScoreKwargsTypedDict = dict()
962
966
 
963
967
  if isinstance(dataset, DataFrame):
964
- self._deps = self._batch_inference_validate_snowpark(
965
- dataset=dataset,
966
- inference_method="score",
967
- )
968
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
969
+ self._deps = self._get_dependencies()
968
970
  selected_cols = self._get_active_columns()
969
971
  if len(selected_cols) > 0:
970
972
  dataset = dataset.select(selected_cols)
971
973
  assert isinstance(dataset._session, Session) # keep mypy happy
972
974
  transform_kwargs = dict(
973
975
  session=dataset._session,
974
- dependencies=["snowflake-snowpark-python"] + self._deps,
976
+ dependencies=self._deps,
975
977
  score_sproc_imports=['sklearn'],
976
978
  )
977
979
  elif isinstance(dataset, pd.DataFrame):
@@ -1036,11 +1038,8 @@ class Ridge(BaseTransformer):
1036
1038
 
1037
1039
  if isinstance(dataset, DataFrame):
1038
1040
 
1039
- self._deps = self._batch_inference_validate_snowpark(
1040
- dataset=dataset,
1041
- inference_method=inference_method,
1042
-
1043
- )
1041
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1042
+ self._deps = self._get_dependencies()
1044
1043
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1045
1044
  transform_kwargs = dict(
1046
1045
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RidgeClassifier(BaseTransformer):
70
64
  r"""Classifier using Ridge regression
71
65
  For more details on this class, see [sklearn.linear_model.RidgeClassifier]
@@ -358,20 +352,17 @@ class RidgeClassifier(BaseTransformer):
358
352
  self,
359
353
  dataset: DataFrame,
360
354
  inference_method: str,
361
- ) -> List[str]:
362
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
363
- return the available package that exists in the snowflake anaconda channel
355
+ ) -> None:
356
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
364
357
 
365
358
  Args:
366
359
  dataset: snowpark dataframe
367
360
  inference_method: the inference method such as predict, score...
368
-
361
+
369
362
  Raises:
370
363
  SnowflakeMLException: If the estimator is not fitted, raise error
371
364
  SnowflakeMLException: If the session is None, raise error
372
365
 
373
- Returns:
374
- A list of available package that exists in the snowflake anaconda channel
375
366
  """
376
367
  if not self._is_fitted:
377
368
  raise exceptions.SnowflakeMLException(
@@ -389,9 +380,7 @@ class RidgeClassifier(BaseTransformer):
389
380
  "Session must not specified for snowpark dataset."
390
381
  ),
391
382
  )
392
- # Validate that key package version in user workspace are supported in snowflake conda channel
393
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
394
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
383
+
395
384
 
396
385
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
397
386
  @telemetry.send_api_usage_telemetry(
@@ -439,7 +428,8 @@ class RidgeClassifier(BaseTransformer):
439
428
 
440
429
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
441
430
 
442
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
431
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
432
+ self._deps = self._get_dependencies()
443
433
  assert isinstance(
444
434
  dataset._session, Session
445
435
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -522,10 +512,8 @@ class RidgeClassifier(BaseTransformer):
522
512
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
523
513
  expected_dtype = convert_sp_to_sf_type(output_types[0])
524
514
 
525
- self._deps = self._batch_inference_validate_snowpark(
526
- dataset=dataset,
527
- inference_method=inference_method,
528
- )
515
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
516
+ self._deps = self._get_dependencies()
529
517
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
530
518
 
531
519
  transform_kwargs = dict(
@@ -592,16 +580,40 @@ class RidgeClassifier(BaseTransformer):
592
580
  self._is_fitted = True
593
581
  return output_result
594
582
 
583
+
584
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
585
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
586
+ """ Method not supported for this class.
595
587
 
596
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
597
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
598
- """
588
+
589
+ Raises:
590
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
591
+
592
+ Args:
593
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
594
+ Snowpark or Pandas DataFrame.
595
+ output_cols_prefix: Prefix for the response columns
599
596
  Returns:
600
597
  Transformed dataset.
601
598
  """
602
- self.fit(dataset)
603
- assert self._sklearn_object is not None
604
- return self._sklearn_object.embedding_
599
+ self._infer_input_output_cols(dataset)
600
+ super()._check_dataset_type(dataset)
601
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
602
+ estimator=self._sklearn_object,
603
+ dataset=dataset,
604
+ input_cols=self.input_cols,
605
+ label_cols=self.label_cols,
606
+ sample_weight_col=self.sample_weight_col,
607
+ autogenerated=self._autogenerated,
608
+ subproject=_SUBPROJECT,
609
+ )
610
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
611
+ drop_input_cols=self._drop_input_cols,
612
+ expected_output_cols_list=self.output_cols,
613
+ )
614
+ self._sklearn_object = fitted_estimator
615
+ self._is_fitted = True
616
+ return output_result
605
617
 
606
618
 
607
619
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -692,10 +704,8 @@ class RidgeClassifier(BaseTransformer):
692
704
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
693
705
 
694
706
  if isinstance(dataset, DataFrame):
695
- self._deps = self._batch_inference_validate_snowpark(
696
- dataset=dataset,
697
- inference_method=inference_method,
698
- )
707
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
708
+ self._deps = self._get_dependencies()
699
709
  assert isinstance(
700
710
  dataset._session, Session
701
711
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -760,10 +770,8 @@ class RidgeClassifier(BaseTransformer):
760
770
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
761
771
 
762
772
  if isinstance(dataset, DataFrame):
763
- self._deps = self._batch_inference_validate_snowpark(
764
- dataset=dataset,
765
- inference_method=inference_method,
766
- )
773
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
774
+ self._deps = self._get_dependencies()
767
775
  assert isinstance(
768
776
  dataset._session, Session
769
777
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -827,10 +835,8 @@ class RidgeClassifier(BaseTransformer):
827
835
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
828
836
 
829
837
  if isinstance(dataset, DataFrame):
830
- self._deps = self._batch_inference_validate_snowpark(
831
- dataset=dataset,
832
- inference_method=inference_method,
833
- )
838
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
839
+ self._deps = self._get_dependencies()
834
840
  assert isinstance(
835
841
  dataset._session, Session
836
842
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -896,10 +902,8 @@ class RidgeClassifier(BaseTransformer):
896
902
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
897
903
 
898
904
  if isinstance(dataset, DataFrame):
899
- self._deps = self._batch_inference_validate_snowpark(
900
- dataset=dataset,
901
- inference_method=inference_method,
902
- )
905
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
906
+ self._deps = self._get_dependencies()
903
907
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
904
908
  transform_kwargs = dict(
905
909
  session=dataset._session,
@@ -963,17 +967,15 @@ class RidgeClassifier(BaseTransformer):
963
967
  transform_kwargs: ScoreKwargsTypedDict = dict()
964
968
 
965
969
  if isinstance(dataset, DataFrame):
966
- self._deps = self._batch_inference_validate_snowpark(
967
- dataset=dataset,
968
- inference_method="score",
969
- )
970
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
971
+ self._deps = self._get_dependencies()
970
972
  selected_cols = self._get_active_columns()
971
973
  if len(selected_cols) > 0:
972
974
  dataset = dataset.select(selected_cols)
973
975
  assert isinstance(dataset._session, Session) # keep mypy happy
974
976
  transform_kwargs = dict(
975
977
  session=dataset._session,
976
- dependencies=["snowflake-snowpark-python"] + self._deps,
978
+ dependencies=self._deps,
977
979
  score_sproc_imports=['sklearn'],
978
980
  )
979
981
  elif isinstance(dataset, pd.DataFrame):
@@ -1038,11 +1040,8 @@ class RidgeClassifier(BaseTransformer):
1038
1040
 
1039
1041
  if isinstance(dataset, DataFrame):
1040
1042
 
1041
- self._deps = self._batch_inference_validate_snowpark(
1042
- dataset=dataset,
1043
- inference_method=inference_method,
1044
-
1045
- )
1043
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1044
+ self._deps = self._get_dependencies()
1046
1045
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1047
1046
  transform_kwargs = dict(
1048
1047
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RidgeClassifierCV(BaseTransformer):
70
64
  r"""Ridge classifier with built-in cross-validation
71
65
  For more details on this class, see [sklearn.linear_model.RidgeClassifierCV]
@@ -309,20 +303,17 @@ class RidgeClassifierCV(BaseTransformer):
309
303
  self,
310
304
  dataset: DataFrame,
311
305
  inference_method: str,
312
- ) -> List[str]:
313
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
314
- return the available package that exists in the snowflake anaconda channel
306
+ ) -> None:
307
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
315
308
 
316
309
  Args:
317
310
  dataset: snowpark dataframe
318
311
  inference_method: the inference method such as predict, score...
319
-
312
+
320
313
  Raises:
321
314
  SnowflakeMLException: If the estimator is not fitted, raise error
322
315
  SnowflakeMLException: If the session is None, raise error
323
316
 
324
- Returns:
325
- A list of available package that exists in the snowflake anaconda channel
326
317
  """
327
318
  if not self._is_fitted:
328
319
  raise exceptions.SnowflakeMLException(
@@ -340,9 +331,7 @@ class RidgeClassifierCV(BaseTransformer):
340
331
  "Session must not specified for snowpark dataset."
341
332
  ),
342
333
  )
343
- # Validate that key package version in user workspace are supported in snowflake conda channel
344
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
345
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
334
+
346
335
 
347
336
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
348
337
  @telemetry.send_api_usage_telemetry(
@@ -390,7 +379,8 @@ class RidgeClassifierCV(BaseTransformer):
390
379
 
391
380
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
392
381
 
393
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
382
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
383
+ self._deps = self._get_dependencies()
394
384
  assert isinstance(
395
385
  dataset._session, Session
396
386
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -473,10 +463,8 @@ class RidgeClassifierCV(BaseTransformer):
473
463
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
474
464
  expected_dtype = convert_sp_to_sf_type(output_types[0])
475
465
 
476
- self._deps = self._batch_inference_validate_snowpark(
477
- dataset=dataset,
478
- inference_method=inference_method,
479
- )
466
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
467
+ self._deps = self._get_dependencies()
480
468
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
481
469
 
482
470
  transform_kwargs = dict(
@@ -543,16 +531,40 @@ class RidgeClassifierCV(BaseTransformer):
543
531
  self._is_fitted = True
544
532
  return output_result
545
533
 
534
+
535
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
536
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
537
+ """ Method not supported for this class.
546
538
 
547
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
548
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
549
- """
539
+
540
+ Raises:
541
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
542
+
543
+ Args:
544
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
545
+ Snowpark or Pandas DataFrame.
546
+ output_cols_prefix: Prefix for the response columns
550
547
  Returns:
551
548
  Transformed dataset.
552
549
  """
553
- self.fit(dataset)
554
- assert self._sklearn_object is not None
555
- return self._sklearn_object.embedding_
550
+ self._infer_input_output_cols(dataset)
551
+ super()._check_dataset_type(dataset)
552
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
553
+ estimator=self._sklearn_object,
554
+ dataset=dataset,
555
+ input_cols=self.input_cols,
556
+ label_cols=self.label_cols,
557
+ sample_weight_col=self.sample_weight_col,
558
+ autogenerated=self._autogenerated,
559
+ subproject=_SUBPROJECT,
560
+ )
561
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
562
+ drop_input_cols=self._drop_input_cols,
563
+ expected_output_cols_list=self.output_cols,
564
+ )
565
+ self._sklearn_object = fitted_estimator
566
+ self._is_fitted = True
567
+ return output_result
556
568
 
557
569
 
558
570
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -643,10 +655,8 @@ class RidgeClassifierCV(BaseTransformer):
643
655
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
644
656
 
645
657
  if isinstance(dataset, DataFrame):
646
- self._deps = self._batch_inference_validate_snowpark(
647
- dataset=dataset,
648
- inference_method=inference_method,
649
- )
658
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
659
+ self._deps = self._get_dependencies()
650
660
  assert isinstance(
651
661
  dataset._session, Session
652
662
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -711,10 +721,8 @@ class RidgeClassifierCV(BaseTransformer):
711
721
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
712
722
 
713
723
  if isinstance(dataset, DataFrame):
714
- self._deps = self._batch_inference_validate_snowpark(
715
- dataset=dataset,
716
- inference_method=inference_method,
717
- )
724
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
725
+ self._deps = self._get_dependencies()
718
726
  assert isinstance(
719
727
  dataset._session, Session
720
728
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -778,10 +786,8 @@ class RidgeClassifierCV(BaseTransformer):
778
786
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
779
787
 
780
788
  if isinstance(dataset, DataFrame):
781
- self._deps = self._batch_inference_validate_snowpark(
782
- dataset=dataset,
783
- inference_method=inference_method,
784
- )
789
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
790
+ self._deps = self._get_dependencies()
785
791
  assert isinstance(
786
792
  dataset._session, Session
787
793
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -847,10 +853,8 @@ class RidgeClassifierCV(BaseTransformer):
847
853
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
848
854
 
849
855
  if isinstance(dataset, DataFrame):
850
- self._deps = self._batch_inference_validate_snowpark(
851
- dataset=dataset,
852
- inference_method=inference_method,
853
- )
856
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
857
+ self._deps = self._get_dependencies()
854
858
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
855
859
  transform_kwargs = dict(
856
860
  session=dataset._session,
@@ -914,17 +918,15 @@ class RidgeClassifierCV(BaseTransformer):
914
918
  transform_kwargs: ScoreKwargsTypedDict = dict()
915
919
 
916
920
  if isinstance(dataset, DataFrame):
917
- self._deps = self._batch_inference_validate_snowpark(
918
- dataset=dataset,
919
- inference_method="score",
920
- )
921
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
922
+ self._deps = self._get_dependencies()
921
923
  selected_cols = self._get_active_columns()
922
924
  if len(selected_cols) > 0:
923
925
  dataset = dataset.select(selected_cols)
924
926
  assert isinstance(dataset._session, Session) # keep mypy happy
925
927
  transform_kwargs = dict(
926
928
  session=dataset._session,
927
- dependencies=["snowflake-snowpark-python"] + self._deps,
929
+ dependencies=self._deps,
928
930
  score_sproc_imports=['sklearn'],
929
931
  )
930
932
  elif isinstance(dataset, pd.DataFrame):
@@ -989,11 +991,8 @@ class RidgeClassifierCV(BaseTransformer):
989
991
 
990
992
  if isinstance(dataset, DataFrame):
991
993
 
992
- self._deps = self._batch_inference_validate_snowpark(
993
- dataset=dataset,
994
- inference_method=inference_method,
995
-
996
- )
994
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
995
+ self._deps = self._get_dependencies()
997
996
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
998
997
  transform_kwargs = dict(
999
998
  session = dataset._session,