snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class Nystroem(BaseTransformer):
70
64
  r"""Approximate a kernel map using a subset of the training data
71
65
  For more details on this class, see [sklearn.kernel_approximation.Nystroem]
@@ -308,20 +302,17 @@ class Nystroem(BaseTransformer):
308
302
  self,
309
303
  dataset: DataFrame,
310
304
  inference_method: str,
311
- ) -> List[str]:
312
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
313
- return the available package that exists in the snowflake anaconda channel
305
+ ) -> None:
306
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
314
307
 
315
308
  Args:
316
309
  dataset: snowpark dataframe
317
310
  inference_method: the inference method such as predict, score...
318
-
311
+
319
312
  Raises:
320
313
  SnowflakeMLException: If the estimator is not fitted, raise error
321
314
  SnowflakeMLException: If the session is None, raise error
322
315
 
323
- Returns:
324
- A list of available package that exists in the snowflake anaconda channel
325
316
  """
326
317
  if not self._is_fitted:
327
318
  raise exceptions.SnowflakeMLException(
@@ -339,9 +330,7 @@ class Nystroem(BaseTransformer):
339
330
  "Session must not specified for snowpark dataset."
340
331
  ),
341
332
  )
342
- # Validate that key package version in user workspace are supported in snowflake conda channel
343
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
344
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
333
+
345
334
 
