snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -832,6 +832,18 @@ class OneHotEncoder(base.BaseTransformer):
832
832
 
833
833
  # columns: COLUMN_NAME, CATEGORY, COUNT, FITTED_CATEGORY, ENCODING, N_FEATURES_OUT, ENCODED_VALUE, OUTPUT_CATs
834
834
  assert dataset._session is not None
835
+
836
+ def convert_to_string_excluding_nan(item: Any) -> Union[None, str]:
837
+ if pd.isna(item):
838
+ return None # or np.nan if you prefer to keep as NaN
839
+ else:
840
+ return str(item)
841
+
842
+ # In case of fitting with pandas dataframe and transforming with snowpark dataframe
843
+ # state_pandas cannot recognize the datatype of _CATEGORY and _FITTED_CATEGORY column
844
+ # Therefore, apply the convert_to_string_excluding_nan function to _CATEGORY and _FITTED_CATEGORY
845
+ state_pandas[[_CATEGORY]] = state_pandas[[_CATEGORY]].applymap(convert_to_string_excluding_nan)
846
+ state_pandas[[_FITTED_CATEGORY]] = state_pandas[[_FITTED_CATEGORY]].applymap(convert_to_string_excluding_nan)
835
847
  state_df = dataset._session.create_dataframe(state_pandas)
836
848
 
837
849
  transformed_dataset = dataset
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.preprocessing".replace("
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class PolynomialFeatures(BaseTransformer):
70
64
  r"""Generate polynomial and interaction features
71
65
  For more details on this class, see [sklearn.preprocessing.PolynomialFeatures]
@@ -283,20 +277,17 @@ class PolynomialFeatures(BaseTransformer):
283
277
  self,
284
278
  dataset: DataFrame,
285
279
  inference_method: str,
286
- ) -> List[str]:
287
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
288
- return the available package that exists in the snowflake anaconda channel
280
+ ) -> None:
281
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
289
282
 
290
283
  Args:
291
284
  dataset: snowpark dataframe
292
285
  inference_method: the inference method such as predict, score...
293
-
286
+
294
287
  Raises:
295
288
  SnowflakeMLException: If the estimator is not fitted, raise error
296
289
  SnowflakeMLException: If the session is None, raise error
297
290
 
298
- Returns:
299
- A list of available package that exists in the snowflake anaconda channel
300
291
  """
301
292
  if not self._is_fitted:
302
293
  raise exceptions.SnowflakeMLException(
@@ -314,9 +305,7 @@ class PolynomialFeatures(BaseTransformer):
314
305
  "Session must not specified for snowpark dataset."
315
306
  ),
316
307
  )
317
- # Validate that key package version in user workspace are supported in snowflake conda channel
318
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
319
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
308
+
320
309
 
