snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LinearSVC(BaseTransformer):
70
64
  r"""Linear Support Vector Classification
71
65
  For more details on this class, see [sklearn.svm.LinearSVC]
@@ -354,20 +348,17 @@ class LinearSVC(BaseTransformer):
354
348
  self,
355
349
  dataset: DataFrame,
356
350
  inference_method: str,
357
- ) -> List[str]:
358
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
359
- return the available package that exists in the snowflake anaconda channel
351
+ ) -> None:
352
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
360
353
 
361
354
  Args:
362
355
  dataset: snowpark dataframe
363
356
  inference_method: the inference method such as predict, score...
364
-
357
+
365
358
  Raises:
366
359
  SnowflakeMLException: If the estimator is not fitted, raise error
367
360
  SnowflakeMLException: If the session is None, raise error
368
361
 
369
- Returns:
370
- A list of available package that exists in the snowflake anaconda channel
371
362
  """
372
363
  if not self._is_fitted:
373
364
  raise exceptions.SnowflakeMLException(
@@ -385,9 +376,7 @@ class LinearSVC(BaseTransformer):
385
376
  "Session must not specified for snowpark dataset."
386
377
  ),
387
378
  )
388
- # Validate that key package version in user workspace are supported in snowflake conda channel
389
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
390
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
379
+
391
380
 
