snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RidgeCV(BaseTransformer):
70
64
  r"""Ridge regression with built-in cross-validation
71
65
  For more details on this class, see [sklearn.linear_model.RidgeCV]
@@ -330,20 +324,17 @@ class RidgeCV(BaseTransformer):
330
324
  self,
331
325
  dataset: DataFrame,
332
326
  inference_method: str,
333
- ) -> List[str]:
334
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
335
- return the available package that exists in the snowflake anaconda channel
327
+ ) -> None:
328
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
336
329
 
337
330
  Args:
338
331
  dataset: snowpark dataframe
339
332
  inference_method: the inference method such as predict, score...
340
-
333
+
341
334
  Raises:
342
335
  SnowflakeMLException: If the estimator is not fitted, raise error
343
336
  SnowflakeMLException: If the session is None, raise error
344
337
 
345
- Returns:
346
- A list of available package that exists in the snowflake anaconda channel
347
338
  """
348
339
  if not self._is_fitted:
349
340
  raise exceptions.SnowflakeMLException(
@@ -361,9 +352,7 @@ class RidgeCV(BaseTransformer):
361
352
  "Session must not specified for snowpark dataset."
362
353
  ),
363
354
  )
364
- # Validate that key package version in user workspace are supported in snowflake conda channel
365
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
366
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
355
+
367
356
 