321
310
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
322
311
  @telemetry.send_api_usage_telemetry(
@@ -362,7 +351,8 @@ class PolynomialFeatures(BaseTransformer):
362
351
 
363
352
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
364
353
 
365
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
354
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
355
+ self._deps = self._get_dependencies()
366
356
  assert isinstance(
367
357
  dataset._session, Session
368
358
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -447,10 +437,8 @@ class PolynomialFeatures(BaseTransformer):
447
437
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
448
438
  expected_dtype = convert_sp_to_sf_type(output_types[0])
449
439
 
450
- self._deps = self._batch_inference_validate_snowpark(
451
- dataset=dataset,
452
- inference_method=inference_method,
453
- )
440
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
441
+ self._deps = self._get_dependencies()
454
442
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
455
443
 
456
444
  transform_kwargs = dict(
@@ -517,16 +505,42 @@ class PolynomialFeatures(BaseTransformer):
517
505
  self._is_fitted = True
518
506
  return output_result
519
507
 
508
+
509
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
510
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
511
+ """ Fit to data, then transform it
512
+ For more details on this function, see [sklearn.preprocessing.PolynomialFeatures.fit_transform]
513
+ (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.PolynomialFeatures.html#sklearn.preprocessing.PolynomialFeatures.fit_transform)
514
+
520
515
 
521
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
522
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
523
- """
516
+ Raises:
517
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
518
+
519
+ Args:
520
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
521
+ Snowpark or Pandas DataFrame.
522
+ output_cols_prefix: Prefix for the response columns
524
523
  Returns:
525
524
  Transformed dataset.
526
525
  """
527
- self.fit(dataset)
528
- assert self._sklearn_object is not None
529
- return self._sklearn_object.embedding_
526
+ self._infer_input_output_cols(dataset)
527
+ super()._check_dataset_type(dataset)
528
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
529
+ estimator=self._sklearn_object,
530
+ dataset=dataset,
531
+ input_cols=self.input_cols,
532
+ label_cols=self.label_cols,
533
+ sample_weight_col=self.sample_weight_col,
534
+ autogenerated=self._autogenerated,
535
+ subproject=_SUBPROJECT,
536
+ )
537
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
538
+ drop_input_cols=self._drop_input_cols,
539
+ expected_output_cols_list=self.output_cols,
540
+ )
541
+ self._sklearn_object = fitted_estimator
542
+ self._is_fitted = True
543
+ return output_result
530
544
 
531
545
 
532
546
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -617,10 +631,8 @@ class PolynomialFeatures(BaseTransformer):
617
631
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
618
632
 
619
633
  if isinstance(dataset, DataFrame):
620
- self._deps = self._batch_inference_validate_snowpark(
621
- dataset=dataset,
622
- inference_method=inference_method,
623
- )
634
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
635
+ self._deps = self._get_dependencies()
624
636
  assert isinstance(
625
637
  dataset._session, Session
626
638
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -685,10 +697,8 @@ class PolynomialFeatures(BaseTransformer):
685
697
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
686
698
 
687
699
  if isinstance(dataset, DataFrame):
688
- self._deps = self._batch_inference_validate_snowpark(
689
- dataset=dataset,
690
- inference_method=inference_method,
691
- )
700
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
701
+ self._deps = self._get_dependencies()
692
702
  assert isinstance(
693
703
  dataset._session, Session
694
704
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -750,10 +760,8 @@ class PolynomialFeatures(BaseTransformer):
750
760
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
751
761
 
752
762
  if isinstance(dataset, DataFrame):
753
- self._deps = self._batch_inference_validate_snowpark(
754
- dataset=dataset,
755
- inference_method=inference_method,
756
- )
763
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
764
+ self._deps = self._get_dependencies()
757
765
  assert isinstance(
758
766
  dataset._session, Session
759
767
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -819,10 +827,8 @@ class PolynomialFeatures(BaseTransformer):
819
827
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
820
828
 
821
829
  if isinstance(dataset, DataFrame):
822
- self._deps = self._batch_inference_validate_snowpark(
823
- dataset=dataset,
824
- inference_method=inference_method,
825
- )
830
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
831
+ self._deps = self._get_dependencies()
826
832
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
827
833
  transform_kwargs = dict(
828
834
  session=dataset._session,
@@ -884,17 +890,15 @@ class PolynomialFeatures(BaseTransformer):
884
890
  transform_kwargs: ScoreKwargsTypedDict = dict()
885
891
 
886
892
  if isinstance(dataset, DataFrame):
887
- self._deps = self._batch_inference_validate_snowpark(
888
- dataset=dataset,
889
- inference_method="score",
890
- )
893
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
894
+ self._deps = self._get_dependencies()
891
895
  selected_cols = self._get_active_columns()
892
896
  if len(selected_cols) > 0:
893
897
  dataset = dataset.select(selected_cols)
894
898
  assert isinstance(dataset._session, Session) # keep mypy happy
895
899
  transform_kwargs = dict(
896
900
  session=dataset._session,
897
- dependencies=["snowflake-snowpark-python"] + self._deps,
901
+ dependencies=self._deps,
898
902
  score_sproc_imports=['sklearn'],
899
903
  )
900
904
  elif isinstance(dataset, pd.DataFrame):
@@ -959,11 +963,8 @@ class PolynomialFeatures(BaseTransformer):
959
963
 
960
964
  if isinstance(dataset, DataFrame):
961
965
 
962
- self._deps = self._batch_inference_validate_snowpark(
963
- dataset=dataset,
964
- inference_method=inference_method,
965
-
966
- )
966
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
967
+ self._deps = self._get_dependencies()
967
968
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
968
969
  transform_kwargs = dict(
969
970
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.semi_supervised".replace
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LabelPropagation(BaseTransformer):
70
64
  r"""Label Propagation classifier
71
65
  For more details on this class, see [sklearn.semi_supervised.LabelPropagation]
@@ -289,20 +283,17 @@ class LabelPropagation(BaseTransformer):
289
283
  self,
290
284
  dataset: DataFrame,
291
285
  inference_method: str,
292
- ) -> List[str]:
293
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
294
- return the available package that exists in the snowflake anaconda channel
286
+ ) -> None:
287
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
295
288
 
296
289
  Args:
297
290
  dataset: snowpark dataframe
298
291
  inference_method: the inference method such as predict, score...
299
-
292
+
300
293
  Raises:
301
294
  SnowflakeMLException: If the estimator is not fitted, raise error
302
295
  SnowflakeMLException: If the session is None, raise error
303
296
 
304
- Returns:
305
- A list of available package that exists in the snowflake anaconda channel
306
297
  """
307
298
  if not self._is_fitted:
308
299
  raise exceptions.SnowflakeMLException(
@@ -320,9 +311,7 @@ class LabelPropagation(BaseTransformer):
320
311
  "Session must not specified for snowpark dataset."
321
312
  ),
322
313
  )
323
- # Validate that key package version in user workspace are supported in snowflake conda channel
324
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
325
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
314
+
326
315
 
327
316
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
328
317
  @telemetry.send_api_usage_telemetry(
@@ -370,7 +359,8 @@ class LabelPropagation(BaseTransformer):
370
359
 
371
360
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
372
361
 
373
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
362
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
363
+ self._deps = self._get_dependencies()
374
364
  assert isinstance(
375
365
  dataset._session, Session
376
366
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -453,10 +443,8 @@ class LabelPropagation(BaseTransformer):
453
443
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
454
444
  expected_dtype = convert_sp_to_sf_type(output_types[0])
455
445
 
456
- self._deps = self._batch_inference_validate_snowpark(
457
- dataset=dataset,
458
- inference_method=inference_method,
459
- )
446
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
447
+ self._deps = self._get_dependencies()
460
448
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
461
449
 
462
450
  transform_kwargs = dict(
@@ -523,16 +511,40 @@ class LabelPropagation(BaseTransformer):
523
511
  self._is_fitted = True
524
512
  return output_result
525
513
 
514
+
515
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
516
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
517
+ """ Method not supported for this class.
526
518
 
527
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
528
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
529
- """
519
+
520
+ Raises:
521
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
522
+
523
+ Args:
524
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
525
+ Snowpark or Pandas DataFrame.
526
+ output_cols_prefix: Prefix for the response columns
530
527
  Returns:
531
528
  Transformed dataset.
532
529
  """
533
- self.fit(dataset)
534
- assert self._sklearn_object is not None
535
- return self._sklearn_object.embedding_
530
+ self._infer_input_output_cols(dataset)
531
+ super()._check_dataset_type(dataset)
532
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
533
+ estimator=self._sklearn_object,
534
+ dataset=dataset,
535
+ input_cols=self.input_cols,
536
+ label_cols=self.label_cols,
537
+ sample_weight_col=self.sample_weight_col,
538
+ autogenerated=self._autogenerated,
539
+ subproject=_SUBPROJECT,
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=self.output_cols,
544
+ )
545
+ self._sklearn_object = fitted_estimator
546
+ self._is_fitted = True
547
+ return output_result
536
548
 
