snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -224,6 +224,7 @@ class RandomizedSearchCV(BaseTransformer):
224
224
  expensive and is not strictly required to select the parameters that
225
225
  yield the best generalization performance.
226
226
  """
227
+
227
228
  _ENABLE_DISTRIBUTED = True
228
229
 
229
230
  def __init__( # type: ignore[no-untyped-def]
@@ -345,13 +346,7 @@ class RandomizedSearchCV(BaseTransformer):
345
346
  self._get_model_signatures(dataset)
346
347
  return self
347
348
 
348
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
349
- if self._drop_input_cols:
350
- return []
351
- else:
352
- return list(set(dataset.columns) - set(self.output_cols))
353
-
354
- def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> None:
349
+ def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
355
350
  """Util method to run validate that batch inference can be run on a snowpark dataframe."""
356
351
  if not self._is_fitted:
357
352
  raise exceptions.SnowflakeMLException(
@@ -368,7 +363,7 @@ class RandomizedSearchCV(BaseTransformer):
368
363
  original_exception=ValueError("Session must not specified for snowpark dataset."),
369
364
  )
370
365
  # Validate that key package version in user workspace are supported in snowflake conda channel
371
- pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
366
+ return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
372
367
  pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
373
368
  )
374
369
 
@@ -403,7 +398,7 @@ class RandomizedSearchCV(BaseTransformer):
403
398
  expected_type_inferred = convert_sp_to_sf_type(
404
399
  self.model_signatures["predict"].outputs[0].as_snowpark_type()
405
400
  )
406
- self._batch_inference_validate_snowpark(
401
+ self._deps = self._batch_inference_validate_snowpark(
407
402
  dataset=dataset,
408
403
  inference_method=inference_method,
409
404
  )
@@ -412,8 +407,8 @@ class RandomizedSearchCV(BaseTransformer):
412
407
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
413
408
  transform_kwargs = dict(
414
409
  session=dataset._session,
415
- dependencies=self._get_dependencies(),
416
- pass_through_cols=self._get_pass_through_columns(dataset),
410
+ dependencies=self._deps,
411
+ drop_input_cols=self._drop_input_cols,
417
412
  expected_output_cols_type=expected_type_inferred,
418
413
  )
419
414
 
@@ -462,14 +457,14 @@ class RandomizedSearchCV(BaseTransformer):
462
457
  inference_method = "transform"
463
458
 
464
459
  if isinstance(dataset, DataFrame):
465
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
460
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
466
461
  assert isinstance(
467
462
  dataset._session, Session
468
463
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
469
464
  transform_kwargs = dict(
470
465
  session=dataset._session,
471
- dependencies=self._get_dependencies(),
472
- pass_through_cols=self._get_pass_through_columns(dataset),
466
+ dependencies=self._deps,
467
+ drop_input_cols=self._drop_input_cols,
473
468
  )
474
469
 
475
470
  elif isinstance(dataset, pd.DataFrame):
@@ -491,36 +486,6 @@ class RandomizedSearchCV(BaseTransformer):
491
486
  )
492
487
  return output_df
493
488
 
494
- def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
495
- """Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
496
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
497
-
498
- Args:
499
- output_cols_prefix (str): prefix according to the function
500
-
501
- Returns:
502
- List[str]: output cols with prefix
503
- """
504
- if getattr(self._sklearn_object, "classes_", None) is None:
505
- return [output_cols_prefix]
506
-
507
- assert self._sklearn_object is not None # keep mypy happy
508
- classes = self._sklearn_object.classes_
509
- if isinstance(classes, np.ndarray):
510
- return [f"{output_cols_prefix}{c}" for c in classes.tolist()]
511
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], np.ndarray):
512
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
513
- output_cols = []
514
- for i, cl in enumerate(classes):
515
- # For binary classification, there is only one output column for each class
516
- # ndarray as the two classes are complementary.
517
- if len(cl) == 2:
518
- output_cols.append(f"{output_cols_prefix}_{i}_{cl[0]}")
519
- else:
520
- output_cols.extend([f"{output_cols_prefix}_{i}_{c}" for c in cl.tolist()])
521
- return output_cols
522
- return []
523
-
524
489
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
525
490
  @telemetry.send_api_usage_telemetry(
526
491
  project=_PROJECT,
@@ -550,14 +515,14 @@ class RandomizedSearchCV(BaseTransformer):
550
515
  inference_method = "predict_proba"
551
516
 
552
517
  if isinstance(dataset, DataFrame):
553
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
518
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
554
519
  assert isinstance(
555
520
  dataset._session, Session
556
521
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
557
522
  transform_kwargs = dict(
558
523
  session=dataset._session,
559
- dependencies=self._get_dependencies(),
560
- pass_through_cols=self._get_pass_through_columns(dataset),
524
+ dependencies=self._deps,
525
+ drop_input_cols=self._drop_input_cols,
561
526
  expected_output_cols_type="float",
562
527
  )
563
528
 
@@ -610,14 +575,14 @@ class RandomizedSearchCV(BaseTransformer):
610
575
  inference_method = "predict_log_proba"
611
576
 
612
577
  if isinstance(dataset, DataFrame):
613
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
578
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
614
579
  assert isinstance(
615
580
  dataset._session, Session
616
581
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
617
582
  transform_kwargs = dict(
618
583
  session=dataset._session,
619
- dependencies=self._get_dependencies(),
620
- pass_through_cols=self._get_pass_through_columns(dataset),
584
+ dependencies=self._deps,
585
+ drop_input_cols=self._drop_input_cols,
621
586
  expected_output_cols_type="float",
622
587
  )
623
588
 
@@ -669,14 +634,14 @@ class RandomizedSearchCV(BaseTransformer):
669
634
  inference_method = "decision_function"
670
635
 
671
636
  if isinstance(dataset, DataFrame):
672
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
637
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
673
638
  assert isinstance(
674
639
  dataset._session, Session
675
640
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
676
641
  transform_kwargs = dict(
677
642
  session=dataset._session,
678
- dependencies=self._get_dependencies(),
679
- pass_through_cols=self._get_pass_through_columns(dataset),
643
+ dependencies=self._deps,
644
+ drop_input_cols=self._drop_input_cols,
680
645
  expected_output_cols_type="float",
681
646
  )
682
647
 
@@ -730,14 +695,14 @@ class RandomizedSearchCV(BaseTransformer):
730
695
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
731
696
 
732
697
  if isinstance(dataset, DataFrame):
733
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
698
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
734
699
  assert isinstance(
735
700
  dataset._session, Session
736
701
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
737
702
  transform_kwargs = dict(
738
703
  session=dataset._session,
739
- dependencies=self._get_dependencies(),
740
- pass_through_cols=self._get_pass_through_columns(dataset),
704
+ dependencies=self._deps,
705
+ drop_input_cols=self._drop_input_cols,
741
706
  expected_output_cols_type="float",
742
707
  )
743
708
 
@@ -780,6 +745,10 @@ class RandomizedSearchCV(BaseTransformer):
780
745
  transform_kwargs: ScoreKwargsTypedDict = dict()
781
746
 
782
747
  if isinstance(dataset, DataFrame):
748
+ self._deps = self._batch_inference_validate_snowpark(
749
+ dataset=dataset,
750
+ inference_method="score",
751
+ )
783
752
  selected_cols = self._get_active_columns()
784
753
  if len(selected_cols) > 0:
785
754
  dataset = dataset.select(selected_cols)
@@ -787,7 +756,7 @@ class RandomizedSearchCV(BaseTransformer):
787
756
  assert isinstance(dataset._session, Session) # keep mypy happy
788
757
  transform_kwargs = dict(
789
758
  session=dataset._session,
790
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
759
+ dependencies=["snowflake-snowpark-python"] + self._deps,
791
760
  score_sproc_imports=["sklearn"],
792
761
  )
793
762
  elif isinstance(dataset, pd.DataFrame):
@@ -271,18 +271,24 @@ class OneVsOneClassifier(BaseTransformer):
271
271
  self._get_model_signatures(dataset)
272
272
  return self
273
273
 
274
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
275
- if self._drop_input_cols:
276
- return []
277
- else:
278
- return list(set(dataset.columns) - set(self.output_cols))
279
-
280
274
  def _batch_inference_validate_snowpark(
281
275
  self,
282
276
  dataset: DataFrame,
283
277
  inference_method: str,
284
278
  ) -> List[str]:
285
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
279
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
280
+ return the available package that exists in the snowflake anaconda channel
281
+
282
+ Args:
283
+ dataset: snowpark dataframe
284
+ inference_method: the inference method such as predict, score...
285
+
286
+ Raises:
287
+ SnowflakeMLException: If the estimator is not fitted, raise error
288
+ SnowflakeMLException: If the session is None, raise error
289
+
290
+ Returns:
291
+ A list of available package that exists in the snowflake anaconda channel
286
292
  """
