snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class HuberRegressor(BaseTransformer):
70
64
  r"""L2-regularized linear regression model that is robust to outliers
71
65
  For more details on this class, see [sklearn.linear_model.HuberRegressor]
@@ -293,20 +287,17 @@ class HuberRegressor(BaseTransformer):
293
287
  self,
294
288
  dataset: DataFrame,
295
289
  inference_method: str,
296
- ) -> List[str]:
297
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
298
- return the available package that exists in the snowflake anaconda channel
290
+ ) -> None:
291
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
299
292
 
300
293
  Args:
301
294
  dataset: snowpark dataframe
302
295
  inference_method: the inference method such as predict, score...
303
-
296
+
304
297
  Raises:
305
298
  SnowflakeMLException: If the estimator is not fitted, raise error
306
299
  SnowflakeMLException: If the session is None, raise error
307
300
 
308
- Returns:
309
- A list of available package that exists in the snowflake anaconda channel
310
301
  """
311
302
  if not self._is_fitted:
312
303
  raise exceptions.SnowflakeMLException(
@@ -324,9 +315,7 @@ class HuberRegressor(BaseTransformer):
324
315
  "Session must not specified for snowpark dataset."
325
316
  ),
326
317
  )
327
- # Validate that key package version in user workspace are supported in snowflake conda channel
328
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
329
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
318
+
330
319
 