537
549
 
538
550
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -625,10 +637,8 @@ class LabelPropagation(BaseTransformer):
625
637
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
626
638
 
627
639
  if isinstance(dataset, DataFrame):
628
- self._deps = self._batch_inference_validate_snowpark(
629
- dataset=dataset,
630
- inference_method=inference_method,
631
- )
640
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
641
+ self._deps = self._get_dependencies()
632
642
  assert isinstance(
633
643
  dataset._session, Session
634
644
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -695,10 +705,8 @@ class LabelPropagation(BaseTransformer):
695
705
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
696
706
 
697
707
  if isinstance(dataset, DataFrame):
698
- self._deps = self._batch_inference_validate_snowpark(
699
- dataset=dataset,
700
- inference_method=inference_method,
701
- )
708
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
709
+ self._deps = self._get_dependencies()
702
710
  assert isinstance(
703
711
  dataset._session, Session
704
712
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -760,10 +768,8 @@ class LabelPropagation(BaseTransformer):
760
768
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
761
769
 
762
770
  if isinstance(dataset, DataFrame):
763
- self._deps = self._batch_inference_validate_snowpark(
764
- dataset=dataset,
765
- inference_method=inference_method,
766
- )
771
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
772
+ self._deps = self._get_dependencies()
767
773
  assert isinstance(
768
774
  dataset._session, Session
769
775
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -829,10 +835,8 @@ class LabelPropagation(BaseTransformer):
829
835
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
830
836
 
831
837
  if isinstance(dataset, DataFrame):
832
- self._deps = self._batch_inference_validate_snowpark(
833
- dataset=dataset,
834
- inference_method=inference_method,
835
- )
838
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
839
+ self._deps = self._get_dependencies()
836
840
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
837
841
  transform_kwargs = dict(
838
842
  session=dataset._session,
@@ -896,17 +900,15 @@ class LabelPropagation(BaseTransformer):
896
900
  transform_kwargs: ScoreKwargsTypedDict = dict()
897
901
 
898
902
  if isinstance(dataset, DataFrame):
899
- self._deps = self._batch_inference_validate_snowpark(
900
- dataset=dataset,
901
- inference_method="score",
902
- )
903
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
904
+ self._deps = self._get_dependencies()
903
905
  selected_cols = self._get_active_columns()
904
906
  if len(selected_cols) > 0:
905
907
  dataset = dataset.select(selected_cols)
906
908
  assert isinstance(dataset._session, Session) # keep mypy happy
907
909
  transform_kwargs = dict(
908
910
  session=dataset._session,
909
- dependencies=["snowflake-snowpark-python"] + self._deps,
911
+ dependencies=self._deps,
910
912
  score_sproc_imports=['sklearn'],
911
913
  )
912
914
  elif isinstance(dataset, pd.DataFrame):
@@ -971,11 +973,8 @@ class LabelPropagation(BaseTransformer):
971
973
 
972
974
  if isinstance(dataset, DataFrame):
973
975
 
974
- self._deps = self._batch_inference_validate_snowpark(
975
- dataset=dataset,
976
- inference_method=inference_method,
977
-
978
- )
976
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
977
+ self._deps = self._get_dependencies()
979
978
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
980
979
  transform_kwargs = dict(
981
980
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.semi_supervised".replace
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LabelSpreading(BaseTransformer):
70
64
  r"""LabelSpreading model for semi-supervised learning
71
65
  For more details on this class, see [sklearn.semi_supervised.LabelSpreading]
@@ -298,20 +292,17 @@ class LabelSpreading(BaseTransformer):
298
292
  self,
299
293
  dataset: DataFrame,
300
294
  inference_method: str,
301
- ) -> List[str]:
302
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
303
- return the available package that exists in the snowflake anaconda channel
295
+ ) -> None:
296
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
304
297
 
305
298
  Args:
306
299
  dataset: snowpark dataframe
307
300
  inference_method: the inference method such as predict, score...
308
-
301
+
309
302
  Raises:
310
303
  SnowflakeMLException: If the estimator is not fitted, raise error
311
304
  SnowflakeMLException: If the session is None, raise error
312
305
 
313
- Returns:
314
- A list of available package that exists in the snowflake anaconda channel
315
306
  """
316
307
  if not self._is_fitted:
317
308
  raise exceptions.SnowflakeMLException(
@@ -329,9 +320,7 @@ class LabelSpreading(BaseTransformer):
329
320
  "Session must not specified for snowpark dataset."
330
321
  ),
331
322
  )
332
- # Validate that key package version in user workspace are supported in snowflake conda channel
333
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
334
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
323
+
335
324
 
336
325
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
337
326
  @telemetry.send_api_usage_telemetry(
@@ -379,7 +368,8 @@ class LabelSpreading(BaseTransformer):
379
368
 
380
369
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
381
370
 
382
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
371
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
372
+ self._deps = self._get_dependencies()
383
373
  assert isinstance(
384
374
  dataset._session, Session
385
375
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -462,10 +452,8 @@ class LabelSpreading(BaseTransformer):
462
452
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
463
453
  expected_dtype = convert_sp_to_sf_type(output_types[0])
464
454
 
465
- self._deps = self._batch_inference_validate_snowpark(
466
- dataset=dataset,
467
- inference_method=inference_method,
468
- )
455
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
456
+ self._deps = self._get_dependencies()
469
457
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
470
458
 
471
459
  transform_kwargs = dict(
@@ -532,16 +520,40 @@ class LabelSpreading(BaseTransformer):
532
520
  self._is_fitted = True
533
521
  return output_result
534
522
 
523
+
524
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
525
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
526
+ """ Method not supported for this class.
535
527
 
536
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
537
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
538
- """
528
+
529
+ Raises:
530
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
531
+
532
+ Args:
533
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
534
+ Snowpark or Pandas DataFrame.
535
+ output_cols_prefix: Prefix for the response columns
539
536
  Returns:
540
537
  Transformed dataset.
541
538
  """
542
- self.fit(dataset)
543
- assert self._sklearn_object is not None
544
- return self._sklearn_object.embedding_
539
+ self._infer_input_output_cols(dataset)
540
+ super()._check_dataset_type(dataset)
541
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
542
+ estimator=self._sklearn_object,
543
+ dataset=dataset,
544
+ input_cols=self.input_cols,
545
+ label_cols=self.label_cols,
546
+ sample_weight_col=self.sample_weight_col,
547
+ autogenerated=self._autogenerated,
548
+ subproject=_SUBPROJECT,
549
+ )
550
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
551
+ drop_input_cols=self._drop_input_cols,
552
+ expected_output_cols_list=self.output_cols,
553
+ )
554
+ self._sklearn_object = fitted_estimator
555
+ self._is_fitted = True
556
+ return output_result
545
557
 
546
558
 
547
559
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -634,10 +646,8 @@ class LabelSpreading(BaseTransformer):
634
646
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
635
647
 
636
648
  if isinstance(dataset, DataFrame):
637
- self._deps = self._batch_inference_validate_snowpark(
638
- dataset=dataset,
639
- inference_method=inference_method,
640
- )
649
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
650
+ self._deps = self._get_dependencies()
641
651
  assert isinstance(
642
652
  dataset._session, Session
643
653
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -704,10 +714,8 @@ class LabelSpreading(BaseTransformer):
704
714
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
705
715
 
706
716
  if isinstance(dataset, DataFrame):
707
- self._deps = self._batch_inference_validate_snowpark(
708
- dataset=dataset,
709
- inference_method=inference_method,
710
- )
717
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
718
+ self._deps = self._get_dependencies()
711
719
  assert isinstance(
712
720
  dataset._session, Session
713
721
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -769,10 +777,8 @@ class LabelSpreading(BaseTransformer):
769
777
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
770
778
 
771
779
  if isinstance(dataset, DataFrame):
772
- self._deps = self._batch_inference_validate_snowpark(
773
- dataset=dataset,
774
- inference_method=inference_method,
775
- )
780
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
781
+ self._deps = self._get_dependencies()
776
782
  assert isinstance(
777
783
  dataset._session, Session
778
784
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -838,10 +844,8 @@ class LabelSpreading(BaseTransformer):
838
844
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
839
845
 
840
846
  if isinstance(dataset, DataFrame):
841
- self._deps = self._batch_inference_validate_snowpark(
842
- dataset=dataset,
843
- inference_method=inference_method,
844
- )
847
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
848
+ self._deps = self._get_dependencies()
845
849
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
846
850
  transform_kwargs = dict(
847
851
  session=dataset._session,
@@ -905,17 +909,15 @@ class LabelSpreading(BaseTransformer):
905
909
  transform_kwargs: ScoreKwargsTypedDict = dict()
906
910
 
907
911
  if isinstance(dataset, DataFrame):
908
- self._deps = self._batch_inference_validate_snowpark(
909
- dataset=dataset,
910
- inference_method="score",
911
- )
912
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
913
+ self._deps = self._get_dependencies()
912
914
  selected_cols = self._get_active_columns()
913
915
  if len(selected_cols) > 0:
914
916
  dataset = dataset.select(selected_cols)
915
917
  assert isinstance(dataset._session, Session) # keep mypy happy
916
918
  transform_kwargs = dict(
917
919
  session=dataset._session,
918
- dependencies=["snowflake-snowpark-python"] + self._deps,
920
+ dependencies=self._deps,
919
921
  score_sproc_imports=['sklearn'],
920
922
  )
921
923
  elif isinstance(dataset, pd.DataFrame):
@@ -980,11 +982,8 @@ class LabelSpreading(BaseTransformer):
980
982
 
981
983
  if isinstance(dataset, DataFrame):
982
984
 
983
- self._deps = self._batch_inference_validate_snowpark(
984
- dataset=dataset,
985
- inference_method=inference_method,
986
-
987
- )
985
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
986
+ self._deps = self._get_dependencies()
988
987
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
989
988
  transform_kwargs = dict(
990
989
  session = dataset._session,