snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -338,18 +338,24 @@ class KMeans(BaseTransformer):
338
338
  self._get_model_signatures(dataset)
339
339
  return self
340
340
 
341
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
342
- if self._drop_input_cols:
343
- return []
344
- else:
345
- return list(set(dataset.columns) - set(self.output_cols))
346
-
347
341
  def _batch_inference_validate_snowpark(
348
342
  self,
349
343
  dataset: DataFrame,
350
344
  inference_method: str,
351
345
  ) -> List[str]:
352
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
346
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
347
+ return the available package that exists in the snowflake anaconda channel
348
+
349
+ Args:
350
+ dataset: snowpark dataframe
351
+ inference_method: the inference method such as predict, score...
352
+
353
+ Raises:
354
+ SnowflakeMLException: If the estimator is not fitted, raise error
355
+ SnowflakeMLException: If the session is None, raise error
356
+
357
+ Returns:
358
+ A list of available package that exists in the snowflake anaconda channel
353
359
  """
354
360
  if not self._is_fitted:
355
361
  raise exceptions.SnowflakeMLException(
@@ -423,7 +429,7 @@ class KMeans(BaseTransformer):
423
429
  transform_kwargs = dict(
424
430
  session = dataset._session,
425
431
  dependencies = self._deps,
426
- pass_through_cols = self._get_pass_through_columns(dataset),
432
+ drop_input_cols = self._drop_input_cols,
427
433
  expected_output_cols_type = expected_type_inferred,
428
434
  )
429
435
 
@@ -485,16 +491,16 @@ class KMeans(BaseTransformer):
485
491
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
486
492
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
487
493
  # each row containing a list of values.
488
- expected_dtype = "ARRAY"
494
+ expected_dtype = "array"
489
495
 
490
496
  # If we were unable to assign a type to this transform in the factory, infer the type here.
491
497
  if expected_dtype == "":
492
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
498
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
493
499
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
494
- expected_dtype = "ARRAY"
495
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
500
+ expected_dtype = "array"
501
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
496
502
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
497
- expected_dtype = "ARRAY"
503
+ expected_dtype = "array"
498
504
  else:
499
505
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
500
506
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -512,7 +518,7 @@ class KMeans(BaseTransformer):
512
518
  transform_kwargs = dict(
513
519
  session = dataset._session,
514
520
  dependencies = self._deps,
515
- pass_through_cols = self._get_pass_through_columns(dataset),
521
+ drop_input_cols = self._drop_input_cols,
516
522
  expected_output_cols_type = expected_dtype,
517
523
  )
518
524
 
@@ -565,7 +571,7 @@ class KMeans(BaseTransformer):
565
571
  subproject=_SUBPROJECT,
566
572
  )
567
573
  output_result, fitted_estimator = model_trainer.train_fit_predict(
568
- pass_through_columns=self._get_pass_through_columns(dataset),
574
+ drop_input_cols=self._drop_input_cols,
569
575
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
570
576
  )
571
577
  self._sklearn_object = fitted_estimator
@@ -583,44 +589,6 @@ class KMeans(BaseTransformer):
583
589
  assert self._sklearn_object is not None
584
590
  return self._sklearn_object.embedding_
585
591
 
586
-
587
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
588
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
589
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
590
- """
591
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
592
- if output_cols:
593
- output_cols = [
594
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
595
- for c in output_cols
596
- ]
597
- elif getattr(self._sklearn_object, "classes_", None) is None:
598
- output_cols = [output_cols_prefix]
599
- elif self._sklearn_object is not None:
600
- classes = self._sklearn_object.classes_
601
- if isinstance(classes, numpy.ndarray):
602
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
603
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
604
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
605
- output_cols = []
606
- for i, cl in enumerate(classes):
607
- # For binary classification, there is only one output column for each class
608
- # ndarray as the two classes are complementary.
609
- if len(cl) == 2:
610
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
611
- else:
612
- output_cols.extend([
613
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
614
- ])
615
- else:
616
- output_cols = []
617
-
618
- # Make sure column names are valid snowflake identifiers.
619
- assert output_cols is not None # Make MyPy happy
620
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
621
-
622
- return rv
623
-
624
592
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
625
593
  @telemetry.send_api_usage_telemetry(
626
594
  project=_PROJECT,
@@ -660,7 +628,7 @@ class KMeans(BaseTransformer):
660
628
  transform_kwargs = dict(
661
629
  session=dataset._session,
662
630
  dependencies=self._deps,
663
- pass_through_cols=self._get_pass_through_columns(dataset),
631
+ drop_input_cols = self._drop_input_cols,
664
632
  expected_output_cols_type="float",
665
633
  )
666
634
 
@@ -725,7 +693,7 @@ class KMeans(BaseTransformer):
725
693
  transform_kwargs = dict(
726
694
  session=dataset._session,
727
695
  dependencies=self._deps,
728
- pass_through_cols=self._get_pass_through_columns(dataset),
696
+ drop_input_cols = self._drop_input_cols,
729
697
  expected_output_cols_type="float",
730
698
  )
731
699
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +754,7 @@ class KMeans(BaseTransformer):
786
754
  transform_kwargs = dict(
787
755
  session=dataset._session,
788
756
  dependencies=self._deps,
789
- pass_through_cols=self._get_pass_through_columns(dataset),
757
+ drop_input_cols = self._drop_input_cols,
790
758
  expected_output_cols_type="float",
791
759
  )
792
760
 
@@ -851,7 +819,7 @@ class KMeans(BaseTransformer):
851
819
  transform_kwargs = dict(
852
820
  session=dataset._session,
853
821
  dependencies=self._deps,
854
- pass_through_cols=self._get_pass_through_columns(dataset),
822
+ drop_input_cols = self._drop_input_cols,
855
823
  expected_output_cols_type="float",
856
824
  )
857
825
 
@@ -907,13 +875,17 @@ class KMeans(BaseTransformer):
907
875
  transform_kwargs: ScoreKwargsTypedDict = dict()
908
876
 
909
877
  if isinstance(dataset, DataFrame):
878
+ self._deps = self._batch_inference_validate_snowpark(
879
+ dataset=dataset,
880
+ inference_method="score",
881
+ )
910
882
  selected_cols = self._get_active_columns()
911
883
  if len(selected_cols) > 0:
912
884
  dataset = dataset.select(selected_cols)
913
885
  assert isinstance(dataset._session, Session) # keep mypy happy
914
886
  transform_kwargs = dict(
915
887
  session=dataset._session,
916
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
888
+ dependencies=["snowflake-snowpark-python"] + self._deps,
917
889
  score_sproc_imports=['sklearn'],
918
890
  )
919
891
  elif isinstance(dataset, pd.DataFrame):
@@ -987,9 +959,9 @@ class KMeans(BaseTransformer):
987
959
  transform_kwargs = dict(
988
960
  session = dataset._session,
989
961
  dependencies = self._deps,
990
- pass_through_cols = self._get_pass_through_columns(dataset),
991
- expected_output_cols_type = "array",
992
- n_neighbors = n_neighbors,
962
+ drop_input_cols = self._drop_input_cols,
963
+ expected_output_cols_type="array",
964
+ n_neighbors = n_neighbors,
993
965
  return_distance = return_distance
994
966
  )
995
967
  elif isinstance(dataset, pd.DataFrame):
@@ -314,18 +314,24 @@ class MeanShift(BaseTransformer):
314
314
  self._get_model_signatures(dataset)
315
315
  return self
316
316
 
317
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
318
- if self._drop_input_cols:
319
- return []
320
- else:
321
- return list(set(dataset.columns) - set(self.output_cols))
322
-
323
317
  def _batch_inference_validate_snowpark(
324
318
  self,
325
319
  dataset: DataFrame,
326
320
  inference_method: str,
327
321
  ) -> List[str]:
328
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
322
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
323
+ return the available package that exists in the snowflake anaconda channel
324
+
325
+ Args:
326
+ dataset: snowpark dataframe
327
+ inference_method: the inference method such as predict, score...
328
+
329
+ Raises:
330
+ SnowflakeMLException: If the estimator is not fitted, raise error
331
+ SnowflakeMLException: If the session is None, raise error
332
+
333
+ Returns:
334
+ A list of available package that exists in the snowflake anaconda channel
329
335
  """
330
336
  if not self._is_fitted:
331
337
  raise exceptions.SnowflakeMLException(
@@ -399,7 +405,7 @@ class MeanShift(BaseTransformer):
399
405
  transform_kwargs = dict(
400
406
  session = dataset._session,
401
407
  dependencies = self._deps,
402
- pass_through_cols = self._get_pass_through_columns(dataset),
408
+ drop_input_cols = self._drop_input_cols,
403
409
  expected_output_cols_type = expected_type_inferred,
404
410
  )
405
411
 
@@ -459,16 +465,16 @@ class MeanShift(BaseTransformer):
459
465
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
460
466
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
461
467
  # each row containing a list of values.
462
- expected_dtype = "ARRAY"
468
+ expected_dtype = "array"
463
469
 
464
470
  # If we were unable to assign a type to this transform in the factory, infer the type here.
465
471
  if expected_dtype == "":
466
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
472
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
467
473
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
468
- expected_dtype = "ARRAY"
469
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
474
+ expected_dtype = "array"
475
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
470
476
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
471
- expected_dtype = "ARRAY"
477
+ expected_dtype = "array"
472
478
  else:
473
479
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
474
480
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -486,7 +492,7 @@ class MeanShift(BaseTransformer):
486
492
  transform_kwargs = dict(
487
493
  session = dataset._session,
488
494
  dependencies = self._deps,
489
- pass_through_cols = self._get_pass_through_columns(dataset),
495
+ drop_input_cols = self._drop_input_cols,
490
496
  expected_output_cols_type = expected_dtype,
491
497
  )
492
498
 
@@ -539,7 +545,7 @@ class MeanShift(BaseTransformer):
539
545
  subproject=_SUBPROJECT,
540
546
  )
541
547
  output_result, fitted_estimator = model_trainer.train_fit_predict(
542
- pass_through_columns=self._get_pass_through_columns(dataset),
548
+ drop_input_cols=self._drop_input_cols,
543
549
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
544
550
  )
545
551
  self._sklearn_object = fitted_estimator
@@ -557,44 +563,6 @@ class MeanShift(BaseTransformer):
557
563
  assert self._sklearn_object is not None
558
564
  return self._sklearn_object.embedding_
559
565
 
560
-
561
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
562
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
563
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
564
- """
565
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
566
- if output_cols:
567
- output_cols = [
568
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
569
- for c in output_cols
570
- ]
571
- elif getattr(self._sklearn_object, "classes_", None) is None:
572
- output_cols = [output_cols_prefix]
573
- elif self._sklearn_object is not None:
574
- classes = self._sklearn_object.classes_
575
- if isinstance(classes, numpy.ndarray):
576
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
577
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
578
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
579
- output_cols = []
580
- for i, cl in enumerate(classes):
581
- # For binary classification, there is only one output column for each class
582
- # ndarray as the two classes are complementary.
583
- if len(cl) == 2:
584
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
585
- else:
586
- output_cols.extend([
587
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
588
- ])
589
- else:
590
- output_cols = []
591
-
592
- # Make sure column names are valid snowflake identifiers.
593
- assert output_cols is not None # Make MyPy happy
594
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
595
-
596
- return rv
597
-
598
566
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
599
567
  @telemetry.send_api_usage_telemetry(
600
568
  project=_PROJECT,
@@ -634,7 +602,7 @@ class MeanShift(BaseTransformer):
634
602
  transform_kwargs = dict(
635
603
  session=dataset._session,
636
604
  dependencies=self._deps,
637
- pass_through_cols=self._get_pass_through_columns(dataset),
605
+ drop_input_cols = self._drop_input_cols,
638
606
  expected_output_cols_type="float",
639
607
  )
640
608
 
@@ -699,7 +667,7 @@ class MeanShift(BaseTransformer):
699
667
  transform_kwargs = dict(
700
668
  session=dataset._session,
701
669
  dependencies=self._deps,
702
- pass_through_cols=self._get_pass_through_columns(dataset),
670
+ drop_input_cols = self._drop_input_cols,
703
671
  expected_output_cols_type="float",
704
672
  )
705
673
  elif isinstance(dataset, pd.DataFrame):
@@ -760,7 +728,7 @@ class MeanShift(BaseTransformer):
760
728
  transform_kwargs = dict(
761
729
  session=dataset._session,
762
730
  dependencies=self._deps,
763
- pass_through_cols=self._get_pass_through_columns(dataset),
731
+ drop_input_cols = self._drop_input_cols,
764
732
  expected_output_cols_type="float",
765
733
  )
766
734
 
@@ -825,7 +793,7 @@ class MeanShift(BaseTransformer):
825
793
  transform_kwargs = dict(
826
794
  session=dataset._session,
827
795
  dependencies=self._deps,
828
- pass_through_cols=self._get_pass_through_columns(dataset),
796
+ drop_input_cols = self._drop_input_cols,
829
797
  expected_output_cols_type="float",
830
798
  )
831
799
 
@@ -879,13 +847,17 @@ class MeanShift(BaseTransformer):
879
847
  transform_kwargs: ScoreKwargsTypedDict = dict()
880
848
 
881
849
  if isinstance(dataset, DataFrame):
850
+ self._deps = self._batch_inference_validate_snowpark(
851
+ dataset=dataset,
852
+ inference_method="score",
853
+ )
882
854
  selected_cols = self._get_active_columns()
883
855
  if len(selected_cols) > 0:
884
856
  dataset = dataset.select(selected_cols)
885
857
  assert isinstance(dataset._session, Session) # keep mypy happy
886
858
  transform_kwargs = dict(
887
859
  session=dataset._session,
888
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
860
+ dependencies=["snowflake-snowpark-python"] + self._deps,
889
861
  score_sproc_imports=['sklearn'],
890
862
  )
891
863
  elif isinstance(dataset, pd.DataFrame):
@@ -959,9 +931,9 @@ class MeanShift(BaseTransformer):
959
931
  transform_kwargs = dict(
960
932
  session = dataset._session,
961
933
  dependencies = self._deps,
962
- pass_through_cols = self._get_pass_through_columns(dataset),
963
- expected_output_cols_type = "array",
964
- n_neighbors = n_neighbors,
934
+ drop_input_cols = self._drop_input_cols,
935
+ expected_output_cols_type="array",
936
+ n_neighbors = n_neighbors,
965
937
  return_distance = return_distance
966
938
  )
967
939
  elif isinstance(dataset, pd.DataFrame):
@@ -364,18 +364,24 @@ class MiniBatchKMeans(BaseTransformer):
364
364
  self._get_model_signatures(dataset)
365
365
  return self
366
366
 
367
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
368
- if self._drop_input_cols:
369
- return []
370
- else:
371
- return list(set(dataset.columns) - set(self.output_cols))
372
-
373
367
  def _batch_inference_validate_snowpark(
374
368
  self,
375
369
  dataset: DataFrame,
376
370
  inference_method: str,
377
371
  ) -> List[str]:
378
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
372
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
373
+ return the available package that exists in the snowflake anaconda channel
374
+
375
+ Args:
376
+ dataset: snowpark dataframe
377
+ inference_method: the inference method such as predict, score...
378
+
379
+ Raises:
380
+ SnowflakeMLException: If the estimator is not fitted, raise error
381
+ SnowflakeMLException: If the session is None, raise error
382
+
383
+ Returns:
384
+ A list of available package that exists in the snowflake anaconda channel
379
385
  """
380
386
  if not self._is_fitted:
381
387
  raise exceptions.SnowflakeMLException(
@@ -449,7 +455,7 @@ class MiniBatchKMeans(BaseTransformer):
449
455
  transform_kwargs = dict(
450
456
  session = dataset._session,
451
457
  dependencies = self._deps,
452
- pass_through_cols = self._get_pass_through_columns(dataset),
458
+ drop_input_cols = self._drop_input_cols,
453
459
  expected_output_cols_type = expected_type_inferred,
454
460
  )
455
461
 
@@ -511,16 +517,16 @@ class MiniBatchKMeans(BaseTransformer):
511
517
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
512
518
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
513
519
  # each row containing a list of values.
514
- expected_dtype = "ARRAY"
520
+ expected_dtype = "array"
515
521
 
516
522
  # If we were unable to assign a type to this transform in the factory, infer the type here.
517
523
  if expected_dtype == "":
518
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
524
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
519
525
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
520
- expected_dtype = "ARRAY"
521
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
526
+ expected_dtype = "array"
527
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
522
528
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
523
- expected_dtype = "ARRAY"
529
+ expected_dtype = "array"
524
530
  else:
525
531
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
526
532
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -538,7 +544,7 @@ class MiniBatchKMeans(BaseTransformer):
538
544
  transform_kwargs = dict(
539
545
  session = dataset._session,
540
546
  dependencies = self._deps,
541
- pass_through_cols = self._get_pass_through_columns(dataset),
547
+ drop_input_cols = self._drop_input_cols,
542
548
  expected_output_cols_type = expected_dtype,
543
549
  )
544
550
 
@@ -591,7 +597,7 @@ class MiniBatchKMeans(BaseTransformer):
591
597
  subproject=_SUBPROJECT,
592
598
  )
593
599
  output_result, fitted_estimator = model_trainer.train_fit_predict(
594
- pass_through_columns=self._get_pass_through_columns(dataset),
600
+ drop_input_cols=self._drop_input_cols,
595
601
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
596
602
  )
597
603
  self._sklearn_object = fitted_estimator
@@ -609,44 +615,6 @@ class MiniBatchKMeans(BaseTransformer):
609
615
  assert self._sklearn_object is not None
610
616
  return self._sklearn_object.embedding_
611
617
 
612
-
613
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
614
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
615
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
616
- """
617
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
618
- if output_cols:
619
- output_cols = [
620
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
621
- for c in output_cols
622
- ]
623
- elif getattr(self._sklearn_object, "classes_", None) is None:
624
- output_cols = [output_cols_prefix]
625
- elif self._sklearn_object is not None:
626
- classes = self._sklearn_object.classes_
627
- if isinstance(classes, numpy.ndarray):
628
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
629
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
630
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
631
- output_cols = []
632
- for i, cl in enumerate(classes):
633
- # For binary classification, there is only one output column for each class
634
- # ndarray as the two classes are complementary.
635
- if len(cl) == 2:
636
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
637
- else:
638
- output_cols.extend([
639
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
640
- ])
641
- else:
642
- output_cols = []
643
-
644
- # Make sure column names are valid snowflake identifiers.
645
- assert output_cols is not None # Make MyPy happy
646
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
647
-
648
- return rv
649
-
650
618
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
651
619
  @telemetry.send_api_usage_telemetry(
652
620
  project=_PROJECT,
@@ -686,7 +654,7 @@ class MiniBatchKMeans(BaseTransformer):
686
654
  transform_kwargs = dict(
687
655
  session=dataset._session,
688
656
  dependencies=self._deps,
689
- pass_through_cols=self._get_pass_through_columns(dataset),
657
+ drop_input_cols = self._drop_input_cols,
690
658
  expected_output_cols_type="float",
691
659
  )
692
660
 
@@ -751,7 +719,7 @@ class MiniBatchKMeans(BaseTransformer):
751
719
  transform_kwargs = dict(
752
720
  session=dataset._session,
753
721
  dependencies=self._deps,
754
- pass_through_cols=self._get_pass_through_columns(dataset),
722
+ drop_input_cols = self._drop_input_cols,
755
723
  expected_output_cols_type="float",
756
724
  )
757
725
  elif isinstance(dataset, pd.DataFrame):
@@ -812,7 +780,7 @@ class MiniBatchKMeans(BaseTransformer):
812
780
  transform_kwargs = dict(
813
781
  session=dataset._session,
814
782
  dependencies=self._deps,
815
- pass_through_cols=self._get_pass_through_columns(dataset),
783
+ drop_input_cols = self._drop_input_cols,
816
784
  expected_output_cols_type="float",
817
785
  )
818
786
 
@@ -877,7 +845,7 @@ class MiniBatchKMeans(BaseTransformer):
877
845
  transform_kwargs = dict(
878
846
  session=dataset._session,
879
847
  dependencies=self._deps,
880
- pass_through_cols=self._get_pass_through_columns(dataset),
848
+ drop_input_cols = self._drop_input_cols,
881
849
  expected_output_cols_type="float",
882
850
  )
883
851
 
@@ -933,13 +901,17 @@ class MiniBatchKMeans(BaseTransformer):
933
901
  transform_kwargs: ScoreKwargsTypedDict = dict()
934
902
 
935
903
  if isinstance(dataset, DataFrame):
904
+ self._deps = self._batch_inference_validate_snowpark(
905
+ dataset=dataset,
906
+ inference_method="score",
907
+ )
936
908
  selected_cols = self._get_active_columns()
937
909
  if len(selected_cols) > 0:
938
910
  dataset = dataset.select(selected_cols)
939
911
  assert isinstance(dataset._session, Session) # keep mypy happy
940
912
  transform_kwargs = dict(
941
913
  session=dataset._session,
942
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
914
+ dependencies=["snowflake-snowpark-python"] + self._deps,
943
915
  score_sproc_imports=['sklearn'],
944
916
  )
945
917
  elif isinstance(dataset, pd.DataFrame):
@@ -1013,9 +985,9 @@ class MiniBatchKMeans(BaseTransformer):
1013
985
  transform_kwargs = dict(
1014
986
  session = dataset._session,
1015
987
  dependencies = self._deps,
1016
- pass_through_cols = self._get_pass_through_columns(dataset),
1017
- expected_output_cols_type = "array",
1018
- n_neighbors = n_neighbors,
988
+ drop_input_cols = self._drop_input_cols,
989
+ expected_output_cols_type="array",
990
+ n_neighbors = n_neighbors,
1019
991
  return_distance = return_distance
1020
992
  )
1021
993
  elif isinstance(dataset, pd.DataFrame):