snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -305,18 +305,24 @@ class MultiTaskLasso(BaseTransformer):
305
305
  self._get_model_signatures(dataset)
306
306
  return self
307
307
 
308
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
309
- if self._drop_input_cols:
310
- return []
311
- else:
312
- return list(set(dataset.columns) - set(self.output_cols))
313
-
314
308
  def _batch_inference_validate_snowpark(
315
309
  self,
316
310
  dataset: DataFrame,
317
311
  inference_method: str,
318
312
  ) -> List[str]:
319
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
313
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
314
+ return the available package that exists in the snowflake anaconda channel
315
+
316
+ Args:
317
+ dataset: snowpark dataframe
318
+ inference_method: the inference method such as predict, score...
319
+
320
+ Raises:
321
+ SnowflakeMLException: If the estimator is not fitted, raise error
322
+ SnowflakeMLException: If the session is None, raise error
323
+
324
+ Returns:
325
+ A list of available package that exists in the snowflake anaconda channel
320
326
  """
321
327
  if not self._is_fitted:
322
328
  raise exceptions.SnowflakeMLException(
@@ -390,7 +396,7 @@ class MultiTaskLasso(BaseTransformer):
390
396
  transform_kwargs = dict(
391
397
  session = dataset._session,
392
398
  dependencies = self._deps,
393
- pass_through_cols = self._get_pass_through_columns(dataset),
399
+ drop_input_cols = self._drop_input_cols,
394
400
  expected_output_cols_type = expected_type_inferred,
395
401
  )
396
402
 
@@ -450,16 +456,16 @@ class MultiTaskLasso(BaseTransformer):
450
456
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
451
457
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
452
458
  # each row containing a list of values.
453
- expected_dtype = "ARRAY"
459
+ expected_dtype = "array"
454
460
 
455
461
  # If we were unable to assign a type to this transform in the factory, infer the type here.
456
462
  if expected_dtype == "":
457
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
463
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
458
464
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
459
- expected_dtype = "ARRAY"
460
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
465
+ expected_dtype = "array"
466
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
461
467
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
462
- expected_dtype = "ARRAY"
468
+ expected_dtype = "array"
463
469
  else:
464
470
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
465
471
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -477,7 +483,7 @@ class MultiTaskLasso(BaseTransformer):
477
483
  transform_kwargs = dict(
478
484
  session = dataset._session,
479
485
  dependencies = self._deps,
480
- pass_through_cols = self._get_pass_through_columns(dataset),
486
+ drop_input_cols = self._drop_input_cols,
481
487
  expected_output_cols_type = expected_dtype,
482
488
  )
483
489
 
@@ -528,7 +534,7 @@ class MultiTaskLasso(BaseTransformer):
528
534
  subproject=_SUBPROJECT,
529
535
  )
530
536
  output_result, fitted_estimator = model_trainer.train_fit_predict(
531
- pass_through_columns=self._get_pass_through_columns(dataset),
537
+ drop_input_cols=self._drop_input_cols,
532
538
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
533
539
  )
534
540
  self._sklearn_object = fitted_estimator
@@ -546,44 +552,6 @@ class MultiTaskLasso(BaseTransformer):
546
552
  assert self._sklearn_object is not None
547
553
  return self._sklearn_object.embedding_
548
554
 
549
-
550
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
551
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
552
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
553
- """
554
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
555
- if output_cols:
556
- output_cols = [
557
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
558
- for c in output_cols
559
- ]
560
- elif getattr(self._sklearn_object, "classes_", None) is None:
561
- output_cols = [output_cols_prefix]
562
- elif self._sklearn_object is not None:
563
- classes = self._sklearn_object.classes_
564
- if isinstance(classes, numpy.ndarray):
565
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
566
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
567
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
568
- output_cols = []
569
- for i, cl in enumerate(classes):
570
- # For binary classification, there is only one output column for each class
571
- # ndarray as the two classes are complementary.
572
- if len(cl) == 2:
573
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
574
- else:
575
- output_cols.extend([
576
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
577
- ])
578
- else:
579
- output_cols = []
580
-
581
- # Make sure column names are valid snowflake identifiers.
582
- assert output_cols is not None # Make MyPy happy
583
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
584
-
585
- return rv
586
-
587
555
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
588
556
  @telemetry.send_api_usage_telemetry(
589
557
  project=_PROJECT,
@@ -623,7 +591,7 @@ class MultiTaskLasso(BaseTransformer):
623
591
  transform_kwargs = dict(
624
592
  session=dataset._session,
625
593
  dependencies=self._deps,
626
- pass_through_cols=self._get_pass_through_columns(dataset),
594
+ drop_input_cols = self._drop_input_cols,
627
595
  expected_output_cols_type="float",
628
596
  )
629
597
 
@@ -688,7 +656,7 @@ class MultiTaskLasso(BaseTransformer):
688
656
  transform_kwargs = dict(
689
657
  session=dataset._session,
690
658
  dependencies=self._deps,
691
- pass_through_cols=self._get_pass_through_columns(dataset),
659
+ drop_input_cols = self._drop_input_cols,
692
660
  expected_output_cols_type="float",
693
661
  )
694
662
  elif isinstance(dataset, pd.DataFrame):
@@ -749,7 +717,7 @@ class MultiTaskLasso(BaseTransformer):
749
717
  transform_kwargs = dict(
750
718
  session=dataset._session,
751
719
  dependencies=self._deps,
752
- pass_through_cols=self._get_pass_through_columns(dataset),
720
+ drop_input_cols = self._drop_input_cols,
753
721
  expected_output_cols_type="float",
754
722
  )
755
723
 
@@ -814,7 +782,7 @@ class MultiTaskLasso(BaseTransformer):
814
782
  transform_kwargs = dict(
815
783
  session=dataset._session,
816
784
  dependencies=self._deps,
817
- pass_through_cols=self._get_pass_through_columns(dataset),
785
+ drop_input_cols = self._drop_input_cols,
818
786
  expected_output_cols_type="float",
819
787
  )
820
788
 
@@ -870,13 +838,17 @@ class MultiTaskLasso(BaseTransformer):
870
838
  transform_kwargs: ScoreKwargsTypedDict = dict()
871
839
 
872
840
  if isinstance(dataset, DataFrame):
841
+ self._deps = self._batch_inference_validate_snowpark(
842
+ dataset=dataset,
843
+ inference_method="score",
844
+ )
873
845
  selected_cols = self._get_active_columns()
874
846
  if len(selected_cols) > 0:
875
847
  dataset = dataset.select(selected_cols)
876
848
  assert isinstance(dataset._session, Session) # keep mypy happy
877
849
  transform_kwargs = dict(
878
850
  session=dataset._session,
879
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
851
+ dependencies=["snowflake-snowpark-python"] + self._deps,
880
852
  score_sproc_imports=['sklearn'],
881
853
  )
882
854
  elif isinstance(dataset, pd.DataFrame):
@@ -950,9 +922,9 @@ class MultiTaskLasso(BaseTransformer):
950
922
  transform_kwargs = dict(
951
923
  session = dataset._session,
952
924
  dependencies = self._deps,
953
- pass_through_cols = self._get_pass_through_columns(dataset),
954
- expected_output_cols_type = "array",
955
- n_neighbors = n_neighbors,
925
+ drop_input_cols = self._drop_input_cols,
926
+ expected_output_cols_type="array",
927
+ n_neighbors = n_neighbors,
956
928
  return_distance = return_distance
957
929
  )
958
930
  elif isinstance(dataset, pd.DataFrame):
@@ -340,18 +340,24 @@ class MultiTaskLassoCV(BaseTransformer):
340
340
  self._get_model_signatures(dataset)
341
341
  return self
342
342
 
343
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
344
- if self._drop_input_cols:
345
- return []
346
- else:
347
- return list(set(dataset.columns) - set(self.output_cols))
348
-
349
343
  def _batch_inference_validate_snowpark(
350
344
  self,
351
345
  dataset: DataFrame,
352
346
  inference_method: str,
353
347
  ) -> List[str]:
354
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
348
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
349
+ return the available package that exists in the snowflake anaconda channel
350
+
351
+ Args:
352
+ dataset: snowpark dataframe
353
+ inference_method: the inference method such as predict, score...
354
+
355
+ Raises:
356
+ SnowflakeMLException: If the estimator is not fitted, raise error
357
+ SnowflakeMLException: If the session is None, raise error
358
+
359
+ Returns:
360
+ A list of available package that exists in the snowflake anaconda channel
355
361
  """
356
362
  if not self._is_fitted:
357
363
  raise exceptions.SnowflakeMLException(
@@ -425,7 +431,7 @@ class MultiTaskLassoCV(BaseTransformer):
425
431
  transform_kwargs = dict(
426
432
  session = dataset._session,
427
433
  dependencies = self._deps,
428
- pass_through_cols = self._get_pass_through_columns(dataset),
434
+ drop_input_cols = self._drop_input_cols,
429
435
  expected_output_cols_type = expected_type_inferred,
430
436
  )
431
437
 
@@ -485,16 +491,16 @@ class MultiTaskLassoCV(BaseTransformer):
485
491
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
486
492
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
487
493
  # each row containing a list of values.
488
- expected_dtype = "ARRAY"
494
+ expected_dtype = "array"
489
495
 
490
496
  # If we were unable to assign a type to this transform in the factory, infer the type here.
491
497
  if expected_dtype == "":
492
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
498
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
493
499
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
494
- expected_dtype = "ARRAY"
495
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
500
+ expected_dtype = "array"
501
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
496
502
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
497
- expected_dtype = "ARRAY"
503
+ expected_dtype = "array"
498
504
  else:
499
505
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
500
506
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -512,7 +518,7 @@ class MultiTaskLassoCV(BaseTransformer):
512
518
  transform_kwargs = dict(
513
519
  session = dataset._session,
514
520
  dependencies = self._deps,
515
- pass_through_cols = self._get_pass_through_columns(dataset),
521
+ drop_input_cols = self._drop_input_cols,
516
522
  expected_output_cols_type = expected_dtype,
517
523
  )
518
524
 
@@ -563,7 +569,7 @@ class MultiTaskLassoCV(BaseTransformer):
563
569
  subproject=_SUBPROJECT,
564
570
  )
565
571
  output_result, fitted_estimator = model_trainer.train_fit_predict(
566
- pass_through_columns=self._get_pass_through_columns(dataset),
572
+ drop_input_cols=self._drop_input_cols,
567
573
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
568
574
  )
569
575
  self._sklearn_object = fitted_estimator
@@ -581,44 +587,6 @@ class MultiTaskLassoCV(BaseTransformer):
581
587
  assert self._sklearn_object is not None
582
588
  return self._sklearn_object.embedding_
583
589
 
584
-
585
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
586
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
587
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
588
- """
589
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
590
- if output_cols:
591
- output_cols = [
592
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
593
- for c in output_cols
594
- ]
595
- elif getattr(self._sklearn_object, "classes_", None) is None:
596
- output_cols = [output_cols_prefix]
597
- elif self._sklearn_object is not None:
598
- classes = self._sklearn_object.classes_
599
- if isinstance(classes, numpy.ndarray):
600
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
601
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
602
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
603
- output_cols = []
604
- for i, cl in enumerate(classes):
605
- # For binary classification, there is only one output column for each class
606
- # ndarray as the two classes are complementary.
607
- if len(cl) == 2:
608
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
609
- else:
610
- output_cols.extend([
611
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
612
- ])
613
- else:
614
- output_cols = []
615
-
616
- # Make sure column names are valid snowflake identifiers.
617
- assert output_cols is not None # Make MyPy happy
618
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
619
-
620
- return rv
621
-
622
590
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
623
591
  @telemetry.send_api_usage_telemetry(
624
592
  project=_PROJECT,
@@ -658,7 +626,7 @@ class MultiTaskLassoCV(BaseTransformer):
658
626
  transform_kwargs = dict(
659
627
  session=dataset._session,
660
628
  dependencies=self._deps,
661
- pass_through_cols=self._get_pass_through_columns(dataset),
629
+ drop_input_cols = self._drop_input_cols,
662
630
  expected_output_cols_type="float",
663
631
  )
664
632
 
@@ -723,7 +691,7 @@ class MultiTaskLassoCV(BaseTransformer):
723
691
  transform_kwargs = dict(
724
692
  session=dataset._session,
725
693
  dependencies=self._deps,
726
- pass_through_cols=self._get_pass_through_columns(dataset),
694
+ drop_input_cols = self._drop_input_cols,
727
695
  expected_output_cols_type="float",
728
696
  )
729
697
  elif isinstance(dataset, pd.DataFrame):
@@ -784,7 +752,7 @@ class MultiTaskLassoCV(BaseTransformer):
784
752
  transform_kwargs = dict(
785
753
  session=dataset._session,
786
754
  dependencies=self._deps,
787
- pass_through_cols=self._get_pass_through_columns(dataset),
755
+ drop_input_cols = self._drop_input_cols,
788
756
  expected_output_cols_type="float",
789
757
  )
790
758
 
@@ -849,7 +817,7 @@ class MultiTaskLassoCV(BaseTransformer):
849
817
  transform_kwargs = dict(
850
818
  session=dataset._session,
851
819
  dependencies=self._deps,
852
- pass_through_cols=self._get_pass_through_columns(dataset),
820
+ drop_input_cols = self._drop_input_cols,
853
821
  expected_output_cols_type="float",
854
822
  )
855
823
 
@@ -905,13 +873,17 @@ class MultiTaskLassoCV(BaseTransformer):
905
873
  transform_kwargs: ScoreKwargsTypedDict = dict()
906
874
 
907
875
  if isinstance(dataset, DataFrame):
876
+ self._deps = self._batch_inference_validate_snowpark(
877
+ dataset=dataset,
878
+ inference_method="score",
879
+ )
908
880
  selected_cols = self._get_active_columns()
909
881
  if len(selected_cols) > 0:
910
882
  dataset = dataset.select(selected_cols)
911
883
  assert isinstance(dataset._session, Session) # keep mypy happy
912
884
  transform_kwargs = dict(
913
885
  session=dataset._session,
914
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
886
+ dependencies=["snowflake-snowpark-python"] + self._deps,
915
887
  score_sproc_imports=['sklearn'],
916
888
  )
917
889
  elif isinstance(dataset, pd.DataFrame):
@@ -985,9 +957,9 @@ class MultiTaskLassoCV(BaseTransformer):
985
957
  transform_kwargs = dict(
986
958
  session = dataset._session,
987
959
  dependencies = self._deps,
988
- pass_through_cols = self._get_pass_through_columns(dataset),
989
- expected_output_cols_type = "array",
990
- n_neighbors = n_neighbors,
960
+ drop_input_cols = self._drop_input_cols,
961
+ expected_output_cols_type="array",
962
+ n_neighbors = n_neighbors,
991
963
  return_distance = return_distance
992
964
  )
993
965
  elif isinstance(dataset, pd.DataFrame):
@@ -288,18 +288,24 @@ class OrthogonalMatchingPursuit(BaseTransformer):
288
288
  self._get_model_signatures(dataset)
289
289
  return self
290
290
 
291
- def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
292
- if self._drop_input_cols:
293
- return []
294
- else:
295
- return list(set(dataset.columns) - set(self.output_cols))
296
-
297
291
  def _batch_inference_validate_snowpark(
298
292
  self,
299
293
  dataset: DataFrame,
300
294
  inference_method: str,
301
295
  ) -> List[str]:
302
- """Util method to run validate that batch inference can be run on a snowpark dataframe.
296
+ """Util method to run validate that batch inference can be run on a snowpark dataframe and
297
+ return the available package that exists in the snowflake anaconda channel
298
+
299
+ Args:
300
+ dataset: snowpark dataframe
301
+ inference_method: the inference method such as predict, score...
302
+
303
+ Raises:
304
+ SnowflakeMLException: If the estimator is not fitted, raise error
305
+ SnowflakeMLException: If the session is None, raise error
306
+
307
+ Returns:
308
+ A list of available package that exists in the snowflake anaconda channel
303
309
  """
304
310
  if not self._is_fitted:
305
311
  raise exceptions.SnowflakeMLException(
@@ -373,7 +379,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
373
379
  transform_kwargs = dict(
374
380
  session = dataset._session,
375
381
  dependencies = self._deps,
376
- pass_through_cols = self._get_pass_through_columns(dataset),
382
+ drop_input_cols = self._drop_input_cols,
377
383
  expected_output_cols_type = expected_type_inferred,
378
384
  )
379
385
 
@@ -433,16 +439,16 @@ class OrthogonalMatchingPursuit(BaseTransformer):
433
439
  # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
434
440
  # based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
435
441
  # each row containing a list of values.
436
- expected_dtype = "ARRAY"
442
+ expected_dtype = "array"
437
443
 
438
444
  # If we were unable to assign a type to this transform in the factory, infer the type here.
439
445
  if expected_dtype == "":
440
- # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
446
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
441
447
  if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
442
- expected_dtype = "ARRAY"
443
- # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
448
+ expected_dtype = "array"
449
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
444
450
  elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
445
- expected_dtype = "ARRAY"
451
+ expected_dtype = "array"
446
452
  else:
447
453
  output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
448
454
  # We can only infer the output types from the input types if the following two statemetns are true:
@@ -460,7 +466,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
460
466
  transform_kwargs = dict(
461
467
  session = dataset._session,
462
468
  dependencies = self._deps,
463
- pass_through_cols = self._get_pass_through_columns(dataset),
469
+ drop_input_cols = self._drop_input_cols,
464
470
  expected_output_cols_type = expected_dtype,
465
471
  )
466
472
 
@@ -511,7 +517,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
511
517
  subproject=_SUBPROJECT,
512
518
  )
513
519
  output_result, fitted_estimator = model_trainer.train_fit_predict(
514
- pass_through_columns=self._get_pass_through_columns(dataset),
520
+ drop_input_cols=self._drop_input_cols,
515
521
  expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
516
522
  )
517
523
  self._sklearn_object = fitted_estimator
@@ -529,44 +535,6 @@ class OrthogonalMatchingPursuit(BaseTransformer):
529
535
  assert self._sklearn_object is not None
530
536
  return self._sklearn_object.embedding_
531
537
 
532
-
533
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
534
- """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
535
- Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
536
- """
537
- output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
538
- if output_cols:
539
- output_cols = [
540
- identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
541
- for c in output_cols
542
- ]
543
- elif getattr(self._sklearn_object, "classes_", None) is None:
544
- output_cols = [output_cols_prefix]
545
- elif self._sklearn_object is not None:
546
- classes = self._sklearn_object.classes_
547
- if isinstance(classes, numpy.ndarray):
548
- output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
549
- elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
550
- # If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
551
- output_cols = []
552
- for i, cl in enumerate(classes):
553
- # For binary classification, there is only one output column for each class
554
- # ndarray as the two classes are complementary.
555
- if len(cl) == 2:
556
- output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
557
- else:
558
- output_cols.extend([
559
- f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
560
- ])
561
- else:
562
- output_cols = []
563
-
564
- # Make sure column names are valid snowflake identifiers.
565
- assert output_cols is not None # Make MyPy happy
566
- rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
567
-
568
- return rv
569
-
570
538
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
571
539
  @telemetry.send_api_usage_telemetry(
572
540
  project=_PROJECT,
@@ -606,7 +574,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
606
574
  transform_kwargs = dict(
607
575
  session=dataset._session,
608
576
  dependencies=self._deps,
609
- pass_through_cols=self._get_pass_through_columns(dataset),
577
+ drop_input_cols = self._drop_input_cols,
610
578
  expected_output_cols_type="float",
611
579
  )
612
580
 
@@ -671,7 +639,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
671
639
  transform_kwargs = dict(
672
640
  session=dataset._session,
673
641
  dependencies=self._deps,
674
- pass_through_cols=self._get_pass_through_columns(dataset),
642
+ drop_input_cols = self._drop_input_cols,
675
643
  expected_output_cols_type="float",
676
644
  )
677
645
  elif isinstance(dataset, pd.DataFrame):
@@ -732,7 +700,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
732
700
  transform_kwargs = dict(
733
701
  session=dataset._session,
734
702
  dependencies=self._deps,
735
- pass_through_cols=self._get_pass_through_columns(dataset),
703
+ drop_input_cols = self._drop_input_cols,
736
704
  expected_output_cols_type="float",
737
705
  )
738
706
 
@@ -797,7 +765,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
797
765
  transform_kwargs = dict(
798
766
  session=dataset._session,
799
767
  dependencies=self._deps,
800
- pass_through_cols=self._get_pass_through_columns(dataset),
768
+ drop_input_cols = self._drop_input_cols,
801
769
  expected_output_cols_type="float",
802
770
  )
803
771
 
@@ -853,13 +821,17 @@ class OrthogonalMatchingPursuit(BaseTransformer):
853
821
  transform_kwargs: ScoreKwargsTypedDict = dict()
854
822
 
855
823
  if isinstance(dataset, DataFrame):
824
+ self._deps = self._batch_inference_validate_snowpark(
825
+ dataset=dataset,
826
+ inference_method="score",
827
+ )
856
828
  selected_cols = self._get_active_columns()
857
829
  if len(selected_cols) > 0:
858
830
  dataset = dataset.select(selected_cols)
859
831
  assert isinstance(dataset._session, Session) # keep mypy happy
860
832
  transform_kwargs = dict(
861
833
  session=dataset._session,
862
- dependencies=["snowflake-snowpark-python"] + self._get_dependencies(),
834
+ dependencies=["snowflake-snowpark-python"] + self._deps,
863
835
  score_sproc_imports=['sklearn'],
864
836
  )
865
837
  elif isinstance(dataset, pd.DataFrame):
@@ -933,9 +905,9 @@ class OrthogonalMatchingPursuit(BaseTransformer):
933
905
  transform_kwargs = dict(
934
906
  session = dataset._session,
935
907
  dependencies = self._deps,
936
- pass_through_cols = self._get_pass_through_columns(dataset),
937
- expected_output_cols_type = "array",
938
- n_neighbors = n_neighbors,
908
+ drop_input_cols = self._drop_input_cols,
909
+ expected_output_cols_type="array",
910
+ n_neighbors = n_neighbors,
939
911
  return_distance = return_distance
940
912
  )
941
913
  elif isinstance(dataset, pd.DataFrame):