snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class Lasso(BaseTransformer):
70
64
  r"""Linear Model trained with L1 prior as regularizer (aka the Lasso)
71
65
  For more details on this class, see [sklearn.linear_model.Lasso]
@@ -323,20 +317,17 @@ class Lasso(BaseTransformer):
323
317
  self,
324
318
  dataset: DataFrame,
325
319
  inference_method: str,
326
- ) -> List[str]:
327
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
328
- return the available package that exists in the snowflake anaconda channel
320
+ ) -> None:
321
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
329
322
 
330
323
  Args:
331
324
  dataset: snowpark dataframe
332
325
  inference_method: the inference method such as predict, score...
333
-
326
+
334
327
  Raises:
335
328
  SnowflakeMLException: If the estimator is not fitted, raise error
336
329
  SnowflakeMLException: If the session is None, raise error
337
330
 
338
- Returns:
339
- A list of available package that exists in the snowflake anaconda channel
340
331
  """
341
332
  if not self._is_fitted:
342
333
  raise exceptions.SnowflakeMLException(
@@ -354,9 +345,7 @@ class Lasso(BaseTransformer):
354
345
  "Session must not specified for snowpark dataset."
355
346
  ),
356
347
  )
357
- # Validate that key package version in user workspace are supported in snowflake conda channel
358
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
359
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
348
+
360
349
 
