snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class IsolationForest(BaseTransformer):
70
64
  r"""Isolation Forest Algorithm
71
65
  For more details on this class, see [sklearn.ensemble.IsolationForest]
@@ -324,20 +318,17 @@ class IsolationForest(BaseTransformer):
324
318
  self,
325
319
  dataset: DataFrame,
326
320
  inference_method: str,
327
- ) -> List[str]:
328
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
329
- return the available package that exists in the snowflake anaconda channel
321
+ ) -> None:
322
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
330
323
 
331
324
  Args:
332
325
  dataset: snowpark dataframe
333
326
  inference_method: the inference method such as predict, score...
334
-
327
+
335
328
  Raises:
336
329
  SnowflakeMLException: If the estimator is not fitted, raise error
337
330
  SnowflakeMLException: If the session is None, raise error
338
331
 
339
- Returns:
340
- A list of available package that exists in the snowflake anaconda channel
341
332
  """
342
333
  if not self._is_fitted:
343
334
  raise exceptions.SnowflakeMLException(
@@ -355,9 +346,7 @@ class IsolationForest(BaseTransformer):
355
346
  "Session must not specified for snowpark dataset."
356
347
  ),
357
348
  )
358
- # Validate that key package version in user workspace are supported in snowflake conda channel
359
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
360
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
349
+
361
350
 