287
293
  if not self._is_fitted:
288
294
  raise exceptions.SnowflakeMLException(
@@ -356,7 +362,7 @@ class OneVsOneClassifier(BaseTransformer):
356
362
  transform_kwargs = dict(
357
363
  session = dataset._session,
358
364
  dependencies = self._deps,
359
- pass_through_cols = self._get_pass_through_columns(dataset),
365
+ drop_input_cols = self._drop_input_cols,
360
366
  expected_output_cols_type = expected_type_inferred,
361
367
  )
362
368
 
@@ -416,16 +422,16 @@ class OneVsOneClassifier(BaseTransformer):
416
422
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
417
423
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
418
424
  # each row containing a list of values.
419
- expected_dtype = "ARRAY"
425
+ expected_dtype = "array"
420
426
 
421
427
  # If we were unable to assign a type to this transform in the factory, infer the type here.
422
428
  if expected_dtype == "":
423
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
429
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
424
430
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
425
- expected_dtype = "ARRAY"
426
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
431
+ expected_dtype = "array"
432
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
427
433
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
428
- expected_dtype = "ARRAY"
434
+ expected_dtype = "array"
429
435
  else:
430
436
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
431
437
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -443,7 +449,7 @@ class OneVsOneClassifier(BaseTransformer):
443
449
  transform_kwargs = dict(
444
450
  session = dataset._session,
445
451
  dependencies = self._deps,
446
- pass_through_cols = self._get_pass_through_columns(dataset),
452
+ drop_input_cols = self._drop_input_cols,
447
453
  expected_output_cols_type = expected_dtype,
448
454
  )
449
455
 
@@ -494,7 +500,7 @@ class OneVsOneClassifier(BaseTransformer):
494
500
  subproject=_SUBPROJECT,
495
501
  )
