snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -263,18 +263,24 @@ class OAS(BaseTransformer):
263
263
  self._get_model_signatures(dataset)
264
264
  return self
265
265
 
266
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
267
- if self._drop_input_cols:
268
- return []
269
- else:
270
- return list(set(dataset.columns) - set(self.output_cols))
271
-
272
266
  def _batch_inference_validate_snowpark(
273
267
  self,
274
268
  dataset: DataFrame,
275
269
  inference_method: str,
276
270
  ) -> List[str]:
277
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
271
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
272
+ return the available package that exists in the snowflake anaconda channel
273
+
274
+ Args:
275
+ dataset: snowpark dataframe
276
+ inference_method: the inference method such as predict, score...
277
+
278
+ Raises:
279
+ SnowflakeMLException: If the estimator is not fitted, raise error
280
+ SnowflakeMLException: If the session is None, raise error
281
+
282
+ Returns:
283
+ A list of available package that exists in the snowflake anaconda channel
278
284
  """
279
285
  if not self._is_fitted:
280
286
  raise exceptions.SnowflakeMLException(
@@ -346,7 +352,7 @@ class OAS(BaseTransformer):
346
352
  transform_kwargs = dict(
347
353
  session = dataset._session,
348
354
  dependencies = self._deps,
349
- pass_through_cols = self._get_pass_through_columns(dataset),
355
+ drop_input_cols = self._drop_input_cols,
350
356
  expected_output_cols_type = expected_type_inferred,
351
357
  )
352
358
 
@@ -406,16 +412,16 @@ class OAS(BaseTransformer):
406
412
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
407
413
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
408
414
  # each row containing a list of values.
409
- expected_dtype = "ARRAY"
415
+ expected_dtype = "array"
410
416
 
411
417
  # If we were unable to assign a type to this transform in the factory, infer the type here.
412
418
  if expected_dtype == "":
413
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
419
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
414
420
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
415
- expected_dtype = "ARRAY"
416
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
421
+ expected_dtype = "array"
422
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
417
423
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
418
- expected_dtype = "ARRAY"
424
+ expected_dtype = "array"
419
425
  else:
420
426
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
421
427
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -433,7 +439,7 @@ class OAS(BaseTransformer):
433
439
  transform_kwargs = dict(
434
440
  session = dataset._session,
435
441
  dependencies = self._deps,
436
- pass_through_cols = self._get_pass_through_columns(dataset),
442
+ drop_input_cols = self._drop_input_cols,
437
443
  expected_output_cols_type = expected_dtype,
438
444
  )
439
445
 
@@ -484,7 +490,7 @@ class OAS(BaseTransformer):
484
490
  subproject=_SUBPROJECT,
485
491
  )
486
492
  output_result, fitted_estimator = model_trainer.train_fit_predict(
487
- pass_through_columns=self._get_pass_through_columns(dataset),
493
+ drop_input_cols=self._drop_input_cols,
488
494
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
489
495
  )
490
496
  self._sklearn_object = fitted_estimator
@@ -502,44 +508,6 @@ class OAS(BaseTransformer):
502
508
  assert self._sklearn_object is not None
503
509
  return self._sklearn_object.embedding_
504
510
 
505
-
506
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
507
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
508
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
509
- """
510
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
511
- if output_cols:
512
- output_cols = [
513
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
514
- for c in output_cols
515
- ]
516
- elif getattr(self._sklearn_object, "classes_", None) is None:
517
- output_cols = [output_cols_prefix]
518
- elif self._sklearn_object is not None:
519
- classes = self._sklearn_object.classes_
520
- if isinstance(classes, numpy.ndarray):
521
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
522
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
523
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
524
- output_cols = []
525
- for i, cl in enumerate(classes):
526
- # For binary classification, there is only one output column for each class
527
- # ndarray as the two classes are complementary.
528
- if len(cl) == 2:
529
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
530
- else:
531
- output_cols.extend([
532
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
533
- ])
534
- else:
535
- output_cols = []
536
-
537
- # Make sure column names are valid snowflake identifiers.
538
- assert output_cols is not None # Make MyPy happy
539
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
540
-
541
- return rv
542
-
543
511
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
544
512
  @telemetry.send_api_usage_telemetry(
545
513
  project=_PROJECT,
@@ -579,7 +547,7 @@ class OAS(BaseTransformer):
579
547
  transform_kwargs = dict(
580
548
  session=dataset._session,
581
549
  dependencies=self._deps,
582
- pass_through_cols=self._get_pass_through_columns(dataset),
550
+ drop_input_cols = self._drop_input_cols,
583
551
  expected_output_cols_type="float",
584
552
  )
585
553
 
@@ -644,7 +612,7 @@ class OAS(BaseTransformer):
644
612
  transform_kwargs = dict(
645
613
  session=dataset._session,
646
614
  dependencies=self._deps,
647
- pass_through_cols=self._get_pass_through_columns(dataset),
615
+ drop_input_cols = self._drop_input_cols,
648
616
  expected_output_cols_type="float",
649
617
  )
650
618
  elif isinstance(dataset, pd.DataFrame):
@@ -705,7 +673,7 @@ class OAS(BaseTransformer):
705
673
  transform_kwargs = dict(
706
674
  session=dataset._session,
707
675
  dependencies=self._deps,
708
- pass_through_cols=self._get_pass_through_columns(dataset),
676
+ drop_input_cols = self._drop_input_cols,
709
677
  expected_output_cols_type="float",
710
678
  )
711
679
 
@@ -770,7 +738,7 @@ class OAS(BaseTransformer):
770
738
  transform_kwargs = dict(
771
739
  session=dataset._session,
772
740
  dependencies=self._deps,
773
- pass_through_cols=self._get_pass_through_columns(dataset),
741
+ drop_input_cols = self._drop_input_cols,
774
742
  expected_output_cols_type="float",
775
743
  )
776
744
 
@@ -826,13 +794,17 @@ class OAS(BaseTransformer):
826
794
  transform_kwargs: ScoreKwargsTypedDict = dict()
827
795
 
828
796
  if isinstance(dataset, DataFrame):
797
+ self._deps = self._batch_inference_validate_snowpark(
798
+ dataset=dataset,
799
+ inference_method="score",
800
+ )
829
801
  selected_cols = self._get_active_columns()
830
802
  if len(selected_cols) > 0:
831
803
  dataset = dataset.select(selected_cols)
832
804
  assert isinstance(dataset._session, Session) # keep mypy happy
833
805
  transform_kwargs = dict(
834
806
  session=dataset._session,
835
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
807
+ dependencies=["snowflake-snowpark-python"] + self._deps,
836
808
  score_sproc_imports=['sklearn'],
837
809
  )
838
810
  elif isinstance(dataset, pd.DataFrame):
@@ -906,9 +878,9 @@ class OAS(BaseTransformer):
906
878
  transform_kwargs = dict(
907
879
  session = dataset._session,
908
880
  dependencies = self._deps,
909
- pass_through_cols = self._get_pass_through_columns(dataset),
910
- expected_output_cols_type = "array",
911
- n_neighbors = n_neighbors,
881
+ drop_input_cols = self._drop_input_cols,
882
+ expected_output_cols_type="array",
883
+ n_neighbors = n_neighbors,
912
884
  return_distance = return_distance
913
885
  )
914
886
  elif isinstance(dataset, pd.DataFrame):
@@ -269,18 +269,24 @@ class ShrunkCovariance(BaseTransformer):
269
269
  self._get_model_signatures(dataset)
270
270
  return self
271
271
 
272
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
273
- if self._drop_input_cols:
274
- return []
275
- else:
276
- return list(set(dataset.columns) - set(self.output_cols))
277
-
278
272
  def _batch_inference_validate_snowpark(
279
273
  self,
280
274
  dataset: DataFrame,
281
275
  inference_method: str,
282
276
  ) -> List[str]:
283
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
277
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
278
+ return the available package that exists in the snowflake anaconda channel
279
+
280
+ Args:
281
+ dataset: snowpark dataframe
282
+ inference_method: the inference method such as predict, score...
283
+
284
+ Raises:
285
+ SnowflakeMLException: If the estimator is not fitted, raise error
286
+ SnowflakeMLException: If the session is None, raise error
287
+
288
+ Returns:
289
+ A list of available package that exists in the snowflake anaconda channel
284
290
  """
285
291
  if not self._is_fitted:
286
292
  raise exceptions.SnowflakeMLException(
@@ -352,7 +358,7 @@ class ShrunkCovariance(BaseTransformer):
352
358
  transform_kwargs = dict(
353
359
  session = dataset._session,
354
360
  dependencies = self._deps,
355
- pass_through_cols = self._get_pass_through_columns(dataset),
361
+ drop_input_cols = self._drop_input_cols,
356
362
  expected_output_cols_type = expected_type_inferred,
357
363
  )
358
364
 
@@ -412,16 +418,16 @@ class ShrunkCovariance(BaseTransformer):
412
418
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
413
419
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
414
420
  # each row containing a list of values.
415
- expected_dtype = "ARRAY"
421
+ expected_dtype = "array"
416
422
 
417
423
  # If we were unable to assign a type to this transform in the factory, infer the type here.
418
424
  if expected_dtype == "":
419
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
425
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
420
426
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
421
- expected_dtype = "ARRAY"
422
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
427
+ expected_dtype = "array"
428
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
423
429
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
424
- expected_dtype = "ARRAY"
430
+ expected_dtype = "array"
425
431
  else:
426
432
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
427
433
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -439,7 +445,7 @@ class ShrunkCovariance(BaseTransformer):
439
445
  transform_kwargs = dict(
440
446
  session = dataset._session,
441
447
  dependencies = self._deps,
442
- pass_through_cols = self._get_pass_through_columns(dataset),
448
+ drop_input_cols = self._drop_input_cols,
443
449
  expected_output_cols_type = expected_dtype,
444
450
  )
445
451
 
@@ -490,7 +496,7 @@ class ShrunkCovariance(BaseTransformer):
490
496
  subproject=_SUBPROJECT,
491
497
  )
492
498
  output_result, fitted_estimator = model_trainer.train_fit_predict(
493
- pass_through_columns=self._get_pass_through_columns(dataset),
499
+ drop_input_cols=self._drop_input_cols,
494
500
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
495
501
  )
496
502
  self._sklearn_object = fitted_estimator
@@ -508,44 +514,6 @@ class ShrunkCovariance(BaseTransformer):
508
514
  assert self._sklearn_object is not None
509
515
  return self._sklearn_object.embedding_
510
516
 
511
-
512
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
513
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
514
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
515
- """
516
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
517
- if output_cols:
518
- output_cols = [
519
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
520
- for c in output_cols
521
- ]
522
- elif getattr(self._sklearn_object, "classes_", None) is None:
523
- output_cols = [output_cols_prefix]
524
- elif self._sklearn_object is not None:
525
- classes = self._sklearn_object.classes_
526
- if isinstance(classes, numpy.ndarray):
527
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
528
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
529
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
530
- output_cols = []
531
- for i, cl in enumerate(classes):
532
- # For binary classification, there is only one output column for each class
533
- # ndarray as the two classes are complementary.
534
- if len(cl) == 2:
535
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
536
- else:
537
- output_cols.extend([
538
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
539
- ])
540
- else:
541
- output_cols = []
542
-
543
- # Make sure column names are valid snowflake identifiers.
544
- assert output_cols is not None # Make MyPy happy
545
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
546
-
547
- return rv
548
-
549
517
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
550
518
  @telemetry.send_api_usage_telemetry(
551
519
  project=_PROJECT,
@@ -585,7 +553,7 @@ class ShrunkCovariance(BaseTransformer):
585
553
  transform_kwargs = dict(
586
554
  session=dataset._session,
587
555
  dependencies=self._deps,
588
- pass_through_cols=self._get_pass_through_columns(dataset),
556
+ drop_input_cols = self._drop_input_cols,
589
557
  expected_output_cols_type="float",
590
558
  )
591
559
 
@@ -650,7 +618,7 @@ class ShrunkCovariance(BaseTransformer):
650
618
  transform_kwargs = dict(
651
619
  session=dataset._session,
652
620
  dependencies=self._deps,
653
- pass_through_cols=self._get_pass_through_columns(dataset),
621
+ drop_input_cols = self._drop_input_cols,
654
622
  expected_output_cols_type="float",
655
623
  )
656
624
  elif isinstance(dataset, pd.DataFrame):
@@ -711,7 +679,7 @@ class ShrunkCovariance(BaseTransformer):
711
679
  transform_kwargs = dict(
712
680
  session=dataset._session,
713
681
  dependencies=self._deps,
714
- pass_through_cols=self._get_pass_through_columns(dataset),
682
+ drop_input_cols = self._drop_input_cols,
715
683
  expected_output_cols_type="float",
716
684
  )
717
685
 
@@ -776,7 +744,7 @@ class ShrunkCovariance(BaseTransformer):
776
744
  transform_kwargs = dict(
777
745
  session=dataset._session,
778
746
  dependencies=self._deps,
779
- pass_through_cols=self._get_pass_through_columns(dataset),
747
+ drop_input_cols = self._drop_input_cols,
780
748
  expected_output_cols_type="float",
781
749
  )
782
750
 
@@ -832,13 +800,17 @@ class ShrunkCovariance(BaseTransformer):
832
800
  transform_kwargs: ScoreKwargsTypedDict = dict()
833
801
 
834
802
  if isinstance(dataset, DataFrame):
803
+ self._deps = self._batch_inference_validate_snowpark(
804
+ dataset=dataset,
805
+ inference_method="score",
806
+ )
835
807
  selected_cols = self._get_active_columns()
836
808
  if len(selected_cols) > 0:
837
809
  dataset = dataset.select(selected_cols)
838
810
  assert isinstance(dataset._session, Session) # keep mypy happy
839
811
  transform_kwargs = dict(
840
812
  session=dataset._session,
841
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
813
+ dependencies=["snowflake-snowpark-python"] + self._deps,
842
814
  score_sproc_imports=['sklearn'],
843
815
  )
844
816
  elif isinstance(dataset, pd.DataFrame):
@@ -912,9 +884,9 @@ class ShrunkCovariance(BaseTransformer):
912
884
  transform_kwargs = dict(
913
885
  session = dataset._session,
914
886
  dependencies = self._deps,
915
- pass_through_cols = self._get_pass_through_columns(dataset),
916
- expected_output_cols_type = "array",
917
- n_neighbors = n_neighbors,
887
+ drop_input_cols = self._drop_input_cols,
888
+ expected_output_cols_type="array",
889
+ n_neighbors = n_neighbors,
918
890
  return_distance = return_distance
919
891
  )
920
892
  elif isinstance(dataset, pd.DataFrame):
@@ -375,18 +375,24 @@ class DictionaryLearning(BaseTransformer):
375
375
  self._get_model_signatures(dataset)
376
376
  return self
377
377
 
378
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
379
- if self._drop_input_cols:
380
- return []
381
- else:
382
- return list(set(dataset.columns) - set(self.output_cols))
383
-
384
378
  def _batch_inference_validate_snowpark(
385
379
  self,
386
380
  dataset: DataFrame,
387
381
  inference_method: str,
388
382
  ) -> List[str]:
389
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
383
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
384
+ return the available package that exists in the snowflake anaconda channel
385
+
386
+ Args:
387
+ dataset: snowpark dataframe
388
+ inference_method: the inference method such as predict, score...
389
+
390
+ Raises:
391
+ SnowflakeMLException: If the estimator is not fitted, raise error
392
+ SnowflakeMLException: If the session is None, raise error
393
+
394
+ Returns:
395
+ A list of available package that exists in the snowflake anaconda channel
390
396
  """
391
397
  if not self._is_fitted:
392
398
  raise exceptions.SnowflakeMLException(
@@ -458,7 +464,7 @@ class DictionaryLearning(BaseTransformer):
458
464
  transform_kwargs = dict(
459
465
  session = dataset._session,
460
466
  dependencies = self._deps,
461
- pass_through_cols = self._get_pass_through_columns(dataset),
467
+ drop_input_cols = self._drop_input_cols,
462
468
  expected_output_cols_type = expected_type_inferred,
463
469
  )
464
470
 
@@ -520,16 +526,16 @@ class DictionaryLearning(BaseTransformer):
520
526
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
521
527
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
522
528
  # each row containing a list of values.
523
- expected_dtype = "ARRAY"
529
+ expected_dtype = "array"
524
530
 
525
531
  # If we were unable to assign a type to this transform in the factory, infer the type here.
526
532
  if expected_dtype == "":
527
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
533
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
528
534
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
529
- expected_dtype = "ARRAY"
530
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
535
+ expected_dtype = "array"
536
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
531
537
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
532
- expected_dtype = "ARRAY"
538
+ expected_dtype = "array"
533
539
  else:
534
540
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
535
541
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -547,7 +553,7 @@ class DictionaryLearning(BaseTransformer):
547
553
  transform_kwargs = dict(
548
554
  session = dataset._session,
549
555
  dependencies = self._deps,
550
- pass_through_cols = self._get_pass_through_columns(dataset),
556
+ drop_input_cols = self._drop_input_cols,
551
557
  expected_output_cols_type = expected_dtype,
552
558
  )
553
559
 
@@ -598,7 +604,7 @@ class DictionaryLearning(BaseTransformer):
598
604
  subproject=_SUBPROJECT,
599
605
  )
600
606
  output_result, fitted_estimator = model_trainer.train_fit_predict(
601
- pass_through_columns=self._get_pass_through_columns(dataset),
607
+ drop_input_cols=self._drop_input_cols,
602
608
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
603
609
  )
604
610
  self._sklearn_object = fitted_estimator
@@ -616,44 +622,6 @@ class DictionaryLearning(BaseTransformer):
616
622
  assert self._sklearn_object is not None
617
623
  return self._sklearn_object.embedding_
618
624
 
619
-
620
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
621
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
622
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
623
- """
624
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
625
- if output_cols:
626
- output_cols = [
627
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
628
- for c in output_cols
629
- ]
630
- elif getattr(self._sklearn_object, "classes_", None) is None:
631
- output_cols = [output_cols_prefix]
632
- elif self._sklearn_object is not None:
633
- classes = self._sklearn_object.classes_
634
- if isinstance(classes, numpy.ndarray):
635
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
636
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
637
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
638
- output_cols = []
639
- for i, cl in enumerate(classes):
640
- # For binary classification, there is only one output column for each class
641
- # ndarray as the two classes are complementary.
642
- if len(cl) == 2:
643
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
644
- else:
645
- output_cols.extend([
646
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
647
- ])
648
- else:
649
- output_cols = []
650
-
651
- # Make sure column names are valid snowflake identifiers.
652
- assert output_cols is not None # Make MyPy happy
653
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
654
-
655
- return rv
656
-
657
625
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
658
626
  @telemetry.send_api_usage_telemetry(
659
627
  project=_PROJECT,
@@ -693,7 +661,7 @@ class DictionaryLearning(BaseTransformer):
693
661
  transform_kwargs = dict(
694
662
  session=dataset._session,
695
663
  dependencies=self._deps,
696
- pass_through_cols=self._get_pass_through_columns(dataset),
664
+ drop_input_cols = self._drop_input_cols,
697
665
  expected_output_cols_type="float",
698
666
  )
699
667
 
@@ -758,7 +726,7 @@ class DictionaryLearning(BaseTransformer):
758
726
  transform_kwargs = dict(
759
727
  session=dataset._session,
760
728
  dependencies=self._deps,
761
- pass_through_cols=self._get_pass_through_columns(dataset),
729
+ drop_input_cols = self._drop_input_cols,
762
730
  expected_output_cols_type="float",
763
731
  )
764
732
  elif isinstance(dataset, pd.DataFrame):
@@ -819,7 +787,7 @@ class DictionaryLearning(BaseTransformer):
819
787
  transform_kwargs = dict(
820
788
  session=dataset._session,
821
789
  dependencies=self._deps,
822
- pass_through_cols=self._get_pass_through_columns(dataset),
790
+ drop_input_cols = self._drop_input_cols,
823
791
  expected_output_cols_type="float",
824
792
  )
825
793
 
@@ -884,7 +852,7 @@ class DictionaryLearning(BaseTransformer):
884
852
  transform_kwargs = dict(
885
853
  session=dataset._session,
886
854
  dependencies=self._deps,
887
- pass_through_cols=self._get_pass_through_columns(dataset),
855
+ drop_input_cols = self._drop_input_cols,
888
856
  expected_output_cols_type="float",
889
857
  )
890
858
 
@@ -938,13 +906,17 @@ class DictionaryLearning(BaseTransformer):
938
906
  transform_kwargs: ScoreKwargsTypedDict = dict()
939
907
 
940
908
  if isinstance(dataset, DataFrame):
909
+ self._deps = self._batch_inference_validate_snowpark(
910
+ dataset=dataset,
911
+ inference_method="score",
912
+ )
941
913
  selected_cols = self._get_active_columns()
942
914
  if len(selected_cols) > 0:
943
915
  dataset = dataset.select(selected_cols)
944
916
  assert isinstance(dataset._session, Session) # keep mypy happy
945
917
  transform_kwargs = dict(
946
918
  session=dataset._session,
947
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
919
+ dependencies=["snowflake-snowpark-python"] + self._deps,
948
920
  score_sproc_imports=['sklearn'],
949
921
  )
950
922
  elif isinstance(dataset, pd.DataFrame):
@@ -1018,9 +990,9 @@ class DictionaryLearning(BaseTransformer):
1018
990
  transform_kwargs = dict(
1019
991
  session = dataset._session,
1020
992
  dependencies = self._deps,
1021
- pass_through_cols = self._get_pass_through_columns(dataset),
1022
- expected_output_cols_type = "array",
1023
- n_neighbors = n_neighbors,
993
+ drop_input_cols = self._drop_input_cols,
994
+ expected_output_cols_type="array",
995
+ n_neighbors = n_neighbors,
1024
996
  return_distance = return_distance
1025
997
  )
1026
998
  elif isinstance(dataset, pd.DataFrame):