snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -283,18 +283,24 @@ class OutputCodeClassifier(BaseTransformer):
283
283
  self._get_model_signatures(dataset)
284
284
  return self
285
285
 
286
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
287
- if self._drop_input_cols:
288
- return []
289
- else:
290
- return list(set(dataset.columns) - set(self.output_cols))
291
-
292
286
  def _batch_inference_validate_snowpark(
293
287
  self,
294
288
  dataset: DataFrame,
295
289
  inference_method: str,
296
290
  ) -> List[str]:
297
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
291
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
292
+ return the available package that exists in the snowflake anaconda channel
293
+
294
+ Args:
295
+ dataset: snowpark dataframe
296
+ inference_method: the inference method such as predict, score...
297
+
298
+ Raises:
299
+ SnowflakeMLException: If the estimator is not fitted, raise error
300
+ SnowflakeMLException: If the session is None, raise error
301
+
302
+ Returns:
303
+ A list of available package that exists in the snowflake anaconda channel
298
304
  """
299
305
  if not self._is_fitted:
300
306
  raise exceptions.SnowflakeMLException(
@@ -368,7 +374,7 @@ class OutputCodeClassifier(BaseTransformer):
368
374
  transform_kwargs = dict(
369
375
  session = dataset._session,
370
376
  dependencies = self._deps,
371
- pass_through_cols = self._get_pass_through_columns(dataset),
377
+ drop_input_cols = self._drop_input_cols,
372
378
  expected_output_cols_type = expected_type_inferred,
373
379
  )
374
380
 
@@ -428,16 +434,16 @@ class OutputCodeClassifier(BaseTransformer):
428
434
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
429
435
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
430
436
  # each row containing a list of values.
431
- expected_dtype = "ARRAY"
437
+ expected_dtype = "array"
432
438
 
433
439
  # If we were unable to assign a type to this transform in the factory, infer the type here.
434
440
  if expected_dtype == "":
435
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
441
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
436
442
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
437
- expected_dtype = "ARRAY"
438
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
443
+ expected_dtype = "array"
444
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
439
445
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
440
- expected_dtype = "ARRAY"
446
+ expected_dtype = "array"
441
447
  else:
442
448
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
443
449
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -455,7 +461,7 @@ class OutputCodeClassifier(BaseTransformer):
455
461
  transform_kwargs = dict(
456
462
  session = dataset._session,
457
463
  dependencies = self._deps,
458
- pass_through_cols = self._get_pass_through_columns(dataset),
464
+ drop_input_cols = self._drop_input_cols,
459
465
  expected_output_cols_type = expected_dtype,
460
466
  )
461
467
 
@@ -506,7 +512,7 @@ class OutputCodeClassifier(BaseTransformer):
506
512
  subproject=_SUBPROJECT,
507
513
  )
508
514
  output_result, fitted_estimator = model_trainer.train_fit_predict(
509
- pass_through_columns=self._get_pass_through_columns(dataset),
515
+ drop_input_cols=self._drop_input_cols,
510
516
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
511
517
  )
512
518
  self._sklearn_object = fitted_estimator
@@ -524,44 +530,6 @@ class OutputCodeClassifier(BaseTransformer):
524
530
  assert self._sklearn_object is not None
525
531
  return self._sklearn_object.embedding_
526
532
 
527
-
528
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
529
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
530
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
531
- """
532
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
533
- if output_cols:
534
- output_cols = [
535
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
536
- for c in output_cols
537
- ]
538
- elif getattr(self._sklearn_object, "classes_", None) is None:
539
- output_cols = [output_cols_prefix]
540
- elif self._sklearn_object is not None:
541
- classes = self._sklearn_object.classes_
542
- if isinstance(classes, numpy.ndarray):
543
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
544
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
545
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
546
- output_cols = []
547
- for i, cl in enumerate(classes):
548
- # For binary classification, there is only one output column for each class
549
- # ndarray as the two classes are complementary.
550
- if len(cl) == 2:
551
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
552
- else:
553
- output_cols.extend([
554
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
555
- ])
556
- else:
557
- output_cols = []
558
-
559
- # Make sure column names are valid snowflake identifiers.
560
- assert output_cols is not None # Make MyPy happy
561
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
562
-
563
- return rv
564
-
565
533
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
566
534
  @telemetry.send_api_usage_telemetry(
567
535
  project=_PROJECT,
@@ -601,7 +569,7 @@ class OutputCodeClassifier(BaseTransformer):
601
569
  transform_kwargs = dict(
602
570
  session=dataset._session,
603
571
  dependencies=self._deps,
604
- pass_through_cols=self._get_pass_through_columns(dataset),
572
+ drop_input_cols = self._drop_input_cols,
605
573
  expected_output_cols_type="float",
606
574
  )
607
575
 
@@ -666,7 +634,7 @@ class OutputCodeClassifier(BaseTransformer):
666
634
  transform_kwargs = dict(
667
635
  session=dataset._session,
668
636
  dependencies=self._deps,
669
- pass_through_cols=self._get_pass_through_columns(dataset),
637
+ drop_input_cols = self._drop_input_cols,
670
638
  expected_output_cols_type="float",
671
639
  )
672
640
  elif isinstance(dataset, pd.DataFrame):
@@ -727,7 +695,7 @@ class OutputCodeClassifier(BaseTransformer):
727
695
  transform_kwargs = dict(
728
696
  session=dataset._session,
729
697
  dependencies=self._deps,
730
- pass_through_cols=self._get_pass_through_columns(dataset),
698
+ drop_input_cols = self._drop_input_cols,
731
699
  expected_output_cols_type="float",
732
700
  )
733
701
 
@@ -792,7 +760,7 @@ class OutputCodeClassifier(BaseTransformer):
792
760
  transform_kwargs = dict(
793
761
  session=dataset._session,
794
762
  dependencies=self._deps,
795
- pass_through_cols=self._get_pass_through_columns(dataset),
763
+ drop_input_cols = self._drop_input_cols,
796
764
  expected_output_cols_type="float",
797
765
  )
798
766
 
@@ -848,13 +816,17 @@ class OutputCodeClassifier(BaseTransformer):
848
816
  transform_kwargs: ScoreKwargsTypedDict = dict()
849
817
 
850
818
  if isinstance(dataset, DataFrame):
819
+ self._deps = self._batch_inference_validate_snowpark(
820
+ dataset=dataset,
821
+ inference_method="score",
822
+ )
851
823
  selected_cols = self._get_active_columns()
852
824
  if len(selected_cols) > 0:
853
825
  dataset = dataset.select(selected_cols)
854
826
  assert isinstance(dataset._session, Session) # keep mypy happy
855
827
  transform_kwargs = dict(
856
828
  session=dataset._session,
857
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
829
+ dependencies=["snowflake-snowpark-python"] + self._deps,
858
830
  score_sproc_imports=['sklearn'],
859
831
  )
860
832
  elif isinstance(dataset, pd.DataFrame):
@@ -928,9 +900,9 @@ class OutputCodeClassifier(BaseTransformer):
928
900
  transform_kwargs = dict(
929
901
  session = dataset._session,
930
902
  dependencies = self._deps,
931
- pass_through_cols = self._get_pass_through_columns(dataset),
932
- expected_output_cols_type = "array",
933
- n_neighbors = n_neighbors,
903
+ drop_input_cols = self._drop_input_cols,
904
+ expected_output_cols_type="array",
905
+ n_neighbors = n_neighbors,
934
906
  return_distance = return_distance
935
907
  )
936
908
  elif isinstance(dataset, pd.DataFrame):
@@ -283,18 +283,24 @@ class BernoulliNB(BaseTransformer):
283
283
  self._get_model_signatures(dataset)
284
284
  return self
285
285
 
286
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
287
- if self._drop_input_cols:
288
- return []
289
- else:
290
- return list(set(dataset.columns) - set(self.output_cols))
291
-
292
286
  def _batch_inference_validate_snowpark(
293
287
  self,
294
288
  dataset: DataFrame,
295
289
  inference_method: str,
296
290
  ) -> List[str]:
297
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
291
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
292
+ return the available package that exists in the snowflake anaconda channel
293
+
294
+ Args:
295
+ dataset: snowpark dataframe
296
+ inference_method: the inference method such as predict, score...
297
+
298
+ Raises:
299
+ SnowflakeMLException: If the estimator is not fitted, raise error
300
+ SnowflakeMLException: If the session is None, raise error
301
+
302
+ Returns:
303
+ A list of available package that exists in the snowflake anaconda channel
298
304
  """
299
305
  if not self._is_fitted:
300
306
  raise exceptions.SnowflakeMLException(
@@ -368,7 +374,7 @@ class BernoulliNB(BaseTransformer):
368
374
  transform_kwargs = dict(
369
375
  session = dataset._session,
370
376
  dependencies = self._deps,
371
- pass_through_cols = self._get_pass_through_columns(dataset),
377
+ drop_input_cols = self._drop_input_cols,
372
378
  expected_output_cols_type = expected_type_inferred,
373
379
  )
374
380
 
@@ -428,16 +434,16 @@ class BernoulliNB(BaseTransformer):
428
434
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
429
435
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
430
436
  # each row containing a list of values.
431
- expected_dtype = "ARRAY"
437
+ expected_dtype = "array"
432
438
 
433
439
  # If we were unable to assign a type to this transform in the factory, infer the type here.
434
440
  if expected_dtype == "":
435
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
441
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
436
442
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
437
- expected_dtype = "ARRAY"
438
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
443
+ expected_dtype = "array"
444
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
439
445
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
440
- expected_dtype = "ARRAY"
446
+ expected_dtype = "array"
441
447
  else:
442
448
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
443
449
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -455,7 +461,7 @@ class BernoulliNB(BaseTransformer):
455
461
  transform_kwargs = dict(
456
462
  session = dataset._session,
457
463
  dependencies = self._deps,
458
- pass_through_cols = self._get_pass_through_columns(dataset),
464
+ drop_input_cols = self._drop_input_cols,
459
465
  expected_output_cols_type = expected_dtype,
460
466
  )
461
467
 
@@ -506,7 +512,7 @@ class BernoulliNB(BaseTransformer):
506
512
  subproject=_SUBPROJECT,
507
513
  )
508
514
  output_result, fitted_estimator = model_trainer.train_fit_predict(
509
- pass_through_columns=self._get_pass_through_columns(dataset),
515
+ drop_input_cols=self._drop_input_cols,
510
516
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
511
517
  )
512
518
  self._sklearn_object = fitted_estimator
@@ -524,44 +530,6 @@ class BernoulliNB(BaseTransformer):
524
530
  assert self._sklearn_object is not None
525
531
  return self._sklearn_object.embedding_
526
532
 
527
-
528
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
529
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
530
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
531
- """
532
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
533
- if output_cols:
534
- output_cols = [
535
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
536
- for c in output_cols
537
- ]
538
- elif getattr(self._sklearn_object, "classes_", None) is None:
539
- output_cols = [output_cols_prefix]
540
- elif self._sklearn_object is not None:
541
- classes = self._sklearn_object.classes_
542
- if isinstance(classes, numpy.ndarray):
543
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
544
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
545
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
546
- output_cols = []
547
- for i, cl in enumerate(classes):
548
- # For binary classification, there is only one output column for each class
549
- # ndarray as the two classes are complementary.
550
- if len(cl) == 2:
551
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
552
- else:
553
- output_cols.extend([
554
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
555
- ])
556
- else:
557
- output_cols = []
558
-
559
- # Make sure column names are valid snowflake identifiers.
560
- assert output_cols is not None # Make MyPy happy
561
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
562
-
563
- return rv
564
-
565
533
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
566
534
  @telemetry.send_api_usage_telemetry(
567
535
  project=_PROJECT,
@@ -603,7 +571,7 @@ class BernoulliNB(BaseTransformer):
603
571
  transform_kwargs = dict(
604
572
  session=dataset._session,
605
573
  dependencies=self._deps,
606
- pass_through_cols=self._get_pass_through_columns(dataset),
574
+ drop_input_cols = self._drop_input_cols,
607
575
  expected_output_cols_type="float",
608
576
  )
609
577
 
@@ -670,7 +638,7 @@ class BernoulliNB(BaseTransformer):
670
638
  transform_kwargs = dict(
671
639
  session=dataset._session,
672
640
  dependencies=self._deps,
673
- pass_through_cols=self._get_pass_through_columns(dataset),
641
+ drop_input_cols = self._drop_input_cols,
674
642
  expected_output_cols_type="float",
675
643
  )
676
644
  elif isinstance(dataset, pd.DataFrame):
@@ -731,7 +699,7 @@ class BernoulliNB(BaseTransformer):
731
699
  transform_kwargs = dict(
732
700
  session=dataset._session,
733
701
  dependencies=self._deps,
734
- pass_through_cols=self._get_pass_through_columns(dataset),
702
+ drop_input_cols = self._drop_input_cols,
735
703
  expected_output_cols_type="float",
736
704
  )
737
705
 
@@ -796,7 +764,7 @@ class BernoulliNB(BaseTransformer):
796
764
  transform_kwargs = dict(
797
765
  session=dataset._session,
798
766
  dependencies=self._deps,
799
- pass_through_cols=self._get_pass_through_columns(dataset),
767
+ drop_input_cols = self._drop_input_cols,
800
768
  expected_output_cols_type="float",
801
769
  )
802
770
 
@@ -852,13 +820,17 @@ class BernoulliNB(BaseTransformer):
852
820
  transform_kwargs: ScoreKwargsTypedDict = dict()
853
821
 
854
822
  if isinstance(dataset, DataFrame):
823
+ self._deps = self._batch_inference_validate_snowpark(
824
+ dataset=dataset,
825
+ inference_method="score",
826
+ )
855
827
  selected_cols = self._get_active_columns()
856
828
  if len(selected_cols) > 0:
857
829
  dataset = dataset.select(selected_cols)
858
830
  assert isinstance(dataset._session, Session) # keep mypy happy
859
831
  transform_kwargs = dict(
860
832
  session=dataset._session,
861
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
833
+ dependencies=["snowflake-snowpark-python"] + self._deps,
862
834
  score_sproc_imports=['sklearn'],
863
835
  )
864
836
  elif isinstance(dataset, pd.DataFrame):
@@ -932,9 +904,9 @@ class BernoulliNB(BaseTransformer):
932
904
  transform_kwargs = dict(
933
905
  session = dataset._session,
934
906
  dependencies = self._deps,
935
- pass_through_cols = self._get_pass_through_columns(dataset),
936
- expected_output_cols_type = "array",
937
- n_neighbors = n_neighbors,
907
+ drop_input_cols = self._drop_input_cols,
908
+ expected_output_cols_type="array",
909
+ n_neighbors = n_neighbors,
938
910
  return_distance = return_distance
939
911
  )
940
912
  elif isinstance(dataset, pd.DataFrame):
@@ -289,18 +289,24 @@ class CategoricalNB(BaseTransformer):
289
289
  self._get_model_signatures(dataset)
290
290
  return self
291
291
 
292
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
293
- if self._drop_input_cols:
294
- return []
295
- else:
296
- return list(set(dataset.columns) - set(self.output_cols))
297
-
298
292
  def _batch_inference_validate_snowpark(
299
293
  self,
300
294
  dataset: DataFrame,
301
295
  inference_method: str,
302
296
  ) -> List[str]:
303
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
297
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
298
+ return the available package that exists in the snowflake anaconda channel
299
+
300
+ Args:
301
+ dataset: snowpark dataframe
302
+ inference_method: the inference method such as predict, score...
303
+
304
+ Raises:
305
+ SnowflakeMLException: If the estimator is not fitted, raise error
306
+ SnowflakeMLException: If the session is None, raise error
307
+
308
+ Returns:
309
+ A list of available package that exists in the snowflake anaconda channel
304
310
  """
305
311
  if not self._is_fitted:
306
312
  raise exceptions.SnowflakeMLException(
@@ -374,7 +380,7 @@ class CategoricalNB(BaseTransformer):
374
380
  transform_kwargs = dict(
375
381
  session = dataset._session,
376
382
  dependencies = self._deps,
377
- pass_through_cols = self._get_pass_through_columns(dataset),
383
+ drop_input_cols = self._drop_input_cols,
378
384
  expected_output_cols_type = expected_type_inferred,
379
385
  )
380
386
 
@@ -434,16 +440,16 @@ class CategoricalNB(BaseTransformer):
434
440
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
435
441
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
436
442
  # each row containing a list of values.
437
- expected_dtype = "ARRAY"
443
+ expected_dtype = "array"
438
444
 
439
445
  # If we were unable to assign a type to this transform in the factory, infer the type here.
440
446
  if expected_dtype == "":
441
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
447
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
442
448
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
443
- expected_dtype = "ARRAY"
444
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
449
+ expected_dtype = "array"
450
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
445
451
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
446
- expected_dtype = "ARRAY"
452
+ expected_dtype = "array"
447
453
  else:
448
454
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
449
455
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -461,7 +467,7 @@ class CategoricalNB(BaseTransformer):
461
467
  transform_kwargs = dict(
462
468
  session = dataset._session,
463
469
  dependencies = self._deps,
464
- pass_through_cols = self._get_pass_through_columns(dataset),
470
+ drop_input_cols = self._drop_input_cols,
465
471
  expected_output_cols_type = expected_dtype,
466
472
  )
467
473
 
@@ -512,7 +518,7 @@ class CategoricalNB(BaseTransformer):
512
518
  subproject=_SUBPROJECT,
513
519
  )
514
520
  output_result, fitted_estimator = model_trainer.train_fit_predict(
515
- pass_through_columns=self._get_pass_through_columns(dataset),
521
+ drop_input_cols=self._drop_input_cols,
516
522
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
517
523
  )
518
524
  self._sklearn_object = fitted_estimator
@@ -530,44 +536,6 @@ class CategoricalNB(BaseTransformer):
530
536
  assert self._sklearn_object is not None
531
537
  return self._sklearn_object.embedding_
532
538
 
533
-
534
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
535
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
536
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
537
- """
538
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
539
- if output_cols:
540
- output_cols = [
541
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
542
- for c in output_cols
543
- ]
544
- elif getattr(self._sklearn_object, "classes_", None) is None:
545
- output_cols = [output_cols_prefix]
546
- elif self._sklearn_object is not None:
547
- classes = self._sklearn_object.classes_
548
- if isinstance(classes, numpy.ndarray):
549
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
550
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
551
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
552
- output_cols = []
553
- for i, cl in enumerate(classes):
554
- # For binary classification, there is only one output column for each class
555
- # ndarray as the two classes are complementary.
556
- if len(cl) == 2:
557
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
558
- else:
559
- output_cols.extend([
560
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
561
- ])
562
- else:
563
- output_cols = []
564
-
565
- # Make sure column names are valid snowflake identifiers.
566
- assert output_cols is not None # Make MyPy happy
567
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
568
-
569
- return rv
570
-
571
539
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
572
540
  @telemetry.send_api_usage_telemetry(
573
541
  project=_PROJECT,
@@ -609,7 +577,7 @@ class CategoricalNB(BaseTransformer):
609
577
  transform_kwargs = dict(
610
578
  session=dataset._session,
611
579
  dependencies=self._deps,
612
- pass_through_cols=self._get_pass_through_columns(dataset),
580
+ drop_input_cols = self._drop_input_cols,
613
581
  expected_output_cols_type="float",
614
582
  )
615
583
 
@@ -676,7 +644,7 @@ class CategoricalNB(BaseTransformer):
676
644
  transform_kwargs = dict(
677
645
  session=dataset._session,
678
646
  dependencies=self._deps,
679
- pass_through_cols=self._get_pass_through_columns(dataset),
647
+ drop_input_cols = self._drop_input_cols,
680
648
  expected_output_cols_type="float",
681
649
  )
682
650
  elif isinstance(dataset, pd.DataFrame):
@@ -737,7 +705,7 @@ class CategoricalNB(BaseTransformer):
737
705
  transform_kwargs = dict(
738
706
  session=dataset._session,
739
707
  dependencies=self._deps,
740
- pass_through_cols=self._get_pass_through_columns(dataset),
708
+ drop_input_cols = self._drop_input_cols,
741
709
  expected_output_cols_type="float",
742
710
  )
743
711
 
@@ -802,7 +770,7 @@ class CategoricalNB(BaseTransformer):
802
770
  transform_kwargs = dict(
803
771
  session=dataset._session,
804
772
  dependencies=self._deps,
805
- pass_through_cols=self._get_pass_through_columns(dataset),
773
+ drop_input_cols = self._drop_input_cols,
806
774
  expected_output_cols_type="float",
807
775
  )
808
776
 
@@ -858,13 +826,17 @@ class CategoricalNB(BaseTransformer):
858
826
  transform_kwargs: ScoreKwargsTypedDict = dict()
859
827
 
860
828
  if isinstance(dataset, DataFrame):
829
+ self._deps = self._batch_inference_validate_snowpark(
830
+ dataset=dataset,
831
+ inference_method="score",
832
+ )
861
833
  selected_cols = self._get_active_columns()
862
834
  if len(selected_cols) > 0:
863
835
  dataset = dataset.select(selected_cols)
864
836
  assert isinstance(dataset._session, Session) # keep mypy happy
865
837
  transform_kwargs = dict(
866
838
  session=dataset._session,
867
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
839
+ dependencies=["snowflake-snowpark-python"] + self._deps,
868
840
  score_sproc_imports=['sklearn'],
869
841
  )
870
842
  elif isinstance(dataset, pd.DataFrame):
@@ -938,9 +910,9 @@ class CategoricalNB(BaseTransformer):
938
910
  transform_kwargs = dict(
939
911
  session = dataset._session,
940
912
  dependencies = self._deps,
941
- pass_through_cols = self._get_pass_through_columns(dataset),
942
- expected_output_cols_type = "array",
943
- n_neighbors = n_neighbors,
913
+ drop_input_cols = self._drop_input_cols,
914
+ expected_output_cols_type="array",
915
+ n_neighbors = n_neighbors,
944
916
  return_distance = return_distance
945
917
  )
946
918
  elif isinstance(dataset, pd.DataFrame):