snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class Perceptron(BaseTransformer):
70
64
  r"""Linear perceptron classifier
71
65
  For more details on this class, see [sklearn.linear_model.Perceptron]
@@ -361,20 +355,17 @@ class Perceptron(BaseTransformer):
361
355
  self,
362
356
  dataset: DataFrame,
363
357
  inference_method: str,
364
- ) -> List[str]:
365
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
366
- return the available package that exists in the snowflake anaconda channel
358
+ ) -> None:
359
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
367
360
 
368
361
  Args:
369
362
  dataset: snowpark dataframe
370
363
  inference_method: the inference method such as predict, score...
371
-
364
+
372
365
  Raises:
373
366
  SnowflakeMLException: If the estimator is not fitted, raise error
374
367
  SnowflakeMLException: If the session is None, raise error
375
368
 
376
- Returns:
377
- A list of available package that exists in the snowflake anaconda channel
378
369
  """
379
370
  if not self._is_fitted:
380
371
  raise exceptions.SnowflakeMLException(
@@ -392,9 +383,7 @@ class Perceptron(BaseTransformer):
392
383
  "Session must not specified for snowpark dataset."
393
384
  ),
394
385
  )
395
- # Validate that key package version in user workspace are supported in snowflake conda channel
396
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
397
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
386
+
398
387
 