331
320
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
332
321
  @telemetry.send_api_usage_telemetry(
@@ -374,7 +363,8 @@ class HuberRegressor(BaseTransformer):
374
363
 
375
364
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
376
365
 
377
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
366
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
367
+ self._deps = self._get_dependencies()
378
368
  assert isinstance(
379
369
  dataset._session, Session
380
370
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -457,10 +447,8 @@ class HuberRegressor(BaseTransformer):
457
447
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
458
448
  expected_dtype = convert_sp_to_sf_type(output_types[0])
459
449
 
460
- self._deps = self._batch_inference_validate_snowpark(
461
- dataset=dataset,
462
- inference_method=inference_method,
463
- )
450
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
451
+ self._deps = self._get_dependencies()
464
452
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
465
453
 
466
454
  transform_kwargs = dict(
@@ -527,16 +515,40 @@ class HuberRegressor(BaseTransformer):
527
515
  self._is_fitted = True
528
516
  return output_result
529
517
 
518
+
519
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
520
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
521
+ """ Method not supported for this class.
530
522
 
531
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
532
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
533
- """
523
+
524
+ Raises:
525
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
526
+
527
+ Args:
528
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
529
+ Snowpark or Pandas DataFrame.
530
+ output_cols_prefix: Prefix for the response columns
534
531
  Returns:
535
532
  Transformed dataset.
536
533
  """
537
- self.fit(dataset)
538
- assert self._sklearn_object is not None
539
- return self._sklearn_object.embedding_
534
+ self._infer_input_output_cols(dataset)
535
+ super()._check_dataset_type(dataset)
536
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
537
+ estimator=self._sklearn_object,
538
+ dataset=dataset,
539
+ input_cols=self.input_cols,
540
+ label_cols=self.label_cols,
541
+ sample_weight_col=self.sample_weight_col,
542
+ autogenerated=self._autogenerated,
543
+ subproject=_SUBPROJECT,
544
+ )
545
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
546
+ drop_input_cols=self._drop_input_cols,
547
+ expected_output_cols_list=self.output_cols,
548
+ )
549
+ self._sklearn_object = fitted_estimator
550
+ self._is_fitted = True
551
+ return output_result
540
552
 
541
553
 
542
554
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -627,10 +639,8 @@ class HuberRegressor(BaseTransformer):
627
639
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
628
640
 
629
641
  if isinstance(dataset, DataFrame):
630
- self._deps = self._batch_inference_validate_snowpark(
631
- dataset=dataset,
632
- inference_method=inference_method,
633
- )
642
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
643
+ self._deps = self._get_dependencies()
634
644
  assert isinstance(
635
645
  dataset._session, Session
636
646
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -695,10 +705,8 @@ class HuberRegressor(BaseTransformer):
695
705
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
696
706
 
697
707
  if isinstance(dataset, DataFrame):
698
- self._deps = self._batch_inference_validate_snowpark(
699
- dataset=dataset,
700
- inference_method=inference_method,
701
- )
708
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
709
+ self._deps = self._get_dependencies()
702
710
  assert isinstance(
703
711
  dataset._session, Session
704
712
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -760,10 +768,8 @@ class HuberRegressor(BaseTransformer):
760
768
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
761
769
 
762
770
  if isinstance(dataset, DataFrame):
763
- self._deps = self._batch_inference_validate_snowpark(
764
- dataset=dataset,
765
- inference_method=inference_method,
766
- )
771
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
772
+ self._deps = self._get_dependencies()
767
773
  assert isinstance(
768
774
  dataset._session, Session
769
775
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -829,10 +835,8 @@ class HuberRegressor(BaseTransformer):
829
835
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
830
836
 
831
837
  if isinstance(dataset, DataFrame):
832
- self._deps = self._batch_inference_validate_snowpark(
833
- dataset=dataset,
834
- inference_method=inference_method,
835
- )
838
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
839
+ self._deps = self._get_dependencies()
836
840
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
837
841
  transform_kwargs = dict(
838
842
  session=dataset._session,
@@ -896,17 +900,15 @@ class HuberRegressor(BaseTransformer):
896
900
  transform_kwargs: ScoreKwargsTypedDict = dict()
897
901
 
898
902
  if isinstance(dataset, DataFrame):
899
- self._deps = self._batch_inference_validate_snowpark(
900
- dataset=dataset,
901
- inference_method="score",
902
- )
903
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
904
+ self._deps = self._get_dependencies()
903
905
  selected_cols = self._get_active_columns()
904
906
  if len(selected_cols) > 0:
905
907
  dataset = dataset.select(selected_cols)
906
908
  assert isinstance(dataset._session, Session) # keep mypy happy
907
909
  transform_kwargs = dict(
908
910
  session=dataset._session,
909
- dependencies=["snowflake-snowpark-python"] + self._deps,
911
+ dependencies=self._deps,
910
912
  score_sproc_imports=['sklearn'],
911
913
  )
912
914
  elif isinstance(dataset, pd.DataFrame):
@@ -971,11 +973,8 @@ class HuberRegressor(BaseTransformer):
971
973
 
972
974
  if isinstance(dataset, DataFrame):
973
975
 
974
- self._deps = self._batch_inference_validate_snowpark(
975
- dataset=dataset,
976
- inference_method=inference_method,
977
-
978
- )
976
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
977
+ self._deps = self._get_dependencies()
979
978
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
980
979
  transform_kwargs = dict(
981
980
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class Lars(BaseTransformer):
70
64
  r"""Least Angle Regression model a
71
65
  For more details on this class, see [sklearn.linear_model.Lars]
@@ -322,20 +316,17 @@ class Lars(BaseTransformer):
322
316
  self,
323
317
  dataset: DataFrame,
324
318
  inference_method: str,
325
- ) -> List[str]:
326
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
327
- return the available package that exists in the snowflake anaconda channel
319
+ ) -> None:
320
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
328
321
 
329
322
  Args:
330
323
  dataset: snowpark dataframe
331
324
  inference_method: the inference method such as predict, score...
332
-
325
+
333
326
  Raises:
334
327
  SnowflakeMLException: If the estimator is not fitted, raise error
335
328
  SnowflakeMLException: If the session is None, raise error
336
329
 
337
- Returns:
338
- A list of available package that exists in the snowflake anaconda channel
339
330
  """
340
331
  if not self._is_fitted:
341
332
  raise exceptions.SnowflakeMLException(
@@ -353,9 +344,7 @@ class Lars(BaseTransformer):
353
344
  "Session must not specified for snowpark dataset."
354
345
  ),
355
346
  )
356
- # Validate that key package version in user workspace are supported in snowflake conda channel
357
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
358
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
347
+
359
348
 
360
349
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
361
350
  @telemetry.send_api_usage_telemetry(
@@ -403,7 +392,8 @@ class Lars(BaseTransformer):
403
392
 
404
393
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
405
394
 
406
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
395
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
396
+ self._deps = self._get_dependencies()
407
397
  assert isinstance(
408
398
  dataset._session, Session
409
399
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -486,10 +476,8 @@ class Lars(BaseTransformer):
486
476
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
487
477
  expected_dtype = convert_sp_to_sf_type(output_types[0])
488
478
 
489
- self._deps = self._batch_inference_validate_snowpark(
490
- dataset=dataset,
491
- inference_method=inference_method,
492
- )
479
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
480
+ self._deps = self._get_dependencies()
493
481
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
494
482
 
495
483
  transform_kwargs = dict(
@@ -556,16 +544,40 @@ class Lars(BaseTransformer):
556
544
  self._is_fitted = True
557
545
  return output_result
558
546
 
547
+
548
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
549
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
550
+ """ Method not supported for this class.
559
551
 
560
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
561
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
562
- """
552
+
553
+ Raises:
554
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
555
+
556
+ Args:
557
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
558
+ Snowpark or Pandas DataFrame.
559
+ output_cols_prefix: Prefix for the response columns
563
560
  Returns:
564
561
  Transformed dataset.
565
562
  """
566
- self.fit(dataset)
567
- assert self._sklearn_object is not None
568
- return self._sklearn_object.embedding_
563
+ self._infer_input_output_cols(dataset)
564
+ super()._check_dataset_type(dataset)
565
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
566
+ estimator=self._sklearn_object,
567
+ dataset=dataset,
568
+ input_cols=self.input_cols,
569
+ label_cols=self.label_cols,
570
+ sample_weight_col=self.sample_weight_col,
571
+ autogenerated=self._autogenerated,
572
+ subproject=_SUBPROJECT,
573
+ )
574
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=self.output_cols,
577
+ )
578
+ self._sklearn_object = fitted_estimator
579
+ self._is_fitted = True
580
+ return output_result
569
581
 
