snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -334,9 +334,12 @@ class GridSearchCV(BaseTransformer):
334
334
  self._generate_model_signatures(dataset)
335
335
  return self
336
336
 
337
- def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
338
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
339
- return the available package that exists in the snowflake anaconda channel
337
+ def _batch_inference_validate_snowpark(
338
+ self,
339
+ dataset: DataFrame,
340
+ inference_method: str,
341
+ ) -> None:
342
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
340
343
 
341
344
  Args:
342
345
  dataset: snowpark dataframe
@@ -346,8 +349,6 @@ class GridSearchCV(BaseTransformer):
346
349
  SnowflakeMLException: If the estimator is not fitted, raise error
347
350
  SnowflakeMLException: If the session is None, raise error
348
351
 
349
- Returns:
350
- A list of available package that exists in the snowflake anaconda channel
351
352
  """
352
353
  if not self._is_fitted:
353
354
  raise exceptions.SnowflakeMLException(
@@ -363,10 +364,6 @@ class GridSearchCV(BaseTransformer):
363
364
  error_code=error_codes.NOT_FOUND,
364
365
  original_exception=ValueError("Session must not specified for snowpark dataset."),
365
366
  )
366
- # Validate that key package version in user workspace are supported in snowflake conda channel
367
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
368
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
369
- )
370
367
 
371
368
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
372
369
  @telemetry.send_api_usage_telemetry(
@@ -415,10 +412,8 @@ class GridSearchCV(BaseTransformer):
415
412
  )
416
413
 
417
414
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
418
- self._deps = self._batch_inference_validate_snowpark(
419
- dataset=dataset,
420
- inference_method=inference_method,
421
- )
415
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
416
+ self._deps = self._get_dependencies()
422
417
 
423
418
  assert isinstance(
424
419
  dataset._session, Session
@@ -476,7 +471,8 @@ class GridSearchCV(BaseTransformer):
476
471
  inference_method = "transform"
477
472
 
478
473
  if isinstance(dataset, DataFrame):
479
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
474
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
475
+ self._deps = self._get_dependencies()
480
476
  assert isinstance(
481
477
  dataset._session, Session
482
478
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -535,7 +531,8 @@ class GridSearchCV(BaseTransformer):
535
531
  inference_method = "predict_proba"
536
532
 
537
533
  if isinstance(dataset, DataFrame):
538
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
534
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
535
+ self._deps = self._get_dependencies()
539
536
  assert isinstance(
540
537
  dataset._session, Session
541
538
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -595,7 +592,8 @@ class GridSearchCV(BaseTransformer):
595
592
  inference_method = "predict_log_proba"
596
593
 
597
594
  if isinstance(dataset, DataFrame):
598
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
595
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
596
+ self._deps = self._get_dependencies()
599
597
  assert isinstance(
600
598
  dataset._session, Session
601
599
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -655,7 +653,8 @@ class GridSearchCV(BaseTransformer):
655
653
  inference_method = "decision_function"
656
654
 
657
655
  if isinstance(dataset, DataFrame):
658
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
656
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
657
+ self._deps = self._get_dependencies()
659
658
  assert isinstance(
660
659
  dataset._session, Session
661
660
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -716,7 +715,8 @@ class GridSearchCV(BaseTransformer):
716
715
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
717
716
 
718
717
  if isinstance(dataset, DataFrame):
719
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
718
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
719
+ self._deps = self._get_dependencies()
720
720
  assert isinstance(
721
721
  dataset._session, Session
722
722
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -767,17 +767,15 @@ class GridSearchCV(BaseTransformer):
767
767
  transform_kwargs: ScoreKwargsTypedDict = dict()
768
768
 
769
769
  if isinstance(dataset, DataFrame):
770
- self._deps = self._batch_inference_validate_snowpark(
771
- dataset=dataset,
772
- inference_method="score",
773
- )
770
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
771
+ self._deps = self._get_dependencies()
774
772
  selected_cols = self._get_active_columns()
775
773
  if len(selected_cols) > 0:
776
774
  dataset = dataset.select(selected_cols)
777
775
  assert isinstance(dataset._session, Session) # keep mypy happy
778
776
  transform_kwargs = dict(
779
777
  session=dataset._session,
780
- dependencies=["snowflake-snowpark-python"] + self._deps,
778
+ dependencies=self._deps,
781
779
  score_sproc_imports=["sklearn"],
782
780
  )
783
781
  elif isinstance(dataset, pd.DataFrame):
@@ -347,8 +347,22 @@ class RandomizedSearchCV(BaseTransformer):
347
347
  self._generate_model_signatures(dataset)
348
348
  return self
349
349
 
350
- def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
351
- """Util method to run validate that batch inference can be run on a snowpark dataframe."""
350
+ def _batch_inference_validate_snowpark(
351
+ self,
352
+ dataset: DataFrame,
353
+ inference_method: str,
354
+ ) -> None:
355
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
356
+
357
+ Args:
358
+ dataset: snowpark dataframe
359
+ inference_method: the inference method such as predict, score...
360
+
361
+ Raises:
362
+ SnowflakeMLException: If the estimator is not fitted, raise error
363
+ SnowflakeMLException: If the session is None, raise error
364
+
365
+ """
352
366
  if not self._is_fitted:
