snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LassoLarsCV(BaseTransformer):
70
64
  r"""Cross-validated Lasso, using the LARS algorithm
71
65
  For more details on this class, see [sklearn.linear_model.LassoLarsCV]
@@ -344,20 +338,17 @@ class LassoLarsCV(BaseTransformer):
344
338
  self,
345
339
  dataset: DataFrame,
346
340
  inference_method: str,
347
- ) -> List[str]:
348
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
349
- return the available package that exists in the snowflake anaconda channel
341
+ ) -> None:
342
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
350
343
 
351
344
  Args:
352
345
  dataset: snowpark dataframe
353
346
  inference_method: the inference method such as predict, score...
354
-
347
+
355
348
  Raises:
356
349
  SnowflakeMLException: If the estimator is not fitted, raise error
357
350
  SnowflakeMLException: If the session is None, raise error
358
351
 
359
- Returns:
360
- A list of available package that exists in the snowflake anaconda channel
361
352
  """
362
353
  if not self._is_fitted:
363
354
  raise exceptions.SnowflakeMLException(
@@ -375,9 +366,7 @@ class LassoLarsCV(BaseTransformer):
375
366
  "Session must not specified for snowpark dataset."
376
367
  ),
377
368
  )
378
- # Validate that key package version in user workspace are supported in snowflake conda channel
379
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
380
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
369
+
381
370
 
