snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ElasticNet(BaseTransformer):
70
64
  r"""Linear regression with combined L1 and L2 priors as regularizer
71
65
  For more details on this class, see [sklearn.linear_model.ElasticNet]
@@ -329,20 +323,17 @@ class ElasticNet(BaseTransformer):
329
323
  self,
330
324
  dataset: DataFrame,
331
325
  inference_method: str,
332
- ) -> List[str]:
333
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
334
- return the available package that exists in the snowflake anaconda channel
326
+ ) -> None:
327
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
335
328
 
336
329
  Args:
337
330
  dataset: snowpark dataframe
338
331
  inference_method: the inference method such as predict, score...
339
-
332
+
340
333
  Raises:
341
334
  SnowflakeMLException: If the estimator is not fitted, raise error
342
335
  SnowflakeMLException: If the session is None, raise error
343
336
 
344
- Returns:
345
- A list of available package that exists in the snowflake anaconda channel
346
337
  """
347
338
  if not self._is_fitted:
348
339
  raise exceptions.SnowflakeMLException(
@@ -360,9 +351,7 @@ class ElasticNet(BaseTransformer):
360
351
  "Session must not specified for snowpark dataset."
361
352
  ),
362
353
  )
363
- # Validate that key package version in user workspace are supported in snowflake conda channel
364
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
365
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
354
+
366
355
 