570
582
 
571
583
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -656,10 +668,8 @@ class Lars(BaseTransformer):
656
668
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
657
669
 
658
670
  if isinstance(dataset, DataFrame):
659
- self._deps = self._batch_inference_validate_snowpark(
660
- dataset=dataset,
661
- inference_method=inference_method,
662
- )
671
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
672
+ self._deps = self._get_dependencies()
663
673
  assert isinstance(
664
674
  dataset._session, Session
665
675
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -724,10 +734,8 @@ class Lars(BaseTransformer):
724
734
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
725
735
 
726
736
  if isinstance(dataset, DataFrame):
727
- self._deps = self._batch_inference_validate_snowpark(
728
- dataset=dataset,
729
- inference_method=inference_method,
730
- )
737
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
738
+ self._deps = self._get_dependencies()
731
739
  assert isinstance(
732
740
  dataset._session, Session
733
741
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -789,10 +797,8 @@ class Lars(BaseTransformer):
789
797
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
790
798
 
791
799
  if isinstance(dataset, DataFrame):
792
- self._deps = self._batch_inference_validate_snowpark(
793
- dataset=dataset,
794
- inference_method=inference_method,
795
- )
800
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
801
+ self._deps = self._get_dependencies()
796
802
  assert isinstance(
797
803
  dataset._session, Session
798
804
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -858,10 +864,8 @@ class Lars(BaseTransformer):
858
864
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
859
865
 
860
866
  if isinstance(dataset, DataFrame):
861
- self._deps = self._batch_inference_validate_snowpark(
862
- dataset=dataset,
863
- inference_method=inference_method,
864
- )
867
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
868
+ self._deps = self._get_dependencies()
865
869
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
866
870
  transform_kwargs = dict(
867
871
  session=dataset._session,
@@ -925,17 +929,15 @@ class Lars(BaseTransformer):
925
929
  transform_kwargs: ScoreKwargsTypedDict = dict()
926
930
 
927
931
  if isinstance(dataset, DataFrame):
928
- self._deps = self._batch_inference_validate_snowpark(
929
- dataset=dataset,
930
- inference_method="score",
931
- )
932
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
933
+ self._deps = self._get_dependencies()
932
934
  selected_cols = self._get_active_columns()
933
935
  if len(selected_cols) > 0:
934
936
  dataset = dataset.select(selected_cols)
935
937
  assert isinstance(dataset._session, Session) # keep mypy happy
936
938
  transform_kwargs = dict(
937
939
  session=dataset._session,
938
- dependencies=["snowflake-snowpark-python"] + self._deps,
940
+ dependencies=self._deps,
939
941
  score_sproc_imports=['sklearn'],
940
942
  )
941
943
  elif isinstance(dataset, pd.DataFrame):
@@ -1000,11 +1002,8 @@ class Lars(BaseTransformer):
1000
1002
 
1001
1003
  if isinstance(dataset, DataFrame):
1002
1004
 
1003
- self._deps = self._batch_inference_validate_snowpark(
1004
- dataset=dataset,
1005
- inference_method=inference_method,
1006
-
1007
- )
1005
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1006
+ self._deps = self._get_dependencies()
1008
1007
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1009
1008
  transform_kwargs = dict(
1010
1009
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LarsCV(BaseTransformer):
70
64
  r"""Cross-validated Least Angle Regression model
71
65
  For more details on this class, see [sklearn.linear_model.LarsCV]
@@ -330,20 +324,17 @@ class LarsCV(BaseTransformer):
330
324
  self,
331
325
  dataset: DataFrame,
332
326
  inference_method: str,
333
- ) -> List[str]:
334
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
335
- return the available package that exists in the snowflake anaconda channel
327
+ ) -> None:
328
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
336
329
 
337
330
  Args:
338
331
  dataset: snowpark dataframe
339
332
  inference_method: the inference method such as predict, score...
340
-
333
+
341
334
  Raises:
342
335
  SnowflakeMLException: If the estimator is not fitted, raise error
343
336
  SnowflakeMLException: If the session is None, raise error
344
337
 
345
- Returns:
346
- A list of available package that exists in the snowflake anaconda channel
347
338
  """
348
339
  if not self._is_fitted:
349
340
  raise exceptions.SnowflakeMLException(
@@ -361,9 +352,7 @@ class LarsCV(BaseTransformer):
361
352
  "Session must not specified for snowpark dataset."
362
353
  ),
363
354
  )
364
- # Validate that key package version in user workspace are supported in snowflake conda channel
365
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
366
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
355
+
367
356
 
368
357
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
369
358
  @telemetry.send_api_usage_telemetry(
@@ -411,7 +400,8 @@ class LarsCV(BaseTransformer):
411
400
 
412
401
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
413
402
 
414
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
403
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
404
+ self._deps = self._get_dependencies()
415
405
  assert isinstance(
416
406
  dataset._session, Session
417
407
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -494,10 +484,8 @@ class LarsCV(BaseTransformer):
494
484
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
495
485
  expected_dtype = convert_sp_to_sf_type(output_types[0])
496
486
 
497
- self._deps = self._batch_inference_validate_snowpark(
498
- dataset=dataset,
499
- inference_method=inference_method,
500
- )
487
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
488
+ self._deps = self._get_dependencies()
501
489
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
502
490
 
503
491
  transform_kwargs = dict(
@@ -564,16 +552,40 @@ class LarsCV(BaseTransformer):
564
552
  self._is_fitted = True
565
553
  return output_result
566
554
 
555
+
556
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
557
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
558
+ """ Method not supported for this class.
567
559
 
568
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
569
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
570
- """
560
+
561
+ Raises:
562
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
563
+
564
+ Args:
565
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
566
+ Snowpark or Pandas DataFrame.
567
+ output_cols_prefix: Prefix for the response columns
571
568
  Returns:
572
569
  Transformed dataset.
573
570
  """
574
- self.fit(dataset)
575
- assert self._sklearn_object is not None
576
- return self._sklearn_object.embedding_
571
+ self._infer_input_output_cols(dataset)
572
+ super()._check_dataset_type(dataset)
573
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
574
+ estimator=self._sklearn_object,
575
+ dataset=dataset,
576
+ input_cols=self.input_cols,
577
+ label_cols=self.label_cols,
578
+ sample_weight_col=self.sample_weight_col,
579
+ autogenerated=self._autogenerated,
580
+ subproject=_SUBPROJECT,
581
+ )
582
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
583
+ drop_input_cols=self._drop_input_cols,
584
+ expected_output_cols_list=self.output_cols,
585
+ )
586
+ self._sklearn_object = fitted_estimator
587
+ self._is_fitted = True
588
+ return output_result
577
589
 
578
590
 
579
591
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -664,10 +676,8 @@ class LarsCV(BaseTransformer):
664
676
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
665
677
 
666
678
  if isinstance(dataset, DataFrame):
667
- self._deps = self._batch_inference_validate_snowpark(
668
- dataset=dataset,
669
- inference_method=inference_method,
670
- )
679
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
680
+ self._deps = self._get_dependencies()
671
681
  assert isinstance(
672
682
  dataset._session, Session
673
683
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -732,10 +742,8 @@ class LarsCV(BaseTransformer):
732
742
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
733
743
 
734
744
  if isinstance(dataset, DataFrame):
735
- self._deps = self._batch_inference_validate_snowpark(
736
- dataset=dataset,
737
- inference_method=inference_method,
738
- )
745
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
746
+ self._deps = self._get_dependencies()
739
747
  assert isinstance(
740
748
  dataset._session, Session
741
749
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -797,10 +805,8 @@ class LarsCV(BaseTransformer):
797
805
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
798
806
 
799
807
  if isinstance(dataset, DataFrame):
800
- self._deps = self._batch_inference_validate_snowpark(
801
- dataset=dataset,
802
- inference_method=inference_method,
803
- )
808
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
809
+ self._deps = self._get_dependencies()
804
810
  assert isinstance(
805
811
  dataset._session, Session
806
812
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -866,10 +872,8 @@ class LarsCV(BaseTransformer):
866
872
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
867
873
 
868
874
  if isinstance(dataset, DataFrame):
869
- self._deps = self._batch_inference_validate_snowpark(
870
- dataset=dataset,
871
- inference_method=inference_method,
872
- )
875
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
876
+ self._deps = self._get_dependencies()
873
877
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
874
878
  transform_kwargs = dict(
875
879
  session=dataset._session,
@@ -933,17 +937,15 @@ class LarsCV(BaseTransformer):
933
937
  transform_kwargs: ScoreKwargsTypedDict = dict()
934
938
 
935
939
  if isinstance(dataset, DataFrame):
936
- self._deps = self._batch_inference_validate_snowpark(
937
- dataset=dataset,
938
- inference_method="score",
939
- )
940
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
941
+ self._deps = self._get_dependencies()
940
942
  selected_cols = self._get_active_columns()
941
943
  if len(selected_cols) > 0:
942
944
  dataset = dataset.select(selected_cols)
943
945
  assert isinstance(dataset._session, Session) # keep mypy happy
944
946
  transform_kwargs = dict(
945
947
  session=dataset._session,
946
- dependencies=["snowflake-snowpark-python"] + self._deps,
948
+ dependencies=self._deps,
947
949
  score_sproc_imports=['sklearn'],
948
950
  )
949
951
  elif isinstance(dataset, pd.DataFrame):
@@ -1008,11 +1010,8 @@ class LarsCV(BaseTransformer):
1008
1010
 
1009
1011
  if isinstance(dataset, DataFrame):
1010
1012
 
1011
- self._deps = self._batch_inference_validate_snowpark(
1012
- dataset=dataset,
1013
- inference_method=inference_method,
1014
-
1015
- )
1013
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1014
+ self._deps = self._get_dependencies()
1016
1015
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1017
1016
  transform_kwargs = dict(
1018
1017
  session = dataset._session,