496
502
  output_result, fitted_estimator = model_trainer.train_fit_predict(
497
- pass_through_columns=self._get_pass_through_columns(dataset),
503
+ drop_input_cols=self._drop_input_cols,
498
504
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
499
505
  )
500
506
  self._sklearn_object = fitted_estimator
@@ -512,44 +518,6 @@ class OneVsOneClassifier(BaseTransformer):
512
518
  assert self._sklearn_object is not None
513
519
  return self._sklearn_object.embedding_
514
520
 
515
-
516
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
517
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
518
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
519
- """
520
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
521
- if output_cols:
522
- output_cols = [
523
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
524
- for c in output_cols
525
- ]
526
- elif getattr(self._sklearn_object, "classes_", None) is None:
527
- output_cols = [output_cols_prefix]
528
- elif self._sklearn_object is not None:
529
- classes = self._sklearn_object.classes_
530
- if isinstance(classes, numpy.ndarray):
531
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
532
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
533
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
534
- output_cols = []
535
- for i, cl in enumerate(classes):
536
- # For binary classification, there is only one output column for each class
537
- # ndarray as the two classes are complementary.
538
- if len(cl) == 2:
539
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
540
- else:
541
- output_cols.extend([
542
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
543
- ])
544
- else:
545
- output_cols = []
546
-
547
- # Make sure column names are valid snowflake identifiers.
548
- assert output_cols is not None # Make MyPy happy
549
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
550
-
551
- return rv
552
-
553
521
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
554
522
  @telemetry.send_api_usage_telemetry(
555
523
  project=_PROJECT,
@@ -589,7 +557,7 @@ class OneVsOneClassifier(BaseTransformer):
589
557
  transform_kwargs = dict(
590
558
  session=dataset._session,
591
559
  dependencies=self._deps,
592
- pass_through_cols=self._get_pass_through_columns(dataset),
560
+ drop_input_cols = self._drop_input_cols,
593
561
  expected_output_cols_type="float",
594
562
  )
595
563
 
@@ -654,7 +622,7 @@ class OneVsOneClassifier(BaseTransformer):
654
622
  transform_kwargs = dict(
655
623
  session=dataset._session,
656
624
  dependencies=self._deps,
657
- pass_through_cols=self._get_pass_through_columns(dataset),
625
+ drop_input_cols = self._drop_input_cols,
658
626
  expected_output_cols_type="float",
659
627
  )
660
628
  elif isinstance(dataset, pd.DataFrame):
@@ -717,7 +685,7 @@ class OneVsOneClassifier(BaseTransformer):
717
685
  transform_kwargs = dict(
718
686
  session=dataset._session,
719
687
  dependencies=self._deps,
720
- pass_through_cols=self._get_pass_through_columns(dataset),
688
+ drop_input_cols = self._drop_input_cols,
721
689
  expected_output_cols_type="float",
722
690
  )
723
691
 
@@ -782,7 +750,7 @@ class OneVsOneClassifier(BaseTransformer):
782
750
  transform_kwargs = dict(
783
751
  session=dataset._session,
784
752
  dependencies=self._deps,
785
- pass_through_cols=self._get_pass_through_columns(dataset),
753
+ drop_input_cols = self._drop_input_cols,
786
754
  expected_output_cols_type="float",
787
755
  )
788
756
 
@@ -838,13 +806,17 @@ class OneVsOneClassifier(BaseTransformer):
838
806
  transform_kwargs: ScoreKwargsTypedDict = dict()
839
807
 
840
808
  if isinstance(dataset, DataFrame):
809
+ self._deps = self._batch_inference_validate_snowpark(
810
+ dataset=dataset,
811
+ inference_method="score",
812
+ )
841
813
  selected_cols = self._get_active_columns()
842
814
  if len(selected_cols) > 0:
843
815
  dataset = dataset.select(selected_cols)
844
816
  assert isinstance(dataset._session, Session) # keep mypy happy
845
817
  transform_kwargs = dict(
846
818
  session=dataset._session,
847
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
819
+ dependencies=["snowflake-snowpark-python"] + self._deps,
848
820
  score_sproc_imports=['sklearn'],
849
821
  )
850
822
  elif isinstance(dataset, pd.DataFrame):
@@ -918,9 +890,9 @@ class OneVsOneClassifier(BaseTransformer):
918
890
  transform_kwargs = dict(
919
891
  session = dataset._session,
920
892
  dependencies = self._deps,
921
- pass_through_cols = self._get_pass_through_columns(dataset),
922
- expected_output_cols_type = "array",
923
- n_neighbors = n_neighbors,
893
+ drop_input_cols = self._drop_input_cols,
894
+ expected_output_cols_type="array",
895
+ n_neighbors = n_neighbors,
924
896
  return_distance = return_distance
925
897
  )
926
898
  elif isinstance(dataset, pd.DataFrame):
@@ -280,18 +280,24 @@ class OneVsRestClassifier(BaseTransformer):
280
280
  self._get_model_signatures(dataset)
281
281
  return self
282
282
 
283
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
284
- if self._drop_input_cols:
285
- return []
286
- else:
287
- return list(set(dataset.columns) - set(self.output_cols))
288
-
289
283
  def _batch_inference_validate_snowpark(
290
284
  self,
291
285
  dataset: DataFrame,
292
286
  inference_method: str,
293
287
  ) -> List[str]:
294
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
288
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
289
+ return the available package that exists in the snowflake anaconda channel
290
+
291
+ Args:
292
+ dataset: snowpark dataframe
293
+ inference_method: the inference method such as predict, score...
294
+
295
+ Raises:
296
+ SnowflakeMLException: If the estimator is not fitted, raise error
297
+ SnowflakeMLException: If the session is None, raise error
298
+
299
+ Returns:
300
+ A list of available package that exists in the snowflake anaconda channel
295
301
  """
