snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MultiTaskElasticNetCV(BaseTransformer):
70
64
  r"""Multi-task L1/L2 ElasticNet with built-in cross-validation
71
65
  For more details on this class, see [sklearn.linear_model.MultiTaskElasticNetCV]
@@ -354,20 +348,17 @@ class MultiTaskElasticNetCV(BaseTransformer):
354
348
  self,
355
349
  dataset: DataFrame,
356
350
  inference_method: str,
357
- ) -> List[str]:
358
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
359
- return the available package that exists in the snowflake anaconda channel
351
+ ) -> None:
352
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
360
353
 
361
354
  Args:
362
355
  dataset: snowpark dataframe
363
356
  inference_method: the inference method such as predict, score...
364
-
357
+
365
358
  Raises:
366
359
  SnowflakeMLException: If the estimator is not fitted, raise error
367
360
  SnowflakeMLException: If the session is None, raise error
368
361
 
369
- Returns:
370
- A list of available package that exists in the snowflake anaconda channel
371
362
  """
372
363
  if not self._is_fitted:
373
364
  raise exceptions.SnowflakeMLException(
@@ -385,9 +376,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
385
376
  "Session must not specified for snowpark dataset."
386
377
  ),
387
378
  )
388
- # Validate that key package version in user workspace are supported in snowflake conda channel
389
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
390
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
379
+
391
380
 
