snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class NuSVR(BaseTransformer):
70
64
  r"""Nu Support Vector Regression
71
65
  For more details on this class, see [sklearn.svm.NuSVR]
@@ -321,20 +315,17 @@ class NuSVR(BaseTransformer):
321
315
  self,
322
316
  dataset: DataFrame,
323
317
  inference_method: str,
324
- ) -> List[str]:
325
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
326
- return the available package that exists in the snowflake anaconda channel
318
+ ) -> None:
319
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
327
320
 
328
321
  Args:
329
322
  dataset: snowpark dataframe
330
323
  inference_method: the inference method such as predict, score...
331
-
324
+
332
325
  Raises:
333
326
  SnowflakeMLException: If the estimator is not fitted, raise error
334
327
  SnowflakeMLException: If the session is None, raise error
335
328
 
336
- Returns:
337
- A list of available package that exists in the snowflake anaconda channel
338
329
  """
339
330
  if not self._is_fitted:
340
331
  raise exceptions.SnowflakeMLException(
@@ -352,9 +343,7 @@ class NuSVR(BaseTransformer):
352
343
  "Session must not specified for snowpark dataset."
353
344
  ),
354
345
  )
355
- # Validate that key package version in user workspace are supported in snowflake conda channel
356
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
357
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
346
+
358
347
 