353
367
  raise exceptions.SnowflakeMLException(
354
368
  error_code=error_codes.METHOD_NOT_ALLOWED,
@@ -363,10 +377,6 @@ class RandomizedSearchCV(BaseTransformer):
363
377
  error_code=error_codes.NOT_FOUND,
364
378
  original_exception=ValueError("Session must not specified for snowpark dataset."),
365
379
  )
366
- # Validate that key package version in user workspace are supported in snowflake conda channel
367
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
368
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
369
- )
370
380
 
371
381
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
372
382
  @telemetry.send_api_usage_telemetry(
@@ -414,10 +424,9 @@ class RandomizedSearchCV(BaseTransformer):
414
424
  )
415
425
 
416
426
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
417
- self._deps = self._batch_inference_validate_snowpark(
418
- dataset=dataset,
419
- inference_method=inference_method,
420
- )
427
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
428
+ self._deps = self._get_dependencies()
429
+
421
430
  assert isinstance(
422
431
  dataset._session, Session
423
432
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -473,7 +482,9 @@ class RandomizedSearchCV(BaseTransformer):
473
482
  inference_method = "transform"
474
483
 
475
484
  if isinstance(dataset, DataFrame):
476
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
485
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
486
+ self._deps = self._get_dependencies()
487
+
477
488
  assert isinstance(
478
489
  dataset._session, Session
479
490
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -531,7 +542,9 @@ class RandomizedSearchCV(BaseTransformer):
531
542
  inference_method = "predict_proba"
532
543
 
533
544
  if isinstance(dataset, DataFrame):
534
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
545
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
546
+ self._deps = self._get_dependencies()
547
+
535
548
  assert isinstance(
536
549
  dataset._session, Session
537
550
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -591,7 +604,9 @@ class RandomizedSearchCV(BaseTransformer):
591
604
  inference_method = "predict_log_proba"
592
605
 
593
606
  if isinstance(dataset, DataFrame):
594
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
607
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
608
+ self._deps = self._get_dependencies()
609
+
595
610
  assert isinstance(
596
611
  dataset._session, Session
597
612
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -650,7 +665,9 @@ class RandomizedSearchCV(BaseTransformer):
650
665
  inference_method = "decision_function"
651
666
 
652
667
  if isinstance(dataset, DataFrame):
653
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
668
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
669
+ self._deps = self._get_dependencies()
670
+
654
671
  assert isinstance(
655
672
  dataset._session, Session
656
673
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -711,7 +728,9 @@ class RandomizedSearchCV(BaseTransformer):
711
728
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
712
729
 
713
730
  if isinstance(dataset, DataFrame):
714
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
731
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
732
+ self._deps = self._get_dependencies()
733
+
715
734
  assert isinstance(
716
735
  dataset._session, Session
717
736
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -761,10 +780,9 @@ class RandomizedSearchCV(BaseTransformer):
761
780
  transform_kwargs: ScoreKwargsTypedDict = dict()
762
781
 
763
782
  if isinstance(dataset, DataFrame):
764
- self._deps = self._batch_inference_validate_snowpark(
765
- dataset=dataset,
766
- inference_method="score",
767
- )
783
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
784
+ self._deps = self._get_dependencies()
785
+
768
786
  selected_cols = self._get_active_columns()
769
787
  if len(selected_cols) > 0:
770
788
  dataset = dataset.select(selected_cols)
@@ -772,7 +790,7 @@ class RandomizedSearchCV(BaseTransformer):
772
790
  assert isinstance(dataset._session, Session) # keep mypy happy
773
791
  transform_kwargs = dict(
774
792
  session=dataset._session,
775
- dependencies=["snowflake-snowpark-python"] + self._deps,
793
+ dependencies=self._deps,
776
794
  score_sproc_imports=["sklearn"],
777
795
  )
778
796
  elif isinstance(dataset, pd.DataFrame):
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class OneVsOneClassifier(BaseTransformer):
70
64
  r"""One-vs-one multiclass strategy
71
65
  For more details on this class, see [sklearn.multiclass.OneVsOneClassifier]
@@ -271,20 +265,17 @@ class OneVsOneClassifier(BaseTransformer):
271
265
  self,
272
266
  dataset: DataFrame,
273
267
  inference_method: str,
274
- ) -> List[str]:
275
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
276
- return the available package that exists in the snowflake anaconda channel
268
+ ) -> None:
269
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
277
270
 
278
271
  Args:
279
272
  dataset: snowpark dataframe
280
273
  inference_method: the inference method such as predict, score...
281
-
274
+
282
275
  Raises:
283
276
  SnowflakeMLException: If the estimator is not fitted, raise error
284
277
  SnowflakeMLException: If the session is None, raise error
285
278
 
286
- Returns:
287
- A list of available package that exists in the snowflake anaconda channel
288
279
  """
289
280
  if not self._is_fitted:
290
281
  raise exceptions.SnowflakeMLException(
@@ -302,9 +293,7 @@ class OneVsOneClassifier(BaseTransformer):
302
293
  "Session must not specified for snowpark dataset."
303
294
  ),
304
295
  )
305
- # Validate that key package version in user workspace are supported in snowflake conda channel
306
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
307
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
296
+
308
297
 
309
298
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
310
299
  @telemetry.send_api_usage_telemetry(
@@ -352,7 +341,8 @@ class OneVsOneClassifier(BaseTransformer):
352
341
 
353
342
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
354
343
 
355
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
344
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
345
+ self._deps = self._get_dependencies()
356
346
  assert isinstance(
357
347
  dataset._session, Session
358
348
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -435,10 +425,8 @@ class OneVsOneClassifier(BaseTransformer):
435
425
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
436
426
  expected_dtype = convert_sp_to_sf_type(output_types[0])
437
427
 
438
- self._deps = self._batch_inference_validate_snowpark(
439
- dataset=dataset,
440
- inference_method=inference_method,
441
- )
428
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
429
+ self._deps = self._get_dependencies()
442
430
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
443
431
 
444
432
  transform_kwargs = dict(
@@ -505,16 +493,40 @@ class OneVsOneClassifier(BaseTransformer):
505
493
  self._is_fitted = True
506
494
  return output_result
507
495
 
496
+
497
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
498
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
499
+ """ Method not supported for this class.
508
500
 
509
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
510
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
511
- """
501
+
502
+ Raises:
503
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
504
+
505
+ Args:
506
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
507
+ Snowpark or Pandas DataFrame.
508
+ output_cols_prefix: Prefix for the response columns
512
509
  Returns:
513
510
  Transformed dataset.
514
511
  """
515
- self.fit(dataset)
516
- assert self._sklearn_object is not None
517
- return self._sklearn_object.embedding_
512
+ self._infer_input_output_cols(dataset)
513
+ super()._check_dataset_type(dataset)
514
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
515
+ estimator=self._sklearn_object,
516
+ dataset=dataset,
517
+ input_cols=self.input_cols,
518
+ label_cols=self.label_cols,
519
+ sample_weight_col=self.sample_weight_col,
520
+ autogenerated=self._autogenerated,
521
+ subproject=_SUBPROJECT,
522
+ )
523
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
524
+ drop_input_cols=self._drop_input_cols,
525
+ expected_output_cols_list=self.output_cols,
526
+ )
527
+ self._sklearn_object = fitted_estimator
528
+ self._is_fitted = True
529
+ return output_result
518
530
 
519
531
 
520
532
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -605,10 +617,8 @@ class OneVsOneClassifier(BaseTransformer):
605
617
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
606
618
 
607
619
  if isinstance(dataset, DataFrame):
608
- self._deps = self._batch_inference_validate_snowpark(
609
- dataset=dataset,
610
- inference_method=inference_method,
611
- )
620
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
621
+ self._deps = self._get_dependencies()
612
622
  assert isinstance(
613
623
  dataset._session, Session
614
624
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -673,10 +683,8 @@ class OneVsOneClassifier(BaseTransformer):
673
683
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
674
684
 
675
685
  if isinstance(dataset, DataFrame):
676
- self._deps = self._batch_inference_validate_snowpark(
677
- dataset=dataset,
678
- inference_method=inference_method,
679
- )
686
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
687
+ self._deps = self._get_dependencies()
680
688
  assert isinstance(
681
689
  dataset._session, Session
682
690
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -740,10 +748,8 @@ class OneVsOneClassifier(BaseTransformer):
740
748
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
741
749
 
742
750
  if isinstance(dataset, DataFrame):
743
- self._deps = self._batch_inference_validate_snowpark(
744
- dataset=dataset,
745
- inference_method=inference_method,
746
- )
751
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
752
+ self._deps = self._get_dependencies()
747
753
  assert isinstance(
748
754
  dataset._session, Session
749
755
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -809,10 +815,8 @@ class OneVsOneClassifier(BaseTransformer):
809
815
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
810
816
 
811
817
  if isinstance(dataset, DataFrame):
812
- self._deps = self._batch_inference_validate_snowpark(
813
- dataset=dataset,
814
- inference_method=inference_method,
815
- )
818
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
819
+ self._deps = self._get_dependencies()
816
820
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
817
821
  transform_kwargs = dict(
818
822
  session=dataset._session,
@@ -876,17 +880,15 @@ class OneVsOneClassifier(BaseTransformer):
876
880
  transform_kwargs: ScoreKwargsTypedDict = dict()
877
881
 
878
882
  if isinstance(dataset, DataFrame):
879
- self._deps = self._batch_inference_validate_snowpark(
880
- dataset=dataset,
881
- inference_method="score",
882
- )
883
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
884
+ self._deps = self._get_dependencies()
883
885
  selected_cols = self._get_active_columns()
884
886
  if len(selected_cols) > 0:
885
887
  dataset = dataset.select(selected_cols)
886
888
  assert isinstance(dataset._session, Session) # keep mypy happy
887
889
  transform_kwargs = dict(
888
890
  session=dataset._session,
889
- dependencies=["snowflake-snowpark-python"] + self._deps,
891
+ dependencies=self._deps,
890
892
  score_sproc_imports=['sklearn'],
891
893
  )
892
894
  elif isinstance(dataset, pd.DataFrame):
@@ -951,11 +953,8 @@ class OneVsOneClassifier(BaseTransformer):
951
953
 
952
954
  if isinstance(dataset, DataFrame):
953
955
 
954
- self._deps = self._batch_inference_validate_snowpark(
955
- dataset=dataset,
956
- inference_method=inference_method,
957
-
958
- )
956
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
957
+ self._deps = self._get_dependencies()
959
958
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
960
959
  transform_kwargs = dict(
961
960
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.multiclass".replace("skl
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class OneVsRestClassifier(BaseTransformer):
70
64
  r"""One-vs-the-rest (OvR) multiclass strategy
71
65
  For more details on this class, see [sklearn.multiclass.OneVsRestClassifier]
@@ -280,20 +274,17 @@ class OneVsRestClassifier(BaseTransformer):
280
274
  self,
281
275
  dataset: DataFrame,
282
276
  inference_method: str,
283
- ) -> List[str]:
284
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
285
- return the available package that exists in the snowflake anaconda channel
277
+ ) -> None:
278
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
286
279
 
287
280
  Args:
288
281
  dataset: snowpark dataframe
289
282
  inference_method: the inference method such as predict, score...
290
-
283
+
291
284
  Raises:
292
285
  SnowflakeMLException: If the estimator is not fitted, raise error
293
286
  SnowflakeMLException: If the session is None, raise error
294
287
 
295
- Returns:
296
- A list of available package that exists in the snowflake anaconda channel
297
288
  """
298
289
  if not self._is_fitted:
299
290
  raise exceptions.SnowflakeMLException(
@@ -311,9 +302,7 @@ class OneVsRestClassifier(BaseTransformer):
311
302
  "Session must not specified for snowpark dataset."
312
303
  ),
313
304
  )
314
- # Validate that key package version in user workspace are supported in snowflake conda channel
315
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
316
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
305
+
317
306
 
318
307
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
319
308
  @telemetry.send_api_usage_telemetry(
@@ -361,7 +350,8 @@ class OneVsRestClassifier(BaseTransformer):
361
350
 
362
351
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
363
352
 
364
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
353
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
354
+ self._deps = self._get_dependencies()
365
355
  assert isinstance(
366
356
  dataset._session, Session
367
357
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -444,10 +434,8 @@ class OneVsRestClassifier(BaseTransformer):
444
434
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
445
435
  expected_dtype = convert_sp_to_sf_type(output_types[0])
446
436
 
447
- self._deps = self._batch_inference_validate_snowpark(
448
- dataset=dataset,
449
- inference_method=inference_method,
450
- )
437
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
438
+ self._deps = self._get_dependencies()
451
439
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
452
440
 
453
441
  transform_kwargs = dict(
@@ -514,16 +502,40 @@ class OneVsRestClassifier(BaseTransformer):
514
502
  self._is_fitted = True
515
503
  return output_result
516
504
 
505
+
506
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
507
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
508
+ """ Method not supported for this class.
517
509
 
518
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
519
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
520
- """
510
+
511
+ Raises:
512
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
513
+
514
+ Args:
515
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
516
+ Snowpark or Pandas DataFrame.
517
+ output_cols_prefix: Prefix for the response columns
521
518
  Returns:
522
519
  Transformed dataset.
523
520
  """
524
- self.fit(dataset)
525
- assert self._sklearn_object is not None
526
- return self._sklearn_object.embedding_
521
+ self._infer_input_output_cols(dataset)
522
+ super()._check_dataset_type(dataset)
523
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
524
+ estimator=self._sklearn_object,
525
+ dataset=dataset,
526
+ input_cols=self.input_cols,
527
+ label_cols=self.label_cols,
528
+ sample_weight_col=self.sample_weight_col,
529
+ autogenerated=self._autogenerated,
530
+ subproject=_SUBPROJECT,
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=self.output_cols,
535
+ )
536
+ self._sklearn_object = fitted_estimator
537
+ self._is_fitted = True
538
+ return output_result
527
539
 
528
540
 
529
541
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -616,10 +628,8 @@ class OneVsRestClassifier(BaseTransformer):
616
628
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
617
629
 
618
630
  if isinstance(dataset, DataFrame):
619
- self._deps = self._batch_inference_validate_snowpark(
620
- dataset=dataset,
621
- inference_method=inference_method,
622
- )
631
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
632
+ self._deps = self._get_dependencies()
623
633
  assert isinstance(
624
634
  dataset._session, Session
625
635
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -686,10 +696,8 @@ class OneVsRestClassifier(BaseTransformer):
686
696
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
687
697
 
688
698
  if isinstance(dataset, DataFrame):
689
- self._deps = self._batch_inference_validate_snowpark(
690
- dataset=dataset,
691
- inference_method=inference_method,
692
- )
699
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
700
+ self._deps = self._get_dependencies()
693
701
  assert isinstance(
694
702
  dataset._session, Session
695
703
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -753,10 +761,8 @@ class OneVsRestClassifier(BaseTransformer):
753
761
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
754
762
 
755
763
  if isinstance(dataset, DataFrame):
756
- self._deps = self._batch_inference_validate_snowpark(
757
- dataset=dataset,
758
- inference_method=inference_method,
759
- )
764
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
765
+ self._deps = self._get_dependencies()
760
766
  assert isinstance(
761
767
  dataset._session, Session
762
768
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -822,10 +828,8 @@ class OneVsRestClassifier(BaseTransformer):
822
828
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
823
829
 
824
830
  if isinstance(dataset, DataFrame):
825
- self._deps = self._batch_inference_validate_snowpark(
826
- dataset=dataset,
827
- inference_method=inference_method,
828
- )
831
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
832
+ self._deps = self._get_dependencies()
829
833
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
830
834
  transform_kwargs = dict(
831
835
  session=dataset._session,
@@ -889,17 +893,15 @@ class OneVsRestClassifier(BaseTransformer):
889
893
  transform_kwargs: ScoreKwargsTypedDict = dict()
890
894
 
891
895
  if isinstance(dataset, DataFrame):
892
- self._deps = self._batch_inference_validate_snowpark(
893
- dataset=dataset,
894
- inference_method="score",
895
- )
896
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
897
+ self._deps = self._get_dependencies()
896
898
  selected_cols = self._get_active_columns()
897
899
  if len(selected_cols) > 0:
898
900
  dataset = dataset.select(selected_cols)
899
901
  assert isinstance(dataset._session, Session) # keep mypy happy
900
902
  transform_kwargs = dict(
901
903
  session=dataset._session,
902
- dependencies=["snowflake-snowpark-python"] + self._deps,
904
+ dependencies=self._deps,
903
905
  score_sproc_imports=['sklearn'],
904
906
  )
905
907
  elif isinstance(dataset, pd.DataFrame):
@@ -964,11 +966,8 @@ class OneVsRestClassifier(BaseTransformer):
964
966
 
965
967
  if isinstance(dataset, DataFrame):
966
968
 
967
- self._deps = self._batch_inference_validate_snowpark(
968
- dataset=dataset,
969
- inference_method=inference_method,
970
-
971
- )
969
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
970
+ self._deps = self._get_dependencies()
972
971
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
973
972
  transform_kwargs = dict(
974
973
  session = dataset._session,