362
351
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
363
352
  @telemetry.send_api_usage_telemetry(
@@ -405,7 +394,8 @@ class IsolationForest(BaseTransformer):
405
394
 
406
395
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
407
396
 
408
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
397
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
398
+ self._deps = self._get_dependencies()
409
399
  assert isinstance(
410
400
  dataset._session, Session
411
401
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -488,10 +478,8 @@ class IsolationForest(BaseTransformer):
488
478
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
489
479
  expected_dtype = convert_sp_to_sf_type(output_types[0])
490
480
 
491
- self._deps = self._batch_inference_validate_snowpark(
492
- dataset=dataset,
493
- inference_method=inference_method,
494
- )
481
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
482
+ self._deps = self._get_dependencies()
495
483
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
496
484
 
497
485
  transform_kwargs = dict(
@@ -560,16 +548,40 @@ class IsolationForest(BaseTransformer):
560
548
  self._is_fitted = True
561
549
  return output_result
562
550
 
551
+
552
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
553
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
554
+ """ Method not supported for this class.
555
+
563
556
 
564
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
565
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
566
- """
557
+ Raises:
558
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
559
+
560
+ Args:
561
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
562
+ Snowpark or Pandas DataFrame.
563
+ output_cols_prefix: Prefix for the response columns
567
564
  Returns:
568
565
  Transformed dataset.
569
566
  """
570
- self.fit(dataset)
571
- assert self._sklearn_object is not None
572
- return self._sklearn_object.embedding_
567
+ self._infer_input_output_cols(dataset)
568
+ super()._check_dataset_type(dataset)
569
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
570
+ estimator=self._sklearn_object,
571
+ dataset=dataset,
572
+ input_cols=self.input_cols,
573
+ label_cols=self.label_cols,
574
+ sample_weight_col=self.sample_weight_col,
575
+ autogenerated=self._autogenerated,
576
+ subproject=_SUBPROJECT,
577
+ )
578
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
579
+ drop_input_cols=self._drop_input_cols,
580
+ expected_output_cols_list=self.output_cols,
581
+ )
582
+ self._sklearn_object = fitted_estimator
583
+ self._is_fitted = True
584
+ return output_result
573
585
 
574
586
 
575
587
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -660,10 +672,8 @@ class IsolationForest(BaseTransformer):
660
672
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
661
673
 
662
674
  if isinstance(dataset, DataFrame):
663
- self._deps = self._batch_inference_validate_snowpark(
664
- dataset=dataset,
665
- inference_method=inference_method,
666
- )
675
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
676
+ self._deps = self._get_dependencies()
667
677
  assert isinstance(
668
678
  dataset._session, Session
669
679
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -728,10 +738,8 @@ class IsolationForest(BaseTransformer):
728
738
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
729
739
 
730
740
  if isinstance(dataset, DataFrame):
731
- self._deps = self._batch_inference_validate_snowpark(
732
- dataset=dataset,
733
- inference_method=inference_method,
734
- )
741
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
742
+ self._deps = self._get_dependencies()
735
743
  assert isinstance(
736
744
  dataset._session, Session
737
745
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -795,10 +803,8 @@ class IsolationForest(BaseTransformer):
795
803
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
796
804
 
797
805
  if isinstance(dataset, DataFrame):
798
- self._deps = self._batch_inference_validate_snowpark(
799
- dataset=dataset,
800
- inference_method=inference_method,
801
- )
806
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
807
+ self._deps = self._get_dependencies()
802
808
  assert isinstance(
803
809
  dataset._session, Session
804
810
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -866,10 +872,8 @@ class IsolationForest(BaseTransformer):
866
872
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
867
873
 
868
874
  if isinstance(dataset, DataFrame):
869
- self._deps = self._batch_inference_validate_snowpark(
870
- dataset=dataset,
871
- inference_method=inference_method,
872
- )
875
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
876
+ self._deps = self._get_dependencies()
873
877
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
874
878
  transform_kwargs = dict(
875
879
  session=dataset._session,
@@ -931,17 +935,15 @@ class IsolationForest(BaseTransformer):
931
935
  transform_kwargs: ScoreKwargsTypedDict = dict()
932
936
 
933
937
  if isinstance(dataset, DataFrame):
934
- self._deps = self._batch_inference_validate_snowpark(
935
- dataset=dataset,
936
- inference_method="score",
937
- )
938
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
939
+ self._deps = self._get_dependencies()
938
940
  selected_cols = self._get_active_columns()
939
941
  if len(selected_cols) > 0:
940
942
  dataset = dataset.select(selected_cols)
941
943
  assert isinstance(dataset._session, Session) # keep mypy happy
942
944
  transform_kwargs = dict(
943
945
  session=dataset._session,
944
- dependencies=["snowflake-snowpark-python"] + self._deps,
946
+ dependencies=self._deps,
945
947
  score_sproc_imports=['sklearn'],
946
948
  )
947
949
  elif isinstance(dataset, pd.DataFrame):
@@ -1006,11 +1008,8 @@ class IsolationForest(BaseTransformer):
1006
1008
 
1007
1009
  if isinstance(dataset, DataFrame):
1008
1010
 
1009
- self._deps = self._batch_inference_validate_snowpark(
1010
- dataset=dataset,
1011
- inference_method=inference_method,
1012
-
1013
- )
1011
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1012
+ self._deps = self._get_dependencies()
1014
1013
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1015
1014
  transform_kwargs = dict(
1016
1015
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RandomForestClassifier(BaseTransformer):
70
64
  r"""A random forest classifier
71
65
  For more details on this class, see [sklearn.ensemble.RandomForestClassifier]
@@ -436,20 +430,17 @@ class RandomForestClassifier(BaseTransformer):
436
430
  self,
437
431
  dataset: DataFrame,
438
432
  inference_method: str,
439
- ) -> List[str]:
440
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
441
- return the available package that exists in the snowflake anaconda channel
433
+ ) -> None:
434
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
442
435
 
443
436
  Args:
444
437
  dataset: snowpark dataframe
445
438
  inference_method: the inference method such as predict, score...
446
-
439
+
447
440
  Raises:
448
441
  SnowflakeMLException: If the estimator is not fitted, raise error
449
442
  SnowflakeMLException: If the session is None, raise error
450
443
 
451
- Returns:
452
- A list of available package that exists in the snowflake anaconda channel
453
444
  """
454
445
  if not self._is_fitted:
455
446
  raise exceptions.SnowflakeMLException(
@@ -467,9 +458,7 @@ class RandomForestClassifier(BaseTransformer):
467
458
  "Session must not specified for snowpark dataset."
468
459
  ),
469
460
  )
470
- # Validate that key package version in user workspace are supported in snowflake conda channel
471
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
472
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
461
+
473
462
 
474
463
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
475
464
  @telemetry.send_api_usage_telemetry(
@@ -517,7 +506,8 @@ class RandomForestClassifier(BaseTransformer):
517
506
 
518
507
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
519
508
 
520
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
509
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
510
+ self._deps = self._get_dependencies()
521
511
  assert isinstance(
522
512
  dataset._session, Session
523
513
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -600,10 +590,8 @@ class RandomForestClassifier(BaseTransformer):
600
590
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
601
591
  expected_dtype = convert_sp_to_sf_type(output_types[0])
602
592
 
603
- self._deps = self._batch_inference_validate_snowpark(
604
- dataset=dataset,
605
- inference_method=inference_method,
606
- )
593
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
594
+ self._deps = self._get_dependencies()
607
595
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
608
596
 
609
597
  transform_kwargs = dict(
@@ -670,16 +658,40 @@ class RandomForestClassifier(BaseTransformer):
670
658
  self._is_fitted = True
671
659
  return output_result
672
660
 
661
+
662
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
663
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
664
+ """ Method not supported for this class.
673
665
 
674
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
675
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
676
- """
666
+
667
+ Raises:
668
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
669
+
670
+ Args:
671
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
672
+ Snowpark or Pandas DataFrame.
673
+ output_cols_prefix: Prefix for the response columns
677
674
  Returns:
678
675
  Transformed dataset.
679
676
  """
680
- self.fit(dataset)
681
- assert self._sklearn_object is not None
682
- return self._sklearn_object.embedding_
677
+ self._infer_input_output_cols(dataset)
678
+ super()._check_dataset_type(dataset)
679
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
680
+ estimator=self._sklearn_object,
681
+ dataset=dataset,
682
+ input_cols=self.input_cols,
683
+ label_cols=self.label_cols,
684
+ sample_weight_col=self.sample_weight_col,
685
+ autogenerated=self._autogenerated,
686
+ subproject=_SUBPROJECT,
687
+ )
688
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
689
+ drop_input_cols=self._drop_input_cols,
690
+ expected_output_cols_list=self.output_cols,
691
+ )
692
+ self._sklearn_object = fitted_estimator
693
+ self._is_fitted = True
694
+ return output_result
683
695
 
