snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class LocalOutlierFactor(BaseTransformer):
70
64
  r"""Unsupervised Outlier Detection using the Local Outlier Factor (LOF)
71
65
  For more details on this class, see [sklearn.neighbors.LocalOutlierFactor]
@@ -341,20 +335,17 @@ class LocalOutlierFactor(BaseTransformer):
341
335
  self,
342
336
  dataset: DataFrame,
343
337
  inference_method: str,
344
- ) -> List[str]:
345
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
346
- return the available package that exists in the snowflake anaconda channel
338
+ ) -> None:
339
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
347
340
 
348
341
  Args:
349
342
  dataset: snowpark dataframe
350
343
  inference_method: the inference method such as predict, score...
351
-
344
+
352
345
  Raises:
353
346
  SnowflakeMLException: If the estimator is not fitted, raise error
354
347
  SnowflakeMLException: If the session is None, raise error
355
348
 
356
- Returns:
357
- A list of available package that exists in the snowflake anaconda channel
358
349
  """
359
350
  if not self._is_fitted:
360
351
  raise exceptions.SnowflakeMLException(
@@ -372,9 +363,7 @@ class LocalOutlierFactor(BaseTransformer):
372
363
  "Session must not specified for snowpark dataset."
373
364
  ),
374
365
  )
375
- # Validate that key package version in user workspace are supported in snowflake conda channel
376
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
377
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
366
+
378
367
 
