sklearn-compat 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
@@ -0,0 +1,673 @@
1
+ """Ease developer experience to support multiple versions of scikit-learn.
2
+
3
+ This file is intended to be vendored in your project if you do not want to depend on
4
+ `sklearn-compat` as a package. Then, you can import directly from this file.
5
+
6
+ Be aware that depending on `sklearn-compat` does not add any additional dependencies:
7
+ we are only depending on `scikit-learn`.
8
+
9
+ Version: 0.1.0
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import functools
15
+ import platform
16
+ import sys
17
+ from dataclasses import dataclass, field
18
+ from typing import Callable, Literal
19
+
20
+ import sklearn
21
+ from sklearn.utils._param_validation import validate_parameter_constraints
22
+ from sklearn.utils.fixes import parse_version
23
+
24
+ sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
25
+
26
+
27
+ ########################################################################################
28
+ # The following code does not depend on the sklearn version
29
+ ########################################################################################
30
+
31
+
32
+ # tags infrastructure
33
+ def _dataclass_args():
34
+ if sys.version_info < (3, 10):
35
+ return {}
36
+ return {"slots": True}
37
+
38
+
39
+ def get_tags(estimator):
40
+ """Get estimator tags in a consistent format across different sklearn versions.
41
+
42
+ This function provides compatibility between sklearn versions before and after 1.6.
43
+ It returns either a Tags object (sklearn >= 1.6) or a converted Tags object from
44
+ the dictionary format (sklearn < 1.6) containing metadata about the estimator's
45
+ requirements and capabilities.
46
+
47
+ Parameters
48
+ ----------
49
+ estimator : estimator object
50
+ A scikit-learn estimator instance.
51
+
52
+ Returns
53
+ -------
54
+ tags : Tags
55
+ An object containing metadata about the estimator's requirements and
56
+ capabilities (e.g., input types, fitting requirements, classifier/regressor
57
+ specific tags).
58
+ """
59
+ try:
60
+ from sklearn.utils._tags import get_tags
61
+
62
+ return get_tags(estimator)
63
+ except ImportError:
64
+ from sklearn.utils._tags import _safe_tags
65
+
66
+ return _to_new_tags(_safe_tags(estimator), estimator)
67
+
68
+
69
+ def _to_new_tags(old_tags, estimator=None):
70
+ """Utility function convert old tags (dictionary) to new tags (dataclass)."""
71
+ input_tags = InputTags(
72
+ one_d_array="1darray" in old_tags["X_types"],
73
+ two_d_array="2darray" in old_tags["X_types"],
74
+ three_d_array="3darray" in old_tags["X_types"],
75
+ sparse="sparse" in old_tags["X_types"],
76
+ categorical="categorical" in old_tags["X_types"],
77
+ string="string" in old_tags["X_types"],
78
+ dict="dict" in old_tags["X_types"],
79
+ positive_only=old_tags["requires_positive_X"],
80
+ allow_nan=old_tags["allow_nan"],
81
+ pairwise=old_tags["pairwise"],
82
+ )
83
+ target_tags = TargetTags(
84
+ required=old_tags["requires_y"],
85
+ one_d_labels="1dlabels" in old_tags["X_types"],
86
+ two_d_labels="2dlabels" in old_tags["X_types"],
87
+ positive_only=old_tags["requires_positive_y"],
88
+ multi_output=old_tags["multioutput"] or old_tags["multioutput_only"],
89
+ single_output=not old_tags["multioutput_only"],
90
+ )
91
+ if estimator is not None and (
92
+ hasattr(estimator, "transform") or hasattr(estimator, "fit_transform")
93
+ ):
94
+ transformer_tags = TransformerTags(
95
+ preserves_dtype=old_tags["preserves_dtype"],
96
+ )
97
+ else:
98
+ transformer_tags = None
99
+ estimator_type = getattr(estimator, "_estimator_type", None)
100
+ if estimator_type == "classifier":
101
+ classifier_tags = ClassifierTags(
102
+ poor_score=old_tags["poor_score"],
103
+ multi_class=not old_tags["binary_only"],
104
+ multi_label=old_tags["multilabel"],
105
+ )
106
+ else:
107
+ classifier_tags = None
108
+ if estimator_type == "regressor":
109
+ regressor_tags = RegressorTags(
110
+ poor_score=old_tags["poor_score"],
111
+ multi_label=old_tags["multilabel"],
112
+ )
113
+ else:
114
+ regressor_tags = None
115
+ return Tags(
116
+ estimator_type=estimator_type,
117
+ target_tags=target_tags,
118
+ transformer_tags=transformer_tags,
119
+ classifier_tags=classifier_tags,
120
+ regressor_tags=regressor_tags,
121
+ input_tags=input_tags,
122
+ # Array-API was introduced in 1.3, we need to default to False if not inside
123
+ # the old-tags.
124
+ array_api_support=old_tags.get("array_api_support", False),
125
+ no_validation=old_tags["no_validation"],
126
+ non_deterministic=old_tags["non_deterministic"],
127
+ requires_fit=old_tags["requires_fit"],
128
+ _skip_test=old_tags["_skip_test"],
129
+ )
130
+
131
+
132
+ ########################################################################################
133
+ # Upgrading for scikit-learn 1.3
134
+ ########################################################################################
135
+
136
+ if sklearn_version < parse_version("1.3"):
137
+ # parameter validation
138
+ def _fit_context(*, prefer_skip_nested_validation):
139
+ """Decorator to run the fit methods of estimators within context managers.
140
+
141
+ With scikit-learn < 1.3, this decorator is no-op.
142
+
143
+ Parameters
144
+ ----------
145
+ prefer_skip_nested_validation : bool
146
+ If True, the validation of parameters of inner estimators or functions
147
+ called during fit will be skipped.
148
+
149
+ This is useful to avoid validating many times the parameters passed by the
150
+ user from the public facing API. It's also useful to avoid validating
151
+ parameters that we pass internally to inner functions that are guaranteed to
152
+ be valid by the test suite.
153
+
154
+ It should be set to True for most estimators, except for those that receive
155
+ non-validated objects as parameters, such as meta-estimators that are given
156
+ estimator objects.
157
+
158
+ Returns
159
+ -------
160
+ decorated_fit : method
161
+ The decorated fit method.
162
+ """
163
+
164
+ def decorator(fit_method):
165
+ @functools.wraps(fit_method)
166
+ def wrapper(estimator, *args, **kwargs):
167
+ estimator._validate_params()
168
+ return fit_method(estimator, *args, **kwargs)
169
+
170
+ return wrapper
171
+
172
+ return decorator
173
+ else:
174
+ # parameter validation
175
+
176
+ from sklearn.base import _fit_context # noqa: F401
177
+
178
+
179
+ ########################################################################################
180
+ # Upgrading for scikit-learn 1.4
181
+ ########################################################################################
182
+
183
+
184
+ if sklearn_version < parse_version("1.4"):
185
+
186
+ def _is_fitted(estimator, attributes=None, all_or_any=all):
187
+ """Determine if an estimator is fitted
188
+
189
+ Parameters
190
+ ----------
191
+ estimator : estimator instance
192
+ Estimator instance for which the check is performed.
193
+
194
+ attributes : str, list or tuple of str, default=None
195
+ Attribute name(s) given as string or a list/tuple of strings
196
+ Eg.: ``["coef_", "estimator_", ...], "coef_"``
197
+
198
+ If `None`, `estimator` is considered fitted if there exist an
199
+ attribute that ends with a underscore and does not start with double
200
+ underscore.
201
+
202
+ all_or_any : callable, {all, any}, default=all
203
+ Specify whether all or any of the given attributes must exist.
204
+
205
+ Returns
206
+ -------
207
+ fitted : bool
208
+ Whether the estimator is fitted.
209
+ """
210
+ if attributes is not None:
211
+ if not isinstance(attributes, (list, tuple)):
212
+ attributes = [attributes]
213
+ return all_or_any([hasattr(estimator, attr) for attr in attributes])
214
+
215
+ if hasattr(estimator, "__sklearn_is_fitted__"):
216
+ return estimator.__sklearn_is_fitted__()
217
+
218
+ fitted_attrs = [
219
+ v for v in vars(estimator) if v.endswith("_") and not v.startswith("__")
220
+ ]
221
+ return len(fitted_attrs) > 0
222
+
223
+ else:
224
+ from sklearn.utils.validation import _is_fitted # noqa: F401
225
+
226
+
227
+ ########################################################################################
228
+ # Upgrading for scikit-learn 1.5
229
+ ########################################################################################
230
+
231
+
232
+ if sklearn_version < parse_version("1.5"):
233
+ # chunking
234
+ # extmath
235
+ # fixes
236
+ from sklearn.utils import (
237
+ _IS_32BIT, # noqa: F401
238
+ _approximate_mode, # noqa: F401
239
+ _in_unstable_openblas_configuration, # noqa: F401
240
+ gen_batches, # noqa: F401
241
+ gen_even_slices, # noqa: F401
242
+ get_chunk_n_rows, # noqa: F401
243
+ safe_sqr, # noqa: F401
244
+ )
245
+ from sklearn.utils import _chunk_generator as chunk_generator # noqa: F401
246
+
247
+ _IS_WASM = platform.machine() in ["wasm32", "wasm64"]
248
+ # indexing
249
+ # mask
250
+ # missing
251
+ # optional dependencies
252
+ # user interface
253
+ # validation
254
+ from sklearn.utils import (
255
+ _determine_key_type, # noqa: F401
256
+ _get_column_indices, # noqa: F401
257
+ _print_elapsed_time, # noqa: F401
258
+ _safe_assign, # noqa: F401
259
+ _safe_indexing, # noqa: F401
260
+ _to_object_array, # noqa: F401
261
+ axis0_safe_slice, # noqa: F401
262
+ check_matplotlib_support, # noqa: F401
263
+ check_pandas_support, # noqa: F401
264
+ indices_to_mask, # noqa: F401
265
+ is_scalar_nan, # noqa: F401
266
+ resample, # noqa: F401
267
+ safe_mask, # noqa: F401
268
+ shuffle, # noqa: F401
269
+ )
270
+ from sklearn.utils import _is_pandas_na as is_pandas_na # noqa: F401
271
+ else:
272
+ # chunking
273
+ from sklearn.utils._chunking import (
274
+ chunk_generator, # noqa: F401
275
+ gen_batches, # noqa: F401
276
+ gen_even_slices, # noqa: F401
277
+ get_chunk_n_rows, # noqa: F401
278
+ )
279
+
280
+ # indexing
281
+ from sklearn.utils._indexing import (
282
+ _determine_key_type, # noqa: F401
283
+ _get_column_indices, # noqa: F401
284
+ _safe_assign, # noqa: F401
285
+ _safe_indexing, # noqa: F401
286
+ resample, # noqa: F401
287
+ shuffle, # noqa: F401
288
+ )
289
+
290
+ # mask
291
+ from sklearn.utils._mask import (
292
+ axis0_safe_slice, # noqa: F401
293
+ indices_to_mask, # noqa: F401
294
+ safe_mask, # noqa: F401
295
+ )
296
+
297
+ # missing
298
+ from sklearn.utils._missing import (
299
+ is_pandas_na, # noqa: F401
300
+ is_scalar_nan, # noqa: F401
301
+ )
302
+
303
+ # optional dependencies
304
+ from sklearn.utils._optional_dependencies import ( # noqa: F401
305
+ check_matplotlib_support,
306
+ check_pandas_support, # noqa: F401
307
+ )
308
+
309
+ # user interface
310
+ from sklearn.utils._user_interface import _print_elapsed_time # noqa: F401
311
+
312
+ # extmath
313
+ from sklearn.utils.extmath import (
314
+ _approximate_mode, # noqa: F401
315
+ safe_sqr, # noqa: F401
316
+ )
317
+
318
+ # fixes
319
+ from sklearn.utils.fixes import (
320
+ _IS_32BIT, # noqa: F401
321
+ _IS_WASM, # noqa: F401
322
+ _in_unstable_openblas_configuration, # noqa: F401
323
+ )
324
+
325
+ # validation
326
+ from sklearn.utils.validation import _to_object_array # noqa: F401
327
+
328
+ ########################################################################################
329
+ # Upgrading for scikit-learn 1.6
330
+ ########################################################################################
331
+
332
+
333
+ if sklearn_version < parse_version("1.6"):
334
+ # test_common
335
+ from sklearn.utils.estimator_checks import _construct_instance
336
+
337
+ def type_of_target(y, input_name="", *, raise_unknown=False):
338
+ # fix for raise_unknown which is introduced in scikit-learn 1.6
339
+ from sklearn.utils.multiclass import type_of_target
340
+
341
+ def _raise_or_return(target_type):
342
+ """Depending on the value of raise_unknown, either raise an error or
343
+ return 'unknown'.
344
+ """
345
+ if raise_unknown and target_type == "unknown":
346
+ input = input_name if input_name else "data"
347
+ raise ValueError(f"Unknown label type for {input}: {y!r}")
348
+ else:
349
+ return target_type
350
+
351
+ target_type = type_of_target(y, input_name=input_name)
352
+ return _raise_or_return(target_type)
353
+
354
+ def _construct_instances(Estimator):
355
+ yield _construct_instance(Estimator)
356
+
357
+ # validation
358
+ def validate_data(_estimator, /, **kwargs):
359
+ if "ensure_all_finite" in kwargs:
360
+ force_all_finite = kwargs.pop("ensure_all_finite")
361
+ else:
362
+ force_all_finite = True
363
+ return _estimator._validate_data(**kwargs, force_all_finite=force_all_finite)
364
+
365
+ def _check_n_features(estimator, X, *, reset):
366
+ return estimator._check_n_features(X, reset=reset)
367
+
368
+ def _check_feature_names(estimator, X, *, reset):
369
+ return estimator._check_feature_names(X, reset=reset)
370
+
371
+ # tags infrastructure
372
+ @dataclass(**_dataclass_args())
373
+ class InputTags:
374
+ """Tags for the input data.
375
+
376
+ Parameters
377
+ ----------
378
+ one_d_array : bool, default=False
379
+ Whether the input can be a 1D array.
380
+
381
+ two_d_array : bool, default=True
382
+ Whether the input can be a 2D array. Note that most common
383
+ tests currently run only if this flag is set to ``True``.
384
+
385
+ three_d_array : bool, default=False
386
+ Whether the input can be a 3D array.
387
+
388
+ sparse : bool, default=False
389
+ Whether the input can be a sparse matrix.
390
+
391
+ categorical : bool, default=False
392
+ Whether the input can be categorical.
393
+
394
+ string : bool, default=False
395
+ Whether the input can be an array-like of strings.
396
+
397
+ dict : bool, default=False
398
+ Whether the input can be a dictionary.
399
+
400
+ positive_only : bool, default=False
401
+ Whether the estimator requires positive X.
402
+
403
+ allow_nan : bool, default=False
404
+ Whether the estimator supports data with missing values encoded as `np.nan`.
405
+
406
+ pairwise : bool, default=False
407
+ This boolean attribute indicates whether the data (`X`),
408
+ :term:`fit` and similar methods consists of pairwise measures
409
+ over samples rather than a feature representation for each
410
+ sample. It is usually `True` where an estimator has a
411
+ `metric` or `affinity` or `kernel` parameter with value
412
+ 'precomputed'. Its primary purpose is to support a
413
+ :term:`meta-estimator` or a cross validation procedure that
414
+ extracts a sub-sample of data intended for a pairwise
415
+ estimator, where the data needs to be indexed on both axes.
416
+ Specifically, this tag is used by
417
+ `sklearn.utils.metaestimators._safe_split` to slice rows and
418
+ columns.
419
+ """
420
+
421
+ one_d_array: bool = False
422
+ two_d_array: bool = True
423
+ three_d_array: bool = False
424
+ sparse: bool = False
425
+ categorical: bool = False
426
+ string: bool = False
427
+ dict: bool = False
428
+ positive_only: bool = False
429
+ allow_nan: bool = False
430
+ pairwise: bool = False
431
+
432
+ @dataclass(**_dataclass_args())
433
+ class TargetTags:
434
+ """Tags for the target data.
435
+
436
+ Parameters
437
+ ----------
438
+ required : bool
439
+ Whether the estimator requires y to be passed to `fit`,
440
+ `fit_predict` or `fit_transform` methods. The tag is ``True``
441
+ for estimators inheriting from `~sklearn.base.RegressorMixin`
442
+ and `~sklearn.base.ClassifierMixin`.
443
+
444
+ one_d_labels : bool, default=False
445
+ Whether the input is a 1D labels (y).
446
+
447
+ two_d_labels : bool, default=False
448
+ Whether the input is a 2D labels (y).
449
+
450
+ positive_only : bool, default=False
451
+ Whether the estimator requires a positive y (only applicable
452
+ for regression).
453
+
454
+ multi_output : bool, default=False
455
+ Whether a regressor supports multi-target outputs or a classifier supports
456
+ multi-class multi-output.
457
+
458
+ single_output : bool, default=True
459
+ Whether the target can be single-output. This can be ``False`` if the
460
+ estimator supports only multi-output cases.
461
+ """
462
+
463
+ required: bool
464
+ one_d_labels: bool = False
465
+ two_d_labels: bool = False
466
+ positive_only: bool = False
467
+ multi_output: bool = False
468
+ single_output: bool = True
469
+
470
+ @dataclass(**_dataclass_args())
471
+ class TransformerTags:
472
+ """Tags for the transformer.
473
+
474
+ Parameters
475
+ ----------
476
+ preserves_dtype : list[str], default=["float64"]
477
+ Applies only on transformers. It corresponds to the data types
478
+ which will be preserved such that `X_trans.dtype` is the same
479
+ as `X.dtype` after calling `transformer.transform(X)`. If this
480
+ list is empty, then the transformer is not expected to
481
+ preserve the data type. The first value in the list is
482
+ considered as the default data type, corresponding to the data
483
+ type of the output when the input data type is not going to be
484
+ preserved.
485
+ """
486
+
487
+ preserves_dtype: list[str] = field(default_factory=lambda: ["float64"])
488
+
489
+ @dataclass(**_dataclass_args())
490
+ class ClassifierTags:
491
+ """Tags for the classifier.
492
+
493
+ Parameters
494
+ ----------
495
+ poor_score : bool, default=False
496
+ Whether the estimator fails to provide a "reasonable" test-set
497
+ score, which currently for classification is an accuracy of
498
+ 0.83 on ``make_blobs(n_samples=300, random_state=0)``. The
499
+ datasets and values are based on current estimators in scikit-learn
500
+ and might be replaced by something more systematic.
501
+
502
+ multi_class : bool, default=True
503
+ Whether the classifier can handle multi-class
504
+ classification. Note that all classifiers support binary
505
+ classification. Therefore this flag indicates whether the
506
+ classifier is a binary-classifier-only or not.
507
+
508
+ multi_label : bool, default=False
509
+ Whether the classifier supports multi-label output.
510
+ """
511
+
512
+ poor_score: bool = False
513
+ multi_class: bool = True
514
+ multi_label: bool = False
515
+
516
+ @dataclass(**_dataclass_args())
517
+ class RegressorTags:
518
+ """Tags for the regressor.
519
+
520
+ Parameters
521
+ ----------
522
+ poor_score : bool, default=False
523
+ Whether the estimator fails to provide a "reasonable" test-set
524
+ score, which currently for regression is an R2 of 0.5 on
525
+ ``make_regression(n_samples=200, n_features=10,
526
+ n_informative=1, bias=5.0, noise=20, random_state=42)``. The
527
+ dataset and values are based on current estimators in scikit-learn
528
+ and might be replaced by something more systematic.
529
+
530
+ multi_label : bool, default=False
531
+ Whether the regressor supports multilabel output.
532
+ """
533
+
534
+ poor_score: bool = False
535
+ multi_label: bool = False
536
+
537
+ @dataclass(**_dataclass_args())
538
+ class Tags:
539
+ """Tags for the estimator.
540
+
541
+ See :ref:`estimator_tags` for more information.
542
+
543
+ Parameters
544
+ ----------
545
+ estimator_type : str or None
546
+ The type of the estimator. Can be one of:
547
+ - "classifier"
548
+ - "regressor"
549
+ - "transformer"
550
+ - "clusterer"
551
+ - "outlier_detector"
552
+ - "density_estimator"
553
+
554
+ target_tags : :class:`TargetTags`
555
+ The target(y) tags.
556
+
557
+ transformer_tags : :class:`TransformerTags` or None
558
+ The transformer tags.
559
+
560
+ classifier_tags : :class:`ClassifierTags` or None
561
+ The classifier tags.
562
+
563
+ regressor_tags : :class:`RegressorTags` or None
564
+ The regressor tags.
565
+
566
+ array_api_support : bool, default=False
567
+ Whether the estimator supports Array API compatible inputs.
568
+
569
+ no_validation : bool, default=False
570
+ Whether the estimator skips input-validation. This is only meant for
571
+ stateless and dummy transformers!
572
+
573
+ non_deterministic : bool, default=False
574
+ Whether the estimator is not deterministic given a fixed ``random_state``.
575
+
576
+ requires_fit : bool, default=True
577
+ Whether the estimator requires to be fitted before calling one of
578
+ `transform`, `predict`, `predict_proba`, or `decision_function`.
579
+
580
+ _skip_test : bool, default=False
581
+ Whether to skip common tests entirely. Don't use this unless
582
+ you have a *very good* reason.
583
+
584
+ input_tags : :class:`InputTags`
585
+ The input data(X) tags.
586
+ """
587
+
588
+ estimator_type: str | None
589
+ target_tags: TargetTags
590
+ transformer_tags: TransformerTags | None = None
591
+ classifier_tags: ClassifierTags | None = None
592
+ regressor_tags: RegressorTags | None = None
593
+ array_api_support: bool = False
594
+ no_validation: bool = False
595
+ non_deterministic: bool = False
596
+ requires_fit: bool = True
597
+ _skip_test: bool = False
598
+ input_tags: InputTags = field(default_factory=InputTags)
599
+
600
+ def _patched_more_tags(estimator, expected_failed_checks):
601
+ import copy
602
+
603
+ from sklearn.utils._tags import _safe_tags
604
+
605
+ original_tags = copy.deepcopy(_safe_tags(estimator))
606
+
607
+ def patched_more_tags(self):
608
+ original_tags.update({"_xfail_checks": expected_failed_checks})
609
+ return original_tags
610
+
611
+ estimator.__class__._more_tags = patched_more_tags
612
+ return estimator
613
+
614
+ def check_estimator(
615
+ estimator=None,
616
+ generate_only=False,
617
+ *,
618
+ legacy: bool = True,
619
+ expected_failed_checks: dict[str, str] | None = None,
620
+ on_skip: Literal["warn"] | None = "warn",
621
+ on_fail: Literal["raise", "warn"] | None = "raise",
622
+ callback: Callable | None = None,
623
+ ):
624
+ # legacy, on_skip, on_fail, and callback are not supported and ignored
625
+ from sklearn.utils.estimator_checks import check_estimator
626
+
627
+ return check_estimator(
628
+ _patched_more_tags(estimator, expected_failed_checks),
629
+ generate_only=generate_only,
630
+ )
631
+
632
+ def parametrize_with_checks(
633
+ estimators,
634
+ *,
635
+ legacy: bool = True,
636
+ expected_failed_checks: Callable | None = None,
637
+ ):
638
+ # legacy is not supported and ignored
639
+ from sklearn.utils.estimator_checks import parametrize_with_checks
640
+
641
+ estimators = [
642
+ _patched_more_tags(estimator, expected_failed_checks(estimator))
643
+ for estimator in estimators
644
+ ]
645
+
646
+ return parametrize_with_checks(estimators)
647
+
648
+ else:
649
+ # test_common
650
+ # tags infrastructure
651
+ from sklearn.utils import (
652
+ ClassifierTags,
653
+ InputTags,
654
+ RegressorTags,
655
+ Tags,
656
+ TargetTags,
657
+ TransformerTags,
658
+ )
659
+ from sklearn.utils._test_common.instance_generator import (
660
+ _construct_instances, # noqa: F401
661
+ )
662
+ from sklearn.utils.estimator_checks import (
663
+ check_estimator, # noqa: F401
664
+ parametrize_with_checks, # noqa: F401
665
+ )
666
+ from sklearn.utils.multiclass import type_of_target # noqa: F401
667
+
668
+ # validation
669
+ from sklearn.utils.validation import (
670
+ _check_feature_names, # noqa: F401
671
+ _check_n_features, # noqa: F401
672
+ validate_data, # noqa: F401
673
+ )
sklearn_compat/base.py ADDED
@@ -0,0 +1,3 @@
1
+ from sklearn_compat._sklearn_compat import (
2
+ _fit_context, # noqa: F401
3
+ )
@@ -0,0 +1,19 @@
1
+ from sklearn_compat.utils._tags import (
2
+ ClassifierTags,
3
+ InputTags,
4
+ RegressorTags,
5
+ Tags,
6
+ TargetTags,
7
+ TransformerTags,
8
+ get_tags,
9
+ )
10
+
11
+ __all__ = [
12
+ "Tags",
13
+ "InputTags",
14
+ "TargetTags",
15
+ "ClassifierTags",
16
+ "RegressorTags",
17
+ "TransformerTags",
18
+ "get_tags",
19
+ ]
@@ -0,0 +1,6 @@
1
+ from sklearn_compat._sklearn_compat import (
2
+ chunk_generator, # noqa: F401
3
+ gen_batches, # noqa: F401
4
+ gen_even_slices, # noqa: F401
5
+ get_chunk_n_rows, # noqa: F401
6
+ )