panelsplit 1.1.1__tar.gz → 2.0.4.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/.github/workflows/ci.yml +1 -1
  2. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/.github/workflows/lint.yml +29 -2
  3. panelsplit-2.0.4.dev0/.gitignore +28 -0
  4. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/.pre-commit-config.yaml +15 -0
  5. panelsplit-2.0.4.dev0/CHANGELOG.md +40 -0
  6. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/PKG-INFO +5 -3
  7. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/README.md +1 -1
  8. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/panelsplit/__init__.py +16 -1
  9. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/panelsplit/application.py +105 -124
  10. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/panelsplit/cross_validation.py +99 -107
  11. panelsplit-2.0.4.dev0/panelsplit/metrics.py +609 -0
  12. panelsplit-2.0.4.dev0/panelsplit/model_selection/__init__.py +6 -0
  13. panelsplit-2.0.4.dev0/panelsplit/model_selection/_validation.py +512 -0
  14. panelsplit-2.0.4.dev0/panelsplit/model_selection/model_selection.py +1630 -0
  15. panelsplit-2.0.4.dev0/panelsplit/pipeline.py +1212 -0
  16. panelsplit-2.0.4.dev0/panelsplit/plot.py +66 -0
  17. panelsplit-2.0.4.dev0/panelsplit/utils/_response.py +73 -0
  18. panelsplit-2.0.4.dev0/panelsplit/utils/typing.py +12 -0
  19. panelsplit-2.0.4.dev0/panelsplit/utils/utils.py +17 -0
  20. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/panelsplit/utils/validation.py +53 -30
  21. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/pyproject.toml +40 -2
  22. panelsplit-2.0.4.dev0/tests/df_generation.py +78 -0
  23. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_check_fitted_fix.py +22 -9
  24. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_edge_cases.py +8 -10
  25. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_issue_59_fix.py +28 -21
  26. panelsplit-2.0.4.dev0/tests/test_metrics.py +139 -0
  27. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_narwhals_compatibility.py +20 -9
  28. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_pipeline.py +136 -14
  29. panelsplit-2.0.4.dev0/tests/test_scorer.py +85 -0
  30. panelsplit-2.0.4.dev0/tests/test_search.py +260 -0
  31. panelsplit-2.0.4.dev0/tests/test_set_params.py +29 -0
  32. panelsplit-2.0.4.dev0/uv.lock +3104 -0
  33. panelsplit-1.1.1/.gitignore +0 -13
  34. panelsplit-1.1.1/CHANGELOG.md +0 -24
  35. panelsplit-1.1.1/panelsplit/pipeline.py +0 -826
  36. panelsplit-1.1.1/panelsplit/plot.py +0 -61
  37. panelsplit-1.1.1/panelsplit/utils/utils.py +0 -11
  38. panelsplit-1.1.1/tests/test_search.py +0 -155
  39. panelsplit-1.1.1/uv.lock +0 -5257
  40. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/.github/workflows/pre-commit.yml +0 -0
  41. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/.github/workflows/releases.yml +0 -0
  42. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/CITATION.cff +0 -0
  43. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/CNAME +0 -0
  44. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/CODE_OF_CONDUCT.md +0 -0
  45. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/LICENSE +0 -0
  46. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/examples/An introduction to PanelSplit.ipynb +0 -0
  47. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/panelsplit/utils/__init__.py +0 -0
  48. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/__init__.py +0 -0
  49. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_PanelSplit.py +0 -0
  50. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_cross_validation.py +0 -0
  51. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_plot.py +0 -0
  52. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_utils.py +0 -0
  53. {panelsplit-1.1.1 → panelsplit-2.0.4.dev0}/tests/test_validation_coverage.py +0 -0
@@ -14,7 +14,7 @@ jobs:
14
14
 
15
15
  strategy:
16
16
  matrix:
17
- python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
17
+ python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
18
18
  fail-fast: true
19
19
 
20
20
  steps:
@@ -16,7 +16,7 @@ jobs:
16
16
  - name: Setup Python
17
17
  uses: actions/setup-python@v4
18
18
  with:
19
- python-version: '3.9'
19
+ python-version: '3.11'
20
20
  - name: Checkout
21
21
  uses: actions/checkout@v3
22
22
  - name: Install uv
@@ -26,8 +26,35 @@ jobs:
26
26
 
27
27
  - name: Install dependencies
28
28
  run: uv sync --dev