392
381
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
393
382
  @telemetry.send_api_usage_telemetry(
@@ -435,7 +424,8 @@ class LinearSVC(BaseTransformer):
435
424
 
436
425
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
437
426
 
438
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
427
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
428
+ self._deps = self._get_dependencies()
439
429
  assert isinstance(
440
430
  dataset._session, Session
441
431
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -518,10 +508,8 @@ class LinearSVC(BaseTransformer):
518
508
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
519
509
  expected_dtype = convert_sp_to_sf_type(output_types[0])
520
510
 
521
- self._deps = self._batch_inference_validate_snowpark(
522
- dataset=dataset,
523
- inference_method=inference_method,
524
- )
511
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
512
+ self._deps = self._get_dependencies()
525
513
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
526
514
 
527
515
  transform_kwargs = dict(
@@ -588,16 +576,40 @@ class LinearSVC(BaseTransformer):
588
576
  self._is_fitted = True
589
577
  return output_result
590
578
 
579
+
580
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
581
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
582
+ """ Method not supported for this class.
591
583
 
592
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
593
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
594
- """
584
+
585
+ Raises:
586
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
587
+
588
+ Args:
589
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
590
+ Snowpark or Pandas DataFrame.
591
+ output_cols_prefix: Prefix for the response columns
595
592
  Returns:
596
593
  Transformed dataset.
597
594
  """
598
- self.fit(dataset)
599
- assert self._sklearn_object is not None
600
- return self._sklearn_object.embedding_
595
+ self._infer_input_output_cols(dataset)
596
+ super()._check_dataset_type(dataset)
597
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
598
+ estimator=self._sklearn_object,
599
+ dataset=dataset,
600
+ input_cols=self.input_cols,
601
+ label_cols=self.label_cols,
602
+ sample_weight_col=self.sample_weight_col,
603
+ autogenerated=self._autogenerated,
604
+ subproject=_SUBPROJECT,
605
+ )
606
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
607
+ drop_input_cols=self._drop_input_cols,
608
+ expected_output_cols_list=self.output_cols,
609
+ )
610
+ self._sklearn_object = fitted_estimator
611
+ self._is_fitted = True
612
+ return output_result
601
613
 
602
614
 
603
615
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -688,10 +700,8 @@ class LinearSVC(BaseTransformer):
688
700
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
689
701
 
690
702
  if isinstance(dataset, DataFrame):
691
- self._deps = self._batch_inference_validate_snowpark(
692
- dataset=dataset,
693
- inference_method=inference_method,
694
- )
703
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
704
+ self._deps = self._get_dependencies()
695
705
  assert isinstance(
696
706
  dataset._session, Session
697
707
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -756,10 +766,8 @@ class LinearSVC(BaseTransformer):
756
766
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
757
767
 
758
768
  if isinstance(dataset, DataFrame):
759
- self._deps = self._batch_inference_validate_snowpark(
760
- dataset=dataset,
761
- inference_method=inference_method,
762
- )
769
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
770
+ self._deps = self._get_dependencies()
763
771
  assert isinstance(
764
772
  dataset._session, Session
765
773
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -823,10 +831,8 @@ class LinearSVC(BaseTransformer):
823
831
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
824
832
 
825
833
  if isinstance(dataset, DataFrame):
826
- self._deps = self._batch_inference_validate_snowpark(
827
- dataset=dataset,
828
- inference_method=inference_method,
829
- )
834
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
835
+ self._deps = self._get_dependencies()
830
836
  assert isinstance(
831
837
  dataset._session, Session
832
838
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -892,10 +898,8 @@ class LinearSVC(BaseTransformer):
892
898
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
893
899
 
894
900
  if isinstance(dataset, DataFrame):
895
- self._deps = self._batch_inference_validate_snowpark(
896
- dataset=dataset,
897
- inference_method=inference_method,
898
- )
901
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
902
+ self._deps = self._get_dependencies()
899
903
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
900
904
  transform_kwargs = dict(
901
905
  session=dataset._session,
@@ -959,17 +963,15 @@ class LinearSVC(BaseTransformer):
959
963
  transform_kwargs: ScoreKwargsTypedDict = dict()
960
964
 
961
965
  if isinstance(dataset, DataFrame):
962
- self._deps = self._batch_inference_validate_snowpark(
963
- dataset=dataset,
964
- inference_method="score",
965
- )
966
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
967
+ self._deps = self._get_dependencies()
966
968
  selected_cols = self._get_active_columns()
967
969
  if len(selected_cols) > 0:
968
970
  dataset = dataset.select(selected_cols)
969
971
  assert isinstance(dataset._session, Session) # keep mypy happy
970
972
  transform_kwargs = dict(
971
973
  session=dataset._session,
972
- dependencies=["snowflake-snowpark-python"] + self._deps,
974
+ dependencies=self._deps,
973
975
  score_sproc_imports=['sklearn'],
974
976
  )
975
977
  elif isinstance(dataset, pd.DataFrame):
@@ -1034,11 +1036,8 @@ class LinearSVC(BaseTransformer):
1034
1036
 
1035
1037
  if isinstance(dataset, DataFrame):
1036
1038
 
1037
- self._deps = self._batch_inference_validate_snowpark(
1038
- dataset=dataset,
1039
- inference_method=inference_method,
1040
-
1041
- )
1039
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1040
+ self._deps = self._get_dependencies()
1042
1041
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1043
1042
  transform_kwargs = dict(
1044
1043
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LinearSVR(BaseTransformer):
70
64
  r"""Linear Support Vector Regression
71
65
  For more details on this class, see [sklearn.svm.LinearSVR]
@@ -326,20 +320,17 @@ class LinearSVR(BaseTransformer):
326
320
  self,
327
321
  dataset: DataFrame,
328
322
  inference_method: str,
329
- ) -> List[str]:
330
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
331
- return the available package that exists in the snowflake anaconda channel
323
+ ) -> None:
324
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
332
325
 
333
326
  Args:
334
327
  dataset: snowpark dataframe
335
328
  inference_method: the inference method such as predict, score...
336
-
329
+
337
330
  Raises:
338
331
  SnowflakeMLException: If the estimator is not fitted, raise error
339
332
  SnowflakeMLException: If the session is None, raise error
340
333
 
341
- Returns:
342
- A list of available package that exists in the snowflake anaconda channel
343
334
  """
344
335
  if not self._is_fitted:
345
336
  raise exceptions.SnowflakeMLException(
@@ -357,9 +348,7 @@ class LinearSVR(BaseTransformer):
357
348
  "Session must not specified for snowpark dataset."
358
349
  ),
359
350
  )
360
- # Validate that key package version in user workspace are supported in snowflake conda channel
361
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
362
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
351
+
363
352
 
364
353
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
365
354
  @telemetry.send_api_usage_telemetry(
@@ -407,7 +396,8 @@ class LinearSVR(BaseTransformer):
407
396
 
408
397
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
409
398
 
410
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
399
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
400
+ self._deps = self._get_dependencies()
411
401
  assert isinstance(
412
402
  dataset._session, Session
413
403
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -490,10 +480,8 @@ class LinearSVR(BaseTransformer):
490
480
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
491
481
  expected_dtype = convert_sp_to_sf_type(output_types[0])
492
482
 
493
- self._deps = self._batch_inference_validate_snowpark(
494
- dataset=dataset,
495
- inference_method=inference_method,
496
- )
483
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
484
+ self._deps = self._get_dependencies()
497
485
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
498
486
 
499
487
  transform_kwargs = dict(
@@ -560,16 +548,40 @@ class LinearSVR(BaseTransformer):
560
548
  self._is_fitted = True
561
549
  return output_result
562
550
 
551
+
552
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
553
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
554
+ """ Method not supported for this class.
563
555
 
564
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
565
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
566
- """
556
+
557
+ Raises:
558
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
559
+
560
+ Args:
561
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
562
+ Snowpark or Pandas DataFrame.
563
+ output_cols_prefix: Prefix for the response columns
567
564
  Returns:
568
565
  Transformed dataset.
569
566
  """
570
- self.fit(dataset)
571
- assert self._sklearn_object is not None
572
- return self._sklearn_object.embedding_
567
+ self._infer_input_output_cols(dataset)
568
+ super()._check_dataset_type(dataset)
569
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
570
+ estimator=self._sklearn_object,
571
+ dataset=dataset,
572
+ input_cols=self.input_cols,
573
+ label_cols=self.label_cols,
574
+ sample_weight_col=self.sample_weight_col,
575
+ autogenerated=self._autogenerated,
576
+ subproject=_SUBPROJECT,
577
+ )
578
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
579
+ drop_input_cols=self._drop_input_cols,
580
+ expected_output_cols_list=self.output_cols,
581
+ )
582
+ self._sklearn_object = fitted_estimator
583
+ self._is_fitted = True
584
+ return output_result
573
585
 
