snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.tree".replace("sklearn."
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class ExtraTreeRegressor(BaseTransformer):
70
64
  r"""An extremely randomized tree regressor
71
65
  For more details on this class, see [sklearn.tree.ExtraTreeRegressor]
@@ -365,20 +359,17 @@ class ExtraTreeRegressor(BaseTransformer):
365
359
  self,
366
360
  dataset: DataFrame,
367
361
  inference_method: str,
368
- ) -> List[str]:
369
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
370
- return the available package that exists in the snowflake anaconda channel
362
+ ) -> None:
363
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
371
364
 
372
365
  Args:
373
366
  dataset: snowpark dataframe
374
367
  inference_method: the inference method such as predict, score...
375
-
368
+
376
369
  Raises:
377
370
  SnowflakeMLException: If the estimator is not fitted, raise error
378
371
  SnowflakeMLException: If the session is None, raise error
379
372
 
380
- Returns:
381
- A list of available package that exists in the snowflake anaconda channel
382
373
  """
383
374
  if not self._is_fitted:
384
375
  raise exceptions.SnowflakeMLException(
@@ -396,9 +387,7 @@ class ExtraTreeRegressor(BaseTransformer):
396
387
  "Session must not specified for snowpark dataset."
397
388
  ),
398
389
  )
399
- # Validate that key package version in user workspace are supported in snowflake conda channel
400
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
401
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
390
+
402
391
 
