snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -391,18 +391,24 @@ class DecisionTreeClassifier(BaseTransformer):
391
391
  self._get_model_signatures(dataset)
392
392
  return self
393
393
 
394
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
395
- if self._drop_input_cols:
396
- return []
397
- else:
398
- return list(set(dataset.columns) - set(self.output_cols))
399
-
400
394
  def _batch_inference_validate_snowpark(
401
395
  self,
402
396
  dataset: DataFrame,
403
397
  inference_method: str,
404
398
  ) -> List[str]:
405
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
399
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
400
+ return the available package that exists in the snowflake anaconda channel
401
+
402
+ Args:
403
+ dataset: snowpark dataframe
404
+ inference_method: the inference method such as predict, score...
405
+
406
+ Raises:
407
+ SnowflakeMLException: If the estimator is not fitted, raise error
408
+ SnowflakeMLException: If the session is None, raise error
409
+
410
+ Returns:
411
+ A list of available package that exists in the snowflake anaconda channel
406
412
  """
407
413
  if not self._is_fitted:
408
414
  raise exceptions.SnowflakeMLException(
@@ -476,7 +482,7 @@ class DecisionTreeClassifier(BaseTransformer):
476
482
  transform_kwargs = dict(
477
483
  session = dataset._session,
478
484
  dependencies = self._deps,
479
- pass_through_cols = self._get_pass_through_columns(dataset),
485
+ drop_input_cols = self._drop_input_cols,
480
486
  expected_output_cols_type = expected_type_inferred,
481
487
  )
482
488
 
@@ -536,16 +542,16 @@ class DecisionTreeClassifier(BaseTransformer):
536
542
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
537
543
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
538
544
  # each row containing a list of values.
539
- expected_dtype = "ARRAY"
545
+ expected_dtype = "array"
540
546
 
541
547
  # If we were unable to assign a type to this transform in the factory, infer the type here.
542
548
  if expected_dtype == "":
543
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
549
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
544
550
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
545
- expected_dtype = "ARRAY"
546
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
551
+ expected_dtype = "array"
552
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
547
553
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
548
- expected_dtype = "ARRAY"
554
+ expected_dtype = "array"
549
555
  else:
550
556
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
551
557
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -563,7 +569,7 @@ class DecisionTreeClassifier(BaseTransformer):
563
569
  transform_kwargs = dict(
564
570
  session = dataset._session,
565
571
  dependencies = self._deps,
566
- pass_through_cols = self._get_pass_through_columns(dataset),
572
+ drop_input_cols = self._drop_input_cols,
567
573
  expected_output_cols_type = expected_dtype,
568
574
  )
569
575
 
@@ -614,7 +620,7 @@ class DecisionTreeClassifier(BaseTransformer):
614
620
  subproject=_SUBPROJECT,
615
621
  )
616
622
  output_result, fitted_estimator = model_trainer.train_fit_predict(
617
- pass_through_columns=self._get_pass_through_columns(dataset),
623
+ drop_input_cols=self._drop_input_cols,
618
624
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
619
625
  )
620
626
  self._sklearn_object = fitted_estimator
@@ -632,44 +638,6 @@ class DecisionTreeClassifier(BaseTransformer):
632
638
  assert self._sklearn_object is not None
633
639
  return self._sklearn_object.embedding_
634
640
 
635
-
636
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
637
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
638
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
639
- """
640
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
641
- if output_cols:
642
- output_cols = [
643
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
644
- for c in output_cols
645
- ]
646
- elif getattr(self._sklearn_object, "classes_", None) is None:
647
- output_cols = [output_cols_prefix]
648
- elif self._sklearn_object is not None:
649
- classes = self._sklearn_object.classes_
650
- if isinstance(classes, numpy.ndarray):
651
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
652
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
653
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
654
- output_cols = []
655
- for i, cl in enumerate(classes):
656
- # For binary classification, there is only one output column for each class
657
- # ndarray as the two classes are complementary.
658
- if len(cl) == 2:
659
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
660
- else:
661
- output_cols.extend([
662
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
663
- ])
664
- else:
665
- output_cols = []
666
-
667
- # Make sure column names are valid snowflake identifiers.
668
- assert output_cols is not None # Make MyPy happy
669
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
670
-
671
- return rv
672
-
673
641
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
674
642
  @telemetry.send_api_usage_telemetry(
675
643
  project=_PROJECT,
@@ -711,7 +679,7 @@ class DecisionTreeClassifier(BaseTransformer):
711
679
  transform_kwargs = dict(
712
680
  session=dataset._session,
713
681
  dependencies=self._deps,
714
- pass_through_cols=self._get_pass_through_columns(dataset),
682
+ drop_input_cols = self._drop_input_cols,
715
683
  expected_output_cols_type="float",
716
684
  )
717
685
 
@@ -778,7 +746,7 @@ class DecisionTreeClassifier(BaseTransformer):
778
746
  transform_kwargs = dict(
779
747
  session=dataset._session,
780
748
  dependencies=self._deps,
781
- pass_through_cols=self._get_pass_through_columns(dataset),
749
+ drop_input_cols = self._drop_input_cols,
782
750
  expected_output_cols_type="float",
783
751
  )
784
752
  elif isinstance(dataset, pd.DataFrame):
@@ -839,7 +807,7 @@ class DecisionTreeClassifier(BaseTransformer):
839
807
  transform_kwargs = dict(
840
808
  session=dataset._session,
841
809
  dependencies=self._deps,
842
- pass_through_cols=self._get_pass_through_columns(dataset),
810
+ drop_input_cols = self._drop_input_cols,
843
811
  expected_output_cols_type="float",
844
812
  )
845
813
 
@@ -904,7 +872,7 @@ class DecisionTreeClassifier(BaseTransformer):
904
872
  transform_kwargs = dict(
905
873
  session=dataset._session,
906
874
  dependencies=self._deps,
907
- pass_through_cols=self._get_pass_through_columns(dataset),
875
+ drop_input_cols = self._drop_input_cols,
908
876
  expected_output_cols_type="float",
909
877
  )
910
878
 
@@ -960,13 +928,17 @@ class DecisionTreeClassifier(BaseTransformer):
960
928
  transform_kwargs: ScoreKwargsTypedDict = dict()
961
929
 
962
930
  if isinstance(dataset, DataFrame):
931
+ self._deps = self._batch_inference_validate_snowpark(
932
+ dataset=dataset,
933
+ inference_method="score",
934
+ )
963
935
  selected_cols = self._get_active_columns()
964
936
  if len(selected_cols) > 0:
965
937
  dataset = dataset.select(selected_cols)
966
938
  assert isinstance(dataset._session, Session) # keep mypy happy
967
939
  transform_kwargs = dict(
968
940
  session=dataset._session,
969
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
941
+ dependencies=["snowflake-snowpark-python"] + self._deps,
970
942
  score_sproc_imports=['sklearn'],
971
943
  )
972
944
  elif isinstance(dataset, pd.DataFrame):
@@ -1040,9 +1012,9 @@ class DecisionTreeClassifier(BaseTransformer):
1040
1012
  transform_kwargs = dict(
1041
1013
  session = dataset._session,
1042
1014
  dependencies = self._deps,
1043
- pass_through_cols = self._get_pass_through_columns(dataset),
1044
- expected_output_cols_type = "array",
1045
- n_neighbors = n_neighbors,
1015
+ drop_input_cols = self._drop_input_cols,
1016
+ expected_output_cols_type="array",
1017
+ n_neighbors = n_neighbors,
1046
1018
  return_distance = return_distance
1047
1019
  )
1048
1020
  elif isinstance(dataset, pd.DataFrame):
@@ -373,18 +373,24 @@ class DecisionTreeRegressor(BaseTransformer):
373
373
  self._get_model_signatures(dataset)
374
374
  return self
375
375
 
376
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
377
- if self._drop_input_cols:
378
- return []
379
- else:
380
- return list(set(dataset.columns) - set(self.output_cols))
381
-
382
376
  def _batch_inference_validate_snowpark(
383
377
  self,
384
378
  dataset: DataFrame,
385
379
  inference_method: str,
386
380
  ) -> List[str]:
387
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
381
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
382
+ return the available package that exists in the snowflake anaconda channel
383
+
384
+ Args:
385
+ dataset: snowpark dataframe
386
+ inference_method: the inference method such as predict, score...
387
+
388
+ Raises:
389
+ SnowflakeMLException: If the estimator is not fitted, raise error
390
+ SnowflakeMLException: If the session is None, raise error
391
+
392
+ Returns:
393
+ A list of available package that exists in the snowflake anaconda channel
388
394
  """
389
395
  if not self._is_fitted:
390
396
  raise exceptions.SnowflakeMLException(
@@ -458,7 +464,7 @@ class DecisionTreeRegressor(BaseTransformer):
458
464
  transform_kwargs = dict(
459
465
  session = dataset._session,
460
466
  dependencies = self._deps,
461
- pass_through_cols = self._get_pass_through_columns(dataset),
467
+ drop_input_cols = self._drop_input_cols,
462
468
  expected_output_cols_type = expected_type_inferred,
463
469
  )
464
470
 
@@ -518,16 +524,16 @@ class DecisionTreeRegressor(BaseTransformer):
518
524
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
519
525
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
520
526
  # each row containing a list of values.
521
- expected_dtype = "ARRAY"
527
+ expected_dtype = "array"
522
528
 
523
529
  # If we were unable to assign a type to this transform in the factory, infer the type here.
524
530
  if expected_dtype == "":
525
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
531
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
526
532
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
527
- expected_dtype = "ARRAY"
528
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
533
+ expected_dtype = "array"
534
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
529
535
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
530
- expected_dtype = "ARRAY"
536
+ expected_dtype = "array"
531
537
  else:
532
538
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
533
539
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -545,7 +551,7 @@ class DecisionTreeRegressor(BaseTransformer):
545
551
  transform_kwargs = dict(
546
552
  session = dataset._session,
547
553
  dependencies = self._deps,
548
- pass_through_cols = self._get_pass_through_columns(dataset),
554
+ drop_input_cols = self._drop_input_cols,
549
555
  expected_output_cols_type = expected_dtype,
550
556
  )
551
557
 
@@ -596,7 +602,7 @@ class DecisionTreeRegressor(BaseTransformer):
596
602
  subproject=_SUBPROJECT,
597
603
  )
598
604
  output_result, fitted_estimator = model_trainer.train_fit_predict(
599
- pass_through_columns=self._get_pass_through_columns(dataset),
605
+ drop_input_cols=self._drop_input_cols,
600
606
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
601
607
  )
602
608
  self._sklearn_object = fitted_estimator
@@ -614,44 +620,6 @@ class DecisionTreeRegressor(BaseTransformer):
614
620
  assert self._sklearn_object is not None
615
621
  return self._sklearn_object.embedding_
616
622
 
617
-
618
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
619
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
620
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
621
- """
622
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
623
- if output_cols:
624
- output_cols = [
625
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
626
- for c in output_cols
627
- ]
628
- elif getattr(self._sklearn_object, "classes_", None) is None:
629
- output_cols = [output_cols_prefix]
630
- elif self._sklearn_object is not None:
631
- classes = self._sklearn_object.classes_
632
- if isinstance(classes, numpy.ndarray):
633
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
634
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
635
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
636
- output_cols = []
637
- for i, cl in enumerate(classes):
638
- # For binary classification, there is only one output column for each class
639
- # ndarray as the two classes are complementary.
640
- if len(cl) == 2:
641
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
642
- else:
643
- output_cols.extend([
644
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
645
- ])
646
- else:
647
- output_cols = []
648
-
649
- # Make sure column names are valid snowflake identifiers.
650
- assert output_cols is not None # Make MyPy happy
651
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
652
-
653
- return rv
654
-
655
623
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
656
624
  @telemetry.send_api_usage_telemetry(
657
625
  project=_PROJECT,
@@ -691,7 +659,7 @@ class DecisionTreeRegressor(BaseTransformer):
691
659
  transform_kwargs = dict(
692
660
  session=dataset._session,
693
661
  dependencies=self._deps,
694
- pass_through_cols=self._get_pass_through_columns(dataset),
662
+ drop_input_cols = self._drop_input_cols,
695
663
  expected_output_cols_type="float",
696
664
  )
697
665
 
@@ -756,7 +724,7 @@ class DecisionTreeRegressor(BaseTransformer):
756
724
  transform_kwargs = dict(
757
725
  session=dataset._session,
758
726
  dependencies=self._deps,
759
- pass_through_cols=self._get_pass_through_columns(dataset),
727
+ drop_input_cols = self._drop_input_cols,
760
728
  expected_output_cols_type="float",
761
729
  )
762
730
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +785,7 @@ class DecisionTreeRegressor(BaseTransformer):
817
785
  transform_kwargs = dict(
818
786
  session=dataset._session,
819
787
  dependencies=self._deps,
820
- pass_through_cols=self._get_pass_through_columns(dataset),
788
+ drop_input_cols = self._drop_input_cols,
821
789
  expected_output_cols_type="float",
822
790
  )
823
791
 
@@ -882,7 +850,7 @@ class DecisionTreeRegressor(BaseTransformer):
882
850
  transform_kwargs = dict(
883
851
  session=dataset._session,
884
852
  dependencies=self._deps,
885
- pass_through_cols=self._get_pass_through_columns(dataset),
853
+ drop_input_cols = self._drop_input_cols,
886
854
  expected_output_cols_type="float",
887
855
  )
888
856
 
@@ -938,13 +906,17 @@ class DecisionTreeRegressor(BaseTransformer):
938
906
  transform_kwargs: ScoreKwargsTypedDict = dict()
939
907
 
940
908
  if isinstance(dataset, DataFrame):
909
+ self._deps = self._batch_inference_validate_snowpark(
910
+ dataset=dataset,
911
+ inference_method="score",
912
+ )
941
913
  selected_cols = self._get_active_columns()
942
914
  if len(selected_cols) > 0:
943
915
  dataset = dataset.select(selected_cols)
944
916
  assert isinstance(dataset._session, Session) # keep mypy happy
945
917
  transform_kwargs = dict(
946
918
  session=dataset._session,
947
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
919
+ dependencies=["snowflake-snowpark-python"] + self._deps,
948
920
  score_sproc_imports=['sklearn'],
949
921
  )
950
922
  elif isinstance(dataset, pd.DataFrame):
@@ -1018,9 +990,9 @@ class DecisionTreeRegressor(BaseTransformer):
1018
990
  transform_kwargs = dict(
1019
991
  session = dataset._session,
1020
992
  dependencies = self._deps,
1021
- pass_through_cols = self._get_pass_through_columns(dataset),
1022
- expected_output_cols_type = "array",
1023
- n_neighbors = n_neighbors,
993
+ drop_input_cols = self._drop_input_cols,
994
+ expected_output_cols_type="array",
995
+ n_neighbors = n_neighbors,
1024
996
  return_distance = return_distance
1025
997
  )
1026
998
  elif isinstance(dataset, pd.DataFrame):
@@ -383,18 +383,24 @@ class ExtraTreeClassifier(BaseTransformer):
383
383
  self._get_model_signatures(dataset)
384
384
  return self
385
385
 
386
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
387
- if self._drop_input_cols:
388
- return []
389
- else:
390
- return list(set(dataset.columns) - set(self.output_cols))
391
-
392
386
  def _batch_inference_validate_snowpark(
393
387
  self,
394
388
  dataset: DataFrame,
395
389
  inference_method: str,
396
390
  ) -> List[str]:
397
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
391
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
392
+ return the available package that exists in the snowflake anaconda channel
393
+
394
+ Args:
395
+ dataset: snowpark dataframe
396
+ inference_method: the inference method such as predict, score...
397
+
398
+ Raises:
399
+ SnowflakeMLException: If the estimator is not fitted, raise error
400
+ SnowflakeMLException: If the session is None, raise error
401
+
402
+ Returns:
403
+ A list of available package that exists in the snowflake anaconda channel
398
404
  """
399
405
  if not self._is_fitted:
400
406
  raise exceptions.SnowflakeMLException(
@@ -468,7 +474,7 @@ class ExtraTreeClassifier(BaseTransformer):
468
474
  transform_kwargs = dict(
469
475
  session = dataset._session,
470
476
  dependencies = self._deps,
471
- pass_through_cols = self._get_pass_through_columns(dataset),
477
+ drop_input_cols = self._drop_input_cols,
472
478
  expected_output_cols_type = expected_type_inferred,
473
479
  )
474
480
 
@@ -528,16 +534,16 @@ class ExtraTreeClassifier(BaseTransformer):
528
534
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
529
535
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
530
536
  # each row containing a list of values.
531
- expected_dtype = "ARRAY"
537
+ expected_dtype = "array"
532
538
 
533
539
  # If we were unable to assign a type to this transform in the factory, infer the type here.
534
540
  if expected_dtype == "":
535
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
541
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
536
542
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
537
- expected_dtype = "ARRAY"
538
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
543
+ expected_dtype = "array"
544
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
539
545
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
540
- expected_dtype = "ARRAY"
546
+ expected_dtype = "array"
541
547
  else:
542
548
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
543
549
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -555,7 +561,7 @@ class ExtraTreeClassifier(BaseTransformer):
555
561
  transform_kwargs = dict(
556
562
  session = dataset._session,
557
563
  dependencies = self._deps,
558
- pass_through_cols = self._get_pass_through_columns(dataset),
564
+ drop_input_cols = self._drop_input_cols,
559
565
  expected_output_cols_type = expected_dtype,
560
566
  )
561
567
 
@@ -606,7 +612,7 @@ class ExtraTreeClassifier(BaseTransformer):
606
612
  subproject=_SUBPROJECT,
607
613
  )
608
614
  output_result, fitted_estimator = model_trainer.train_fit_predict(
609
- pass_through_columns=self._get_pass_through_columns(dataset),
615
+ drop_input_cols=self._drop_input_cols,
610
616
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
611
617
  )
612
618
  self._sklearn_object = fitted_estimator
@@ -624,44 +630,6 @@ class ExtraTreeClassifier(BaseTransformer):
624
630
  assert self._sklearn_object is not None
625
631
  return self._sklearn_object.embedding_
626
632
 
627
-
628
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
629
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
630
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
631
- """
632
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
633
- if output_cols:
634
- output_cols = [
635
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
636
- for c in output_cols
637
- ]
638
- elif getattr(self._sklearn_object, "classes_", None) is None:
639
- output_cols = [output_cols_prefix]
640
- elif self._sklearn_object is not None:
641
- classes = self._sklearn_object.classes_
642
- if isinstance(classes, numpy.ndarray):
643
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
644
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
645
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
646
- output_cols = []
647
- for i, cl in enumerate(classes):
648
- # For binary classification, there is only one output column for each class
649
- # ndarray as the two classes are complementary.
650
- if len(cl) == 2:
651
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
652
- else:
653
- output_cols.extend([
654
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
655
- ])
656
- else:
657
- output_cols = []
658
-
659
- # Make sure column names are valid snowflake identifiers.
660
- assert output_cols is not None # Make MyPy happy
661
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
662
-
663
- return rv
664
-
665
633
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
666
634
  @telemetry.send_api_usage_telemetry(
667
635
  project=_PROJECT,
@@ -703,7 +671,7 @@ class ExtraTreeClassifier(BaseTransformer):
703
671
  transform_kwargs = dict(
704
672
  session=dataset._session,
705
673
  dependencies=self._deps,
706
- pass_through_cols=self._get_pass_through_columns(dataset),
674
+ drop_input_cols = self._drop_input_cols,
707
675
  expected_output_cols_type="float",
708
676
  )
709
677
 
@@ -770,7 +738,7 @@ class ExtraTreeClassifier(BaseTransformer):
770
738
  transform_kwargs = dict(
771
739
  session=dataset._session,
772
740
  dependencies=self._deps,
773
- pass_through_cols=self._get_pass_through_columns(dataset),
741
+ drop_input_cols = self._drop_input_cols,
774
742
  expected_output_cols_type="float",
775
743
  )
776
744
  elif isinstance(dataset, pd.DataFrame):
@@ -831,7 +799,7 @@ class ExtraTreeClassifier(BaseTransformer):
831
799
  transform_kwargs = dict(
832
800
  session=dataset._session,
833
801
  dependencies=self._deps,
834
- pass_through_cols=self._get_pass_through_columns(dataset),
802
+ drop_input_cols = self._drop_input_cols,
835
803
  expected_output_cols_type="float",
836
804
  )
837
805
 
@@ -896,7 +864,7 @@ class ExtraTreeClassifier(BaseTransformer):
896
864
  transform_kwargs = dict(
897
865
  session=dataset._session,
898
866
  dependencies=self._deps,
899
- pass_through_cols=self._get_pass_through_columns(dataset),
867
+ drop_input_cols = self._drop_input_cols,
900
868
  expected_output_cols_type="float",
901
869
  )
902
870
 
@@ -952,13 +920,17 @@ class ExtraTreeClassifier(BaseTransformer):
952
920
  transform_kwargs: ScoreKwargsTypedDict = dict()
953
921
 
954
922
  if isinstance(dataset, DataFrame):
923
+ self._deps = self._batch_inference_validate_snowpark(
924
+ dataset=dataset,
925
+ inference_method="score",
926
+ )
955
927
  selected_cols = self._get_active_columns()
956
928
  if len(selected_cols) > 0:
957
929
  dataset = dataset.select(selected_cols)
958
930
  assert isinstance(dataset._session, Session) # keep mypy happy
959
931
  transform_kwargs = dict(
960
932
  session=dataset._session,
961
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
933
+ dependencies=["snowflake-snowpark-python"] + self._deps,
962
934
  score_sproc_imports=['sklearn'],
963
935
  )
964
936
  elif isinstance(dataset, pd.DataFrame):
@@ -1032,9 +1004,9 @@ class ExtraTreeClassifier(BaseTransformer):
1032
1004
  transform_kwargs = dict(
1033
1005
  session = dataset._session,
1034
1006
  dependencies = self._deps,
1035
- pass_through_cols = self._get_pass_through_columns(dataset),
1036
- expected_output_cols_type = "array",
1037
- n_neighbors = n_neighbors,
1007
+ drop_input_cols = self._drop_input_cols,
1008
+ expected_output_cols_type="array",
1009
+ n_neighbors = n_neighbors,
1038
1010
  return_distance = return_distance
1039
1011
  )
1040
1012
  elif isinstance(dataset, pd.DataFrame):