snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class OrthogonalMatchingPursuit(BaseTransformer):
70
64
  r"""Orthogonal Matching Pursuit model (OMP)
71
65
  For more details on this class, see [sklearn.linear_model.OrthogonalMatchingPursuit]
@@ -288,20 +282,17 @@ class OrthogonalMatchingPursuit(BaseTransformer):
288
282
  self,
289
283
  dataset: DataFrame,
290
284
  inference_method: str,
291
- ) -> List[str]:
292
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
293
- return the available package that exists in the snowflake anaconda channel
285
+ ) -> None:
286
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
294
287
 
295
288
  Args:
296
289
  dataset: snowpark dataframe
297
290
  inference_method: the inference method such as predict, score...
298
-
291
+
299
292
  Raises:
300
293
  SnowflakeMLException: If the estimator is not fitted, raise error
301
294
  SnowflakeMLException: If the session is None, raise error
302
295
 
303
- Returns:
304
- A list of available package that exists in the snowflake anaconda channel
305
296
  """
306
297
  if not self._is_fitted:
307
298
  raise exceptions.SnowflakeMLException(
@@ -319,9 +310,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
319
310
  "Session must not specified for snowpark dataset."
320
311
  ),
321
312
  )
322
- # Validate that key package version in user workspace are supported in snowflake conda channel
323
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
324
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
313
+
325
314
 