392
381
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
393
382
  @telemetry.send_api_usage_telemetry(
@@ -435,7 +424,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
435
424
 
436
425
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
437
426
 
438
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
427
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
428
+ self._deps = self._get_dependencies()
439
429
  assert isinstance(
440
430
  dataset._session, Session
441
431
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -518,10 +508,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
518
508
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
519
509
  expected_dtype = convert_sp_to_sf_type(output_types[0])
520
510
 
521
- self._deps = self._batch_inference_validate_snowpark(
522
- dataset=dataset,
523
- inference_method=inference_method,
524
- )
511
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
512
+ self._deps = self._get_dependencies()
525
513
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
526
514
 
527
515
  transform_kwargs = dict(
@@ -588,16 +576,40 @@ class MultiTaskElasticNetCV(BaseTransformer):
588
576
  self._is_fitted = True
589
577
  return output_result
590
578
 
579
+
580
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
581
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
582
+ """ Method not supported for this class.
591
583
 
592
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
593
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
594
- """
584
+
585
+ Raises:
586
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
587
+
588
+ Args:
589
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
590
+ Snowpark or Pandas DataFrame.
591
+ output_cols_prefix: Prefix for the response columns
595
592
  Returns:
596
593
  Transformed dataset.
597
594
  """
598
- self.fit(dataset)
599
- assert self._sklearn_object is not None
600
- return self._sklearn_object.embedding_
595
+ self._infer_input_output_cols(dataset)
596
+ super()._check_dataset_type(dataset)
597
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
598
+ estimator=self._sklearn_object,
599
+ dataset=dataset,
600
+ input_cols=self.input_cols,
601
+ label_cols=self.label_cols,
602
+ sample_weight_col=self.sample_weight_col,
603
+ autogenerated=self._autogenerated,
604
+ subproject=_SUBPROJECT,
605
+ )
606
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
607
+ drop_input_cols=self._drop_input_cols,
608
+ expected_output_cols_list=self.output_cols,
609
+ )
610
+ self._sklearn_object = fitted_estimator
611
+ self._is_fitted = True
612
+ return output_result
601
613
 
602
614
 
603
615
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -688,10 +700,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
688
700
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
689
701
 
690
702
  if isinstance(dataset, DataFrame):
691
- self._deps = self._batch_inference_validate_snowpark(
692
- dataset=dataset,
693
- inference_method=inference_method,
694
- )
703
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
704
+ self._deps = self._get_dependencies()
695
705
  assert isinstance(
696
706
  dataset._session, Session
697
707
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -756,10 +766,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
756
766
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
757
767
 
758
768
  if isinstance(dataset, DataFrame):
759
- self._deps = self._batch_inference_validate_snowpark(
760
- dataset=dataset,
761
- inference_method=inference_method,
762
- )
769
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
770
+ self._deps = self._get_dependencies()
763
771
  assert isinstance(
764
772
  dataset._session, Session
765
773
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -821,10 +829,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
821
829
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
822
830
 
823
831
  if isinstance(dataset, DataFrame):
824
- self._deps = self._batch_inference_validate_snowpark(
825
- dataset=dataset,
826
- inference_method=inference_method,
827
- )
832
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
833
+ self._deps = self._get_dependencies()
828
834
  assert isinstance(
829
835
  dataset._session, Session
830
836
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -890,10 +896,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
890
896
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
891
897
 
892
898
  if isinstance(dataset, DataFrame):
893
- self._deps = self._batch_inference_validate_snowpark(
894
- dataset=dataset,
895
- inference_method=inference_method,
896
- )
899
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
900
+ self._deps = self._get_dependencies()
897
901
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
898
902
  transform_kwargs = dict(
899
903
  session=dataset._session,
@@ -957,17 +961,15 @@ class MultiTaskElasticNetCV(BaseTransformer):
957
961
  transform_kwargs: ScoreKwargsTypedDict = dict()
958
962
 
959
963
  if isinstance(dataset, DataFrame):
960
- self._deps = self._batch_inference_validate_snowpark(
961
- dataset=dataset,
962
- inference_method="score",
963
- )
964
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
965
+ self._deps = self._get_dependencies()
964
966
  selected_cols = self._get_active_columns()
965
967
  if len(selected_cols) > 0:
966
968
  dataset = dataset.select(selected_cols)
967
969
  assert isinstance(dataset._session, Session) # keep mypy happy
968
970
  transform_kwargs = dict(
969
971
  session=dataset._session,
970
- dependencies=["snowflake-snowpark-python"] + self._deps,
972
+ dependencies=self._deps,
971
973
  score_sproc_imports=['sklearn'],
972
974
  )
973
975
  elif isinstance(dataset, pd.DataFrame):
@@ -1032,11 +1034,8 @@ class MultiTaskElasticNetCV(BaseTransformer):
1032
1034
 
1033
1035
  if isinstance(dataset, DataFrame):
1034
1036
 
1035
- self._deps = self._batch_inference_validate_snowpark(
1036
- dataset=dataset,
1037
- inference_method=inference_method,
1038
-
1039
- )
1037
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1038
+ self._deps = self._get_dependencies()
1040
1039
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1041
1040
  transform_kwargs = dict(
1042
1041
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MultiTaskLasso(BaseTransformer):
70
64
  r"""Multi-task Lasso model trained with L1/L2 mixed-norm as regularizer
71
65
  For more details on this class, see [sklearn.linear_model.MultiTaskLasso]
@@ -305,20 +299,17 @@ class MultiTaskLasso(BaseTransformer):
305
299
  self,
306
300
  dataset: DataFrame,
307
301
  inference_method: str,
308
- ) -> List[str]:
309
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
310
- return the available package that exists in the snowflake anaconda channel
302
+ ) -> None:
303
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
311
304
 
312
305
  Args:
313
306
  dataset: snowpark dataframe
314
307
  inference_method: the inference method such as predict, score...
315
-
308
+
316
309
  Raises:
317
310
  SnowflakeMLException: If the estimator is not fitted, raise error
318
311
  SnowflakeMLException: If the session is None, raise error
319
312
 
320
- Returns:
321
- A list of available package that exists in the snowflake anaconda channel
322
313
  """
323
314
  if not self._is_fitted:
324
315
  raise exceptions.SnowflakeMLException(
@@ -336,9 +327,7 @@ class MultiTaskLasso(BaseTransformer):
336
327
  "Session must not specified for snowpark dataset."
337
328
  ),
338
329
  )
339
- # Validate that key package version in user workspace are supported in snowflake conda channel
340
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
341
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
330
+
342
331
 
343
332
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
344
333
  @telemetry.send_api_usage_telemetry(
@@ -386,7 +375,8 @@ class MultiTaskLasso(BaseTransformer):
386
375
 
387
376
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
388
377
 
389
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
378
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
379
+ self._deps = self._get_dependencies()
390
380
  assert isinstance(
391
381
  dataset._session, Session
392
382
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -469,10 +459,8 @@ class MultiTaskLasso(BaseTransformer):
469
459
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
470
460
  expected_dtype = convert_sp_to_sf_type(output_types[0])
471
461
 
472
- self._deps = self._batch_inference_validate_snowpark(
473
- dataset=dataset,
474
- inference_method=inference_method,
475
- )
462
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
463
+ self._deps = self._get_dependencies()
476
464
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
477
465
 
478
466
  transform_kwargs = dict(
@@ -539,16 +527,40 @@ class MultiTaskLasso(BaseTransformer):
539
527
  self._is_fitted = True
540
528
  return output_result
541
529
 
530
+
531
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
532
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
533
+ """ Method not supported for this class.
542
534
 
543
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
544
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
545
- """
535
+
536
+ Raises:
537
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
538
+
539
+ Args:
540
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
541
+ Snowpark or Pandas DataFrame.
542
+ output_cols_prefix: Prefix for the response columns
546
543
  Returns:
547
544
  Transformed dataset.
548
545
  """
549
- self.fit(dataset)
550
- assert self._sklearn_object is not None
551
- return self._sklearn_object.embedding_
546
+ self._infer_input_output_cols(dataset)
547
+ super()._check_dataset_type(dataset)
548
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
549
+ estimator=self._sklearn_object,
550
+ dataset=dataset,
551
+ input_cols=self.input_cols,
552
+ label_cols=self.label_cols,
553
+ sample_weight_col=self.sample_weight_col,
554
+ autogenerated=self._autogenerated,
555
+ subproject=_SUBPROJECT,
556
+ )
557
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
558
+ drop_input_cols=self._drop_input_cols,
559
+ expected_output_cols_list=self.output_cols,
560
+ )
561
+ self._sklearn_object = fitted_estimator
562
+ self._is_fitted = True
563
+ return output_result
552
564
 