359
348
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
360
349
  @telemetry.send_api_usage_telemetry(
@@ -402,7 +391,8 @@ class NuSVR(BaseTransformer):
402
391
 
403
392
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
404
393
 
405
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
394
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
395
+ self._deps = self._get_dependencies()
406
396
  assert isinstance(
407
397
  dataset._session, Session
408
398
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -485,10 +475,8 @@ class NuSVR(BaseTransformer):
485
475
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
486
476
  expected_dtype = convert_sp_to_sf_type(output_types[0])
487
477
 
488
- self._deps = self._batch_inference_validate_snowpark(
489
- dataset=dataset,
490
- inference_method=inference_method,
491
- )
478
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
479
+ self._deps = self._get_dependencies()
492
480
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
493
481
 
494
482
  transform_kwargs = dict(
@@ -555,16 +543,40 @@ class NuSVR(BaseTransformer):
555
543
  self._is_fitted = True
556
544
  return output_result
557
545
 
546
+
547
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
548
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
549
+ """ Method not supported for this class.
558
550
 
559
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
560
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
561
- """
551
+
552
+ Raises:
553
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
554
+
555
+ Args:
556
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
557
+ Snowpark or Pandas DataFrame.
558
+ output_cols_prefix: Prefix for the response columns
562
559
  Returns:
563
560
  Transformed dataset.
564
561
  """
565
- self.fit(dataset)
566
- assert self._sklearn_object is not None
567
- return self._sklearn_object.embedding_
562
+ self._infer_input_output_cols(dataset)
563
+ super()._check_dataset_type(dataset)
564
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
565
+ estimator=self._sklearn_object,
566
+ dataset=dataset,
567
+ input_cols=self.input_cols,
568
+ label_cols=self.label_cols,
569
+ sample_weight_col=self.sample_weight_col,
570
+ autogenerated=self._autogenerated,
571
+ subproject=_SUBPROJECT,
572
+ )
573
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=self.output_cols,
576
+ )
577
+ self._sklearn_object = fitted_estimator
578
+ self._is_fitted = True
579
+ return output_result
568
580
 
569
581
 
570
582
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -655,10 +667,8 @@ class NuSVR(BaseTransformer):
655
667
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
656
668
 
657
669
  if isinstance(dataset, DataFrame):
658
- self._deps = self._batch_inference_validate_snowpark(
659
- dataset=dataset,
660
- inference_method=inference_method,
661
- )
670
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
671
+ self._deps = self._get_dependencies()
662
672
  assert isinstance(
663
673
  dataset._session, Session
664
674
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -723,10 +733,8 @@ class NuSVR(BaseTransformer):
723
733
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
724
734
 
725
735
  if isinstance(dataset, DataFrame):
726
- self._deps = self._batch_inference_validate_snowpark(
727
- dataset=dataset,
728
- inference_method=inference_method,
729
- )
736
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
737
+ self._deps = self._get_dependencies()
730
738
  assert isinstance(
731
739
  dataset._session, Session
732
740
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -788,10 +796,8 @@ class NuSVR(BaseTransformer):
788
796
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
789
797
 
790
798
  if isinstance(dataset, DataFrame):
791
- self._deps = self._batch_inference_validate_snowpark(
792
- dataset=dataset,
793
- inference_method=inference_method,
794
- )
799
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
800
+ self._deps = self._get_dependencies()
795
801
  assert isinstance(
796
802
  dataset._session, Session
797
803
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -857,10 +863,8 @@ class NuSVR(BaseTransformer):
857
863
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
858
864
 
859
865
  if isinstance(dataset, DataFrame):
860
- self._deps = self._batch_inference_validate_snowpark(
861
- dataset=dataset,
862
- inference_method=inference_method,
863
- )
866
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
867
+ self._deps = self._get_dependencies()
864
868
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
865
869
  transform_kwargs = dict(
866
870
  session=dataset._session,
@@ -924,17 +928,15 @@ class NuSVR(BaseTransformer):
924
928
  transform_kwargs: ScoreKwargsTypedDict = dict()
925
929
 
926
930
  if isinstance(dataset, DataFrame):
927
- self._deps = self._batch_inference_validate_snowpark(
928
- dataset=dataset,
929
- inference_method="score",
930
- )
931
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
932
+ self._deps = self._get_dependencies()
931
933
  selected_cols = self._get_active_columns()
932
934
  if len(selected_cols) > 0:
933
935
  dataset = dataset.select(selected_cols)
934
936
  assert isinstance(dataset._session, Session) # keep mypy happy
935
937
  transform_kwargs = dict(
936
938
  session=dataset._session,
937
- dependencies=["snowflake-snowpark-python"] + self._deps,
939
+ dependencies=self._deps,
938
940
  score_sproc_imports=['sklearn'],
939
941
  )
940
942
  elif isinstance(dataset, pd.DataFrame):
@@ -999,11 +1001,8 @@ class NuSVR(BaseTransformer):
999
1001
 
1000
1002
  if isinstance(dataset, DataFrame):
1001
1003
 
1002
- self._deps = self._batch_inference_validate_snowpark(
1003
- dataset=dataset,
1004
- inference_method=inference_method,
1005
-
1006
- )
1004
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1005
+ self._deps = self._get_dependencies()
1007
1006
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1008
1007
  transform_kwargs = dict(
1009
1008
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SVC(BaseTransformer):
70
64
  r"""C-Support Vector Classification
71
65
  For more details on this class, see [sklearn.svm.SVC]
@@ -363,20 +357,17 @@ class SVC(BaseTransformer):
363
357
  self,
364
358
  dataset: DataFrame,
365
359
  inference_method: str,
366
- ) -> List[str]:
367
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
368
- return the available package that exists in the snowflake anaconda channel
360
+ ) -> None:
361
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
369
362
 
370
363
  Args:
371
364
  dataset: snowpark dataframe
372
365
  inference_method: the inference method such as predict, score...
373
-
366
+
374
367
  Raises:
375
368
  SnowflakeMLException: If the estimator is not fitted, raise error
376
369
  SnowflakeMLException: If the session is None, raise error
377
370
 
378
- Returns:
379
- A list of available package that exists in the snowflake anaconda channel
380
371
  """
381
372
  if not self._is_fitted:
382
373
  raise exceptions.SnowflakeMLException(
@@ -394,9 +385,7 @@ class SVC(BaseTransformer):
394
385
  "Session must not specified for snowpark dataset."
395
386
  ),
396
387
  )
397
- # Validate that key package version in user workspace are supported in snowflake conda channel
398
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
399
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
388
+
400
389
 
401
390
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
402
391
  @telemetry.send_api_usage_telemetry(
@@ -444,7 +433,8 @@ class SVC(BaseTransformer):
444
433
 
445
434
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
446
435
 
447
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
436
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
437
+ self._deps = self._get_dependencies()
448
438
  assert isinstance(
449
439
  dataset._session, Session
450
440
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -527,10 +517,8 @@ class SVC(BaseTransformer):
527
517
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
528
518
  expected_dtype = convert_sp_to_sf_type(output_types[0])
529
519
 
530
- self._deps = self._batch_inference_validate_snowpark(
531
- dataset=dataset,
532
- inference_method=inference_method,
533
- )
520
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
521
+ self._deps = self._get_dependencies()
534
522
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
535
523
 
536
524
  transform_kwargs = dict(
@@ -597,16 +585,40 @@ class SVC(BaseTransformer):
597
585
  self._is_fitted = True
598
586
  return output_result
599
587
 
588
+
589
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
590
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
591
+ """ Method not supported for this class.
600
592
 
601
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
602
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
603
- """
593
+
594
+ Raises:
595
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
596
+
597
+ Args:
598
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
599
+ Snowpark or Pandas DataFrame.
600
+ output_cols_prefix: Prefix for the response columns
604
601
  Returns:
605
602
  Transformed dataset.
606
603
  """
607
- self.fit(dataset)
608
- assert self._sklearn_object is not None
609
- return self._sklearn_object.embedding_
604
+ self._infer_input_output_cols(dataset)
605
+ super()._check_dataset_type(dataset)
606
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
607
+ estimator=self._sklearn_object,
608
+ dataset=dataset,
609
+ input_cols=self.input_cols,
610
+ label_cols=self.label_cols,
611
+ sample_weight_col=self.sample_weight_col,
612
+ autogenerated=self._autogenerated,
613
+ subproject=_SUBPROJECT,
614
+ )
615
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
616
+ drop_input_cols=self._drop_input_cols,
617
+ expected_output_cols_list=self.output_cols,
618
+ )
619
+ self._sklearn_object = fitted_estimator
620
+ self._is_fitted = True
621
+ return output_result
610
622
 
