snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace(
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class BernoulliRBM(BaseTransformer):
70
64
  r"""Bernoulli Restricted Boltzmann Machine (RBM)
71
65
  For more details on this class, see [sklearn.neural_network.BernoulliRBM]
@@ -293,20 +287,17 @@ class BernoulliRBM(BaseTransformer):
293
287
  self,
294
288
  dataset: DataFrame,
295
289
  inference_method: str,
296
- ) -> List[str]:
297
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
298
- return the available package that exists in the snowflake anaconda channel
290
+ ) -> None:
291
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
299
292
 
300
293
  Args:
301
294
  dataset: snowpark dataframe
302
295
  inference_method: the inference method such as predict, score...
303
-
296
+
304
297
  Raises:
305
298
  SnowflakeMLException: If the estimator is not fitted, raise error
306
299
  SnowflakeMLException: If the session is None, raise error
307
300
 
308
- Returns:
309
- A list of available package that exists in the snowflake anaconda channel
310
301
  """
311
302
  if not self._is_fitted:
312
303
  raise exceptions.SnowflakeMLException(
@@ -324,9 +315,7 @@ class BernoulliRBM(BaseTransformer):
324
315
  "Session must not specified for snowpark dataset."
325
316
  ),
326
317
  )
327
- # Validate that key package version in user workspace are supported in snowflake conda channel
328
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
329
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
318
+
330
319
 