367
356
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
368
357
  @telemetry.send_api_usage_telemetry(
@@ -410,7 +399,8 @@ class ElasticNet(BaseTransformer):
410
399
 
411
400
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
412
401
 
413
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
402
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
403
+ self._deps = self._get_dependencies()
414
404
  assert isinstance(
415
405
  dataset._session, Session
416
406
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -493,10 +483,8 @@ class ElasticNet(BaseTransformer):
493
483
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
494
484
  expected_dtype = convert_sp_to_sf_type(output_types[0])
495
485
 
496
- self._deps = self._batch_inference_validate_snowpark(
497
- dataset=dataset,
498
- inference_method=inference_method,
499
- )
486
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
487
+ self._deps = self._get_dependencies()
500
488
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
501
489
 
502
490
  transform_kwargs = dict(
@@ -563,16 +551,40 @@ class ElasticNet(BaseTransformer):
563
551
  self._is_fitted = True
564
552
  return output_result
565
553
 
554
+
555
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
556
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
557
+ """ Method not supported for this class.
566
558
 
567
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
568
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
569
- """
559
+
560
+ Raises:
561
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
562
+
563
+ Args:
564
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
565
+ Snowpark or Pandas DataFrame.
566
+ output_cols_prefix: Prefix for the response columns
570
567
  Returns:
571
568
  Transformed dataset.
572
569
  """
573
- self.fit(dataset)
574
- assert self._sklearn_object is not None
575
- return self._sklearn_object.embedding_
570
+ self._infer_input_output_cols(dataset)
571
+ super()._check_dataset_type(dataset)
572
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
573
+ estimator=self._sklearn_object,
574
+ dataset=dataset,
575
+ input_cols=self.input_cols,
576
+ label_cols=self.label_cols,
577
+ sample_weight_col=self.sample_weight_col,
578
+ autogenerated=self._autogenerated,
579
+ subproject=_SUBPROJECT,
580
+ )
581
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
582
+ drop_input_cols=self._drop_input_cols,
583
+ expected_output_cols_list=self.output_cols,
584
+ )
585
+ self._sklearn_object = fitted_estimator
586
+ self._is_fitted = True
587
+ return output_result
576
588
 
577
589
 
578
590
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -663,10 +675,8 @@ class ElasticNet(BaseTransformer):
663
675
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
664
676
 
665
677
  if isinstance(dataset, DataFrame):
666
- self._deps = self._batch_inference_validate_snowpark(
667
- dataset=dataset,
668
- inference_method=inference_method,
669
- )
678
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
679
+ self._deps = self._get_dependencies()
670
680
  assert isinstance(
671
681
  dataset._session, Session
672
682
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -731,10 +741,8 @@ class ElasticNet(BaseTransformer):
731
741
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
732
742
 
733
743
  if isinstance(dataset, DataFrame):
734
- self._deps = self._batch_inference_validate_snowpark(
735
- dataset=dataset,
736
- inference_method=inference_method,
737
- )
744
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
745
+ self._deps = self._get_dependencies()
738
746
  assert isinstance(
739
747
  dataset._session, Session
740
748
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -796,10 +804,8 @@ class ElasticNet(BaseTransformer):
796
804
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
797
805
 
798
806
  if isinstance(dataset, DataFrame):
799
- self._deps = self._batch_inference_validate_snowpark(
800
- dataset=dataset,
801
- inference_method=inference_method,
802
- )
807
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
808
+ self._deps = self._get_dependencies()
803
809
  assert isinstance(
804
810
  dataset._session, Session
805
811
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -865,10 +871,8 @@ class ElasticNet(BaseTransformer):
865
871
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
866
872
 
867
873
  if isinstance(dataset, DataFrame):
868
- self._deps = self._batch_inference_validate_snowpark(
869
- dataset=dataset,
870
- inference_method=inference_method,
871
- )
874
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
875
+ self._deps = self._get_dependencies()
872
876
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
873
877
  transform_kwargs = dict(
874
878
  session=dataset._session,
@@ -932,17 +936,15 @@ class ElasticNet(BaseTransformer):
932
936
  transform_kwargs: ScoreKwargsTypedDict = dict()
933
937
 
934
938
  if isinstance(dataset, DataFrame):
935
- self._deps = self._batch_inference_validate_snowpark(
936
- dataset=dataset,
937
- inference_method="score",
938
- )
939
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
940
+ self._deps = self._get_dependencies()
939
941
  selected_cols = self._get_active_columns()
940
942
  if len(selected_cols) > 0:
941
943
  dataset = dataset.select(selected_cols)
942
944
  assert isinstance(dataset._session, Session) # keep mypy happy
943
945
  transform_kwargs = dict(
944
946
  session=dataset._session,
945
- dependencies=["snowflake-snowpark-python"] + self._deps,
947
+ dependencies=self._deps,
946
948
  score_sproc_imports=['sklearn'],
947
949
  )
948
950
  elif isinstance(dataset, pd.DataFrame):
@@ -1007,11 +1009,8 @@ class ElasticNet(BaseTransformer):
1007
1009
 
1008
1010
  if isinstance(dataset, DataFrame):
1009
1011
 
1010
- self._deps = self._batch_inference_validate_snowpark(
1011
- dataset=dataset,
1012
- inference_method=inference_method,
1013
-
1014
- )
1012
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1013
+ self._deps = self._get_dependencies()
1015
1014
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1016
1015
  transform_kwargs = dict(
1017
1016
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ElasticNetCV(BaseTransformer):
70
64
  r"""Elastic Net model with iterative fitting along a regularization path
71
65
  For more details on this class, see [sklearn.linear_model.ElasticNetCV]
@@ -365,20 +359,17 @@ class ElasticNetCV(BaseTransformer):
365
359
  self,
366
360
  dataset: DataFrame,
367
361
  inference_method: str,
368
- ) -> List[str]:
369
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
370
- return the available package that exists in the snowflake anaconda channel
362
+ ) -> None:
363
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
371
364
 
372
365
  Args:
373
366
  dataset: snowpark dataframe
374
367
  inference_method: the inference method such as predict, score...
375
-
368
+
376
369
  Raises:
377
370
  SnowflakeMLException: If the estimator is not fitted, raise error
378
371
  SnowflakeMLException: If the session is None, raise error
379
372
 
380
- Returns:
381
- A list of available package that exists in the snowflake anaconda channel
382
373
  """
383
374
  if not self._is_fitted:
384
375
  raise exceptions.SnowflakeMLException(
@@ -396,9 +387,7 @@ class ElasticNetCV(BaseTransformer):
396
387
  "Session must not specified for snowpark dataset."
397
388
  ),
398
389
  )
399
- # Validate that key package version in user workspace are supported in snowflake conda channel
400
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
401
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
390
+
402
391
 
403
392
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
404
393
  @telemetry.send_api_usage_telemetry(
@@ -446,7 +435,8 @@ class ElasticNetCV(BaseTransformer):
446
435
 
447
436
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
448
437
 
449
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
438
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
439
+ self._deps = self._get_dependencies()
450
440
  assert isinstance(
451
441
  dataset._session, Session
452
442
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -529,10 +519,8 @@ class ElasticNetCV(BaseTransformer):
529
519
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
530
520
  expected_dtype = convert_sp_to_sf_type(output_types[0])
531
521
 
532
- self._deps = self._batch_inference_validate_snowpark(
533
- dataset=dataset,
534
- inference_method=inference_method,
535
- )
522
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
523
+ self._deps = self._get_dependencies()
536
524
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
537
525
 
538
526
  transform_kwargs = dict(
@@ -599,16 +587,40 @@ class ElasticNetCV(BaseTransformer):
599
587
  self._is_fitted = True
600
588
  return output_result
601
589
 
590
+
591
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
592
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
593
+ """ Method not supported for this class.
602
594
 
603
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
604
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
605
- """
595
+
596
+ Raises:
597
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
598
+
599
+ Args:
600
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
601
+ Snowpark or Pandas DataFrame.
602
+ output_cols_prefix: Prefix for the response columns
606
603
  Returns:
607
604
  Transformed dataset.
608
605
  """
609
- self.fit(dataset)
610
- assert self._sklearn_object is not None
611
- return self._sklearn_object.embedding_
606
+ self._infer_input_output_cols(dataset)
607
+ super()._check_dataset_type(dataset)
608
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
609
+ estimator=self._sklearn_object,
610
+ dataset=dataset,
611
+ input_cols=self.input_cols,
612
+ label_cols=self.label_cols,
613
+ sample_weight_col=self.sample_weight_col,
614
+ autogenerated=self._autogenerated,
615
+ subproject=_SUBPROJECT,
616
+ )
617
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
618
+ drop_input_cols=self._drop_input_cols,
619
+ expected_output_cols_list=self.output_cols,
620
+ )
621
+ self._sklearn_object = fitted_estimator
622
+ self._is_fitted = True
623
+ return output_result
612
624
 