574
586
 
575
587
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -660,10 +672,8 @@ class LinearSVR(BaseTransformer):
660
672
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
661
673
 
662
674
  if isinstance(dataset, DataFrame):
663
- self._deps = self._batch_inference_validate_snowpark(
664
- dataset=dataset,
665
- inference_method=inference_method,
666
- )
675
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
676
+ self._deps = self._get_dependencies()
667
677
  assert isinstance(
668
678
  dataset._session, Session
669
679
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -728,10 +738,8 @@ class LinearSVR(BaseTransformer):
728
738
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
729
739
 
730
740
  if isinstance(dataset, DataFrame):
731
- self._deps = self._batch_inference_validate_snowpark(
732
- dataset=dataset,
733
- inference_method=inference_method,
734
- )
741
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
742
+ self._deps = self._get_dependencies()
735
743
  assert isinstance(
736
744
  dataset._session, Session
737
745
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -793,10 +801,8 @@ class LinearSVR(BaseTransformer):
793
801
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
794
802
 
795
803
  if isinstance(dataset, DataFrame):
796
- self._deps = self._batch_inference_validate_snowpark(
797
- dataset=dataset,
798
- inference_method=inference_method,
799
- )
804
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
805
+ self._deps = self._get_dependencies()
800
806
  assert isinstance(
801
807
  dataset._session, Session
802
808
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -862,10 +868,8 @@ class LinearSVR(BaseTransformer):
862
868
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
863
869
 
864
870
  if isinstance(dataset, DataFrame):
865
- self._deps = self._batch_inference_validate_snowpark(
866
- dataset=dataset,
867
- inference_method=inference_method,
868
- )
871
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
872
+ self._deps = self._get_dependencies()
869
873
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
870
874
  transform_kwargs = dict(
871
875
  session=dataset._session,
@@ -929,17 +933,15 @@ class LinearSVR(BaseTransformer):
929
933
  transform_kwargs: ScoreKwargsTypedDict = dict()
930
934
 
931
935
  if isinstance(dataset, DataFrame):
932
- self._deps = self._batch_inference_validate_snowpark(
933
- dataset=dataset,
934
- inference_method="score",
935
- )
936
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
937
+ self._deps = self._get_dependencies()
936
938
  selected_cols = self._get_active_columns()
937
939
  if len(selected_cols) > 0:
938
940
  dataset = dataset.select(selected_cols)
939
941
  assert isinstance(dataset._session, Session) # keep mypy happy
940
942
  transform_kwargs = dict(
941
943
  session=dataset._session,
942
- dependencies=["snowflake-snowpark-python"] + self._deps,
944
+ dependencies=self._deps,
943
945
  score_sproc_imports=['sklearn'],
944
946
  )
945
947
  elif isinstance(dataset, pd.DataFrame):
@@ -1004,11 +1006,8 @@ class LinearSVR(BaseTransformer):
1004
1006
 
1005
1007
  if isinstance(dataset, DataFrame):
1006
1008
 
1007
- self._deps = self._batch_inference_validate_snowpark(
1008
- dataset=dataset,
1009
- inference_method=inference_method,
1010
-
1011
- )
1009
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1010
+ self._deps = self._get_dependencies()
1012
1011
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1013
1012
  transform_kwargs = dict(
1014
1013
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.svm".replace("sklearn.",
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class NuSVC(BaseTransformer):
70
64
  r"""Nu-Support Vector Classification
71
65
  For more details on this class, see [sklearn.svm.NuSVC]
@@ -360,20 +354,17 @@ class NuSVC(BaseTransformer):
360
354
  self,
361
355
  dataset: DataFrame,
362
356
  inference_method: str,
363
- ) -> List[str]:
364
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
365
- return the available package that exists in the snowflake anaconda channel
357
+ ) -> None:
358
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
366
359
 
367
360
  Args:
368
361
  dataset: snowpark dataframe
369
362
  inference_method: the inference method such as predict, score...
370
-
363
+
371
364
  Raises:
372
365
  SnowflakeMLException: If the estimator is not fitted, raise error
373
366
  SnowflakeMLException: If the session is None, raise error
374
367
 
375
- Returns:
376
- A list of available package that exists in the snowflake anaconda channel
377
368
  """
378
369
  if not self._is_fitted:
379
370
  raise exceptions.SnowflakeMLException(
@@ -391,9 +382,7 @@ class NuSVC(BaseTransformer):
391
382
  "Session must not specified for snowpark dataset."
392
383
  ),
393
384
  )
394
- # Validate that key package version in user workspace are supported in snowflake conda channel
395
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
396
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
385
+
397
386
 
398
387
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
399
388
  @telemetry.send_api_usage_telemetry(
@@ -441,7 +430,8 @@ class NuSVC(BaseTransformer):
441
430
 
442
431
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
443
432
 
444
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
433
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
434
+ self._deps = self._get_dependencies()
445
435
  assert isinstance(
446
436
  dataset._session, Session
447
437
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -524,10 +514,8 @@ class NuSVC(BaseTransformer):
524
514
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
525
515
  expected_dtype = convert_sp_to_sf_type(output_types[0])
526
516
 
527
- self._deps = self._batch_inference_validate_snowpark(
528
- dataset=dataset,
529
- inference_method=inference_method,
530
- )
517
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
518
+ self._deps = self._get_dependencies()
531
519
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
532
520
 
533
521
  transform_kwargs = dict(
@@ -594,16 +582,40 @@ class NuSVC(BaseTransformer):
594
582
  self._is_fitted = True
595
583
  return output_result
596
584
 
585
+
586
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
587
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
588
+ """ Method not supported for this class.
597
589
 
598
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
599
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
600
- """
590
+
591
+ Raises:
592
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
593
+
594
+ Args:
595
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
596
+ Snowpark or Pandas DataFrame.
597
+ output_cols_prefix: Prefix for the response columns
601
598
  Returns:
602
599
  Transformed dataset.
603
600
  """
604
- self.fit(dataset)
605
- assert self._sklearn_object is not None
606
- return self._sklearn_object.embedding_
601
+ self._infer_input_output_cols(dataset)
602
+ super()._check_dataset_type(dataset)
603
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
604
+ estimator=self._sklearn_object,
605
+ dataset=dataset,
606
+ input_cols=self.input_cols,
607
+ label_cols=self.label_cols,
608
+ sample_weight_col=self.sample_weight_col,
609
+ autogenerated=self._autogenerated,
610
+ subproject=_SUBPROJECT,
611
+ )
612
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
613
+ drop_input_cols=self._drop_input_cols,
614
+ expected_output_cols_list=self.output_cols,
615
+ )
616
+ self._sklearn_object = fitted_estimator
617
+ self._is_fitted = True
618
+ return output_result
607
619
 
608
620
 
609
621
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -696,10 +708,8 @@ class NuSVC(BaseTransformer):
696
708
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
697
709
 
698
710
  if isinstance(dataset, DataFrame):
699
- self._deps = self._batch_inference_validate_snowpark(
700
- dataset=dataset,
701
- inference_method=inference_method,
702
- )
711
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
712
+ self._deps = self._get_dependencies()
703
713
  assert isinstance(
704
714
  dataset._session, Session
705
715
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -766,10 +776,8 @@ class NuSVC(BaseTransformer):
766
776
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
767
777
 
768
778
  if isinstance(dataset, DataFrame):
769
- self._deps = self._batch_inference_validate_snowpark(
770
- dataset=dataset,
771
- inference_method=inference_method,
772
- )
779
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
780
+ self._deps = self._get_dependencies()
773
781
  assert isinstance(
774
782
  dataset._session, Session
775
783
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -833,10 +841,8 @@ class NuSVC(BaseTransformer):
833
841
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
834
842
 
835
843
  if isinstance(dataset, DataFrame):
836
- self._deps = self._batch_inference_validate_snowpark(
837
- dataset=dataset,
838
- inference_method=inference_method,
839
- )
844
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
845
+ self._deps = self._get_dependencies()
840
846
  assert isinstance(
841
847
  dataset._session, Session
842
848
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -902,10 +908,8 @@ class NuSVC(BaseTransformer):
902
908
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
903
909
 
904
910
  if isinstance(dataset, DataFrame):
905
- self._deps = self._batch_inference_validate_snowpark(
906
- dataset=dataset,
907
- inference_method=inference_method,
908
- )
911
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
912
+ self._deps = self._get_dependencies()
909
913
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
910
914
  transform_kwargs = dict(
911
915
  session=dataset._session,
@@ -969,17 +973,15 @@ class NuSVC(BaseTransformer):
969
973
  transform_kwargs: ScoreKwargsTypedDict = dict()
970
974
 
971
975
  if isinstance(dataset, DataFrame):
972
- self._deps = self._batch_inference_validate_snowpark(
973
- dataset=dataset,
974
- inference_method="score",
975
- )
976
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
977
+ self._deps = self._get_dependencies()
976
978
  selected_cols = self._get_active_columns()
977
979
  if len(selected_cols) > 0:
978
980
  dataset = dataset.select(selected_cols)
979
981
  assert isinstance(dataset._session, Session) # keep mypy happy
980
982
  transform_kwargs = dict(
981
983
  session=dataset._session,
982
- dependencies=["snowflake-snowpark-python"] + self._deps,
984
+ dependencies=self._deps,
983
985
  score_sproc_imports=['sklearn'],
984
986
  )
985
987
  elif isinstance(dataset, pd.DataFrame):
@@ -1044,11 +1046,8 @@ class NuSVC(BaseTransformer):
1044
1046
 
1045
1047
  if isinstance(dataset, DataFrame):
1046
1048
 
1047
- self._deps = self._batch_inference_validate_snowpark(
1048
- dataset=dataset,
1049
- inference_method=inference_method,
1050
-
1051
- )
1049
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1050
+ self._deps = self._get_dependencies()
1052
1051
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1053
1052
  transform_kwargs = dict(
1054
1053
  session = dataset._session,