403
392
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
404
393
  @telemetry.send_api_usage_telemetry(
@@ -446,7 +435,8 @@ class ExtraTreeRegressor(BaseTransformer):
446
435
 
447
436
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
448
437
 
449
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
438
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
439
+ self._deps = self._get_dependencies()
450
440
  assert isinstance(
451
441
  dataset._session, Session
452
442
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -529,10 +519,8 @@ class ExtraTreeRegressor(BaseTransformer):
529
519
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
530
520
  expected_dtype = convert_sp_to_sf_type(output_types[0])
531
521
 
532
- self._deps = self._batch_inference_validate_snowpark(
533
- dataset=dataset,
534
- inference_method=inference_method,
535
- )
522
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
523
+ self._deps = self._get_dependencies()
536
524
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
537
525
 
538
526
  transform_kwargs = dict(
@@ -599,16 +587,40 @@ class ExtraTreeRegressor(BaseTransformer):
599
587
  self._is_fitted = True
600
588
  return output_result
601
589
 
590
+
591
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
592
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
593
+ """ Method not supported for this class.
602
594
 
603
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
604
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
605
- """
595
+
596
+ Raises:
597
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
598
+
599
+ Args:
600
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
601
+ Snowpark or Pandas DataFrame.
602
+ output_cols_prefix: Prefix for the response columns
606
603
  Returns:
607
604
  Transformed dataset.
608
605
  """
609
- self.fit(dataset)
610
- assert self._sklearn_object is not None
611
- return self._sklearn_object.embedding_
606
+ self._infer_input_output_cols(dataset)
607
+ super()._check_dataset_type(dataset)
608
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
609
+ estimator=self._sklearn_object,
610
+ dataset=dataset,
611
+ input_cols=self.input_cols,
612
+ label_cols=self.label_cols,
613
+ sample_weight_col=self.sample_weight_col,
614
+ autogenerated=self._autogenerated,
615
+ subproject=_SUBPROJECT,
616
+ )
617
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
618
+ drop_input_cols=self._drop_input_cols,
619
+ expected_output_cols_list=self.output_cols,
620
+ )
621
+ self._sklearn_object = fitted_estimator
622
+ self._is_fitted = True
623
+ return output_result
612
624
 
613
625
 
614
626
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -699,10 +711,8 @@ class ExtraTreeRegressor(BaseTransformer):
699
711
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
700
712
 
701
713
  if isinstance(dataset, DataFrame):
702
- self._deps = self._batch_inference_validate_snowpark(
703
- dataset=dataset,
704
- inference_method=inference_method,
705
- )
714
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
715
+ self._deps = self._get_dependencies()
706
716
  assert isinstance(
707
717
  dataset._session, Session
708
718
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -767,10 +777,8 @@ class ExtraTreeRegressor(BaseTransformer):
767
777
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
768
778
 
769
779
  if isinstance(dataset, DataFrame):
770
- self._deps = self._batch_inference_validate_snowpark(
771
- dataset=dataset,
772
- inference_method=inference_method,
773
- )
780
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
781
+ self._deps = self._get_dependencies()
774
782
  assert isinstance(
775
783
  dataset._session, Session
776
784
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -832,10 +840,8 @@ class ExtraTreeRegressor(BaseTransformer):
832
840
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
833
841
 
834
842
  if isinstance(dataset, DataFrame):
835
- self._deps = self._batch_inference_validate_snowpark(
836
- dataset=dataset,
837
- inference_method=inference_method,
838
- )
843
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
844
+ self._deps = self._get_dependencies()
839
845
  assert isinstance(
840
846
  dataset._session, Session
841
847
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -901,10 +907,8 @@ class ExtraTreeRegressor(BaseTransformer):
901
907
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
902
908
 
903
909
  if isinstance(dataset, DataFrame):
904
- self._deps = self._batch_inference_validate_snowpark(
905
- dataset=dataset,
906
- inference_method=inference_method,
907
- )
910
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
911
+ self._deps = self._get_dependencies()
908
912
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
909
913
  transform_kwargs = dict(
910
914
  session=dataset._session,
@@ -968,17 +972,15 @@ class ExtraTreeRegressor(BaseTransformer):
968
972
  transform_kwargs: ScoreKwargsTypedDict = dict()
969
973
 
970
974
  if isinstance(dataset, DataFrame):
971
- self._deps = self._batch_inference_validate_snowpark(
972
- dataset=dataset,
973
- inference_method="score",
974
- )
975
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
976
+ self._deps = self._get_dependencies()
975
977
  selected_cols = self._get_active_columns()
976
978
  if len(selected_cols) > 0:
977
979
  dataset = dataset.select(selected_cols)
978
980
  assert isinstance(dataset._session, Session) # keep mypy happy
979
981
  transform_kwargs = dict(
980
982
  session=dataset._session,
981
- dependencies=["snowflake-snowpark-python"] + self._deps,
983
+ dependencies=self._deps,
982
984
  score_sproc_imports=['sklearn'],
983
985
  )
984
986
  elif isinstance(dataset, pd.DataFrame):
@@ -1043,11 +1045,8 @@ class ExtraTreeRegressor(BaseTransformer):
1043
1045
 
1044
1046
  if isinstance(dataset, DataFrame):
1045
1047
 
1046
- self._deps = self._batch_inference_validate_snowpark(
1047
- dataset=dataset,
1048
- inference_method=inference_method,
1049
-
1050
- )
1048
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1049
+ self._deps = self._get_dependencies()
1051
1050
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1052
1051
  transform_kwargs = dict(
1053
1052
  session = dataset._session,
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
59
59
 
60
60
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
61
61
 
62
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
- return check
66
-
67
-
68
62
  class XGBClassifier(BaseTransformer):
69
63
  r"""Implementation of the scikit-learn API for XGBoost classification
70
64
  For more details on this class, see [xgboost.XGBClassifier]
@@ -483,20 +477,17 @@ class XGBClassifier(BaseTransformer):
483
477
  self,
484
478
  dataset: DataFrame,
485
479
  inference_method: str,
486
- ) -> List[str]:
487
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
488
- return the available package that exists in the snowflake anaconda channel
480
+ ) -> None:
481
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
489
482
 
490
483
  Args:
491
484
  dataset: snowpark dataframe
492
485
  inference_method: the inference method such as predict, score...
493
-
486
+
494
487
  Raises:
495
488
  SnowflakeMLException: If the estimator is not fitted, raise error
496
489
  SnowflakeMLException: If the session is None, raise error
497
490
 
498
- Returns:
499
- A list of available package that exists in the snowflake anaconda channel
500
491
  """
501
492
  if not self._is_fitted:
502
493
  raise exceptions.SnowflakeMLException(
@@ -514,9 +505,7 @@ class XGBClassifier(BaseTransformer):
514
505
  "Session must not specified for snowpark dataset."
515
506
  ),
516
507
  )
517
- # Validate that key package version in user workspace are supported in snowflake conda channel
518
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
519
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
508
+
520
509
 
521
510
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
522
511
  @telemetry.send_api_usage_telemetry(
@@ -564,7 +553,8 @@ class XGBClassifier(BaseTransformer):
564
553
 
565
554
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
566
555
 
567
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
556
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
557
+ self._deps = self._get_dependencies()
568
558
  assert isinstance(
569
559
  dataset._session, Session
570
560
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -647,10 +637,8 @@ class XGBClassifier(BaseTransformer):
647
637
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
648
638
  expected_dtype = convert_sp_to_sf_type(output_types[0])
649
639
 
650
- self._deps = self._batch_inference_validate_snowpark(
651
- dataset=dataset,
652
- inference_method=inference_method,
653
- )
640
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
641
+ self._deps = self._get_dependencies()
654
642
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
655
643
 
656
644
  transform_kwargs = dict(
@@ -717,16 +705,40 @@ class XGBClassifier(BaseTransformer):
717
705
  self._is_fitted = True
718
706
  return output_result
719
707
 
708
+
709
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
710
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
711
+ """ Method not supported for this class.
720
712
 
721
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
722
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
723
- """
713
+
714
+ Raises:
715
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
716
+
717
+ Args:
718
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
719
+ Snowpark or Pandas DataFrame.
720
+ output_cols_prefix: Prefix for the response columns
724
721
  Returns:
725
722
  Transformed dataset.
726
723
  """
727
- self.fit(dataset)
728
- assert self._sklearn_object is not None
729
- return self._sklearn_object.embedding_
724
+ self._infer_input_output_cols(dataset)
725
+ super()._check_dataset_type(dataset)
726
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
727
+ estimator=self._sklearn_object,
728
+ dataset=dataset,
729
+ input_cols=self.input_cols,
730
+ label_cols=self.label_cols,
731
+ sample_weight_col=self.sample_weight_col,
732
+ autogenerated=self._autogenerated,
733
+ subproject=_SUBPROJECT,
734
+ )
735
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
736
+ drop_input_cols=self._drop_input_cols,
737
+ expected_output_cols_list=self.output_cols,
738
+ )
739
+ self._sklearn_object = fitted_estimator
740
+ self._is_fitted = True
741
+ return output_result
730
742
 