553
565
 
554
566
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -639,10 +651,8 @@ class MultiTaskLasso(BaseTransformer):
639
651
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
640
652
 
641
653
  if isinstance(dataset, DataFrame):
642
- self._deps = self._batch_inference_validate_snowpark(
643
- dataset=dataset,
644
- inference_method=inference_method,
645
- )
654
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
655
+ self._deps = self._get_dependencies()
646
656
  assert isinstance(
647
657
  dataset._session, Session
648
658
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -707,10 +717,8 @@ class MultiTaskLasso(BaseTransformer):
707
717
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
708
718
 
709
719
  if isinstance(dataset, DataFrame):
710
- self._deps = self._batch_inference_validate_snowpark(
711
- dataset=dataset,
712
- inference_method=inference_method,
713
- )
720
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
721
+ self._deps = self._get_dependencies()
714
722
  assert isinstance(
715
723
  dataset._session, Session
716
724
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -772,10 +780,8 @@ class MultiTaskLasso(BaseTransformer):
772
780
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
773
781
 
774
782
  if isinstance(dataset, DataFrame):
775
- self._deps = self._batch_inference_validate_snowpark(
776
- dataset=dataset,
777
- inference_method=inference_method,
778
- )
783
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
784
+ self._deps = self._get_dependencies()
779
785
  assert isinstance(
780
786
  dataset._session, Session
781
787
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -841,10 +847,8 @@ class MultiTaskLasso(BaseTransformer):
841
847
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
842
848
 
843
849
  if isinstance(dataset, DataFrame):
844
- self._deps = self._batch_inference_validate_snowpark(
845
- dataset=dataset,
846
- inference_method=inference_method,
847
- )
850
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
851
+ self._deps = self._get_dependencies()
848
852
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
849
853
  transform_kwargs = dict(
850
854
  session=dataset._session,
@@ -908,17 +912,15 @@ class MultiTaskLasso(BaseTransformer):
908
912
  transform_kwargs: ScoreKwargsTypedDict = dict()
909
913
 
910
914
  if isinstance(dataset, DataFrame):
911
- self._deps = self._batch_inference_validate_snowpark(
912
- dataset=dataset,
913
- inference_method="score",
914
- )
915
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
916
+ self._deps = self._get_dependencies()
915
917
  selected_cols = self._get_active_columns()
916
918
  if len(selected_cols) > 0:
917
919
  dataset = dataset.select(selected_cols)
918
920
  assert isinstance(dataset._session, Session) # keep mypy happy
919
921
  transform_kwargs = dict(
920
922
  session=dataset._session,
921
- dependencies=["snowflake-snowpark-python"] + self._deps,
923
+ dependencies=self._deps,
922
924
  score_sproc_imports=['sklearn'],
923
925
  )
924
926
  elif isinstance(dataset, pd.DataFrame):
@@ -983,11 +985,8 @@ class MultiTaskLasso(BaseTransformer):
983
985
 
984
986
  if isinstance(dataset, DataFrame):
985
987
 
986
- self._deps = self._batch_inference_validate_snowpark(
987
- dataset=dataset,
988
- inference_method=inference_method,
989
-
990
- )
988
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
989
+ self._deps = self._get_dependencies()
991
990
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
992
991
  transform_kwargs = dict(
993
992
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.linear_model".replace("s
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class MultiTaskLassoCV(BaseTransformer):
70
64
  r"""Multi-task Lasso model trained with L1/L2 mixed-norm as regularizer
71
65
  For more details on this class, see [sklearn.linear_model.MultiTaskLassoCV]
@@ -340,20 +334,17 @@ class MultiTaskLassoCV(BaseTransformer):
340
334
  self,
341
335
  dataset: DataFrame,
342
336
  inference_method: str,
343
- ) -> List[str]:
344
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
345
- return the available package that exists in the snowflake anaconda channel
337
+ ) -> None:
338
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
346
339
 
347
340
  Args:
348
341
  dataset: snowpark dataframe
349
342
  inference_method: the inference method such as predict, score...
350
-
343
+
351
344
  Raises:
352
345
  SnowflakeMLException: If the estimator is not fitted, raise error
353
346
  SnowflakeMLException: If the session is None, raise error
354
347
 
355
- Returns:
356
- A list of available package that exists in the snowflake anaconda channel
357
348
  """
358
349
  if not self._is_fitted:
359
350
  raise exceptions.SnowflakeMLException(
@@ -371,9 +362,7 @@ class MultiTaskLassoCV(BaseTransformer):
371
362
  "Session must not specified for snowpark dataset."
372
363
  ),
373
364
  )
374
- # Validate that key package version in user workspace are supported in snowflake conda channel
375
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
376
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
365
+
377
366
 
378
367
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
379
368
  @telemetry.send_api_usage_telemetry(
@@ -421,7 +410,8 @@ class MultiTaskLassoCV(BaseTransformer):
421
410
 
422
411
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
423
412
 
424
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
413
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
414
+ self._deps = self._get_dependencies()
425
415
  assert isinstance(
426
416
  dataset._session, Session
427
417
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -504,10 +494,8 @@ class MultiTaskLassoCV(BaseTransformer):
504
494
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
505
495
  expected_dtype = convert_sp_to_sf_type(output_types[0])
506
496
 
507
- self._deps = self._batch_inference_validate_snowpark(
508
- dataset=dataset,
509
- inference_method=inference_method,
510
- )
497
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
498
+ self._deps = self._get_dependencies()
511
499
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
512
500
 
513
501
  transform_kwargs = dict(
@@ -574,16 +562,40 @@ class MultiTaskLassoCV(BaseTransformer):
574
562
  self._is_fitted = True
575
563
  return output_result
576
564
 
565
+
566
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
567
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
568
+ """ Method not supported for this class.
577
569
 
578
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
579
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
580
- """
570
+
571
+ Raises:
572
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
573
+
574
+ Args:
575
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
576
+ Snowpark or Pandas DataFrame.
577
+ output_cols_prefix: Prefix for the response columns
581
578
  Returns:
582
579
  Transformed dataset.
583
580
  """
584
- self.fit(dataset)
585
- assert self._sklearn_object is not None
586
- return self._sklearn_object.embedding_
581
+ self._infer_input_output_cols(dataset)
582
+ super()._check_dataset_type(dataset)
583
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
584
+ estimator=self._sklearn_object,
585
+ dataset=dataset,
586
+ input_cols=self.input_cols,
587
+ label_cols=self.label_cols,
588
+ sample_weight_col=self.sample_weight_col,
589
+ autogenerated=self._autogenerated,
590
+ subproject=_SUBPROJECT,
591
+ )
592
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
593
+ drop_input_cols=self._drop_input_cols,
594
+ expected_output_cols_list=self.output_cols,
595
+ )
596
+ self._sklearn_object = fitted_estimator
597
+ self._is_fitted = True
598
+ return output_result
587
599
 
588
600
 
589
601
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -674,10 +686,8 @@ class MultiTaskLassoCV(BaseTransformer):
674
686
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
675
687
 
676
688
  if isinstance(dataset, DataFrame):
677
- self._deps = self._batch_inference_validate_snowpark(
678
- dataset=dataset,
679
- inference_method=inference_method,
680
- )
689
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
690
+ self._deps = self._get_dependencies()
681
691
  assert isinstance(
682
692
  dataset._session, Session
683
693
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -742,10 +752,8 @@ class MultiTaskLassoCV(BaseTransformer):
742
752
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
743
753
 
744
754
  if isinstance(dataset, DataFrame):
745
- self._deps = self._batch_inference_validate_snowpark(
746
- dataset=dataset,
747
- inference_method=inference_method,
748
- )
755
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
756
+ self._deps = self._get_dependencies()
749
757
  assert isinstance(
750
758
  dataset._session, Session
751
759
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -807,10 +815,8 @@ class MultiTaskLassoCV(BaseTransformer):
807
815
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
808
816
 
809
817
  if isinstance(dataset, DataFrame):
810
- self._deps = self._batch_inference_validate_snowpark(
811
- dataset=dataset,
812
- inference_method=inference_method,
813
- )
818
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
819
+ self._deps = self._get_dependencies()
814
820
  assert isinstance(
815
821
  dataset._session, Session
816
822
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -876,10 +882,8 @@ class MultiTaskLassoCV(BaseTransformer):
876
882
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
877
883
 
878
884
  if isinstance(dataset, DataFrame):
879
- self._deps = self._batch_inference_validate_snowpark(
880
- dataset=dataset,
881
- inference_method=inference_method,
882
- )
885
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
886
+ self._deps = self._get_dependencies()
883
887
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
884
888
  transform_kwargs = dict(
885
889
  session=dataset._session,
@@ -943,17 +947,15 @@ class MultiTaskLassoCV(BaseTransformer):
943
947
  transform_kwargs: ScoreKwargsTypedDict = dict()
944
948
 
945
949
  if isinstance(dataset, DataFrame):
946
- self._deps = self._batch_inference_validate_snowpark(
947
- dataset=dataset,
948
- inference_method="score",
949
- )
950
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
951
+ self._deps = self._get_dependencies()
950
952
  selected_cols = self._get_active_columns()
951
953
  if len(selected_cols) > 0:
952
954
  dataset = dataset.select(selected_cols)
953
955
  assert isinstance(dataset._session, Session) # keep mypy happy
954
956
  transform_kwargs = dict(
955
957
  session=dataset._session,
956
- dependencies=["snowflake-snowpark-python"] + self._deps,
958
+ dependencies=self._deps,
957
959
  score_sproc_imports=['sklearn'],
958
960
  )
959
961
  elif isinstance(dataset, pd.DataFrame):
@@ -1018,11 +1020,8 @@ class MultiTaskLassoCV(BaseTransformer):
1018
1020
 
1019
1021
  if isinstance(dataset, DataFrame):
1020
1022
 
1021
- self._deps = self._batch_inference_validate_snowpark(
1022
- dataset=dataset,
1023
- inference_method=inference_method,
1024
-
1025
- )
1023
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1024
+ self._deps = self._get_dependencies()
1026
1025
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1027
1026
  transform_kwargs = dict(
1028
1027
  session = dataset._session,