296
302
  if not self._is_fitted:
297
303
  raise exceptions.SnowflakeMLException(
@@ -365,7 +371,7 @@ class OneVsRestClassifier(BaseTransformer):
365
371
  transform_kwargs = dict(
366
372
  session = dataset._session,
367
373
  dependencies = self._deps,
368
- pass_through_cols = self._get_pass_through_columns(dataset),
374
+ drop_input_cols = self._drop_input_cols,
369
375
  expected_output_cols_type = expected_type_inferred,
370
376
  )
371
377
 
@@ -425,16 +431,16 @@ class OneVsRestClassifier(BaseTransformer):
425
431
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
426
432
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
427
433
  # each row containing a list of values.
428
- expected_dtype = "ARRAY"
434
+ expected_dtype = "array"
429
435
 
430
436
  # If we were unable to assign a type to this transform in the factory, infer the type here.
431
437
  if expected_dtype == "":
432
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
438
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
433
439
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
434
- expected_dtype = "ARRAY"
435
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
440
+ expected_dtype = "array"
441
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
436
442
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
437
- expected_dtype = "ARRAY"
443
+ expected_dtype = "array"
438
444
  else:
439
445
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
440
446
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -452,7 +458,7 @@ class OneVsRestClassifier(BaseTransformer):
452
458
  transform_kwargs = dict(
453
459
  session = dataset._session,
454
460
  dependencies = self._deps,
455
- pass_through_cols = self._get_pass_through_columns(dataset),
461
+ drop_input_cols = self._drop_input_cols,
456
462
  expected_output_cols_type = expected_dtype,
457
463
  )
458
464
 
@@ -503,7 +509,7 @@ class OneVsRestClassifier(BaseTransformer):
503
509
  subproject=_SUBPROJECT,
504
510
  )