382
371
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
383
372
  @telemetry.send_api_usage_telemetry(
@@ -425,7 +414,8 @@ class LassoLarsCV(BaseTransformer):
425
414
 
426
415
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
427
416
 
428
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
417
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
418
+ self._deps = self._get_dependencies()
429
419
  assert isinstance(
430
420
  dataset._session, Session
431
421
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -508,10 +498,8 @@ class LassoLarsCV(BaseTransformer):
508
498
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
509
499
  expected_dtype = convert_sp_to_sf_type(output_types[0])
510
500
 
511
- self._deps = self._batch_inference_validate_snowpark(
512
- dataset=dataset,
513
- inference_method=inference_method,
514
- )
501
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
502
+ self._deps = self._get_dependencies()
515
503
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
516
504
 
517
505
  transform_kwargs = dict(
@@ -578,16 +566,40 @@ class LassoLarsCV(BaseTransformer):
578
566
  self._is_fitted = True
579
567
  return output_result
580
568
 
569
+
570
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
571
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
572
+ """ Method not supported for this class.
581
573
 
582
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
583
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
584
- """
574
+
575
+ Raises:
576
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
577
+
578
+ Args:
579
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
580
+ Snowpark or Pandas DataFrame.
581
+ output_cols_prefix: Prefix for the response columns
585
582
  Returns:
586
583
  Transformed dataset.
587
584
  """
588
- self.fit(dataset)
589
- assert self._sklearn_object is not None
590
- return self._sklearn_object.embedding_
585
+ self._infer_input_output_cols(dataset)
586
+ super()._check_dataset_type(dataset)
587
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
588
+ estimator=self._sklearn_object,
589
+ dataset=dataset,
590
+ input_cols=self.input_cols,
591
+ label_cols=self.label_cols,
592
+ sample_weight_col=self.sample_weight_col,
593
+ autogenerated=self._autogenerated,
594
+ subproject=_SUBPROJECT,
595
+ )
596
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
597
+ drop_input_cols=self._drop_input_cols,
598
+ expected_output_cols_list=self.output_cols,
599
+ )
600
+ self._sklearn_object = fitted_estimator
601
+ self._is_fitted = True
602
+ return output_result
591
603
 
592
604
 
593
605
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -678,10 +690,8 @@ class LassoLarsCV(BaseTransformer):
678
690
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
679
691
 
680
692
  if isinstance(dataset, DataFrame):
681
- self._deps = self._batch_inference_validate_snowpark(
682
- dataset=dataset,
683
- inference_method=inference_method,
684
- )
693
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
694
+ self._deps = self._get_dependencies()
685
695
  assert isinstance(
686
696
  dataset._session, Session
687
697
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -746,10 +756,8 @@ class LassoLarsCV(BaseTransformer):
746
756
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
747
757
 
748
758
  if isinstance(dataset, DataFrame):
749
- self._deps = self._batch_inference_validate_snowpark(
750
- dataset=dataset,
751
- inference_method=inference_method,
752
- )
759
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
760
+ self._deps = self._get_dependencies()
753
761
  assert isinstance(
754
762
  dataset._session, Session
755
763
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -811,10 +819,8 @@ class LassoLarsCV(BaseTransformer):
811
819
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
812
820
 
813
821
  if isinstance(dataset, DataFrame):
814
- self._deps = self._batch_inference_validate_snowpark(
815
- dataset=dataset,
816
- inference_method=inference_method,
817
- )
822
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
823
+ self._deps = self._get_dependencies()
818
824
  assert isinstance(
819
825
  dataset._session, Session
820
826
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -880,10 +886,8 @@ class LassoLarsCV(BaseTransformer):
880
886
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
881
887
 
882
888
  if isinstance(dataset, DataFrame):
883
- self._deps = self._batch_inference_validate_snowpark(
884
- dataset=dataset,
885
- inference_method=inference_method,
886
- )
889
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
890
+ self._deps = self._get_dependencies()
887
891
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
888
892
  transform_kwargs = dict(
889
893
  session=dataset._session,
@@ -947,17 +951,15 @@ class LassoLarsCV(BaseTransformer):
947
951
  transform_kwargs: ScoreKwargsTypedDict = dict()
948
952
 
949
953
  if isinstance(dataset, DataFrame):
950
- self._deps = self._batch_inference_validate_snowpark(
951
- dataset=dataset,
952
- inference_method="score",
953
- )
954
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
955
+ self._deps = self._get_dependencies()
954
956
  selected_cols = self._get_active_columns()
955
957
  if len(selected_cols) > 0:
956
958
  dataset = dataset.select(selected_cols)
957
959
  assert isinstance(dataset._session, Session) # keep mypy happy
958
960
  transform_kwargs = dict(
959
961
  session=dataset._session,
960
- dependencies=["snowflake-snowpark-python"] + self._deps,
962
+ dependencies=self._deps,
961
963
  score_sproc_imports=['sklearn'],
962
964
  )
963
965
  elif isinstance(dataset, pd.DataFrame):
@@ -1022,11 +1024,8 @@ class LassoLarsCV(BaseTransformer):
1022
1024
 
1023
1025
  if isinstance(dataset, DataFrame):
1024
1026
 
1025
- self._deps = self._batch_inference_validate_snowpark(
1026
- dataset=dataset,
1027
- inference_method=inference_method,
1028
-
1029
- )
1027
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1028
+ self._deps = self._get_dependencies()
1030
1029
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1031
1030
  transform_kwargs = dict(
1032
1031
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LassoLarsIC(BaseTransformer):
70
64
  r"""Lasso model fit with Lars using BIC or AIC for model selection
71
65
  For more details on this class, see [sklearn.linear_model.LassoLarsIC]
@@ -327,20 +321,17 @@ class LassoLarsIC(BaseTransformer):
327
321
  self,
328
322
  dataset: DataFrame,
329
323
  inference_method: str,
330
- ) -> List[str]:
331
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
332
- return the available package that exists in the snowflake anaconda channel
324
+ ) -> None:
325
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
333
326
 
334
327
  Args:
335
328
  dataset: snowpark dataframe
336
329
  inference_method: the inference method such as predict, score...
337
-
330
+
338
331
  Raises:
339
332
  SnowflakeMLException: If the estimator is not fitted, raise error
340
333
  SnowflakeMLException: If the session is None, raise error
341
334
 
342
- Returns:
343
- A list of available package that exists in the snowflake anaconda channel
344
335
  """
345
336
  if not self._is_fitted:
346
337
  raise exceptions.SnowflakeMLException(
@@ -358,9 +349,7 @@ class LassoLarsIC(BaseTransformer):
358
349
  "Session must not specified for snowpark dataset."
359
350
  ),
360
351
  )
361
- # Validate that key package version in user workspace are supported in snowflake conda channel
362
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
363
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
352
+
364
353
 
365
354
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
366
355
  @telemetry.send_api_usage_telemetry(
@@ -408,7 +397,8 @@ class LassoLarsIC(BaseTransformer):
408
397
 
409
398
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
410
399
 
411
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
400
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
401
+ self._deps = self._get_dependencies()
412
402
  assert isinstance(
413
403
  dataset._session, Session
414
404
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -491,10 +481,8 @@ class LassoLarsIC(BaseTransformer):
491
481
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
492
482
  expected_dtype = convert_sp_to_sf_type(output_types[0])
493
483
 
494
- self._deps = self._batch_inference_validate_snowpark(
495
- dataset=dataset,
496
- inference_method=inference_method,
497
- )
484
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
485
+ self._deps = self._get_dependencies()
498
486
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
499
487
 
500
488
  transform_kwargs = dict(
@@ -561,16 +549,40 @@ class LassoLarsIC(BaseTransformer):
561
549
  self._is_fitted = True
562
550
  return output_result
563
551
 
552
+
553
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
554
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
555
+ """ Method not supported for this class.
564
556
 
565
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
566
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
567
- """
557
+
558
+ Raises:
559
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
560
+
561
+ Args:
562
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
563
+ Snowpark or Pandas DataFrame.
564
+ output_cols_prefix: Prefix for the response columns
568
565
  Returns:
569
566
  Transformed dataset.
570
567
  """
571
- self.fit(dataset)
572
- assert self._sklearn_object is not None
573
- return self._sklearn_object.embedding_
568
+ self._infer_input_output_cols(dataset)
569
+ super()._check_dataset_type(dataset)
570
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
571
+ estimator=self._sklearn_object,
572
+ dataset=dataset,
573
+ input_cols=self.input_cols,
574
+ label_cols=self.label_cols,
575
+ sample_weight_col=self.sample_weight_col,
576
+ autogenerated=self._autogenerated,
577
+ subproject=_SUBPROJECT,
578
+ )
579
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=self.output_cols,
582
+ )
583
+ self._sklearn_object = fitted_estimator
584
+ self._is_fitted = True
585
+ return output_result
574
586
 
