snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (206) hide show
  1. snowflake/ml/_internal/env_utils.py +66 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
  7. snowflake/ml/dataset/__init__.py +10 -0
  8. snowflake/ml/dataset/dataset.py +454 -129
  9. snowflake/ml/dataset/dataset_factory.py +53 -0
  10. snowflake/ml/dataset/dataset_metadata.py +103 -0
  11. snowflake/ml/dataset/dataset_reader.py +202 -0
  12. snowflake/ml/feature_store/feature_store.py +408 -282
  13. snowflake/ml/feature_store/feature_view.py +37 -8
  14. snowflake/ml/fileset/embedded_stage_fs.py +146 -0
  15. snowflake/ml/fileset/sfcfs.py +0 -4
  16. snowflake/ml/fileset/snowfs.py +159 -0
  17. snowflake/ml/fileset/stage_fs.py +1 -4
  18. snowflake/ml/model/__init__.py +2 -2
  19. snowflake/ml/model/_api.py +16 -1
  20. snowflake/ml/model/_client/model/model_impl.py +27 -0
  21. snowflake/ml/model/_client/model/model_version_impl.py +135 -0
  22. snowflake/ml/model/_client/ops/model_ops.py +137 -67
  23. snowflake/ml/model/_client/sql/model.py +16 -14
  24. snowflake/ml/model/_client/sql/model_version.py +109 -1
  25. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  26. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  27. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  28. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  29. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  30. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  33. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  35. snowflake/ml/model/_packager/model_packager.py +0 -3
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  37. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  38. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  39. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  40. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  41. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
  42. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  43. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  44. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  45. snowflake/ml/modeling/cluster/birch.py +53 -52
  46. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  47. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  48. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  49. snowflake/ml/modeling/cluster/k_means.py +53 -52
  50. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  51. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  52. snowflake/ml/modeling/cluster/optics.py +51 -52
  53. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  54. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  55. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  56. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  57. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  58. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  59. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  60. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  61. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  62. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  63. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  64. snowflake/ml/modeling/covariance/oas.py +51 -52
  65. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  66. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  67. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  68. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  69. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  70. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  71. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  72. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  73. snowflake/ml/modeling/decomposition/pca.py +53 -52
  74. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  75. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  76. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  77. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  78. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  79. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  80. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  81. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  82. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  83. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  84. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  85. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  86. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  88. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  89. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  92. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  93. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  94. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  95. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  96. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  97. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  98. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  99. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  101. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  102. snowflake/ml/modeling/framework/base.py +63 -36
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  105. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  106. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  107. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  116. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  118. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  122. snowflake/ml/modeling/linear_model/lars.py +51 -52
  123. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  124. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  129. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  139. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  142. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  151. snowflake/ml/modeling/manifold/isomap.py +53 -52
  152. snowflake/ml/modeling/manifold/mds.py +53 -52
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  154. snowflake/ml/modeling/manifold/tsne.py +53 -52
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  157. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  158. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  159. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  160. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  161. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  162. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  163. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  164. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  165. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  166. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  167. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  168. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  169. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  170. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  171. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  172. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  173. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  174. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  175. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  176. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  177. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  178. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  179. snowflake/ml/modeling/pipeline/pipeline.py +514 -32
  180. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  182. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  183. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  184. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  185. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  186. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  187. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  188. snowflake/ml/modeling/svm/svc.py +51 -52
  189. snowflake/ml/modeling/svm/svr.py +51 -52
  190. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  191. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  192. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  193. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  194. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  195. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  196. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  197. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  198. snowflake/ml/registry/model_registry.py +3 -149
  199. snowflake/ml/version.py +1 -1
  200. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
  201. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
  202. snowflake/ml/registry/_artifact_manager.py +0 -156
  203. snowflake/ml/registry/artifact.py +0 -46
  204. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
  206. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class NeighborhoodComponentsAnalysis(BaseTransformer):