331
320
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
332
321
  @telemetry.send_api_usage_telemetry(
@@ -372,7 +361,8 @@ class BernoulliRBM(BaseTransformer):
372
361
 
373
362
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
374
363
 
375
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
364
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
365
+ self._deps = self._get_dependencies()
376
366
  assert isinstance(
377
367
  dataset._session, Session
378
368
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -457,10 +447,8 @@ class BernoulliRBM(BaseTransformer):
457
447
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
458
448
  expected_dtype = convert_sp_to_sf_type(output_types[0])
459
449
 
460
- self._deps = self._batch_inference_validate_snowpark(
461
- dataset=dataset,
462
- inference_method=inference_method,
463
- )
450
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
451
+ self._deps = self._get_dependencies()
464
452
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
465
453
 
466
454
  transform_kwargs = dict(
@@ -527,16 +515,42 @@ class BernoulliRBM(BaseTransformer):
527
515
  self._is_fitted = True
528
516
  return output_result
529
517
 
518
+
519
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
520
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
521
+ """ Fit to data, then transform it
522
+ For more details on this function, see [sklearn.neural_network.BernoulliRBM.fit_transform]
523
+ (https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.BernoulliRBM.html#sklearn.neural_network.BernoulliRBM.fit_transform)
524
+
530
525
 
531
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
532
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
533
- """
526
+ Raises:
527
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
528
+
529
+ Args:
530
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
531
+ Snowpark or Pandas DataFrame.
532
+ output_cols_prefix: Prefix for the response columns
534
533
  Returns:
535
534
  Transformed dataset.
536
535
  """
537
- self.fit(dataset)
538
- assert self._sklearn_object is not None
539
- return self._sklearn_object.embedding_
536
+ self._infer_input_output_cols(dataset)
537
+ super()._check_dataset_type(dataset)
538
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
539
+ estimator=self._sklearn_object,
540
+ dataset=dataset,
541
+ input_cols=self.input_cols,
542
+ label_cols=self.label_cols,
543
+ sample_weight_col=self.sample_weight_col,
544
+ autogenerated=self._autogenerated,
545
+ subproject=_SUBPROJECT,
546
+ )
547
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=self.output_cols,
550
+ )
551
+ self._sklearn_object = fitted_estimator
552
+ self._is_fitted = True
553
+ return output_result
540
554
 
541
555
 
542
556
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -627,10 +641,8 @@ class BernoulliRBM(BaseTransformer):
627
641
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
628
642
 
629
643
  if isinstance(dataset, DataFrame):
630
- self._deps = self._batch_inference_validate_snowpark(
631
- dataset=dataset,
632
- inference_method=inference_method,
633
- )
644
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
645
+ self._deps = self._get_dependencies()
634
646
  assert isinstance(
635
647
  dataset._session, Session
636
648
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -695,10 +707,8 @@ class BernoulliRBM(BaseTransformer):
695
707
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
696
708
 
697
709
  if isinstance(dataset, DataFrame):
698
- self._deps = self._batch_inference_validate_snowpark(
699
- dataset=dataset,
700
- inference_method=inference_method,
701
- )
710
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
711
+ self._deps = self._get_dependencies()
702
712
  assert isinstance(
703
713
  dataset._session, Session
704
714
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -760,10 +770,8 @@ class BernoulliRBM(BaseTransformer):
760
770
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
761
771
 
762
772
  if isinstance(dataset, DataFrame):
763
- self._deps = self._batch_inference_validate_snowpark(
764
- dataset=dataset,
765
- inference_method=inference_method,
766
- )
773
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
774
+ self._deps = self._get_dependencies()
767
775
  assert isinstance(
768
776
  dataset._session, Session
769
777
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -831,10 +839,8 @@ class BernoulliRBM(BaseTransformer):
831
839
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
832
840
 
833
841
  if isinstance(dataset, DataFrame):
834
- self._deps = self._batch_inference_validate_snowpark(
835
- dataset=dataset,
836
- inference_method=inference_method,
837
- )
842
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
843
+ self._deps = self._get_dependencies()
838
844
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
839
845
  transform_kwargs = dict(
840
846
  session=dataset._session,
@@ -896,17 +902,15 @@ class BernoulliRBM(BaseTransformer):
896
902
  transform_kwargs: ScoreKwargsTypedDict = dict()
897
903
 
898
904
  if isinstance(dataset, DataFrame):
899
- self._deps = self._batch_inference_validate_snowpark(
900
- dataset=dataset,
901
- inference_method="score",
902
- )
905
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
906
+ self._deps = self._get_dependencies()
903
907
  selected_cols = self._get_active_columns()
904
908
  if len(selected_cols) > 0:
905
909
  dataset = dataset.select(selected_cols)
906
910
  assert isinstance(dataset._session, Session) # keep mypy happy
907
911
  transform_kwargs = dict(
908
912
  session=dataset._session,
909
- dependencies=["snowflake-snowpark-python"] + self._deps,
913
+ dependencies=self._deps,
910
914
  score_sproc_imports=['sklearn'],
911
915
  )
912
916
  elif isinstance(dataset, pd.DataFrame):
@@ -971,11 +975,8 @@ class BernoulliRBM(BaseTransformer):
971
975
 
972
976
  if isinstance(dataset, DataFrame):
973
977
 
974
- self._deps = self._batch_inference_validate_snowpark(
975
- dataset=dataset,
976
- inference_method=inference_method,
977
-
978
- )
978
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
979
+ self._deps = self._get_dependencies()
979
980
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
980
981
  transform_kwargs = dict(
981
982
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace(
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MLPClassifier(BaseTransformer):
70
64
  r"""Multi-layer Perceptron classifier
71
65
  For more details on this class, see [sklearn.neural_network.MLPClassifier]
@@ -448,20 +442,17 @@ class MLPClassifier(BaseTransformer):
448
442
  self,
449
443
  dataset: DataFrame,
450
444
  inference_method: str,
451
- ) -> List[str]:
452
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
453
- return the available package that exists in the snowflake anaconda channel
445
+ ) -> None:
446
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
454
447
 
455
448
  Args:
456
449
  dataset: snowpark dataframe
457
450
  inference_method: the inference method such as predict, score...
458
-
451
+
459
452
  Raises:
460
453
  SnowflakeMLException: If the estimator is not fitted, raise error
461
454
  SnowflakeMLException: If the session is None, raise error
462
455
 
463
- Returns:
464
- A list of available package that exists in the snowflake anaconda channel
465
456
  """
466
457
  if not self._is_fitted:
467
458
  raise exceptions.SnowflakeMLException(
@@ -479,9 +470,7 @@ class MLPClassifier(BaseTransformer):
479
470
  "Session must not specified for snowpark dataset."
480
471
  ),
481
472
  )
482
- # Validate that key package version in user workspace are supported in snowflake conda channel
483
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
484
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
473
+
485
474
 
486
475
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
487
476
  @telemetry.send_api_usage_telemetry(
@@ -529,7 +518,8 @@ class MLPClassifier(BaseTransformer):
529
518
 
530
519
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
531
520
 
532
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
521
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
522
+ self._deps = self._get_dependencies()
533
523
  assert isinstance(
534
524
  dataset._session, Session
535
525
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -612,10 +602,8 @@ class MLPClassifier(BaseTransformer):
612
602
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
613
603
  expected_dtype = convert_sp_to_sf_type(output_types[0])
614
604
 
615
- self._deps = self._batch_inference_validate_snowpark(
616
- dataset=dataset,
617
- inference_method=inference_method,
618
- )
605
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
606
+ self._deps = self._get_dependencies()
619
607
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
620
608
 
621
609
  transform_kwargs = dict(
@@ -682,16 +670,40 @@ class MLPClassifier(BaseTransformer):
682
670
  self._is_fitted = True
683
671
  return output_result
684
672
 
673
+
674
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
675
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
676
+ """ Method not supported for this class.
685
677
 
686
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
687
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
688
- """
678
+
679
+ Raises:
680
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
681
+
682
+ Args:
683
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
684
+ Snowpark or Pandas DataFrame.
685
+ output_cols_prefix: Prefix for the response columns
689
686
  Returns:
690
687
  Transformed dataset.
691
688
  """
692
- self.fit(dataset)
693
- assert self._sklearn_object is not None
694
- return self._sklearn_object.embedding_
689
+ self._infer_input_output_cols(dataset)
690
+ super()._check_dataset_type(dataset)
691
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
692
+ estimator=self._sklearn_object,
693
+ dataset=dataset,
694
+ input_cols=self.input_cols,
695
+ label_cols=self.label_cols,
696
+ sample_weight_col=self.sample_weight_col,
697
+ autogenerated=self._autogenerated,
698
+ subproject=_SUBPROJECT,
699
+ )
700
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
701
+ drop_input_cols=self._drop_input_cols,
702
+ expected_output_cols_list=self.output_cols,
703
+ )
704
+ self._sklearn_object = fitted_estimator
705
+ self._is_fitted = True
706
+ return output_result
695
707
 