326
315
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
327
316
  @telemetry.send_api_usage_telemetry(
@@ -369,7 +358,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
369
358
 
370
359
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
371
360
 
372
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
361
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
362
+ self._deps = self._get_dependencies()
373
363
  assert isinstance(
374
364
  dataset._session, Session
375
365
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -452,10 +442,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
452
442
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
453
443
  expected_dtype = convert_sp_to_sf_type(output_types[0])
454
444
 
455
- self._deps = self._batch_inference_validate_snowpark(
456
- dataset=dataset,
457
- inference_method=inference_method,
458
- )
445
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
446
+ self._deps = self._get_dependencies()
459
447
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
460
448
 
461
449
  transform_kwargs = dict(
@@ -522,16 +510,40 @@ class OrthogonalMatchingPursuit(BaseTransformer):
522
510
  self._is_fitted = True
523
511
  return output_result
524
512
 
513
+
514
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
515
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
516
+ """ Method not supported for this class.
525
517
 
526
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
527
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
528
- """
518
+
519
+ Raises:
520
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
521
+
522
+ Args:
523
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
524
+ Snowpark or Pandas DataFrame.
525
+ output_cols_prefix: Prefix for the response columns
529
526
  Returns:
530
527
  Transformed dataset.
531
528
  """
532
- self.fit(dataset)
533
- assert self._sklearn_object is not None
534
- return self._sklearn_object.embedding_
529
+ self._infer_input_output_cols(dataset)
530
+ super()._check_dataset_type(dataset)
531
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
532
+ estimator=self._sklearn_object,
533
+ dataset=dataset,
534
+ input_cols=self.input_cols,
535
+ label_cols=self.label_cols,
536
+ sample_weight_col=self.sample_weight_col,
537
+ autogenerated=self._autogenerated,
538
+ subproject=_SUBPROJECT,
539
+ )
540
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
541
+ drop_input_cols=self._drop_input_cols,
542
+ expected_output_cols_list=self.output_cols,
543
+ )
544
+ self._sklearn_object = fitted_estimator
545
+ self._is_fitted = True
546
+ return output_result
535
547
 
536
548
 
537
549
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -622,10 +634,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
622
634
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
623
635
 
624
636
  if isinstance(dataset, DataFrame):
625
- self._deps = self._batch_inference_validate_snowpark(
626
- dataset=dataset,
627
- inference_method=inference_method,
628
- )
637
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
638
+ self._deps = self._get_dependencies()
629
639
  assert isinstance(
630
640
  dataset._session, Session
631
641
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -690,10 +700,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
690
700
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
691
701
 
692
702
  if isinstance(dataset, DataFrame):
693
- self._deps = self._batch_inference_validate_snowpark(
694
- dataset=dataset,
695
- inference_method=inference_method,
696
- )
703
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
704
+ self._deps = self._get_dependencies()
697
705
  assert isinstance(
698
706
  dataset._session, Session
699
707
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -755,10 +763,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
755
763
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
756
764
 
757
765
  if isinstance(dataset, DataFrame):
758
- self._deps = self._batch_inference_validate_snowpark(
759
- dataset=dataset,
760
- inference_method=inference_method,
761
- )
766
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
767
+ self._deps = self._get_dependencies()
762
768
  assert isinstance(
763
769
  dataset._session, Session
764
770
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -824,10 +830,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
824
830
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
825
831
 
826
832
  if isinstance(dataset, DataFrame):
827
- self._deps = self._batch_inference_validate_snowpark(
828
- dataset=dataset,
829
- inference_method=inference_method,
830
- )
833
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
834
+ self._deps = self._get_dependencies()
831
835
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
832
836
  transform_kwargs = dict(
833
837
  session=dataset._session,
@@ -891,17 +895,15 @@ class OrthogonalMatchingPursuit(BaseTransformer):
891
895
  transform_kwargs: ScoreKwargsTypedDict = dict()
892
896
 
893
897
  if isinstance(dataset, DataFrame):
894
- self._deps = self._batch_inference_validate_snowpark(
895
- dataset=dataset,
896
- inference_method="score",
897
- )
898
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
899
+ self._deps = self._get_dependencies()
898
900
  selected_cols = self._get_active_columns()
899
901
  if len(selected_cols) > 0:
900
902
  dataset = dataset.select(selected_cols)
901
903
  assert isinstance(dataset._session, Session) # keep mypy happy
902
904
  transform_kwargs = dict(
903
905
  session=dataset._session,
904
- dependencies=["snowflake-snowpark-python"] + self._deps,
906
+ dependencies=self._deps,
905
907
  score_sproc_imports=['sklearn'],
906
908
  )
907
909
  elif isinstance(dataset, pd.DataFrame):
@@ -966,11 +968,8 @@ class OrthogonalMatchingPursuit(BaseTransformer):
966
968
 
967
969
  if isinstance(dataset, DataFrame):
968
970
 
969
- self._deps = self._batch_inference_validate_snowpark(
970
- dataset=dataset,
971
- inference_method=inference_method,
972
-
973
- )
971
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
972
+ self._deps = self._get_dependencies()
974
973
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
975
974
  transform_kwargs = dict(
976
975
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class PassiveAggressiveClassifier(BaseTransformer):
70
64
  r"""Passive Aggressive Classifier
71
65
  For more details on this class, see [sklearn.linear_model.PassiveAggressiveClassifier]
@@ -362,20 +356,17 @@ class PassiveAggressiveClassifier(BaseTransformer):
362
356
  self,
363
357
  dataset: DataFrame,
364
358
  inference_method: str,
365
- ) -> List[str]:
366
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
367
- return the available package that exists in the snowflake anaconda channel
359
+ ) -> None:
360
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
368
361
 
369
362
  Args:
370
363
  dataset: snowpark dataframe
371
364
  inference_method: the inference method such as predict, score...
372
-
365
+
373
366
  Raises:
374
367
  SnowflakeMLException: If the estimator is not fitted, raise error
375
368
  SnowflakeMLException: If the session is None, raise error
376
369
 
377
- Returns:
378
- A list of available package that exists in the snowflake anaconda channel
379
370
  """
380
371
  if not self._is_fitted:
381
372
  raise exceptions.SnowflakeMLException(
@@ -393,9 +384,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
393
384
  "Session must not specified for snowpark dataset."
394
385
  ),
395
386
  )
396
- # Validate that key package version in user workspace are supported in snowflake conda channel
397
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
398
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
387
+
399
388
 
400
389
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
401
390
  @telemetry.send_api_usage_telemetry(
@@ -443,7 +432,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
443
432
 
444
433
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
445
434
 
446
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
435
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
436
+ self._deps = self._get_dependencies()
447
437
  assert isinstance(
448
438
  dataset._session, Session
449
439
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -526,10 +516,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
526
516
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
527
517
  expected_dtype = convert_sp_to_sf_type(output_types[0])
528
518
 
529
- self._deps = self._batch_inference_validate_snowpark(
530
- dataset=dataset,
531
- inference_method=inference_method,
532
- )
519
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
520
+ self._deps = self._get_dependencies()
533
521
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
534
522
 
535
523
  transform_kwargs = dict(
@@ -596,16 +584,40 @@ class PassiveAggressiveClassifier(BaseTransformer):
596
584
  self._is_fitted = True
597
585
  return output_result
598
586
 
587
+
588
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
589
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
590
+ """ Method not supported for this class.
599
591
 
600
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
601
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
602
- """
592
+
593
+ Raises:
594
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
595
+
596
+ Args:
597
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
598
+ Snowpark or Pandas DataFrame.
599
+ output_cols_prefix: Prefix for the response columns
603
600
  Returns:
604
601
  Transformed dataset.
605
602
  """
606
- self.fit(dataset)
607
- assert self._sklearn_object is not None
608
- return self._sklearn_object.embedding_
603
+ self._infer_input_output_cols(dataset)
604
+ super()._check_dataset_type(dataset)
605
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
606
+ estimator=self._sklearn_object,
607
+ dataset=dataset,
608
+ input_cols=self.input_cols,
609
+ label_cols=self.label_cols,
610
+ sample_weight_col=self.sample_weight_col,
611
+ autogenerated=self._autogenerated,
612
+ subproject=_SUBPROJECT,
613
+ )
614
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
615
+ drop_input_cols=self._drop_input_cols,
616
+ expected_output_cols_list=self.output_cols,
617
+ )
618
+ self._sklearn_object = fitted_estimator
619
+ self._is_fitted = True
620
+ return output_result
609
621
 