70
64
  r"""Neighborhood Components Analysis
71
65
  For more details on this class, see [sklearn.neighbors.NeighborhoodComponentsAnalysis]
@@ -345,20 +339,17 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
345
339
  self,
346
340
  dataset: DataFrame,
347
341
  inference_method: str,
348
- ) -> List[str]:
349
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
350
- return the available package that exists in the snowflake anaconda channel
342
+ ) -> None:
343
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
351
344
 
352
345
  Args:
353
346
  dataset: snowpark dataframe
354
347
  inference_method: the inference method such as predict, score...
355
-
348
+
356
349
  Raises:
357
350
  SnowflakeMLException: If the estimator is not fitted, raise error
358
351
  SnowflakeMLException: If the session is None, raise error
359
352
 
360
- Returns:
361
- A list of available package that exists in the snowflake anaconda channel
362
353
  """
363
354
  if not self._is_fitted:
364
355
  raise exceptions.SnowflakeMLException(
@@ -376,9 +367,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
376
367
  "Session must not specified for snowpark dataset."
377
368
  ),
378
369
  )
379
- # Validate that key package version in user workspace are supported in snowflake conda channel
380
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
381
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
370
+
382
371
 
383
372
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
384
373
  @telemetry.send_api_usage_telemetry(
@@ -424,7 +413,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
424
413
 
425
414
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
426
415
 
427
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
416
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
417
+ self._deps = self._get_dependencies()
428
418
  assert isinstance(
429
419
  dataset._session, Session
430
420
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -509,10 +499,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
509
499
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
510
500
  expected_dtype = convert_sp_to_sf_type(output_types[0])
511
501
 
512
- self._deps = self._batch_inference_validate_snowpark(
513
- dataset=dataset,
514
- inference_method=inference_method,
515
- )
502
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
503
+ self._deps = self._get_dependencies()
516
504
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
517
505
 
518
506
  transform_kwargs = dict(
@@ -579,16 +567,42 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
579
567
  self._is_fitted = True
580
568
  return output_result
581
569
 
570
+
571
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
572
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
573
+ """ Fit to data, then transform it
574
+ For more details on this function, see [sklearn.neighbors.NeighborhoodComponentsAnalysis.fit_transform]
575
+ (https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NeighborhoodComponentsAnalysis.html#sklearn.neighbors.NeighborhoodComponentsAnalysis.fit_transform)
576
+
582
577
 
583
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
584
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
585
- """
578
+ Raises:
579
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
580
+
581
+ Args:
582
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
583
+ Snowpark or Pandas DataFrame.
584
+ output_cols_prefix: Prefix for the response columns
586
585
  Returns:
587
586
  Transformed dataset.
588
587
  """
589
- self.fit(dataset)
590
- assert self._sklearn_object is not None
591
- return self._sklearn_object.embedding_
588
+ self._infer_input_output_cols(dataset)
589
+ super()._check_dataset_type(dataset)
590
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
591
+ estimator=self._sklearn_object,
592
+ dataset=dataset,
593
+ input_cols=self.input_cols,
594
+ label_cols=self.label_cols,
595
+ sample_weight_col=self.sample_weight_col,
596
+ autogenerated=self._autogenerated,
597
+ subproject=_SUBPROJECT,
598
+ )
599
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
600
+ drop_input_cols=self._drop_input_cols,
601
+ expected_output_cols_list=self.output_cols,
602
+ )
603
+ self._sklearn_object = fitted_estimator
604
+ self._is_fitted = True
605
+ return output_result
592
606
 
593
607
 
594
608
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -679,10 +693,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
679
693
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
680
694
 
681
695
  if isinstance(dataset, DataFrame):
682
- self._deps = self._batch_inference_validate_snowpark(
683
- dataset=dataset,
684
- inference_method=inference_method,
685
- )
696
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
697
+ self._deps = self._get_dependencies()
686
698
  assert isinstance(
687
699
  dataset._session, Session
688
700
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -747,10 +759,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
747
759
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
748
760
 
749
761
  if isinstance(dataset, DataFrame):
750
- self._deps = self._batch_inference_validate_snowpark(
751
- dataset=dataset,
752
- inference_method=inference_method,
753
- )
762
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
763
+ self._deps = self._get_dependencies()
754
764
  assert isinstance(
755
765
  dataset._session, Session
756
766
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -812,10 +822,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
812
822
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
813
823
 
814
824
  if isinstance(dataset, DataFrame):
815
- self._deps = self._batch_inference_validate_snowpark(
816
- dataset=dataset,
817
- inference_method=inference_method,
818
- )
825
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
826
+ self._deps = self._get_dependencies()
819
827
  assert isinstance(
820
828
  dataset._session, Session
821
829
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -881,10 +889,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
881
889
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
882
890
 
883
891
  if isinstance(dataset, DataFrame):
884
- self._deps = self._batch_inference_validate_snowpark(
885
- dataset=dataset,
886
- inference_method=inference_method,
887
- )
892
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
893
+ self._deps = self._get_dependencies()
888
894
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
889
895
  transform_kwargs = dict(
890
896
  session=dataset._session,
@@ -946,17 +952,15 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
946
952
  transform_kwargs: ScoreKwargsTypedDict = dict()
947
953
 
948
954
  if isinstance(dataset, DataFrame):
949
- self._deps = self._batch_inference_validate_snowpark(
950
- dataset=dataset,
951
- inference_method="score",
952
- )
955
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
956
+ self._deps = self._get_dependencies()
953
957
  selected_cols = self._get_active_columns()
954
958
  if len(selected_cols) > 0:
955
959
  dataset = dataset.select(selected_cols)
956
960
  assert isinstance(dataset._session, Session) # keep mypy happy
957
961
  transform_kwargs = dict(
958
962
  session=dataset._session,
959
- dependencies=["snowflake-snowpark-python"] + self._deps,
963
+ dependencies=self._deps,
960
964
  score_sproc_imports=['sklearn'],
961
965
  )
962
966
  elif isinstance(dataset, pd.DataFrame):
@@ -1021,11 +1025,8 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
1021
1025
 
1022
1026
  if isinstance(dataset, DataFrame):
1023
1027
 
1024
- self._deps = self._batch_inference_validate_snowpark(
1025
- dataset=dataset,
1026
- inference_method=inference_method,
1027
-
1028
- )
1028
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1029
+ self._deps = self._get_dependencies()
1029
1030
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1030
1031
  transform_kwargs = dict(
1031
1032
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RadiusNeighborsClassifier(BaseTransformer):
70
64
  r"""Classifier implementing a vote among neighbors within a given radius
71
65
  For more details on this class, see [sklearn.neighbors.RadiusNeighborsClassifier]
@@ -346,20 +340,17 @@ class RadiusNeighborsClassifier(BaseTransformer):
346
340
  self,
347
341
  dataset: DataFrame,
348
342
  inference_method: str,
349
- ) -> List[str]:
350
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
351
- return the available package that exists in the snowflake anaconda channel
343
+ ) -> None:
344
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
352
345
 
353
346
  Args:
354
347
  dataset: snowpark dataframe
355
348
  inference_method: the inference method such as predict, score...
356
-
349
+
357
350
  Raises:
358
351
  SnowflakeMLException: If the estimator is not fitted, raise error
359
352
  SnowflakeMLException: If the session is None, raise error
360
353
 
361
- Returns:
362
- A list of available package that exists in the snowflake anaconda channel
363
354
  """
364
355
  if not self._is_fitted:
365
356
  raise exceptions.SnowflakeMLException(
@@ -377,9 +368,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
377
368
  "Session must not specified for snowpark dataset."
378
369
  ),
379
370
  )
380
- # Validate that key package version in user workspace are supported in snowflake conda channel
381
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
382
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
371
+
383
372
 
384
373
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
385
374
  @telemetry.send_api_usage_telemetry(
@@ -427,7 +416,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
427
416
 
428
417
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
429
418
 
430
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
419
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
420
+ self._deps = self._get_dependencies()
431
421
  assert isinstance(
432
422
  dataset._session, Session
433
423
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -510,10 +500,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
510
500
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
511
501
  expected_dtype = convert_sp_to_sf_type(output_types[0])
512
502
 
513
- self._deps = self._batch_inference_validate_snowpark(
514
- dataset=dataset,
515
- inference_method=inference_method,
516
- )
503
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
504
+ self._deps = self._get_dependencies()
517
505
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
518
506
 
519
507
  transform_kwargs = dict(
@@ -580,16 +568,40 @@ class RadiusNeighborsClassifier(BaseTransformer):
580
568
  self._is_fitted = True
581
569
  return output_result
582
570
 
571
+
572
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
573
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
574
+ """ Method not supported for this class.
583
575
 
584
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
585
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
586
- """
576
+
577
+ Raises:
578
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
579
+
580
+ Args:
581
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
582
+ Snowpark or Pandas DataFrame.
583
+ output_cols_prefix: Prefix for the response columns
587
584
  Returns:
588
585
  Transformed dataset.
589
586
  """
590
- self.fit(dataset)
591
- assert self._sklearn_object is not None
592
- return self._sklearn_object.embedding_
587
+ self._infer_input_output_cols(dataset)
588
+ super()._check_dataset_type(dataset)
589
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
590
+ estimator=self._sklearn_object,
591
+ dataset=dataset,
592
+ input_cols=self.input_cols,
593
+ label_cols=self.label_cols,
594
+ sample_weight_col=self.sample_weight_col,
595
+ autogenerated=self._autogenerated,
596
+ subproject=_SUBPROJECT,
597
+ )
598
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
599
+ drop_input_cols=self._drop_input_cols,
600
+ expected_output_cols_list=self.output_cols,
601
+ )
602
+ self._sklearn_object = fitted_estimator
603
+ self._is_fitted = True
604
+ return output_result
593
605
 