696
708
 
697
709
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -784,10 +796,8 @@ class MLPClassifier(BaseTransformer):
784
796
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
785
797
 
786
798
  if isinstance(dataset, DataFrame):
787
- self._deps = self._batch_inference_validate_snowpark(
788
- dataset=dataset,
789
- inference_method=inference_method,
790
- )
799
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
800
+ self._deps = self._get_dependencies()
791
801
  assert isinstance(
792
802
  dataset._session, Session
793
803
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -854,10 +864,8 @@ class MLPClassifier(BaseTransformer):
854
864
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
855
865
 
856
866
  if isinstance(dataset, DataFrame):
857
- self._deps = self._batch_inference_validate_snowpark(
858
- dataset=dataset,
859
- inference_method=inference_method,
860
- )
867
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
868
+ self._deps = self._get_dependencies()
861
869
  assert isinstance(
862
870
  dataset._session, Session
863
871
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -919,10 +927,8 @@ class MLPClassifier(BaseTransformer):
919
927
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
920
928
 
921
929
  if isinstance(dataset, DataFrame):
922
- self._deps = self._batch_inference_validate_snowpark(
923
- dataset=dataset,
924
- inference_method=inference_method,
925
- )
930
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
931
+ self._deps = self._get_dependencies()
926
932
  assert isinstance(
927
933
  dataset._session, Session
928
934
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -988,10 +994,8 @@ class MLPClassifier(BaseTransformer):
988
994
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
989
995
 
990
996
  if isinstance(dataset, DataFrame):
991
- self._deps = self._batch_inference_validate_snowpark(
992
- dataset=dataset,
993
- inference_method=inference_method,
994
- )
997
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
998
+ self._deps = self._get_dependencies()
995
999
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
996
1000
  transform_kwargs = dict(
997
1001
  session=dataset._session,
@@ -1055,17 +1059,15 @@ class MLPClassifier(BaseTransformer):
1055
1059
  transform_kwargs: ScoreKwargsTypedDict = dict()
1056
1060
 
1057
1061
  if isinstance(dataset, DataFrame):
1058
- self._deps = self._batch_inference_validate_snowpark(
1059
- dataset=dataset,
1060
- inference_method="score",
1061
- )
1062
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1063
+ self._deps = self._get_dependencies()
1062
1064
  selected_cols = self._get_active_columns()
1063
1065
  if len(selected_cols) > 0:
1064
1066
  dataset = dataset.select(selected_cols)
1065
1067
  assert isinstance(dataset._session, Session) # keep mypy happy
1066
1068
  transform_kwargs = dict(
1067
1069
  session=dataset._session,
1068
- dependencies=["snowflake-snowpark-python"] + self._deps,
1070
+ dependencies=self._deps,
1069
1071
  score_sproc_imports=['sklearn'],
1070
1072
  )
1071
1073
  elif isinstance(dataset, pd.DataFrame):
@@ -1130,11 +1132,8 @@ class MLPClassifier(BaseTransformer):
1130
1132
 
1131
1133
  if isinstance(dataset, DataFrame):
1132
1134
 
1133
- self._deps = self._batch_inference_validate_snowpark(
1134
- dataset=dataset,
1135
- inference_method=inference_method,
1136
-
1137
- )
1135
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1136
+ self._deps = self._get_dependencies()
1138
1137
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1139
1138
  transform_kwargs = dict(
1140
1139
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neural_network".replace(
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MLPRegressor(BaseTransformer):
70
64
  r"""Multi-layer Perceptron regressor
71
65
  For more details on this class, see [sklearn.neural_network.MLPRegressor]
@@ -444,20 +438,17 @@ class MLPRegressor(BaseTransformer):
444
438
  self,
445
439
  dataset: DataFrame,
446
440
  inference_method: str,
447
- ) -> List[str]:
448
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
449
- return the available package that exists in the snowflake anaconda channel
441
+ ) -> None:
442
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
450
443
 
451
444
  Args:
452
445
  dataset: snowpark dataframe
453
446
  inference_method: the inference method such as predict, score...
454
-
447
+
455
448
  Raises:
456
449
  SnowflakeMLException: If the estimator is not fitted, raise error
457
450
  SnowflakeMLException: If the session is None, raise error
458
451
 
459
- Returns:
460
- A list of available package that exists in the snowflake anaconda channel
461
452
  """
462
453
  if not self._is_fitted:
463
454
  raise exceptions.SnowflakeMLException(
@@ -475,9 +466,7 @@ class MLPRegressor(BaseTransformer):
475
466
  "Session must not specified for snowpark dataset."
476
467
  ),
477
468
  )
478
- # Validate that key package version in user workspace are supported in snowflake conda channel
479
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
480
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
469
+
481
470
 
482
471
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
483
472
  @telemetry.send_api_usage_telemetry(
@@ -525,7 +514,8 @@ class MLPRegressor(BaseTransformer):
525
514
 
526
515
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
527
516
 
528
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
517
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
518
+ self._deps = self._get_dependencies()
529
519
  assert isinstance(
530
520
  dataset._session, Session
531
521
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -608,10 +598,8 @@ class MLPRegressor(BaseTransformer):
608
598
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
609
599
  expected_dtype = convert_sp_to_sf_type(output_types[0])
610
600
 
611
- self._deps = self._batch_inference_validate_snowpark(
612
- dataset=dataset,
613
- inference_method=inference_method,
614
- )
601
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
602
+ self._deps = self._get_dependencies()
615
603
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
616
604
 
617
605
  transform_kwargs = dict(
@@ -678,16 +666,40 @@ class MLPRegressor(BaseTransformer):
678
666
  self._is_fitted = True
679
667
  return output_result
680
668
 
669
+
670
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
671
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
672
+ """ Method not supported for this class.
681
673
 
682
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
683
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
684
- """
674
+
675
+ Raises:
676
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
677
+
678
+ Args:
679
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
680
+ Snowpark or Pandas DataFrame.
681
+ output_cols_prefix: Prefix for the response columns
685
682
  Returns:
686
683
  Transformed dataset.
687
684
  """
688
- self.fit(dataset)
689
- assert self._sklearn_object is not None
690
- return self._sklearn_object.embedding_
685
+ self._infer_input_output_cols(dataset)
686
+ super()._check_dataset_type(dataset)
687
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
688
+ estimator=self._sklearn_object,
689
+ dataset=dataset,
690
+ input_cols=self.input_cols,
691
+ label_cols=self.label_cols,
692
+ sample_weight_col=self.sample_weight_col,
693
+ autogenerated=self._autogenerated,
694
+ subproject=_SUBPROJECT,
695
+ )
696
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
697
+ drop_input_cols=self._drop_input_cols,
698
+ expected_output_cols_list=self.output_cols,
699
+ )
700
+ self._sklearn_object = fitted_estimator
701
+ self._is_fitted = True
702
+ return output_result
691
703
 
692
704
 
693
705
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -778,10 +790,8 @@ class MLPRegressor(BaseTransformer):
778
790
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
779
791
 
780
792
  if isinstance(dataset, DataFrame):
781
- self._deps = self._batch_inference_validate_snowpark(
782
- dataset=dataset,
783
- inference_method=inference_method,
784
- )
793
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
794
+ self._deps = self._get_dependencies()
785
795
  assert isinstance(
786
796
  dataset._session, Session
787
797
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -846,10 +856,8 @@ class MLPRegressor(BaseTransformer):
846
856
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
847
857
 
848
858
  if isinstance(dataset, DataFrame):
849
- self._deps = self._batch_inference_validate_snowpark(
850
- dataset=dataset,
851
- inference_method=inference_method,
852
- )
859
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
860
+ self._deps = self._get_dependencies()
853
861
  assert isinstance(
854
862
  dataset._session, Session
855
863
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -911,10 +919,8 @@ class MLPRegressor(BaseTransformer):
911
919
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
912
920
 
913
921
  if isinstance(dataset, DataFrame):
914
- self._deps = self._batch_inference_validate_snowpark(
915
- dataset=dataset,
916
- inference_method=inference_method,
917
- )
922
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
923
+ self._deps = self._get_dependencies()
918
924
  assert isinstance(
919
925
  dataset._session, Session
920
926
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -980,10 +986,8 @@ class MLPRegressor(BaseTransformer):
980
986
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
981
987
 
982
988
  if isinstance(dataset, DataFrame):
983
- self._deps = self._batch_inference_validate_snowpark(
984
- dataset=dataset,
985
- inference_method=inference_method,
986
- )
989
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
990
+ self._deps = self._get_dependencies()
987
991
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
988
992
  transform_kwargs = dict(
989
993
  session=dataset._session,
@@ -1047,17 +1051,15 @@ class MLPRegressor(BaseTransformer):
1047
1051
  transform_kwargs: ScoreKwargsTypedDict = dict()
1048
1052
 
1049
1053
  if isinstance(dataset, DataFrame):
1050
- self._deps = self._batch_inference_validate_snowpark(
1051
- dataset=dataset,
1052
- inference_method="score",
1053
- )
1054
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1055
+ self._deps = self._get_dependencies()
1054
1056
  selected_cols = self._get_active_columns()
1055
1057
  if len(selected_cols) > 0:
1056
1058
  dataset = dataset.select(selected_cols)
1057
1059
  assert isinstance(dataset._session, Session) # keep mypy happy
1058
1060
  transform_kwargs = dict(
1059
1061
  session=dataset._session,
1060
- dependencies=["snowflake-snowpark-python"] + self._deps,
1062
+ dependencies=self._deps,
1061
1063
  score_sproc_imports=['sklearn'],
1062
1064
  )
1063
1065
  elif isinstance(dataset, pd.DataFrame):
@@ -1122,11 +1124,8 @@ class MLPRegressor(BaseTransformer):
1122
1124
 
1123
1125
  if isinstance(dataset, DataFrame):
1124
1126
 
1125
- self._deps = self._batch_inference_validate_snowpark(
1126
- dataset=dataset,
1127
- inference_method=inference_method,
1128
-
1129
- )
1127
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1128
+ self._deps = self._get_dependencies()
1130
1129
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1131
1130
  transform_kwargs = dict(
1132
1131
  session = dataset._session,