snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +240 -16
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_sse_client.py +81 -0
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +34 -10
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  16. snowflake/ml/_internal/telemetry.py +26 -0
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/dataset/dataset.py +54 -32
  20. snowflake/ml/dataset/dataset_factory.py +3 -4
  21. snowflake/ml/feature_store/feature_store.py +440 -243
  22. snowflake/ml/feature_store/feature_view.py +61 -9
  23. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  24. snowflake/ml/fileset/fileset.py +2 -2
  25. snowflake/ml/fileset/snowfs.py +4 -15
  26. snowflake/ml/fileset/stage_fs.py +6 -8
  27. snowflake/ml/lineage/__init__.py +3 -0
  28. snowflake/ml/lineage/lineage_node.py +139 -0
  29. snowflake/ml/model/_client/model/model_impl.py +47 -14
  30. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  31. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  32. snowflake/ml/model/_client/sql/model.py +1 -0
  33. snowflake/ml/model/_client/sql/model_version.py +47 -4
  34. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  35. snowflake/ml/model/_model_composer/model_composer.py +7 -6
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  37. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  38. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
  40. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  41. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  42. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  43. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  45. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  46. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  53. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  56. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  57. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  58. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  59. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  60. snowflake/ml/model/_packager/model_packager.py +9 -4
  61. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  62. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  63. snowflake/ml/model/_signatures/core.py +13 -1
  64. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  65. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  66. snowflake/ml/model/custom_model.py +22 -2
  67. snowflake/ml/model/model_signature.py +2 -0
  68. snowflake/ml/model/type_hints.py +74 -4
  69. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  70. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
  71. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  72. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
  73. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
  74. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
  75. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  76. snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
  77. snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
  78. snowflake/ml/modeling/cluster/birch.py +5 -3
  79. snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
  80. snowflake/ml/modeling/cluster/dbscan.py +5 -3
  81. snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
  82. snowflake/ml/modeling/cluster/k_means.py +5 -3
  83. snowflake/ml/modeling/cluster/mean_shift.py +5 -3
  84. snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
  85. snowflake/ml/modeling/cluster/optics.py +5 -3
  86. snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
  87. snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
  88. snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
  89. snowflake/ml/modeling/compose/column_transformer.py +5 -3
  90. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  91. snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
  92. snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
  93. snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
  94. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
  95. snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
  96. snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
  97. snowflake/ml/modeling/covariance/oas.py +5 -3
  98. snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
  99. snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
  100. snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
  101. snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
  102. snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
  103. snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
  104. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
  105. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
  106. snowflake/ml/modeling/decomposition/pca.py +5 -3
  107. snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
  108. snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
  109. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  110. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  111. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  112. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  113. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  114. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  115. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  116. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  117. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  118. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  119. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  120. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  121. snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
  122. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  123. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  124. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  125. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  126. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  127. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  128. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  129. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  130. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  131. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  132. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  133. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  134. snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
  135. snowflake/ml/modeling/framework/base.py +3 -8
  136. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  137. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  138. snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
  139. snowflake/ml/modeling/impute/knn_imputer.py +5 -3
  140. snowflake/ml/modeling/impute/missing_indicator.py +5 -3
  141. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  142. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
  143. snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
  144. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
  145. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
  146. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
  147. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  148. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  149. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  151. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  152. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  153. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  154. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  155. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  156. snowflake/ml/modeling/linear_model/lars.py +1 -1
  157. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  158. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  159. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  160. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  161. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  162. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  163. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  164. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  165. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  166. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  167. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  168. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  169. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  170. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  171. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  172. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  173. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  174. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  175. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  176. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  177. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  178. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  179. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  180. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  181. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
  182. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  183. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  184. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  185. snowflake/ml/modeling/manifold/isomap.py +5 -3
  186. snowflake/ml/modeling/manifold/mds.py +5 -3
  187. snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
  188. snowflake/ml/modeling/manifold/tsne.py +5 -3
  189. snowflake/ml/modeling/metrics/ranking.py +3 -0
  190. snowflake/ml/modeling/metrics/regression.py +3 -0
  191. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
  192. snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
  193. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  194. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  195. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  196. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  197. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  198. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  199. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  200. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  201. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  202. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  203. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  204. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  205. snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
  206. snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
  207. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  208. snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
  209. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  210. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  211. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  212. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
  213. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  214. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  215. snowflake/ml/modeling/pipeline/pipeline.py +6 -0
  216. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  217. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  218. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  219. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  220. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  221. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  222. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
  223. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
  224. snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
  225. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  226. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  227. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  228. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  229. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  230. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  231. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  232. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  233. snowflake/ml/modeling/svm/svc.py +1 -1
  234. snowflake/ml/modeling/svm/svr.py +1 -1
  235. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  236. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  237. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  238. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  239. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  240. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  241. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  242. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  243. snowflake/ml/registry/_manager/model_manager.py +16 -3
  244. snowflake/ml/version.py +1 -1
  245. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
  246. snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
  247. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  248. snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
  249. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  250. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,12 @@ from snowflake.ml._internal.exceptions import (
13
13
  exceptions,
14
14
  modeling_error_messages,
15
15
  )