611
623
 
612
624
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -699,10 +711,8 @@ class SVC(BaseTransformer):
699
711
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
700
712
 
701
713
  if isinstance(dataset, DataFrame):
702
- self._deps = self._batch_inference_validate_snowpark(
703
- dataset=dataset,
704
- inference_method=inference_method,
705
- )
714
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
715
+ self._deps = self._get_dependencies()
706
716
  assert isinstance(
707
717
  dataset._session, Session
708
718
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -769,10 +779,8 @@ class SVC(BaseTransformer):
769
779
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
770
780
 
771
781
  if isinstance(dataset, DataFrame):
772
- self._deps = self._batch_inference_validate_snowpark(
773
- dataset=dataset,
774
- inference_method=inference_method,
775
- )
782
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
783
+ self._deps = self._get_dependencies()
776
784
  assert isinstance(
777
785
  dataset._session, Session
778
786
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -836,10 +844,8 @@ class SVC(BaseTransformer):
836
844
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
837
845
 
838
846
  if isinstance(dataset, DataFrame):
839
- self._deps = self._batch_inference_validate_snowpark(
840
- dataset=dataset,
841
- inference_method=inference_method,
842
- )
847
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
848
+ self._deps = self._get_dependencies()
843
849
  assert isinstance(
844
850
  dataset._session, Session
845
851
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -905,10 +911,8 @@ class SVC(BaseTransformer):
905
911
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
906
912
 
907
913
  if isinstance(dataset, DataFrame):
908
- self._deps = self._batch_inference_validate_snowpark(
909
- dataset=dataset,
910
- inference_method=inference_method,
911
- )
914
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
915
+ self._deps = self._get_dependencies()
912
916
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
913
917
  transform_kwargs = dict(
914
918
  session=dataset._session,
@@ -972,17 +976,15 @@ class SVC(BaseTransformer):
972
976
  transform_kwargs: ScoreKwargsTypedDict = dict()
973
977
 
974
978
  if isinstance(dataset, DataFrame):
975
- self._deps = self._batch_inference_validate_snowpark(
976
- dataset=dataset,
977
- inference_method="score",
978
- )
979
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
980
+ self._deps = self._get_dependencies()
979
981
  selected_cols = self._get_active_columns()
980
982
  if len(selected_cols) > 0:
981
983
  dataset = dataset.select(selected_cols)
982
984
  assert isinstance(dataset._session, Session) # keep mypy happy
983
985
  transform_kwargs = dict(
984
986
  session=dataset._session,
985
- dependencies=["snowflake-snowpark-python"] + self._deps,
987
+ dependencies=self._deps,
986
988
  score_sproc_imports=['sklearn'],
987
989
  )
988
990
  elif isinstance(dataset, pd.DataFrame):
@@ -1047,11 +1049,8 @@ class SVC(BaseTransformer):
1047
1049
 
1048
1050
  if isinstance(dataset, DataFrame):
1049
1051
 
1050
- self._deps = self._batch_inference_validate_snowpark(
1051
- dataset=dataset,
1052
- inference_method=inference_method,
1053
-
1054
- )
1052
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1053
+ self._deps = self._get_dependencies()
1055
1054
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1056
1055
  transform_kwargs = dict(
1057
1056
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class SVR(BaseTransformer):
70
64
  r"""Epsilon-Support Vector Regression
71
65
  For more details on this class, see [sklearn.svm.SVR]
@@ -324,20 +318,17 @@ class SVR(BaseTransformer):
324
318
  self,
325
319
  dataset: DataFrame,
326
320
  inference_method: str,
327
- ) -> List[str]:
328
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
329
- return the available package that exists in the snowflake anaconda channel
321
+ ) -> None:
322
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
330
323
 
331
324
  Args:
332
325
  dataset: snowpark dataframe
333
326
  inference_method: the inference method such as predict, score...
334
-
327
+
335
328
  Raises:
336
329
  SnowflakeMLException: If the estimator is not fitted, raise error
337
330
  SnowflakeMLException: If the session is None, raise error
338
331
 
339
- Returns:
340
- A list of available package that exists in the snowflake anaconda channel
341
332
  """
342
333
  if not self._is_fitted:
343
334
  raise exceptions.SnowflakeMLException(
@@ -355,9 +346,7 @@ class SVR(BaseTransformer):
355
346
  "Session must not specified for snowpark dataset."
356
347
  ),
357
348
  )
358
- # Validate that key package version in user workspace are supported in snowflake conda channel
359
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
360
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
349
+
361
350
 
362
351
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
363
352
  @telemetry.send_api_usage_telemetry(
@@ -405,7 +394,8 @@ class SVR(BaseTransformer):
405
394
 
406
395
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
407
396
 
408
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
397
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
398
+ self._deps = self._get_dependencies()
409
399
  assert isinstance(
410
400
  dataset._session, Session
411
401
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -488,10 +478,8 @@ class SVR(BaseTransformer):
488
478
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
489
479
  expected_dtype = convert_sp_to_sf_type(output_types[0])
490
480
 
491
- self._deps = self._batch_inference_validate_snowpark(
492
- dataset=dataset,
493
- inference_method=inference_method,
494
- )
481
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
482
+ self._deps = self._get_dependencies()
495
483
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
496
484
 
497
485
  transform_kwargs = dict(
@@ -558,16 +546,40 @@ class SVR(BaseTransformer):
558
546
  self._is_fitted = True
559
547
  return output_result
560
548
 
549
+
550
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
551
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
552
+ """ Method not supported for this class.
561
553
 
562
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
563
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
564
- """
554
+
555
+ Raises:
556
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
557
+
558
+ Args:
559
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
560
+ Snowpark or Pandas DataFrame.
561
+ output_cols_prefix: Prefix for the response columns
565
562
  Returns:
566
563
  Transformed dataset.
567
564
  """
568
- self.fit(dataset)
569
- assert self._sklearn_object is not None
570
- return self._sklearn_object.embedding_
565
+ self._infer_input_output_cols(dataset)
566
+ super()._check_dataset_type(dataset)
567
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
568
+ estimator=self._sklearn_object,
569
+ dataset=dataset,
570
+ input_cols=self.input_cols,
571
+ label_cols=self.label_cols,
572
+ sample_weight_col=self.sample_weight_col,
573
+ autogenerated=self._autogenerated,
574
+ subproject=_SUBPROJECT,
575
+ )
576
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
577
+ drop_input_cols=self._drop_input_cols,
578
+ expected_output_cols_list=self.output_cols,
579
+ )
580
+ self._sklearn_object = fitted_estimator
581
+ self._is_fitted = True
582
+ return output_result
571
583
 
572
584
 
573
585
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -658,10 +670,8 @@ class SVR(BaseTransformer):
658
670
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
659
671
 
660
672
  if isinstance(dataset, DataFrame):
661
- self._deps = self._batch_inference_validate_snowpark(
662
- dataset=dataset,
663
- inference_method=inference_method,
664
- )
673
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
674
+ self._deps = self._get_dependencies()
665
675
  assert isinstance(
666
676
  dataset._session, Session
667
677
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -726,10 +736,8 @@ class SVR(BaseTransformer):
726
736
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
727
737
 
728
738
  if isinstance(dataset, DataFrame):
729
- self._deps = self._batch_inference_validate_snowpark(
730
- dataset=dataset,
731
- inference_method=inference_method,
732
- )
739
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
740
+ self._deps = self._get_dependencies()
733
741
  assert isinstance(
734
742
  dataset._session, Session
735
743
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -791,10 +799,8 @@ class SVR(BaseTransformer):
791
799
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
792
800
 
793
801
  if isinstance(dataset, DataFrame):
794
- self._deps = self._batch_inference_validate_snowpark(
795
- dataset=dataset,
796
- inference_method=inference_method,
797
- )
802
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
803
+ self._deps = self._get_dependencies()
798
804
  assert isinstance(
799
805
  dataset._session, Session
800
806
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -860,10 +866,8 @@ class SVR(BaseTransformer):
860
866
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
861
867
 
862
868
  if isinstance(dataset, DataFrame):
863
- self._deps = self._batch_inference_validate_snowpark(
864
- dataset=dataset,
865
- inference_method=inference_method,
866
- )
869
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
870
+ self._deps = self._get_dependencies()
867
871
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
868
872
  transform_kwargs = dict(
869
873
  session=dataset._session,
@@ -927,17 +931,15 @@ class SVR(BaseTransformer):
927
931
  transform_kwargs: ScoreKwargsTypedDict = dict()
928
932
 
929
933
  if isinstance(dataset, DataFrame):
930
- self._deps = self._batch_inference_validate_snowpark(
931
- dataset=dataset,
932
- inference_method="score",
933
- )
934
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
935
+ self._deps = self._get_dependencies()
934
936
  selected_cols = self._get_active_columns()
935
937
  if len(selected_cols) > 0:
936
938
  dataset = dataset.select(selected_cols)
937
939
  assert isinstance(dataset._session, Session) # keep mypy happy
938
940
  transform_kwargs = dict(
939
941
  session=dataset._session,
940
- dependencies=["snowflake-snowpark-python"] + self._deps,
942
+ dependencies=self._deps,
941
943
  score_sproc_imports=['sklearn'],
942
944
  )
943
945
  elif isinstance(dataset, pd.DataFrame):
@@ -1002,11 +1004,8 @@ class SVR(BaseTransformer):
1002
1004
 
1003
1005
  if isinstance(dataset, DataFrame):
1004
1006
 
1005
- self._deps = self._batch_inference_validate_snowpark(
1006
- dataset=dataset,
1007
- inference_method=inference_method,
1008
-
1009
- )
1007
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1008
+ self._deps = self._get_dependencies()
1010
1009
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1011
1010
  transform_kwargs = dict(
1012
1011
  session = dataset._session,