575
587
 
576
588
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -661,10 +673,8 @@ class LassoLarsIC(BaseTransformer):
661
673
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
662
674
 
663
675
  if isinstance(dataset, DataFrame):
664
- self._deps = self._batch_inference_validate_snowpark(
665
- dataset=dataset,
666
- inference_method=inference_method,
667
- )
676
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
677
+ self._deps = self._get_dependencies()
668
678
  assert isinstance(
669
679
  dataset._session, Session
670
680
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -729,10 +739,8 @@ class LassoLarsIC(BaseTransformer):
729
739
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
730
740
 
731
741
  if isinstance(dataset, DataFrame):
732
- self._deps = self._batch_inference_validate_snowpark(
733
- dataset=dataset,
734
- inference_method=inference_method,
735
- )
742
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
743
+ self._deps = self._get_dependencies()
736
744
  assert isinstance(
737
745
  dataset._session, Session
738
746
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -794,10 +802,8 @@ class LassoLarsIC(BaseTransformer):
794
802
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
795
803
 
796
804
  if isinstance(dataset, DataFrame):
797
- self._deps = self._batch_inference_validate_snowpark(
798
- dataset=dataset,
799
- inference_method=inference_method,
800
- )
805
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
806
+ self._deps = self._get_dependencies()
801
807
  assert isinstance(
802
808
  dataset._session, Session
803
809
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -863,10 +869,8 @@ class LassoLarsIC(BaseTransformer):
863
869
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
864
870
 
865
871
  if isinstance(dataset, DataFrame):
866
- self._deps = self._batch_inference_validate_snowpark(
867
- dataset=dataset,
868
- inference_method=inference_method,
869
- )
872
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
873
+ self._deps = self._get_dependencies()
870
874
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
871
875
  transform_kwargs = dict(
872
876
  session=dataset._session,
@@ -930,17 +934,15 @@ class LassoLarsIC(BaseTransformer):
930
934
  transform_kwargs: ScoreKwargsTypedDict = dict()
931
935
 
932
936
  if isinstance(dataset, DataFrame):
933
- self._deps = self._batch_inference_validate_snowpark(
934
- dataset=dataset,
935
- inference_method="score",
936
- )
937
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
938
+ self._deps = self._get_dependencies()
937
939
  selected_cols = self._get_active_columns()
938
940
  if len(selected_cols) > 0:
939
941
  dataset = dataset.select(selected_cols)
940
942
  assert isinstance(dataset._session, Session) # keep mypy happy
941
943
  transform_kwargs = dict(
942
944
  session=dataset._session,
943
- dependencies=["snowflake-snowpark-python"] + self._deps,
945
+ dependencies=self._deps,
944
946
  score_sproc_imports=['sklearn'],
945
947
  )
946
948
  elif isinstance(dataset, pd.DataFrame):
@@ -1005,11 +1007,8 @@ class LassoLarsIC(BaseTransformer):
1005
1007
 
1006
1008
  if isinstance(dataset, DataFrame):
1007
1009
 
1008
- self._deps = self._batch_inference_validate_snowpark(
1009
- dataset=dataset,
1010
- inference_method=inference_method,
1011
-
1012
- )
1010
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1011
+ self._deps = self._get_dependencies()
1013
1012
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1014
1013
  transform_kwargs = dict(
1015
1014
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LinearRegression(BaseTransformer):
70
64
  r"""Ordinary least squares Linear Regression
71
65
  For more details on this class, see [sklearn.linear_model.LinearRegression]
@@ -280,20 +274,17 @@ class LinearRegression(BaseTransformer):
280
274
  self,
281
275
  dataset: DataFrame,
282
276
  inference_method: str,
283
- ) -> List[str]:
284
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
285
- return the available package that exists in the snowflake anaconda channel
277
+ ) -> None:
278
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
286
279
 
287
280
  Args:
288
281
  dataset: snowpark dataframe
289
282
  inference_method: the inference method such as predict, score...
290
-
283
+
291
284
  Raises:
292
285
  SnowflakeMLException: If the estimator is not fitted, raise error
293
286
  SnowflakeMLException: If the session is None, raise error
294
287
 
295
- Returns:
296
- A list of available package that exists in the snowflake anaconda channel
297
288
  """
298
289
  if not self._is_fitted:
299
290
  raise exceptions.SnowflakeMLException(
@@ -311,9 +302,7 @@ class LinearRegression(BaseTransformer):
311
302
  "Session must not specified for snowpark dataset."
312
303
  ),
313
304
  )
314
- # Validate that key package version in user workspace are supported in snowflake conda channel
315
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
316
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
305
+
317
306
 
318
307
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
319
308
  @telemetry.send_api_usage_telemetry(
@@ -361,7 +350,8 @@ class LinearRegression(BaseTransformer):
361
350
 
362
351
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
363
352
 
364
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
353
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
354
+ self._deps = self._get_dependencies()
365
355
  assert isinstance(
366
356
  dataset._session, Session
367
357
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -444,10 +434,8 @@ class LinearRegression(BaseTransformer):
444
434
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
445
435
  expected_dtype = convert_sp_to_sf_type(output_types[0])
446
436
 
447
- self._deps = self._batch_inference_validate_snowpark(
448
- dataset=dataset,
449
- inference_method=inference_method,
450
- )
437
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
438
+ self._deps = self._get_dependencies()
451
439
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
452
440
 
453
441
  transform_kwargs = dict(
@@ -514,16 +502,40 @@ class LinearRegression(BaseTransformer):
514
502
  self._is_fitted = True
515
503
  return output_result
516
504
 
505
+
506
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
507
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
508
+ """ Method not supported for this class.
517
509
 
518
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
519
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
520
- """
510
+
511
+ Raises:
512
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
513
+
514
+ Args:
515
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
516
+ Snowpark or Pandas DataFrame.
517
+ output_cols_prefix: Prefix for the response columns
521
518
  Returns:
522
519
  Transformed dataset.
523
520
  """
524
- self.fit(dataset)
525
- assert self._sklearn_object is not None
526
- return self._sklearn_object.embedding_
521
+ self._infer_input_output_cols(dataset)
522
+ super()._check_dataset_type(dataset)
523
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
524
+ estimator=self._sklearn_object,
525
+ dataset=dataset,
526
+ input_cols=self.input_cols,
527
+ label_cols=self.label_cols,
528
+ sample_weight_col=self.sample_weight_col,
529
+ autogenerated=self._autogenerated,
530
+ subproject=_SUBPROJECT,
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=self.output_cols,
535
+ )
536
+ self._sklearn_object = fitted_estimator
537
+ self._is_fitted = True
538
+ return output_result
527
539
 
528
540
 
529
541
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -614,10 +626,8 @@ class LinearRegression(BaseTransformer):
614
626
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
615
627
 
616
628
  if isinstance(dataset, DataFrame):
617
- self._deps = self._batch_inference_validate_snowpark(
618
- dataset=dataset,
619
- inference_method=inference_method,
620
- )
629
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
630
+ self._deps = self._get_dependencies()
621
631
  assert isinstance(
622
632
  dataset._session, Session
623
633
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -682,10 +692,8 @@ class LinearRegression(BaseTransformer):
682
692
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
683
693
 
684
694
  if isinstance(dataset, DataFrame):
685
- self._deps = self._batch_inference_validate_snowpark(
686
- dataset=dataset,
687
- inference_method=inference_method,
688
- )
695
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
696
+ self._deps = self._get_dependencies()
689
697
  assert isinstance(
690
698
  dataset._session, Session
691
699
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -747,10 +755,8 @@ class LinearRegression(BaseTransformer):
747
755
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
748
756
 
749
757
  if isinstance(dataset, DataFrame):
750
- self._deps = self._batch_inference_validate_snowpark(
751
- dataset=dataset,
752
- inference_method=inference_method,
753
- )
758
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
759
+ self._deps = self._get_dependencies()
754
760
  assert isinstance(
755
761
  dataset._session, Session
756
762
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -816,10 +822,8 @@ class LinearRegression(BaseTransformer):
816
822
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
817
823
 
818
824
  if isinstance(dataset, DataFrame):
819
- self._deps = self._batch_inference_validate_snowpark(
820
- dataset=dataset,
821
- inference_method=inference_method,
822
- )
825
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
826
+ self._deps = self._get_dependencies()
823
827
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
824
828
  transform_kwargs = dict(
825
829
  session=dataset._session,
@@ -883,17 +887,15 @@ class LinearRegression(BaseTransformer):
883
887
  transform_kwargs: ScoreKwargsTypedDict = dict()
884
888
 
885
889
  if isinstance(dataset, DataFrame):
886
- self._deps = self._batch_inference_validate_snowpark(
887
- dataset=dataset,
888
- inference_method="score",
889
- )
890
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
891
+ self._deps = self._get_dependencies()
890
892
  selected_cols = self._get_active_columns()
891
893
  if len(selected_cols) > 0:
892
894
  dataset = dataset.select(selected_cols)
893
895
  assert isinstance(dataset._session, Session) # keep mypy happy
894
896
  transform_kwargs = dict(
895
897
  session=dataset._session,
896
- dependencies=["snowflake-snowpark-python"] + self._deps,
898
+ dependencies=self._deps,
897
899
  score_sproc_imports=['sklearn'],
898
900
  )
899
901
  elif isinstance(dataset, pd.DataFrame):
@@ -958,11 +960,8 @@ class LinearRegression(BaseTransformer):
958
960
 
959
961
  if isinstance(dataset, DataFrame):
960
962
 
961
- self._deps = self._batch_inference_validate_snowpark(
962
- dataset=dataset,
963
- inference_method=inference_method,
964
-
965
- )
963
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
964
+ self._deps = self._get_dependencies()
966
965
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
967
966
  transform_kwargs = dict(
968
967
  session = dataset._session,