16
- from snowflake.ml._internal.utils import pkg_version_utils
16
+ from snowflake.ml._internal.utils import pkg_version_utils, temp_file_utils
17
17
  from snowflake.ml._internal.utils.query_result_checker import ResultValidator
18
18
  from snowflake.ml._internal.utils.snowpark_dataframe_utils import (
19
19
  cast_snowpark_dataframe,
20
20
  )
21
- from snowflake.ml._internal.utils.temp_file_utils import get_temp_file_path
21
+ from snowflake.ml.modeling._internal import estimator_utils
22
22
  from snowflake.ml.modeling._internal.model_specifications import (
23
23
  ModelSpecifications,
24
24
  ModelSpecificationsBuilder,
@@ -303,11 +303,10 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
303
303
  statement_params=statement_params,
304
304
  anonymous=True,
305
305
  imports=list(import_file_paths),
306
+ execute_as="caller",
306
307
  ) # type: ignore[misc]
307
308
  def fit_wrapper_sproc(
308
309
  session: Session,
309
- stage_transform_file_name: str,
310
- stage_result_file_name: str,
311
310
  dataset_stage_name: str,
312
311
  batch_size: int,
313
312
  input_cols: List[str],
@@ -320,9 +319,13 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
320
319
 
321
320
  import cloudpickle as cp
322
321
 
323
- local_transform_file_name = get_temp_file_path()
322
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
324
323
 
325
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
324
+ session.file.get(
325
+ stage_location=dataset_stage_name,
326
+ target_directory=local_transform_file_name,
327
+ statement_params=statement_params,
328
+ )
326
329
 
327
330
  local_transform_file_path = os.path.join(
328
331
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -345,13 +348,13 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
345
348
  sample_weight_col=sample_weight_col,
346
349
  )
347
350
 
348
- local_result_file_name = get_temp_file_path()
351
+ local_result_file_name = temp_file_utils.get_temp_file_path()
349
352
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
350
353
  cp.dump(estimator, local_result_file_obj)
351
354
 
352
355
  session.file.put(
353
- local_result_file_name,
354
- stage_result_file_name,
356
+ local_file_name=local_result_file_name,
357
+ stage_location=dataset_stage_name,
355
358
  auto_compress=False,
356
359
  overwrite=True,
357
360
  statement_params=statement_params,
@@ -394,11 +397,6 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
394
397
  SnowflakeMLException: For known types of user and system errors.
395
398
  e: For every unexpected exception from SnowflakeClient.
396
399
  """
397
- temp_stage_name = self._create_temp_stage()
398
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(stage_name=temp_stage_name)
399
- data_file_paths = self._write_training_data_to_stage(dataset_stage_name=temp_stage_name)
400
-
401
- # Call fit sproc
402
400
  statement_params = telemetry.get_function_usage_statement_params(
403
401
  project=_PROJECT,
404
402
  subproject=self._subproject,
@@ -406,7 +404,16 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
406
404
  api_calls=[Session.call],
407
405
  custom_tags=None,
408
406
  )
407
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
408
+ estimator_utils.upload_model_to_stage(
409
+ stage_name=temp_stage_name,
410
+ estimator=self.estimator,
411
+ session=self.session,
412
+ statement_params=statement_params,
413
+ )
414
+ data_file_paths = self._write_training_data_to_stage(dataset_stage_name=temp_stage_name)
409
415
 
416
+ # Call fit sproc
410
417
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
411
418
  fit_wrapper = self._get_xgb_external_memory_fit_wrapper_sproc(
412
419
  model_spec=model_spec,
@@ -418,8 +425,6 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
418
425
  try:
419
426
  sproc_export_file_name = fit_wrapper(
420
427
  self.session,
421
- stage_transform_file_name,
422
- stage_result_file_name,
423
428
  temp_stage_name,
424
429
  self._batch_size,
425
430
  self.input_cols,
@@ -440,7 +445,7 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
440
445
  sproc_export_file_name = fields[0]
441
446
 
442
447
  return self._fetch_model_from_stage(
443
- dir_path=stage_result_file_name,
448
+ dir_path=temp_stage_name,
444
449
  file_name=sproc_export_file_name,
445
450
  statement_params=statement_params,
446
451
  )
@@ -296,7 +296,7 @@ class CalibratedClassifierCV(BaseTransformer):
296
296
  inspect.currentframe(), CalibratedClassifierCV.__class__.__name__
297
297
  ),
298
298
  api_calls=[Session.call],
299
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
299
+ custom_tags={"autogen": True} if self._autogenerated else None,
300
300
  )
301
301
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
302
302
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class AffinityPropagation(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -271,7 +273,7 @@ class AffinityPropagation(BaseTransformer):
271
273
  inspect.currentframe(), AffinityPropagation.__class__.__name__
272
274
  ),
273
275
  api_calls=[Session.call],
274
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
276
+ custom_tags={"autogen": True} if self._autogenerated else None,
275
277
  )
276
278
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
277
279
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class AgglomerativeClustering(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -304,7 +306,7 @@ class AgglomerativeClustering(BaseTransformer):
304
306
  inspect.currentframe(), AgglomerativeClustering.__class__.__name__
305
307
  ),
306
308
  api_calls=[Session.call],
307
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
309
+ custom_tags={"autogen": True} if self._autogenerated else None,
308
310
  )
309
311
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
310
312
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class Birch(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -262,7 +264,7 @@ class Birch(BaseTransformer):
262
264
  inspect.currentframe(), Birch.__class__.__name__
263
265
  ),
264
266
  api_calls=[Session.call],
265
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
267
+ custom_tags={"autogen": True} if self._autogenerated else None,
266
268
  )
267
269
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
268
270
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class BisectingKMeans(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -311,7 +313,7 @@ class BisectingKMeans(BaseTransformer):
311
313
  inspect.currentframe(), BisectingKMeans.__class__.__name__
312
314
  ),
313
315
  api_calls=[Session.call],
314
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
316
+ custom_tags={"autogen": True} if self._autogenerated else None,
315
317
  )
316
318
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
317
319
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class DBSCAN(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -279,7 +281,7 @@ class DBSCAN(BaseTransformer):
279
281
  inspect.currentframe(), DBSCAN.__class__.__name__
280
282
  ),
281
283
  api_calls=[Session.call],
282
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
284
+ custom_tags={"autogen": True} if self._autogenerated else None,
283
285
  )
284
286
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
285
287
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class FeatureAgglomeration(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -311,7 +313,7 @@ class FeatureAgglomeration(BaseTransformer):
311
313
  inspect.currentframe(), FeatureAgglomeration.__class__.__name__
312
314
  ),
313
315
  api_calls=[Session.call],
314
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
316
+ custom_tags={"autogen": True} if self._autogenerated else None,
315
317
  )
316
318
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
317
319
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class KMeans(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -306,7 +308,7 @@ class KMeans(BaseTransformer):
306
308
  inspect.currentframe(), KMeans.__class__.__name__
307
309
  ),
308
310
  api_calls=[Session.call],
309
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
311
+ custom_tags={"autogen": True} if self._autogenerated else None,
310
312
  )
311
313
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
312
314
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class MeanShift(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -282,7 +284,7 @@ class MeanShift(BaseTransformer):
282
284
  inspect.currentframe(), MeanShift.__class__.__name__
283
285
  ),
284
286
  api_calls=[Session.call],
285
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
287
+ custom_tags={"autogen": True} if self._autogenerated else None,
286
288
  )
287
289
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
288
290
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class MiniBatchKMeans(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -332,7 +334,7 @@ class MiniBatchKMeans(BaseTransformer):
332
334
  inspect.currentframe(), MiniBatchKMeans.__class__.__name__
333
335
  ),
334
336
  api_calls=[Session.call],
335
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
337
+ custom_tags={"autogen": True} if self._autogenerated else None,
336
338
  )
337
339
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
338
340
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class OPTICS(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -352,7 +354,7 @@ class OPTICS(BaseTransformer):
352
354
  inspect.currentframe(), OPTICS.__class__.__name__
353
355
  ),
354
356
  api_calls=[Session.call],
355
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
357
+ custom_tags={"autogen": True} if self._autogenerated else None,
356
358
  )
357
359
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
358
360
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class SpectralBiclustering(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -290,7 +292,7 @@ class SpectralBiclustering(BaseTransformer):
290
292
  inspect.currentframe(), SpectralBiclustering.__class__.__name__
291
293
  ),
292
294
  api_calls=[Session.call],
293
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
295
+ custom_tags={"autogen": True} if self._autogenerated else None,
294
296
  )
295
297
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
296
298
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class SpectralClustering(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -348,7 +350,7 @@ class SpectralClustering(BaseTransformer):
348
350
  inspect.currentframe(), SpectralClustering.__class__.__name__
349
351
  ),
350
352
  api_calls=[Session.call],
351
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
353
+ custom_tags={"autogen": True} if self._autogenerated else None,
352
354
  )
353
355
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
354
356
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class SpectralCoclustering(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -269,7 +271,7 @@ class SpectralCoclustering(BaseTransformer):
269
271
  inspect.currentframe(), SpectralCoclustering.__class__.__name__
270
272
  ),
271
273
  api_calls=[Session.call],
272
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
274
+ custom_tags={"autogen": True} if self._autogenerated else None,
273
275
  )
274
276
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
275
277
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class ColumnTransformer(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -299,7 +301,7 @@ class ColumnTransformer(BaseTransformer):
299
301
  inspect.currentframe(), ColumnTransformer.__class__.__name__
300
302
  ),
301
303
  api_calls=[Session.call],
302
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
304
+ custom_tags={"autogen": True} if self._autogenerated else None,
303
305
  )
304
306
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
305
307
  pd_df.columns = dataset.columns
@@ -260,7 +260,7 @@ class TransformedTargetRegressor(BaseTransformer):
260
260
  inspect.currentframe(), TransformedTargetRegressor.__class__.__name__
261
261
  ),
262
262
  api_calls=[Session.call],
263
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
263
+ custom_tags={"autogen": True} if self._autogenerated else None,
264
264
  )
265
265
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
266
266
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class EllipticEnvelope(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -255,7 +257,7 @@ class EllipticEnvelope(BaseTransformer):
255
257
  inspect.currentframe(), EllipticEnvelope.__class__.__name__
256
258
  ),
257
259
  api_calls=[Session.call],
258
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
260
+ custom_tags={"autogen": True} if self._autogenerated else None,
259
261
  )
260
262
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
261
263
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class EmpiricalCovariance(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -231,7 +233,7 @@ class EmpiricalCovariance(BaseTransformer):
231
233
  inspect.currentframe(), EmpiricalCovariance.__class__.__name__
232
234
  ),
233
235
  api_calls=[Session.call],
234
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
236
+ custom_tags={"autogen": True} if self._autogenerated else None,
235
237
  )
236
238
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
237
239
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class GraphicalLasso(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -279,7 +281,7 @@ class GraphicalLasso(BaseTransformer):
279
281
  inspect.currentframe(), GraphicalLasso.__class__.__name__
280
282
  ),
281
283
  api_calls=[Session.call],
282
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
284
+ custom_tags={"autogen": True} if self._autogenerated else None,
283
285
  )
284
286
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
285
287
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class GraphicalLassoCV(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -305,7 +307,7 @@ class GraphicalLassoCV(BaseTransformer):
305
307
  inspect.currentframe(), GraphicalLassoCV.__class__.__name__
306
308
  ),
307
309
  api_calls=[Session.call],
308
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
310
+ custom_tags={"autogen": True} if self._autogenerated else None,
309
311
  )
310
312
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
311
313
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class LedoitWolf(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -238,7 +240,7 @@ class LedoitWolf(BaseTransformer):
238
240
  inspect.currentframe(), LedoitWolf.__class__.__name__
239
241
  ),
240
242
  api_calls=[Session.call],
241
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
243
+ custom_tags={"autogen": True} if self._autogenerated else None,
242
244
  )
243
245
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
244
246
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class MinCovDet(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -250,7 +252,7 @@ class MinCovDet(BaseTransformer):
250
252
  inspect.currentframe(), MinCovDet.__class__.__name__
251
253
  ),
252
254
  api_calls=[Session.call],
253
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
255
+ custom_tags={"autogen": True} if self._autogenerated else None,
254
256
  )
255
257
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
256
258
  pd_df.columns = dataset.columns
@@ -76,8 +76,10 @@ class OAS(BaseTransformer):
76
76
  initialization with the `set_input_cols` method.
77
77
 
78
78
  label_cols: Optional[Union[str, List[str]]]
79
- This parameter is optional and will be ignored during fit. It is present here for API consistency by convention.
80
-
79
+ A string or list of strings representing column names that contain labels.
80
+ Label columns must be specified with this parameter during initialization
81
+ or with the `set_label_cols` method before fitting.
82
+
81
83
  output_cols: Optional[Union[str, List[str]]]
82
84
  A string or list of strings representing column names that will store the
83
85
  output of predict and transform operations. The length of output_cols must
@@ -231,7 +233,7 @@ class OAS(BaseTransformer):
231
233
  inspect.currentframe(), OAS.__class__.__name__
232
234
  ),
233
235
  api_calls=[Session.call],
234
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
236
+ custom_tags={"autogen": True} if self._autogenerated else None,
235
237
  )
236
238
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
237
239
  pd_df.columns = dataset.columns