731
743
 
732
744
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -819,10 +831,8 @@ class XGBClassifier(BaseTransformer):
819
831
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
820
832
 
821
833
  if isinstance(dataset, DataFrame):
822
- self._deps = self._batch_inference_validate_snowpark(
823
- dataset=dataset,
824
- inference_method=inference_method,
825
- )
834
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
835
+ self._deps = self._get_dependencies()
826
836
  assert isinstance(
827
837
  dataset._session, Session
828
838
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -889,10 +899,8 @@ class XGBClassifier(BaseTransformer):
889
899
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
890
900
 
891
901
  if isinstance(dataset, DataFrame):
892
- self._deps = self._batch_inference_validate_snowpark(
893
- dataset=dataset,
894
- inference_method=inference_method,
895
- )
902
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
903
+ self._deps = self._get_dependencies()
896
904
  assert isinstance(
897
905
  dataset._session, Session
898
906
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -954,10 +962,8 @@ class XGBClassifier(BaseTransformer):
954
962
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
955
963
 
956
964
  if isinstance(dataset, DataFrame):
957
- self._deps = self._batch_inference_validate_snowpark(
958
- dataset=dataset,
959
- inference_method=inference_method,
960
- )
965
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
966
+ self._deps = self._get_dependencies()
961
967
  assert isinstance(
962
968
  dataset._session, Session
963
969
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -1023,10 +1029,8 @@ class XGBClassifier(BaseTransformer):
1023
1029
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
1024
1030
 
1025
1031
  if isinstance(dataset, DataFrame):
1026
- self._deps = self._batch_inference_validate_snowpark(
1027
- dataset=dataset,
1028
- inference_method=inference_method,
1029
- )
1032
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1033
+ self._deps = self._get_dependencies()
1030
1034
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1031
1035
  transform_kwargs = dict(
1032
1036
  session=dataset._session,
@@ -1090,17 +1094,15 @@ class XGBClassifier(BaseTransformer):
1090
1094
  transform_kwargs: ScoreKwargsTypedDict = dict()
1091
1095
 
1092
1096
  if isinstance(dataset, DataFrame):
1093
- self._deps = self._batch_inference_validate_snowpark(
1094
- dataset=dataset,
1095
- inference_method="score",
1096
- )
1097
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1098
+ self._deps = self._get_dependencies()
1097
1099
  selected_cols = self._get_active_columns()
1098
1100
  if len(selected_cols) > 0:
1099
1101
  dataset = dataset.select(selected_cols)
1100
1102
  assert isinstance(dataset._session, Session) # keep mypy happy
1101
1103
  transform_kwargs = dict(
1102
1104
  session=dataset._session,
1103
- dependencies=["snowflake-snowpark-python"] + self._deps,
1105
+ dependencies=self._deps,
1104
1106
  score_sproc_imports=['xgboost'],
1105
1107
  )
1106
1108
  elif isinstance(dataset, pd.DataFrame):
@@ -1165,11 +1167,8 @@ class XGBClassifier(BaseTransformer):
1165
1167
 
1166
1168
  if isinstance(dataset, DataFrame):
1167
1169
 
1168
- self._deps = self._batch_inference_validate_snowpark(
1169
- dataset=dataset,
1170
- inference_method=inference_method,
1171
-
1172
- )
1170
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1171
+ self._deps = self._get_dependencies()
1173
1172
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1174
1173
  transform_kwargs = dict(
1175
1174
  session = dataset._session,
@@ -59,12 +59,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "xgboost".replace("sklearn.", "")
59
59
 
60
60
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
61
61
 
62
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
63
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
64
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
65
- return check
66
-
67
-
68
62
  class XGBRegressor(BaseTransformer):
69
63
  r"""Implementation of the scikit-learn API for XGBoost regression
70
64
  For more details on this class, see [xgboost.XGBRegressor]
@@ -482,20 +476,17 @@ class XGBRegressor(BaseTransformer):
482
476
  self,
483
477
  dataset: DataFrame,
484
478
  inference_method: str,
485
- ) -> List[str]:
486
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
487
- return the available package that exists in the snowflake anaconda channel
479
+ ) -> None:
480
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
488
481
 
489
482
  Args:
490
483
  dataset: snowpark dataframe
491
484
  inference_method: the inference method such as predict, score...
492
-
485
+
493
486
  Raises:
494
487
  SnowflakeMLException: If the estimator is not fitted, raise error
495
488
  SnowflakeMLException: If the session is None, raise error
496
489
 
497
- Returns:
498
- A list of available package that exists in the snowflake anaconda channel
499
490
  """
500
491
  if not self._is_fitted:
501
492
  raise exceptions.SnowflakeMLException(
@@ -513,9 +504,7 @@ class XGBRegressor(BaseTransformer):
513
504
  "Session must not specified for snowpark dataset."
514
505
  ),
515
506
  )
516
- # Validate that key package version in user workspace are supported in snowflake conda channel
517
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
518
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
507
+
519
508
 
520
509
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
521
510
  @telemetry.send_api_usage_telemetry(
@@ -563,7 +552,8 @@ class XGBRegressor(BaseTransformer):
563
552
 
564
553
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
565
554
 
566
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
555
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
556
+ self._deps = self._get_dependencies()
567
557
  assert isinstance(
568
558
  dataset._session, Session
569
559
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -646,10 +636,8 @@ class XGBRegressor(BaseTransformer):
646
636
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
647
637
  expected_dtype = convert_sp_to_sf_type(output_types[0])
648
638
 
649
- self._deps = self._batch_inference_validate_snowpark(
650
- dataset=dataset,
651
- inference_method=inference_method,
652
- )
639
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
640
+ self._deps = self._get_dependencies()
653
641
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
654
642
 
655
643
  transform_kwargs = dict(
@@ -716,16 +704,40 @@ class XGBRegressor(BaseTransformer):
716
704
  self._is_fitted = True
717
705
  return output_result
718
706
 
707
+
708
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
709
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
710
+ """ Method not supported for this class.
719
711
 
720
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
721
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
722
- """
712
+
713
+ Raises:
714
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
715
+
716
+ Args:
717
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
718
+ Snowpark or Pandas DataFrame.
719
+ output_cols_prefix: Prefix for the response columns
723
720
  Returns:
724
721
  Transformed dataset.
725
722
  """
726
- self.fit(dataset)
727
- assert self._sklearn_object is not None
728
- return self._sklearn_object.embedding_
723
+ self._infer_input_output_cols(dataset)
724
+ super()._check_dataset_type(dataset)
725
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
726
+ estimator=self._sklearn_object,
727
+ dataset=dataset,
728
+ input_cols=self.input_cols,
729
+ label_cols=self.label_cols,
730
+ sample_weight_col=self.sample_weight_col,
731
+ autogenerated=self._autogenerated,
732
+ subproject=_SUBPROJECT,
733
+ )
734
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
735
+ drop_input_cols=self._drop_input_cols,
736
+ expected_output_cols_list=self.output_cols,
737
+ )
738
+ self._sklearn_object = fitted_estimator
739
+ self._is_fitted = True
740
+ return output_result
729
741
 
730
742
 
731
743
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -816,10 +828,8 @@ class XGBRegressor(BaseTransformer):
816
828
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
817
829
 
818
830
  if isinstance(dataset, DataFrame):
819
- self._deps = self._batch_inference_validate_snowpark(
820
- dataset=dataset,
821
- inference_method=inference_method,
822
- )
831
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
832
+ self._deps = self._get_dependencies()
823
833
  assert isinstance(
824
834
  dataset._session, Session
825
835
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -884,10 +894,8 @@ class XGBRegressor(BaseTransformer):
884
894
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
885
895
 
886
896
  if isinstance(dataset, DataFrame):
887
- self._deps = self._batch_inference_validate_snowpark(
888
- dataset=dataset,
889
- inference_method=inference_method,
890
- )
897
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
898
+ self._deps = self._get_dependencies()
891
899
  assert isinstance(
892
900
  dataset._session, Session
893
901
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -949,10 +957,8 @@ class XGBRegressor(BaseTransformer):
949
957
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
950
958
 
951
959
  if isinstance(dataset, DataFrame):
952
- self._deps = self._batch_inference_validate_snowpark(
953
- dataset=dataset,
954
- inference_method=inference_method,
955
- )
960
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
961
+ self._deps = self._get_dependencies()
956
962
  assert isinstance(
957
963
  dataset._session, Session
958
964
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -1018,10 +1024,8 @@ class XGBRegressor(BaseTransformer):
1018
1024
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
1019
1025
 
1020
1026
  if isinstance(dataset, DataFrame):
1021
- self._deps = self._batch_inference_validate_snowpark(
1022
- dataset=dataset,
1023
- inference_method=inference_method,
1024
- )
1027
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1028
+ self._deps = self._get_dependencies()
1025
1029
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1026
1030
  transform_kwargs = dict(
1027
1031
  session=dataset._session,
@@ -1085,17 +1089,15 @@ class XGBRegressor(BaseTransformer):
1085
1089
  transform_kwargs: ScoreKwargsTypedDict = dict()
1086
1090
 
1087
1091
  if isinstance(dataset, DataFrame):
1088
- self._deps = self._batch_inference_validate_snowpark(
1089
- dataset=dataset,
1090
- inference_method="score",
1091
- )
1092
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
1093
+ self._deps = self._get_dependencies()
1092
1094
  selected_cols = self._get_active_columns()
1093
1095
  if len(selected_cols) > 0:
1094
1096
  dataset = dataset.select(selected_cols)
1095
1097
  assert isinstance(dataset._session, Session) # keep mypy happy
1096
1098
  transform_kwargs = dict(
1097
1099
  session=dataset._session,
1098
- dependencies=["snowflake-snowpark-python"] + self._deps,
1100
+ dependencies=self._deps,
1099
1101
  score_sproc_imports=['xgboost'],
1100
1102
  )
1101
1103
  elif isinstance(dataset, pd.DataFrame):
@@ -1160,11 +1162,8 @@ class XGBRegressor(BaseTransformer):
1160
1162
 
1161
1163
  if isinstance(dataset, DataFrame):
1162
1164
 
1163
- self._deps = self._batch_inference_validate_snowpark(
1164
- dataset=dataset,
1165
- inference_method=inference_method,
1166
-
1167
- )
1165
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1166
+ self._deps = self._get_dependencies()
1168
1167
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1169
1168
  transform_kwargs = dict(
1170
1169
  session = dataset._session,