505
511
  output_result, fitted_estimator = model_trainer.train_fit_predict(
506
- pass_through_columns=self._get_pass_through_columns(dataset),
512
+ drop_input_cols=self._drop_input_cols,
507
513
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
508
514
  )
509
515
  self._sklearn_object = fitted_estimator
@@ -521,44 +527,6 @@ class OneVsRestClassifier(BaseTransformer):
521
527
  assert self._sklearn_object is not None
522
528
  return self._sklearn_object.embedding_
523
529
 
524
-
525
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
526
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
527
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
528
- """
529
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
530
- if output_cols:
531
- output_cols = [
532
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
533
- for c in output_cols
534
- ]
535
- elif getattr(self._sklearn_object, "classes_", None) is None:
536
- output_cols = [output_cols_prefix]
537
- elif self._sklearn_object is not None:
538
- classes = self._sklearn_object.classes_
539
- if isinstance(classes, numpy.ndarray):
540
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
541
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
542
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
543
- output_cols = []
544
- for i, cl in enumerate(classes):
545
- # For binary classification, there is only one output column for each class
546
- # ndarray as the two classes are complementary.
547
- if len(cl) == 2:
548
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
549
- else:
550
- output_cols.extend([
551
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
552
- ])
553
- else:
554
- output_cols = []
555
-
556
- # Make sure column names are valid snowflake identifiers.
557
- assert output_cols is not None # Make MyPy happy
558
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
559
-
560
- return rv
561
-
562
530
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
563
531
  @telemetry.send_api_usage_telemetry(
564
532
  project=_PROJECT,
@@ -600,7 +568,7 @@ class OneVsRestClassifier(BaseTransformer):
600
568
  transform_kwargs = dict(
601
569
  session=dataset._session,
602
570
  dependencies=self._deps,
603
- pass_through_cols=self._get_pass_through_columns(dataset),
571
+ drop_input_cols = self._drop_input_cols,
604
572
  expected_output_cols_type="float",
605
573
  )
606
574
 
@@ -667,7 +635,7 @@ class OneVsRestClassifier(BaseTransformer):
667
635
  transform_kwargs = dict(
668
636
  session=dataset._session,
669
637
  dependencies=self._deps,
670
- pass_through_cols=self._get_pass_through_columns(dataset),
638
+ drop_input_cols = self._drop_input_cols,
671
639
  expected_output_cols_type="float",
672
640
  )
673
641
  elif isinstance(dataset, pd.DataFrame):
@@ -730,7 +698,7 @@ class OneVsRestClassifier(BaseTransformer):
730
698
  transform_kwargs = dict(
731
699
  session=dataset._session,
732
700
  dependencies=self._deps,
733
- pass_through_cols=self._get_pass_through_columns(dataset),
701
+ drop_input_cols = self._drop_input_cols,
734
702
  expected_output_cols_type="float",
735
703
  )
736
704
 
@@ -795,7 +763,7 @@ class OneVsRestClassifier(BaseTransformer):
795
763
  transform_kwargs = dict(
796
764
  session=dataset._session,
797
765
  dependencies=self._deps,
798
- pass_through_cols=self._get_pass_through_columns(dataset),
766
+ drop_input_cols = self._drop_input_cols,
799
767
  expected_output_cols_type="float",
800
768
  )
801
769
 
@@ -851,13 +819,17 @@ class OneVsRestClassifier(BaseTransformer):
851
819
  transform_kwargs: ScoreKwargsTypedDict = dict()
852
820
 
853
821
  if isinstance(dataset, DataFrame):
822
+ self._deps = self._batch_inference_validate_snowpark(
823
+ dataset=dataset,
824
+ inference_method="score",
825
+ )
854
826
  selected_cols = self._get_active_columns()
855
827
  if len(selected_cols) > 0:
856
828
  dataset = dataset.select(selected_cols)
857
829
  assert isinstance(dataset._session, Session) # keep mypy happy
858
830
  transform_kwargs = dict(
859
831
  session=dataset._session,
860
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
832
+ dependencies=["snowflake-snowpark-python"] + self._deps,
861
833
  score_sproc_imports=['sklearn'],
862
834
  )
863
835
  elif isinstance(dataset, pd.DataFrame):
@@ -931,9 +903,9 @@ class OneVsRestClassifier(BaseTransformer):
931
903
  transform_kwargs = dict(
932
904
  session = dataset._session,
933
905
  dependencies = self._deps,
934
- pass_through_cols = self._get_pass_through_columns(dataset),
935
- expected_output_cols_type = "array",
936
- n_neighbors = n_neighbors,
906
+ drop_input_cols = self._drop_input_cols,
907
+ expected_output_cols_type="array",
908
+ n_neighbors = n_neighbors,
937
909
  return_distance = return_distance
938
910
  )
939
911
  elif isinstance(dataset, pd.DataFrame):