379
368
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
380
369
  @telemetry.send_api_usage_telemetry(
@@ -422,7 +411,8 @@ class LocalOutlierFactor(BaseTransformer):
422
411
 
423
412
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
424
413
 
425
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
414
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
415
+ self._deps = self._get_dependencies()
426
416
  assert isinstance(
427
417
  dataset._session, Session
428
418
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -505,10 +495,8 @@ class LocalOutlierFactor(BaseTransformer):
505
495
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
506
496
  expected_dtype = convert_sp_to_sf_type(output_types[0])
507
497
 
508
- self._deps = self._batch_inference_validate_snowpark(
509
- dataset=dataset,
510
- inference_method=inference_method,
511
- )
498
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
499
+ self._deps = self._get_dependencies()
512
500
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
513
501
 
514
502
  transform_kwargs = dict(
@@ -577,16 +565,40 @@ class LocalOutlierFactor(BaseTransformer):
577
565
  self._is_fitted = True
578
566
  return output_result
579
567
 
568
+
569
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
570
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
571
+ """ Method not supported for this class.
572
+
580
573
 
581
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
582
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
583
- """
574
+ Raises:
575
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
576
+
577
+ Args:
578
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
579
+ Snowpark or Pandas DataFrame.
580
+ output_cols_prefix: Prefix for the response columns
584
581
  Returns:
585
582
  Transformed dataset.
586
583
  """
587
- self.fit(dataset)
588
- assert self._sklearn_object is not None
589
- return self._sklearn_object.embedding_
584
+ self._infer_input_output_cols(dataset)
585
+ super()._check_dataset_type(dataset)
586
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
587
+ estimator=self._sklearn_object,
588
+ dataset=dataset,
589
+ input_cols=self.input_cols,
590
+ label_cols=self.label_cols,
591
+ sample_weight_col=self.sample_weight_col,
592
+ autogenerated=self._autogenerated,
593
+ subproject=_SUBPROJECT,
594
+ )
595
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
596
+ drop_input_cols=self._drop_input_cols,
597
+ expected_output_cols_list=self.output_cols,
598
+ )
599
+ self._sklearn_object = fitted_estimator
600
+ self._is_fitted = True
601
+ return output_result
590
602
 
591
603
 
592
604
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -677,10 +689,8 @@ class LocalOutlierFactor(BaseTransformer):
677
689
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
678
690
 
679
691
  if isinstance(dataset, DataFrame):
680
- self._deps = self._batch_inference_validate_snowpark(
681
- dataset=dataset,
682
- inference_method=inference_method,
683
- )
692
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
693
+ self._deps = self._get_dependencies()
684
694
  assert isinstance(
685
695
  dataset._session, Session
686
696
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -745,10 +755,8 @@ class LocalOutlierFactor(BaseTransformer):
745
755
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
746
756
 
747
757
  if isinstance(dataset, DataFrame):
748
- self._deps = self._batch_inference_validate_snowpark(
749
- dataset=dataset,
750
- inference_method=inference_method,
751
- )
758
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
759
+ self._deps = self._get_dependencies()
752
760
  assert isinstance(
753
761
  dataset._session, Session
754
762
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -812,10 +820,8 @@ class LocalOutlierFactor(BaseTransformer):
812
820
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
813
821
 
814
822
  if isinstance(dataset, DataFrame):
815
- self._deps = self._batch_inference_validate_snowpark(
816
- dataset=dataset,
817
- inference_method=inference_method,
818
- )
823
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
824
+ self._deps = self._get_dependencies()
819
825
  assert isinstance(
820
826
  dataset._session, Session
821
827
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -883,10 +889,8 @@ class LocalOutlierFactor(BaseTransformer):
883
889
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
884
890
 
885
891
  if isinstance(dataset, DataFrame):
886
- self._deps = self._batch_inference_validate_snowpark(
887
- dataset=dataset,
888
- inference_method=inference_method,
889
- )
892
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
893
+ self._deps = self._get_dependencies()
890
894
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
891
895
  transform_kwargs = dict(
892
896
  session=dataset._session,
@@ -948,17 +952,15 @@ class LocalOutlierFactor(BaseTransformer):
948
952
  transform_kwargs: ScoreKwargsTypedDict = dict()
949
953
 
950
954
  if isinstance(dataset, DataFrame):
951
- self._deps = self._batch_inference_validate_snowpark(
952
- dataset=dataset,
953
- inference_method="score",
954
- )
955
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
956
+ self._deps = self._get_dependencies()
955
957
  selected_cols = self._get_active_columns()
956
958
  if len(selected_cols) > 0:
957
959
  dataset = dataset.select(selected_cols)
958
960
  assert isinstance(dataset._session, Session) # keep mypy happy
959
961
  transform_kwargs = dict(
960
962
  session=dataset._session,
961
- dependencies=["snowflake-snowpark-python"] + self._deps,
963
+ dependencies=self._deps,
962
964
  score_sproc_imports=['sklearn'],
963
965
  )
964
966
  elif isinstance(dataset, pd.DataFrame):
@@ -1025,11 +1027,8 @@ class LocalOutlierFactor(BaseTransformer):
1025
1027
 
1026
1028
  if isinstance(dataset, DataFrame):
1027
1029
 
1028
- self._deps = self._batch_inference_validate_snowpark(
1029
- dataset=dataset,
1030
- inference_method=inference_method,
1031
-
1032
- )
1030
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1031
+ self._deps = self._get_dependencies()
1033
1032
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1034
1033
  transform_kwargs = dict(
1035
1034
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class NearestCentroid(BaseTransformer):
70
64
  r"""Nearest centroid classifier
71
65
  For more details on this class, see [sklearn.neighbors.NearestCentroid]
@@ -274,20 +268,17 @@ class NearestCentroid(BaseTransformer):
274
268
  self,
275
269
  dataset: DataFrame,
276
270
  inference_method: str,
277
- ) -> List[str]:
278
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
279
- return the available package that exists in the snowflake anaconda channel
271
+ ) -> None:
272
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
280
273
 
281
274
  Args:
282
275
  dataset: snowpark dataframe
283
276
  inference_method: the inference method such as predict, score...
284
-
277
+
285
278
  Raises:
286
279
  SnowflakeMLException: If the estimator is not fitted, raise error
287
280
  SnowflakeMLException: If the session is None, raise error
288
281
 
289
- Returns:
290
- A list of available package that exists in the snowflake anaconda channel
291
282
  """
292
283
  if not self._is_fitted:
293
284
  raise exceptions.SnowflakeMLException(
@@ -305,9 +296,7 @@ class NearestCentroid(BaseTransformer):
305
296
  "Session must not specified for snowpark dataset."
306
297
  ),
307
298
  )
308
- # Validate that key package version in user workspace are supported in snowflake conda channel
309
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
310
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
299
+
311
300
 
312
301
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
313
302
  @telemetry.send_api_usage_telemetry(
@@ -355,7 +344,8 @@ class NearestCentroid(BaseTransformer):
355
344
 
356
345
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
357
346
 
358
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
347
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
348
+ self._deps = self._get_dependencies()
359
349
  assert isinstance(
360
350
  dataset._session, Session
361
351
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -438,10 +428,8 @@ class NearestCentroid(BaseTransformer):
438
428
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
439
429
  expected_dtype = convert_sp_to_sf_type(output_types[0])
440
430
 
441
- self._deps = self._batch_inference_validate_snowpark(
442
- dataset=dataset,
443
- inference_method=inference_method,
444
- )
431
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
432
+ self._deps = self._get_dependencies()
445
433
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
446
434
 
447
435
  transform_kwargs = dict(
@@ -508,16 +496,40 @@ class NearestCentroid(BaseTransformer):
508
496
  self._is_fitted = True
509
497
  return output_result
510
498
 
499
+
500
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
501
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
502
+ """ Method not supported for this class.
511
503
 
512
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
513
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
514
- """
504
+
505
+ Raises:
506
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
507
+
508
+ Args:
509
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
510
+ Snowpark or Pandas DataFrame.
511
+ output_cols_prefix: Prefix for the response columns
515
512
  Returns:
516
513
  Transformed dataset.
517
514
  """
518
- self.fit(dataset)
519
- assert self._sklearn_object is not None
520
- return self._sklearn_object.embedding_
515
+ self._infer_input_output_cols(dataset)
516
+ super()._check_dataset_type(dataset)
517
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
518
+ estimator=self._sklearn_object,
519
+ dataset=dataset,
520
+ input_cols=self.input_cols,
521
+ label_cols=self.label_cols,
522
+ sample_weight_col=self.sample_weight_col,
523
+ autogenerated=self._autogenerated,
524
+ subproject=_SUBPROJECT,
525
+ )
526
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
527
+ drop_input_cols=self._drop_input_cols,
528
+ expected_output_cols_list=self.output_cols,
529
+ )
530
+ self._sklearn_object = fitted_estimator
531
+ self._is_fitted = True
532
+ return output_result
521
533
 