610
622
 
611
623
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -696,10 +708,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
696
708
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
697
709
 
698
710
  if isinstance(dataset, DataFrame):
699
- self._deps = self._batch_inference_validate_snowpark(
700
- dataset=dataset,
701
- inference_method=inference_method,
702
- )
711
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
712
+ self._deps = self._get_dependencies()
703
713
  assert isinstance(
704
714
  dataset._session, Session
705
715
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -764,10 +774,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
764
774
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
765
775
 
766
776
  if isinstance(dataset, DataFrame):
767
- self._deps = self._batch_inference_validate_snowpark(
768
- dataset=dataset,
769
- inference_method=inference_method,
770
- )
777
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
778
+ self._deps = self._get_dependencies()
771
779
  assert isinstance(
772
780
  dataset._session, Session
773
781
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -831,10 +839,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
831
839
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
832
840
 
833
841
  if isinstance(dataset, DataFrame):
834
- self._deps = self._batch_inference_validate_snowpark(
835
- dataset=dataset,
836
- inference_method=inference_method,
837
- )
842
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
843
+ self._deps = self._get_dependencies()
838
844
  assert isinstance(
839
845
  dataset._session, Session
840
846
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -900,10 +906,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
900
906
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
901
907
 
902
908
  if isinstance(dataset, DataFrame):
903
- self._deps = self._batch_inference_validate_snowpark(
904
- dataset=dataset,
905
- inference_method=inference_method,
906
- )
909
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
910
+ self._deps = self._get_dependencies()
907
911
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
908
912
  transform_kwargs = dict(
909
913
  session=dataset._session,
@@ -967,17 +971,15 @@ class PassiveAggressiveClassifier(BaseTransformer):
967
971
  transform_kwargs: ScoreKwargsTypedDict = dict()
968
972
 
969
973
  if isinstance(dataset, DataFrame):
970
- self._deps = self._batch_inference_validate_snowpark(
971
- dataset=dataset,
972
- inference_method="score",
973
- )
974
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
975
+ self._deps = self._get_dependencies()
974
976
  selected_cols = self._get_active_columns()
975
977
  if len(selected_cols) > 0:
976
978
  dataset = dataset.select(selected_cols)
977
979
  assert isinstance(dataset._session, Session) # keep mypy happy
978
980
  transform_kwargs = dict(
979
981
  session=dataset._session,
980
- dependencies=["snowflake-snowpark-python"] + self._deps,
982
+ dependencies=self._deps,
981
983
  score_sproc_imports=['sklearn'],
982
984
  )
983
985
  elif isinstance(dataset, pd.DataFrame):
@@ -1042,11 +1044,8 @@ class PassiveAggressiveClassifier(BaseTransformer):
1042
1044
 
1043
1045
  if isinstance(dataset, DataFrame):
1044
1046
 
1045
- self._deps = self._batch_inference_validate_snowpark(
1046
- dataset=dataset,
1047
- inference_method=inference_method,
1048
-
1049
- )
1047
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1048
+ self._deps = self._get_dependencies()
1050
1049
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1051
1050
  transform_kwargs = dict(
1052
1051
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class PassiveAggressiveRegressor(BaseTransformer):
70
64
  r"""Passive Aggressive Regressor
71
65
  For more details on this class, see [sklearn.linear_model.PassiveAggressiveRegressor]
@@ -348,20 +342,17 @@ class PassiveAggressiveRegressor(BaseTransformer):
348
342
  self,
349
343
  dataset: DataFrame,
350
344
  inference_method: str,
351
- ) -> List[str]:
352
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
353
- return the available package that exists in the snowflake anaconda channel
345
+ ) -> None:
346
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
354
347
 
355
348
  Args:
356
349
  dataset: snowpark dataframe
357
350
  inference_method: the inference method such as predict, score...
358
-
351
+
359
352
  Raises:
360
353
  SnowflakeMLException: If the estimator is not fitted, raise error
361
354
  SnowflakeMLException: If the session is None, raise error
362
355
 
363
- Returns:
364
- A list of available package that exists in the snowflake anaconda channel
365
356
  """
366
357
  if not self._is_fitted:
367
358
  raise exceptions.SnowflakeMLException(
@@ -379,9 +370,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
379
370
  "Session must not specified for snowpark dataset."
380
371
  ),
381
372
  )
382
- # Validate that key package version in user workspace are supported in snowflake conda channel
383
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
384
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
373
+
385
374
 
386
375
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
387
376
  @telemetry.send_api_usage_telemetry(
@@ -429,7 +418,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
429
418
 
430
419
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
431
420
 
432
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
421
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
422
+ self._deps = self._get_dependencies()
433
423
  assert isinstance(
434
424
  dataset._session, Session
435
425
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -512,10 +502,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
512
502
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
513
503
  expected_dtype = convert_sp_to_sf_type(output_types[0])
514
504
 
515
- self._deps = self._batch_inference_validate_snowpark(
516
- dataset=dataset,
517
- inference_method=inference_method,
518
- )
505
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
506
+ self._deps = self._get_dependencies()
519
507
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
520
508
 
521
509
  transform_kwargs = dict(
@@ -582,16 +570,40 @@ class PassiveAggressiveRegressor(BaseTransformer):
582
570
  self._is_fitted = True
583
571
  return output_result
584
572
 
573
+
574
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
575
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
576
+ """ Method not supported for this class.
585
577
 
586
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
587
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
588
- """
578
+
579
+ Raises:
580
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
581
+
582
+ Args:
583
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
584
+ Snowpark or Pandas DataFrame.
585
+ output_cols_prefix: Prefix for the response columns
589
586
  Returns:
590
587
  Transformed dataset.
591
588
  """
592
- self.fit(dataset)
593
- assert self._sklearn_object is not None
594
- return self._sklearn_object.embedding_
589
+ self._infer_input_output_cols(dataset)
590
+ super()._check_dataset_type(dataset)
591
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
592
+ estimator=self._sklearn_object,
593
+ dataset=dataset,
594
+ input_cols=self.input_cols,
595
+ label_cols=self.label_cols,
596
+ sample_weight_col=self.sample_weight_col,
597
+ autogenerated=self._autogenerated,
598
+ subproject=_SUBPROJECT,
599
+ )
600
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
601
+ drop_input_cols=self._drop_input_cols,
602
+ expected_output_cols_list=self.output_cols,
603
+ )
604
+ self._sklearn_object = fitted_estimator
605
+ self._is_fitted = True
606
+ return output_result
595
607
 
596
608
 
597
609
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -682,10 +694,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
682
694
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
683
695
 
684
696
  if isinstance(dataset, DataFrame):
685
- self._deps = self._batch_inference_validate_snowpark(
686
- dataset=dataset,
687
- inference_method=inference_method,
688
- )
697
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
698
+ self._deps = self._get_dependencies()
689
699
  assert isinstance(
690
700
  dataset._session, Session
691
701
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -750,10 +760,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
750
760
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
751
761
 
752
762
  if isinstance(dataset, DataFrame):
753
- self._deps = self._batch_inference_validate_snowpark(
754
- dataset=dataset,
755
- inference_method=inference_method,
756
- )
763
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
764
+ self._deps = self._get_dependencies()
757
765
  assert isinstance(
758
766
  dataset._session, Session
759
767
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -815,10 +823,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
815
823
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
816
824
 
817
825
  if isinstance(dataset, DataFrame):
818
- self._deps = self._batch_inference_validate_snowpark(
819
- dataset=dataset,
820
- inference_method=inference_method,
821
- )
826
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
827
+ self._deps = self._get_dependencies()
822
828
  assert isinstance(
823
829
  dataset._session, Session
824
830
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -884,10 +890,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
884
890
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
885
891
 
886
892
  if isinstance(dataset, DataFrame):
887
- self._deps = self._batch_inference_validate_snowpark(
888
- dataset=dataset,
889
- inference_method=inference_method,
890
- )
893
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
894
+ self._deps = self._get_dependencies()
891
895
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
892
896
  transform_kwargs = dict(
893
897
  session=dataset._session,
@@ -951,17 +955,15 @@ class PassiveAggressiveRegressor(BaseTransformer):
951
955
  transform_kwargs: ScoreKwargsTypedDict = dict()
952
956
 
953
957
  if isinstance(dataset, DataFrame):
954
- self._deps = self._batch_inference_validate_snowpark(
955
- dataset=dataset,
956
- inference_method="score",
957
- )
958
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
959
+ self._deps = self._get_dependencies()
958
960
  selected_cols = self._get_active_columns()
959
961
  if len(selected_cols) > 0:
960
962
  dataset = dataset.select(selected_cols)
961
963
  assert isinstance(dataset._session, Session) # keep mypy happy
962
964
  transform_kwargs = dict(
963
965
  session=dataset._session,
964
- dependencies=["snowflake-snowpark-python"] + self._deps,
966
+ dependencies=self._deps,
965
967
  score_sproc_imports=['sklearn'],
966
968
  )
967
969
  elif isinstance(dataset, pd.DataFrame):
@@ -1026,11 +1028,8 @@ class PassiveAggressiveRegressor(BaseTransformer):
1026
1028
 
1027
1029
  if isinstance(dataset, DataFrame):
1028
1030
 
1029
- self._deps = self._batch_inference_validate_snowpark(
1030
- dataset=dataset,
1031
- inference_method=inference_method,
1032
-
1033
- )
1031
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1032
+ self._deps = self._get_dependencies()
1034
1033
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1035
1034
  transform_kwargs = dict(
1036
1035
  session = dataset._session,