361
350
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
362
351
  @telemetry.send_api_usage_telemetry(
@@ -404,7 +393,8 @@ class Lasso(BaseTransformer):
404
393
 
405
394
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
406
395
 
407
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
396
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
397
+ self._deps = self._get_dependencies()
408
398
  assert isinstance(
409
399
  dataset._session, Session
410
400
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -487,10 +477,8 @@ class Lasso(BaseTransformer):
487
477
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
488
478
  expected_dtype = convert_sp_to_sf_type(output_types[0])
489
479
 
490
- self._deps = self._batch_inference_validate_snowpark(
491
- dataset=dataset,
492
- inference_method=inference_method,
493
- )
480
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
481
+ self._deps = self._get_dependencies()
494
482
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
495
483
 
496
484
  transform_kwargs = dict(
@@ -557,16 +545,40 @@ class Lasso(BaseTransformer):
557
545
  self._is_fitted = True
558
546
  return output_result
559
547
 
548
+
549
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
550
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
551
+ """ Method not supported for this class.
560
552
 
561
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
562
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
563
- """
553
+
554
+ Raises:
555
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
556
+
557
+ Args:
558
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
559
+ Snowpark or Pandas DataFrame.
560
+ output_cols_prefix: Prefix for the response columns
564
561
  Returns:
565
562
  Transformed dataset.
566
563
  """
567
- self.fit(dataset)
568
- assert self._sklearn_object is not None
569
- return self._sklearn_object.embedding_
564
+ self._infer_input_output_cols(dataset)
565
+ super()._check_dataset_type(dataset)
566
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
567
+ estimator=self._sklearn_object,
568
+ dataset=dataset,
569
+ input_cols=self.input_cols,
570
+ label_cols=self.label_cols,
571
+ sample_weight_col=self.sample_weight_col,
572
+ autogenerated=self._autogenerated,
573
+ subproject=_SUBPROJECT,
574
+ )
575
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
576
+ drop_input_cols=self._drop_input_cols,
577
+ expected_output_cols_list=self.output_cols,
578
+ )
579
+ self._sklearn_object = fitted_estimator
580
+ self._is_fitted = True
581
+ return output_result
570
582
 
571
583
 
572
584
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -657,10 +669,8 @@ class Lasso(BaseTransformer):
657
669
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
658
670
 
659
671
  if isinstance(dataset, DataFrame):
660
- self._deps = self._batch_inference_validate_snowpark(
661
- dataset=dataset,
662
- inference_method=inference_method,
663
- )
672
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
673
+ self._deps = self._get_dependencies()
664
674
  assert isinstance(
665
675
  dataset._session, Session
666
676
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -725,10 +735,8 @@ class Lasso(BaseTransformer):
725
735
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
726
736
 
727
737
  if isinstance(dataset, DataFrame):
728
- self._deps = self._batch_inference_validate_snowpark(
729
- dataset=dataset,
730
- inference_method=inference_method,
731
- )
738
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
739
+ self._deps = self._get_dependencies()
732
740
  assert isinstance(
733
741
  dataset._session, Session
734
742
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -790,10 +798,8 @@ class Lasso(BaseTransformer):
790
798
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
791
799
 
792
800
  if isinstance(dataset, DataFrame):
793
- self._deps = self._batch_inference_validate_snowpark(
794
- dataset=dataset,
795
- inference_method=inference_method,
796
- )
801
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
802
+ self._deps = self._get_dependencies()
797
803
  assert isinstance(
798
804
  dataset._session, Session
799
805
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -859,10 +865,8 @@ class Lasso(BaseTransformer):
859
865
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
860
866
 
861
867
  if isinstance(dataset, DataFrame):
862
- self._deps = self._batch_inference_validate_snowpark(
863
- dataset=dataset,
864
- inference_method=inference_method,
865
- )
868
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
869
+ self._deps = self._get_dependencies()
866
870
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
867
871
  transform_kwargs = dict(
868
872
  session=dataset._session,
@@ -926,17 +930,15 @@ class Lasso(BaseTransformer):
926
930
  transform_kwargs: ScoreKwargsTypedDict = dict()
927
931
 
928
932
  if isinstance(dataset, DataFrame):
929
- self._deps = self._batch_inference_validate_snowpark(
930
- dataset=dataset,
931
- inference_method="score",
932
- )
933
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
934
+ self._deps = self._get_dependencies()
933
935
  selected_cols = self._get_active_columns()
934
936
  if len(selected_cols) > 0:
935
937
  dataset = dataset.select(selected_cols)
936
938
  assert isinstance(dataset._session, Session) # keep mypy happy
937
939
  transform_kwargs = dict(
938
940
  session=dataset._session,
939
- dependencies=["snowflake-snowpark-python"] + self._deps,
941
+ dependencies=self._deps,
940
942
  score_sproc_imports=['sklearn'],
941
943
  )
942
944
  elif isinstance(dataset, pd.DataFrame):
@@ -1001,11 +1003,8 @@ class Lasso(BaseTransformer):
1001
1003
 
1002
1004
  if isinstance(dataset, DataFrame):
1003
1005
 
1004
- self._deps = self._batch_inference_validate_snowpark(
1005
- dataset=dataset,
1006
- inference_method=inference_method,
1007
-
1008
- )
1006
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1007
+ self._deps = self._get_dependencies()
1009
1008
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1010
1009
  transform_kwargs = dict(
1011
1010
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LassoCV(BaseTransformer):
70
64
  r"""Lasso linear model with iterative fitting along a regularization path
71
65
  For more details on this class, see [sklearn.linear_model.LassoCV]
@@ -351,20 +345,17 @@ class LassoCV(BaseTransformer):
351
345
  self,
352
346
  dataset: DataFrame,
353
347
  inference_method: str,
354
- ) -> List[str]:
355
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
356
- return the available package that exists in the snowflake anaconda channel
348
+ ) -> None:
349
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
357
350
 
358
351
  Args:
359
352
  dataset: snowpark dataframe
360
353
  inference_method: the inference method such as predict, score...
361
-
354
+
362
355
  Raises:
363
356
  SnowflakeMLException: If the estimator is not fitted, raise error
364
357
  SnowflakeMLException: If the session is None, raise error
365
358
 
366
- Returns:
367
- A list of available package that exists in the snowflake anaconda channel
368
359
  """
369
360
  if not self._is_fitted:
370
361
  raise exceptions.SnowflakeMLException(
@@ -382,9 +373,7 @@ class LassoCV(BaseTransformer):
382
373
  "Session must not specified for snowpark dataset."
383
374
  ),
384
375
  )
385
- # Validate that key package version in user workspace are supported in snowflake conda channel
386
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
387
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
376
+
388
377
 
389
378
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
390
379
  @telemetry.send_api_usage_telemetry(
@@ -432,7 +421,8 @@ class LassoCV(BaseTransformer):
432
421
 
433
422
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
434
423
 
435
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
424
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
425
+ self._deps = self._get_dependencies()
436
426
  assert isinstance(
437
427
  dataset._session, Session
438
428
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -515,10 +505,8 @@ class LassoCV(BaseTransformer):
515
505
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
516
506
  expected_dtype = convert_sp_to_sf_type(output_types[0])
517
507
 
518
- self._deps = self._batch_inference_validate_snowpark(
519
- dataset=dataset,
520
- inference_method=inference_method,
521
- )
508
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
509
+ self._deps = self._get_dependencies()
522
510
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
523
511
 
524
512
  transform_kwargs = dict(
@@ -585,16 +573,40 @@ class LassoCV(BaseTransformer):
585
573
  self._is_fitted = True
586
574
  return output_result
587
575
 
576
+
577
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
578
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
579
+ """ Method not supported for this class.
588
580
 
589
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
590
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
591
- """
581
+
582
+ Raises:
583
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
584
+
585
+ Args:
586
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
587
+ Snowpark or Pandas DataFrame.
588
+ output_cols_prefix: Prefix for the response columns
592
589
  Returns:
593
590
  Transformed dataset.
594
591
  """
595
- self.fit(dataset)
596
- assert self._sklearn_object is not None
597
- return self._sklearn_object.embedding_
592
+ self._infer_input_output_cols(dataset)
593
+ super()._check_dataset_type(dataset)
594
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
595
+ estimator=self._sklearn_object,
596
+ dataset=dataset,
597
+ input_cols=self.input_cols,
598
+ label_cols=self.label_cols,
599
+ sample_weight_col=self.sample_weight_col,
600
+ autogenerated=self._autogenerated,
601
+ subproject=_SUBPROJECT,
602
+ )
603
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
604
+ drop_input_cols=self._drop_input_cols,
605
+ expected_output_cols_list=self.output_cols,
606
+ )
607
+ self._sklearn_object = fitted_estimator
608
+ self._is_fitted = True
609
+ return output_result
598
610
 