522
534
 
523
535
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -608,10 +620,8 @@ class NearestCentroid(BaseTransformer):
608
620
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
609
621
 
610
622
  if isinstance(dataset, DataFrame):
611
- self._deps = self._batch_inference_validate_snowpark(
612
- dataset=dataset,
613
- inference_method=inference_method,
614
- )
623
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
624
+ self._deps = self._get_dependencies()
615
625
  assert isinstance(
616
626
  dataset._session, Session
617
627
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -676,10 +686,8 @@ class NearestCentroid(BaseTransformer):
676
686
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
677
687
 
678
688
  if isinstance(dataset, DataFrame):
679
- self._deps = self._batch_inference_validate_snowpark(
680
- dataset=dataset,
681
- inference_method=inference_method,
682
- )
689
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
690
+ self._deps = self._get_dependencies()
683
691
  assert isinstance(
684
692
  dataset._session, Session
685
693
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -741,10 +749,8 @@ class NearestCentroid(BaseTransformer):
741
749
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
742
750
 
743
751
  if isinstance(dataset, DataFrame):
744
- self._deps = self._batch_inference_validate_snowpark(
745
- dataset=dataset,
746
- inference_method=inference_method,
747
- )
752
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
753
+ self._deps = self._get_dependencies()
748
754
  assert isinstance(
749
755
  dataset._session, Session
750
756
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -810,10 +816,8 @@ class NearestCentroid(BaseTransformer):
810
816
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
811
817
 
812
818
  if isinstance(dataset, DataFrame):
813
- self._deps = self._batch_inference_validate_snowpark(
814
- dataset=dataset,
815
- inference_method=inference_method,
816
- )
819
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
820
+ self._deps = self._get_dependencies()
817
821
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
818
822
  transform_kwargs = dict(
819
823
  session=dataset._session,
@@ -877,17 +881,15 @@ class NearestCentroid(BaseTransformer):
877
881
  transform_kwargs: ScoreKwargsTypedDict = dict()
878
882
 
879
883
  if isinstance(dataset, DataFrame):
880
- self._deps = self._batch_inference_validate_snowpark(
881
- dataset=dataset,
882
- inference_method="score",
883
- )
884
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
885
+ self._deps = self._get_dependencies()
884
886
  selected_cols = self._get_active_columns()
885
887
  if len(selected_cols) > 0:
886
888
  dataset = dataset.select(selected_cols)
887
889
  assert isinstance(dataset._session, Session) # keep mypy happy
888
890
  transform_kwargs = dict(
889
891
  session=dataset._session,
890
- dependencies=["snowflake-snowpark-python"] + self._deps,
892
+ dependencies=self._deps,
891
893
  score_sproc_imports=['sklearn'],
892
894
  )
893
895
  elif isinstance(dataset, pd.DataFrame):
@@ -952,11 +954,8 @@ class NearestCentroid(BaseTransformer):
952
954
 
953
955
  if isinstance(dataset, DataFrame):
954
956
 
955
- self._deps = self._batch_inference_validate_snowpark(
956
- dataset=dataset,
957
- inference_method=inference_method,
958
-
959
- )
957
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
958
+ self._deps = self._get_dependencies()
960
959
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
961
960
  transform_kwargs = dict(
962
961
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class NearestNeighbors(BaseTransformer):
70
64
  r"""Unsupervised learner for implementing neighbor searches
71
65
  For more details on this class, see [sklearn.neighbors.NearestNeighbors]
@@ -324,20 +318,17 @@ class NearestNeighbors(BaseTransformer):
324
318
  self,
325
319
  dataset: DataFrame,
326
320
  inference_method: str,
327
- ) -> List[str]:
328
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
329
- return the available package that exists in the snowflake anaconda channel
321
+ ) -> None:
322
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
330
323
 
331
324
  Args:
332
325
  dataset: snowpark dataframe
333
326
  inference_method: the inference method such as predict, score...
334
-
327
+
335
328
  Raises:
336
329
  SnowflakeMLException: If the estimator is not fitted, raise error
337
330
  SnowflakeMLException: If the session is None, raise error
338
331
 
339
- Returns:
340
- A list of available package that exists in the snowflake anaconda channel
341
332
  """
342
333
  if not self._is_fitted:
343
334
  raise exceptions.SnowflakeMLException(
@@ -355,9 +346,7 @@ class NearestNeighbors(BaseTransformer):
355
346
  "Session must not specified for snowpark dataset."
356
347
  ),
357
348
  )
358
- # Validate that key package version in user workspace are supported in snowflake conda channel
359
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
360
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
349
+
361
350
 
362
351
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
363
352
  @telemetry.send_api_usage_telemetry(
@@ -403,7 +392,8 @@ class NearestNeighbors(BaseTransformer):
403
392
 
404
393
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
405
394
 
406
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
395
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
396
+ self._deps = self._get_dependencies()
407
397
  assert isinstance(
408
398
  dataset._session, Session
409
399
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -486,10 +476,8 @@ class NearestNeighbors(BaseTransformer):
486
476
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
487
477
  expected_dtype = convert_sp_to_sf_type(output_types[0])
488
478
 
489
- self._deps = self._batch_inference_validate_snowpark(
490
- dataset=dataset,
491
- inference_method=inference_method,
492
- )
479
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
480
+ self._deps = self._get_dependencies()
493
481
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
494
482
 
495
483
  transform_kwargs = dict(
@@ -556,16 +544,40 @@ class NearestNeighbors(BaseTransformer):
556
544
  self._is_fitted = True
557
545
  return output_result
558
546
 
547
+
548
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
549
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
550
+ """ Method not supported for this class.
559
551
 
560
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
561
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
562
- """
552
+
553
+ Raises:
554
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
555
+
556
+ Args:
557
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
558
+ Snowpark or Pandas DataFrame.
559
+ output_cols_prefix: Prefix for the response columns
563
560
  Returns:
564
561
  Transformed dataset.
565
562
  """
566
- self.fit(dataset)
567
- assert self._sklearn_object is not None
568
- return self._sklearn_object.embedding_
563
+ self._infer_input_output_cols(dataset)
564
+ super()._check_dataset_type(dataset)
565
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
566
+ estimator=self._sklearn_object,
567
+ dataset=dataset,
568
+ input_cols=self.input_cols,
569
+ label_cols=self.label_cols,
570
+ sample_weight_col=self.sample_weight_col,
571
+ autogenerated=self._autogenerated,
572
+ subproject=_SUBPROJECT,
573
+ )
574
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=self.output_cols,
577
+ )
578
+ self._sklearn_object = fitted_estimator
579
+ self._is_fitted = True
580
+ return output_result
569
581
 
570
582
 
571
583
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -656,10 +668,8 @@ class NearestNeighbors(BaseTransformer):
656
668
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
657
669
 
658
670
  if isinstance(dataset, DataFrame):
659
- self._deps = self._batch_inference_validate_snowpark(
660
- dataset=dataset,
661
- inference_method=inference_method,
662
- )
671
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
672
+ self._deps = self._get_dependencies()
663
673
  assert isinstance(
664
674
  dataset._session, Session
665
675
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -724,10 +734,8 @@ class NearestNeighbors(BaseTransformer):
724
734
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
725
735
 
726
736
  if isinstance(dataset, DataFrame):
727
- self._deps = self._batch_inference_validate_snowpark(
728
- dataset=dataset,
729
- inference_method=inference_method,
730
- )
737
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
738
+ self._deps = self._get_dependencies()
731
739
  assert isinstance(
732
740
  dataset._session, Session
733
741
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -789,10 +797,8 @@ class NearestNeighbors(BaseTransformer):
789
797
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
790
798
 
791
799
  if isinstance(dataset, DataFrame):
792
- self._deps = self._batch_inference_validate_snowpark(
793
- dataset=dataset,
794
- inference_method=inference_method,
795
- )
800
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
801
+ self._deps = self._get_dependencies()
796
802
  assert isinstance(
797
803
  dataset._session, Session
798
804
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -858,10 +864,8 @@ class NearestNeighbors(BaseTransformer):
858
864
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
859
865
 
860
866
  if isinstance(dataset, DataFrame):
861
- self._deps = self._batch_inference_validate_snowpark(
862
- dataset=dataset,
863
- inference_method=inference_method,
864
- )
867
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
868
+ self._deps = self._get_dependencies()
865
869
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
866
870
  transform_kwargs = dict(
867
871
  session=dataset._session,
@@ -923,17 +927,15 @@ class NearestNeighbors(BaseTransformer):
923
927
  transform_kwargs: ScoreKwargsTypedDict = dict()
924
928
 
925
929
  if isinstance(dataset, DataFrame):
926
- self._deps = self._batch_inference_validate_snowpark(
927
- dataset=dataset,
928
- inference_method="score",
929
- )
930
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
931
+ self._deps = self._get_dependencies()
930
932
  selected_cols = self._get_active_columns()
931
933
  if len(selected_cols) > 0:
932
934
  dataset = dataset.select(selected_cols)
933
935
  assert isinstance(dataset._session, Session) # keep mypy happy
934
936
  transform_kwargs = dict(
935
937
  session=dataset._session,
936
- dependencies=["snowflake-snowpark-python"] + self._deps,
938
+ dependencies=self._deps,
937
939
  score_sproc_imports=['sklearn'],
938
940
  )
939
941
  elif isinstance(dataset, pd.DataFrame):
@@ -1000,11 +1002,8 @@ class NearestNeighbors(BaseTransformer):
1000
1002
 
1001
1003
  if isinstance(dataset, DataFrame):
1002
1004
 
1003
- self._deps = self._batch_inference_validate_snowpark(
1004
- dataset=dataset,
1005
- inference_method=inference_method,
1006
-
1007
- )
1005
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1006
+ self._deps = self._get_dependencies()
1008
1007
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1009
1008
  transform_kwargs = dict(
1010
1009
  session = dataset._session,