399
388
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
400
389
  @telemetry.send_api_usage_telemetry(
@@ -442,7 +431,8 @@ class Perceptron(BaseTransformer):
442
431
 
443
432
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
444
433
 
445
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
434
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
435
+ self._deps = self._get_dependencies()
446
436
  assert isinstance(
447
437
  dataset._session, Session
448
438
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -525,10 +515,8 @@ class Perceptron(BaseTransformer):
525
515
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
526
516
  expected_dtype = convert_sp_to_sf_type(output_types[0])
527
517
 
528
- self._deps = self._batch_inference_validate_snowpark(
529
- dataset=dataset,
530
- inference_method=inference_method,
531
- )
518
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
519
+ self._deps = self._get_dependencies()
532
520
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
533
521
 
534
522
  transform_kwargs = dict(
@@ -595,16 +583,40 @@ class Perceptron(BaseTransformer):
595
583
  self._is_fitted = True
596
584
  return output_result
597
585
 
586
+
587
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
588
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
589
+ """ Method not supported for this class.
598
590
 
599
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
600
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
601
- """
591
+
592
+ Raises:
593
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
594
+
595
+ Args:
596
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
597
+ Snowpark or Pandas DataFrame.
598
+ output_cols_prefix: Prefix for the response columns
602
599
  Returns:
603
600
  Transformed dataset.
604
601
  """
605
- self.fit(dataset)
606
- assert self._sklearn_object is not None
607
- return self._sklearn_object.embedding_
602
+ self._infer_input_output_cols(dataset)
603
+ super()._check_dataset_type(dataset)
604
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
605
+ estimator=self._sklearn_object,
606
+ dataset=dataset,
607
+ input_cols=self.input_cols,
608
+ label_cols=self.label_cols,
609
+ sample_weight_col=self.sample_weight_col,
610
+ autogenerated=self._autogenerated,
611
+ subproject=_SUBPROJECT,
612
+ )
613
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
614
+ drop_input_cols=self._drop_input_cols,
615
+ expected_output_cols_list=self.output_cols,
616
+ )
617
+ self._sklearn_object = fitted_estimator
618
+ self._is_fitted = True
619
+ return output_result
608
620
 
609
621
 
610
622
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -695,10 +707,8 @@ class Perceptron(BaseTransformer):
695
707
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
696
708
 
697
709
  if isinstance(dataset, DataFrame):
698
- self._deps = self._batch_inference_validate_snowpark(
699
- dataset=dataset,
700
- inference_method=inference_method,
701
- )
710
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
711
+ self._deps = self._get_dependencies()
702
712
  assert isinstance(
703
713
  dataset._session, Session
704
714
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -763,10 +773,8 @@ class Perceptron(BaseTransformer):
763
773
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
764
774
 
765
775
  if isinstance(dataset, DataFrame):
766
- self._deps = self._batch_inference_validate_snowpark(
767
- dataset=dataset,
768
- inference_method=inference_method,
769
- )
776
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
777
+ self._deps = self._get_dependencies()
770
778
  assert isinstance(
771
779
  dataset._session, Session
772
780
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -830,10 +838,8 @@ class Perceptron(BaseTransformer):
830
838
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
831
839
 
832
840
  if isinstance(dataset, DataFrame):
833
- self._deps = self._batch_inference_validate_snowpark(
834
- dataset=dataset,
835
- inference_method=inference_method,
836
- )
841
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
842
+ self._deps = self._get_dependencies()
837
843
  assert isinstance(
838
844
  dataset._session, Session
839
845
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -899,10 +905,8 @@ class Perceptron(BaseTransformer):
899
905
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
900
906
 
901
907
  if isinstance(dataset, DataFrame):
902
- self._deps = self._batch_inference_validate_snowpark(
903
- dataset=dataset,
904
- inference_method=inference_method,
905
- )
908
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
909
+ self._deps = self._get_dependencies()
906
910
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
907
911
  transform_kwargs = dict(
908
912
  session=dataset._session,
@@ -966,17 +970,15 @@ class Perceptron(BaseTransformer):
966
970
  transform_kwargs: ScoreKwargsTypedDict = dict()
967
971
 
968
972
  if isinstance(dataset, DataFrame):
969
- self._deps = self._batch_inference_validate_snowpark(
970
- dataset=dataset,
971
- inference_method="score",
972
- )
973
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
974
+ self._deps = self._get_dependencies()
973
975
  selected_cols = self._get_active_columns()
974
976
  if len(selected_cols) > 0:
975
977
  dataset = dataset.select(selected_cols)
976
978
  assert isinstance(dataset._session, Session) # keep mypy happy
977
979
  transform_kwargs = dict(
978
980
  session=dataset._session,
979
- dependencies=["snowflake-snowpark-python"] + self._deps,
981
+ dependencies=self._deps,
980
982
  score_sproc_imports=['sklearn'],
981
983
  )
982
984
  elif isinstance(dataset, pd.DataFrame):
@@ -1041,11 +1043,8 @@ class Perceptron(BaseTransformer):
1041
1043
 
1042
1044
  if isinstance(dataset, DataFrame):
1043
1045
 
1044
- self._deps = self._batch_inference_validate_snowpark(
1045
- dataset=dataset,
1046
- inference_method=inference_method,
1047
-
1048
- )
1046
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1047
+ self._deps = self._get_dependencies()
1049
1048
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1050
1049
  transform_kwargs = dict(
1051
1050
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class PoissonRegressor(BaseTransformer):
70
64
  r"""Generalized Linear Model with a Poisson distribution
71
65
  For more details on this class, see [sklearn.linear_model.PoissonRegressor]
@@ -310,20 +304,17 @@ class PoissonRegressor(BaseTransformer):
310
304
  self,
311
305
  dataset: DataFrame,
312
306
  inference_method: str,
313
- ) -> List[str]:
314
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
315
- return the available package that exists in the snowflake anaconda channel
307
+ ) -> None:
308
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
316
309
 
317
310
  Args:
318
311
  dataset: snowpark dataframe
319
312
  inference_method: the inference method such as predict, score...
320
-
313
+
321
314
  Raises:
322
315
  SnowflakeMLException: If the estimator is not fitted, raise error
323
316
  SnowflakeMLException: If the session is None, raise error
324
317
 
325
- Returns:
326
- A list of available package that exists in the snowflake anaconda channel
327
318
  """
328
319
  if not self._is_fitted:
329
320
  raise exceptions.SnowflakeMLException(
@@ -341,9 +332,7 @@ class PoissonRegressor(BaseTransformer):
341
332
  "Session must not specified for snowpark dataset."
342
333
  ),
343
334
  )
344
- # Validate that key package version in user workspace are supported in snowflake conda channel
345
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
346
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
335
+
347
336
 
348
337
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
349
338
  @telemetry.send_api_usage_telemetry(
@@ -391,7 +380,8 @@ class PoissonRegressor(BaseTransformer):
391
380
 
392
381
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
393
382
 
394
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
383
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
384
+ self._deps = self._get_dependencies()
395
385
  assert isinstance(
396
386
  dataset._session, Session
397
387
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -474,10 +464,8 @@ class PoissonRegressor(BaseTransformer):
474
464
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
475
465
  expected_dtype = convert_sp_to_sf_type(output_types[0])
476
466
 
477
- self._deps = self._batch_inference_validate_snowpark(
478
- dataset=dataset,
479
- inference_method=inference_method,
480
- )
467
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
468
+ self._deps = self._get_dependencies()
481
469
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
482
470
 
483
471
  transform_kwargs = dict(
@@ -544,16 +532,40 @@ class PoissonRegressor(BaseTransformer):
544
532
  self._is_fitted = True
545
533
  return output_result
546
534
 
535
+
536
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
537
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
538
+ """ Method not supported for this class.
547
539
 
548
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
549
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
550
- """
540
+
541
+ Raises:
542
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
543
+
544
+ Args:
545
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
546
+ Snowpark or Pandas DataFrame.
547
+ output_cols_prefix: Prefix for the response columns
551
548
  Returns:
552
549
  Transformed dataset.
553
550
  """
554
- self.fit(dataset)
555
- assert self._sklearn_object is not None
556
- return self._sklearn_object.embedding_
551
+ self._infer_input_output_cols(dataset)
552
+ super()._check_dataset_type(dataset)
553
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
554
+ estimator=self._sklearn_object,
555
+ dataset=dataset,
556
+ input_cols=self.input_cols,
557
+ label_cols=self.label_cols,
558
+ sample_weight_col=self.sample_weight_col,
559
+ autogenerated=self._autogenerated,
560
+ subproject=_SUBPROJECT,
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=self.output_cols,
565
+ )
566
+ self._sklearn_object = fitted_estimator
567
+ self._is_fitted = True
568
+ return output_result
557
569
 