684
696
 
685
697
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -772,10 +784,8 @@ class RandomForestClassifier(BaseTransformer):
772
784
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
773
785
 
774
786
  if isinstance(dataset, DataFrame):
775
- self._deps = self._batch_inference_validate_snowpark(
776
- dataset=dataset,
777
- inference_method=inference_method,
778
- )
787
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
788
+ self._deps = self._get_dependencies()
779
789
  assert isinstance(
780
790
  dataset._session, Session
781
791
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -842,10 +852,8 @@ class RandomForestClassifier(BaseTransformer):
842
852
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
843
853
 
844
854
  if isinstance(dataset, DataFrame):
845
- self._deps = self._batch_inference_validate_snowpark(
846
- dataset=dataset,
847
- inference_method=inference_method,
848
- )
855
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
856
+ self._deps = self._get_dependencies()
849
857
  assert isinstance(
850
858
  dataset._session, Session
851
859
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -907,10 +915,8 @@ class RandomForestClassifier(BaseTransformer):
907
915
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
908
916
 
909
917
  if isinstance(dataset, DataFrame):
910
- self._deps = self._batch_inference_validate_snowpark(
911
- dataset=dataset,
912
- inference_method=inference_method,
913
- )
918
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
919
+ self._deps = self._get_dependencies()
914
920
  assert isinstance(
915
921
  dataset._session, Session
916
922
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -976,10 +982,8 @@ class RandomForestClassifier(BaseTransformer):
976
982
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
977
983
 
978
984
  if isinstance(dataset, DataFrame):
979
- self._deps = self._batch_inference_validate_snowpark(
980
- dataset=dataset,
981
- inference_method=inference_method,
982
- )
985
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
986
+ self._deps = self._get_dependencies()
983
987
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
984
988
  transform_kwargs = dict(
985
989
  session=dataset._session,
@@ -1043,17 +1047,15 @@ class RandomForestClassifier(BaseTransformer):
1043
1047
  transform_kwargs: ScoreKwargsTypedDict = dict()
1044
1048
 
1045
1049
  if isinstance(dataset, DataFrame):
1046
- self._deps = self._batch_inference_validate_snowpark(
1047
- dataset=dataset,
1048
- inference_method="score",
1049
- )
1050
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1051
+ self._deps = self._get_dependencies()
1050
1052
  selected_cols = self._get_active_columns()
1051
1053
  if len(selected_cols) > 0:
1052
1054
  dataset = dataset.select(selected_cols)
1053
1055
  assert isinstance(dataset._session, Session) # keep mypy happy
1054
1056
  transform_kwargs = dict(
1055
1057
  session=dataset._session,
1056
- dependencies=["snowflake-snowpark-python"] + self._deps,
1058
+ dependencies=self._deps,
1057
1059
  score_sproc_imports=['sklearn'],
1058
1060
  )
1059
1061
  elif isinstance(dataset, pd.DataFrame):
@@ -1118,11 +1120,8 @@ class RandomForestClassifier(BaseTransformer):
1118
1120
 
1119
1121
  if isinstance(dataset, DataFrame):
1120
1122
 
1121
- self._deps = self._batch_inference_validate_snowpark(
1122
- dataset=dataset,
1123
- inference_method=inference_method,
1124
-
1125
- )
1123
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1124
+ self._deps = self._get_dependencies()
1126
1125
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1127
1126
  transform_kwargs = dict(
1128
1127
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RandomForestRegressor(BaseTransformer):
70
64
  r"""A random forest regressor
71
65
  For more details on this class, see [sklearn.ensemble.RandomForestRegressor]
@@ -415,20 +409,17 @@ class RandomForestRegressor(BaseTransformer):
415
409
  self,
416
410
  dataset: DataFrame,
417
411
  inference_method: str,
418
- ) -> List[str]:
419
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
420
- return the available package that exists in the snowflake anaconda channel
412
+ ) -> None:
413
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
421
414
 
422
415
  Args:
423
416
  dataset: snowpark dataframe
424
417
  inference_method: the inference method such as predict, score...
425
-
418
+
426
419
  Raises:
427
420
  SnowflakeMLException: If the estimator is not fitted, raise error
428
421
  SnowflakeMLException: If the session is None, raise error
429
422
 
430
- Returns:
431
- A list of available package that exists in the snowflake anaconda channel
432
423
  """
433
424
  if not self._is_fitted:
434
425
  raise exceptions.SnowflakeMLException(
@@ -446,9 +437,7 @@ class RandomForestRegressor(BaseTransformer):
446
437
  "Session must not specified for snowpark dataset."
447
438
  ),
448
439
  )
449
- # Validate that key package version in user workspace are supported in snowflake conda channel
450
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
451
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
440
+
452
441
 
453
442
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
454
443
  @telemetry.send_api_usage_telemetry(
@@ -496,7 +485,8 @@ class RandomForestRegressor(BaseTransformer):
496
485
 
497
486
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
498
487
 
499
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
488
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
489
+ self._deps = self._get_dependencies()
500
490
  assert isinstance(
501
491
  dataset._session, Session
502
492
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -579,10 +569,8 @@ class RandomForestRegressor(BaseTransformer):
579
569
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
580
570
  expected_dtype = convert_sp_to_sf_type(output_types[0])
581
571
 
582
- self._deps = self._batch_inference_validate_snowpark(
583
- dataset=dataset,
584
- inference_method=inference_method,
585
- )
572
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
573
+ self._deps = self._get_dependencies()
586
574
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
587
575
 
588
576
  transform_kwargs = dict(
@@ -649,16 +637,40 @@ class RandomForestRegressor(BaseTransformer):
649
637
  self._is_fitted = True
650
638
  return output_result
651
639
 
640
+
641
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
642
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
643
+ """ Method not supported for this class.
652
644
 
653
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
654
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
655
- """
645
+
646
+ Raises:
647
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
648
+
649
+ Args:
650
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
651
+ Snowpark or Pandas DataFrame.
652
+ output_cols_prefix: Prefix for the response columns
656
653
  Returns:
657
654
  Transformed dataset.
658
655
  """
659
- self.fit(dataset)
660
- assert self._sklearn_object is not None
661
- return self._sklearn_object.embedding_
656
+ self._infer_input_output_cols(dataset)
657
+ super()._check_dataset_type(dataset)
658
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
659
+ estimator=self._sklearn_object,
660
+ dataset=dataset,
661
+ input_cols=self.input_cols,
662
+ label_cols=self.label_cols,
663
+ sample_weight_col=self.sample_weight_col,
664
+ autogenerated=self._autogenerated,
665
+ subproject=_SUBPROJECT,
666
+ )
667
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
668
+ drop_input_cols=self._drop_input_cols,
669
+ expected_output_cols_list=self.output_cols,
670
+ )
671
+ self._sklearn_object = fitted_estimator
672
+ self._is_fitted = True
673
+ return output_result
662
674
 
663
675
 
664
676
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -749,10 +761,8 @@ class RandomForestRegressor(BaseTransformer):
749
761
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
750
762
 
751
763
  if isinstance(dataset, DataFrame):
752
- self._deps = self._batch_inference_validate_snowpark(
753
- dataset=dataset,
754
- inference_method=inference_method,
755
- )
764
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
765
+ self._deps = self._get_dependencies()
756
766
  assert isinstance(
757
767
  dataset._session, Session
758
768
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -817,10 +827,8 @@ class RandomForestRegressor(BaseTransformer):
817
827
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
818
828
 
819
829
  if isinstance(dataset, DataFrame):
820
- self._deps = self._batch_inference_validate_snowpark(
821
- dataset=dataset,
822
- inference_method=inference_method,
823
- )
830
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
831
+ self._deps = self._get_dependencies()
824
832
  assert isinstance(
825
833
  dataset._session, Session
826
834
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -882,10 +890,8 @@ class RandomForestRegressor(BaseTransformer):
882
890
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
883
891
 
884
892
  if isinstance(dataset, DataFrame):
885
- self._deps = self._batch_inference_validate_snowpark(
886
- dataset=dataset,
887
- inference_method=inference_method,
888
- )
893
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
894
+ self._deps = self._get_dependencies()
889
895
  assert isinstance(
890
896
  dataset._session, Session
891
897
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -951,10 +957,8 @@ class RandomForestRegressor(BaseTransformer):
951
957
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
952
958
 
953
959
  if isinstance(dataset, DataFrame):
954
- self._deps = self._batch_inference_validate_snowpark(
955
- dataset=dataset,
956
- inference_method=inference_method,
957
- )
960
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
961
+ self._deps = self._get_dependencies()
958
962
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
959
963
  transform_kwargs = dict(
960
964
  session=dataset._session,
@@ -1018,17 +1022,15 @@ class RandomForestRegressor(BaseTransformer):
1018
1022
  transform_kwargs: ScoreKwargsTypedDict = dict()
1019
1023
 
1020
1024
  if isinstance(dataset, DataFrame):
1021
- self._deps = self._batch_inference_validate_snowpark(
1022
- dataset=dataset,
1023
- inference_method="score",
1024
- )
1025
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1026
+ self._deps = self._get_dependencies()
1025
1027
  selected_cols = self._get_active_columns()
1026
1028
  if len(selected_cols) > 0:
1027
1029
  dataset = dataset.select(selected_cols)
1028
1030
  assert isinstance(dataset._session, Session) # keep mypy happy
1029
1031
  transform_kwargs = dict(
1030
1032
  session=dataset._session,
1031
- dependencies=["snowflake-snowpark-python"] + self._deps,
1033
+ dependencies=self._deps,
1032
1034
  score_sproc_imports=['sklearn'],
1033
1035
  )
1034
1036
  elif isinstance(dataset, pd.DataFrame):
@@ -1093,11 +1095,8 @@ class RandomForestRegressor(BaseTransformer):
1093
1095
 
1094
1096
  if isinstance(dataset, DataFrame):
1095
1097
 
1096
- self._deps = self._batch_inference_validate_snowpark(
1097
- dataset=dataset,
1098
- inference_method=inference_method,
1099
-
1100
- )
1098
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1099
+ self._deps = self._get_dependencies()
1101
1100
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1102
1101
  transform_kwargs = dict(
1103
1102
  session = dataset._session,