368
357
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
369
358
  @telemetry.send_api_usage_telemetry(
@@ -411,7 +400,8 @@ class RidgeCV(BaseTransformer):
411
400
 
412
401
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
413
402
 
414
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
403
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
404
+ self._deps = self._get_dependencies()
415
405
  assert isinstance(
416
406
  dataset._session, Session
417
407
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -494,10 +484,8 @@ class RidgeCV(BaseTransformer):
494
484
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
495
485
  expected_dtype = convert_sp_to_sf_type(output_types[0])
496
486
 
497
- self._deps = self._batch_inference_validate_snowpark(
498
- dataset=dataset,
499
- inference_method=inference_method,
500
- )
487
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
488
+ self._deps = self._get_dependencies()
501
489
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
502
490
 
503
491
  transform_kwargs = dict(
@@ -564,16 +552,40 @@ class RidgeCV(BaseTransformer):
564
552
  self._is_fitted = True
565
553
  return output_result
566
554
 
555
+
556
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
557
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
558
+ """ Method not supported for this class.
567
559
 
568
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
569
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
570
- """
560
+
561
+ Raises:
562
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
563
+
564
+ Args:
565
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
566
+ Snowpark or Pandas DataFrame.
567
+ output_cols_prefix: Prefix for the response columns
571
568
  Returns:
572
569
  Transformed dataset.
573
570
  """
574
- self.fit(dataset)
575
- assert self._sklearn_object is not None
576
- return self._sklearn_object.embedding_
571
+ self._infer_input_output_cols(dataset)
572
+ super()._check_dataset_type(dataset)
573
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
574
+ estimator=self._sklearn_object,
575
+ dataset=dataset,
576
+ input_cols=self.input_cols,
577
+ label_cols=self.label_cols,
578
+ sample_weight_col=self.sample_weight_col,
579
+ autogenerated=self._autogenerated,
580
+ subproject=_SUBPROJECT,
581
+ )
582
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
583
+ drop_input_cols=self._drop_input_cols,
584
+ expected_output_cols_list=self.output_cols,
585
+ )
586
+ self._sklearn_object = fitted_estimator
587
+ self._is_fitted = True
588
+ return output_result
577
589
 
578
590
 
579
591
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -664,10 +676,8 @@ class RidgeCV(BaseTransformer):
664
676
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
665
677
 
666
678
  if isinstance(dataset, DataFrame):
667
- self._deps = self._batch_inference_validate_snowpark(
668
- dataset=dataset,
669
- inference_method=inference_method,
670
- )
679
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
680
+ self._deps = self._get_dependencies()
671
681
  assert isinstance(
672
682
  dataset._session, Session
673
683
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -732,10 +742,8 @@ class RidgeCV(BaseTransformer):
732
742
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
733
743
 
734
744
  if isinstance(dataset, DataFrame):
735
- self._deps = self._batch_inference_validate_snowpark(
736
- dataset=dataset,
737
- inference_method=inference_method,
738
- )
745
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
746
+ self._deps = self._get_dependencies()
739
747
  assert isinstance(
740
748
  dataset._session, Session
741
749
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -797,10 +805,8 @@ class RidgeCV(BaseTransformer):
797
805
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
798
806
 
799
807
  if isinstance(dataset, DataFrame):
800
- self._deps = self._batch_inference_validate_snowpark(
801
- dataset=dataset,
802
- inference_method=inference_method,
803
- )
808
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
809
+ self._deps = self._get_dependencies()
804
810
  assert isinstance(
805
811
  dataset._session, Session
806
812
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -866,10 +872,8 @@ class RidgeCV(BaseTransformer):
866
872
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
867
873
 
868
874
  if isinstance(dataset, DataFrame):
869
- self._deps = self._batch_inference_validate_snowpark(
870
- dataset=dataset,
871
- inference_method=inference_method,
872
- )
875
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
876
+ self._deps = self._get_dependencies()
873
877
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
874
878
  transform_kwargs = dict(
875
879
  session=dataset._session,
@@ -933,17 +937,15 @@ class RidgeCV(BaseTransformer):
933
937
  transform_kwargs: ScoreKwargsTypedDict = dict()
934
938
 
935
939
  if isinstance(dataset, DataFrame):
936
- self._deps = self._batch_inference_validate_snowpark(
937
- dataset=dataset,
938
- inference_method="score",
939
- )
940
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
941
+ self._deps = self._get_dependencies()
940
942
  selected_cols = self._get_active_columns()
941
943
  if len(selected_cols) > 0:
942
944
  dataset = dataset.select(selected_cols)
943
945
  assert isinstance(dataset._session, Session) # keep mypy happy
944
946
  transform_kwargs = dict(
945
947
  session=dataset._session,
946
- dependencies=["snowflake-snowpark-python"] + self._deps,
948
+ dependencies=self._deps,
947
949
  score_sproc_imports=['sklearn'],
948
950
  )
949
951
  elif isinstance(dataset, pd.DataFrame):
@@ -1008,11 +1010,8 @@ class RidgeCV(BaseTransformer):
1008
1010
 
1009
1011
  if isinstance(dataset, DataFrame):
1010
1012
 
1011
- self._deps = self._batch_inference_validate_snowpark(
1012
- dataset=dataset,
1013
- inference_method=inference_method,
1014
-
1015
- )
1013
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1014
+ self._deps = self._get_dependencies()
1016
1015
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1017
1016
  transform_kwargs = dict(
1018
1017
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SGDClassifier(BaseTransformer):
70
64
  r"""Linear classifiers (SVM, logistic regression, etc
71
65
  For more details on this class, see [sklearn.linear_model.SGDClassifier]
@@ -449,20 +443,17 @@ class SGDClassifier(BaseTransformer):
449
443
  self,
450
444
  dataset: DataFrame,
451
445
  inference_method: str,
452
- ) -> List[str]:
453
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
454
- return the available package that exists in the snowflake anaconda channel
446
+ ) -> None:
447
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
455
448
 
456
449
  Args:
457
450
  dataset: snowpark dataframe
458
451
  inference_method: the inference method such as predict, score...
459
-
452
+
460
453
  Raises:
461
454
  SnowflakeMLException: If the estimator is not fitted, raise error
462
455
  SnowflakeMLException: If the session is None, raise error
463
456
 
464
- Returns:
465
- A list of available package that exists in the snowflake anaconda channel
466
457
  """
467
458
  if not self._is_fitted:
468
459
  raise exceptions.SnowflakeMLException(
@@ -480,9 +471,7 @@ class SGDClassifier(BaseTransformer):
480
471
  "Session must not specified for snowpark dataset."
481
472
  ),
482
473
  )
483
- # Validate that key package version in user workspace are supported in snowflake conda channel
484
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
485
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
474
+
486
475
 
487
476
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
488
477
  @telemetry.send_api_usage_telemetry(
@@ -530,7 +519,8 @@ class SGDClassifier(BaseTransformer):
530
519
 
531
520
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
532
521
 
533
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
522
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
523
+ self._deps = self._get_dependencies()
534
524
  assert isinstance(
535
525
  dataset._session, Session
536
526
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -613,10 +603,8 @@ class SGDClassifier(BaseTransformer):
613
603
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
614
604
  expected_dtype = convert_sp_to_sf_type(output_types[0])
615
605
 
616
- self._deps = self._batch_inference_validate_snowpark(
617
- dataset=dataset,
618
- inference_method=inference_method,
619
- )
606
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
607
+ self._deps = self._get_dependencies()
620
608
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
621
609
 
622
610
  transform_kwargs = dict(
@@ -683,16 +671,40 @@ class SGDClassifier(BaseTransformer):
683
671
  self._is_fitted = True
684
672
  return output_result
685
673
 
674
+
675
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
676
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
677
+ """ Method not supported for this class.
686
678
 
687
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
688
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
689
- """
679
+
680
+ Raises:
681
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
682
+
683
+ Args:
684
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
685
+ Snowpark or Pandas DataFrame.
686
+ output_cols_prefix: Prefix for the response columns
690
687
  Returns:
691
688
  Transformed dataset.
692
689
  """
693
- self.fit(dataset)
694
- assert self._sklearn_object is not None
695
- return self._sklearn_object.embedding_
690
+ self._infer_input_output_cols(dataset)
691
+ super()._check_dataset_type(dataset)
692
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
693
+ estimator=self._sklearn_object,
694
+ dataset=dataset,
695
+ input_cols=self.input_cols,
696
+ label_cols=self.label_cols,
697
+ sample_weight_col=self.sample_weight_col,
698
+ autogenerated=self._autogenerated,
699
+ subproject=_SUBPROJECT,
700
+ )
701
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
702
+ drop_input_cols=self._drop_input_cols,
703
+ expected_output_cols_list=self.output_cols,
704
+ )
705
+ self._sklearn_object = fitted_estimator
706
+ self._is_fitted = True
707
+ return output_result
696
708
 