558
570
 
559
571
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -644,10 +656,8 @@ class PoissonRegressor(BaseTransformer):
644
656
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
645
657
 
646
658
  if isinstance(dataset, DataFrame):
647
- self._deps = self._batch_inference_validate_snowpark(
648
- dataset=dataset,
649
- inference_method=inference_method,
650
- )
659
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
660
+ self._deps = self._get_dependencies()
651
661
  assert isinstance(
652
662
  dataset._session, Session
653
663
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -712,10 +722,8 @@ class PoissonRegressor(BaseTransformer):
712
722
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
713
723
 
714
724
  if isinstance(dataset, DataFrame):
715
- self._deps = self._batch_inference_validate_snowpark(
716
- dataset=dataset,
717
- inference_method=inference_method,
718
- )
725
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
726
+ self._deps = self._get_dependencies()
719
727
  assert isinstance(
720
728
  dataset._session, Session
721
729
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -777,10 +785,8 @@ class PoissonRegressor(BaseTransformer):
777
785
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
778
786
 
779
787
  if isinstance(dataset, DataFrame):
780
- self._deps = self._batch_inference_validate_snowpark(
781
- dataset=dataset,
782
- inference_method=inference_method,
783
- )
788
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
789
+ self._deps = self._get_dependencies()
784
790
  assert isinstance(
785
791
  dataset._session, Session
786
792
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -846,10 +852,8 @@ class PoissonRegressor(BaseTransformer):
846
852
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
847
853
 
848
854
  if isinstance(dataset, DataFrame):
849
- self._deps = self._batch_inference_validate_snowpark(
850
- dataset=dataset,
851
- inference_method=inference_method,
852
- )
855
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
856
+ self._deps = self._get_dependencies()
853
857
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
854
858
  transform_kwargs = dict(
855
859
  session=dataset._session,
@@ -913,17 +917,15 @@ class PoissonRegressor(BaseTransformer):
913
917
  transform_kwargs: ScoreKwargsTypedDict = dict()
914
918
 
915
919
  if isinstance(dataset, DataFrame):
916
- self._deps = self._batch_inference_validate_snowpark(
917
- dataset=dataset,
918
- inference_method="score",
919
- )
920
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
921
+ self._deps = self._get_dependencies()
920
922
  selected_cols = self._get_active_columns()
921
923
  if len(selected_cols) > 0:
922
924
  dataset = dataset.select(selected_cols)
923
925
  assert isinstance(dataset._session, Session) # keep mypy happy
924
926
  transform_kwargs = dict(
925
927
  session=dataset._session,
926
- dependencies=["snowflake-snowpark-python"] + self._deps,
928
+ dependencies=self._deps,
927
929
  score_sproc_imports=['sklearn'],
928
930
  )
929
931
  elif isinstance(dataset, pd.DataFrame):
@@ -988,11 +990,8 @@ class PoissonRegressor(BaseTransformer):
988
990
 
989
991
  if isinstance(dataset, DataFrame):
990
992
 
991
- self._deps = self._batch_inference_validate_snowpark(
992
- dataset=dataset,
993
- inference_method=inference_method,
994
-
995
- )
993
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
994
+ self._deps = self._get_dependencies()
996
995
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
997
996
  transform_kwargs = dict(
998
997
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RANSACRegressor(BaseTransformer):
70
64
  r"""RANSAC (RANdom SAmple Consensus) algorithm
71
65
  For more details on this class, see [sklearn.linear_model.RANSACRegressor]
@@ -366,20 +360,17 @@ class RANSACRegressor(BaseTransformer):
366
360
  self,
367
361
  dataset: DataFrame,
368
362
  inference_method: str,
369
- ) -> List[str]:
370
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
371
- return the available package that exists in the snowflake anaconda channel
363
+ ) -> None:
364
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
372
365
 
373
366
  Args:
374
367
  dataset: snowpark dataframe
375
368
  inference_method: the inference method such as predict, score...
376
-
369
+
377
370
  Raises:
378
371
  SnowflakeMLException: If the estimator is not fitted, raise error
379
372
  SnowflakeMLException: If the session is None, raise error
380
373
 
381
- Returns:
382
- A list of available package that exists in the snowflake anaconda channel
383
374
  """
384
375
  if not self._is_fitted:
385
376
  raise exceptions.SnowflakeMLException(
@@ -397,9 +388,7 @@ class RANSACRegressor(BaseTransformer):
397
388
  "Session must not specified for snowpark dataset."
398
389
  ),
399
390
  )
400
- # Validate that key package version in user workspace are supported in snowflake conda channel
401
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
402
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
391
+
403
392
 
404
393
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
405
394
  @telemetry.send_api_usage_telemetry(
@@ -447,7 +436,8 @@ class RANSACRegressor(BaseTransformer):
447
436
 
448
437
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
449
438
 
450
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
439
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
440
+ self._deps = self._get_dependencies()
451
441
  assert isinstance(
452
442
  dataset._session, Session
453
443
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -530,10 +520,8 @@ class RANSACRegressor(BaseTransformer):
530
520
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
531
521
  expected_dtype = convert_sp_to_sf_type(output_types[0])
532
522
 
533
- self._deps = self._batch_inference_validate_snowpark(
534
- dataset=dataset,
535
- inference_method=inference_method,
536
- )
523
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
524
+ self._deps = self._get_dependencies()
537
525
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
538
526
 
539
527
  transform_kwargs = dict(
@@ -600,16 +588,40 @@ class RANSACRegressor(BaseTransformer):
600
588
  self._is_fitted = True
601
589
  return output_result
602
590
 
591
+
592
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
593
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
594
+ """ Method not supported for this class.
603
595
 
604
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
605
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
606
- """
596
+
597
+ Raises:
598
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
599
+
600
+ Args:
601
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
602
+ Snowpark or Pandas DataFrame.
603
+ output_cols_prefix: Prefix for the response columns
607
604
  Returns:
608
605
  Transformed dataset.
609
606
  """
610
- self.fit(dataset)
611
- assert self._sklearn_object is not None
612
- return self._sklearn_object.embedding_
607
+ self._infer_input_output_cols(dataset)
608
+ super()._check_dataset_type(dataset)
609
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
610
+ estimator=self._sklearn_object,
611
+ dataset=dataset,
612
+ input_cols=self.input_cols,
613
+ label_cols=self.label_cols,
614
+ sample_weight_col=self.sample_weight_col,
615
+ autogenerated=self._autogenerated,
616
+ subproject=_SUBPROJECT,
617
+ )
618
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
619
+ drop_input_cols=self._drop_input_cols,
620
+ expected_output_cols_list=self.output_cols,
621
+ )
622
+ self._sklearn_object = fitted_estimator
623
+ self._is_fitted = True
624
+ return output_result
613
625
 
614
626
 
615
627
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -700,10 +712,8 @@ class RANSACRegressor(BaseTransformer):
700
712
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
701
713
 
702
714
  if isinstance(dataset, DataFrame):
703
- self._deps = self._batch_inference_validate_snowpark(
704
- dataset=dataset,
705
- inference_method=inference_method,
706
- )
715
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
716
+ self._deps = self._get_dependencies()
707
717
  assert isinstance(
708
718
  dataset._session, Session
709
719
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -768,10 +778,8 @@ class RANSACRegressor(BaseTransformer):
768
778
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
769
779
 
770
780
  if isinstance(dataset, DataFrame):
771
- self._deps = self._batch_inference_validate_snowpark(
772
- dataset=dataset,
773
- inference_method=inference_method,
774
- )
781
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
782
+ self._deps = self._get_dependencies()
775
783
  assert isinstance(
776
784
  dataset._session, Session
777
785
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -833,10 +841,8 @@ class RANSACRegressor(BaseTransformer):
833
841
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
834
842
 
835
843
  if isinstance(dataset, DataFrame):
836
- self._deps = self._batch_inference_validate_snowpark(
837
- dataset=dataset,
838
- inference_method=inference_method,
839
- )
844
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
845
+ self._deps = self._get_dependencies()
840
846
  assert isinstance(
841
847
  dataset._session, Session
842
848
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -902,10 +908,8 @@ class RANSACRegressor(BaseTransformer):
902
908
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
903
909
 
904
910
  if isinstance(dataset, DataFrame):
905
- self._deps = self._batch_inference_validate_snowpark(
906
- dataset=dataset,
907
- inference_method=inference_method,
908
- )
911
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
912
+ self._deps = self._get_dependencies()
909
913
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
910
914
  transform_kwargs = dict(
911
915
  session=dataset._session,
@@ -969,17 +973,15 @@ class RANSACRegressor(BaseTransformer):
969
973
  transform_kwargs: ScoreKwargsTypedDict = dict()
970
974
 
971
975
  if isinstance(dataset, DataFrame):
972
- self._deps = self._batch_inference_validate_snowpark(
973
- dataset=dataset,
974
- inference_method="score",
975
- )
976
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
977
+ self._deps = self._get_dependencies()
976
978
  selected_cols = self._get_active_columns()
977
979
  if len(selected_cols) > 0:
978
980
  dataset = dataset.select(selected_cols)
979
981
  assert isinstance(dataset._session, Session) # keep mypy happy
980
982
  transform_kwargs = dict(
981
983
  session=dataset._session,
982
- dependencies=["snowflake-snowpark-python"] + self._deps,
984
+ dependencies=self._deps,
983
985
  score_sproc_imports=['sklearn'],
984
986
  )
985
987
  elif isinstance(dataset, pd.DataFrame):
@@ -1044,11 +1046,8 @@ class RANSACRegressor(BaseTransformer):
1044
1046
 
1045
1047
  if isinstance(dataset, DataFrame):
1046
1048
 
1047
- self._deps = self._batch_inference_validate_snowpark(
1048
- dataset=dataset,
1049
- inference_method=inference_method,
1050
-
1051
- )
1049
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1050
+ self._deps = self._get_dependencies()
1052
1051
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1053
1052
  transform_kwargs = dict(
1054
1053
  session = dataset._session,