599
611
 
600
612
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -685,10 +697,8 @@ class LassoCV(BaseTransformer):
685
697
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
686
698
 
687
699
  if isinstance(dataset, DataFrame):
688
- self._deps = self._batch_inference_validate_snowpark(
689
- dataset=dataset,
690
- inference_method=inference_method,
691
- )
700
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
701
+ self._deps = self._get_dependencies()
692
702
  assert isinstance(
693
703
  dataset._session, Session
694
704
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -753,10 +763,8 @@ class LassoCV(BaseTransformer):
753
763
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
754
764
 
755
765
  if isinstance(dataset, DataFrame):
756
- self._deps = self._batch_inference_validate_snowpark(
757
- dataset=dataset,
758
- inference_method=inference_method,
759
- )
766
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
767
+ self._deps = self._get_dependencies()
760
768
  assert isinstance(
761
769
  dataset._session, Session
762
770
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -818,10 +826,8 @@ class LassoCV(BaseTransformer):
818
826
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
819
827
 
820
828
  if isinstance(dataset, DataFrame):
821
- self._deps = self._batch_inference_validate_snowpark(
822
- dataset=dataset,
823
- inference_method=inference_method,
824
- )
829
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
830
+ self._deps = self._get_dependencies()
825
831
  assert isinstance(
826
832
  dataset._session, Session
827
833
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -887,10 +893,8 @@ class LassoCV(BaseTransformer):
887
893
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
888
894
 
889
895
  if isinstance(dataset, DataFrame):
890
- self._deps = self._batch_inference_validate_snowpark(
891
- dataset=dataset,
892
- inference_method=inference_method,
893
- )
896
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
897
+ self._deps = self._get_dependencies()
894
898
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
895
899
  transform_kwargs = dict(
896
900
  session=dataset._session,
@@ -954,17 +958,15 @@ class LassoCV(BaseTransformer):
954
958
  transform_kwargs: ScoreKwargsTypedDict = dict()
955
959
 
956
960
  if isinstance(dataset, DataFrame):
957
- self._deps = self._batch_inference_validate_snowpark(
958
- dataset=dataset,
959
- inference_method="score",
960
- )
961
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
962
+ self._deps = self._get_dependencies()
961
963
  selected_cols = self._get_active_columns()
962
964
  if len(selected_cols) > 0:
963
965
  dataset = dataset.select(selected_cols)
964
966
  assert isinstance(dataset._session, Session) # keep mypy happy
965
967
  transform_kwargs = dict(
966
968
  session=dataset._session,
967
- dependencies=["snowflake-snowpark-python"] + self._deps,
969
+ dependencies=self._deps,
968
970
  score_sproc_imports=['sklearn'],
969
971
  )
970
972
  elif isinstance(dataset, pd.DataFrame):
@@ -1029,11 +1031,8 @@ class LassoCV(BaseTransformer):
1029
1031
 
1030
1032
  if isinstance(dataset, DataFrame):
1031
1033
 
1032
- self._deps = self._batch_inference_validate_snowpark(
1033
- dataset=dataset,
1034
- inference_method=inference_method,
1035
-
1036
- )
1034
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1035
+ self._deps = self._get_dependencies()
1037
1036
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1038
1037
  transform_kwargs = dict(
1039
1038
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LassoLars(BaseTransformer):
70
64
  r"""Lasso model fit with Least Angle Regression a
71
65
  For more details on this class, see [sklearn.linear_model.LassoLars]
@@ -343,20 +337,17 @@ class LassoLars(BaseTransformer):
343
337
  self,
344
338
  dataset: DataFrame,
345
339
  inference_method: str,
346
- ) -> List[str]:
347
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
348
- return the available package that exists in the snowflake anaconda channel
340
+ ) -> None:
341
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
349
342
 
350
343
  Args:
351
344
  dataset: snowpark dataframe
352
345
  inference_method: the inference method such as predict, score...
353
-
346
+
354
347
  Raises:
355
348
  SnowflakeMLException: If the estimator is not fitted, raise error
356
349
  SnowflakeMLException: If the session is None, raise error
357
350
 
358
- Returns:
359
- A list of available package that exists in the snowflake anaconda channel
360
351
  """
361
352
  if not self._is_fitted:
362
353
  raise exceptions.SnowflakeMLException(
@@ -374,9 +365,7 @@ class LassoLars(BaseTransformer):
374
365
  "Session must not specified for snowpark dataset."
375
366
  ),
376
367
  )
377
- # Validate that key package version in user workspace are supported in snowflake conda channel
378
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
379
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
368
+
380
369
 
381
370
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
382
371
  @telemetry.send_api_usage_telemetry(
@@ -424,7 +413,8 @@ class LassoLars(BaseTransformer):
424
413
 
425
414
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
426
415
 
427
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
416
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
417
+ self._deps = self._get_dependencies()
428
418
  assert isinstance(
429
419
  dataset._session, Session
430
420
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -507,10 +497,8 @@ class LassoLars(BaseTransformer):
507
497
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
508
498
  expected_dtype = convert_sp_to_sf_type(output_types[0])
509
499
 
510
- self._deps = self._batch_inference_validate_snowpark(
511
- dataset=dataset,
512
- inference_method=inference_method,
513
- )
500
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
501
+ self._deps = self._get_dependencies()
514
502
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
515
503
 
516
504
  transform_kwargs = dict(
@@ -577,16 +565,40 @@ class LassoLars(BaseTransformer):
577
565
  self._is_fitted = True
578
566
  return output_result
579
567
 
568
+
569
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
570
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
571
+ """ Method not supported for this class.
580
572
 
581
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
582
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
583
- """
573
+
574
+ Raises:
575
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
576
+
577
+ Args:
578
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
579
+ Snowpark or Pandas DataFrame.
580
+ output_cols_prefix: Prefix for the response columns
584
581
  Returns:
585
582
  Transformed dataset.
586
583
  """
587
- self.fit(dataset)
588
- assert self._sklearn_object is not None
589
- return self._sklearn_object.embedding_
584
+ self._infer_input_output_cols(dataset)
585
+ super()._check_dataset_type(dataset)
586
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
587
+ estimator=self._sklearn_object,
588
+ dataset=dataset,
589
+ input_cols=self.input_cols,
590
+ label_cols=self.label_cols,
591
+ sample_weight_col=self.sample_weight_col,
592
+ autogenerated=self._autogenerated,
593
+ subproject=_SUBPROJECT,
594
+ )
595
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
596
+ drop_input_cols=self._drop_input_cols,
597
+ expected_output_cols_list=self.output_cols,
598
+ )
599
+ self._sklearn_object = fitted_estimator
600
+ self._is_fitted = True
601
+ return output_result
590
602
 
591
603
 
592
604
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -677,10 +689,8 @@ class LassoLars(BaseTransformer):
677
689
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
678
690
 
679
691
  if isinstance(dataset, DataFrame):
680
- self._deps = self._batch_inference_validate_snowpark(
681
- dataset=dataset,
682
- inference_method=inference_method,
683
- )
692
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
693
+ self._deps = self._get_dependencies()
684
694
  assert isinstance(
685
695
  dataset._session, Session
686
696
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -745,10 +755,8 @@ class LassoLars(BaseTransformer):
745
755
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
746
756
 
747
757
  if isinstance(dataset, DataFrame):
748
- self._deps = self._batch_inference_validate_snowpark(
749
- dataset=dataset,
750
- inference_method=inference_method,
751
- )
758
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
759
+ self._deps = self._get_dependencies()
752
760
  assert isinstance(
753
761
  dataset._session, Session
754
762
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -810,10 +818,8 @@ class LassoLars(BaseTransformer):
810
818
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
811
819
 
812
820
  if isinstance(dataset, DataFrame):
813
- self._deps = self._batch_inference_validate_snowpark(
814
- dataset=dataset,
815
- inference_method=inference_method,
816
- )
821
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
822
+ self._deps = self._get_dependencies()
817
823
  assert isinstance(
818
824
  dataset._session, Session
819
825
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -879,10 +885,8 @@ class LassoLars(BaseTransformer):
879
885
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
880
886
 
881
887
  if isinstance(dataset, DataFrame):
882
- self._deps = self._batch_inference_validate_snowpark(
883
- dataset=dataset,
884
- inference_method=inference_method,
885
- )
888
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
889
+ self._deps = self._get_dependencies()
886
890
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
887
891
  transform_kwargs = dict(
888
892
  session=dataset._session,
@@ -946,17 +950,15 @@ class LassoLars(BaseTransformer):
946
950
  transform_kwargs: ScoreKwargsTypedDict = dict()
947
951
 
948
952
  if isinstance(dataset, DataFrame):
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method="score",
952
- )
953
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
954
+ self._deps = self._get_dependencies()
953
955
  selected_cols = self._get_active_columns()
954
956
  if len(selected_cols) > 0:
955
957
  dataset = dataset.select(selected_cols)
956
958
  assert isinstance(dataset._session, Session) # keep mypy happy
957
959
  transform_kwargs = dict(
958
960
  session=dataset._session,
959
- dependencies=["snowflake-snowpark-python"] + self._deps,
961
+ dependencies=self._deps,
960
962
  score_sproc_imports=['sklearn'],
961
963
  )
962
964
  elif isinstance(dataset, pd.DataFrame):
@@ -1021,11 +1023,8 @@ class LassoLars(BaseTransformer):
1021
1023
 
1022
1024
  if isinstance(dataset, DataFrame):
1023
1025
 
1024
- self._deps = self._batch_inference_validate_snowpark(
1025
- dataset=dataset,
1026
- inference_method=inference_method,
1027
-
1028
- )
1026
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1027
+ self._deps = self._get_dependencies()
1029
1028
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1030
1029
  transform_kwargs = dict(
1031
1030
  session = dataset._session,