snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class AffinityPropagation(BaseTransformer):
70
64
  r"""Perform Affinity Propagation Clustering of data
71
65
  For more details on this class, see [sklearn.cluster.AffinityPropagation]
@@ -303,20 +297,17 @@ class AffinityPropagation(BaseTransformer):
303
297
  self,
304
298
  dataset: DataFrame,
305
299
  inference_method: str,
306
- ) -> List[str]:
307
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
308
- return the available package that exists in the snowflake anaconda channel
300
+ ) -> None:
301
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
309
302
 
310
303
  Args:
311
304
  dataset: snowpark dataframe
312
305
  inference_method: the inference method such as predict, score...
313
-
306
+
314
307
  Raises:
315
308
  SnowflakeMLException: If the estimator is not fitted, raise error
316
309
  SnowflakeMLException: If the session is None, raise error
317
310
 
318
- Returns:
319
- A list of available package that exists in the snowflake anaconda channel
320
311
  """
321
312
  if not self._is_fitted:
322
313
  raise exceptions.SnowflakeMLException(
@@ -334,9 +325,7 @@ class AffinityPropagation(BaseTransformer):
334
325
  "Session must not specified for snowpark dataset."
335
326
  ),
336
327
  )
337
- # Validate that key package version in user workspace are supported in snowflake conda channel
338
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
339
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
328
+
340
329
 
