panelsplit 2.0.5__tar.gz → 2.0.5.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/.github/workflows/ci.yml +1 -1
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/PKG-INFO +4 -4
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/README.md +1 -1
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/metrics.py +37 -152
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/model_selection/model_selection.py +117 -10
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/pipeline.py +25 -57
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/pyproject.toml +2 -2
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_metrics.py +0 -8
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_pipeline.py +0 -21
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_search.py +0 -18
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/uv.lock +1959 -1434
- panelsplit-2.0.5/tests/test_sequentialcvpipeline_indices.py +0 -148
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/.github/workflows/lint.yml +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/.github/workflows/pre-commit.yml +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/.github/workflows/releases.yml +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/.gitignore +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/.pre-commit-config.yaml +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/CHANGELOG.md +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/CITATION.cff +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/CNAME +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/CODE_OF_CONDUCT.md +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/LICENSE +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/examples/An introduction to PanelSplit.ipynb +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/__init__.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/application.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/cross_validation.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/model_selection/__init__.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/model_selection/_validation.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/plot.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/utils/__init__.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/utils/_response.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/utils/typing.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/utils/utils.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/panelsplit/utils/validation.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/__init__.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/df_generation.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_PanelSplit.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_check_fitted_fix.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_cross_validation.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_edge_cases.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_issue_59_fix.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_narwhals_compatibility.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_plot.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_scorer.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_set_params.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_utils.py +0 -0
- {panelsplit-2.0.5 → panelsplit-2.0.5.dev0}/tests/test_validation_coverage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: panelsplit
|
|
3
|
-
Version: 2.0.5
|
|
3
|
+
Version: 2.0.5.dev0
|
|
4
4
|
Summary: A tool for panel data analysis.
|
|
5
5
|
Project-URL: Homepage, https://github.com/4Freye/panelsplit
|
|
6
6
|
Project-URL: Repository, https://github.com/4Freye/panelsplit
|
|
@@ -11,13 +11,13 @@ License-File: LICENSE
|
|
|
11
11
|
Classifier: License :: OSI Approved :: MIT License
|
|
12
12
|
Classifier: Operating System :: OS Independent
|
|
13
13
|
Classifier: Programming Language :: Python :: 3
|
|
14
|
-
Requires-Python: >=3.
|
|
14
|
+
Requires-Python: >=3.10
|
|
15
15
|
Requires-Dist: joblib>=1.0.1
|
|
16
16
|
Requires-Dist: matplotlib>=3.4.3
|
|
17
17
|
Requires-Dist: narwhals>=1.42.1
|
|
18
18
|
Requires-Dist: numpy>=1.21.0
|
|
19
19
|
Requires-Dist: pandas>=1.3.0
|
|
20
|
-
Requires-Dist: scikit-learn>=
|
|
20
|
+
Requires-Dist: scikit-learn>=0.24.2
|
|
21
21
|
Requires-Dist: scipy>=1.10.1
|
|
22
22
|
Requires-Dist: tqdm>=4.67.1
|
|
23
23
|
Requires-Dist: typing-extensions>=4.13.2
|
|
@@ -32,7 +32,7 @@ panelsplit is a Python package designed to facilitate time series cross-validati
|
|
|
32
32
|
|
|
33
33
|
## Installation
|
|
34
34
|
|
|
35
|
-
panelsplit is tested for compatibility with python versions >= 3.
|
|
35
|
+
panelsplit is tested for compatibility with python versions >= 3.10. You can install panelsplit using pip:
|
|
36
36
|
|
|
37
37
|
```bash
|
|
38
38
|
pip install panelsplit
|
|
@@ -7,7 +7,7 @@ panelsplit is a Python package designed to facilitate time series cross-validati
|
|
|
7
7
|
|
|
8
8
|
## Installation
|
|
9
9
|
|
|
10
|
-
panelsplit is tested for compatibility with python versions >= 3.
|
|
10
|
+
panelsplit is tested for compatibility with python versions >= 3.10. You can install panelsplit using pip:
|
|
11
11
|
|
|
12
12
|
```bash
|
|
13
13
|
pip install panelsplit
|
|
@@ -1,42 +1,37 @@
|
|
|
1
|
-
|
|
2
|
-
Metrics that are equivalent their sklearn counterparts, except for the fact that they work with SequentialCVPipeline.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
# Standard library
|
|
6
|
-
import warnings
|
|
1
|
+
from .utils.validation import _safe_indexing
|
|
7
2
|
from inspect import signature
|
|
8
3
|
from collections.abc import Iterable
|
|
9
4
|
from functools import partial
|
|
5
|
+
from sklearn.metrics._scorer import _MultimetricScorer
|
|
6
|
+
from sklearn.utils._param_validation import (
|
|
7
|
+
validate_params,
|
|
8
|
+
)
|
|
9
|
+
from sklearn.metrics._scorer import _PassthroughScorer, _get_response_method_name
|
|
10
10
|
from copy import deepcopy
|
|
11
|
+
from sklearn.utils.validation import _check_response_method
|
|
12
|
+
import warnings
|
|
13
|
+
from sklearn.base import is_regressor
|
|
14
|
+
from panelsplit.utils._response import _get_response_values
|
|
15
|
+
from sklearn.utils.metadata_routing import (
|
|
16
|
+
_MetadataRequester,
|
|
17
|
+
_raise_for_params,
|
|
18
|
+
_routing_enabled,
|
|
19
|
+
MetadataRequest,
|
|
20
|
+
)
|
|
21
|
+
from .utils.typing import EstimatorLike, ArrayLike
|
|
22
|
+
from numpy.typing import NDArray
|
|
11
23
|
from typing import Callable, Optional, List, Union, Any, Dict
|
|
12
|
-
|
|
13
|
-
# Third-party / typing
|
|
14
24
|
from typing_extensions import Self
|
|
15
|
-
from numpy.typing import NDArray
|
|
16
25
|
|
|
17
|
-
#
|
|
18
|
-
from .utils.validation import _safe_indexing
|
|
19
|
-
from .utils.typing import EstimatorLike, ArrayLike
|
|
20
|
-
from panelsplit.utils._response import _get_response_values
|
|
21
|
-
|
|
22
|
-
# sklearn public metrics (single consolidated import)
|
|
26
|
+
# all the error scores:
|
|
23
27
|
from sklearn.metrics import (
|
|
24
28
|
accuracy_score,
|
|
25
|
-
adjusted_mutual_info_score,
|
|
26
|
-
adjusted_rand_score,
|
|
27
29
|
average_precision_score,
|
|
28
30
|
balanced_accuracy_score,
|
|
29
31
|
brier_score_loss,
|
|
30
32
|
class_likelihood_ratios,
|
|
31
|
-
completeness_score,
|
|
32
33
|
d2_absolute_error_score,
|
|
33
|
-
d2_brier_score,
|
|
34
|
-
d2_log_loss_score,
|
|
35
34
|
explained_variance_score,
|
|
36
|
-
f1_score,
|
|
37
|
-
fowlkes_mallows_score,
|
|
38
|
-
jaccard_score,
|
|
39
|
-
homogeneity_score,
|
|
40
35
|
log_loss,
|
|
41
36
|
matthews_corrcoef,
|
|
42
37
|
max_error,
|
|
@@ -47,35 +42,22 @@ from sklearn.metrics import (
|
|
|
47
42
|
mean_squared_error,
|
|
48
43
|
mean_squared_log_error,
|
|
49
44
|
median_absolute_error,
|
|
50
|
-
mutual_info_score,
|
|
51
|
-
normalized_mutual_info_score,
|
|
52
|
-
precision_score,
|
|
53
|
-
rand_score,
|
|
54
45
|
r2_score,
|
|
55
|
-
recall_score,
|
|
56
46
|
roc_auc_score,
|
|
57
47
|
root_mean_squared_error,
|
|
58
48
|
root_mean_squared_log_error,
|
|
59
49
|
top_k_accuracy_score,
|
|
60
|
-
v_measure_score,
|
|
61
50
|
)
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
# metadata routing utilities (used by some sklearn internals)
|
|
74
|
-
from sklearn.utils.metadata_routing import (
|
|
75
|
-
_MetadataRequester,
|
|
76
|
-
_raise_for_params,
|
|
77
|
-
_routing_enabled,
|
|
78
|
-
MetadataRequest,
|
|
51
|
+
from sklearn.metrics.cluster import (
|
|
52
|
+
adjusted_mutual_info_score,
|
|
53
|
+
adjusted_rand_score,
|
|
54
|
+
completeness_score,
|
|
55
|
+
fowlkes_mallows_score,
|
|
56
|
+
homogeneity_score,
|
|
57
|
+
mutual_info_score,
|
|
58
|
+
normalized_mutual_info_score,
|
|
59
|
+
rand_score,
|
|
60
|
+
v_measure_score,
|
|
79
61
|
)
|
|
80
62
|
|
|
81
63
|
|
|
@@ -106,63 +88,14 @@ def make_SequentialCV_scorer(
|
|
|
106
88
|
greater_is_better: bool = True,
|
|
107
89
|
**kwargs: Any,
|
|
108
90
|
) -> Callable[..., float]:
|
|
109
|
-
"""
|
|
110
|
-
Make a SequentialCVPipeline-compatible scorer from a performance metric.
|
|
111
|
-
|
|
112
|
-
A scorer is a wrapper around an arbitrary metric or loss function that is called
|
|
113
|
-
with the signature `scorer(estimator, X, y_true, **kwargs)`.
|
|
114
|
-
|
|
115
|
-
The parameter `response_method` allows to specify which method of the estimator
|
|
116
|
-
should be used to feed the scoring/loss function.
|
|
117
|
-
|
|
118
|
-
Parameters
|
|
119
|
-
----------
|
|
120
|
-
score_func : callable
|
|
121
|
-
Score function (or loss function) with signature
|
|
122
|
-
``score_func(y, y_pred, **kwargs)``.
|
|
123
|
-
|
|
124
|
-
response_method : {"predict_proba", "decision_function", "predict"} or \
|
|
125
|
-
list/tuple of such str, default="predict"
|
|
126
|
-
|
|
127
|
-
Specifies the response method to use get prediction from an estimator
|
|
128
|
-
(i.e. :term:`predict_proba`, :term:`decision_function` or
|
|
129
|
-
:term:`predict`). Possible choices are:
|
|
130
|
-
|
|
131
|
-
- if `str`, it corresponds to the name to the method to return;
|
|
132
|
-
- if a list or tuple of `str`, it provides the method names in order of
|
|
133
|
-
preference. The method returned corresponds to the first method in
|
|
134
|
-
the list and which is implemented by `estimator`.
|
|
135
|
-
|
|
136
|
-
greater_is_better : bool, default=True
|
|
137
|
-
Whether `score_func` is a score function (default), meaning high is
|
|
138
|
-
good, or a loss function, meaning low is good. In the latter case, the
|
|
139
|
-
scorer object will sign-flip the outcome of the `score_func`.
|
|
140
|
-
|
|
141
|
-
**kwargs : additional arguments
|
|
142
|
-
Additional parameters to be passed to `score_func`.
|
|
143
|
-
|
|
144
|
-
Returns
|
|
145
|
-
-------
|
|
146
|
-
Callable
|
|
147
|
-
Callable object that returns a scalar score; greater is better.
|
|
148
|
-
|
|
149
|
-
Examples
|
|
150
|
-
--------
|
|
151
|
-
>>> from panelsplit.metrics import make_SequentialCV_scorer
|
|
152
|
-
>>> from sklearn.metrics import brier_score_loss
|
|
153
|
-
>>> brier_loss_scorer= make_SequentialCV_scorer(brier_score_loss, response_method='predict_proba', greater_is_better=False)
|
|
154
|
-
|
|
155
|
-
>>> from panelsplit.pipeline import SequentialCVPipeline
|
|
156
|
-
>>> from sklearn.ensemble import RandomForestClassifier
|
|
157
|
-
>>> from sklearn.datasets import load_iris
|
|
158
|
-
>>> X, y = load_iris(return_X_y=True)
|
|
159
|
-
>>> p = SequentialCVPipeline(steps = [('rf', RandomForestClassifier())], cv_steps = [None])
|
|
160
|
-
>>> p.fit(X, y)
|
|
161
|
-
>>> brier_loss_scorer(p, X, y)
|
|
162
|
-
"""
|
|
163
91
|
sign = 1 if greater_is_better else -1
|
|
164
92
|
|
|
165
93
|
if response_method is None:
|
|
94
|
+
warnings.warn(
|
|
95
|
+
"response_method=None is deprecated in version 1.6 and will be removed "
|
|
96
|
+
"in version 1.8. Leave it to its default value to avoid this warning.",
|
|
97
|
+
FutureWarning,
|
|
98
|
+
)
|
|
166
99
|
response_method = "predict"
|
|
167
100
|
elif response_method == "default":
|
|
168
101
|
response_method = "predict"
|
|
@@ -225,6 +158,7 @@ class _BaseScorer(_MetadataRequester):
|
|
|
225
158
|
self._sign = sign
|
|
226
159
|
self._kwargs = kwargs
|
|
227
160
|
self._response_method = response_method
|
|
161
|
+
# TODO (1.8): remove in 1.8 (scoring="max_error" has been deprecated in 1.6)
|
|
228
162
|
self._deprecation_msg = None
|
|
229
163
|
|
|
230
164
|
def _get_pos_label(self) -> Optional[Any]:
|
|
@@ -236,6 +170,7 @@ class _BaseScorer(_MetadataRequester):
|
|
|
236
170
|
return None
|
|
237
171
|
|
|
238
172
|
def _accept_sample_weight(self) -> bool:
|
|
173
|
+
# TODO(slep006): remove when metadata routing is the only way
|
|
239
174
|
return "sample_weight" in signature(self._score_func).parameters
|
|
240
175
|
|
|
241
176
|
def __repr__(self) -> str:
|
|
@@ -282,6 +217,7 @@ class _BaseScorer(_MetadataRequester):
|
|
|
282
217
|
float
|
|
283
218
|
Score function applied to prediction of estimator on X.
|
|
284
219
|
"""
|
|
220
|
+
# TODO (1.8): remove in 1.8 (scoring="max_error" has been deprecated in 1.6)
|
|
285
221
|
if self._deprecation_msg is not None:
|
|
286
222
|
warnings.warn(
|
|
287
223
|
self._deprecation_msg, category=DeprecationWarning, stacklevel=2
|
|
@@ -378,7 +314,6 @@ class _Scorer(_BaseScorer):
|
|
|
378
314
|
X,
|
|
379
315
|
pos_label=pos_label,
|
|
380
316
|
)
|
|
381
|
-
|
|
382
317
|
# make lookup dict for fast matching
|
|
383
318
|
pred_dict = dict(zip(idx, y_pred))
|
|
384
319
|
|
|
@@ -405,36 +340,6 @@ class _Scorer(_BaseScorer):
|
|
|
405
340
|
prefer_skip_nested_validation=True,
|
|
406
341
|
)
|
|
407
342
|
def get_scorer(scoring: Union[str, Callable]) -> Any:
|
|
408
|
-
"""
|
|
409
|
-
Get a scorer from string.
|
|
410
|
-
|
|
411
|
-
`sklearn.metrics.get_scorer_names` can be used to retrieve the names
|
|
412
|
-
of all available scorers.
|
|
413
|
-
|
|
414
|
-
Parameters
|
|
415
|
-
----------
|
|
416
|
-
scoring : str, callable or None
|
|
417
|
-
Scoring method as string. If callable it is returned as is.
|
|
418
|
-
If None, returns None.
|
|
419
|
-
|
|
420
|
-
Returns
|
|
421
|
-
-------
|
|
422
|
-
callable
|
|
423
|
-
The scorer.
|
|
424
|
-
|
|
425
|
-
Notes
|
|
426
|
-
-----
|
|
427
|
-
When passed a string, this function always returns a copy of the scorer
|
|
428
|
-
object. Calling `get_scorer` twice for the same scorer results in two
|
|
429
|
-
separate scorer objects.
|
|
430
|
-
|
|
431
|
-
Examples
|
|
432
|
-
--------
|
|
433
|
-
>>> from panelsplit.metrics import get_scorer
|
|
434
|
-
>>> accuracy = get_scorer("accuracy")
|
|
435
|
-
>>> accuracy(classifier, X, y)
|
|
436
|
-
"""
|
|
437
|
-
|
|
438
343
|
if isinstance(scoring, str):
|
|
439
344
|
try:
|
|
440
345
|
scorer = deepcopy(_SCORERS[scoring])
|
|
@@ -584,11 +489,7 @@ neg_mean_poisson_deviance_scorer = make_SequentialCV_scorer(
|
|
|
584
489
|
neg_mean_gamma_deviance_scorer = make_SequentialCV_scorer(
|
|
585
490
|
mean_gamma_deviance, greater_is_better=False
|
|
586
491
|
)
|
|
587
|
-
# D^2 scorers (fraction of explained Brier / log-loss)
|
|
588
492
|
d2_absolute_error_scorer = make_SequentialCV_scorer(d2_absolute_error_score)
|
|
589
|
-
d2_brier_scorer = make_SequentialCV_scorer(d2_brier_score)
|
|
590
|
-
d2_log_loss_scorer = make_SequentialCV_scorer(d2_log_loss_score)
|
|
591
|
-
|
|
592
493
|
|
|
593
494
|
# Standard Classification Scores
|
|
594
495
|
accuracy_scorer = make_SequentialCV_scorer(accuracy_score)
|
|
@@ -682,8 +583,6 @@ _SCORERS = dict(
|
|
|
682
583
|
neg_mean_poisson_deviance=neg_mean_poisson_deviance_scorer,
|
|
683
584
|
neg_mean_gamma_deviance=neg_mean_gamma_deviance_scorer,
|
|
684
585
|
d2_absolute_error_score=d2_absolute_error_scorer,
|
|
685
|
-
d2_brier_score=d2_brier_scorer,
|
|
686
|
-
d2_log_loss_score=d2_log_loss_scorer,
|
|
687
586
|
accuracy=accuracy_scorer,
|
|
688
587
|
top_k_accuracy=top_k_accuracy_scorer,
|
|
689
588
|
roc_auc=roc_auc_scorer,
|
|
@@ -708,17 +607,3 @@ _SCORERS = dict(
|
|
|
708
607
|
normalized_mutual_info_score=normalized_mutual_info_scorer,
|
|
709
608
|
fowlkes_mallows_score=fowlkes_mallows_scorer,
|
|
710
609
|
)
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
for name, metric in [
|
|
714
|
-
("precision", precision_score),
|
|
715
|
-
("recall", recall_score),
|
|
716
|
-
("f1", f1_score),
|
|
717
|
-
("jaccard", jaccard_score),
|
|
718
|
-
]:
|
|
719
|
-
_SCORERS[name] = make_SequentialCV_scorer(metric, average="binary")
|
|
720
|
-
for average in ["macro", "micro", "samples", "weighted"]:
|
|
721
|
-
qualified_name = "{0}_{1}".format(name, average)
|
|
722
|
-
_SCORERS[qualified_name] = make_SequentialCV_scorer(
|
|
723
|
-
metric, pos_label=None, average=average
|
|
724
|
-
)
|
|
@@ -970,8 +970,8 @@ class GridSearch(BaseSearch):
|
|
|
970
970
|
|
|
971
971
|
If `scoring` represents a single score, one can use:
|
|
972
972
|
|
|
973
|
-
- a single string (see
|
|
974
|
-
- a callable (see
|
|
973
|
+
- a single string (see :ref:`scoring_string_names`);
|
|
974
|
+
- a callable (see :ref:`scoring_callable`) that returns a single value;
|
|
975
975
|
- `None`, the `estimator`'s default evaluation criterion is used.
|
|
976
976
|
|
|
977
977
|
If `scoring` represents multiple scores, one can use:
|
|
@@ -981,13 +981,16 @@ class GridSearch(BaseSearch):
|
|
|
981
981
|
names and the values are the metric scores;
|
|
982
982
|
- a dictionary with metric names as keys and callables as values.
|
|
983
983
|
|
|
984
|
-
See
|
|
984
|
+
See :ref:`multimetric_grid_search` for an example.
|
|
985
985
|
|
|
986
986
|
n_jobs : int, default=None
|
|
987
987
|
Number of jobs to run in parallel.
|
|
988
988
|
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
|
|
989
989
|
``-1`` means using all processors.
|
|
990
990
|
|
|
991
|
+
.. versionchanged:: v0.20
|
|
992
|
+
`n_jobs` default changed from 1 to None
|
|
993
|
+
|
|
991
994
|
refit : bool, str, or callable, default=True
|
|
992
995
|
Refit an estimator using the best found parameters on the whole
|
|
993
996
|
dataset.
|
|
@@ -1051,20 +1054,67 @@ class GridSearch(BaseSearch):
|
|
|
1051
1054
|
expensive and is not strictly required to select the parameters that
|
|
1052
1055
|
yield the best generalization performance.
|
|
1053
1056
|
|
|
1057
|
+
.. versionadded:: 0.19
|
|
1058
|
+
|
|
1059
|
+
.. versionchanged:: 0.21
|
|
1060
|
+
Default value was changed from ``True`` to ``False``
|
|
1061
|
+
|
|
1054
1062
|
Attributes
|
|
1055
1063
|
----------
|
|
1056
1064
|
cv_results_ : dict of numpy (masked) ndarrays
|
|
1057
1065
|
A dict with keys as column headers and values as columns, that can be
|
|
1058
1066
|
imported into a pandas ``DataFrame``.
|
|
1059
1067
|
|
|
1068
|
+
For instance the below given table
|
|
1069
|
+
|
|
1070
|
+
+------------+-----------+------------+-----------------+---+---------+
|
|
1071
|
+
|param_kernel|param_gamma|param_degree|split0_test_score|...|rank_t...|
|
|
1072
|
+
+============+===========+============+=================+===+=========+
|
|
1073
|
+
| 'poly' | -- | 2 | 0.80 |...| 2 |
|
|
1074
|
+
+------------+-----------+------------+-----------------+---+---------+
|
|
1075
|
+
| 'poly' | -- | 3 | 0.70 |...| 4 |
|
|
1076
|
+
+------------+-----------+------------+-----------------+---+---------+
|
|
1077
|
+
| 'rbf' | 0.1 | -- | 0.80 |...| 3 |
|
|
1078
|
+
+------------+-----------+------------+-----------------+---+---------+
|
|
1079
|
+
| 'rbf' | 0.2 | -- | 0.93 |...| 1 |
|
|
1080
|
+
+------------+-----------+------------+-----------------+---+---------+
|
|
1081
|
+
|
|
1082
|
+
will be represented by a ``cv_results_`` dict of::
|
|
1083
|
+
|
|
1084
|
+
{
|
|
1085
|
+
'param_kernel': masked_array(data = ['poly', 'poly', 'rbf', 'rbf'],
|
|
1086
|
+
mask = [False False False False]...)
|
|
1087
|
+
'param_gamma': masked_array(data = [-- -- 0.1 0.2],
|
|
1088
|
+
mask = [ True True False False]...),
|
|
1089
|
+
'param_degree': masked_array(data = [2.0 3.0 -- --],
|
|
1090
|
+
mask = [False False True True]...),
|
|
1091
|
+
'split0_test_score' : [0.80, 0.70, 0.80, 0.93],
|
|
1092
|
+
'split1_test_score' : [0.82, 0.50, 0.70, 0.78],
|
|
1093
|
+
'mean_test_score' : [0.81, 0.60, 0.75, 0.85],
|
|
1094
|
+
'std_test_score' : [0.01, 0.10, 0.05, 0.08],
|
|
1095
|
+
'rank_test_score' : [2, 4, 3, 1],
|
|
1096
|
+
'split0_train_score' : [0.80, 0.92, 0.70, 0.93],
|
|
1097
|
+
'split1_train_score' : [0.82, 0.55, 0.70, 0.87],
|
|
1098
|
+
'mean_train_score' : [0.81, 0.74, 0.70, 0.90],
|
|
1099
|
+
'std_train_score' : [0.01, 0.19, 0.00, 0.03],
|
|
1100
|
+
'mean_fit_time' : [0.73, 0.63, 0.43, 0.49],
|
|
1101
|
+
'std_fit_time' : [0.01, 0.02, 0.01, 0.01],
|
|
1102
|
+
'mean_score_time' : [0.01, 0.06, 0.04, 0.04],
|
|
1103
|
+
'std_score_time' : [0.00, 0.00, 0.00, 0.01],
|
|
1104
|
+
'params' : [{'kernel': 'poly', 'degree': 2}, ...],
|
|
1105
|
+
}
|
|
1106
|
+
|
|
1060
1107
|
For an example of visualization and interpretation of GridSearch results,
|
|
1061
|
-
see
|
|
1108
|
+
see :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_stats.py`.
|
|
1062
1109
|
|
|
1063
1110
|
NOTE
|
|
1064
1111
|
|
|
1065
1112
|
The key ``'params'`` is used to store a list of parameter
|
|
1066
1113
|
settings dicts for all the parameter candidates.
|
|
1067
1114
|
|
|
1115
|
+
The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
|
|
1116
|
+
``std_score_time`` are all in seconds.
|
|
1117
|
+
|
|
1068
1118
|
For multi-metric evaluation, the scores for all the scorers are
|
|
1069
1119
|
available in the ``cv_results_`` dict at the keys ending with that
|
|
1070
1120
|
scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown
|
|
@@ -1117,6 +1167,8 @@ class GridSearch(BaseSearch):
|
|
|
1117
1167
|
|
|
1118
1168
|
This is present only if ``refit`` is not False.
|
|
1119
1169
|
|
|
1170
|
+
.. versionadded:: 0.20
|
|
1171
|
+
|
|
1120
1172
|
multimetric_ : bool
|
|
1121
1173
|
Whether or not the scorers compute several metrics.
|
|
1122
1174
|
|
|
@@ -1130,12 +1182,16 @@ class GridSearch(BaseSearch):
|
|
|
1130
1182
|
parameter for more details) and that `best_estimator_` exposes
|
|
1131
1183
|
`n_features_in_` when fit.
|
|
1132
1184
|
|
|
1185
|
+
.. versionadded:: 0.24
|
|
1186
|
+
|
|
1133
1187
|
feature_names_in_ : ndarray of shape (`n_features_in_`,)
|
|
1134
1188
|
Names of features seen during :term:`fit`. Only defined if
|
|
1135
1189
|
`best_estimator_` is defined (see the documentation for the `refit`
|
|
1136
1190
|
parameter for more details) and that `best_estimator_` exposes
|
|
1137
1191
|
`feature_names_in_` when fit.
|
|
1138
1192
|
|
|
1193
|
+
.. versionadded:: 1.0
|
|
1194
|
+
|
|
1139
1195
|
See Also
|
|
1140
1196
|
--------
|
|
1141
1197
|
ParameterGrid : Generates all the combinations of a hyperparameter grid.
|
|
@@ -1170,11 +1226,11 @@ class GridSearch(BaseSearch):
|
|
|
1170
1226
|
GridSearch(estimator=SVC(),
|
|
1171
1227
|
param_grid={'C': [1, 10], 'kernel': ('linear', 'rbf')})
|
|
1172
1228
|
>>> sorted(clf.cv_results_.keys())
|
|
1173
|
-
['mean_test_score',...
|
|
1229
|
+
['mean_fit_time', 'mean_score_time', 'mean_test_score',...
|
|
1174
1230
|
'param_C', 'param_kernel', 'params',...
|
|
1175
1231
|
'rank_test_score', 'split0_test_score',...
|
|
1176
1232
|
'split2_test_score', ...
|
|
1177
|
-
'std_test_score']
|
|
1233
|
+
'std_fit_time', 'std_score_time', 'std_test_score']
|
|
1178
1234
|
"""
|
|
1179
1235
|
|
|
1180
1236
|
_parameter_constraints: dict = {
|
|
@@ -1264,8 +1320,8 @@ class RandomizedSearch(BaseSearch):
|
|
|
1264
1320
|
|
|
1265
1321
|
If `scoring` represents a single score, one can use:
|
|
1266
1322
|
|
|
1267
|
-
- a single string (see
|
|
1268
|
-
- a callable (see
|
|
1323
|
+
- a single string (see :ref:`scoring_string_names`);
|
|
1324
|
+
- a callable (see :ref:`scoring_callable`) that returns a single value;
|
|
1269
1325
|
- `None`, the `estimator`'s default evaluation criterion is used.
|
|
1270
1326
|
|
|
1271
1327
|
If `scoring` represents multiple scores, one can use:
|
|
@@ -1275,7 +1331,7 @@ class RandomizedSearch(BaseSearch):
|
|
|
1275
1331
|
names and the values are the metric scores;
|
|
1276
1332
|
- a dictionary with metric names as keys and callables as values.
|
|
1277
1333
|
|
|
1278
|
-
See
|
|
1334
|
+
See :ref:`multimetric_grid_search` for an example.
|
|
1279
1335
|
|
|
1280
1336
|
If None, the estimator's score method is used.
|
|
1281
1337
|
|
|
@@ -1285,6 +1341,9 @@ class RandomizedSearch(BaseSearch):
|
|
|
1285
1341
|
``-1`` means using all processors.
|
|
1286
1342
|
for more details.
|
|
1287
1343
|
|
|
1344
|
+
.. versionchanged:: v0.20
|
|
1345
|
+
`n_jobs` default changed from 1 to None
|
|
1346
|
+
|
|
1288
1347
|
refit : bool, str, or callable, default=True
|
|
1289
1348
|
Refit an estimator using the best found parameters on the whole
|
|
1290
1349
|
dataset.
|
|
@@ -1354,20 +1413,62 @@ class RandomizedSearch(BaseSearch):
|
|
|
1354
1413
|
expensive and is not strictly required to select the parameters that
|
|
1355
1414
|
yield the best generalization performance.
|
|
1356
1415
|
|
|
1416
|
+
.. versionadded:: 0.19
|
|
1417
|
+
|
|
1418
|
+
.. versionchanged:: 0.21
|
|
1419
|
+
Default value was changed from ``True`` to ``False``
|
|
1420
|
+
|
|
1357
1421
|
Attributes
|
|
1358
1422
|
----------
|
|
1359
1423
|
cv_results_ : dict of numpy (masked) ndarrays
|
|
1360
1424
|
A dict with keys as column headers and values as columns, that can be
|
|
1361
1425
|
imported into a pandas ``DataFrame``.
|
|
1362
1426
|
|
|
1427
|
+
For instance the below given table
|
|
1428
|
+
|
|
1429
|
+
+--------------+-------------+-------------------+---+---------------+
|
|
1430
|
+
| param_kernel | param_gamma | split0_test_score |...|rank_test_score|
|
|
1431
|
+
+==============+=============+===================+===+===============+
|
|
1432
|
+
| 'rbf' | 0.1 | 0.80 |...| 1 |
|
|
1433
|
+
+--------------+-------------+-------------------+---+---------------+
|
|
1434
|
+
| 'rbf' | 0.2 | 0.84 |...| 3 |
|
|
1435
|
+
+--------------+-------------+-------------------+---+---------------+
|
|
1436
|
+
| 'rbf' | 0.3 | 0.70 |...| 2 |
|
|
1437
|
+
+--------------+-------------+-------------------+---+---------------+
|
|
1438
|
+
|
|
1439
|
+
will be represented by a ``cv_results_`` dict of::
|
|
1440
|
+
|
|
1441
|
+
{
|
|
1442
|
+
'param_kernel' : masked_array(data = ['rbf', 'rbf', 'rbf'],
|
|
1443
|
+
mask = False),
|
|
1444
|
+
'param_gamma' : masked_array(data = [0.1 0.2 0.3], mask = False),
|
|
1445
|
+
'split0_test_score' : [0.80, 0.84, 0.70],
|
|
1446
|
+
'split1_test_score' : [0.82, 0.50, 0.70],
|
|
1447
|
+
'mean_test_score' : [0.81, 0.67, 0.70],
|
|
1448
|
+
'std_test_score' : [0.01, 0.24, 0.00],
|
|
1449
|
+
'rank_test_score' : [1, 3, 2],
|
|
1450
|
+
'split0_train_score' : [0.80, 0.92, 0.70],
|
|
1451
|
+
'split1_train_score' : [0.82, 0.55, 0.70],
|
|
1452
|
+
'mean_train_score' : [0.81, 0.74, 0.70],
|
|
1453
|
+
'std_train_score' : [0.01, 0.19, 0.00],
|
|
1454
|
+
'mean_fit_time' : [0.73, 0.63, 0.43],
|
|
1455
|
+
'std_fit_time' : [0.01, 0.02, 0.01],
|
|
1456
|
+
'mean_score_time' : [0.01, 0.06, 0.04],
|
|
1457
|
+
'std_score_time' : [0.00, 0.00, 0.00],
|
|
1458
|
+
'params' : [{'kernel' : 'rbf', 'gamma' : 0.1}, ...],
|
|
1459
|
+
}
|
|
1460
|
+
|
|
1363
1461
|
For an example of analysing ``cv_results_``,
|
|
1364
|
-
see
|
|
1462
|
+
see :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_stats.py`.
|
|
1365
1463
|
|
|
1366
1464
|
NOTE
|
|
1367
1465
|
|
|
1368
1466
|
The key ``'params'`` is used to store a list of parameter
|
|
1369
1467
|
settings dicts for all the parameter candidates.
|
|
1370
1468
|
|
|
1469
|
+
The ``mean_fit_time``, ``std_fit_time``, ``mean_score_time`` and
|
|
1470
|
+
``std_score_time`` are all in seconds.
|
|
1471
|
+
|
|
1371
1472
|
For multi-metric evaluation, the scores for all the scorers are
|
|
1372
1473
|
available in the ``cv_results_`` dict at the keys ending with that
|
|
1373
1474
|
scorer's name (``'_<scorer_name>'``) instead of ``'_score'`` shown
|
|
@@ -1423,6 +1524,8 @@ class RandomizedSearch(BaseSearch):
|
|
|
1423
1524
|
|
|
1424
1525
|
This is present only if ``refit`` is not False.
|
|
1425
1526
|
|
|
1527
|
+
.. versionadded:: 0.20
|
|
1528
|
+
|
|
1426
1529
|
multimetric_ : bool
|
|
1427
1530
|
Whether or not the scorers compute several metrics.
|
|
1428
1531
|
|
|
@@ -1436,12 +1539,16 @@ class RandomizedSearch(BaseSearch):
|
|
|
1436
1539
|
parameter for more details) and that `best_estimator_` exposes
|
|
1437
1540
|
`n_features_in_` when fit.
|
|
1438
1541
|
|
|
1542
|
+
.. versionadded:: 0.24
|
|
1543
|
+
|
|
1439
1544
|
feature_names_in_ : ndarray of shape (`n_features_in_`,)
|
|
1440
1545
|
Names of features seen during :term:`fit`. Only defined if
|
|
1441
1546
|
`best_estimator_` is defined (see the documentation for the `refit`
|
|
1442
1547
|
parameter for more details) and that `best_estimator_` exposes
|
|
1443
1548
|
`feature_names_in_` when fit.
|
|
1444
1549
|
|
|
1550
|
+
.. versionadded:: 1.0
|
|
1551
|
+
|
|
1445
1552
|
See Also
|
|
1446
1553
|
--------
|
|
1447
1554
|
GridSearch : Does exhaustive search over a grid of parameters.
|