snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -303,18 +303,24 @@ class AffinityPropagation(BaseTransformer):
303
303
  self._get_model_signatures(dataset)
304
304
  return self
305
305
 
306
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
307
- if self._drop_input_cols:
308
- return []
309
- else:
310
- return list(set(dataset.columns) - set(self.output_cols))
311
-
312
306
  def _batch_inference_validate_snowpark(
313
307
  self,
314
308
  dataset: DataFrame,
315
309
  inference_method: str,
316
310
  ) -> List[str]:
317
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
311
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
312
+ return the available package that exists in the snowflake anaconda channel
313
+
314
+ Args:
315
+ dataset: snowpark dataframe
316
+ inference_method: the inference method such as predict, score...
317
+
318
+ Raises:
319
+ SnowflakeMLException: If the estimator is not fitted, raise error
320
+ SnowflakeMLException: If the session is None, raise error
321
+
322
+ Returns:
323
+ A list of available package that exists in the snowflake anaconda channel
318
324
  """
319
325
  if not self._is_fitted:
320
326
  raise exceptions.SnowflakeMLException(
@@ -388,7 +394,7 @@ class AffinityPropagation(BaseTransformer):
388
394
  transform_kwargs = dict(
389
395
  session = dataset._session,
390
396
  dependencies = self._deps,
391
- pass_through_cols = self._get_pass_through_columns(dataset),
397
+ drop_input_cols = self._drop_input_cols,
392
398
  expected_output_cols_type = expected_type_inferred,
393
399
  )
394
400
 
@@ -448,16 +454,16 @@ class AffinityPropagation(BaseTransformer):
448
454
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
449
455
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
450
456
  # each row containing a list of values.
451
- expected_dtype = "ARRAY"
457
+ expected_dtype = "array"
452
458
 
453
459
  # If we were unable to assign a type to this transform in the factory, infer the type here.
454
460
  if expected_dtype == "":
455
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
461
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
456
462
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
457
- expected_dtype = "ARRAY"
458
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
463
+ expected_dtype = "array"
464
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
459
465
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
460
- expected_dtype = "ARRAY"
466
+ expected_dtype = "array"
461
467
  else:
462
468
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
463
469
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -475,7 +481,7 @@ class AffinityPropagation(BaseTransformer):
475
481
  transform_kwargs = dict(
476
482
  session = dataset._session,
477
483
  dependencies = self._deps,
478
- pass_through_cols = self._get_pass_through_columns(dataset),
484
+ drop_input_cols = self._drop_input_cols,
479
485
  expected_output_cols_type = expected_dtype,
480
486
  )
481
487
 
@@ -528,7 +534,7 @@ class AffinityPropagation(BaseTransformer):
528
534
  subproject=_SUBPROJECT,
529
535
  )
530
536
  output_result, fitted_estimator = model_trainer.train_fit_predict(
531
- pass_through_columns=self._get_pass_through_columns(dataset),
537
+ drop_input_cols=self._drop_input_cols,
532
538
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
533
539
  )
534
540
  self._sklearn_object = fitted_estimator
@@ -546,44 +552,6 @@ class AffinityPropagation(BaseTransformer):
546
552
  assert self._sklearn_object is not None
547
553
  return self._sklearn_object.embedding_
548
554
 
549
-
550
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
551
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
552
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
553
- """
554
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
555
- if output_cols:
556
- output_cols = [
557
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
558
- for c in output_cols
559
- ]
560
- elif getattr(self._sklearn_object, "classes_", None) is None:
561
- output_cols = [output_cols_prefix]
562
- elif self._sklearn_object is not None:
563
- classes = self._sklearn_object.classes_
564
- if isinstance(classes, numpy.ndarray):
565
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
566
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
567
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
568
- output_cols = []
569
- for i, cl in enumerate(classes):
570
- # For binary classification, there is only one output column for each class
571
- # ndarray as the two classes are complementary.
572
- if len(cl) == 2:
573
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
574
- else:
575
- output_cols.extend([
576
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
577
- ])
578
- else:
579
- output_cols = []
580
-
581
- # Make sure column names are valid snowflake identifiers.
582
- assert output_cols is not None # Make MyPy happy
583
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
584
-
585
- return rv
586
-
587
555
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
588
556
  @telemetry.send_api_usage_telemetry(
589
557
  project=_PROJECT,
@@ -623,7 +591,7 @@ class AffinityPropagation(BaseTransformer):
623
591
  transform_kwargs = dict(
624
592
  session=dataset._session,
625
593
  dependencies=self._deps,
626
- pass_through_cols=self._get_pass_through_columns(dataset),
594
+ drop_input_cols = self._drop_input_cols,
627
595
  expected_output_cols_type="float",
628
596
  )
629
597
 
@@ -688,7 +656,7 @@ class AffinityPropagation(BaseTransformer):
688
656
  transform_kwargs = dict(
689
657
  session=dataset._session,
690
658
  dependencies=self._deps,
691
- pass_through_cols=self._get_pass_through_columns(dataset),
659
+ drop_input_cols = self._drop_input_cols,
692
660
  expected_output_cols_type="float",
693
661
  )
694
662
  elif isinstance(dataset, pd.DataFrame):
@@ -749,7 +717,7 @@ class AffinityPropagation(BaseTransformer):
749
717
  transform_kwargs = dict(
750
718
  session=dataset._session,
751
719
  dependencies=self._deps,
752
- pass_through_cols=self._get_pass_through_columns(dataset),
720
+ drop_input_cols = self._drop_input_cols,
753
721
  expected_output_cols_type="float",
754
722
  )
755
723
 
@@ -814,7 +782,7 @@ class AffinityPropagation(BaseTransformer):
814
782
  transform_kwargs = dict(
815
783
  session=dataset._session,
816
784
  dependencies=self._deps,
817
- pass_through_cols=self._get_pass_through_columns(dataset),
785
+ drop_input_cols = self._drop_input_cols,
818
786
  expected_output_cols_type="float",
819
787
  )
820
788
 
@@ -868,13 +836,17 @@ class AffinityPropagation(BaseTransformer):
868
836
  transform_kwargs: ScoreKwargsTypedDict = dict()
869
837
 
870
838
  if isinstance(dataset, DataFrame):
839
+ self._deps = self._batch_inference_validate_snowpark(
840
+ dataset=dataset,
841
+ inference_method="score",
842
+ )
871
843
  selected_cols = self._get_active_columns()
872
844
  if len(selected_cols) > 0:
873
845
  dataset = dataset.select(selected_cols)
874
846
  assert isinstance(dataset._session, Session) # keep mypy happy
875
847
  transform_kwargs = dict(
876
848
  session=dataset._session,
877
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
849
+ dependencies=["snowflake-snowpark-python"] + self._deps,
878
850
  score_sproc_imports=['sklearn'],
879
851
  )
880
852
  elif isinstance(dataset, pd.DataFrame):
@@ -948,9 +920,9 @@ class AffinityPropagation(BaseTransformer):
948
920
  transform_kwargs = dict(
949
921
  session = dataset._session,
950
922
  dependencies = self._deps,
951
- pass_through_cols = self._get_pass_through_columns(dataset),
952
- expected_output_cols_type = "array",
953
- n_neighbors = n_neighbors,
923
+ drop_input_cols = self._drop_input_cols,
924
+ expected_output_cols_type="array",
925
+ n_neighbors = n_neighbors,
954
926
  return_distance = return_distance
955
927
  )
956
928
  elif isinstance(dataset, pd.DataFrame):
@@ -336,18 +336,24 @@ class AgglomerativeClustering(BaseTransformer):
336
336
  self._get_model_signatures(dataset)
337
337
  return self
338
338
 
339
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
340
- if self._drop_input_cols:
341
- return []
342
- else:
343
- return list(set(dataset.columns) - set(self.output_cols))
344
-
345
339
  def _batch_inference_validate_snowpark(
346
340
  self,
347
341
  dataset: DataFrame,
348
342
  inference_method: str,
349
343
  ) -> List[str]:
350
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
344
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
345
+ return the available package that exists in the snowflake anaconda channel
346
+
347
+ Args:
348
+ dataset: snowpark dataframe
349
+ inference_method: the inference method such as predict, score...
350
+
351
+ Raises:
352
+ SnowflakeMLException: If the estimator is not fitted, raise error
353
+ SnowflakeMLException: If the session is None, raise error
354
+
355
+ Returns:
356
+ A list of available package that exists in the snowflake anaconda channel
351
357
  """
352
358
  if not self._is_fitted:
353
359
  raise exceptions.SnowflakeMLException(
@@ -419,7 +425,7 @@ class AgglomerativeClustering(BaseTransformer):
419
425
  transform_kwargs = dict(
420
426
  session = dataset._session,
421
427
  dependencies = self._deps,
422
- pass_through_cols = self._get_pass_through_columns(dataset),
428
+ drop_input_cols = self._drop_input_cols,
423
429
  expected_output_cols_type = expected_type_inferred,
424
430
  )
425
431
 
@@ -479,16 +485,16 @@ class AgglomerativeClustering(BaseTransformer):
479
485
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
480
486
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
481
487
  # each row containing a list of values.
482
- expected_dtype = "ARRAY"
488
+ expected_dtype = "array"
483
489
 
484
490
  # If we were unable to assign a type to this transform in the factory, infer the type here.
485
491
  if expected_dtype == "":
486
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
492
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
487
493
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
488
- expected_dtype = "ARRAY"
489
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
494
+ expected_dtype = "array"
495
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
490
496
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
491
- expected_dtype = "ARRAY"
497
+ expected_dtype = "array"
492
498
  else:
493
499
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
494
500
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -506,7 +512,7 @@ class AgglomerativeClustering(BaseTransformer):
506
512
  transform_kwargs = dict(
507
513
  session = dataset._session,
508
514
  dependencies = self._deps,
509
- pass_through_cols = self._get_pass_through_columns(dataset),
515
+ drop_input_cols = self._drop_input_cols,
510
516
  expected_output_cols_type = expected_dtype,
511
517
  )
512
518
 
@@ -559,7 +565,7 @@ class AgglomerativeClustering(BaseTransformer):
559
565
  subproject=_SUBPROJECT,
560
566
  )
561
567
  output_result, fitted_estimator = model_trainer.train_fit_predict(
562
- pass_through_columns=self._get_pass_through_columns(dataset),
568
+ drop_input_cols=self._drop_input_cols,
563
569
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
564
570
  )
565
571
  self._sklearn_object = fitted_estimator
@@ -577,44 +583,6 @@ class AgglomerativeClustering(BaseTransformer):
577
583
  assert self._sklearn_object is not None
578
584
  return self._sklearn_object.embedding_
579
585
 
580
-
581
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
582
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
583
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
584
- """
585
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
586
- if output_cols:
587
- output_cols = [
588
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
589
- for c in output_cols
590
- ]
591
- elif getattr(self._sklearn_object, "classes_", None) is None:
592
- output_cols = [output_cols_prefix]
593
- elif self._sklearn_object is not None:
594
- classes = self._sklearn_object.classes_
595
- if isinstance(classes, numpy.ndarray):
596
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
597
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
598
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
599
- output_cols = []
600
- for i, cl in enumerate(classes):
601
- # For binary classification, there is only one output column for each class
602
- # ndarray as the two classes are complementary.
603
- if len(cl) == 2:
604
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
605
- else:
606
- output_cols.extend([
607
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
608
- ])
609
- else:
610
- output_cols = []
611
-
612
- # Make sure column names are valid snowflake identifiers.
613
- assert output_cols is not None # Make MyPy happy
614
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
615
-
616
- return rv
617
-
618
586
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
619
587
  @telemetry.send_api_usage_telemetry(
620
588
  project=_PROJECT,
@@ -654,7 +622,7 @@ class AgglomerativeClustering(BaseTransformer):
654
622
  transform_kwargs = dict(
655
623
  session=dataset._session,
656
624
  dependencies=self._deps,
657
- pass_through_cols=self._get_pass_through_columns(dataset),
625
+ drop_input_cols = self._drop_input_cols,
658
626
  expected_output_cols_type="float",
659
627
  )
660
628
 
@@ -719,7 +687,7 @@ class AgglomerativeClustering(BaseTransformer):
719
687
  transform_kwargs = dict(
720
688
  session=dataset._session,
721
689
  dependencies=self._deps,
722
- pass_through_cols=self._get_pass_through_columns(dataset),
690
+ drop_input_cols = self._drop_input_cols,
723
691
  expected_output_cols_type="float",
724
692
  )
725
693
  elif isinstance(dataset, pd.DataFrame):
@@ -780,7 +748,7 @@ class AgglomerativeClustering(BaseTransformer):
780
748
  transform_kwargs = dict(
781
749
  session=dataset._session,
782
750
  dependencies=self._deps,
783
- pass_through_cols=self._get_pass_through_columns(dataset),
751
+ drop_input_cols = self._drop_input_cols,
784
752
  expected_output_cols_type="float",
785
753
  )
786
754
 
@@ -845,7 +813,7 @@ class AgglomerativeClustering(BaseTransformer):
845
813
  transform_kwargs = dict(
846
814
  session=dataset._session,
847
815
  dependencies=self._deps,
848
- pass_through_cols=self._get_pass_through_columns(dataset),
816
+ drop_input_cols = self._drop_input_cols,
849
817
  expected_output_cols_type="float",
850
818
  )
851
819
 
@@ -899,13 +867,17 @@ class AgglomerativeClustering(BaseTransformer):
899
867
  transform_kwargs: ScoreKwargsTypedDict = dict()
900
868
 
901
869
  if isinstance(dataset, DataFrame):
870
+ self._deps = self._batch_inference_validate_snowpark(
871
+ dataset=dataset,
872
+ inference_method="score",
873
+ )
902
874
  selected_cols = self._get_active_columns()
903
875
  if len(selected_cols) > 0:
904
876
  dataset = dataset.select(selected_cols)
905
877
  assert isinstance(dataset._session, Session) # keep mypy happy
906
878
  transform_kwargs = dict(
907
879
  session=dataset._session,
908
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
880
+ dependencies=["snowflake-snowpark-python"] + self._deps,
909
881
  score_sproc_imports=['sklearn'],
910
882
  )
911
883
  elif isinstance(dataset, pd.DataFrame):
@@ -979,9 +951,9 @@ class AgglomerativeClustering(BaseTransformer):
979
951
  transform_kwargs = dict(
980
952
  session = dataset._session,
981
953
  dependencies = self._deps,
982
- pass_through_cols = self._get_pass_through_columns(dataset),
983
- expected_output_cols_type = "array",
984
- n_neighbors = n_neighbors,
954
+ drop_input_cols = self._drop_input_cols,
955
+ expected_output_cols_type="array",
956
+ n_neighbors = n_neighbors,
985
957
  return_distance = return_distance
986
958
  )
987
959
  elif isinstance(dataset, pd.DataFrame):
@@ -294,18 +294,24 @@ class Birch(BaseTransformer):
294
294
  self._get_model_signatures(dataset)
295
295
  return self
296
296
 
297
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
298
- if self._drop_input_cols:
299
- return []
300
- else:
301
- return list(set(dataset.columns) - set(self.output_cols))
302
-
303
297
  def _batch_inference_validate_snowpark(
304
298
  self,
305
299
  dataset: DataFrame,
306
300
  inference_method: str,
307
301
  ) -> List[str]:
308
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
302
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
303
+ return the available package that exists in the snowflake anaconda channel
304
+
305
+ Args:
306
+ dataset: snowpark dataframe
307
+ inference_method: the inference method such as predict, score...
308
+
309
+ Raises:
310
+ SnowflakeMLException: If the estimator is not fitted, raise error
311
+ SnowflakeMLException: If the session is None, raise error
312
+
313
+ Returns:
314
+ A list of available package that exists in the snowflake anaconda channel
309
315
  """
310
316
  if not self._is_fitted:
311
317
  raise exceptions.SnowflakeMLException(
@@ -379,7 +385,7 @@ class Birch(BaseTransformer):
379
385
  transform_kwargs = dict(
380
386
  session = dataset._session,
381
387
  dependencies = self._deps,
382
- pass_through_cols = self._get_pass_through_columns(dataset),
388
+ drop_input_cols = self._drop_input_cols,
383
389
  expected_output_cols_type = expected_type_inferred,
384
390
  )
385
391
 
@@ -441,16 +447,16 @@ class Birch(BaseTransformer):
441
447
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
442
448
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
443
449
  # each row containing a list of values.
444
- expected_dtype = "ARRAY"
450
+ expected_dtype = "array"
445
451
 
446
452
  # If we were unable to assign a type to this transform in the factory, infer the type here.
447
453
  if expected_dtype == "":
448
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
454
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
449
455
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
450
- expected_dtype = "ARRAY"
451
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
456
+ expected_dtype = "array"
457
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
452
458
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
453
- expected_dtype = "ARRAY"
459
+ expected_dtype = "array"
454
460
  else:
455
461
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
456
462
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -468,7 +474,7 @@ class Birch(BaseTransformer):
468
474
  transform_kwargs = dict(
469
475
  session = dataset._session,
470
476
  dependencies = self._deps,
471
- pass_through_cols = self._get_pass_through_columns(dataset),
477
+ drop_input_cols = self._drop_input_cols,
472
478
  expected_output_cols_type = expected_dtype,
473
479
  )
474
480
 
@@ -521,7 +527,7 @@ class Birch(BaseTransformer):
521
527
  subproject=_SUBPROJECT,
522
528
  )
523
529
  output_result, fitted_estimator = model_trainer.train_fit_predict(
524
- pass_through_columns=self._get_pass_through_columns(dataset),
530
+ drop_input_cols=self._drop_input_cols,
525
531
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
526
532
  )
527
533
  self._sklearn_object = fitted_estimator
@@ -539,44 +545,6 @@ class Birch(BaseTransformer):
539
545
  assert self._sklearn_object is not None
540
546
  return self._sklearn_object.embedding_
541
547
 
542
-
543
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
544
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
545
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
546
- """
547
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
548
- if output_cols:
549
- output_cols = [
550
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
551
- for c in output_cols
552
- ]
553
- elif getattr(self._sklearn_object, "classes_", None) is None:
554
- output_cols = [output_cols_prefix]
555
- elif self._sklearn_object is not None:
556
- classes = self._sklearn_object.classes_
557
- if isinstance(classes, numpy.ndarray):
558
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
559
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
560
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
561
- output_cols = []
562
- for i, cl in enumerate(classes):
563
- # For binary classification, there is only one output column for each class
564
- # ndarray as the two classes are complementary.
565
- if len(cl) == 2:
566
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
567
- else:
568
- output_cols.extend([
569
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
570
- ])
571
- else:
572
- output_cols = []
573
-
574
- # Make sure column names are valid snowflake identifiers.
575
- assert output_cols is not None # Make MyPy happy
576
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
577
-
578
- return rv
579
-
580
548
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
581
549
  @telemetry.send_api_usage_telemetry(
582
550
  project=_PROJECT,
@@ -616,7 +584,7 @@ class Birch(BaseTransformer):
616
584
  transform_kwargs = dict(
617
585
  session=dataset._session,
618
586
  dependencies=self._deps,
619
- pass_through_cols=self._get_pass_through_columns(dataset),
587
+ drop_input_cols = self._drop_input_cols,
620
588
  expected_output_cols_type="float",
621
589
  )
622
590
 
@@ -681,7 +649,7 @@ class Birch(BaseTransformer):
681
649
  transform_kwargs = dict(
682
650
  session=dataset._session,
683
651
  dependencies=self._deps,
684
- pass_through_cols=self._get_pass_through_columns(dataset),
652
+ drop_input_cols = self._drop_input_cols,
685
653
  expected_output_cols_type="float",
686
654
  )
687
655
  elif isinstance(dataset, pd.DataFrame):
@@ -742,7 +710,7 @@ class Birch(BaseTransformer):
742
710
  transform_kwargs = dict(
743
711
  session=dataset._session,
744
712
  dependencies=self._deps,
745
- pass_through_cols=self._get_pass_through_columns(dataset),
713
+ drop_input_cols = self._drop_input_cols,
746
714
  expected_output_cols_type="float",
747
715
  )
748
716
 
@@ -807,7 +775,7 @@ class Birch(BaseTransformer):
807
775
  transform_kwargs = dict(
808
776
  session=dataset._session,
809
777
  dependencies=self._deps,
810
- pass_through_cols=self._get_pass_through_columns(dataset),
778
+ drop_input_cols = self._drop_input_cols,
811
779
  expected_output_cols_type="float",
812
780
  )
813
781
 
@@ -861,13 +829,17 @@ class Birch(BaseTransformer):
861
829
  transform_kwargs: ScoreKwargsTypedDict = dict()
862
830
 
863
831
  if isinstance(dataset, DataFrame):
832
+ self._deps = self._batch_inference_validate_snowpark(
833
+ dataset=dataset,
834
+ inference_method="score",
835
+ )
864
836
  selected_cols = self._get_active_columns()
865
837
  if len(selected_cols) > 0:
866
838
  dataset = dataset.select(selected_cols)
867
839
  assert isinstance(dataset._session, Session) # keep mypy happy
868
840
  transform_kwargs = dict(
869
841
  session=dataset._session,
870
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
842
+ dependencies=["snowflake-snowpark-python"] + self._deps,
871
843
  score_sproc_imports=['sklearn'],
872
844
  )
873
845
  elif isinstance(dataset, pd.DataFrame):
@@ -941,9 +913,9 @@ class Birch(BaseTransformer):
941
913
  transform_kwargs = dict(
942
914
  session = dataset._session,
943
915
  dependencies = self._deps,
944
- pass_through_cols = self._get_pass_through_columns(dataset),
945
- expected_output_cols_type = "array",
946
- n_neighbors = n_neighbors,
916
+ drop_input_cols = self._drop_input_cols,
917
+ expected_output_cols_type="array",
918
+ n_neighbors = n_neighbors,
947
919
  return_distance = return_distance
948
920
  )
949
921
  elif isinstance(dataset, pd.DataFrame):