613
625
 
614
626
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -699,10 +711,8 @@ class ElasticNetCV(BaseTransformer):
699
711
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
700
712
 
701
713
  if isinstance(dataset, DataFrame):
702
- self._deps = self._batch_inference_validate_snowpark(
703
- dataset=dataset,
704
- inference_method=inference_method,
705
- )
714
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
715
+ self._deps = self._get_dependencies()
706
716
  assert isinstance(
707
717
  dataset._session, Session
708
718
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -767,10 +777,8 @@ class ElasticNetCV(BaseTransformer):
767
777
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
768
778
 
769
779
  if isinstance(dataset, DataFrame):
770
- self._deps = self._batch_inference_validate_snowpark(
771
- dataset=dataset,
772
- inference_method=inference_method,
773
- )
780
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
781
+ self._deps = self._get_dependencies()
774
782
  assert isinstance(
775
783
  dataset._session, Session
776
784
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -832,10 +840,8 @@ class ElasticNetCV(BaseTransformer):
832
840
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
833
841
 
834
842
  if isinstance(dataset, DataFrame):
835
- self._deps = self._batch_inference_validate_snowpark(
836
- dataset=dataset,
837
- inference_method=inference_method,
838
- )
843
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
844
+ self._deps = self._get_dependencies()
839
845
  assert isinstance(
840
846
  dataset._session, Session
841
847
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -901,10 +907,8 @@ class ElasticNetCV(BaseTransformer):
901
907
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
902
908
 
903
909
  if isinstance(dataset, DataFrame):
904
- self._deps = self._batch_inference_validate_snowpark(
905
- dataset=dataset,
906
- inference_method=inference_method,
907
- )
910
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
911
+ self._deps = self._get_dependencies()
908
912
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
909
913
  transform_kwargs = dict(
910
914
  session=dataset._session,
@@ -968,17 +972,15 @@ class ElasticNetCV(BaseTransformer):
968
972
  transform_kwargs: ScoreKwargsTypedDict = dict()
969
973
 
970
974
  if isinstance(dataset, DataFrame):
971
- self._deps = self._batch_inference_validate_snowpark(
972
- dataset=dataset,
973
- inference_method="score",
974
- )
975
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
976
+ self._deps = self._get_dependencies()
975
977
  selected_cols = self._get_active_columns()
976
978
  if len(selected_cols) > 0:
977
979
  dataset = dataset.select(selected_cols)
978
980
  assert isinstance(dataset._session, Session) # keep mypy happy
979
981
  transform_kwargs = dict(
980
982
  session=dataset._session,
981
- dependencies=["snowflake-snowpark-python"] + self._deps,
983
+ dependencies=self._deps,
982
984
  score_sproc_imports=['sklearn'],
983
985
  )
984
986
  elif isinstance(dataset, pd.DataFrame):
@@ -1043,11 +1045,8 @@ class ElasticNetCV(BaseTransformer):
1043
1045
 
1044
1046
  if isinstance(dataset, DataFrame):
1045
1047
 
1046
- self._deps = self._batch_inference_validate_snowpark(
1047
- dataset=dataset,
1048
- inference_method=inference_method,
1049
-
1050
- )
1048
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1049
+ self._deps = self._get_dependencies()
1051
1050
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1052
1051
  transform_kwargs = dict(
1053
1052
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class GammaRegressor(BaseTransformer):
70
64
  r"""Generalized Linear Model with a Gamma distribution
71
65
  For more details on this class, see [sklearn.linear_model.GammaRegressor]
@@ -310,20 +304,17 @@ class GammaRegressor(BaseTransformer):
310
304
  self,
311
305
  dataset: DataFrame,
312
306
  inference_method: str,
313
- ) -> List[str]:
314
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
315
- return the available package that exists in the snowflake anaconda channel
307
+ ) -> None:
308
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
316
309
 
317
310
  Args:
318
311
  dataset: snowpark dataframe
319
312
  inference_method: the inference method such as predict, score...
320
-
313
+
321
314
  Raises:
322
315
  SnowflakeMLException: If the estimator is not fitted, raise error
323
316
  SnowflakeMLException: If the session is None, raise error
324
317
 
325
- Returns:
326
- A list of available package that exists in the snowflake anaconda channel
327
318
  """
328
319
  if not self._is_fitted:
329
320
  raise exceptions.SnowflakeMLException(
@@ -341,9 +332,7 @@ class GammaRegressor(BaseTransformer):
341
332
  "Session must not specified for snowpark dataset."
342
333
  ),
343
334
  )
344
- # Validate that key package version in user workspace are supported in snowflake conda channel
345
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
346
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
335
+
347
336
 
348
337
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
349
338
  @telemetry.send_api_usage_telemetry(
@@ -391,7 +380,8 @@ class GammaRegressor(BaseTransformer):
391
380
 
392
381
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
393
382
 
394
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
383
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
384
+ self._deps = self._get_dependencies()
395
385
  assert isinstance(
396
386
  dataset._session, Session
397
387
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -474,10 +464,8 @@ class GammaRegressor(BaseTransformer):
474
464
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
475
465
  expected_dtype = convert_sp_to_sf_type(output_types[0])
476
466
 
477
- self._deps = self._batch_inference_validate_snowpark(
478
- dataset=dataset,
479
- inference_method=inference_method,
480
- )
467
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
468
+ self._deps = self._get_dependencies()
481
469
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
482
470
 
483
471
  transform_kwargs = dict(
@@ -544,16 +532,40 @@ class GammaRegressor(BaseTransformer):
544
532
  self._is_fitted = True
545
533
  return output_result
546
534
 
535
+
536
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
537
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
538
+ """ Method not supported for this class.
547
539
 
548
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
549
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
550
- """
540
+
541
+ Raises:
542
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
543
+
544
+ Args:
545
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
546
+ Snowpark or Pandas DataFrame.
547
+ output_cols_prefix: Prefix for the response columns
551
548
  Returns:
552
549
  Transformed dataset.
553
550
  """
554
- self.fit(dataset)
555
- assert self._sklearn_object is not None
556
- return self._sklearn_object.embedding_
551
+ self._infer_input_output_cols(dataset)
552
+ super()._check_dataset_type(dataset)
553
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
554
+ estimator=self._sklearn_object,
555
+ dataset=dataset,
556
+ input_cols=self.input_cols,
557
+ label_cols=self.label_cols,
558
+ sample_weight_col=self.sample_weight_col,
559
+ autogenerated=self._autogenerated,
560
+ subproject=_SUBPROJECT,
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=self.output_cols,
565
+ )
566
+ self._sklearn_object = fitted_estimator
567
+ self._is_fitted = True
568
+ return output_result
557
569
 
558
570
 
559
571
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -644,10 +656,8 @@ class GammaRegressor(BaseTransformer):
644
656
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
645
657
 
646
658
  if isinstance(dataset, DataFrame):
647
- self._deps = self._batch_inference_validate_snowpark(
648
- dataset=dataset,
649
- inference_method=inference_method,
650
- )
659
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
660
+ self._deps = self._get_dependencies()
651
661
  assert isinstance(
652
662
  dataset._session, Session
653
663
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -712,10 +722,8 @@ class GammaRegressor(BaseTransformer):
712
722
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
713
723
 
714
724
  if isinstance(dataset, DataFrame):
715
- self._deps = self._batch_inference_validate_snowpark(
716
- dataset=dataset,
717
- inference_method=inference_method,
718
- )
725
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
726
+ self._deps = self._get_dependencies()
719
727
  assert isinstance(
720
728
  dataset._session, Session
721
729
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -777,10 +785,8 @@ class GammaRegressor(BaseTransformer):
777
785
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
778
786
 
779
787
  if isinstance(dataset, DataFrame):
780
- self._deps = self._batch_inference_validate_snowpark(
781
- dataset=dataset,
782
- inference_method=inference_method,
783
- )
788
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
789
+ self._deps = self._get_dependencies()
784
790
  assert isinstance(
785
791
  dataset._session, Session
786
792
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -846,10 +852,8 @@ class GammaRegressor(BaseTransformer):
846
852
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
847
853
 
848
854
  if isinstance(dataset, DataFrame):
849
- self._deps = self._batch_inference_validate_snowpark(
850
- dataset=dataset,
851
- inference_method=inference_method,
852
- )
855
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
856
+ self._deps = self._get_dependencies()
853
857
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
854
858
  transform_kwargs = dict(
855
859
  session=dataset._session,
@@ -913,17 +917,15 @@ class GammaRegressor(BaseTransformer):
913
917
  transform_kwargs: ScoreKwargsTypedDict = dict()
914
918
 
915
919
  if isinstance(dataset, DataFrame):
916
- self._deps = self._batch_inference_validate_snowpark(
917
- dataset=dataset,
918
- inference_method="score",
919
- )
920
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
921
+ self._deps = self._get_dependencies()
920
922
  selected_cols = self._get_active_columns()
921
923
  if len(selected_cols) > 0:
922
924
  dataset = dataset.select(selected_cols)
923
925
  assert isinstance(dataset._session, Session) # keep mypy happy
924
926
  transform_kwargs = dict(
925
927
  session=dataset._session,
926
- dependencies=["snowflake-snowpark-python"] + self._deps,
928
+ dependencies=self._deps,
927
929
  score_sproc_imports=['sklearn'],
928
930
  )
929
931
  elif isinstance(dataset, pd.DataFrame):
@@ -988,11 +990,8 @@ class GammaRegressor(BaseTransformer):
988
990
 
989
991
  if isinstance(dataset, DataFrame):
990
992
 
991
- self._deps = self._batch_inference_validate_snowpark(
992
- dataset=dataset,
993
- inference_method=inference_method,
994
-
995
- )
993
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
994
+ self._deps = self._get_dependencies()
996
995
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
997
996
  transform_kwargs = dict(
998
997
  session = dataset._session,