697
709
 
698
710
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -785,10 +797,8 @@ class SGDClassifier(BaseTransformer):
785
797
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
786
798
 
787
799
  if isinstance(dataset, DataFrame):
788
- self._deps = self._batch_inference_validate_snowpark(
789
- dataset=dataset,
790
- inference_method=inference_method,
791
- )
800
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
801
+ self._deps = self._get_dependencies()
792
802
  assert isinstance(
793
803
  dataset._session, Session
794
804
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -855,10 +865,8 @@ class SGDClassifier(BaseTransformer):
855
865
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
856
866
 
857
867
  if isinstance(dataset, DataFrame):
858
- self._deps = self._batch_inference_validate_snowpark(
859
- dataset=dataset,
860
- inference_method=inference_method,
861
- )
868
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
869
+ self._deps = self._get_dependencies()
862
870
  assert isinstance(
863
871
  dataset._session, Session
864
872
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -922,10 +930,8 @@ class SGDClassifier(BaseTransformer):
922
930
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
923
931
 
924
932
  if isinstance(dataset, DataFrame):
925
- self._deps = self._batch_inference_validate_snowpark(
926
- dataset=dataset,
927
- inference_method=inference_method,
928
- )
933
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
934
+ self._deps = self._get_dependencies()
929
935
  assert isinstance(
930
936
  dataset._session, Session
931
937
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -991,10 +997,8 @@ class SGDClassifier(BaseTransformer):
991
997
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
992
998
 
993
999
  if isinstance(dataset, DataFrame):
994
- self._deps = self._batch_inference_validate_snowpark(
995
- dataset=dataset,
996
- inference_method=inference_method,
997
- )
1000
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1001
+ self._deps = self._get_dependencies()
998
1002
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
999
1003
  transform_kwargs = dict(
1000
1004
  session=dataset._session,
@@ -1058,17 +1062,15 @@ class SGDClassifier(BaseTransformer):
1058
1062
  transform_kwargs: ScoreKwargsTypedDict = dict()
1059
1063
 
1060
1064
  if isinstance(dataset, DataFrame):
1061
- self._deps = self._batch_inference_validate_snowpark(
1062
- dataset=dataset,
1063
- inference_method="score",
1064
- )
1065
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1066
+ self._deps = self._get_dependencies()
1065
1067
  selected_cols = self._get_active_columns()
1066
1068
  if len(selected_cols) > 0:
1067
1069
  dataset = dataset.select(selected_cols)
1068
1070
  assert isinstance(dataset._session, Session) # keep mypy happy
1069
1071
  transform_kwargs = dict(
1070
1072
  session=dataset._session,
1071
- dependencies=["snowflake-snowpark-python"] + self._deps,
1073
+ dependencies=self._deps,
1072
1074
  score_sproc_imports=['sklearn'],
1073
1075
  )
1074
1076
  elif isinstance(dataset, pd.DataFrame):
@@ -1133,11 +1135,8 @@ class SGDClassifier(BaseTransformer):
1133
1135
 
1134
1136
  if isinstance(dataset, DataFrame):
1135
1137
 
1136
- self._deps = self._batch_inference_validate_snowpark(
1137
- dataset=dataset,
1138
- inference_method=inference_method,
1139
-
1140
- )
1138
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1139
+ self._deps = self._get_dependencies()
1141
1140
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1142
1141
  transform_kwargs = dict(
1143
1142
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SGDOneClassSVM(BaseTransformer):
70
64
  r"""Solves linear One-Class SVM using Stochastic Gradient Descent
71
65
  For more details on this class, see [sklearn.linear_model.SGDOneClassSVM]
@@ -347,20 +341,17 @@ class SGDOneClassSVM(BaseTransformer):
347
341
  self,
348
342
  dataset: DataFrame,
349
343
  inference_method: str,
350
- ) -> List[str]:
351
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
352
- return the available package that exists in the snowflake anaconda channel
344
+ ) -> None:
345
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
353
346
 
354
347
  Args:
355
348
  dataset: snowpark dataframe
356
349
  inference_method: the inference method such as predict, score...
357
-
350
+
358
351
  Raises:
359
352
  SnowflakeMLException: If the estimator is not fitted, raise error
360
353
  SnowflakeMLException: If the session is None, raise error
361
354
 
362
- Returns:
363
- A list of available package that exists in the snowflake anaconda channel
364
355
  """
365
356
  if not self._is_fitted:
366
357
  raise exceptions.SnowflakeMLException(
@@ -378,9 +369,7 @@ class SGDOneClassSVM(BaseTransformer):
378
369
  "Session must not specified for snowpark dataset."
379
370
  ),
380
371
  )
381
- # Validate that key package version in user workspace are supported in snowflake conda channel
382
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
383
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
372
+
384
373
 
385
374
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
386
375
  @telemetry.send_api_usage_telemetry(
@@ -428,7 +417,8 @@ class SGDOneClassSVM(BaseTransformer):
428
417
 
429
418
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
430
419
 
431
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
420
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
421
+ self._deps = self._get_dependencies()
432
422
  assert isinstance(
433
423
  dataset._session, Session
434
424
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -511,10 +501,8 @@ class SGDOneClassSVM(BaseTransformer):
511
501
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
512
502
  expected_dtype = convert_sp_to_sf_type(output_types[0])
513
503
 
514
- self._deps = self._batch_inference_validate_snowpark(
515
- dataset=dataset,
516
- inference_method=inference_method,
517
- )
504
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
505
+ self._deps = self._get_dependencies()
518
506
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
519
507
 
520
508
  transform_kwargs = dict(
@@ -583,16 +571,40 @@ class SGDOneClassSVM(BaseTransformer):
583
571
  self._is_fitted = True
584
572
  return output_result
585
573
 
574
+
575
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
576
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
577
+ """ Method not supported for this class.
578
+
586
579
 
587
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
588
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
589
- """
580
+ Raises:
581
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
582
+
583
+ Args:
584
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
585
+ Snowpark or Pandas DataFrame.
586
+ output_cols_prefix: Prefix for the response columns
590
587
  Returns:
591
588
  Transformed dataset.
592
589
  """
593
- self.fit(dataset)
594
- assert self._sklearn_object is not None
595
- return self._sklearn_object.embedding_
590
+ self._infer_input_output_cols(dataset)
591
+ super()._check_dataset_type(dataset)
592
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
593
+ estimator=self._sklearn_object,
594
+ dataset=dataset,
595
+ input_cols=self.input_cols,
596
+ label_cols=self.label_cols,
597
+ sample_weight_col=self.sample_weight_col,
598
+ autogenerated=self._autogenerated,
599
+ subproject=_SUBPROJECT,
600
+ )
601
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
602
+ drop_input_cols=self._drop_input_cols,
603
+ expected_output_cols_list=self.output_cols,
604
+ )
605
+ self._sklearn_object = fitted_estimator
606
+ self._is_fitted = True
607
+ return output_result
596
608
 
597
609
 
598
610
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -683,10 +695,8 @@ class SGDOneClassSVM(BaseTransformer):
683
695
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
684
696
 
685
697
  if isinstance(dataset, DataFrame):
686
- self._deps = self._batch_inference_validate_snowpark(
687
- dataset=dataset,
688
- inference_method=inference_method,
689
- )
698
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
699
+ self._deps = self._get_dependencies()
690
700
  assert isinstance(
691
701
  dataset._session, Session
692
702
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -751,10 +761,8 @@ class SGDOneClassSVM(BaseTransformer):
751
761
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
752
762
 
753
763
  if isinstance(dataset, DataFrame):
754
- self._deps = self._batch_inference_validate_snowpark(
755
- dataset=dataset,
756
- inference_method=inference_method,
757
- )
764
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
765
+ self._deps = self._get_dependencies()
758
766
  assert isinstance(
759
767
  dataset._session, Session
760
768
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -818,10 +826,8 @@ class SGDOneClassSVM(BaseTransformer):
818
826
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
819
827
 
820
828
  if isinstance(dataset, DataFrame):
821
- self._deps = self._batch_inference_validate_snowpark(
822
- dataset=dataset,
823
- inference_method=inference_method,
824
- )
829
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
830
+ self._deps = self._get_dependencies()
825
831
  assert isinstance(
826
832
  dataset._session, Session
827
833
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -889,10 +895,8 @@ class SGDOneClassSVM(BaseTransformer):
889
895
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
890
896
 
891
897
  if isinstance(dataset, DataFrame):
892
- self._deps = self._batch_inference_validate_snowpark(
893
- dataset=dataset,
894
- inference_method=inference_method,
895
- )
898
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
899
+ self._deps = self._get_dependencies()
896
900
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
897
901
  transform_kwargs = dict(
898
902
  session=dataset._session,
@@ -954,17 +958,15 @@ class SGDOneClassSVM(BaseTransformer):
954
958
  transform_kwargs: ScoreKwargsTypedDict = dict()
955
959
 
956
960
  if isinstance(dataset, DataFrame):
957
- self._deps = self._batch_inference_validate_snowpark(
958
- dataset=dataset,
959
- inference_method="score",
960
- )
961
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
962
+ self._deps = self._get_dependencies()
961
963
  selected_cols = self._get_active_columns()
962
964
  if len(selected_cols) > 0:
963
965
  dataset = dataset.select(selected_cols)
964
966
  assert isinstance(dataset._session, Session) # keep mypy happy
965
967
  transform_kwargs = dict(
966
968
  session=dataset._session,
967
- dependencies=["snowflake-snowpark-python"] + self._deps,
969
+ dependencies=self._deps,
968
970
  score_sproc_imports=['sklearn'],
969
971
  )
970
972
  elif isinstance(dataset, pd.DataFrame):
@@ -1029,11 +1031,8 @@ class SGDOneClassSVM(BaseTransformer):
1029
1031
 
1030
1032
  if isinstance(dataset, DataFrame):
1031
1033
 
1032
- self._deps = self._batch_inference_validate_snowpark(
1033
- dataset=dataset,
1034
- inference_method=inference_method,
1035
-
1036
- )
1034
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1035
+ self._deps = self._get_dependencies()
1037
1036
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1038
1037
  transform_kwargs = dict(
1039
1038
  session = dataset._session,