341
330
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
342
331
  @telemetry.send_api_usage_telemetry(
@@ -384,7 +373,8 @@ class AffinityPropagation(BaseTransformer):
384
373
 
385
374
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
386
375
 
387
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
376
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
377
+ self._deps = self._get_dependencies()
388
378
  assert isinstance(
389
379
  dataset._session, Session
390
380
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -467,10 +457,8 @@ class AffinityPropagation(BaseTransformer):
467
457
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
468
458
  expected_dtype = convert_sp_to_sf_type(output_types[0])
469
459
 
470
- self._deps = self._batch_inference_validate_snowpark(
471
- dataset=dataset,
472
- inference_method=inference_method,
473
- )
460
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
461
+ self._deps = self._get_dependencies()
474
462
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
475
463
 
476
464
  transform_kwargs = dict(
@@ -539,16 +527,40 @@ class AffinityPropagation(BaseTransformer):
539
527
  self._is_fitted = True
540
528
  return output_result
541
529
 
530
+
531
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
532
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
533
+ """ Method not supported for this class.
534
+
542
535
 
543
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
544
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
545
- """
536
+ Raises:
537
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
538
+
539
+ Args:
540
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
541
+ Snowpark or Pandas DataFrame.
542
+ output_cols_prefix: Prefix for the response columns
546
543
  Returns:
547
544
  Transformed dataset.
548
545
  """
549
- self.fit(dataset)
550
- assert self._sklearn_object is not None
551
- return self._sklearn_object.embedding_
546
+ self._infer_input_output_cols(dataset)
547
+ super()._check_dataset_type(dataset)
548
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
549
+ estimator=self._sklearn_object,
550
+ dataset=dataset,
551
+ input_cols=self.input_cols,
552
+ label_cols=self.label_cols,
553
+ sample_weight_col=self.sample_weight_col,
554
+ autogenerated=self._autogenerated,
555
+ subproject=_SUBPROJECT,
556
+ )
557
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
558
+ drop_input_cols=self._drop_input_cols,
559
+ expected_output_cols_list=self.output_cols,
560
+ )
561
+ self._sklearn_object = fitted_estimator
562
+ self._is_fitted = True
563
+ return output_result
552
564
 
553
565
 
554
566
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -639,10 +651,8 @@ class AffinityPropagation(BaseTransformer):
639
651
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
640
652
 
641
653
  if isinstance(dataset, DataFrame):
642
- self._deps = self._batch_inference_validate_snowpark(
643
- dataset=dataset,
644
- inference_method=inference_method,
645
- )
654
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
655
+ self._deps = self._get_dependencies()
646
656
  assert isinstance(
647
657
  dataset._session, Session
648
658
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -707,10 +717,8 @@ class AffinityPropagation(BaseTransformer):
707
717
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
708
718
 
709
719
  if isinstance(dataset, DataFrame):
710
- self._deps = self._batch_inference_validate_snowpark(
711
- dataset=dataset,
712
- inference_method=inference_method,
713
- )
720
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
721
+ self._deps = self._get_dependencies()
714
722
  assert isinstance(
715
723
  dataset._session, Session
716
724
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -772,10 +780,8 @@ class AffinityPropagation(BaseTransformer):
772
780
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
773
781
 
774
782
  if isinstance(dataset, DataFrame):
775
- self._deps = self._batch_inference_validate_snowpark(
776
- dataset=dataset,
777
- inference_method=inference_method,
778
- )
783
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
784
+ self._deps = self._get_dependencies()
779
785
  assert isinstance(
780
786
  dataset._session, Session
781
787
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -841,10 +847,8 @@ class AffinityPropagation(BaseTransformer):
841
847
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
842
848
 
843
849
  if isinstance(dataset, DataFrame):
844
- self._deps = self._batch_inference_validate_snowpark(
845
- dataset=dataset,
846
- inference_method=inference_method,
847
- )
850
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
851
+ self._deps = self._get_dependencies()
848
852
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
849
853
  transform_kwargs = dict(
850
854
  session=dataset._session,
@@ -906,17 +910,15 @@ class AffinityPropagation(BaseTransformer):
906
910
  transform_kwargs: ScoreKwargsTypedDict = dict()
907
911
 
908
912
  if isinstance(dataset, DataFrame):
909
- self._deps = self._batch_inference_validate_snowpark(
910
- dataset=dataset,
911
- inference_method="score",
912
- )
913
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
914
+ self._deps = self._get_dependencies()
913
915
  selected_cols = self._get_active_columns()
914
916
  if len(selected_cols) > 0:
915
917
  dataset = dataset.select(selected_cols)
916
918
  assert isinstance(dataset._session, Session) # keep mypy happy
917
919
  transform_kwargs = dict(
918
920
  session=dataset._session,
919
- dependencies=["snowflake-snowpark-python"] + self._deps,
921
+ dependencies=self._deps,
920
922
  score_sproc_imports=['sklearn'],
921
923
  )
922
924
  elif isinstance(dataset, pd.DataFrame):
@@ -981,11 +983,8 @@ class AffinityPropagation(BaseTransformer):
981
983
 
982
984
  if isinstance(dataset, DataFrame):
983
985
 
984
- self._deps = self._batch_inference_validate_snowpark(
985
- dataset=dataset,
986
- inference_method=inference_method,
987
-
988
- )
986
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
987
+ self._deps = self._get_dependencies()
989
988
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
990
989
  transform_kwargs = dict(
991
990
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class AgglomerativeClustering(BaseTransformer):
70
64
  r"""Agglomerative Clustering
71
65
  For more details on this class, see [sklearn.cluster.AgglomerativeClustering]
@@ -336,20 +330,17 @@ class AgglomerativeClustering(BaseTransformer):
336
330
  self,
337
331
  dataset: DataFrame,
338
332
  inference_method: str,
339
- ) -> List[str]:
340
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
341
- return the available package that exists in the snowflake anaconda channel
333
+ ) -> None:
334
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
342
335
 
343
336
  Args:
344
337
  dataset: snowpark dataframe
345
338
  inference_method: the inference method such as predict, score...
346
-
339
+
347
340
  Raises:
348
341
  SnowflakeMLException: If the estimator is not fitted, raise error
349
342
  SnowflakeMLException: If the session is None, raise error
350
343
 
351
- Returns:
352
- A list of available package that exists in the snowflake anaconda channel
353
344
  """
354
345
  if not self._is_fitted:
355
346
  raise exceptions.SnowflakeMLException(
@@ -367,9 +358,7 @@ class AgglomerativeClustering(BaseTransformer):
367
358
  "Session must not specified for snowpark dataset."
368
359
  ),
369
360
  )
370
- # Validate that key package version in user workspace are supported in snowflake conda channel
371
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
372
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
361
+
373
362
 
374
363
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
375
364
  @telemetry.send_api_usage_telemetry(
@@ -415,7 +404,8 @@ class AgglomerativeClustering(BaseTransformer):
415
404
 
416
405
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
417
406
 
418
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
407
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
408
+ self._deps = self._get_dependencies()
419
409
  assert isinstance(
420
410
  dataset._session, Session
421
411
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -498,10 +488,8 @@ class AgglomerativeClustering(BaseTransformer):
498
488
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
499
489
  expected_dtype = convert_sp_to_sf_type(output_types[0])
500
490
 
501
- self._deps = self._batch_inference_validate_snowpark(
502
- dataset=dataset,
503
- inference_method=inference_method,
504
- )
491
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
492
+ self._deps = self._get_dependencies()
505
493
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
506
494
 
507
495
  transform_kwargs = dict(
@@ -570,16 +558,40 @@ class AgglomerativeClustering(BaseTransformer):
570
558
  self._is_fitted = True
571
559
  return output_result
572
560
 
561
+
562
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
563
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
564
+ """ Method not supported for this class.
565
+
573
566
 
574
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
575
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
576
- """
567
+ Raises:
568
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
569
+
570
+ Args:
571
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
572
+ Snowpark or Pandas DataFrame.
573
+ output_cols_prefix: Prefix for the response columns
577
574
  Returns:
578
575
  Transformed dataset.
579
576
  """
580
- self.fit(dataset)
581
- assert self._sklearn_object is not None
582
- return self._sklearn_object.embedding_
577
+ self._infer_input_output_cols(dataset)
578
+ super()._check_dataset_type(dataset)
579
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
580
+ estimator=self._sklearn_object,
581
+ dataset=dataset,
582
+ input_cols=self.input_cols,
583
+ label_cols=self.label_cols,
584
+ sample_weight_col=self.sample_weight_col,
585
+ autogenerated=self._autogenerated,
586
+ subproject=_SUBPROJECT,
587
+ )
588
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
589
+ drop_input_cols=self._drop_input_cols,
590
+ expected_output_cols_list=self.output_cols,
591
+ )
592
+ self._sklearn_object = fitted_estimator
593
+ self._is_fitted = True
594
+ return output_result
583
595
 