29
+
30
+ - name: Install mypy (match pre-commit)
31
+ run: uv run pip install mypy==1.18.2
32
+
29
33
  - name: Run mypy
30
- run: uv run mypy panelsplit
34
+ run: |
35
+ uv run mypy panelsplit \
36
+ --disallow-untyped-defs \
37
+ --disallow-incomplete-defs \
38
+ --ignore-missing-imports
39
+ numpydoc:
40
+ runs-on: ubuntu-latest
41
+ steps:
42
+ - name: Setup Python
43
+ uses: actions/setup-python@v4
44
+ with:
45
+ python-version: '3.11'
46
+ - name: Checkout
47
+ uses: actions/checkout@v3
48
+ - name: Install uv
49
+ run: |
50
+ curl -LsSf https://astral.sh/uv/install.sh | sh
51
+ echo "$HOME/.local/bin" >> $GITHUB_PATH
52
+
53
+ - name: Install dependencies
54
+ run: uv sync --dev
55
+
56
+ - name: Run numpydoc
57
+ run: uv run pre-commit run numpydoc-validation --all-files
31
58
 
32
59
  ruff:
33
60
  runs-on: ubuntu-latest
@@ -0,0 +1,28 @@
1
+ # Ignore compiled Python files
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ *.pyirc
7
+ *.mypy_cache
8
+ *.pytest_cache
9
+ *.ruff_cache
10
+ *docs/
11
+ *typecov/
12
+ dist/
13
+
14
+ .DS_Store
15
+
16
+ debug*
17
+
18
+ panelsplit.egg-info
19
+
20
+ .venv
21
+ .python-version
22
+
23
+ # drafts
24
+ examples/Intro to PanelSplit.ipynb
25
+ examples/PanelSplit explanation.ipynb
26
+ htmlcov
27
+ .coverage
28
+ .vscode/
@@ -23,3 +23,18 @@ repos:
23
23
  args: [--fix]
24
24
  - id: ruff-format
25
25
  types_or: [python, pyi]