346
335
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
347
336
  @telemetry.send_api_usage_telemetry(
@@ -387,7 +376,8 @@ class Nystroem(BaseTransformer):
387
376
 
388
377
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
389
378
 
390
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
379
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
380
+ self._deps = self._get_dependencies()
391
381
  assert isinstance(
392
382
  dataset._session, Session
393
383
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -472,10 +462,8 @@ class Nystroem(BaseTransformer):
472
462
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
473
463
  expected_dtype = convert_sp_to_sf_type(output_types[0])
474
464
 
475
- self._deps = self._batch_inference_validate_snowpark(
476
- dataset=dataset,
477
- inference_method=inference_method,
478
- )
465
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
466
+ self._deps = self._get_dependencies()
479
467
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
480
468
 
481
469
  transform_kwargs = dict(
@@ -542,16 +530,42 @@ class Nystroem(BaseTransformer):
542
530
  self._is_fitted = True
543
531
  return output_result
544
532
 
533
+
534
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
535
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
536
+ """ Fit to data, then transform it
537
+ For more details on this function, see [sklearn.kernel_approximation.Nystroem.fit_transform]
538
+ (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.Nystroem.html#sklearn.kernel_approximation.Nystroem.fit_transform)
539
+
545
540
 
546
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
547
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
548
- """
541
+ Raises:
542
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
543
+
544
+ Args:
545
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
546
+ Snowpark or Pandas DataFrame.
547
+ output_cols_prefix: Prefix for the response columns
549
548
  Returns:
550
549
  Transformed dataset.
551
550
  """
552
- self.fit(dataset)
553
- assert self._sklearn_object is not None
554
- return self._sklearn_object.embedding_
551
+ self._infer_input_output_cols(dataset)
552
+ super()._check_dataset_type(dataset)
553
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
554
+ estimator=self._sklearn_object,
555
+ dataset=dataset,
556
+ input_cols=self.input_cols,
557
+ label_cols=self.label_cols,
558
+ sample_weight_col=self.sample_weight_col,
559
+ autogenerated=self._autogenerated,
560
+ subproject=_SUBPROJECT,
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=self.output_cols,
565
+ )
566
+ self._sklearn_object = fitted_estimator
567
+ self._is_fitted = True
568
+ return output_result
555
569
 
556
570
 
557
571
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -642,10 +656,8 @@ class Nystroem(BaseTransformer):
642
656
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
643
657
 
644
658
  if isinstance(dataset, DataFrame):
645
- self._deps = self._batch_inference_validate_snowpark(
646
- dataset=dataset,
647
- inference_method=inference_method,
648
- )
659
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
660
+ self._deps = self._get_dependencies()
649
661
  assert isinstance(
650
662
  dataset._session, Session
651
663
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -710,10 +722,8 @@ class Nystroem(BaseTransformer):
710
722
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
711
723
 
712
724
  if isinstance(dataset, DataFrame):
713
- self._deps = self._batch_inference_validate_snowpark(
714
- dataset=dataset,
715
- inference_method=inference_method,
716
- )
725
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
726
+ self._deps = self._get_dependencies()
717
727
  assert isinstance(
718
728
  dataset._session, Session
719
729
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -775,10 +785,8 @@ class Nystroem(BaseTransformer):
775
785
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
776
786
 
777
787
  if isinstance(dataset, DataFrame):
778
- self._deps = self._batch_inference_validate_snowpark(
779
- dataset=dataset,
780
- inference_method=inference_method,
781
- )
788
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
789
+ self._deps = self._get_dependencies()
782
790
  assert isinstance(
783
791
  dataset._session, Session
784
792
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -844,10 +852,8 @@ class Nystroem(BaseTransformer):
844
852
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
845
853
 
846
854
  if isinstance(dataset, DataFrame):
847
- self._deps = self._batch_inference_validate_snowpark(
848
- dataset=dataset,
849
- inference_method=inference_method,
850
- )
855
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
856
+ self._deps = self._get_dependencies()
851
857
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
852
858
  transform_kwargs = dict(
853
859
  session=dataset._session,
@@ -909,17 +915,15 @@ class Nystroem(BaseTransformer):
909
915
  transform_kwargs: ScoreKwargsTypedDict = dict()
910
916
 
911
917
  if isinstance(dataset, DataFrame):
912
- self._deps = self._batch_inference_validate_snowpark(
913
- dataset=dataset,
914
- inference_method="score",
915
- )
918
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
919
+ self._deps = self._get_dependencies()
916
920
  selected_cols = self._get_active_columns()
917
921
  if len(selected_cols) > 0:
918
922
  dataset = dataset.select(selected_cols)
919
923
  assert isinstance(dataset._session, Session) # keep mypy happy
920
924
  transform_kwargs = dict(
921
925
  session=dataset._session,
922
- dependencies=["snowflake-snowpark-python"] + self._deps,
926
+ dependencies=self._deps,
923
927
  score_sproc_imports=['sklearn'],
924
928
  )
925
929
  elif isinstance(dataset, pd.DataFrame):
@@ -984,11 +988,8 @@ class Nystroem(BaseTransformer):
984
988
 
985
989
  if isinstance(dataset, DataFrame):
986
990
 
987
- self._deps = self._batch_inference_validate_snowpark(
988
- dataset=dataset,
989
- inference_method=inference_method,
990
-
991
- )
991
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
992
+ self._deps = self._get_dependencies()
992
993
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
993
994
  transform_kwargs = dict(
994
995
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class PolynomialCountSketch(BaseTransformer):
70
64
  r"""Polynomial kernel approximation via Tensor Sketch
71
65
  For more details on this class, see [sklearn.kernel_approximation.PolynomialCountSketch]
@@ -284,20 +278,17 @@ class PolynomialCountSketch(BaseTransformer):
284
278
  self,
285
279
  dataset: DataFrame,
286
280
  inference_method: str,
287
- ) -> List[str]:
288
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
289
- return the available package that exists in the snowflake anaconda channel
281
+ ) -> None:
282
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
290
283
 
291
284
  Args:
292
285
  dataset: snowpark dataframe
293
286
  inference_method: the inference method such as predict, score...
294
-
287
+
295
288
  Raises:
296
289
  SnowflakeMLException: If the estimator is not fitted, raise error
297
290
  SnowflakeMLException: If the session is None, raise error
298
291
 
299
- Returns:
300
- A list of available package that exists in the snowflake anaconda channel
301
292
  """
302
293
  if not self._is_fitted:
303
294
  raise exceptions.SnowflakeMLException(
@@ -315,9 +306,7 @@ class PolynomialCountSketch(BaseTransformer):
315
306
  "Session must not specified for snowpark dataset."
316
307
  ),
317
308
  )
318
- # Validate that key package version in user workspace are supported in snowflake conda channel
319
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
320
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
309
+
321
310
 
322
311
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
323
312
  @telemetry.send_api_usage_telemetry(
@@ -363,7 +352,8 @@ class PolynomialCountSketch(BaseTransformer):
363
352
 
364
353
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
365
354
 
366
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
355
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
356
+ self._deps = self._get_dependencies()
367
357
  assert isinstance(
368
358
  dataset._session, Session
369
359
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -448,10 +438,8 @@ class PolynomialCountSketch(BaseTransformer):
448
438
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
449
439
  expected_dtype = convert_sp_to_sf_type(output_types[0])
450
440
 
451
- self._deps = self._batch_inference_validate_snowpark(
452
- dataset=dataset,
453
- inference_method=inference_method,
454
- )
441
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
442
+ self._deps = self._get_dependencies()
455
443
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
456
444
 
457
445
  transform_kwargs = dict(
@@ -518,16 +506,42 @@ class PolynomialCountSketch(BaseTransformer):
518
506
  self._is_fitted = True
519
507
  return output_result
520
508
 
509
+
510
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
511
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
512
+ """ Fit to data, then transform it
513
+ For more details on this function, see [sklearn.kernel_approximation.PolynomialCountSketch.fit_transform]
514
+ (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.PolynomialCountSketch.html#sklearn.kernel_approximation.PolynomialCountSketch.fit_transform)
515
+
521
516
 
522
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
523
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
524
- """
517
+ Raises:
518
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
519
+
520
+ Args:
521
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
522
+ Snowpark or Pandas DataFrame.
523
+ output_cols_prefix: Prefix for the response columns
525
524
  Returns:
526
525
  Transformed dataset.
527
526
  """
528
- self.fit(dataset)
529
- assert self._sklearn_object is not None
530
- return self._sklearn_object.embedding_
527
+ self._infer_input_output_cols(dataset)
528
+ super()._check_dataset_type(dataset)
529
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
530
+ estimator=self._sklearn_object,
531
+ dataset=dataset,
532
+ input_cols=self.input_cols,
533
+ label_cols=self.label_cols,
534
+ sample_weight_col=self.sample_weight_col,
535
+ autogenerated=self._autogenerated,
536
+ subproject=_SUBPROJECT,
537
+ )
538
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=self.output_cols,
541
+ )
542
+ self._sklearn_object = fitted_estimator
543
+ self._is_fitted = True
544
+ return output_result
531
545
 
