snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -341,18 +341,24 @@ class LocalOutlierFactor(BaseTransformer):
341
341
  self._get_model_signatures(dataset)
342
342
  return self
343
343
 
344
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
345
- if self._drop_input_cols:
346
- return []
347
- else:
348
- return list(set(dataset.columns) - set(self.output_cols))
349
-
350
344
  def _batch_inference_validate_snowpark(
351
345
  self,
352
346
  dataset: DataFrame,
353
347
  inference_method: str,
354
348
  ) -> List[str]:
355
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
349
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
350
+ return the available package that exists in the snowflake anaconda channel
351
+
352
+ Args:
353
+ dataset: snowpark dataframe
354
+ inference_method: the inference method such as predict, score...
355
+
356
+ Raises:
357
+ SnowflakeMLException: If the estimator is not fitted, raise error
358
+ SnowflakeMLException: If the session is None, raise error
359
+
360
+ Returns:
361
+ A list of available package that exists in the snowflake anaconda channel
356
362
  """
357
363
  if not self._is_fitted:
358
364
  raise exceptions.SnowflakeMLException(
@@ -426,7 +432,7 @@ class LocalOutlierFactor(BaseTransformer):
426
432
  transform_kwargs = dict(
427
433
  session = dataset._session,
428
434
  dependencies = self._deps,
429
- pass_through_cols = self._get_pass_through_columns(dataset),
435
+ drop_input_cols = self._drop_input_cols,
430
436
  expected_output_cols_type = expected_type_inferred,
431
437
  )
432
438
 
@@ -486,16 +492,16 @@ class LocalOutlierFactor(BaseTransformer):
486
492
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
487
493
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
488
494
  # each row containing a list of values.
489
- expected_dtype = "ARRAY"
495
+ expected_dtype = "array"
490
496
 
491
497
  # If we were unable to assign a type to this transform in the factory, infer the type here.
492
498
  if expected_dtype == "":
493
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
499
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
494
500
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
495
- expected_dtype = "ARRAY"
496
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
501
+ expected_dtype = "array"
502
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
497
503
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
498
- expected_dtype = "ARRAY"
504
+ expected_dtype = "array"
499
505
  else:
500
506
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
501
507
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -513,7 +519,7 @@ class LocalOutlierFactor(BaseTransformer):
513
519
  transform_kwargs = dict(
514
520
  session = dataset._session,
515
521
  dependencies = self._deps,
516
- pass_through_cols = self._get_pass_through_columns(dataset),
522
+ drop_input_cols = self._drop_input_cols,
517
523
  expected_output_cols_type = expected_dtype,
518
524
  )
519
525
 
@@ -566,7 +572,7 @@ class LocalOutlierFactor(BaseTransformer):
566
572
  subproject=_SUBPROJECT,
567
573
  )
568
574
  output_result, fitted_estimator = model_trainer.train_fit_predict(
569
- pass_through_columns=self._get_pass_through_columns(dataset),
575
+ drop_input_cols=self._drop_input_cols,
570
576
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
571
577
  )
572
578
  self._sklearn_object = fitted_estimator
@@ -584,44 +590,6 @@ class LocalOutlierFactor(BaseTransformer):
584
590
  assert self._sklearn_object is not None
585
591
  return self._sklearn_object.embedding_
586
592
 
587
-
588
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
589
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
590
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
591
- """
592
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
593
- if output_cols:
594
- output_cols = [
595
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
596
- for c in output_cols
597
- ]
598
- elif getattr(self._sklearn_object, "classes_", None) is None:
599
- output_cols = [output_cols_prefix]
600
- elif self._sklearn_object is not None:
601
- classes = self._sklearn_object.classes_
602
- if isinstance(classes, numpy.ndarray):
603
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
604
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
605
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
606
- output_cols = []
607
- for i, cl in enumerate(classes):
608
- # For binary classification, there is only one output column for each class
609
- # ndarray as the two classes are complementary.
610
- if len(cl) == 2:
611
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
612
- else:
613
- output_cols.extend([
614
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
615
- ])
616
- else:
617
- output_cols = []
618
-
619
- # Make sure column names are valid snowflake identifiers.
620
- assert output_cols is not None # Make MyPy happy
621
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
622
-
623
- return rv
624
-
625
593
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
626
594
  @telemetry.send_api_usage_telemetry(
627
595
  project=_PROJECT,
@@ -661,7 +629,7 @@ class LocalOutlierFactor(BaseTransformer):
661
629
  transform_kwargs = dict(
662
630
  session=dataset._session,
663
631
  dependencies=self._deps,
664
- pass_through_cols=self._get_pass_through_columns(dataset),
632
+ drop_input_cols = self._drop_input_cols,
665
633
  expected_output_cols_type="float",
666
634
  )
667
635
 
@@ -726,7 +694,7 @@ class LocalOutlierFactor(BaseTransformer):
726
694
  transform_kwargs = dict(
727
695
  session=dataset._session,
728
696
  dependencies=self._deps,
729
- pass_through_cols=self._get_pass_through_columns(dataset),
697
+ drop_input_cols = self._drop_input_cols,
730
698
  expected_output_cols_type="float",
731
699
  )
732
700
  elif isinstance(dataset, pd.DataFrame):
@@ -789,7 +757,7 @@ class LocalOutlierFactor(BaseTransformer):
789
757
  transform_kwargs = dict(
790
758
  session=dataset._session,
791
759
  dependencies=self._deps,
792
- pass_through_cols=self._get_pass_through_columns(dataset),
760
+ drop_input_cols = self._drop_input_cols,
793
761
  expected_output_cols_type="float",
794
762
  )
795
763
 
@@ -856,7 +824,7 @@ class LocalOutlierFactor(BaseTransformer):
856
824
  transform_kwargs = dict(
857
825
  session=dataset._session,
858
826
  dependencies=self._deps,
859
- pass_through_cols=self._get_pass_through_columns(dataset),
827
+ drop_input_cols = self._drop_input_cols,
860
828
  expected_output_cols_type="float",
861
829
  )
862
830
 
@@ -910,13 +878,17 @@ class LocalOutlierFactor(BaseTransformer):
910
878
  transform_kwargs: ScoreKwargsTypedDict = dict()
911
879
 
912
880
  if isinstance(dataset, DataFrame):
881
+ self._deps = self._batch_inference_validate_snowpark(
882
+ dataset=dataset,
883
+ inference_method="score",
884
+ )
913
885
  selected_cols = self._get_active_columns()
914
886
  if len(selected_cols) > 0:
915
887
  dataset = dataset.select(selected_cols)
916
888
  assert isinstance(dataset._session, Session) # keep mypy happy
917
889
  transform_kwargs = dict(
918
890
  session=dataset._session,
919
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
891
+ dependencies=["snowflake-snowpark-python"] + self._deps,
920
892
  score_sproc_imports=['sklearn'],
921
893
  )
922
894
  elif isinstance(dataset, pd.DataFrame):
@@ -992,9 +964,9 @@ class LocalOutlierFactor(BaseTransformer):
992
964
  transform_kwargs = dict(
993
965
  session = dataset._session,
994
966
  dependencies = self._deps,
995
- pass_through_cols = self._get_pass_through_columns(dataset),
996
- expected_output_cols_type = "array",
997
- n_neighbors = n_neighbors,
967
+ drop_input_cols = self._drop_input_cols,
968
+ expected_output_cols_type="array",
969
+ n_neighbors = n_neighbors,
998
970
  return_distance = return_distance
999
971
  )
1000
972
  elif isinstance(dataset, pd.DataFrame):
@@ -274,18 +274,24 @@ class NearestCentroid(BaseTransformer):
274
274
  self._get_model_signatures(dataset)
275
275
  return self
276
276
 
277
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
278
- if self._drop_input_cols:
279
- return []
280
- else:
281
- return list(set(dataset.columns) - set(self.output_cols))
282
-
283
277
  def _batch_inference_validate_snowpark(
284
278
  self,
285
279
  dataset: DataFrame,
286
280
  inference_method: str,
287
281
  ) -> List[str]:
288
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
282
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
283
+ return the available package that exists in the snowflake anaconda channel
284
+
285
+ Args:
286
+ dataset: snowpark dataframe
287
+ inference_method: the inference method such as predict, score...
288
+
289
+ Raises:
290
+ SnowflakeMLException: If the estimator is not fitted, raise error
291
+ SnowflakeMLException: If the session is None, raise error
292
+
293
+ Returns:
294
+ A list of available package that exists in the snowflake anaconda channel
289
295
  """
290
296
  if not self._is_fitted:
291
297
  raise exceptions.SnowflakeMLException(
@@ -359,7 +365,7 @@ class NearestCentroid(BaseTransformer):
359
365
  transform_kwargs = dict(
360
366
  session = dataset._session,
361
367
  dependencies = self._deps,
362
- pass_through_cols = self._get_pass_through_columns(dataset),
368
+ drop_input_cols = self._drop_input_cols,
363
369
  expected_output_cols_type = expected_type_inferred,
364
370
  )
365
371
 
@@ -419,16 +425,16 @@ class NearestCentroid(BaseTransformer):
419
425
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
420
426
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
421
427
  # each row containing a list of values.
422
- expected_dtype = "ARRAY"
428
+ expected_dtype = "array"
423
429
 
424
430
  # If we were unable to assign a type to this transform in the factory, infer the type here.
425
431
  if expected_dtype == "":
426
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
432
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
427
433
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
428
- expected_dtype = "ARRAY"
429
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
434
+ expected_dtype = "array"
435
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
430
436
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
431
- expected_dtype = "ARRAY"
437
+ expected_dtype = "array"
432
438
  else:
433
439
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
434
440
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -446,7 +452,7 @@ class NearestCentroid(BaseTransformer):
446
452
  transform_kwargs = dict(
447
453
  session = dataset._session,
448
454
  dependencies = self._deps,
449
- pass_through_cols = self._get_pass_through_columns(dataset),
455
+ drop_input_cols = self._drop_input_cols,
450
456
  expected_output_cols_type = expected_dtype,
451
457
  )
452
458
 
@@ -497,7 +503,7 @@ class NearestCentroid(BaseTransformer):
497
503
  subproject=_SUBPROJECT,
498
504
  )
499
505
  output_result, fitted_estimator = model_trainer.train_fit_predict(
500
- pass_through_columns=self._get_pass_through_columns(dataset),
506
+ drop_input_cols=self._drop_input_cols,
501
507
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
502
508
  )
503
509
  self._sklearn_object = fitted_estimator
@@ -515,44 +521,6 @@ class NearestCentroid(BaseTransformer):
515
521
  assert self._sklearn_object is not None
516
522
  return self._sklearn_object.embedding_
517
523
 
518
-
519
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
520
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
521
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
522
- """
523
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
524
- if output_cols:
525
- output_cols = [
526
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
527
- for c in output_cols
528
- ]
529
- elif getattr(self._sklearn_object, "classes_", None) is None:
530
- output_cols = [output_cols_prefix]
531
- elif self._sklearn_object is not None:
532
- classes = self._sklearn_object.classes_
533
- if isinstance(classes, numpy.ndarray):
534
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
535
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
536
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
537
- output_cols = []
538
- for i, cl in enumerate(classes):
539
- # For binary classification, there is only one output column for each class
540
- # ndarray as the two classes are complementary.
541
- if len(cl) == 2:
542
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
543
- else:
544
- output_cols.extend([
545
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
546
- ])
547
- else:
548
- output_cols = []
549
-
550
- # Make sure column names are valid snowflake identifiers.
551
- assert output_cols is not None # Make MyPy happy
552
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
553
-
554
- return rv
555
-
556
524
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
557
525
  @telemetry.send_api_usage_telemetry(
558
526
  project=_PROJECT,
@@ -592,7 +560,7 @@ class NearestCentroid(BaseTransformer):
592
560
  transform_kwargs = dict(
593
561
  session=dataset._session,
594
562
  dependencies=self._deps,
595
- pass_through_cols=self._get_pass_through_columns(dataset),
563
+ drop_input_cols = self._drop_input_cols,
596
564
  expected_output_cols_type="float",
597
565
  )
598
566
 
@@ -657,7 +625,7 @@ class NearestCentroid(BaseTransformer):
657
625
  transform_kwargs = dict(
658
626
  session=dataset._session,
659
627
  dependencies=self._deps,
660
- pass_through_cols=self._get_pass_through_columns(dataset),
628
+ drop_input_cols = self._drop_input_cols,
661
629
  expected_output_cols_type="float",
662
630
  )
663
631
  elif isinstance(dataset, pd.DataFrame):
@@ -718,7 +686,7 @@ class NearestCentroid(BaseTransformer):
718
686
  transform_kwargs = dict(
719
687
  session=dataset._session,
720
688
  dependencies=self._deps,
721
- pass_through_cols=self._get_pass_through_columns(dataset),
689
+ drop_input_cols = self._drop_input_cols,
722
690
  expected_output_cols_type="float",
723
691
  )
724
692
 
@@ -783,7 +751,7 @@ class NearestCentroid(BaseTransformer):
783
751
  transform_kwargs = dict(
784
752
  session=dataset._session,
785
753
  dependencies=self._deps,
786
- pass_through_cols=self._get_pass_through_columns(dataset),
754
+ drop_input_cols = self._drop_input_cols,
787
755
  expected_output_cols_type="float",
788
756
  )
789
757
 
@@ -839,13 +807,17 @@ class NearestCentroid(BaseTransformer):
839
807
  transform_kwargs: ScoreKwargsTypedDict = dict()
840
808
 
841
809
  if isinstance(dataset, DataFrame):
810
+ self._deps = self._batch_inference_validate_snowpark(
811
+ dataset=dataset,
812
+ inference_method="score",
813
+ )
842
814
  selected_cols = self._get_active_columns()
843
815
  if len(selected_cols) > 0:
844
816
  dataset = dataset.select(selected_cols)
845
817
  assert isinstance(dataset._session, Session) # keep mypy happy
846
818
  transform_kwargs = dict(
847
819
  session=dataset._session,
848
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
820
+ dependencies=["snowflake-snowpark-python"] + self._deps,
849
821
  score_sproc_imports=['sklearn'],
850
822
  )
851
823
  elif isinstance(dataset, pd.DataFrame):
@@ -919,9 +891,9 @@ class NearestCentroid(BaseTransformer):
919
891
  transform_kwargs = dict(
920
892
  session = dataset._session,
921
893
  dependencies = self._deps,
922
- pass_through_cols = self._get_pass_through_columns(dataset),
923
- expected_output_cols_type = "array",
924
- n_neighbors = n_neighbors,
894
+ drop_input_cols = self._drop_input_cols,
895
+ expected_output_cols_type="array",
896
+ n_neighbors = n_neighbors,
925
897
  return_distance = return_distance
926
898
  )
927
899
  elif isinstance(dataset, pd.DataFrame):
@@ -324,18 +324,24 @@ class NearestNeighbors(BaseTransformer):
324
324
  self._get_model_signatures(dataset)
325
325
  return self
326
326
 
327
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
328
- if self._drop_input_cols:
329
- return []
330
- else:
331
- return list(set(dataset.columns) - set(self.output_cols))
332
-
333
327
  def _batch_inference_validate_snowpark(
334
328
  self,
335
329
  dataset: DataFrame,
336
330
  inference_method: str,
337
331
  ) -> List[str]:
338
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
332
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
333
+ return the available package that exists in the snowflake anaconda channel
334
+
335
+ Args:
336
+ dataset: snowpark dataframe
337
+ inference_method: the inference method such as predict, score...
338
+
339
+ Raises:
340
+ SnowflakeMLException: If the estimator is not fitted, raise error
341
+ SnowflakeMLException: If the session is None, raise error
342
+
343
+ Returns:
344
+ A list of available package that exists in the snowflake anaconda channel
339
345
  """
340
346
  if not self._is_fitted:
341
347
  raise exceptions.SnowflakeMLException(
@@ -407,7 +413,7 @@ class NearestNeighbors(BaseTransformer):
407
413
  transform_kwargs = dict(
408
414
  session = dataset._session,
409
415
  dependencies = self._deps,
410
- pass_through_cols = self._get_pass_through_columns(dataset),
416
+ drop_input_cols = self._drop_input_cols,
411
417
  expected_output_cols_type = expected_type_inferred,
412
418
  )
413
419
 
@@ -467,16 +473,16 @@ class NearestNeighbors(BaseTransformer):
467
473
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
468
474
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
469
475
  # each row containing a list of values.
470
- expected_dtype = "ARRAY"
476
+ expected_dtype = "array"
471
477
 
472
478
  # If we were unable to assign a type to this transform in the factory, infer the type here.
473
479
  if expected_dtype == "":
474
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
480
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
475
481
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
476
- expected_dtype = "ARRAY"
477
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
482
+ expected_dtype = "array"
483
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
478
484
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
479
- expected_dtype = "ARRAY"
485
+ expected_dtype = "array"
480
486
  else:
481
487
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
482
488
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -494,7 +500,7 @@ class NearestNeighbors(BaseTransformer):
494
500
  transform_kwargs = dict(
495
501
  session = dataset._session,
496
502
  dependencies = self._deps,
497
- pass_through_cols = self._get_pass_through_columns(dataset),
503
+ drop_input_cols = self._drop_input_cols,
498
504
  expected_output_cols_type = expected_dtype,
499
505
  )
500
506
 
@@ -545,7 +551,7 @@ class NearestNeighbors(BaseTransformer):
545
551
  subproject=_SUBPROJECT,
546
552
  )
547
553
  output_result, fitted_estimator = model_trainer.train_fit_predict(
548
- pass_through_columns=self._get_pass_through_columns(dataset),
554
+ drop_input_cols=self._drop_input_cols,
549
555
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
550
556
  )
551
557
  self._sklearn_object = fitted_estimator
@@ -563,44 +569,6 @@ class NearestNeighbors(BaseTransformer):
563
569
  assert self._sklearn_object is not None
564
570
  return self._sklearn_object.embedding_
565
571
 
566
-
567
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
568
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
569
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
570
- """
571
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
572
- if output_cols:
573
- output_cols = [
574
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
575
- for c in output_cols
576
- ]
577
- elif getattr(self._sklearn_object, "classes_", None) is None:
578
- output_cols = [output_cols_prefix]
579
- elif self._sklearn_object is not None:
580
- classes = self._sklearn_object.classes_
581
- if isinstance(classes, numpy.ndarray):
582
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
583
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
584
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
585
- output_cols = []
586
- for i, cl in enumerate(classes):
587
- # For binary classification, there is only one output column for each class
588
- # ndarray as the two classes are complementary.
589
- if len(cl) == 2:
590
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
591
- else:
592
- output_cols.extend([
593
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
594
- ])
595
- else:
596
- output_cols = []
597
-
598
- # Make sure column names are valid snowflake identifiers.
599
- assert output_cols is not None # Make MyPy happy
600
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
601
-
602
- return rv
603
-
604
572
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
605
573
  @telemetry.send_api_usage_telemetry(
606
574
  project=_PROJECT,
@@ -640,7 +608,7 @@ class NearestNeighbors(BaseTransformer):
640
608
  transform_kwargs = dict(
641
609
  session=dataset._session,
642
610
  dependencies=self._deps,
643
- pass_through_cols=self._get_pass_through_columns(dataset),
611
+ drop_input_cols = self._drop_input_cols,
644
612
  expected_output_cols_type="float",
645
613
  )
646
614
 
@@ -705,7 +673,7 @@ class NearestNeighbors(BaseTransformer):
705
673
  transform_kwargs = dict(
706
674
  session=dataset._session,
707
675
  dependencies=self._deps,
708
- pass_through_cols=self._get_pass_through_columns(dataset),
676
+ drop_input_cols = self._drop_input_cols,
709
677
  expected_output_cols_type="float",
710
678
  )
711
679
  elif isinstance(dataset, pd.DataFrame):
@@ -766,7 +734,7 @@ class NearestNeighbors(BaseTransformer):
766
734
  transform_kwargs = dict(
767
735
  session=dataset._session,
768
736
  dependencies=self._deps,
769
- pass_through_cols=self._get_pass_through_columns(dataset),
737
+ drop_input_cols = self._drop_input_cols,
770
738
  expected_output_cols_type="float",
771
739
  )
772
740
 
@@ -831,7 +799,7 @@ class NearestNeighbors(BaseTransformer):
831
799
  transform_kwargs = dict(
832
800
  session=dataset._session,
833
801
  dependencies=self._deps,
834
- pass_through_cols=self._get_pass_through_columns(dataset),
802
+ drop_input_cols = self._drop_input_cols,
835
803
  expected_output_cols_type="float",
836
804
  )
837
805
 
@@ -885,13 +853,17 @@ class NearestNeighbors(BaseTransformer):
885
853
  transform_kwargs: ScoreKwargsTypedDict = dict()
886
854
 
887
855
  if isinstance(dataset, DataFrame):
856
+ self._deps = self._batch_inference_validate_snowpark(
857
+ dataset=dataset,
858
+ inference_method="score",
859
+ )
888
860
  selected_cols = self._get_active_columns()
889
861
  if len(selected_cols) > 0:
890
862
  dataset = dataset.select(selected_cols)
891
863
  assert isinstance(dataset._session, Session) # keep mypy happy
892
864
  transform_kwargs = dict(
893
865
  session=dataset._session,
894
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
866
+ dependencies=["snowflake-snowpark-python"] + self._deps,
895
867
  score_sproc_imports=['sklearn'],
896
868
  )
897
869
  elif isinstance(dataset, pd.DataFrame):
@@ -967,9 +939,9 @@ class NearestNeighbors(BaseTransformer):
967
939
  transform_kwargs = dict(
968
940
  session = dataset._session,
969
941
  dependencies = self._deps,
970
- pass_through_cols = self._get_pass_through_columns(dataset),
971
- expected_output_cols_type = "array",
972
- n_neighbors = n_neighbors,
942
+ drop_input_cols = self._drop_input_cols,
943
+ expected_output_cols_type="array",
944
+ n_neighbors = n_neighbors,
973
945
  return_distance = return_distance
974
946
  )
975
947
  elif isinstance(dataset, pd.DataFrame):