594
606
 
595
607
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -682,10 +694,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
682
694
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
683
695
 
684
696
  if isinstance(dataset, DataFrame):
685
- self._deps = self._batch_inference_validate_snowpark(
686
- dataset=dataset,
687
- inference_method=inference_method,
688
- )
697
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
698
+ self._deps = self._get_dependencies()
689
699
  assert isinstance(
690
700
  dataset._session, Session
691
701
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -752,10 +762,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
752
762
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
753
763
 
754
764
  if isinstance(dataset, DataFrame):
755
- self._deps = self._batch_inference_validate_snowpark(
756
- dataset=dataset,
757
- inference_method=inference_method,
758
- )
765
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
766
+ self._deps = self._get_dependencies()
759
767
  assert isinstance(
760
768
  dataset._session, Session
761
769
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -817,10 +825,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
817
825
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
818
826
 
819
827
  if isinstance(dataset, DataFrame):
820
- self._deps = self._batch_inference_validate_snowpark(
821
- dataset=dataset,
822
- inference_method=inference_method,
823
- )
828
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
829
+ self._deps = self._get_dependencies()
824
830
  assert isinstance(
825
831
  dataset._session, Session
826
832
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -886,10 +892,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
886
892
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
887
893
 
888
894
  if isinstance(dataset, DataFrame):
889
- self._deps = self._batch_inference_validate_snowpark(
890
- dataset=dataset,
891
- inference_method=inference_method,
892
- )
895
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
896
+ self._deps = self._get_dependencies()
893
897
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
894
898
  transform_kwargs = dict(
895
899
  session=dataset._session,
@@ -953,17 +957,15 @@ class RadiusNeighborsClassifier(BaseTransformer):
953
957
  transform_kwargs: ScoreKwargsTypedDict = dict()
954
958
 
955
959
  if isinstance(dataset, DataFrame):
956
- self._deps = self._batch_inference_validate_snowpark(
957
- dataset=dataset,
958
- inference_method="score",
959
- )
960
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
961
+ self._deps = self._get_dependencies()
960
962
  selected_cols = self._get_active_columns()
961
963
  if len(selected_cols) > 0:
962
964
  dataset = dataset.select(selected_cols)
963
965
  assert isinstance(dataset._session, Session) # keep mypy happy
964
966
  transform_kwargs = dict(
965
967
  session=dataset._session,
966
- dependencies=["snowflake-snowpark-python"] + self._deps,
968
+ dependencies=self._deps,
967
969
  score_sproc_imports=['sklearn'],
968
970
  )
969
971
  elif isinstance(dataset, pd.DataFrame):
@@ -1028,11 +1030,8 @@ class RadiusNeighborsClassifier(BaseTransformer):
1028
1030
 
1029
1031
  if isinstance(dataset, DataFrame):
1030
1032
 
1031
- self._deps = self._batch_inference_validate_snowpark(
1032
- dataset=dataset,
1033
- inference_method=inference_method,
1034
-
1035
- )
1033
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1034
+ self._deps = self._get_dependencies()
1036
1035
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1037
1036
  transform_kwargs = dict(
1038
1037
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.neighbors".replace("skle
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class RadiusNeighborsRegressor(BaseTransformer):
70
64
  r"""Regression based on neighbors within a fixed radius
71
65
  For more details on this class, see [sklearn.neighbors.RadiusNeighborsRegressor]
@@ -336,20 +330,17 @@ class RadiusNeighborsRegressor(BaseTransformer):
336
330
  self,
337
331
  dataset: DataFrame,
338
332
  inference_method: str,
339
- ) -> List[str]:
340
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
341
- return the available package that exists in the snowflake anaconda channel
333
+ ) -> None:
334
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
342
335
 
343
336
  Args:
344
337
  dataset: snowpark dataframe
345
338
  inference_method: the inference method such as predict, score...
346
-
339
+
347
340
  Raises:
348
341
  SnowflakeMLException: If the estimator is not fitted, raise error
349
342
  SnowflakeMLException: If the session is None, raise error
350
343
 
351
- Returns:
352
- A list of available package that exists in the snowflake anaconda channel
353
344
  """
354
345
  if not self._is_fitted:
355
346
  raise exceptions.SnowflakeMLException(
@@ -367,9 +358,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
367
358
  "Session must not specified for snowpark dataset."
368
359
  ),
369
360
  )
370
- # Validate that key package version in user workspace are supported in snowflake conda channel
371
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
372
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
361
+
373
362
 
374
363
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
375
364
  @telemetry.send_api_usage_telemetry(
@@ -417,7 +406,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
417
406
 
418
407
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
419
408
 
420
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
409
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
410
+ self._deps = self._get_dependencies()
421
411
  assert isinstance(
422
412
  dataset._session, Session
423
413
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -500,10 +490,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
500
490
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
501
491
  expected_dtype = convert_sp_to_sf_type(output_types[0])
502
492
 
503
- self._deps = self._batch_inference_validate_snowpark(
504
- dataset=dataset,
505
- inference_method=inference_method,
506
- )
493
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
494
+ self._deps = self._get_dependencies()
507
495
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
508
496
 
509
497
  transform_kwargs = dict(
@@ -570,16 +558,40 @@ class RadiusNeighborsRegressor(BaseTransformer):
570
558
  self._is_fitted = True
571
559
  return output_result
572
560
 
561
+
562
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
563
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
564
+ """ Method not supported for this class.
573
565
 
574
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
575
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
576
- """
566
+
567
+ Raises:
568
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
569
+
570
+ Args:
571
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
572
+ Snowpark or Pandas DataFrame.
573
+ output_cols_prefix: Prefix for the response columns
577
574
  Returns:
578
575
  Transformed dataset.
579
576
  """
580
- self.fit(dataset)
581
- assert self._sklearn_object is not None
582
- return self._sklearn_object.embedding_
577
+ self._infer_input_output_cols(dataset)
578
+ super()._check_dataset_type(dataset)
579
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
580
+ estimator=self._sklearn_object,
581
+ dataset=dataset,
582
+ input_cols=self.input_cols,
583
+ label_cols=self.label_cols,
584
+ sample_weight_col=self.sample_weight_col,
585
+ autogenerated=self._autogenerated,
586
+ subproject=_SUBPROJECT,
587
+ )
588
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
589
+ drop_input_cols=self._drop_input_cols,
590
+ expected_output_cols_list=self.output_cols,
591
+ )
592
+ self._sklearn_object = fitted_estimator
593
+ self._is_fitted = True
594
+ return output_result
583
595
 
584
596
 
585
597
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -670,10 +682,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
670
682
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
671
683
 
672
684
  if isinstance(dataset, DataFrame):
673
- self._deps = self._batch_inference_validate_snowpark(
674
- dataset=dataset,
675
- inference_method=inference_method,
676
- )
685
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
686
+ self._deps = self._get_dependencies()
677
687
  assert isinstance(
678
688
  dataset._session, Session
679
689
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -738,10 +748,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
738
748
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
739
749
 
740
750
  if isinstance(dataset, DataFrame):
741
- self._deps = self._batch_inference_validate_snowpark(
742
- dataset=dataset,
743
- inference_method=inference_method,
744
- )
751
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
752
+ self._deps = self._get_dependencies()
745
753
  assert isinstance(
746
754
  dataset._session, Session
747
755
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -803,10 +811,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
803
811
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
804
812
 
805
813
  if isinstance(dataset, DataFrame):
806
- self._deps = self._batch_inference_validate_snowpark(
807
- dataset=dataset,
808
- inference_method=inference_method,
809
- )
814
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
815
+ self._deps = self._get_dependencies()
810
816
  assert isinstance(
811
817
  dataset._session, Session
812
818
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -872,10 +878,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
872
878
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
873
879
 
874
880
  if isinstance(dataset, DataFrame):
875
- self._deps = self._batch_inference_validate_snowpark(
876
- dataset=dataset,
877
- inference_method=inference_method,
878
- )
881
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
882
+ self._deps = self._get_dependencies()
879
883
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
880
884
  transform_kwargs = dict(
881
885
  session=dataset._session,
@@ -939,17 +943,15 @@ class RadiusNeighborsRegressor(BaseTransformer):
939
943
  transform_kwargs: ScoreKwargsTypedDict = dict()
940
944
 
941
945
  if isinstance(dataset, DataFrame):
942
- self._deps = self._batch_inference_validate_snowpark(
943
- dataset=dataset,
944
- inference_method="score",
945
- )
946
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
947
+ self._deps = self._get_dependencies()
946
948
  selected_cols = self._get_active_columns()
947
949
  if len(selected_cols) > 0:
948
950
  dataset = dataset.select(selected_cols)
949
951
  assert isinstance(dataset._session, Session) # keep mypy happy
950
952
  transform_kwargs = dict(
951
953
  session=dataset._session,
952
- dependencies=["snowflake-snowpark-python"] + self._deps,
954
+ dependencies=self._deps,
953
955
  score_sproc_imports=['sklearn'],
954
956
  )
955
957
  elif isinstance(dataset, pd.DataFrame):
@@ -1014,11 +1016,8 @@ class RadiusNeighborsRegressor(BaseTransformer):
1014
1016
 
1015
1017
  if isinstance(dataset, DataFrame):
1016
1018
 
1017
- self._deps = self._batch_inference_validate_snowpark(
1018
- dataset=dataset,
1019
- inference_method=inference_method,
1020
-
1021
- )
1019
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1020
+ self._deps = self._get_dependencies()
1022
1021
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1023
1022
  transform_kwargs = dict(
1024
1023
  session = dataset._session,