snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -284,18 +284,24 @@ class PolynomialCountSketch(BaseTransformer):
284
284
  self._get_model_signatures(dataset)
285
285
  return self
286
286
 
287
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
288
- if self._drop_input_cols:
289
- return []
290
- else:
291
- return list(set(dataset.columns) - set(self.output_cols))
292
-
293
287
  def _batch_inference_validate_snowpark(
294
288
  self,
295
289
  dataset: DataFrame,
296
290
  inference_method: str,
297
291
  ) -> List[str]:
298
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
292
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
293
+ return the available package that exists in the snowflake anaconda channel
294
+
295
+ Args:
296
+ dataset: snowpark dataframe
297
+ inference_method: the inference method such as predict, score...
298
+
299
+ Raises:
300
+ SnowflakeMLException: If the estimator is not fitted, raise error
301
+ SnowflakeMLException: If the session is None, raise error
302
+
303
+ Returns:
304
+ A list of available package that exists in the snowflake anaconda channel
299
305
  """
300
306
  if not self._is_fitted:
301
307
  raise exceptions.SnowflakeMLException(
@@ -367,7 +373,7 @@ class PolynomialCountSketch(BaseTransformer):
367
373
  transform_kwargs = dict(
368
374
  session = dataset._session,
369
375
  dependencies = self._deps,
370
- pass_through_cols = self._get_pass_through_columns(dataset),
376
+ drop_input_cols = self._drop_input_cols,
371
377
  expected_output_cols_type = expected_type_inferred,
372
378
  )
373
379
 
@@ -429,16 +435,16 @@ class PolynomialCountSketch(BaseTransformer):
429
435
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
430
436
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
431
437
  # each row containing a list of values.
432
- expected_dtype = "ARRAY"
438
+ expected_dtype = "array"
433
439
 
434
440
  # If we were unable to assign a type to this transform in the factory, infer the type here.
435
441
  if expected_dtype == "":
436
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
442
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
437
443
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
438
- expected_dtype = "ARRAY"
439
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
444
+ expected_dtype = "array"
445
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
440
446
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
441
- expected_dtype = "ARRAY"
447
+ expected_dtype = "array"
442
448
  else:
443
449
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
444
450
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -456,7 +462,7 @@ class PolynomialCountSketch(BaseTransformer):
456
462
  transform_kwargs = dict(
457
463
  session = dataset._session,
458
464
  dependencies = self._deps,
459
- pass_through_cols = self._get_pass_through_columns(dataset),
465
+ drop_input_cols = self._drop_input_cols,
460
466
  expected_output_cols_type = expected_dtype,
461
467
  )
462
468
 
@@ -507,7 +513,7 @@ class PolynomialCountSketch(BaseTransformer):
507
513
  subproject=_SUBPROJECT,
508
514
  )
509
515
  output_result, fitted_estimator = model_trainer.train_fit_predict(
510
- pass_through_columns=self._get_pass_through_columns(dataset),
516
+ drop_input_cols=self._drop_input_cols,
511
517
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
512
518
  )
513
519
  self._sklearn_object = fitted_estimator
@@ -525,44 +531,6 @@ class PolynomialCountSketch(BaseTransformer):
525
531
  assert self._sklearn_object is not None
526
532
  return self._sklearn_object.embedding_
527
533
 
528
-
529
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
530
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
531
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
532
- """
533
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
534
- if output_cols:
535
- output_cols = [
536
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
537
- for c in output_cols
538
- ]
539
- elif getattr(self._sklearn_object, "classes_", None) is None:
540
- output_cols = [output_cols_prefix]
541
- elif self._sklearn_object is not None:
542
- classes = self._sklearn_object.classes_
543
- if isinstance(classes, numpy.ndarray):
544
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
545
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
546
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
547
- output_cols = []
548
- for i, cl in enumerate(classes):
549
- # For binary classification, there is only one output column for each class
550
- # ndarray as the two classes are complementary.
551
- if len(cl) == 2:
552
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
553
- else:
554
- output_cols.extend([
555
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
556
- ])
557
- else:
558
- output_cols = []
559
-
560
- # Make sure column names are valid snowflake identifiers.
561
- assert output_cols is not None # Make MyPy happy
562
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
563
-
564
- return rv
565
-
566
534
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
567
535
  @telemetry.send_api_usage_telemetry(
568
536
  project=_PROJECT,
@@ -602,7 +570,7 @@ class PolynomialCountSketch(BaseTransformer):
602
570
  transform_kwargs = dict(
603
571
  session=dataset._session,
604
572
  dependencies=self._deps,
605
- pass_through_cols=self._get_pass_through_columns(dataset),
573
+ drop_input_cols = self._drop_input_cols,
606
574
  expected_output_cols_type="float",
607
575
  )
608
576
 
@@ -667,7 +635,7 @@ class PolynomialCountSketch(BaseTransformer):
667
635
  transform_kwargs = dict(
668
636
  session=dataset._session,
669
637
  dependencies=self._deps,
670
- pass_through_cols=self._get_pass_through_columns(dataset),
638
+ drop_input_cols = self._drop_input_cols,
671
639
  expected_output_cols_type="float",
672
640
  )
673
641
  elif isinstance(dataset, pd.DataFrame):
@@ -728,7 +696,7 @@ class PolynomialCountSketch(BaseTransformer):
728
696
  transform_kwargs = dict(
729
697
  session=dataset._session,
730
698
  dependencies=self._deps,
731
- pass_through_cols=self._get_pass_through_columns(dataset),
699
+ drop_input_cols = self._drop_input_cols,
732
700
  expected_output_cols_type="float",
733
701
  )
734
702
 
@@ -793,7 +761,7 @@ class PolynomialCountSketch(BaseTransformer):
793
761
  transform_kwargs = dict(
794
762
  session=dataset._session,
795
763
  dependencies=self._deps,
796
- pass_through_cols=self._get_pass_through_columns(dataset),
764
+ drop_input_cols = self._drop_input_cols,
797
765
  expected_output_cols_type="float",
798
766
  )
799
767
 
@@ -847,13 +815,17 @@ class PolynomialCountSketch(BaseTransformer):
847
815
  transform_kwargs: ScoreKwargsTypedDict = dict()
848
816
 
849
817
  if isinstance(dataset, DataFrame):
818
+ self._deps = self._batch_inference_validate_snowpark(
819
+ dataset=dataset,
820
+ inference_method="score",
821
+ )
850
822
  selected_cols = self._get_active_columns()
851
823
  if len(selected_cols) > 0:
852
824
  dataset = dataset.select(selected_cols)
853
825
  assert isinstance(dataset._session, Session) # keep mypy happy
854
826
  transform_kwargs = dict(
855
827
  session=dataset._session,
856
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
828
+ dependencies=["snowflake-snowpark-python"] + self._deps,
857
829
  score_sproc_imports=['sklearn'],
858
830
  )
859
831
  elif isinstance(dataset, pd.DataFrame):
@@ -927,9 +899,9 @@ class PolynomialCountSketch(BaseTransformer):
927
899
  transform_kwargs = dict(
928
900
  session = dataset._session,
929
901
  dependencies = self._deps,
930
- pass_through_cols = self._get_pass_through_columns(dataset),
931
- expected_output_cols_type = "array",
932
- n_neighbors = n_neighbors,
902
+ drop_input_cols = self._drop_input_cols,
903
+ expected_output_cols_type="array",
904
+ n_neighbors = n_neighbors,
933
905
  return_distance = return_distance
934
906
  )
935
907
  elif isinstance(dataset, pd.DataFrame):
@@ -271,18 +271,24 @@ class RBFSampler(BaseTransformer):
271
271
  self._get_model_signatures(dataset)
272
272
  return self
273
273
 
274
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
275
- if self._drop_input_cols:
276
- return []
277
- else:
278
- return list(set(dataset.columns) - set(self.output_cols))
279
-
280
274
  def _batch_inference_validate_snowpark(
281
275
  self,
282
276
  dataset: DataFrame,
283
277
  inference_method: str,
284
278
  ) -> List[str]:
285
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
279
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
280
+ return the available package that exists in the snowflake anaconda channel
281
+
282
+ Args:
283
+ dataset: snowpark dataframe
284
+ inference_method: the inference method such as predict, score...
285
+
286
+ Raises:
287
+ SnowflakeMLException: If the estimator is not fitted, raise error
288
+ SnowflakeMLException: If the session is None, raise error
289
+
290
+ Returns:
291
+ A list of available package that exists in the snowflake anaconda channel
286
292
  """
287
293
  if not self._is_fitted:
288
294
  raise exceptions.SnowflakeMLException(
@@ -354,7 +360,7 @@ class RBFSampler(BaseTransformer):
354
360
  transform_kwargs = dict(
355
361
  session = dataset._session,
356
362
  dependencies = self._deps,
357
- pass_through_cols = self._get_pass_through_columns(dataset),
363
+ drop_input_cols = self._drop_input_cols,
358
364
  expected_output_cols_type = expected_type_inferred,
359
365
  )
360
366
 
@@ -416,16 +422,16 @@ class RBFSampler(BaseTransformer):
416
422
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
417
423
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
418
424
  # each row containing a list of values.
419
- expected_dtype = "ARRAY"
425
+ expected_dtype = "array"
420
426
 
421
427
  # If we were unable to assign a type to this transform in the factory, infer the type here.
422
428
  if expected_dtype == "":
423
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
429
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
424
430
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
425
- expected_dtype = "ARRAY"
426
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
431
+ expected_dtype = "array"
432
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
427
433
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
428
- expected_dtype = "ARRAY"
434
+ expected_dtype = "array"
429
435
  else:
430
436
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
431
437
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -443,7 +449,7 @@ class RBFSampler(BaseTransformer):
443
449
  transform_kwargs = dict(
444
450
  session = dataset._session,
445
451
  dependencies = self._deps,
446
- pass_through_cols = self._get_pass_through_columns(dataset),
452
+ drop_input_cols = self._drop_input_cols,
447
453
  expected_output_cols_type = expected_dtype,
448
454
  )
449
455
 
@@ -494,7 +500,7 @@ class RBFSampler(BaseTransformer):
494
500
  subproject=_SUBPROJECT,
495
501
  )
496
502
  output_result, fitted_estimator = model_trainer.train_fit_predict(
497
- pass_through_columns=self._get_pass_through_columns(dataset),
503
+ drop_input_cols=self._drop_input_cols,
498
504
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
499
505
  )
500
506
  self._sklearn_object = fitted_estimator
@@ -512,44 +518,6 @@ class RBFSampler(BaseTransformer):
512
518
  assert self._sklearn_object is not None
513
519
  return self._sklearn_object.embedding_
514
520
 
515
-
516
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
517
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
518
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
519
- """
520
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
521
- if output_cols:
522
- output_cols = [
523
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
524
- for c in output_cols
525
- ]
526
- elif getattr(self._sklearn_object, "classes_", None) is None:
527
- output_cols = [output_cols_prefix]
528
- elif self._sklearn_object is not None:
529
- classes = self._sklearn_object.classes_
530
- if isinstance(classes, numpy.ndarray):
531
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
532
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
533
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
534
- output_cols = []
535
- for i, cl in enumerate(classes):
536
- # For binary classification, there is only one output column for each class
537
- # ndarray as the two classes are complementary.
538
- if len(cl) == 2:
539
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
540
- else:
541
- output_cols.extend([
542
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
543
- ])
544
- else:
545
- output_cols = []
546
-
547
- # Make sure column names are valid snowflake identifiers.
548
- assert output_cols is not None # Make MyPy happy
549
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
550
-
551
- return rv
552
-
553
521
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
554
522
  @telemetry.send_api_usage_telemetry(
555
523
  project=_PROJECT,
@@ -589,7 +557,7 @@ class RBFSampler(BaseTransformer):
589
557
  transform_kwargs = dict(
590
558
  session=dataset._session,
591
559
  dependencies=self._deps,
592
- pass_through_cols=self._get_pass_through_columns(dataset),
560
+ drop_input_cols = self._drop_input_cols,
593
561
  expected_output_cols_type="float",
594
562
  )
595
563
 
@@ -654,7 +622,7 @@ class RBFSampler(BaseTransformer):
654
622
  transform_kwargs = dict(
655
623
  session=dataset._session,
656
624
  dependencies=self._deps,
657
- pass_through_cols=self._get_pass_through_columns(dataset),
625
+ drop_input_cols = self._drop_input_cols,
658
626
  expected_output_cols_type="float",
659
627
  )
660
628
  elif isinstance(dataset, pd.DataFrame):
@@ -715,7 +683,7 @@ class RBFSampler(BaseTransformer):
715
683
  transform_kwargs = dict(
716
684
  session=dataset._session,
717
685
  dependencies=self._deps,
718
- pass_through_cols=self._get_pass_through_columns(dataset),
686
+ drop_input_cols = self._drop_input_cols,
719
687
  expected_output_cols_type="float",
720
688
  )
721
689
 
@@ -780,7 +748,7 @@ class RBFSampler(BaseTransformer):
780
748
  transform_kwargs = dict(
781
749
  session=dataset._session,
782
750
  dependencies=self._deps,
783
- pass_through_cols=self._get_pass_through_columns(dataset),
751
+ drop_input_cols = self._drop_input_cols,
784
752
  expected_output_cols_type="float",
785
753
  )
786
754
 
@@ -834,13 +802,17 @@ class RBFSampler(BaseTransformer):
834
802
  transform_kwargs: ScoreKwargsTypedDict = dict()
835
803
 
836
804
  if isinstance(dataset, DataFrame):
805
+ self._deps = self._batch_inference_validate_snowpark(
806
+ dataset=dataset,
807
+ inference_method="score",
808
+ )
837
809
  selected_cols = self._get_active_columns()
838
810
  if len(selected_cols) > 0:
839
811
  dataset = dataset.select(selected_cols)
840
812
  assert isinstance(dataset._session, Session) # keep mypy happy
841
813
  transform_kwargs = dict(
842
814
  session=dataset._session,
843
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
815
+ dependencies=["snowflake-snowpark-python"] + self._deps,
844
816
  score_sproc_imports=['sklearn'],
845
817
  )
846
818
  elif isinstance(dataset, pd.DataFrame):
@@ -914,9 +886,9 @@ class RBFSampler(BaseTransformer):
914
886
  transform_kwargs = dict(
915
887
  session = dataset._session,
916
888
  dependencies = self._deps,
917
- pass_through_cols = self._get_pass_through_columns(dataset),
918
- expected_output_cols_type = "array",
919
- n_neighbors = n_neighbors,
889
+ drop_input_cols = self._drop_input_cols,
890
+ expected_output_cols_type="array",
891
+ n_neighbors = n_neighbors,
920
892
  return_distance = return_distance
921
893
  )
922
894
  elif isinstance(dataset, pd.DataFrame):
@@ -269,18 +269,24 @@ class SkewedChi2Sampler(BaseTransformer):
269
269
  self._get_model_signatures(dataset)
270
270
  return self
271
271
 
272
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
273
- if self._drop_input_cols:
274
- return []
275
- else:
276
- return list(set(dataset.columns) - set(self.output_cols))
277
-
278
272
  def _batch_inference_validate_snowpark(
279
273
  self,
280
274
  dataset: DataFrame,
281
275
  inference_method: str,
282
276
  ) -> List[str]:
283
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
277
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
278
+ return the available package that exists in the snowflake anaconda channel
279
+
280
+ Args:
281
+ dataset: snowpark dataframe
282
+ inference_method: the inference method such as predict, score...
283
+
284
+ Raises:
285
+ SnowflakeMLException: If the estimator is not fitted, raise error
286
+ SnowflakeMLException: If the session is None, raise error
287
+
288
+ Returns:
289
+ A list of available package that exists in the snowflake anaconda channel
284
290
  """
285
291
  if not self._is_fitted:
286
292
  raise exceptions.SnowflakeMLException(
@@ -352,7 +358,7 @@ class SkewedChi2Sampler(BaseTransformer):
352
358
  transform_kwargs = dict(
353
359
  session = dataset._session,
354
360
  dependencies = self._deps,
355
- pass_through_cols = self._get_pass_through_columns(dataset),
361
+ drop_input_cols = self._drop_input_cols,
356
362
  expected_output_cols_type = expected_type_inferred,
357
363
  )
358
364
 
@@ -414,16 +420,16 @@ class SkewedChi2Sampler(BaseTransformer):
414
420
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
415
421
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
416
422
  # each row containing a list of values.
417
- expected_dtype = "ARRAY"
423
+ expected_dtype = "array"
418
424
 
419
425
  # If we were unable to assign a type to this transform in the factory, infer the type here.
420
426
  if expected_dtype == "":
421
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
427
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
422
428
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
423
- expected_dtype = "ARRAY"
424
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
429
+ expected_dtype = "array"
430
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
425
431
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
426
- expected_dtype = "ARRAY"
432
+ expected_dtype = "array"
427
433
  else:
428
434
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
429
435
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -441,7 +447,7 @@ class SkewedChi2Sampler(BaseTransformer):
441
447
  transform_kwargs = dict(
442
448
  session = dataset._session,
443
449
  dependencies = self._deps,
444
- pass_through_cols = self._get_pass_through_columns(dataset),
450
+ drop_input_cols = self._drop_input_cols,
445
451
  expected_output_cols_type = expected_dtype,
446
452
  )
447
453
 
@@ -492,7 +498,7 @@ class SkewedChi2Sampler(BaseTransformer):
492
498
  subproject=_SUBPROJECT,
493
499
  )
494
500
  output_result, fitted_estimator = model_trainer.train_fit_predict(
495
- pass_through_columns=self._get_pass_through_columns(dataset),
501
+ drop_input_cols=self._drop_input_cols,
496
502
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
497
503
  )
498
504
  self._sklearn_object = fitted_estimator
@@ -510,44 +516,6 @@ class SkewedChi2Sampler(BaseTransformer):
510
516
  assert self._sklearn_object is not None
511
517
  return self._sklearn_object.embedding_
512
518
 
513
-
514
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
515
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
516
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
517
- """
518
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
519
- if output_cols:
520
- output_cols = [
521
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
522
- for c in output_cols
523
- ]
524
- elif getattr(self._sklearn_object, "classes_", None) is None:
525
- output_cols = [output_cols_prefix]
526
- elif self._sklearn_object is not None:
527
- classes = self._sklearn_object.classes_
528
- if isinstance(classes, numpy.ndarray):
529
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
530
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
531
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
532
- output_cols = []
533
- for i, cl in enumerate(classes):
534
- # For binary classification, there is only one output column for each class
535
- # ndarray as the two classes are complementary.
536
- if len(cl) == 2:
537
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
538
- else:
539
- output_cols.extend([
540
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
541
- ])
542
- else:
543
- output_cols = []
544
-
545
- # Make sure column names are valid snowflake identifiers.
546
- assert output_cols is not None # Make MyPy happy
547
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
548
-
549
- return rv
550
-
551
519
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
552
520
  @telemetry.send_api_usage_telemetry(
553
521
  project=_PROJECT,
@@ -587,7 +555,7 @@ class SkewedChi2Sampler(BaseTransformer):
587
555
  transform_kwargs = dict(
588
556
  session=dataset._session,
589
557
  dependencies=self._deps,
590
- pass_through_cols=self._get_pass_through_columns(dataset),
558
+ drop_input_cols = self._drop_input_cols,
591
559
  expected_output_cols_type="float",
592
560
  )
593
561
 
@@ -652,7 +620,7 @@ class SkewedChi2Sampler(BaseTransformer):
652
620
  transform_kwargs = dict(
653
621
  session=dataset._session,
654
622
  dependencies=self._deps,
655
- pass_through_cols=self._get_pass_through_columns(dataset),
623
+ drop_input_cols = self._drop_input_cols,
656
624
  expected_output_cols_type="float",
657
625
  )
658
626
  elif isinstance(dataset, pd.DataFrame):
@@ -713,7 +681,7 @@ class SkewedChi2Sampler(BaseTransformer):
713
681
  transform_kwargs = dict(
714
682
  session=dataset._session,
715
683
  dependencies=self._deps,
716
- pass_through_cols=self._get_pass_through_columns(dataset),
684
+ drop_input_cols = self._drop_input_cols,
717
685
  expected_output_cols_type="float",
718
686
  )
719
687
 
@@ -778,7 +746,7 @@ class SkewedChi2Sampler(BaseTransformer):
778
746
  transform_kwargs = dict(
779
747
  session=dataset._session,
780
748
  dependencies=self._deps,
781
- pass_through_cols=self._get_pass_through_columns(dataset),
749
+ drop_input_cols = self._drop_input_cols,
782
750
  expected_output_cols_type="float",
783
751
  )
784
752
 
@@ -832,13 +800,17 @@ class SkewedChi2Sampler(BaseTransformer):
832
800
  transform_kwargs: ScoreKwargsTypedDict = dict()
833
801
 
834
802
  if isinstance(dataset, DataFrame):
803
+ self._deps = self._batch_inference_validate_snowpark(
804
+ dataset=dataset,
805
+ inference_method="score",
806
+ )
835
807
  selected_cols = self._get_active_columns()
836
808
  if len(selected_cols) > 0:
837
809
  dataset = dataset.select(selected_cols)
838
810
  assert isinstance(dataset._session, Session) # keep mypy happy
839
811
  transform_kwargs = dict(
840
812
  session=dataset._session,
841
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
813
+ dependencies=["snowflake-snowpark-python"] + self._deps,
842
814
  score_sproc_imports=['sklearn'],
843
815
  )
844
816
  elif isinstance(dataset, pd.DataFrame):
@@ -912,9 +884,9 @@ class SkewedChi2Sampler(BaseTransformer):
912
884
  transform_kwargs = dict(
913
885
  session = dataset._session,
914
886
  dependencies = self._deps,
915
- pass_through_cols = self._get_pass_through_columns(dataset),
916
- expected_output_cols_type = "array",
917
- n_neighbors = n_neighbors,
887
+ drop_input_cols = self._drop_input_cols,
888
+ expected_output_cols_type="array",
889
+ n_neighbors = n_neighbors,
918
890
  return_distance = return_distance
919
891
  )
920
892
  elif isinstance(dataset, pd.DataFrame):