532
546
 
533
547
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -618,10 +632,8 @@ class PolynomialCountSketch(BaseTransformer):
618
632
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
619
633
 
620
634
  if isinstance(dataset, DataFrame):
621
- self._deps = self._batch_inference_validate_snowpark(
622
- dataset=dataset,
623
- inference_method=inference_method,
624
- )
635
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
636
+ self._deps = self._get_dependencies()
625
637
  assert isinstance(
626
638
  dataset._session, Session
627
639
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -686,10 +698,8 @@ class PolynomialCountSketch(BaseTransformer):
686
698
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
687
699
 
688
700
  if isinstance(dataset, DataFrame):
689
- self._deps = self._batch_inference_validate_snowpark(
690
- dataset=dataset,
691
- inference_method=inference_method,
692
- )
701
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
702
+ self._deps = self._get_dependencies()
693
703
  assert isinstance(
694
704
  dataset._session, Session
695
705
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -751,10 +761,8 @@ class PolynomialCountSketch(BaseTransformer):
751
761
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
752
762
 
753
763
  if isinstance(dataset, DataFrame):
754
- self._deps = self._batch_inference_validate_snowpark(
755
- dataset=dataset,
756
- inference_method=inference_method,
757
- )
764
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
765
+ self._deps = self._get_dependencies()
758
766
  assert isinstance(
759
767
  dataset._session, Session
760
768
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -820,10 +828,8 @@ class PolynomialCountSketch(BaseTransformer):
820
828
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
821
829
 
822
830
  if isinstance(dataset, DataFrame):
823
- self._deps = self._batch_inference_validate_snowpark(
824
- dataset=dataset,
825
- inference_method=inference_method,
826
- )
831
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
832
+ self._deps = self._get_dependencies()
827
833
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
828
834
  transform_kwargs = dict(
829
835
  session=dataset._session,
@@ -885,17 +891,15 @@ class PolynomialCountSketch(BaseTransformer):
885
891
  transform_kwargs: ScoreKwargsTypedDict = dict()
886
892
 
887
893
  if isinstance(dataset, DataFrame):
888
- self._deps = self._batch_inference_validate_snowpark(
889
- dataset=dataset,
890
- inference_method="score",
891
- )
894
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
895
+ self._deps = self._get_dependencies()
892
896
  selected_cols = self._get_active_columns()
893
897
  if len(selected_cols) > 0:
894
898
  dataset = dataset.select(selected_cols)
895
899
  assert isinstance(dataset._session, Session) # keep mypy happy
896
900
  transform_kwargs = dict(
897
901
  session=dataset._session,
898
- dependencies=["snowflake-snowpark-python"] + self._deps,
902
+ dependencies=self._deps,
899
903
  score_sproc_imports=['sklearn'],
900
904
  )
901
905
  elif isinstance(dataset, pd.DataFrame):
@@ -960,11 +964,8 @@ class PolynomialCountSketch(BaseTransformer):
960
964
 
961
965
  if isinstance(dataset, DataFrame):
962
966
 
963
- self._deps = self._batch_inference_validate_snowpark(
964
- dataset=dataset,
965
- inference_method=inference_method,
966
-
967
- )
967
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
968
+ self._deps = self._get_dependencies()
968
969
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
969
970
  transform_kwargs = dict(
970
971
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RBFSampler(BaseTransformer):
70
64
  r"""Approximate a RBF kernel feature map using random Fourier features
71
65
  For more details on this class, see [sklearn.kernel_approximation.RBFSampler]
@@ -271,20 +265,17 @@ class RBFSampler(BaseTransformer):
271
265
  self,
272
266
  dataset: DataFrame,
273
267
  inference_method: str,
274
- ) -> List[str]:
275
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
276
- return the available package that exists in the snowflake anaconda channel
268
+ ) -> None:
269
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
277
270
 
278
271
  Args:
279
272
  dataset: snowpark dataframe
280
273
  inference_method: the inference method such as predict, score...
281
-
274
+
282
275
  Raises:
283
276
  SnowflakeMLException: If the estimator is not fitted, raise error
284
277
  SnowflakeMLException: If the session is None, raise error
285
278
 
286
- Returns:
287
- A list of available package that exists in the snowflake anaconda channel
288
279
  """
289
280
  if not self._is_fitted:
290
281
  raise exceptions.SnowflakeMLException(
@@ -302,9 +293,7 @@ class RBFSampler(BaseTransformer):
302
293
  "Session must not specified for snowpark dataset."
303
294
  ),
304
295
  )
305
- # Validate that key package version in user workspace are supported in snowflake conda channel
306
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
307
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
296
+
308
297
 
309
298
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
310
299
  @telemetry.send_api_usage_telemetry(
@@ -350,7 +339,8 @@ class RBFSampler(BaseTransformer):
350
339
 
351
340
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
352
341
 
353
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
342
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
343
+ self._deps = self._get_dependencies()
354
344
  assert isinstance(
355
345
  dataset._session, Session
356
346
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -435,10 +425,8 @@ class RBFSampler(BaseTransformer):
435
425
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
436
426
  expected_dtype = convert_sp_to_sf_type(output_types[0])
437
427
 
438
- self._deps = self._batch_inference_validate_snowpark(
439
- dataset=dataset,
440
- inference_method=inference_method,
441
- )
428
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
429
+ self._deps = self._get_dependencies()
442
430
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
443
431
 
444
432
  transform_kwargs = dict(
@@ -505,16 +493,42 @@ class RBFSampler(BaseTransformer):
505
493
  self._is_fitted = True
506
494
  return output_result
507
495
 
496
+
497
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
498
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
499
+ """ Fit to data, then transform it
500
+ For more details on this function, see [sklearn.kernel_approximation.RBFSampler.fit_transform]
501
+ (https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.RBFSampler.html#sklearn.kernel_approximation.RBFSampler.fit_transform)
502
+
508
503
 
509
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
510
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
511
- """
504
+ Raises:
505
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
506
+
507
+ Args:
508
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
509
+ Snowpark or Pandas DataFrame.
510
+ output_cols_prefix: Prefix for the response columns
512
511
  Returns:
513
512
  Transformed dataset.
514
513
  """
515
- self.fit(dataset)
516
- assert self._sklearn_object is not None
517
- return self._sklearn_object.embedding_
514
+ self._infer_input_output_cols(dataset)
515
+ super()._check_dataset_type(dataset)
516
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
517
+ estimator=self._sklearn_object,
518
+ dataset=dataset,
519
+ input_cols=self.input_cols,
520
+ label_cols=self.label_cols,
521
+ sample_weight_col=self.sample_weight_col,
522
+ autogenerated=self._autogenerated,
523
+ subproject=_SUBPROJECT,
524
+ )
525
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
526
+ drop_input_cols=self._drop_input_cols,
527
+ expected_output_cols_list=self.output_cols,
528
+ )
529
+ self._sklearn_object = fitted_estimator
530
+ self._is_fitted = True
531
+ return output_result
518
532
 
519
533
 
520
534
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -605,10 +619,8 @@ class RBFSampler(BaseTransformer):
605
619
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
606
620
 
607
621
  if isinstance(dataset, DataFrame):
608
- self._deps = self._batch_inference_validate_snowpark(
609
- dataset=dataset,
610
- inference_method=inference_method,
611
- )
622
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
623
+ self._deps = self._get_dependencies()
612
624
  assert isinstance(
613
625
  dataset._session, Session
614
626
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -673,10 +685,8 @@ class RBFSampler(BaseTransformer):
673
685
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
674
686
 
675
687
  if isinstance(dataset, DataFrame):
676
- self._deps = self._batch_inference_validate_snowpark(
677
- dataset=dataset,
678
- inference_method=inference_method,
679
- )
688
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
689
+ self._deps = self._get_dependencies()
680
690
  assert isinstance(
681
691
  dataset._session, Session
682
692
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -738,10 +748,8 @@ class RBFSampler(BaseTransformer):
738
748
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
739
749
 
740
750
  if isinstance(dataset, DataFrame):
741
- self._deps = self._batch_inference_validate_snowpark(
742
- dataset=dataset,
743
- inference_method=inference_method,
744
- )
751
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
752
+ self._deps = self._get_dependencies()
745
753
  assert isinstance(
746
754
  dataset._session, Session
747
755
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -807,10 +815,8 @@ class RBFSampler(BaseTransformer):
807
815
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
808
816
 
809
817
  if isinstance(dataset, DataFrame):
810
- self._deps = self._batch_inference_validate_snowpark(
811
- dataset=dataset,
812
- inference_method=inference_method,
813
- )
818
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
819
+ self._deps = self._get_dependencies()
814
820
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
815
821
  transform_kwargs = dict(
816
822
  session=dataset._session,
@@ -872,17 +878,15 @@ class RBFSampler(BaseTransformer):
872
878
  transform_kwargs: ScoreKwargsTypedDict = dict()
873
879
 
874
880
  if isinstance(dataset, DataFrame):
875
- self._deps = self._batch_inference_validate_snowpark(
876
- dataset=dataset,
877
- inference_method="score",
878
- )
881
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
882
+ self._deps = self._get_dependencies()
879
883
  selected_cols = self._get_active_columns()
880
884
  if len(selected_cols) > 0:
881
885
  dataset = dataset.select(selected_cols)
882
886
  assert isinstance(dataset._session, Session) # keep mypy happy
883
887
  transform_kwargs = dict(
884
888
  session=dataset._session,
885
- dependencies=["snowflake-snowpark-python"] + self._deps,
889
+ dependencies=self._deps,
886
890
  score_sproc_imports=['sklearn'],
887
891
  )
888
892
  elif isinstance(dataset, pd.DataFrame):
@@ -947,11 +951,8 @@ class RBFSampler(BaseTransformer):
947
951
 
948
952
  if isinstance(dataset, DataFrame):
949
953
 
950
- self._deps = self._batch_inference_validate_snowpark(
951
- dataset=dataset,
952
- inference_method=inference_method,
953
-
954
- )
954
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
955
+ self._deps = self._get_dependencies()
955
956
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
956
957
  transform_kwargs = dict(
957
958
  session = dataset._session,