snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -386,18 +386,24 @@ class BayesianGaussianMixture(BaseTransformer):
386
386
  self._get_model_signatures(dataset)
387
387
  return self
388
388
 
389
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
390
- if self._drop_input_cols:
391
- return []
392
- else:
393
- return list(set(dataset.columns) - set(self.output_cols))
394
-
395
389
  def _batch_inference_validate_snowpark(
396
390
  self,
397
391
  dataset: DataFrame,
398
392
  inference_method: str,
399
393
  ) -> List[str]:
400
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
394
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
395
+ return the available package that exists in the snowflake anaconda channel
396
+
397
+ Args:
398
+ dataset: snowpark dataframe
399
+ inference_method: the inference method such as predict, score...
400
+
401
+ Raises:
402
+ SnowflakeMLException: If the estimator is not fitted, raise error
403
+ SnowflakeMLException: If the session is None, raise error
404
+
405
+ Returns:
406
+ A list of available package that exists in the snowflake anaconda channel
401
407
  """
402
408
  if not self._is_fitted:
403
409
  raise exceptions.SnowflakeMLException(
@@ -471,7 +477,7 @@ class BayesianGaussianMixture(BaseTransformer):
471
477
  transform_kwargs = dict(
472
478
  session = dataset._session,
473
479
  dependencies = self._deps,
474
- pass_through_cols = self._get_pass_through_columns(dataset),
480
+ drop_input_cols = self._drop_input_cols,
475
481
  expected_output_cols_type = expected_type_inferred,
476
482
  )
477
483
 
@@ -531,16 +537,16 @@ class BayesianGaussianMixture(BaseTransformer):
531
537
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
532
538
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
533
539
  # each row containing a list of values.
534
- expected_dtype = "ARRAY"
540
+ expected_dtype = "array"
535
541
 
536
542
  # If we were unable to assign a type to this transform in the factory, infer the type here.
537
543
  if expected_dtype == "":
538
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
544
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
539
545
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
540
- expected_dtype = "ARRAY"
541
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
546
+ expected_dtype = "array"
547
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
542
548
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
543
- expected_dtype = "ARRAY"
549
+ expected_dtype = "array"
544
550
  else:
545
551
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
546
552
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -558,7 +564,7 @@ class BayesianGaussianMixture(BaseTransformer):
558
564
  transform_kwargs = dict(
559
565
  session = dataset._session,
560
566
  dependencies = self._deps,
561
- pass_through_cols = self._get_pass_through_columns(dataset),
567
+ drop_input_cols = self._drop_input_cols,
562
568
  expected_output_cols_type = expected_dtype,
563
569
  )
564
570
 
@@ -611,7 +617,7 @@ class BayesianGaussianMixture(BaseTransformer):
611
617
  subproject=_SUBPROJECT,
612
618
  )
613
619
  output_result, fitted_estimator = model_trainer.train_fit_predict(
614
- pass_through_columns=self._get_pass_through_columns(dataset),
620
+ drop_input_cols=self._drop_input_cols,
615
621
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
616
622
  )
617
623
  self._sklearn_object = fitted_estimator
@@ -629,44 +635,6 @@ class BayesianGaussianMixture(BaseTransformer):
629
635
  assert self._sklearn_object is not None
630
636
  return self._sklearn_object.embedding_
631
637
 
632
-
633
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
634
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
635
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
636
- """
637
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
638
- if output_cols:
639
- output_cols = [
640
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
641
- for c in output_cols
642
- ]
643
- elif getattr(self._sklearn_object, "classes_", None) is None:
644
- output_cols = [output_cols_prefix]
645
- elif self._sklearn_object is not None:
646
- classes = self._sklearn_object.classes_
647
- if isinstance(classes, numpy.ndarray):
648
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
649
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
650
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
651
- output_cols = []
652
- for i, cl in enumerate(classes):
653
- # For binary classification, there is only one output column for each class
654
- # ndarray as the two classes are complementary.
655
- if len(cl) == 2:
656
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
657
- else:
658
- output_cols.extend([
659
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
660
- ])
661
- else:
662
- output_cols = []
663
-
664
- # Make sure column names are valid snowflake identifiers.
665
- assert output_cols is not None # Make MyPy happy
666
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
667
-
668
- return rv
669
-
670
638
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
671
639
  @telemetry.send_api_usage_telemetry(
672
640
  project=_PROJECT,
@@ -708,7 +676,7 @@ class BayesianGaussianMixture(BaseTransformer):
708
676
  transform_kwargs = dict(
709
677
  session=dataset._session,
710
678
  dependencies=self._deps,
711
- pass_through_cols=self._get_pass_through_columns(dataset),
679
+ drop_input_cols = self._drop_input_cols,
712
680
  expected_output_cols_type="float",
713
681
  )
714
682
 
@@ -775,7 +743,7 @@ class BayesianGaussianMixture(BaseTransformer):
775
743
  transform_kwargs = dict(
776
744
  session=dataset._session,
777
745
  dependencies=self._deps,
778
- pass_through_cols=self._get_pass_through_columns(dataset),
746
+ drop_input_cols = self._drop_input_cols,
779
747
  expected_output_cols_type="float",
780
748
  )
781
749
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +804,7 @@ class BayesianGaussianMixture(BaseTransformer):
836
804
  transform_kwargs = dict(
837
805
  session=dataset._session,
838
806
  dependencies=self._deps,
839
- pass_through_cols=self._get_pass_through_columns(dataset),
807
+ drop_input_cols = self._drop_input_cols,
840
808
  expected_output_cols_type="float",
841
809
  )
842
810
 
@@ -903,7 +871,7 @@ class BayesianGaussianMixture(BaseTransformer):
903
871
  transform_kwargs = dict(
904
872
  session=dataset._session,
905
873
  dependencies=self._deps,
906
- pass_through_cols=self._get_pass_through_columns(dataset),
874
+ drop_input_cols = self._drop_input_cols,
907
875
  expected_output_cols_type="float",
908
876
  )
909
877
 
@@ -959,13 +927,17 @@ class BayesianGaussianMixture(BaseTransformer):
959
927
  transform_kwargs: ScoreKwargsTypedDict = dict()
960
928
 
961
929
  if isinstance(dataset, DataFrame):
930
+ self._deps = self._batch_inference_validate_snowpark(
931
+ dataset=dataset,
932
+ inference_method="score",
933
+ )
962
934
  selected_cols = self._get_active_columns()
963
935
  if len(selected_cols) > 0:
964
936
  dataset = dataset.select(selected_cols)
965
937
  assert isinstance(dataset._session, Session) # keep mypy happy
966
938
  transform_kwargs = dict(
967
939
  session=dataset._session,
968
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
940
+ dependencies=["snowflake-snowpark-python"] + self._deps,
969
941
  score_sproc_imports=['sklearn'],
970
942
  )
971
943
  elif isinstance(dataset, pd.DataFrame):
@@ -1039,9 +1011,9 @@ class BayesianGaussianMixture(BaseTransformer):
1039
1011
  transform_kwargs = dict(
1040
1012
  session = dataset._session,
1041
1013
  dependencies = self._deps,
1042
- pass_through_cols = self._get_pass_through_columns(dataset),
1043
- expected_output_cols_type = "array",
1044
- n_neighbors = n_neighbors,
1014
+ drop_input_cols = self._drop_input_cols,
1015
+ expected_output_cols_type="array",
1016
+ n_neighbors = n_neighbors,
1045
1017
  return_distance = return_distance
1046
1018
  )
1047
1019
  elif isinstance(dataset, pd.DataFrame):
@@ -359,18 +359,24 @@ class GaussianMixture(BaseTransformer):
359
359
  self._get_model_signatures(dataset)
360
360
  return self
361
361
 
362
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
363
- if self._drop_input_cols:
364
- return []
365
- else:
366
- return list(set(dataset.columns) - set(self.output_cols))
367
-
368
362
  def _batch_inference_validate_snowpark(
369
363
  self,
370
364
  dataset: DataFrame,
371
365
  inference_method: str,
372
366
  ) -> List[str]:
373
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
367
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
368
+ return the available package that exists in the snowflake anaconda channel
369
+
370
+ Args:
371
+ dataset: snowpark dataframe
372
+ inference_method: the inference method such as predict, score...
373
+
374
+ Raises:
375
+ SnowflakeMLException: If the estimator is not fitted, raise error
376
+ SnowflakeMLException: If the session is None, raise error
377
+
378
+ Returns:
379
+ A list of available package that exists in the snowflake anaconda channel
374
380
  """
375
381
  if not self._is_fitted:
376
382
  raise exceptions.SnowflakeMLException(
@@ -444,7 +450,7 @@ class GaussianMixture(BaseTransformer):
444
450
  transform_kwargs = dict(
445
451
  session = dataset._session,
446
452
  dependencies = self._deps,
447
- pass_through_cols = self._get_pass_through_columns(dataset),
453
+ drop_input_cols = self._drop_input_cols,
448
454
  expected_output_cols_type = expected_type_inferred,
449
455
  )
450
456
 
@@ -504,16 +510,16 @@ class GaussianMixture(BaseTransformer):
504
510
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
505
511
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
506
512
  # each row containing a list of values.
507
- expected_dtype = "ARRAY"
513
+ expected_dtype = "array"
508
514
 
509
515
  # If we were unable to assign a type to this transform in the factory, infer the type here.
510
516
  if expected_dtype == "":
511
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
517
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
512
518
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
513
- expected_dtype = "ARRAY"
514
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
519
+ expected_dtype = "array"
520
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
515
521
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
516
- expected_dtype = "ARRAY"
522
+ expected_dtype = "array"
517
523
  else:
518
524
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
519
525
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -531,7 +537,7 @@ class GaussianMixture(BaseTransformer):
531
537
  transform_kwargs = dict(
532
538
  session = dataset._session,
533
539
  dependencies = self._deps,
534
- pass_through_cols = self._get_pass_through_columns(dataset),
540
+ drop_input_cols = self._drop_input_cols,
535
541
  expected_output_cols_type = expected_dtype,
536
542
  )
537
543
 
@@ -584,7 +590,7 @@ class GaussianMixture(BaseTransformer):
584
590
  subproject=_SUBPROJECT,
585
591
  )
586
592
  output_result, fitted_estimator = model_trainer.train_fit_predict(
587
- pass_through_columns=self._get_pass_through_columns(dataset),
593
+ drop_input_cols=self._drop_input_cols,
588
594
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
589
595
  )
590
596
  self._sklearn_object = fitted_estimator
@@ -602,44 +608,6 @@ class GaussianMixture(BaseTransformer):
602
608
  assert self._sklearn_object is not None
603
609
  return self._sklearn_object.embedding_
604
610
 
605
-
606
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
607
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
608
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
609
- """
610
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
611
- if output_cols:
612
- output_cols = [
613
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
614
- for c in output_cols
615
- ]
616
- elif getattr(self._sklearn_object, "classes_", None) is None:
617
- output_cols = [output_cols_prefix]
618
- elif self._sklearn_object is not None:
619
- classes = self._sklearn_object.classes_
620
- if isinstance(classes, numpy.ndarray):
621
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
622
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
623
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
624
- output_cols = []
625
- for i, cl in enumerate(classes):
626
- # For binary classification, there is only one output column for each class
627
- # ndarray as the two classes are complementary.
628
- if len(cl) == 2:
629
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
630
- else:
631
- output_cols.extend([
632
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
633
- ])
634
- else:
635
- output_cols = []
636
-
637
- # Make sure column names are valid snowflake identifiers.
638
- assert output_cols is not None # Make MyPy happy
639
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
640
-
641
- return rv
642
-
643
611
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
644
612
  @telemetry.send_api_usage_telemetry(
645
613
  project=_PROJECT,
@@ -681,7 +649,7 @@ class GaussianMixture(BaseTransformer):
681
649
  transform_kwargs = dict(
682
650
  session=dataset._session,
683
651
  dependencies=self._deps,
684
- pass_through_cols=self._get_pass_through_columns(dataset),
652
+ drop_input_cols = self._drop_input_cols,
685
653
  expected_output_cols_type="float",
686
654
  )
687
655
 
@@ -748,7 +716,7 @@ class GaussianMixture(BaseTransformer):
748
716
  transform_kwargs = dict(
749
717
  session=dataset._session,
750
718
  dependencies=self._deps,
751
- pass_through_cols=self._get_pass_through_columns(dataset),
719
+ drop_input_cols = self._drop_input_cols,
752
720
  expected_output_cols_type="float",
753
721
  )
754
722
  elif isinstance(dataset, pd.DataFrame):
@@ -809,7 +777,7 @@ class GaussianMixture(BaseTransformer):
809
777
  transform_kwargs = dict(
810
778
  session=dataset._session,
811
779
  dependencies=self._deps,
812
- pass_through_cols=self._get_pass_through_columns(dataset),
780
+ drop_input_cols = self._drop_input_cols,
813
781
  expected_output_cols_type="float",
814
782
  )
815
783
 
@@ -876,7 +844,7 @@ class GaussianMixture(BaseTransformer):
876
844
  transform_kwargs = dict(
877
845
  session=dataset._session,
878
846
  dependencies=self._deps,
879
- pass_through_cols=self._get_pass_through_columns(dataset),
847
+ drop_input_cols = self._drop_input_cols,
880
848
  expected_output_cols_type="float",
881
849
  )
882
850
 
@@ -932,13 +900,17 @@ class GaussianMixture(BaseTransformer):
932
900
  transform_kwargs: ScoreKwargsTypedDict = dict()
933
901
 
934
902
  if isinstance(dataset, DataFrame):
903
+ self._deps = self._batch_inference_validate_snowpark(
904
+ dataset=dataset,
905
+ inference_method="score",
906
+ )
935
907
  selected_cols = self._get_active_columns()
936
908
  if len(selected_cols) > 0:
937
909
  dataset = dataset.select(selected_cols)
938
910
  assert isinstance(dataset._session, Session) # keep mypy happy
939
911
  transform_kwargs = dict(
940
912
  session=dataset._session,
941
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
913
+ dependencies=["snowflake-snowpark-python"] + self._deps,
942
914
  score_sproc_imports=['sklearn'],
943
915
  )
944
916
  elif isinstance(dataset, pd.DataFrame):
@@ -1012,9 +984,9 @@ class GaussianMixture(BaseTransformer):
1012
984
  transform_kwargs = dict(
1013
985
  session = dataset._session,
1014
986
  dependencies = self._deps,
1015
- pass_through_cols = self._get_pass_through_columns(dataset),
1016
- expected_output_cols_type = "array",
1017
- n_neighbors = n_neighbors,
987
+ drop_input_cols = self._drop_input_cols,
988
+ expected_output_cols_type="array",
989
+ n_neighbors = n_neighbors,
1018
990
  return_distance = return_distance
1019
991
  )
1020
992
  elif isinstance(dataset, pd.DataFrame):
@@ -216,6 +216,7 @@ class GridSearchCV(BaseTransformer):
216
216
  expensive and is not strictly required to select the parameters that
217
217
  yield the best generalization performance.
218
218
  """
219
+
219
220
  _ENABLE_DISTRIBUTED = True
220
221
 
221
222
  def __init__( # type: ignore[no-untyped-def]
@@ -332,14 +333,21 @@ class GridSearchCV(BaseTransformer):
332
333
  self._get_model_signatures(dataset)
333
334
  return self
334
335
 
335
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
336
- if self._drop_input_cols:
337
- return []
338
- else:
339
- return list(set(dataset.columns) - set(self.output_cols))
336
+ def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> List[str]:
337
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
338
+ return the available package that exists in the snowflake anaconda channel
339
+
340
+ Args:
341
+ dataset: snowpark dataframe
342
+ inference_method: the inference method such as predict, score...
343
+
344
+ Raises:
345
+ SnowflakeMLException: If the estimator is not fitted, raise error
346
+ SnowflakeMLException: If the session is None, raise error
340
347
 
341
- def _batch_inference_validate_snowpark(self, dataset: DataFrame, inference_method: str) -> None:
342
- """Util method to run validate that batch inference can be run on a snowpark dataframe."""
348
+ Returns:
349
+ A list of available package that exists in the snowflake anaconda channel
350
+ """
343
351
  if not self._is_fitted:
344
352
  raise exceptions.SnowflakeMLException(
345
353
  error_code=error_codes.METHOD_NOT_ALLOWED,
@@ -355,7 +363,7 @@ class GridSearchCV(BaseTransformer):
355
363
  original_exception=ValueError("Session must not specified for snowpark dataset."),
356
364
  )
357
365
  # Validate that key package version in user workspace are supported in snowflake conda channel
358
- pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
366
+ return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
359
367
  pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
360
368
  )
361
369
 
@@ -391,7 +399,7 @@ class GridSearchCV(BaseTransformer):
391
399
  expected_type_inferred = convert_sp_to_sf_type(
392
400
  self.model_signatures["predict"].outputs[0].as_snowpark_type()
393
401
  )
394
- self._batch_inference_validate_snowpark(
402
+ self._deps = self._batch_inference_validate_snowpark(
395
403
  dataset=dataset,
396
404
  inference_method=inference_method,
397
405
  )
@@ -402,8 +410,8 @@ class GridSearchCV(BaseTransformer):
402
410
 
403
411
  transform_kwargs = dict(
404
412
  session=dataset._session,
405
- dependencies=self._get_dependencies(),
406
- pass_through_cols=self._get_pass_through_columns(dataset),
413
+ dependencies=self._deps,
414
+ drop_input_cols=self._drop_input_cols,
407
415
  expected_output_cols_type=expected_type_inferred,
408
416
  )
409
417
 
@@ -452,15 +460,15 @@ class GridSearchCV(BaseTransformer):
452
460
  inference_method = "transform"
453
461
 
454
462
  if isinstance(dataset, DataFrame):
455
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
463
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
456
464
  assert isinstance(
457
465
  dataset._session, Session
458
466
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
459
467
 
460
468
  transform_kwargs = dict(
461
469
  session=dataset._session,
462
- dependencies=self._get_dependencies(),
463
- pass_through_cols=self._get_pass_through_columns(dataset),
470
+ dependencies=self._deps,
471
+ drop_input_cols=self._drop_input_cols,
464
472
  )
465
473
 
466
474
  elif isinstance(dataset, pd.DataFrame):
@@ -482,36 +490,6 @@ class GridSearchCV(BaseTransformer):
482
490
  )
483
491
  return output_df
484
492
 
485
- def _get_output_column_names(self, output_cols_prefix: str) -> List[str]:
486
- """Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
487
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
488
-
489
- Args:
490
- output_cols_prefix (str): prefix according to the function
491
-
492
- Returns:
493
- List[str]: output cols with prefix
494
- """
495
- if getattr(self._sklearn_object, "classes_", None) is None:
496
- return [output_cols_prefix]
497
-
498
- assert self._sklearn_object is not None # keep mypy happy
499
- classes = self._sklearn_object.classes_
500
- if isinstance(classes, np.ndarray):
501
- return [f"{output_cols_prefix}{c}" for c in classes.tolist()]
502
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], np.ndarray):
503
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
504
- output_cols = []
505
- for i, cl in enumerate(classes):
506
- # For binary classification, there is only one output column for each class
507
- # ndarray as the two classes are complementary.
508
- if len(cl) == 2:
509
- output_cols.append(f"{output_cols_prefix}_{i}_{cl[0]}")
510
- else:
511
- output_cols.extend([f"{output_cols_prefix}_{i}_{c}" for c in cl.tolist()])
512
- return output_cols
513
- return []
514
-
515
493
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
516
494
  @telemetry.send_api_usage_telemetry(
517
495
  project=_PROJECT,
@@ -541,14 +519,14 @@ class GridSearchCV(BaseTransformer):
541
519
  inference_method = "predict_proba"
542
520
 
543
521
  if isinstance(dataset, DataFrame):
544
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
522
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
545
523
  assert isinstance(
546
524
  dataset._session, Session
547
525
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
548
526
  transform_kwargs = dict(
549
527
  session=dataset._session,
550
- dependencies=self._get_dependencies(),
551
- pass_through_cols=self._get_pass_through_columns(dataset),
528
+ dependencies=self._deps,
529
+ drop_input_cols=self._drop_input_cols,
552
530
  expected_output_cols_type="float",
553
531
  )
554
532
 
@@ -601,14 +579,14 @@ class GridSearchCV(BaseTransformer):
601
579
  inference_method = "predict_log_proba"
602
580
 
603
581
  if isinstance(dataset, DataFrame):
604
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
582
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
605
583
  assert isinstance(
606
584
  dataset._session, Session
607
585
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
608
586
  transform_kwargs = dict(
609
587
  session=dataset._session,
610
- dependencies=self._get_dependencies(),
611
- pass_through_cols=self._get_pass_through_columns(dataset),
588
+ drop_input_cols=self._drop_input_cols,
589
+ dependencies=self._deps,
612
590
  expected_output_cols_type="float",
613
591
  )
614
592
 
@@ -661,14 +639,14 @@ class GridSearchCV(BaseTransformer):
661
639
  inference_method = "decision_function"
662
640
 
663
641
  if isinstance(dataset, DataFrame):
664
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
642
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
665
643
  assert isinstance(
666
644
  dataset._session, Session
667
645
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
668
646
  transform_kwargs = dict(
669
647
  session=dataset._session,
670
- dependencies=self._get_dependencies(),
671
- pass_through_cols=self._get_pass_through_columns(dataset),
648
+ dependencies=self._deps,
649
+ drop_input_cols=self._drop_input_cols,
672
650
  expected_output_cols_type="float",
673
651
  )
674
652
 
@@ -722,14 +700,14 @@ class GridSearchCV(BaseTransformer):
722
700
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
723
701
 
724
702
  if isinstance(dataset, DataFrame):
725
- self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
703
+ self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
726
704
  assert isinstance(
727
705
  dataset._session, Session
728
706
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
729
707
  transform_kwargs = dict(
730
708
  session=dataset._session,
731
- dependencies=self._get_dependencies(),
732
- pass_through_cols=self._get_pass_through_columns(dataset),
709
+ dependencies=self._deps,
710
+ drop_input_cols=self._drop_input_cols,
733
711
  expected_output_cols_type="float",
734
712
  )
735
713
 
@@ -773,13 +751,17 @@ class GridSearchCV(BaseTransformer):
773
751
  transform_kwargs: ScoreKwargsTypedDict = dict()
774
752
 
775
753
  if isinstance(dataset, DataFrame):
754
+ self._deps = self._batch_inference_validate_snowpark(
755
+ dataset=dataset,
756
+ inference_method="score",
757
+ )
776
758
  selected_cols = self._get_active_columns()
777
759
  if len(selected_cols) > 0:
778
760
  dataset = dataset.select(selected_cols)
779
761
  assert isinstance(dataset._session, Session) # keep mypy happy
780
762
  transform_kwargs = dict(
781
763
  session=dataset._session,
782
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
764
+ dependencies=["snowflake-snowpark-python"] + self._deps,
783
765
  score_sproc_imports=["sklearn"],
784
766
  )
785
767
  elif isinstance(dataset, pd.DataFrame):