autogluon.tabular 1.3.2b20250711__py3-none-any.whl → 1.3.2b20250712__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. autogluon/tabular/models/__init__.py +1 -1
  2. autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
  3. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
  4. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
  5. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
  6. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
  7. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
  8. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
  9. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
  10. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +376 -0
  11. autogluon/tabular/registry/_ag_model_registry.py +2 -2
  12. autogluon/tabular/version.py +1 -1
  13. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/METADATA +13 -15
  14. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/RECORD +21 -14
  15. autogluon/tabular/models/tabpfn/__init__.py +0 -1
  16. autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
  17. /autogluon.tabular-1.3.2b20250711-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250712-py3.9-nspkg.pth +0 -0
  18. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/LICENSE +0 -0
  19. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/NOTICE +0 -0
  20. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/WHEEL +0 -0
  21. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/namespace_packages.txt +0 -0
  22. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/top_level.txt +0 -0
  23. {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/zip-safe +0 -0
@@ -0,0 +1,863 @@
1
+ # mypy: ignore-errors
2
+ # taken from https://github.com/sklearn-compat/sklearn-compat
3
+ """Ease developer experience to support multiple versions of scikit-learn.
4
+
5
+ This file is intended to be vendored in your project if you do not want to depend on
6
+ `sklearn-compat` as a package. Then, you can import directly from this file.
7
+
8
+ Be aware that depending on `sklearn-compat` does not add any additional dependencies:
9
+ we are only depending on `scikit-learn`.
10
+
11
+ Version: 0.1.3
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import functools
17
+ import inspect
18
+ import platform
19
+ import sys
20
+ import types
21
+ from dataclasses import dataclass, field
22
+ from typing import Callable, Literal
23
+
24
+ import sklearn
25
+ from sklearn.utils.fixes import parse_version
26
+
27
+ sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
28
+
29
+
30
+ ########################################################################################
31
+ # The following code does not depend on the sklearn version
32
+ ########################################################################################
33
+
34
+
35
+ # tags infrastructure
36
+ def _dataclass_args():
37
+ if sys.version_info < (3, 10):
38
+ return {}
39
+ return {"slots": True}
40
+
41
+
42
+ def get_tags(estimator):
43
+ """Get estimator tags in a consistent format across different sklearn versions.
44
+
45
+ This function provides compatibility between sklearn versions before and after 1.6.
46
+ It returns either a Tags object (sklearn >= 1.6) or a converted Tags object from
47
+ the dictionary format (sklearn < 1.6) containing metadata about the estimator's
48
+ requirements and capabilities.
49
+
50
+ Parameters
51
+ ----------
52
+ estimator : estimator object
53
+ A scikit-learn estimator instance.
54
+
55
+ Returns:
56
+ -------
57
+ tags : Tags
58
+ An object containing metadata about the estimator's requirements and
59
+ capabilities (e.g., input types, fitting requirements, classifier/regressor
60
+ specific tags).
61
+ """
62
+ try:
63
+ from sklearn.utils._tags import get_tags
64
+
65
+ return get_tags(estimator)
66
+ except ImportError:
67
+ from sklearn.utils._tags import _safe_tags
68
+
69
+ return _to_new_tags(_safe_tags(estimator), estimator)
70
+
71
+
72
+ def _to_new_tags(old_tags, estimator=None):
73
+ """Utility function convert old tags (dictionary) to new tags (dataclass)."""
74
+ input_tags = InputTags(
75
+ one_d_array="1darray" in old_tags["X_types"],
76
+ two_d_array="2darray" in old_tags["X_types"],
77
+ three_d_array="3darray" in old_tags["X_types"],
78
+ sparse="sparse" in old_tags["X_types"],
79
+ categorical="categorical" in old_tags["X_types"],
80
+ string="string" in old_tags["X_types"],
81
+ dict="dict" in old_tags["X_types"],
82
+ positive_only=old_tags["requires_positive_X"],
83
+ allow_nan=old_tags["allow_nan"],
84
+ pairwise=old_tags["pairwise"],
85
+ )
86
+ target_tags = TargetTags(
87
+ required=old_tags["requires_y"],
88
+ one_d_labels="1dlabels" in old_tags["X_types"],
89
+ two_d_labels="2dlabels" in old_tags["X_types"],
90
+ positive_only=old_tags["requires_positive_y"],
91
+ multi_output=old_tags["multioutput"] or old_tags["multioutput_only"],
92
+ single_output=not old_tags["multioutput_only"],
93
+ )
94
+ if estimator is not None and (
95
+ hasattr(estimator, "transform") or hasattr(estimator, "fit_transform")
96
+ ):
97
+ transformer_tags = TransformerTags(
98
+ preserves_dtype=old_tags["preserves_dtype"],
99
+ )
100
+ else:
101
+ transformer_tags = None
102
+ estimator_type = getattr(estimator, "_estimator_type", None)
103
+ if estimator_type == "classifier":
104
+ classifier_tags = ClassifierTags(
105
+ poor_score=old_tags["poor_score"],
106
+ multi_class=not old_tags["binary_only"],
107
+ multi_label=old_tags["multilabel"],
108
+ )
109
+ else:
110
+ classifier_tags = None
111
+ if estimator_type == "regressor":
112
+ regressor_tags = RegressorTags(
113
+ poor_score=old_tags["poor_score"],
114
+ multi_label=old_tags["multilabel"],
115
+ )
116
+ else:
117
+ regressor_tags = None
118
+ return Tags(
119
+ estimator_type=estimator_type,
120
+ target_tags=target_tags,
121
+ transformer_tags=transformer_tags,
122
+ classifier_tags=classifier_tags,
123
+ regressor_tags=regressor_tags,
124
+ input_tags=input_tags,
125
+ # Array-API was introduced in 1.3, we need to default to False if not inside
126
+ # the old-tags.
127
+ array_api_support=old_tags.get("array_api_support", False),
128
+ no_validation=old_tags["no_validation"],
129
+ non_deterministic=old_tags["non_deterministic"],
130
+ requires_fit=old_tags["requires_fit"],
131
+ _skip_test=old_tags["_skip_test"],
132
+ )
133
+
134
+
135
+ ########################################################################################
136
+ # Upgrading for scikit-learn 1.3
137
+ ########################################################################################
138
+
139
+ if sklearn_version < parse_version("1.3"):
140
+ # parameter validation
141
+ def _fit_context(*, prefer_skip_nested_validation):
142
+ """Decorator to run the fit methods of estimators within context managers."""
143
+
144
+ def decorator(fit_method):
145
+ @functools.wraps(fit_method)
146
+ def wrapper(estimator, *args, **kwargs):
147
+ estimator._validate_params()
148
+ return fit_method(estimator, *args, **kwargs)
149
+
150
+ return wrapper
151
+
152
+ return decorator
153
+
154
+ def validate_params(parameter_constraints, *, prefer_skip_nested_validation):
155
+ """Validate the parameters of an estimator."""
156
+ from sklearn.utils._param_validation import validate_params
157
+
158
+ return validate_params(parameter_constraints)
159
+
160
+ else:
161
+ # parameter validation
162
+
163
+ from sklearn.base import _fit_context # noqa: F401
164
+ from sklearn.utils._param_validation import validate_params # noqa: F401
165
+
166
+
167
+ ########################################################################################
168
+ # Upgrading for scikit-learn 1.4
169
+ ########################################################################################
170
+
171
+
172
+ if sklearn_version < parse_version("1.4"):
173
+
174
+ def _is_fitted(estimator, attributes=None, all_or_any=all):
175
+ """Determine if an estimator is fitted.
176
+
177
+ Parameters
178
+ ----------
179
+ estimator : estimator instance
180
+ Estimator instance for which the check is performed.
181
+
182
+ attributes : str, list or tuple of str, default=None
183
+ Attribute name(s) given as string or a list/tuple of strings
184
+ Eg.: ``["coef_", "estimator_", ...], "coef_"``
185
+
186
+ If `None`, `estimator` is considered fitted if there exist an
187
+ attribute that ends with a underscore and does not start with double
188
+ underscore.
189
+
190
+ all_or_any : callable, {all, any}, default=all
191
+ Specify whether all or any of the given attributes must exist.
192
+
193
+ Returns:
194
+ -------
195
+ fitted : bool
196
+ Whether the estimator is fitted.
197
+ """
198
+ if attributes is not None:
199
+ if not isinstance(attributes, (list, tuple)):
200
+ attributes = [attributes]
201
+ return all_or_any([hasattr(estimator, attr) for attr in attributes])
202
+
203
+ if hasattr(estimator, "__sklearn_is_fitted__"):
204
+ return estimator.__sklearn_is_fitted__()
205
+
206
+ fitted_attrs = [
207
+ v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
208
+ ]
209
+ return len(fitted_attrs) > 0
210
+
211
+ if sklearn_version < parse_version("1.3"):
212
+
213
+ def process_routing(_obj, _method, /, **kwargs):
214
+ raise NotImplementedError(
215
+ "Metadata routing is not implemented in scikit-learn < 1.3",
216
+ )
217
+
218
+ def _raise_for_params(params, owner, method):
219
+ raise NotImplementedError(
220
+ "Metadata routing is not implemented in scikit-learn < 1.3",
221
+ )
222
+ else:
223
+
224
+ def process_routing(_obj, _method, /, **kwargs):
225
+ """Validate and route input parameters."""
226
+ from sklearn.utils._metadata_requests import process_routing
227
+
228
+ return process_routing(_obj, _method, other_params=None, **kwargs)
229
+
230
+ def _raise_for_params(params, owner, method):
231
+ """Raise an error if metadata routing is not enabled and params are passed."""
232
+ from sklearn.utils._metadata_requests import _routing_enabled
233
+
234
+ caller = (
235
+ f"{owner.__class__.__name__}.{method}"
236
+ if method
237
+ else owner.__class__.__name__
238
+ )
239
+ if not _routing_enabled() and params:
240
+ raise ValueError(
241
+ f"Passing extra keyword arguments to {caller} is only supported if"
242
+ " enable_metadata_routing=True, which you can set using"
243
+ " `sklearn.set_config`. See the User Guide"
244
+ " <https://scikit-learn.org/stable/metadata_routing.html> for more"
245
+ f" details. Extra parameters passed are: {set(params)}",
246
+ )
247
+
248
+ def _is_pandas_df(X):
249
+ """Return True if the X is a pandas dataframe."""
250
+ try:
251
+ pd = sys.modules["pandas"]
252
+ except KeyError:
253
+ return False
254
+ return isinstance(X, pd.DataFrame)
255
+
256
+ else:
257
+ from sklearn.utils.metadata_routing import (
258
+ _raise_for_params, # noqa: F401
259
+ process_routing, # noqa: F401
260
+ )
261
+ from sklearn.utils.validation import (
262
+ _is_fitted, # noqa: F401
263
+ _is_pandas_df, # noqa: F401
264
+ )
265
+
266
+
267
+ ########################################################################################
268
+ # Upgrading for scikit-learn 1.5
269
+ ########################################################################################
270
+
271
+
272
+ if sklearn_version < parse_version("1.5"):
273
+ # chunking
274
+ # extmath
275
+ # fixes
276
+ from sklearn.utils import (
277
+ _IS_32BIT,
278
+ _approximate_mode,
279
+ _chunk_generator as chunk_generator,
280
+ _in_unstable_openblas_configuration,
281
+ gen_batches,
282
+ gen_even_slices,
283
+ get_chunk_n_rows,
284
+ safe_sqr,
285
+ )
286
+
287
+ _IS_WASM = platform.machine() in ["wasm32", "wasm64"]
288
+ # indexing
289
+ # mask
290
+ # missing
291
+ # optional dependencies
292
+ # user interface
293
+ # validation
294
+ from sklearn.utils import (
295
+ _determine_key_type,
296
+ _get_column_indices,
297
+ _is_pandas_na as is_pandas_na,
298
+ _print_elapsed_time,
299
+ _safe_assign,
300
+ _safe_indexing,
301
+ _to_object_array,
302
+ axis0_safe_slice,
303
+ check_matplotlib_support,
304
+ check_pandas_support,
305
+ indices_to_mask,
306
+ is_scalar_nan,
307
+ resample,
308
+ safe_mask,
309
+ shuffle,
310
+ )
311
+ else:
312
+ # chunking
313
+ from sklearn.utils._chunking import (
314
+ chunk_generator, # noqa: F401
315
+ gen_batches, # noqa: F401
316
+ gen_even_slices, # noqa: F401
317
+ get_chunk_n_rows, # noqa: F401
318
+ )
319
+
320
+ # indexing
321
+ from sklearn.utils._indexing import (
322
+ _determine_key_type, # noqa: F401
323
+ _get_column_indices, # noqa: F401
324
+ _safe_assign, # noqa: F401
325
+ _safe_indexing, # noqa: F401
326
+ resample, # noqa: F401
327
+ shuffle, # noqa: F401
328
+ )
329
+
330
+ # mask
331
+ from sklearn.utils._mask import (
332
+ axis0_safe_slice, # noqa: F401
333
+ indices_to_mask, # noqa: F401
334
+ safe_mask, # noqa: F401
335
+ )
336
+
337
+ # missing
338
+ from sklearn.utils._missing import (
339
+ is_pandas_na, # noqa: F401
340
+ is_scalar_nan, # noqa: F401
341
+ )
342
+
343
+ # optional dependencies
344
+ from sklearn.utils._optional_dependencies import ( # noqa: F401
345
+ check_matplotlib_support,
346
+ check_pandas_support,
347
+ )
348
+
349
+ # user interface
350
+ from sklearn.utils._user_interface import _print_elapsed_time # noqa: F401
351
+
352
+ # extmath
353
+ from sklearn.utils.extmath import (
354
+ _approximate_mode, # noqa: F401
355
+ safe_sqr, # noqa: F401
356
+ )
357
+
358
+ # fixes
359
+ from sklearn.utils.fixes import (
360
+ _IS_32BIT, # noqa: F401
361
+ _IS_WASM, # noqa: F401
362
+ _in_unstable_openblas_configuration, # noqa: F401
363
+ )
364
+
365
+ # validation
366
+ from sklearn.utils.validation import _to_object_array # noqa: F401
367
+
368
+ ########################################################################################
369
+ # Upgrading for scikit-learn 1.6
370
+ ########################################################################################
371
+
372
+
373
+ if sklearn_version < parse_version("1.6"):
374
+ # base
375
+ def is_clusterer(estimator):
376
+ """Return True if the given estimator is (probably) a clusterer."""
377
+ return get_tags(estimator).estimator_type == "clusterer"
378
+
379
+ # test_common
380
+ from sklearn.utils.estimator_checks import _construct_instance
381
+
382
+ def type_of_target(y, input_name="", *, raise_unknown=False):
383
+ # fix for raise_unknown which is introduced in scikit-learn 1.6
384
+ from sklearn.utils.multiclass import type_of_target
385
+
386
+ def _raise_or_return(target_type):
387
+ """Depending on the value of raise_unknown, either raise an error or
388
+ return 'unknown'.
389
+ """
390
+ if raise_unknown and target_type == "unknown":
391
+ input = input_name if input_name else "data" # noqa: A001
392
+ raise ValueError(f"Unknown label type for {input}: {y!r}")
393
+ return target_type
394
+
395
+ target_type = type_of_target(y, input_name=input_name)
396
+ return _raise_or_return(target_type)
397
+
398
+ def _construct_instances(Estimator):
399
+ yield _construct_instance(Estimator)
400
+
401
+ # validation
402
+ def validate_data(
403
+ _estimator,
404
+ /,
405
+ X="no_validation",
406
+ y="no_validation",
407
+ reset=True,
408
+ validate_separately=False,
409
+ skip_check_array=False,
410
+ **kwargs,
411
+ ):
412
+ """Validate input data and set or check feature names and counts of the input.
413
+
414
+ See the original scikit-learn documentation:
415
+ https://scikit-learn.org/stable/modules/generated/sklearn.utils.validation.validate_data.html#sklearn.utils.validation.validate_data
416
+ """
417
+ if skip_check_array:
418
+ _check_n_features(_estimator, X, reset=reset)
419
+ _check_feature_names(_estimator, X, reset=reset)
420
+
421
+ no_val_X = isinstance(X, str) and X == "no_validation"
422
+ no_val_y = y is None or (isinstance(y, str) and y == "no_validation")
423
+ if not no_val_X and no_val_y:
424
+ out = X
425
+ elif no_val_X and not no_val_y:
426
+ out = y
427
+ else:
428
+ out = X, y
429
+ return out
430
+ if "ensure_all_finite" in kwargs:
431
+ force_all_finite = kwargs.pop("ensure_all_finite")
432
+ else:
433
+ force_all_finite = True
434
+ return _estimator._validate_data(
435
+ X=X,
436
+ y=y,
437
+ reset=reset,
438
+ validate_separately=validate_separately,
439
+ force_all_finite=force_all_finite,
440
+ **kwargs,
441
+ )
442
+
443
+ def _check_n_features(estimator, X, *, reset):
444
+ """Set the `n_features_in_` attribute, or check against it on an estimator."""
445
+ return estimator._check_n_features(X, reset=reset)
446
+
447
+ def _check_feature_names(estimator, X, *, reset):
448
+ """Check `input_features` and generate names if needed."""
449
+ return estimator._check_feature_names(X, reset=reset)
450
+
451
+ def check_array(
452
+ array,
453
+ accept_sparse=False,
454
+ *,
455
+ accept_large_sparse=True,
456
+ dtype="numeric",
457
+ order=None,
458
+ copy=False,
459
+ force_writeable=False,
460
+ ensure_all_finite=None,
461
+ ensure_non_negative=False,
462
+ ensure_2d=True,
463
+ allow_nd=False,
464
+ ensure_min_samples=1,
465
+ ensure_min_features=1,
466
+ estimator=None,
467
+ input_name="",
468
+ ):
469
+ """Input validation on an array, list, sparse matrix or similar.
470
+
471
+ Check the original documentation for more details:
472
+ https://scikit-learn.org/stable/modules/generated/sklearn.utils.check_array.html
473
+ """
474
+ from sklearn.utils.validation import check_array as _check_array
475
+
476
+ force_all_finite = ensure_all_finite if ensure_all_finite is not None else True
477
+
478
+ check_array_params = inspect.signature(_check_array).parameters
479
+ kwargs = {}
480
+ if "force_writeable" in check_array_params:
481
+ kwargs["force_writeable"] = force_writeable
482
+ if "ensure_non_negative" in check_array_params:
483
+ kwargs["ensure_non_negative"] = ensure_non_negative
484
+
485
+ return _check_array(
486
+ array,
487
+ accept_sparse=accept_sparse,
488
+ accept_large_sparse=accept_large_sparse,
489
+ dtype=dtype,
490
+ order=order,
491
+ copy=copy,
492
+ force_all_finite=force_all_finite,
493
+ ensure_2d=ensure_2d,
494
+ allow_nd=allow_nd,
495
+ ensure_min_samples=ensure_min_samples,
496
+ ensure_min_features=ensure_min_features,
497
+ estimator=estimator,
498
+ input_name=input_name,
499
+ **kwargs,
500
+ )
501
+
502
+ def check_X_y(
503
+ X,
504
+ y,
505
+ accept_sparse=False,
506
+ *,
507
+ accept_large_sparse=True,
508
+ dtype="numeric",
509
+ order=None,
510
+ copy=False,
511
+ force_writeable=False,
512
+ ensure_all_finite=None,
513
+ ensure_2d=True,
514
+ allow_nd=False,
515
+ multi_output=False,
516
+ ensure_min_samples=1,
517
+ ensure_min_features=1,
518
+ y_numeric=False,
519
+ estimator=None,
520
+ ):
521
+ """Input validation for standard estimators.
522
+
523
+ Check the original documentation for more details:
524
+ https://scikit-learn.org/stable/modules/generated/sklearn.utils.check_X_y.html
525
+ """
526
+ from sklearn.utils.validation import check_X_y as _check_X_y
527
+
528
+ force_all_finite = ensure_all_finite if ensure_all_finite is not None else True
529
+
530
+ check_X_y_params = inspect.signature(_check_X_y).parameters
531
+ kwargs = {}
532
+ if "force_writeable" in check_X_y_params:
533
+ kwargs["force_writeable"] = force_writeable
534
+
535
+ return _check_X_y(
536
+ X,
537
+ y,
538
+ accept_sparse=accept_sparse,
539
+ accept_large_sparse=accept_large_sparse,
540
+ dtype=dtype,
541
+ order=order,
542
+ copy=copy,
543
+ force_all_finite=force_all_finite,
544
+ ensure_2d=ensure_2d,
545
+ allow_nd=allow_nd,
546
+ multi_output=multi_output,
547
+ ensure_min_samples=ensure_min_samples,
548
+ ensure_min_features=ensure_min_features,
549
+ y_numeric=y_numeric,
550
+ estimator=estimator,
551
+ **kwargs,
552
+ )
553
+
554
+ # tags infrastructure
555
+ @dataclass(**_dataclass_args())
556
+ class InputTags:
557
+ """Tags for the input data.
558
+
559
+ Parameters
560
+ ----------
561
+ one_d_array : bool, default=False
562
+ Whether the input can be a 1D array.
563
+
564
+ two_d_array : bool, default=True
565
+ Whether the input can be a 2D array. Note that most common
566
+ tests currently run only if this flag is set to ``True``.
567
+
568
+ three_d_array : bool, default=False
569
+ Whether the input can be a 3D array.
570
+
571
+ sparse : bool, default=False
572
+ Whether the input can be a sparse matrix.
573
+
574
+ categorical : bool, default=False
575
+ Whether the input can be categorical.
576
+
577
+ string : bool, default=False
578
+ Whether the input can be an array-like of strings.
579
+
580
+ dict : bool, default=False
581
+ Whether the input can be a dictionary.
582
+
583
+ positive_only : bool, default=False
584
+ Whether the estimator requires positive X.
585
+
586
+ allow_nan : bool, default=False
587
+ Whether the estimator supports data with missing values encoded as `np.nan`.
588
+
589
+ pairwise : bool, default=False
590
+ This boolean attribute indicates whether the data (`X`),
591
+ :term:`fit` and similar methods consists of pairwise measures
592
+ over samples rather than a feature representation for each
593
+ sample. It is usually `True` where an estimator has a
594
+ `metric` or `affinity` or `kernel` parameter with value
595
+ 'precomputed'. Its primary purpose is to support a
596
+ :term:`meta-estimator` or a cross validation procedure that
597
+ extracts a sub-sample of data intended for a pairwise
598
+ estimator, where the data needs to be indexed on both axes.
599
+ Specifically, this tag is used by
600
+ `sklearn.utils.metaestimators._safe_split` to slice rows and
601
+ columns.
602
+ """
603
+
604
+ one_d_array: bool = False
605
+ two_d_array: bool = True
606
+ three_d_array: bool = False
607
+ sparse: bool = False
608
+ categorical: bool = False
609
+ string: bool = False
610
+ dict: bool = False
611
+ positive_only: bool = False
612
+ allow_nan: bool = False
613
+ pairwise: bool = False
614
+
615
+ @dataclass(**_dataclass_args())
616
+ class TargetTags:
617
+ """Tags for the target data.
618
+
619
+ Parameters
620
+ ----------
621
+ required : bool
622
+ Whether the estimator requires y to be passed to `fit`,
623
+ `fit_predict` or `fit_transform` methods. The tag is ``True``
624
+ for estimators inheriting from `~sklearn.base.RegressorMixin`
625
+ and `~sklearn.base.ClassifierMixin`.
626
+
627
+ one_d_labels : bool, default=False
628
+ Whether the input is a 1D labels (y).
629
+
630
+ two_d_labels : bool, default=False
631
+ Whether the input is a 2D labels (y).
632
+
633
+ positive_only : bool, default=False
634
+ Whether the estimator requires a positive y (only applicable
635
+ for regression).
636
+
637
+ multi_output : bool, default=False
638
+ Whether a regressor supports multi-target outputs or a classifier supports
639
+ multi-class multi-output.
640
+
641
+ single_output : bool, default=True
642
+ Whether the target can be single-output. This can be ``False`` if the
643
+ estimator supports only multi-output cases.
644
+ """
645
+
646
+ required: bool
647
+ one_d_labels: bool = False
648
+ two_d_labels: bool = False
649
+ positive_only: bool = False
650
+ multi_output: bool = False
651
+ single_output: bool = True
652
+
653
+ @dataclass(**_dataclass_args())
654
+ class TransformerTags:
655
+ """Tags for the transformer.
656
+
657
+ Parameters
658
+ ----------
659
+ preserves_dtype : list[str], default=["float64"]
660
+ Applies only on transformers. It corresponds to the data types
661
+ which will be preserved such that `X_trans.dtype` is the same
662
+ as `X.dtype` after calling `transformer.transform(X)`. If this
663
+ list is empty, then the transformer is not expected to
664
+ preserve the data type. The first value in the list is
665
+ considered as the default data type, corresponding to the data
666
+ type of the output when the input data type is not going to be
667
+ preserved.
668
+ """
669
+
670
+ preserves_dtype: list[str] = field(default_factory=lambda: ["float64"])
671
+
672
+ @dataclass(**_dataclass_args())
673
+ class ClassifierTags:
674
+ """Tags for the classifier.
675
+
676
+ Parameters
677
+ ----------
678
+ poor_score : bool, default=False
679
+ Whether the estimator fails to provide a "reasonable" test-set
680
+ score, which currently for classification is an accuracy of
681
+ 0.83 on ``make_blobs(n_samples=300, random_state=0)``. The
682
+ datasets and values are based on current estimators in scikit-learn
683
+ and might be replaced by something more systematic.
684
+
685
+ multi_class : bool, default=True
686
+ Whether the classifier can handle multi-class
687
+ classification. Note that all classifiers support binary
688
+ classification. Therefore this flag indicates whether the
689
+ classifier is a binary-classifier-only or not.
690
+
691
+ multi_label : bool, default=False
692
+ Whether the classifier supports multi-label output.
693
+ """
694
+
695
+ poor_score: bool = False
696
+ multi_class: bool = True
697
+ multi_label: bool = False
698
+
699
+ @dataclass(**_dataclass_args())
700
+ class RegressorTags:
701
+ """Tags for the regressor.
702
+
703
+ Parameters
704
+ ----------
705
+ poor_score : bool, default=False
706
+ Whether the estimator fails to provide a "reasonable" test-set
707
+ score, which currently for regression is an R2 of 0.5 on
708
+ ``make_regression(n_samples=200, n_features=10,
709
+ n_informative=1, bias=5.0, noise=20, random_state=42)``. The
710
+ dataset and values are based on current estimators in scikit-learn
711
+ and might be replaced by something more systematic.
712
+
713
+ multi_label : bool, default=False
714
+ Whether the regressor supports multilabel output.
715
+ """
716
+
717
+ poor_score: bool = False
718
+ multi_label: bool = False
719
+
720
+ @dataclass(**_dataclass_args())
721
+ class Tags:
722
+ """Tags for the estimator.
723
+
724
+ See :ref:`estimator_tags` for more information.
725
+
726
+ Parameters
727
+ ----------
728
+ estimator_type : str or None
729
+ The type of the estimator. Can be one of:
730
+ - "classifier"
731
+ - "regressor"
732
+ - "transformer"
733
+ - "clusterer"
734
+ - "outlier_detector"
735
+ - "density_estimator"
736
+
737
+ target_tags : :class:`TargetTags`
738
+ The target(y) tags.
739
+
740
+ transformer_tags : :class:`TransformerTags` or None
741
+ The transformer tags.
742
+
743
+ classifier_tags : :class:`ClassifierTags` or None
744
+ The classifier tags.
745
+
746
+ regressor_tags : :class:`RegressorTags` or None
747
+ The regressor tags.
748
+
749
+ array_api_support : bool, default=False
750
+ Whether the estimator supports Array API compatible inputs.
751
+
752
+ no_validation : bool, default=False
753
+ Whether the estimator skips input-validation. This is only meant for
754
+ stateless and dummy transformers!
755
+
756
+ non_deterministic : bool, default=False
757
+ Whether the estimator is not deterministic given a fixed ``random_state``.
758
+
759
+ requires_fit : bool, default=True
760
+ Whether the estimator requires to be fitted before calling one of
761
+ `transform`, `predict`, `predict_proba`, or `decision_function`.
762
+
763
+ _skip_test : bool, default=False
764
+ Whether to skip common tests entirely. Don't use this unless
765
+ you have a *very good* reason.
766
+
767
+ input_tags : :class:`InputTags`
768
+ The input data(X) tags.
769
+ """
770
+
771
+ estimator_type: str | None
772
+ target_tags: TargetTags
773
+ transformer_tags: TransformerTags | None = None
774
+ classifier_tags: ClassifierTags | None = None
775
+ regressor_tags: RegressorTags | None = None
776
+ array_api_support: bool = False
777
+ no_validation: bool = False
778
+ non_deterministic: bool = False
779
+ requires_fit: bool = True
780
+ _skip_test: bool = False
781
+ input_tags: InputTags = field(default_factory=InputTags)
782
+
783
+ def _patched_more_tags(estimator, expected_failed_checks):
784
+ original_class_more_tags = estimator.__class__._more_tags
785
+
786
+ def patched_instance_more_tags(self):
787
+ """Instance-level _more_tags that combines class tags with _xfail_checks."""
788
+ # Get tags from class-level _more_tags
789
+ tags = original_class_more_tags(self)
790
+ # Update with the xfail checks
791
+ tags.update({"_xfail_checks": expected_failed_checks})
792
+ return tags
793
+
794
+ # Patch both class and instance level
795
+ estimator.__class__._more_tags = patched_instance_more_tags
796
+ estimator._more_tags = types.MethodType(patched_instance_more_tags, estimator)
797
+ return estimator
798
+
799
+ def check_estimator(
800
+ estimator=None,
801
+ generate_only=False,
802
+ *,
803
+ legacy: bool = True,
804
+ expected_failed_checks: dict[str, str] | None = None,
805
+ on_skip: Literal["warn"] | None = "warn",
806
+ on_fail: Literal["raise", "warn"] | None = "raise",
807
+ callback: Callable | None = None,
808
+ ):
809
+ # legacy, on_skip, on_fail, and callback are not supported and ignored
810
+ from sklearn.utils.estimator_checks import check_estimator
811
+
812
+ return check_estimator(
813
+ _patched_more_tags(estimator, expected_failed_checks),
814
+ generate_only=generate_only,
815
+ )
816
+
817
+ def parametrize_with_checks(
818
+ estimators,
819
+ *,
820
+ legacy: bool = True,
821
+ expected_failed_checks: Callable | None = None,
822
+ ):
823
+ # legacy is not supported and ignored
824
+ from sklearn.utils.estimator_checks import parametrize_with_checks
825
+
826
+ estimators = [
827
+ _patched_more_tags(estimator, expected_failed_checks(estimator))
828
+ for estimator in estimators
829
+ ]
830
+
831
+ return parametrize_with_checks(estimators)
832
+
833
+ else:
834
+ # base
835
+ from sklearn.base import is_clusterer # noqa: F401
836
+
837
+ # test_common
838
+ # tags infrastructure
839
+ from sklearn.utils import (
840
+ ClassifierTags,
841
+ InputTags,
842
+ RegressorTags,
843
+ Tags,
844
+ TargetTags,
845
+ TransformerTags,
846
+ )
847
+ from sklearn.utils._test_common.instance_generator import (
848
+ _construct_instances, # noqa: F401
849
+ )
850
+ from sklearn.utils.estimator_checks import (
851
+ check_estimator, # noqa: F401
852
+ parametrize_with_checks, # noqa: F401
853
+ )
854
+ from sklearn.utils.multiclass import type_of_target # noqa: F401
855
+
856
+ # validation
857
+ from sklearn.utils.validation import (
858
+ _check_feature_names,
859
+ _check_n_features,
860
+ check_array, # noqa: F401
861
+ check_X_y, # noqa: F401
862
+ validate_data, # noqa: F401
863
+ )