autogluon.tabular 1.3.2b20250711__py3-none-any.whl → 1.3.2b20250713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -1
- autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +376 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -2
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/METADATA +13 -15
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/RECORD +21 -14
- autogluon/tabular/models/tabpfn/__init__.py +0 -1
- autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
- /autogluon.tabular-1.3.2b20250711-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250713.dist-info}/zip-safe +0 -0
@@ -0,0 +1,863 @@
|
|
1
|
+
# mypy: ignore-errors
|
2
|
+
# taken from https://github.com/sklearn-compat/sklearn-compat
|
3
|
+
"""Ease developer experience to support multiple versions of scikit-learn.
|
4
|
+
|
5
|
+
This file is intended to be vendored in your project if you do not want to depend on
|
6
|
+
`sklearn-compat` as a package. Then, you can import directly from this file.
|
7
|
+
|
8
|
+
Be aware that depending on `sklearn-compat` does not add any additional dependencies:
|
9
|
+
we are only depending on `scikit-learn`.
|
10
|
+
|
11
|
+
Version: 0.1.3
|
12
|
+
"""
|
13
|
+
|
14
|
+
from __future__ import annotations
|
15
|
+
|
16
|
+
import functools
|
17
|
+
import inspect
|
18
|
+
import platform
|
19
|
+
import sys
|
20
|
+
import types
|
21
|
+
from dataclasses import dataclass, field
|
22
|
+
from typing import Callable, Literal
|
23
|
+
|
24
|
+
import sklearn
|
25
|
+
from sklearn.utils.fixes import parse_version
|
26
|
+
|
27
|
+
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
|
28
|
+
|
29
|
+
|
30
|
+
########################################################################################
|
31
|
+
# The following code does not depend on the sklearn version
|
32
|
+
########################################################################################
|
33
|
+
|
34
|
+
|
35
|
+
# tags infrastructure
|
36
|
+
def _dataclass_args():
|
37
|
+
if sys.version_info < (3, 10):
|
38
|
+
return {}
|
39
|
+
return {"slots": True}
|
40
|
+
|
41
|
+
|
42
|
+
def get_tags(estimator):
|
43
|
+
"""Get estimator tags in a consistent format across different sklearn versions.
|
44
|
+
|
45
|
+
This function provides compatibility between sklearn versions before and after 1.6.
|
46
|
+
It returns either a Tags object (sklearn >= 1.6) or a converted Tags object from
|
47
|
+
the dictionary format (sklearn < 1.6) containing metadata about the estimator's
|
48
|
+
requirements and capabilities.
|
49
|
+
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
estimator : estimator object
|
53
|
+
A scikit-learn estimator instance.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
-------
|
57
|
+
tags : Tags
|
58
|
+
An object containing metadata about the estimator's requirements and
|
59
|
+
capabilities (e.g., input types, fitting requirements, classifier/regressor
|
60
|
+
specific tags).
|
61
|
+
"""
|
62
|
+
try:
|
63
|
+
from sklearn.utils._tags import get_tags
|
64
|
+
|
65
|
+
return get_tags(estimator)
|
66
|
+
except ImportError:
|
67
|
+
from sklearn.utils._tags import _safe_tags
|
68
|
+
|
69
|
+
return _to_new_tags(_safe_tags(estimator), estimator)
|
70
|
+
|
71
|
+
|
72
|
+
def _to_new_tags(old_tags, estimator=None):
|
73
|
+
"""Utility function convert old tags (dictionary) to new tags (dataclass)."""
|
74
|
+
input_tags = InputTags(
|
75
|
+
one_d_array="1darray" in old_tags["X_types"],
|
76
|
+
two_d_array="2darray" in old_tags["X_types"],
|
77
|
+
three_d_array="3darray" in old_tags["X_types"],
|
78
|
+
sparse="sparse" in old_tags["X_types"],
|
79
|
+
categorical="categorical" in old_tags["X_types"],
|
80
|
+
string="string" in old_tags["X_types"],
|
81
|
+
dict="dict" in old_tags["X_types"],
|
82
|
+
positive_only=old_tags["requires_positive_X"],
|
83
|
+
allow_nan=old_tags["allow_nan"],
|
84
|
+
pairwise=old_tags["pairwise"],
|
85
|
+
)
|
86
|
+
target_tags = TargetTags(
|
87
|
+
required=old_tags["requires_y"],
|
88
|
+
one_d_labels="1dlabels" in old_tags["X_types"],
|
89
|
+
two_d_labels="2dlabels" in old_tags["X_types"],
|
90
|
+
positive_only=old_tags["requires_positive_y"],
|
91
|
+
multi_output=old_tags["multioutput"] or old_tags["multioutput_only"],
|
92
|
+
single_output=not old_tags["multioutput_only"],
|
93
|
+
)
|
94
|
+
if estimator is not None and (
|
95
|
+
hasattr(estimator, "transform") or hasattr(estimator, "fit_transform")
|
96
|
+
):
|
97
|
+
transformer_tags = TransformerTags(
|
98
|
+
preserves_dtype=old_tags["preserves_dtype"],
|
99
|
+
)
|
100
|
+
else:
|
101
|
+
transformer_tags = None
|
102
|
+
estimator_type = getattr(estimator, "_estimator_type", None)
|
103
|
+
if estimator_type == "classifier":
|
104
|
+
classifier_tags = ClassifierTags(
|
105
|
+
poor_score=old_tags["poor_score"],
|
106
|
+
multi_class=not old_tags["binary_only"],
|
107
|
+
multi_label=old_tags["multilabel"],
|
108
|
+
)
|
109
|
+
else:
|
110
|
+
classifier_tags = None
|
111
|
+
if estimator_type == "regressor":
|
112
|
+
regressor_tags = RegressorTags(
|
113
|
+
poor_score=old_tags["poor_score"],
|
114
|
+
multi_label=old_tags["multilabel"],
|
115
|
+
)
|
116
|
+
else:
|
117
|
+
regressor_tags = None
|
118
|
+
return Tags(
|
119
|
+
estimator_type=estimator_type,
|
120
|
+
target_tags=target_tags,
|
121
|
+
transformer_tags=transformer_tags,
|
122
|
+
classifier_tags=classifier_tags,
|
123
|
+
regressor_tags=regressor_tags,
|
124
|
+
input_tags=input_tags,
|
125
|
+
# Array-API was introduced in 1.3, we need to default to False if not inside
|
126
|
+
# the old-tags.
|
127
|
+
array_api_support=old_tags.get("array_api_support", False),
|
128
|
+
no_validation=old_tags["no_validation"],
|
129
|
+
non_deterministic=old_tags["non_deterministic"],
|
130
|
+
requires_fit=old_tags["requires_fit"],
|
131
|
+
_skip_test=old_tags["_skip_test"],
|
132
|
+
)
|
133
|
+
|
134
|
+
|
135
|
+
########################################################################################
|
136
|
+
# Upgrading for scikit-learn 1.3
|
137
|
+
########################################################################################
|
138
|
+
|
139
|
+
if sklearn_version < parse_version("1.3"):
|
140
|
+
# parameter validation
|
141
|
+
def _fit_context(*, prefer_skip_nested_validation):
|
142
|
+
"""Decorator to run the fit methods of estimators within context managers."""
|
143
|
+
|
144
|
+
def decorator(fit_method):
|
145
|
+
@functools.wraps(fit_method)
|
146
|
+
def wrapper(estimator, *args, **kwargs):
|
147
|
+
estimator._validate_params()
|
148
|
+
return fit_method(estimator, *args, **kwargs)
|
149
|
+
|
150
|
+
return wrapper
|
151
|
+
|
152
|
+
return decorator
|
153
|
+
|
154
|
+
def validate_params(parameter_constraints, *, prefer_skip_nested_validation):
|
155
|
+
"""Validate the parameters of an estimator."""
|
156
|
+
from sklearn.utils._param_validation import validate_params
|
157
|
+
|
158
|
+
return validate_params(parameter_constraints)
|
159
|
+
|
160
|
+
else:
|
161
|
+
# parameter validation
|
162
|
+
|
163
|
+
from sklearn.base import _fit_context # noqa: F401
|
164
|
+
from sklearn.utils._param_validation import validate_params # noqa: F401
|
165
|
+
|
166
|
+
|
167
|
+
########################################################################################
|
168
|
+
# Upgrading for scikit-learn 1.4
|
169
|
+
########################################################################################
|
170
|
+
|
171
|
+
|
172
|
+
if sklearn_version < parse_version("1.4"):
|
173
|
+
|
174
|
+
def _is_fitted(estimator, attributes=None, all_or_any=all):
|
175
|
+
"""Determine if an estimator is fitted.
|
176
|
+
|
177
|
+
Parameters
|
178
|
+
----------
|
179
|
+
estimator : estimator instance
|
180
|
+
Estimator instance for which the check is performed.
|
181
|
+
|
182
|
+
attributes : str, list or tuple of str, default=None
|
183
|
+
Attribute name(s) given as string or a list/tuple of strings
|
184
|
+
Eg.: ``["coef_", "estimator_", ...], "coef_"``
|
185
|
+
|
186
|
+
If `None`, `estimator` is considered fitted if there exist an
|
187
|
+
attribute that ends with a underscore and does not start with double
|
188
|
+
underscore.
|
189
|
+
|
190
|
+
all_or_any : callable, {all, any}, default=all
|
191
|
+
Specify whether all or any of the given attributes must exist.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
-------
|
195
|
+
fitted : bool
|
196
|
+
Whether the estimator is fitted.
|
197
|
+
"""
|
198
|
+
if attributes is not None:
|
199
|
+
if not isinstance(attributes, (list, tuple)):
|
200
|
+
attributes = [attributes]
|
201
|
+
return all_or_any([hasattr(estimator, attr) for attr in attributes])
|
202
|
+
|
203
|
+
if hasattr(estimator, "__sklearn_is_fitted__"):
|
204
|
+
return estimator.__sklearn_is_fitted__()
|
205
|
+
|
206
|
+
fitted_attrs = [
|
207
|
+
v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
|
208
|
+
]
|
209
|
+
return len(fitted_attrs) > 0
|
210
|
+
|
211
|
+
if sklearn_version < parse_version("1.3"):
|
212
|
+
|
213
|
+
def process_routing(_obj, _method, /, **kwargs):
|
214
|
+
raise NotImplementedError(
|
215
|
+
"Metadata routing is not implemented in scikit-learn < 1.3",
|
216
|
+
)
|
217
|
+
|
218
|
+
def _raise_for_params(params, owner, method):
|
219
|
+
raise NotImplementedError(
|
220
|
+
"Metadata routing is not implemented in scikit-learn < 1.3",
|
221
|
+
)
|
222
|
+
else:
|
223
|
+
|
224
|
+
def process_routing(_obj, _method, /, **kwargs):
|
225
|
+
"""Validate and route input parameters."""
|
226
|
+
from sklearn.utils._metadata_requests import process_routing
|
227
|
+
|
228
|
+
return process_routing(_obj, _method, other_params=None, **kwargs)
|
229
|
+
|
230
|
+
def _raise_for_params(params, owner, method):
|
231
|
+
"""Raise an error if metadata routing is not enabled and params are passed."""
|
232
|
+
from sklearn.utils._metadata_requests import _routing_enabled
|
233
|
+
|
234
|
+
caller = (
|
235
|
+
f"{owner.__class__.__name__}.{method}"
|
236
|
+
if method
|
237
|
+
else owner.__class__.__name__
|
238
|
+
)
|
239
|
+
if not _routing_enabled() and params:
|
240
|
+
raise ValueError(
|
241
|
+
f"Passing extra keyword arguments to {caller} is only supported if"
|
242
|
+
" enable_metadata_routing=True, which you can set using"
|
243
|
+
" `sklearn.set_config`. See the User Guide"
|
244
|
+
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
|
245
|
+
f" details. Extra parameters passed are: {set(params)}",
|
246
|
+
)
|
247
|
+
|
248
|
+
def _is_pandas_df(X):
|
249
|
+
"""Return True if the X is a pandas dataframe."""
|
250
|
+
try:
|
251
|
+
pd = sys.modules["pandas"]
|
252
|
+
except KeyError:
|
253
|
+
return False
|
254
|
+
return isinstance(X, pd.DataFrame)
|
255
|
+
|
256
|
+
else:
|
257
|
+
from sklearn.utils.metadata_routing import (
|
258
|
+
_raise_for_params, # noqa: F401
|
259
|
+
process_routing, # noqa: F401
|
260
|
+
)
|
261
|
+
from sklearn.utils.validation import (
|
262
|
+
_is_fitted, # noqa: F401
|
263
|
+
_is_pandas_df, # noqa: F401
|
264
|
+
)
|
265
|
+
|
266
|
+
|
267
|
+
########################################################################################
|
268
|
+
# Upgrading for scikit-learn 1.5
|
269
|
+
########################################################################################
|
270
|
+
|
271
|
+
|
272
|
+
if sklearn_version < parse_version("1.5"):
|
273
|
+
# chunking
|
274
|
+
# extmath
|
275
|
+
# fixes
|
276
|
+
from sklearn.utils import (
|
277
|
+
_IS_32BIT,
|
278
|
+
_approximate_mode,
|
279
|
+
_chunk_generator as chunk_generator,
|
280
|
+
_in_unstable_openblas_configuration,
|
281
|
+
gen_batches,
|
282
|
+
gen_even_slices,
|
283
|
+
get_chunk_n_rows,
|
284
|
+
safe_sqr,
|
285
|
+
)
|
286
|
+
|
287
|
+
_IS_WASM = platform.machine() in ["wasm32", "wasm64"]
|
288
|
+
# indexing
|
289
|
+
# mask
|
290
|
+
# missing
|
291
|
+
# optional dependencies
|
292
|
+
# user interface
|
293
|
+
# validation
|
294
|
+
from sklearn.utils import (
|
295
|
+
_determine_key_type,
|
296
|
+
_get_column_indices,
|
297
|
+
_is_pandas_na as is_pandas_na,
|
298
|
+
_print_elapsed_time,
|
299
|
+
_safe_assign,
|
300
|
+
_safe_indexing,
|
301
|
+
_to_object_array,
|
302
|
+
axis0_safe_slice,
|
303
|
+
check_matplotlib_support,
|
304
|
+
check_pandas_support,
|
305
|
+
indices_to_mask,
|
306
|
+
is_scalar_nan,
|
307
|
+
resample,
|
308
|
+
safe_mask,
|
309
|
+
shuffle,
|
310
|
+
)
|
311
|
+
else:
|
312
|
+
# chunking
|
313
|
+
from sklearn.utils._chunking import (
|
314
|
+
chunk_generator, # noqa: F401
|
315
|
+
gen_batches, # noqa: F401
|
316
|
+
gen_even_slices, # noqa: F401
|
317
|
+
get_chunk_n_rows, # noqa: F401
|
318
|
+
)
|
319
|
+
|
320
|
+
# indexing
|
321
|
+
from sklearn.utils._indexing import (
|
322
|
+
_determine_key_type, # noqa: F401
|
323
|
+
_get_column_indices, # noqa: F401
|
324
|
+
_safe_assign, # noqa: F401
|
325
|
+
_safe_indexing, # noqa: F401
|
326
|
+
resample, # noqa: F401
|
327
|
+
shuffle, # noqa: F401
|
328
|
+
)
|
329
|
+
|
330
|
+
# mask
|
331
|
+
from sklearn.utils._mask import (
|
332
|
+
axis0_safe_slice, # noqa: F401
|
333
|
+
indices_to_mask, # noqa: F401
|
334
|
+
safe_mask, # noqa: F401
|
335
|
+
)
|
336
|
+
|
337
|
+
# missing
|
338
|
+
from sklearn.utils._missing import (
|
339
|
+
is_pandas_na, # noqa: F401
|
340
|
+
is_scalar_nan, # noqa: F401
|
341
|
+
)
|
342
|
+
|
343
|
+
# optional dependencies
|
344
|
+
from sklearn.utils._optional_dependencies import ( # noqa: F401
|
345
|
+
check_matplotlib_support,
|
346
|
+
check_pandas_support,
|
347
|
+
)
|
348
|
+
|
349
|
+
# user interface
|
350
|
+
from sklearn.utils._user_interface import _print_elapsed_time # noqa: F401
|
351
|
+
|
352
|
+
# extmath
|
353
|
+
from sklearn.utils.extmath import (
|
354
|
+
_approximate_mode, # noqa: F401
|
355
|
+
safe_sqr, # noqa: F401
|
356
|
+
)
|
357
|
+
|
358
|
+
# fixes
|
359
|
+
from sklearn.utils.fixes import (
|
360
|
+
_IS_32BIT, # noqa: F401
|
361
|
+
_IS_WASM, # noqa: F401
|
362
|
+
_in_unstable_openblas_configuration, # noqa: F401
|
363
|
+
)
|
364
|
+
|
365
|
+
# validation
|
366
|
+
from sklearn.utils.validation import _to_object_array # noqa: F401
|
367
|
+
|
368
|
+
########################################################################################
|
369
|
+
# Upgrading for scikit-learn 1.6
|
370
|
+
########################################################################################
|
371
|
+
|
372
|
+
|
373
|
+
if sklearn_version < parse_version("1.6"):
|
374
|
+
# base
|
375
|
+
def is_clusterer(estimator):
|
376
|
+
"""Return True if the given estimator is (probably) a clusterer."""
|
377
|
+
return get_tags(estimator).estimator_type == "clusterer"
|
378
|
+
|
379
|
+
# test_common
|
380
|
+
from sklearn.utils.estimator_checks import _construct_instance
|
381
|
+
|
382
|
+
def type_of_target(y, input_name="", *, raise_unknown=False):
|
383
|
+
# fix for raise_unknown which is introduced in scikit-learn 1.6
|
384
|
+
from sklearn.utils.multiclass import type_of_target
|
385
|
+
|
386
|
+
def _raise_or_return(target_type):
|
387
|
+
"""Depending on the value of raise_unknown, either raise an error or
|
388
|
+
return 'unknown'.
|
389
|
+
"""
|
390
|
+
if raise_unknown and target_type == "unknown":
|
391
|
+
input = input_name if input_name else "data" # noqa: A001
|
392
|
+
raise ValueError(f"Unknown label type for {input}: {y!r}")
|
393
|
+
return target_type
|
394
|
+
|
395
|
+
target_type = type_of_target(y, input_name=input_name)
|
396
|
+
return _raise_or_return(target_type)
|
397
|
+
|
398
|
+
def _construct_instances(Estimator):
|
399
|
+
yield _construct_instance(Estimator)
|
400
|
+
|
401
|
+
# validation
|
402
|
+
def validate_data(
|
403
|
+
_estimator,
|
404
|
+
/,
|
405
|
+
X="no_validation",
|
406
|
+
y="no_validation",
|
407
|
+
reset=True,
|
408
|
+
validate_separately=False,
|
409
|
+
skip_check_array=False,
|
410
|
+
**kwargs,
|
411
|
+
):
|
412
|
+
"""Validate input data and set or check feature names and counts of the input.
|
413
|
+
|
414
|
+
See the original scikit-learn documentation:
|
415
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.utils.validation.validate_data.html#sklearn.utils.validation.validate_data
|
416
|
+
"""
|
417
|
+
if skip_check_array:
|
418
|
+
_check_n_features(_estimator, X, reset=reset)
|
419
|
+
_check_feature_names(_estimator, X, reset=reset)
|
420
|
+
|
421
|
+
no_val_X = isinstance(X, str) and X == "no_validation"
|
422
|
+
no_val_y = y is None or (isinstance(y, str) and y == "no_validation")
|
423
|
+
if not no_val_X and no_val_y:
|
424
|
+
out = X
|
425
|
+
elif no_val_X and not no_val_y:
|
426
|
+
out = y
|
427
|
+
else:
|
428
|
+
out = X, y
|
429
|
+
return out
|
430
|
+
if "ensure_all_finite" in kwargs:
|
431
|
+
force_all_finite = kwargs.pop("ensure_all_finite")
|
432
|
+
else:
|
433
|
+
force_all_finite = True
|
434
|
+
return _estimator._validate_data(
|
435
|
+
X=X,
|
436
|
+
y=y,
|
437
|
+
reset=reset,
|
438
|
+
validate_separately=validate_separately,
|
439
|
+
force_all_finite=force_all_finite,
|
440
|
+
**kwargs,
|
441
|
+
)
|
442
|
+
|
443
|
+
def _check_n_features(estimator, X, *, reset):
|
444
|
+
"""Set the `n_features_in_` attribute, or check against it on an estimator."""
|
445
|
+
return estimator._check_n_features(X, reset=reset)
|
446
|
+
|
447
|
+
def _check_feature_names(estimator, X, *, reset):
|
448
|
+
"""Check `input_features` and generate names if needed."""
|
449
|
+
return estimator._check_feature_names(X, reset=reset)
|
450
|
+
|
451
|
+
def check_array(
|
452
|
+
array,
|
453
|
+
accept_sparse=False,
|
454
|
+
*,
|
455
|
+
accept_large_sparse=True,
|
456
|
+
dtype="numeric",
|
457
|
+
order=None,
|
458
|
+
copy=False,
|
459
|
+
force_writeable=False,
|
460
|
+
ensure_all_finite=None,
|
461
|
+
ensure_non_negative=False,
|
462
|
+
ensure_2d=True,
|
463
|
+
allow_nd=False,
|
464
|
+
ensure_min_samples=1,
|
465
|
+
ensure_min_features=1,
|
466
|
+
estimator=None,
|
467
|
+
input_name="",
|
468
|
+
):
|
469
|
+
"""Input validation on an array, list, sparse matrix or similar.
|
470
|
+
|
471
|
+
Check the original documentation for more details:
|
472
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.utils.check_array.html
|
473
|
+
"""
|
474
|
+
from sklearn.utils.validation import check_array as _check_array
|
475
|
+
|
476
|
+
force_all_finite = ensure_all_finite if ensure_all_finite is not None else True
|
477
|
+
|
478
|
+
check_array_params = inspect.signature(_check_array).parameters
|
479
|
+
kwargs = {}
|
480
|
+
if "force_writeable" in check_array_params:
|
481
|
+
kwargs["force_writeable"] = force_writeable
|
482
|
+
if "ensure_non_negative" in check_array_params:
|
483
|
+
kwargs["ensure_non_negative"] = ensure_non_negative
|
484
|
+
|
485
|
+
return _check_array(
|
486
|
+
array,
|
487
|
+
accept_sparse=accept_sparse,
|
488
|
+
accept_large_sparse=accept_large_sparse,
|
489
|
+
dtype=dtype,
|
490
|
+
order=order,
|
491
|
+
copy=copy,
|
492
|
+
force_all_finite=force_all_finite,
|
493
|
+
ensure_2d=ensure_2d,
|
494
|
+
allow_nd=allow_nd,
|
495
|
+
ensure_min_samples=ensure_min_samples,
|
496
|
+
ensure_min_features=ensure_min_features,
|
497
|
+
estimator=estimator,
|
498
|
+
input_name=input_name,
|
499
|
+
**kwargs,
|
500
|
+
)
|
501
|
+
|
502
|
+
def check_X_y(
|
503
|
+
X,
|
504
|
+
y,
|
505
|
+
accept_sparse=False,
|
506
|
+
*,
|
507
|
+
accept_large_sparse=True,
|
508
|
+
dtype="numeric",
|
509
|
+
order=None,
|
510
|
+
copy=False,
|
511
|
+
force_writeable=False,
|
512
|
+
ensure_all_finite=None,
|
513
|
+
ensure_2d=True,
|
514
|
+
allow_nd=False,
|
515
|
+
multi_output=False,
|
516
|
+
ensure_min_samples=1,
|
517
|
+
ensure_min_features=1,
|
518
|
+
y_numeric=False,
|
519
|
+
estimator=None,
|
520
|
+
):
|
521
|
+
"""Input validation for standard estimators.
|
522
|
+
|
523
|
+
Check the original documentation for more details:
|
524
|
+
https://scikit-learn.org/stable/modules/generated/sklearn.utils.check_X_y.html
|
525
|
+
"""
|
526
|
+
from sklearn.utils.validation import check_X_y as _check_X_y
|
527
|
+
|
528
|
+
force_all_finite = ensure_all_finite if ensure_all_finite is not None else True
|
529
|
+
|
530
|
+
check_X_y_params = inspect.signature(_check_X_y).parameters
|
531
|
+
kwargs = {}
|
532
|
+
if "force_writeable" in check_X_y_params:
|
533
|
+
kwargs["force_writeable"] = force_writeable
|
534
|
+
|
535
|
+
return _check_X_y(
|
536
|
+
X,
|
537
|
+
y,
|
538
|
+
accept_sparse=accept_sparse,
|
539
|
+
accept_large_sparse=accept_large_sparse,
|
540
|
+
dtype=dtype,
|
541
|
+
order=order,
|
542
|
+
copy=copy,
|
543
|
+
force_all_finite=force_all_finite,
|
544
|
+
ensure_2d=ensure_2d,
|
545
|
+
allow_nd=allow_nd,
|
546
|
+
multi_output=multi_output,
|
547
|
+
ensure_min_samples=ensure_min_samples,
|
548
|
+
ensure_min_features=ensure_min_features,
|
549
|
+
y_numeric=y_numeric,
|
550
|
+
estimator=estimator,
|
551
|
+
**kwargs,
|
552
|
+
)
|
553
|
+
|
554
|
+
# tags infrastructure
|
555
|
+
@dataclass(**_dataclass_args())
|
556
|
+
class InputTags:
|
557
|
+
"""Tags for the input data.
|
558
|
+
|
559
|
+
Parameters
|
560
|
+
----------
|
561
|
+
one_d_array : bool, default=False
|
562
|
+
Whether the input can be a 1D array.
|
563
|
+
|
564
|
+
two_d_array : bool, default=True
|
565
|
+
Whether the input can be a 2D array. Note that most common
|
566
|
+
tests currently run only if this flag is set to ``True``.
|
567
|
+
|
568
|
+
three_d_array : bool, default=False
|
569
|
+
Whether the input can be a 3D array.
|
570
|
+
|
571
|
+
sparse : bool, default=False
|
572
|
+
Whether the input can be a sparse matrix.
|
573
|
+
|
574
|
+
categorical : bool, default=False
|
575
|
+
Whether the input can be categorical.
|
576
|
+
|
577
|
+
string : bool, default=False
|
578
|
+
Whether the input can be an array-like of strings.
|
579
|
+
|
580
|
+
dict : bool, default=False
|
581
|
+
Whether the input can be a dictionary.
|
582
|
+
|
583
|
+
positive_only : bool, default=False
|
584
|
+
Whether the estimator requires positive X.
|
585
|
+
|
586
|
+
allow_nan : bool, default=False
|
587
|
+
Whether the estimator supports data with missing values encoded as `np.nan`.
|
588
|
+
|
589
|
+
pairwise : bool, default=False
|
590
|
+
This boolean attribute indicates whether the data (`X`),
|
591
|
+
:term:`fit` and similar methods consists of pairwise measures
|
592
|
+
over samples rather than a feature representation for each
|
593
|
+
sample. It is usually `True` where an estimator has a
|
594
|
+
`metric` or `affinity` or `kernel` parameter with value
|
595
|
+
'precomputed'. Its primary purpose is to support a
|
596
|
+
:term:`meta-estimator` or a cross validation procedure that
|
597
|
+
extracts a sub-sample of data intended for a pairwise
|
598
|
+
estimator, where the data needs to be indexed on both axes.
|
599
|
+
Specifically, this tag is used by
|
600
|
+
`sklearn.utils.metaestimators._safe_split` to slice rows and
|
601
|
+
columns.
|
602
|
+
"""
|
603
|
+
|
604
|
+
one_d_array: bool = False
|
605
|
+
two_d_array: bool = True
|
606
|
+
three_d_array: bool = False
|
607
|
+
sparse: bool = False
|
608
|
+
categorical: bool = False
|
609
|
+
string: bool = False
|
610
|
+
dict: bool = False
|
611
|
+
positive_only: bool = False
|
612
|
+
allow_nan: bool = False
|
613
|
+
pairwise: bool = False
|
614
|
+
|
615
|
+
@dataclass(**_dataclass_args())
|
616
|
+
class TargetTags:
|
617
|
+
"""Tags for the target data.
|
618
|
+
|
619
|
+
Parameters
|
620
|
+
----------
|
621
|
+
required : bool
|
622
|
+
Whether the estimator requires y to be passed to `fit`,
|
623
|
+
`fit_predict` or `fit_transform` methods. The tag is ``True``
|
624
|
+
for estimators inheriting from `~sklearn.base.RegressorMixin`
|
625
|
+
and `~sklearn.base.ClassifierMixin`.
|
626
|
+
|
627
|
+
one_d_labels : bool, default=False
|
628
|
+
Whether the input is a 1D labels (y).
|
629
|
+
|
630
|
+
two_d_labels : bool, default=False
|
631
|
+
Whether the input is a 2D labels (y).
|
632
|
+
|
633
|
+
positive_only : bool, default=False
|
634
|
+
Whether the estimator requires a positive y (only applicable
|
635
|
+
for regression).
|
636
|
+
|
637
|
+
multi_output : bool, default=False
|
638
|
+
Whether a regressor supports multi-target outputs or a classifier supports
|
639
|
+
multi-class multi-output.
|
640
|
+
|
641
|
+
single_output : bool, default=True
|
642
|
+
Whether the target can be single-output. This can be ``False`` if the
|
643
|
+
estimator supports only multi-output cases.
|
644
|
+
"""
|
645
|
+
|
646
|
+
required: bool
|
647
|
+
one_d_labels: bool = False
|
648
|
+
two_d_labels: bool = False
|
649
|
+
positive_only: bool = False
|
650
|
+
multi_output: bool = False
|
651
|
+
single_output: bool = True
|
652
|
+
|
653
|
+
@dataclass(**_dataclass_args())
|
654
|
+
class TransformerTags:
|
655
|
+
"""Tags for the transformer.
|
656
|
+
|
657
|
+
Parameters
|
658
|
+
----------
|
659
|
+
preserves_dtype : list[str], default=["float64"]
|
660
|
+
Applies only on transformers. It corresponds to the data types
|
661
|
+
which will be preserved such that `X_trans.dtype` is the same
|
662
|
+
as `X.dtype` after calling `transformer.transform(X)`. If this
|
663
|
+
list is empty, then the transformer is not expected to
|
664
|
+
preserve the data type. The first value in the list is
|
665
|
+
considered as the default data type, corresponding to the data
|
666
|
+
type of the output when the input data type is not going to be
|
667
|
+
preserved.
|
668
|
+
"""
|
669
|
+
|
670
|
+
preserves_dtype: list[str] = field(default_factory=lambda: ["float64"])
|
671
|
+
|
672
|
+
@dataclass(**_dataclass_args())
|
673
|
+
class ClassifierTags:
|
674
|
+
"""Tags for the classifier.
|
675
|
+
|
676
|
+
Parameters
|
677
|
+
----------
|
678
|
+
poor_score : bool, default=False
|
679
|
+
Whether the estimator fails to provide a "reasonable" test-set
|
680
|
+
score, which currently for classification is an accuracy of
|
681
|
+
0.83 on ``make_blobs(n_samples=300, random_state=0)``. The
|
682
|
+
datasets and values are based on current estimators in scikit-learn
|
683
|
+
and might be replaced by something more systematic.
|
684
|
+
|
685
|
+
multi_class : bool, default=True
|
686
|
+
Whether the classifier can handle multi-class
|
687
|
+
classification. Note that all classifiers support binary
|
688
|
+
classification. Therefore this flag indicates whether the
|
689
|
+
classifier is a binary-classifier-only or not.
|
690
|
+
|
691
|
+
multi_label : bool, default=False
|
692
|
+
Whether the classifier supports multi-label output.
|
693
|
+
"""
|
694
|
+
|
695
|
+
poor_score: bool = False
|
696
|
+
multi_class: bool = True
|
697
|
+
multi_label: bool = False
|
698
|
+
|
699
|
+
@dataclass(**_dataclass_args())
|
700
|
+
class RegressorTags:
|
701
|
+
"""Tags for the regressor.
|
702
|
+
|
703
|
+
Parameters
|
704
|
+
----------
|
705
|
+
poor_score : bool, default=False
|
706
|
+
Whether the estimator fails to provide a "reasonable" test-set
|
707
|
+
score, which currently for regression is an R2 of 0.5 on
|
708
|
+
``make_regression(n_samples=200, n_features=10,
|
709
|
+
n_informative=1, bias=5.0, noise=20, random_state=42)``. The
|
710
|
+
dataset and values are based on current estimators in scikit-learn
|
711
|
+
and might be replaced by something more systematic.
|
712
|
+
|
713
|
+
multi_label : bool, default=False
|
714
|
+
Whether the regressor supports multilabel output.
|
715
|
+
"""
|
716
|
+
|
717
|
+
poor_score: bool = False
|
718
|
+
multi_label: bool = False
|
719
|
+
|
720
|
+
@dataclass(**_dataclass_args())
|
721
|
+
class Tags:
|
722
|
+
"""Tags for the estimator.
|
723
|
+
|
724
|
+
See :ref:`estimator_tags` for more information.
|
725
|
+
|
726
|
+
Parameters
|
727
|
+
----------
|
728
|
+
estimator_type : str or None
|
729
|
+
The type of the estimator. Can be one of:
|
730
|
+
- "classifier"
|
731
|
+
- "regressor"
|
732
|
+
- "transformer"
|
733
|
+
- "clusterer"
|
734
|
+
- "outlier_detector"
|
735
|
+
- "density_estimator"
|
736
|
+
|
737
|
+
target_tags : :class:`TargetTags`
|
738
|
+
The target(y) tags.
|
739
|
+
|
740
|
+
transformer_tags : :class:`TransformerTags` or None
|
741
|
+
The transformer tags.
|
742
|
+
|
743
|
+
classifier_tags : :class:`ClassifierTags` or None
|
744
|
+
The classifier tags.
|
745
|
+
|
746
|
+
regressor_tags : :class:`RegressorTags` or None
|
747
|
+
The regressor tags.
|
748
|
+
|
749
|
+
array_api_support : bool, default=False
|
750
|
+
Whether the estimator supports Array API compatible inputs.
|
751
|
+
|
752
|
+
no_validation : bool, default=False
|
753
|
+
Whether the estimator skips input-validation. This is only meant for
|
754
|
+
stateless and dummy transformers!
|
755
|
+
|
756
|
+
non_deterministic : bool, default=False
|
757
|
+
Whether the estimator is not deterministic given a fixed ``random_state``.
|
758
|
+
|
759
|
+
requires_fit : bool, default=True
|
760
|
+
Whether the estimator requires to be fitted before calling one of
|
761
|
+
`transform`, `predict`, `predict_proba`, or `decision_function`.
|
762
|
+
|
763
|
+
_skip_test : bool, default=False
|
764
|
+
Whether to skip common tests entirely. Don't use this unless
|
765
|
+
you have a *very good* reason.
|
766
|
+
|
767
|
+
input_tags : :class:`InputTags`
|
768
|
+
The input data(X) tags.
|
769
|
+
"""
|
770
|
+
|
771
|
+
estimator_type: str | None
|
772
|
+
target_tags: TargetTags
|
773
|
+
transformer_tags: TransformerTags | None = None
|
774
|
+
classifier_tags: ClassifierTags | None = None
|
775
|
+
regressor_tags: RegressorTags | None = None
|
776
|
+
array_api_support: bool = False
|
777
|
+
no_validation: bool = False
|
778
|
+
non_deterministic: bool = False
|
779
|
+
requires_fit: bool = True
|
780
|
+
_skip_test: bool = False
|
781
|
+
input_tags: InputTags = field(default_factory=InputTags)
|
782
|
+
|
783
|
+
def _patched_more_tags(estimator, expected_failed_checks):
|
784
|
+
original_class_more_tags = estimator.__class__._more_tags
|
785
|
+
|
786
|
+
def patched_instance_more_tags(self):
|
787
|
+
"""Instance-level _more_tags that combines class tags with _xfail_checks."""
|
788
|
+
# Get tags from class-level _more_tags
|
789
|
+
tags = original_class_more_tags(self)
|
790
|
+
# Update with the xfail checks
|
791
|
+
tags.update({"_xfail_checks": expected_failed_checks})
|
792
|
+
return tags
|
793
|
+
|
794
|
+
# Patch both class and instance level
|
795
|
+
estimator.__class__._more_tags = patched_instance_more_tags
|
796
|
+
estimator._more_tags = types.MethodType(patched_instance_more_tags, estimator)
|
797
|
+
return estimator
|
798
|
+
|
799
|
+
def check_estimator(
|
800
|
+
estimator=None,
|
801
|
+
generate_only=False,
|
802
|
+
*,
|
803
|
+
legacy: bool = True,
|
804
|
+
expected_failed_checks: dict[str, str] | None = None,
|
805
|
+
on_skip: Literal["warn"] | None = "warn",
|
806
|
+
on_fail: Literal["raise", "warn"] | None = "raise",
|
807
|
+
callback: Callable | None = None,
|
808
|
+
):
|
809
|
+
# legacy, on_skip, on_fail, and callback are not supported and ignored
|
810
|
+
from sklearn.utils.estimator_checks import check_estimator
|
811
|
+
|
812
|
+
return check_estimator(
|
813
|
+
_patched_more_tags(estimator, expected_failed_checks),
|
814
|
+
generate_only=generate_only,
|
815
|
+
)
|
816
|
+
|
817
|
+
def parametrize_with_checks(
|
818
|
+
estimators,
|
819
|
+
*,
|
820
|
+
legacy: bool = True,
|
821
|
+
expected_failed_checks: Callable | None = None,
|
822
|
+
):
|
823
|
+
# legacy is not supported and ignored
|
824
|
+
from sklearn.utils.estimator_checks import parametrize_with_checks
|
825
|
+
|
826
|
+
estimators = [
|
827
|
+
_patched_more_tags(estimator, expected_failed_checks(estimator))
|
828
|
+
for estimator in estimators
|
829
|
+
]
|
830
|
+
|
831
|
+
return parametrize_with_checks(estimators)
|
832
|
+
|
833
|
+
else:
|
834
|
+
# base
|
835
|
+
from sklearn.base import is_clusterer # noqa: F401
|
836
|
+
|
837
|
+
# test_common
|
838
|
+
# tags infrastructure
|
839
|
+
from sklearn.utils import (
|
840
|
+
ClassifierTags,
|
841
|
+
InputTags,
|
842
|
+
RegressorTags,
|
843
|
+
Tags,
|
844
|
+
TargetTags,
|
845
|
+
TransformerTags,
|
846
|
+
)
|
847
|
+
from sklearn.utils._test_common.instance_generator import (
|
848
|
+
_construct_instances, # noqa: F401
|
849
|
+
)
|
850
|
+
from sklearn.utils.estimator_checks import (
|
851
|
+
check_estimator, # noqa: F401
|
852
|
+
parametrize_with_checks, # noqa: F401
|
853
|
+
)
|
854
|
+
from sklearn.utils.multiclass import type_of_target # noqa: F401
|
855
|
+
|
856
|
+
# validation
|
857
|
+
from sklearn.utils.validation import (
|
858
|
+
_check_feature_names,
|
859
|
+
_check_n_features,
|
860
|
+
check_array, # noqa: F401
|
861
|
+
check_X_y, # noqa: F401
|
862
|
+
validate_data, # noqa: F401
|
863
|
+
)
|