584
596
 
585
597
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -670,10 +682,8 @@ class AgglomerativeClustering(BaseTransformer):
670
682
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
671
683
 
672
684
  if isinstance(dataset, DataFrame):
673
- self._deps = self._batch_inference_validate_snowpark(
674
- dataset=dataset,
675
- inference_method=inference_method,
676
- )
685
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
686
+ self._deps = self._get_dependencies()
677
687
  assert isinstance(
678
688
  dataset._session, Session
679
689
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -738,10 +748,8 @@ class AgglomerativeClustering(BaseTransformer):
738
748
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
739
749
 
740
750
  if isinstance(dataset, DataFrame):
741
- self._deps = self._batch_inference_validate_snowpark(
742
- dataset=dataset,
743
- inference_method=inference_method,
744
- )
751
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
752
+ self._deps = self._get_dependencies()
745
753
  assert isinstance(
746
754
  dataset._session, Session
747
755
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -803,10 +811,8 @@ class AgglomerativeClustering(BaseTransformer):
803
811
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
804
812
 
805
813
  if isinstance(dataset, DataFrame):
806
- self._deps = self._batch_inference_validate_snowpark(
807
- dataset=dataset,
808
- inference_method=inference_method,
809
- )
814
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
815
+ self._deps = self._get_dependencies()
810
816
  assert isinstance(
811
817
  dataset._session, Session
812
818
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -872,10 +878,8 @@ class AgglomerativeClustering(BaseTransformer):
872
878
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
873
879
 
874
880
  if isinstance(dataset, DataFrame):
875
- self._deps = self._batch_inference_validate_snowpark(
876
- dataset=dataset,
877
- inference_method=inference_method,
878
- )
881
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
882
+ self._deps = self._get_dependencies()
879
883
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
880
884
  transform_kwargs = dict(
881
885
  session=dataset._session,
@@ -937,17 +941,15 @@ class AgglomerativeClustering(BaseTransformer):
937
941
  transform_kwargs: ScoreKwargsTypedDict = dict()
938
942
 
939
943
  if isinstance(dataset, DataFrame):
940
- self._deps = self._batch_inference_validate_snowpark(
941
- dataset=dataset,
942
- inference_method="score",
943
- )
944
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
945
+ self._deps = self._get_dependencies()
944
946
  selected_cols = self._get_active_columns()
945
947
  if len(selected_cols) > 0:
946
948
  dataset = dataset.select(selected_cols)
947
949
  assert isinstance(dataset._session, Session) # keep mypy happy
948
950
  transform_kwargs = dict(
949
951
  session=dataset._session,
950
- dependencies=["snowflake-snowpark-python"] + self._deps,
952
+ dependencies=self._deps,
951
953
  score_sproc_imports=['sklearn'],
952
954
  )
953
955
  elif isinstance(dataset, pd.DataFrame):
@@ -1012,11 +1014,8 @@ class AgglomerativeClustering(BaseTransformer):
1012
1014
 
1013
1015
  if isinstance(dataset, DataFrame):
1014
1016
 
1015
- self._deps = self._batch_inference_validate_snowpark(
1016
- dataset=dataset,
1017
- inference_method=inference_method,
1018
-
1019
- )
1017
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
1018
+ self._deps = self._get_dependencies()
1020
1019
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
1021
1020
  transform_kwargs = dict(
1022
1021
  session = dataset._session,
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklear
60
60
 
61
61
  DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
62
62
 
63
- def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
- def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
- return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
- return check
67
-
68
-
69
63
  class Birch(BaseTransformer):
70
64
  r"""Implements the BIRCH clustering algorithm
71
65
  For more details on this class, see [sklearn.cluster.Birch]
@@ -294,20 +288,17 @@ class Birch(BaseTransformer):
294
288
  self,
295
289
  dataset: DataFrame,
296
290
  inference_method: str,
297
- ) -> List[str]:
298
- """Util method to run validate that batch inference can be run on a snowpark dataframe and
299
- return the available package that exists in the snowflake anaconda channel
291
+ ) -> None:
292
+ """Util method to run validate that batch inference can be run on a snowpark dataframe.
300
293
 
301
294
  Args:
302
295
  dataset: snowpark dataframe
303
296
  inference_method: the inference method such as predict, score...
304
-
297
+
305
298
  Raises:
306
299
  SnowflakeMLException: If the estimator is not fitted, raise error
307
300
  SnowflakeMLException: If the session is None, raise error
308
301
 
309
- Returns:
310
- A list of available package that exists in the snowflake anaconda channel
311
302
  """
312
303
  if not self._is_fitted:
313
304
  raise exceptions.SnowflakeMLException(
@@ -325,9 +316,7 @@ class Birch(BaseTransformer):
325
316
  "Session must not specified for snowpark dataset."
326
317
  ),
327
318
  )
328
- # Validate that key package version in user workspace are supported in snowflake conda channel
329
- return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
330
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
319
+
331
320
 
332
321
  @available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
333
322
  @telemetry.send_api_usage_telemetry(
@@ -375,7 +364,8 @@ class Birch(BaseTransformer):
375
364
 
376
365
  expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
377
366
 
378
- self._deps = self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
367
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
368
+ self._deps = self._get_dependencies()
379
369
  assert isinstance(
380
370
  dataset._session, Session
381
371
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -460,10 +450,8 @@ class Birch(BaseTransformer):
460
450
  if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
461
451
  expected_dtype = convert_sp_to_sf_type(output_types[0])
462
452
 
463
- self._deps = self._batch_inference_validate_snowpark(
464
- dataset=dataset,
465
- inference_method=inference_method,
466
- )
453
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
454
+ self._deps = self._get_dependencies()
467
455
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
468
456
 
469
457
  transform_kwargs = dict(
@@ -532,16 +520,42 @@ class Birch(BaseTransformer):
532
520
  self._is_fitted = True
533
521
  return output_result
534
522
 
523
+
524
+ @available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
525
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
526
+ """ Fit to data, then transform it
527
+ For more details on this function, see [sklearn.cluster.Birch.fit_transform]
528
+ (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.Birch.html#sklearn.cluster.Birch.fit_transform)
529
+
535
530
 
536
- @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
537
- def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
538
- """
531
+ Raises:
532
+ TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
533
+
534
+ Args:
535
+ dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
536
+ Snowpark or Pandas DataFrame.
537
+ output_cols_prefix: Prefix for the response columns
539
538
  Returns:
540
539
  Transformed dataset.
541
540
  """
542
- self.fit(dataset)
543
- assert self._sklearn_object is not None
544
- return self._sklearn_object.embedding_
541
+ self._infer_input_output_cols(dataset)
542
+ super()._check_dataset_type(dataset)
543
+ model_trainer = ModelTrainerBuilder.build_fit_transform(
544
+ estimator=self._sklearn_object,
545
+ dataset=dataset,
546
+ input_cols=self.input_cols,
547
+ label_cols=self.label_cols,
548
+ sample_weight_col=self.sample_weight_col,
549
+ autogenerated=self._autogenerated,
550
+ subproject=_SUBPROJECT,
551
+ )
552
+ output_result, fitted_estimator = model_trainer.train_fit_transform(
553
+ drop_input_cols=self._drop_input_cols,
554
+ expected_output_cols_list=self.output_cols,
555
+ )
556
+ self._sklearn_object = fitted_estimator
557
+ self._is_fitted = True
558
+ return output_result
545
559
 
546
560
 
547
561
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
@@ -632,10 +646,8 @@ class Birch(BaseTransformer):
632
646
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
633
647
 
634
648
  if isinstance(dataset, DataFrame):
635
- self._deps = self._batch_inference_validate_snowpark(
636
- dataset=dataset,
637
- inference_method=inference_method,
638
- )
649
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
650
+ self._deps = self._get_dependencies()
639
651
  assert isinstance(
640
652
  dataset._session, Session
641
653
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -700,10 +712,8 @@ class Birch(BaseTransformer):
700
712
  transform_kwargs: BatchInferenceKwargsTypedDict = dict()
701
713
 
702
714
  if isinstance(dataset, DataFrame):
703
- self._deps = self._batch_inference_validate_snowpark(
704
- dataset=dataset,
705
- inference_method=inference_method,
706
- )
715
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
716
+ self._deps = self._get_dependencies()
707
717
  assert isinstance(
708
718
  dataset._session, Session
709
719
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -765,10 +775,8 @@ class Birch(BaseTransformer):
765
775
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
766
776
 
767
777
  if isinstance(dataset, DataFrame):
768
- self._deps = self._batch_inference_validate_snowpark(
769
- dataset=dataset,
770
- inference_method=inference_method,
771
- )
778
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
779
+ self._deps = self._get_dependencies()
772
780
  assert isinstance(
773
781
  dataset._session, Session
774
782
  ) # mypy does not recognize the check in _batch_inference_validate_snowpark()
@@ -834,10 +842,8 @@ class Birch(BaseTransformer):
834
842
  expected_output_cols = self._get_output_column_names(output_cols_prefix)
835
843
 
836
844
  if isinstance(dataset, DataFrame):
837
- self._deps = self._batch_inference_validate_snowpark(
838
- dataset=dataset,
839
- inference_method=inference_method,
840
- )
845
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
846
+ self._deps = self._get_dependencies()
841
847
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
842
848
  transform_kwargs = dict(
843
849
  session=dataset._session,
@@ -899,17 +905,15 @@ class Birch(BaseTransformer):
899
905
  transform_kwargs: ScoreKwargsTypedDict = dict()
900
906
 
901
907
  if isinstance(dataset, DataFrame):
902
- self._deps = self._batch_inference_validate_snowpark(
903
- dataset=dataset,
904
- inference_method="score",
905
- )
908
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
909
+ self._deps = self._get_dependencies()
906
910
  selected_cols = self._get_active_columns()
907
911
  if len(selected_cols) > 0:
908
912
  dataset = dataset.select(selected_cols)
909
913
  assert isinstance(dataset._session, Session) # keep mypy happy
910
914
  transform_kwargs = dict(
911
915
  session=dataset._session,
912
- dependencies=["snowflake-snowpark-python"] + self._deps,
916
+ dependencies=self._deps,
913
917
  score_sproc_imports=['sklearn'],
914
918
  )
915
919
  elif isinstance(dataset, pd.DataFrame):
@@ -974,11 +978,8 @@ class Birch(BaseTransformer):
974
978
 
975
979
  if isinstance(dataset, DataFrame):
976
980
 
977
- self._deps = self._batch_inference_validate_snowpark(
978
- dataset=dataset,
979
- inference_method=inference_method,
980
-
981
- )
981
+ self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
982
+ self._deps = self._get_dependencies()
982
983
  assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
983
984
  transform_kwargs = dict(
984
985
  session = dataset._session,