26
+
27
+ - repo: https://github.com/numpy/numpydoc
28
+ rev: v1.10.0
29
+ hooks:
30
+ - id: numpydoc-validation
31
+ files: ^panelsplit/.*\.py$
32
+ exclude: ^panelsplit/(_|.*/_)
33
+
34
+
35
+ - repo: https://github.com/pre-commit/mirrors-mypy
36
+ rev: v1.18.2 # Use the sha / tag you want to point at
37
+ hooks:
38
+ - id: mypy
39
+ files: panelsplit
40
+ args: [--disallow-untyped-defs, --disallow-incomplete-defs, --ignore-missing-imports]
@@ -0,0 +1,40 @@
1
+ # Changelog
2
+
3
+ All notable changes to panelsplit will be documented in this file.
4
+
5
+ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6
+ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
+
8
+ ## [2.0.0] - 2025-12-31
9
+ ### Changed
10
+ - Switched from `pydoclint` to `numpydoc` for better docstring readability.
11
+ - `SequentialCVPipeline` made more sklearn-like (e.g. added functionality for `set_params` and `get_params`). Its arguments were also made to match sklearn; `steps` is a list of tuples of length 2 (name, estimator) instead of (name, estimator, cv). CVs were moved to a separate argument, `cv_steps`
12
+ ### Added
13
+ - `panelsplit.model_selection`. This includes `RandomizedSearch` and `GridSearch` in order to allow for hyper-parameter searching with `SequentialCVPipeline`
14
+ - `panelsplit.metrics`. This module includes metrics which work with the `model_selection` module.
15
+
16
+
17
+ ## [1.1.2] - 2025-10-28
18
+ ### Added
19
+ - Consistent type hints with more restrictions (E.g. `--disallow-untyped-defs` `--disallow-incomplete-defs`), addressing [#85](https://github.com/4Freye/panelsplit/issues/85)
20
+ - Consistent docstrings addressing [#94](https://github.com/4Freye/panelsplit/issues/94)
21
+ - mypy and pydoclint checks on `pre-commit-config.yaml` and `.github/workflows/lint.yml`
22
+
23
+
24
+ ## [1.1.1] - 2025-10-23
25
+ ### Changed
26
+ - Migrated from boolean indexing to purely integer-based indexing, as mentioned in [#86](https://github.com/4Freye/panelsplit/issues/86)
27
+ ### Added
28
+ - Consistent type hints throughout the Python codebase, addressing [#85](https://github.com/4Freye/panelsplit/issues/85)
29
+ - mypy to CI, addressing [#85](https://github.com/4Freye/panelsplit/issues/85)
30
+
31
+ ## [1.1.0] - 2025-10-21
32
+ ### Added
33
+ - Support for more DataFrame types (e.g. polars) via narwhals
34
+
35
+ ## [1.0.4] - 2025-10-16
36
+ ### Added
37
+ - `CHANGELOG.md` - marking changes to the project
38
+ - Automation of publishing to pypi
39
+ - Dynamic versioning
40
+ - Automation of GitHub Releases via `CHANGELOG.md`
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: panelsplit
3
- Version: 1.1.1
3
+ Version: 2.0.4.dev0
4
4
  Summary: A tool for panel data analysis.
5
5
  Project-URL: Homepage, https://github.com/4Freye/panelsplit
6
6
  Project-URL: Repository, https://github.com/4Freye/panelsplit
@@ -11,14 +11,16 @@ License-File: LICENSE
11
11
  Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Operating System :: OS Independent
13
13
  Classifier: Programming Language :: Python :: 3
14
- Requires-Python: >=3.8
14
+ Requires-Python: >=3.10
15
15
  Requires-Dist: joblib>=1.0.1
16
16
  Requires-Dist: matplotlib>=3.4.3
17
17
  Requires-Dist: narwhals>=1.42.1
18
18
  Requires-Dist: numpy>=1.21.0
19
19
  Requires-Dist: pandas>=1.3.0
20
20
  Requires-Dist: scikit-learn>=0.24.2
21
+ Requires-Dist: scipy>=1.10.1
21
22
  Requires-Dist: tqdm>=4.67.1
23
+ Requires-Dist: typing-extensions>=4.13.2
22
24
  Description-Content-Type: text/markdown
23
25
 
24
26
  ![PyPI - Version](https://img.shields.io/pypi/v/panelsplit)
@@ -30,7 +32,7 @@ panelsplit is a Python package designed to facilitate time series cross-validati
30
32
 
31
33
  ## Installation
32
34
 
33
- panelsplit is tested for compatibility with python versions >= 3.8. You can install panelsplit using pip:
35
+ panelsplit is tested for compatibility with python versions >= 3.10. You can install panelsplit using pip:
34
36
 
35
37
  ```bash
36
38
  pip install panelsplit
@@ -7,7 +7,7 @@ panelsplit is a Python package designed to facilitate time series cross-validati
7
7
 
8
8
  ## Installation
9
9
 
10
- panelsplit is tested for compatibility with python versions >= 3.8. You can install panelsplit using pip:
10
+ panelsplit is tested for compatibility with python versions >= 3.10. You can install panelsplit using pip:
11
11
 
12
12
  ```bash
13
13
  pip install panelsplit
@@ -42,6 +42,21 @@ Explore the modules in detail by clicking on the links below to see full documen
42
42
 
43
43
  ### `panelsplit.plot`
44
44
  - Visualize time series splits easily.
45
+
46
+ ### `panelsplit.model_selection`
47
+ - **Hyperparameter tuning:** Provides GridSearch and RandomizedSearch classes for optimizing model parameters using panel data cross-validation.
48
+ - **Efficient search:** Supports parallel processing and integrates with panelsplit's cross-validation framework.
49
+
50
+ ### `panelsplit.metrics`
51
+ - **Scoring functions:** Offers a range of metrics for evaluating model performance on panel data.
52
+ - **Sequential CV scorers:** Specialized scorers designed for sequential cross-validation splits.
45
53
  """
46
54
 
47
- __all__ = ["application", "cross_validation", "pipeline", "plot"]
55
+ __all__ = [
56
+ "application",
57
+ "cross_validation",
58
+ "metrics",
59
+ "model_selection",
60
+ "pipeline",
61
+ "plot",
62
+ ]
@@ -1,11 +1,15 @@
1
1
  import inspect
2
- from typing import Tuple, List, Union, Optional
2
+ from typing import Tuple, List, Optional, Iterable, Union
3
+ from numpy.typing import NDArray
3
4
 
4
5
  import narwhals as nw
5
6
  import numpy as np
6
7
  from joblib import Parallel, delayed
7
8
  from narwhals.typing import IntoDataFrame, IntoSeries
8
- from sklearn.base import clone, BaseEstimator
9
+ from sklearn.base import clone
10
+ from .utils.typing import ArrayLike, EstimatorLike
11
+ from .cross_validation import PanelSplit
12
+ from typing import Literal
9
13
 
10
14
  from .utils.utils import _split_wrapper
11
15
  from .utils.validation import (
@@ -17,20 +21,22 @@ from .utils.validation import (
17
21
  )
18
22
 
19
23
 
20
- def _get_non_null_mask(data):
24
+ def _get_non_null_mask(data: IntoSeries) -> IntoSeries:
21
25
  """Get non-null mask for any data type."""
22
26
  return ~nw.from_native(data, series_only=True).is_null()
23
27
 
24
28
 
25
- def _predict_split(model, X_test: IntoDataFrame, method: str = "predict") -> np.ndarray:
29
+ def _predict_split(
30
+ model: EstimatorLike, X_test: ArrayLike, method: str = "predict"
31
+ ) -> np.ndarray:
26
32
  """
27
33
  Perform predictions for a single split.
28
34
 
29
35
  Parameters
30
36
  ----------
31
- model : object
37
+ model : EstimatorLike
32
38
  The machine learning model used for prediction.
33
- X_test : IntoDataFrame
39
+ X_test : ArrayLike
34
40
  The input features for testing.
35
41
  method : str, optional
36
42
  The method to use for prediction. It can be 'predict', 'predict_proba',
@@ -46,34 +52,34 @@ def _predict_split(model, X_test: IntoDataFrame, method: str = "predict") -> np.
46
52
 
47
53
 
48
54
  def _fit_split(
49
- estimator,
55
+ estimator: EstimatorLike,
50
56
  X: IntoDataFrame,
51
57
  y: Optional[IntoSeries],
52
- train_indices: np.ndarray,
53
- sample_weight: Optional[Union[IntoSeries, np.ndarray]] = None,
58
+ train_indices: NDArray,
59
+ sample_weight: Optional[Union[IntoSeries, NDArray]] = None,
54
60
  drop_na_in_y: bool = False,
55
- ):
61
+ ) -> EstimatorLike:
56
62
  """
57
63
  Fit a cloned estimator on the given training indices.
58
64
 
59
65
  Parameters
60
66
  ----------
61
- estimator : object
67
+ estimator : EstimatorLike
62
68
  The machine learning model to be fitted.
63
69
  X : IntoDataFrame
64
70
  The input features for the estimator.
65
- y : IntoSeries or None
66
- The target variable for the estimator.
67
- train_indices : np.ndarray
71
+ y : Optional[IntoSeries]
72
+ The target variable for the estimator. Default is None.
73
+ train_indices : NDArray
68
74
  Integer indices indicating the training data.
69
- sample_weight : IntoSeries or np.ndarray, optional
75
+ sample_weight : Optional[Union[IntoSeries, NDArray]]
70
76
  Sample weights for the training data. Default is None.
71
- drop_na_in_y : bool, default=False
72
- Whether to drop rows with null values in y.
77
+ drop_na_in_y : bool
78
+ Whether to drop rows with null values in y. Default is False
73
79
 
74
80
  Returns
75
81
  -------
76
- object
82
+ EstimatorLike
77
83
  A fitted estimator.
78
84
  """
79
85
  local_estimator = clone(estimator)
@@ -152,41 +158,41 @@ def _prediction_order_to_original_order(indices: List[np.ndarray]) -> np.ndarray
152
158
 
153
159
 
154
160
  def cross_val_fit(
155
- estimator,
161
+ estimator: EstimatorLike,
156
162
  X: IntoDataFrame,
157
163
  y: IntoSeries,
158
- cv,
164
+ cv: Union[PanelSplit, Iterable],
159
165
  sample_weight: Optional[Union[IntoSeries, np.ndarray]] = None,
160
166
  n_jobs: int = 1,
161
167
  progress_bar: bool = False,
162
168
  drop_na_in_y: bool = False,
163
- ) -> List[BaseEstimator]:
169
+ ) -> List[EstimatorLike]:
164
170
  """
165
171
  Fit the estimator using cross-validation.
166
172
 
167
173
  Parameters
168
174
  ----------
169
- estimator : object
175
+ estimator : EstimatorLike
170
176
  The machine learning model to be fitted.
171
177
  X : IntoDataFrame
172
178
  The input features for the estimator.
173
179
  y : IntoSeries
174
180
  The target variable for the estimator.
175
- cv : object or iterable
181
+ cv : Union[PanelSplit, Iterable]
176
182
  Cross-validation splitter; either an object that generates train/test splits (e.g., an instance of PanelSplit)
177
183
  or an iterable of splits.
178
- sample_weight : IntoSeries or np.ndarray, optional
184
+ sample_weight : Optional[Union[IntoSeries, np.ndarray]]
179
185
  Sample weights for the training data. Default is None.
180
- n_jobs : int, optional
186
+ n_jobs : int
181
187
  The number of jobs to run in parallel. Default is 1.
182
- progress_bar : bool, optional
188
+ progress_bar : bool
183
189
  Whether to display a progress bar. Default is False.
184
- drop_na_in_y : bool, optional
190
+ drop_na_in_y : bool
185
191
  Whether to drop observations where y is na. Default is False.
186
192
 
187
193
  Returns
188
194
  -------
189
- list
195
+ List[EstimatorLike]
190
196
  List containing fitted models for each split.
191
197
 
192
198
  Examples
@@ -195,14 +201,11 @@ def cross_val_fit(
195
201
  >>> from sklearn.linear_model import LinearRegression
196
202
  >>> from panelsplit.cross_validation import PanelSplit
197
203
  >>> # Create sample data
198
- >>> df = pd.DataFrame({
199
- ... 'feature': [1, 2, 3, 4, 5, 6],
200
- ... 'period': [1, 1, 2, 2, 3, 3]
201
- ... })
202
- >>> X = df[['feature']]
204
+ >>> df = pd.DataFrame({"feature": [1, 2, 3, 4, 5, 6], "period": [1, 1, 2, 2, 3, 3]})
205
+ >>> X = df[["feature"]]
203
206
  >>> y = pd.Series([2, 4, 6, 8, 10, 12])
204
207
  >>> # Create a PanelSplit instance for cross-validation
205
- >>> ps = PanelSplit(periods=df['period'], n_splits=2)
208
+ >>> ps = PanelSplit(periods=df["period"], n_splits=2)
206
209
  >>> fitted_models = cross_val_fit(LinearRegression(), X, y, ps)
207
210
  >>> len(fitted_models)
208
211
  2
@@ -223,142 +226,136 @@ def cross_val_fit(
223
226
 
224
227
 
225
228
  def cross_val_predict(
226
- fitted_estimators,
229
+ fitted_estimators: List[EstimatorLike],
227
230
  X: IntoDataFrame,
228
- cv,
231
+ cv: Union[PanelSplit, Iterable],
229
232
  method: str = "predict",
230
233
  n_jobs: int = 1,
231
- return_train_preds: bool = False,
232
- ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
234
+ return_group: Literal["test", "train"] = "test",
235
+ ) -> np.ndarray:
233
236
  """
234
237
  Perform cross-validated predictions using a given predictor model.
235
238
 
236
239
  Parameters
237
240
  ----------
238
- fitted_estimators : list
241
+ fitted_estimators : List[EstimatorLike]
239
242
  List of fitted machine learning models used for prediction.
240
243
  X : IntoDataFrame
241
244
  The input features for prediction.
242
- cv : object or iterable
245
+ cv : Union[PanelSplit, Iterable]
243
246
  Cross-validation splitter; either an object that generates train/test splits or an iterable of splits.
244
- method : str, optional
245
- The method to use for prediction. It can be whatever methods are available to the estimator
247
+ method : str
248
+ The method to use for prediction. It can be whatever methods are available to the estimator.
246
249
  (e.g. predict_proba in the case of a classifier or transform in the case of a transformer). Default is 'predict'.
247
- n_jobs : int, optional
250
+ n_jobs : int
248
251
  The number of jobs to run in parallel. Default is 1.
249
- return_train_preds : bool, optional
250
- If True, return predictions for the training set as well. Default is False.
252
+ return_group : {"test","train"}
253
+ Whether to return the train or test predictions. Default is "test".
251
254
 
252
255
  Returns
253
256
  -------
254
- test_preds : np.ndarray
255
- Array containing test predictions made by the model during cross-validation.
256
- train_preds : np.ndarray, optional
257
- Array containing train predictions made by the model during cross-validation.
258
- Returned only if `return_train_preds` is True.
257
+ np.ndarray
258
+ Predictions (either train or test depending on return_group).
259
+
260
+ Examples
261
+ --------
262
+ >>> from sklearn.linear_model import LinearRegression
263
+ >>> import numpy as np
264
+ >>> from panelsplit.cross_validation import PanelSplit,
265
+ >>> from panelsplit.application import cross_val_predict, cross_val_fit
266
+ >>> X = np.arange(12).reshape(6, 2)
267
+ >>> y = np.array([1, 2, 3, 4, 5, 6])
268
+ >>> ps = PanelSplit(periods=np.array([1, 1, 2, 2, 3, 3]), n_splits=2)
269
+ >>> estimators = cross_val_fit(LinearRegression(), X, y, ps)
270
+ >>> preds = cross_val_predict(estimators, X, ps)
271
+ >>> preds.shape
272
+ (4,)
259
273
  """
260
274
  check_fitted_estimators(fitted_estimators)
261
275
  splits = check_cv(cv)
276
+ if return_group not in ["train", "test"]:
277
+ raise ValueError(
278
+ f"return_group must be train or test. Got {return_group} instead."
279
+ )
262
280
 
263
- test_splits = [split[1] for split in splits]
264
- test_indices = _prediction_order_to_original_order(test_splits)
281
+ group = 0 if return_group == "train" else 1
282
+ group_splits = [split[group] for split in splits]
283
+ group_indices = _prediction_order_to_original_order(group_splits)
265
284
 
266
285
  # Use narwhals for dataframe-agnostic operations
267
286
  X_nw = nw.from_native(X, pass_through=True)
268
287
 
269
- test_preds = Parallel(n_jobs=n_jobs)(
288
+ preds = Parallel(n_jobs=n_jobs)(
270
289
  delayed(_predict_split)(
271
290
  fitted_estimators[i],
272
291
  _safe_indexing(X_nw, test_idx, to_native=True),
273
292
  method,
274
293
  )
275
- for i, test_idx in enumerate(test_splits)
294
+ for i, test_idx in enumerate(group_splits)
276
295
  )
277
296
 
278
- if return_train_preds:
279
- train_splits = [split[0] for split in splits]
280
- train_indices = _prediction_order_to_original_order(train_splits)
281
-
282
- train_preds = Parallel(n_jobs=n_jobs)(
283
- delayed(_predict_split)(
284
- fitted_estimators[i],
285
- _safe_indexing(X_nw, train_idx, to_native=True),
286
- method,
287
- )
288
- for i, train_idx in enumerate(train_splits)
289
- )
290
-
291
- return np.concatenate(test_preds, axis=0)[test_indices], np.concatenate(
292
- train_preds, axis=0
293
- )[train_indices]
294
- else:
295
- return np.concatenate(test_preds, axis=0)[test_indices]
297
+ return np.concatenate(preds, axis=0)[group_indices]
296
298
 
297
299
 
298
300
  def cross_val_fit_predict(
299
- estimator,
301
+ estimator: EstimatorLike,
300
302
  X: IntoDataFrame,
301
303
  y: IntoSeries,
302
- cv,
304
+ cv: Union[PanelSplit, Iterable],
303
305
  method: str = "predict",
304
306
  sample_weight: Optional[Union[IntoSeries, np.ndarray]] = None,
305
307
  n_jobs: int = 1,
306
- return_train_preds: bool = False,
307
- drop_na_in_y=False,
308
- ) -> Union[
309
- Tuple[np.ndarray, List[BaseEstimator]],
310
- Tuple[np.ndarray, np.ndarray, List[BaseEstimator]],
311
- ]:
308
+ return_group: Literal["test", "train"] = "test",
309
+ drop_na_in_y: bool = False,
310
+ ) -> Tuple[np.ndarray, List[EstimatorLike]]:
312
311
  """
313
312
  Fit the estimator using cross-validation and then make predictions.
314
313
 
315
314
  Parameters
316
315
  ----------
317
- estimator : object
316
+ estimator : EstimatorLike
318
317
  The machine learning model to be fitted.
319
318
  X : IntoDataFrame
320
319
  The input features for the estimator.
321
320
  y : IntoSeries
322
321
  The target variable for the estimator.
323
- cv : object
322
+ cv : Union[PanelSplit, Iterable]
324
323
  Cross-validation splitter; an object that generates train/test splits.
325
- method : str, optional
326
- The method to use for prediction. It can be whatever methods are available to the estimator
327
- (e.g. predict_proba in the case of a classifier or transform in the case of a transformer). Default is 'predict'.
328
- sample_weight : IntoSeries or np.ndarray, optional
324
+ method : str
325
+ The method to use for prediction. It can be any method available on the estimator
326
+ (e.g., ``predict_proba`` for classifiers or ``transform`` for transformers). Default is predict.
327
+ sample_weight : Optional[Union[IntoSeries, np.ndarray]]
329
328
  Sample weights for the training data. Default is None.
330
- n_jobs : int, optional
329
+ n_jobs : int
331
330
  The number of jobs to run in parallel. Default is 1.
332
- return_train_preds : bool, optional
333
- If True, return predictions for the training set as well. Default is False.
334
- drop_na_in_y : bool, optional
335
- Whether to drop observations where y is na. Default is False.
331
+ return_group : {"test","train"}
332
+ Whether to return the train or test predictions. Default is test.
333
+ drop_na_in_y : bool
334
+ Whether to drop observations where ``y`` is NA. Default is False.
336
335
 
337
336
  Returns
338
337
  -------
339
- tuple
340
- If `return_train_preds` is False, returns a tuple of:
341
- - preds (np.ndarray): Array containing predictions made by the model during cross-validation.
342
- - fitted_estimators (list): List containing fitted models for each split.
343
- If `return_train_preds` is True, returns a tuple of:
344
- - preds (np.ndarray): Array containing test predictions made by the model during cross-validation.
345
- - train_preds (np.ndarray): Array containing train predictions made by the model during cross-validation.
346
- - fitted_estimators (list): List containing fitted models for each split.
338
+ Tuple[np.ndarray, List[EstimatorLike]]
339
+ (predictions (either train or test depending on return_group), fitted_estimators).
340
+
341
+ Raises
342
+ ------
343
+ TypeError
344
+ If the provided estimator does not implement the specified ``method`` or has invalid type.
347
345
 
348
346
  Examples
349
347
  --------
350
348
  >>> import pandas as pd
351
349
  >>> from sklearn.linear_model import LinearRegression
352
- >>> from panelsplit.cross_validation import PanelSplit # assuming PanelSplit is imported from your module
350
+ >>> from panelsplit.cross_validation import (
351
+ ... PanelSplit,
352
+ ... ) # assuming PanelSplit is imported from your module
353
353
  >>> # Create sample data
354
- >>> df = pd.DataFrame({
355
- ... 'feature': [1, 2, 3, 4, 5, 6],
356
- ... 'period': [1, 1, 2, 2, 3, 3]
357
- ... })
358
- >>> X = df[['feature']]
354
+ >>> df = pd.DataFrame({"feature": [1, 2, 3, 4, 5, 6], "period": [1, 1, 2, 2, 3, 3]})
355
+ >>> X = df[["feature"]]
359
356
  >>> y = pd.Series([2, 4, 6, 8, 10, 12])
360
357
  >>> # Create a PanelSplit instance for cross-validation
361
- >>> ps = PanelSplit(periods=df['period'], n_splits=2)
358
+ >>> ps = PanelSplit(periods=df["period"], n_splits=2)
362
359
  >>> # Get test predictions and fitted models
363
360
  >>> preds, models = cross_val_fit_predict(LinearRegression(), X, y, ps)
364
361
  >>> preds.shape
@@ -369,22 +366,6 @@ def cross_val_fit_predict(
369
366
  estimator, X, y, cv, sample_weight, n_jobs, drop_na_in_y=drop_na_in_y
370
367
  )
371
368
 
372
- res = cross_val_predict(
373
- fitted_estimators, X, cv, method, n_jobs, return_train_preds
374
- )
369
+ preds = cross_val_predict(fitted_estimators, X, cv, method, n_jobs, return_group)
375
370
 
376
- if return_train_preds:
377
- # res should be Tuple[np.ndarray, np.ndarray]
378
- if isinstance(res, tuple):
379
- preds, train_preds = res
380
- else:
381
- # defensive: unexpected type at runtime
382
- raise TypeError("cross_val_predict returned ndarray but expected tuple")
383
- return preds, train_preds, fitted_estimators
384
- else:
385
- # res should be np.ndarray
386
- if isinstance(res, tuple):
387
- preds = res[0]
388
- else:
389
- preds = res
390
- return preds, fitted_estimators
371
+ return preds, fitted_estimators