skfolio 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. skfolio/__init__.py +2 -2
  2. skfolio/cluster/__init__.py +1 -1
  3. skfolio/cluster/_hierarchical.py +1 -1
  4. skfolio/datasets/__init__.py +1 -1
  5. skfolio/datasets/_base.py +2 -2
  6. skfolio/datasets/data/__init__.py +1 -0
  7. skfolio/distance/__init__.py +1 -1
  8. skfolio/distance/_base.py +2 -2
  9. skfolio/distance/_distance.py +4 -4
  10. skfolio/distribution/__init__.py +56 -0
  11. skfolio/distribution/_base.py +203 -0
  12. skfolio/distribution/copula/__init__.py +35 -0
  13. skfolio/distribution/copula/_base.py +456 -0
  14. skfolio/distribution/copula/_clayton.py +539 -0
  15. skfolio/distribution/copula/_gaussian.py +407 -0
  16. skfolio/distribution/copula/_gumbel.py +560 -0
  17. skfolio/distribution/copula/_independent.py +196 -0
  18. skfolio/distribution/copula/_joe.py +609 -0
  19. skfolio/distribution/copula/_selection.py +111 -0
  20. skfolio/distribution/copula/_student_t.py +486 -0
  21. skfolio/distribution/copula/_utils.py +509 -0
  22. skfolio/distribution/multivariate/__init__.py +11 -0
  23. skfolio/distribution/multivariate/_base.py +241 -0
  24. skfolio/distribution/multivariate/_utils.py +632 -0
  25. skfolio/distribution/multivariate/_vine_copula.py +1254 -0
  26. skfolio/distribution/univariate/__init__.py +19 -0
  27. skfolio/distribution/univariate/_base.py +308 -0
  28. skfolio/distribution/univariate/_gaussian.py +136 -0
  29. skfolio/distribution/univariate/_johnson_su.py +152 -0
  30. skfolio/distribution/univariate/_normal_inverse_gaussian.py +153 -0
  31. skfolio/distribution/univariate/_selection.py +85 -0
  32. skfolio/distribution/univariate/_student_t.py +144 -0
  33. skfolio/exceptions.py +6 -6
  34. skfolio/measures/__init__.py +1 -1
  35. skfolio/measures/_enums.py +7 -7
  36. skfolio/measures/_measures.py +4 -7
  37. skfolio/metrics/__init__.py +2 -0
  38. skfolio/metrics/_scorer.py +4 -4
  39. skfolio/model_selection/__init__.py +2 -2
  40. skfolio/model_selection/_combinatorial.py +15 -12
  41. skfolio/model_selection/_validation.py +2 -2
  42. skfolio/model_selection/_walk_forward.py +3 -3
  43. skfolio/moments/covariance/_base.py +1 -1
  44. skfolio/moments/covariance/_denoise_covariance.py +1 -1
  45. skfolio/moments/covariance/_detone_covariance.py +1 -1
  46. skfolio/moments/covariance/_empirical_covariance.py +1 -1
  47. skfolio/moments/covariance/_ew_covariance.py +1 -1
  48. skfolio/moments/covariance/_gerber_covariance.py +1 -1
  49. skfolio/moments/covariance/_graphical_lasso_cv.py +1 -1
  50. skfolio/moments/covariance/_implied_covariance.py +2 -7
  51. skfolio/moments/covariance/_ledoit_wolf.py +1 -1
  52. skfolio/moments/covariance/_oas.py +1 -1
  53. skfolio/moments/covariance/_shrunk_covariance.py +1 -1
  54. skfolio/moments/expected_returns/_base.py +1 -1
  55. skfolio/moments/expected_returns/_empirical_mu.py +1 -1
  56. skfolio/moments/expected_returns/_equilibrium_mu.py +1 -1
  57. skfolio/moments/expected_returns/_ew_mu.py +1 -1
  58. skfolio/moments/expected_returns/_shrunk_mu.py +2 -2
  59. skfolio/optimization/__init__.py +2 -0
  60. skfolio/optimization/_base.py +2 -2
  61. skfolio/optimization/cluster/__init__.py +2 -0
  62. skfolio/optimization/cluster/_nco.py +7 -7
  63. skfolio/optimization/cluster/hierarchical/__init__.py +2 -0
  64. skfolio/optimization/cluster/hierarchical/_base.py +1 -2
  65. skfolio/optimization/cluster/hierarchical/_herc.py +2 -2
  66. skfolio/optimization/cluster/hierarchical/_hrp.py +2 -2
  67. skfolio/optimization/convex/__init__.py +2 -0
  68. skfolio/optimization/convex/_base.py +8 -8
  69. skfolio/optimization/convex/_distributionally_robust.py +4 -4
  70. skfolio/optimization/convex/_maximum_diversification.py +5 -5
  71. skfolio/optimization/convex/_mean_risk.py +5 -6
  72. skfolio/optimization/convex/_risk_budgeting.py +3 -3
  73. skfolio/optimization/ensemble/__init__.py +2 -0
  74. skfolio/optimization/ensemble/_base.py +2 -2
  75. skfolio/optimization/ensemble/_stacking.py +1 -1
  76. skfolio/optimization/naive/__init__.py +2 -0
  77. skfolio/optimization/naive/_naive.py +1 -1
  78. skfolio/population/__init__.py +2 -0
  79. skfolio/population/_population.py +35 -9
  80. skfolio/portfolio/_base.py +42 -8
  81. skfolio/portfolio/_multi_period_portfolio.py +3 -2
  82. skfolio/portfolio/_portfolio.py +4 -4
  83. skfolio/pre_selection/__init__.py +2 -0
  84. skfolio/pre_selection/_drop_correlated.py +2 -2
  85. skfolio/pre_selection/_select_complete.py +25 -26
  86. skfolio/pre_selection/_select_k_extremes.py +2 -2
  87. skfolio/pre_selection/_select_non_dominated.py +2 -2
  88. skfolio/pre_selection/_select_non_expiring.py +2 -2
  89. skfolio/preprocessing/__init__.py +2 -0
  90. skfolio/preprocessing/_returns.py +2 -2
  91. skfolio/prior/__init__.py +4 -0
  92. skfolio/prior/_base.py +2 -2
  93. skfolio/prior/_black_litterman.py +5 -3
  94. skfolio/prior/_empirical.py +3 -1
  95. skfolio/prior/_factor_model.py +8 -4
  96. skfolio/prior/_synthetic_data.py +239 -0
  97. skfolio/synthetic_returns/__init__.py +1 -0
  98. skfolio/typing.py +1 -1
  99. skfolio/uncertainty_set/__init__.py +2 -0
  100. skfolio/uncertainty_set/_base.py +2 -2
  101. skfolio/uncertainty_set/_bootstrap.py +1 -1
  102. skfolio/uncertainty_set/_empirical.py +1 -1
  103. skfolio/utils/__init__.py +1 -0
  104. skfolio/utils/bootstrap.py +2 -2
  105. skfolio/utils/equations.py +13 -10
  106. skfolio/utils/sorting.py +2 -2
  107. skfolio/utils/stats.py +7 -7
  108. skfolio/utils/tools.py +76 -12
  109. {skfolio-0.7.0.dist-info → skfolio-0.8.1.dist-info}/METADATA +99 -24
  110. skfolio-0.8.1.dist-info/RECORD +120 -0
  111. {skfolio-0.7.0.dist-info → skfolio-0.8.1.dist-info}/WHEEL +1 -1
  112. skfolio-0.7.0.dist-info/RECORD +0 -95
  113. {skfolio-0.7.0.dist-info → skfolio-0.8.1.dist-info/licenses}/LICENSE +0 -0
  114. {skfolio-0.7.0.dist-info → skfolio-0.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,239 @@
1
+ """Synthetic Data Prior Model estimator."""
2
+
3
+ # Copyright (c) 2025
4
+ # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
+
7
+ import inspect
8
+
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ import sklearn.base as skb
12
+ import sklearn.utils.metadata_routing as skm
13
+ import sklearn.utils.validation as skv
14
+
15
+ from skfolio.distribution import VineCopula
16
+ from skfolio.prior._base import BasePrior
17
+ from skfolio.prior._empirical import EmpiricalPrior
18
+ from skfolio.utils.tools import check_estimator
19
+
20
+
21
+ class SyntheticData(BasePrior):
22
+ """Synthetic Data Estimator.
23
+
24
+ The Synthetic Data model estimates a :class:`~skfolio.prior.PriorModel` by
25
+ fitting a `distribution_estimator` and sampling new returns data from it.
26
+
27
+ The default ``distribution_estimator`` is a Regular Vine Copula model. Other common
28
+ choices are Generative Adversarial Networks (GANs) or Variational Autoencoders
29
+ (VAEs).
30
+
31
+ This class is particularly useful when the historical distribution tail dependencies
32
+ are sparse and need extrapolation for tail optimizations or when optimizing under
33
+ conditional or stressed scenarios.
34
+
35
+ Parameters
36
+ ----------
37
+ distribution_estimator : BaseEstimator, optional
38
+ Estimator to model the distribution of asset returns. It must inherit from
39
+ `BaseEstimator` and implements a `sample` method. If None, the default
40
+ `VineCopula()` model is used.
41
+
42
+ n_samples : int, default=1000
43
+ Number of samples to generate from the `distribution_estimator`, default is
44
+ 1000.
45
+
46
+ sample_args : dict, optional
47
+ Additional keyword arguments to pass to the `sample` method of the
48
+ `distribution_estimator`.
49
+
50
+ Attributes
51
+ ----------
52
+ prior_model_ : PriorModel
53
+ The assets :class:`~skfolio.prior.PriorModel`.
54
+
55
+ distribution_estimator_ : BaseEstimator
56
+ The fitted distribution estimator.
57
+
58
+ n_features_in_ : int
59
+ Number of assets seen during `fit`.
60
+
61
+ feature_names_in_ : ndarray of shape (`n_features_in_`,)
62
+ Names of features seen during `fit`. Defined only when `X`
63
+ has feature names that are all strings.
64
+
65
+ Examples
66
+ --------
67
+ >>> import numpy as np
68
+ >>> from skfolio.datasets import load_sp500_dataset, load_factors_dataset
69
+ >>> from skfolio.preprocessing import prices_to_returns
70
+ >>> from skfolio.distribution import VineCopula
71
+ >>> from skfolio.optimization import MeanRisk
72
+ >>> from skfolio.prior import FactorModel, SyntheticData
73
+ >>> from skfolio import RiskMeasure
74
+ >>>
75
+ >>> # Load historical prices and convert them to returns
76
+ >>> prices = load_sp500_dataset()
77
+ >>> factors = load_factors_dataset()
78
+ >>> X, y = prices_to_returns(prices, factors)
79
+ >>>
80
+ >>> # Instanciate the SyntheticData model and fit it
81
+ >>> model = SyntheticData()
82
+ >>> model.fit(X)
83
+ >>> print(model.prior_model_)
84
+ >>>
85
+ >>> # Minimum CVaR optimization on synthetic returns
86
+ >>> model = MeanRisk(
87
+ ... risk_measure=RiskMeasure.CVAR,
88
+ ... prior_estimator=SyntheticData(
89
+ ... distribution_estimator=VineCopula(log_transform=True, n_jobs=-1),
90
+ ... n_samples=2000,
91
+ ... )
92
+ ... )
93
+ >>> model.fit(X)
94
+ >>> print(model.weights_)
95
+ >>>
96
+ >>> # Minimum CVaR optimization on Stressed Factors
97
+ >>> factor_model = FactorModel(
98
+ ... factor_prior_estimator=SyntheticData(
99
+ ... distribution_estimator=VineCopula(
100
+ ... central_assets=["QUAL"],
101
+ ... log_transform=True,
102
+ ... n_jobs=-1,
103
+ ... ),
104
+ ... n_samples=5000,
105
+ ... sample_args=dict(conditioning={"QUAL": -0.2}),
106
+ ... )
107
+ ... )
108
+ >>> model = MeanRisk(risk_measure=RiskMeasure.CVAR, prior_estimator=factor_model)
109
+ >>> model.fit(X, y)
110
+ >>> print(model.weights_)
111
+ >>>
112
+ >>> # Stress Test the Portfolio
113
+ >>> factor_model.set_params(factor_prior_estimator__sample_args=dict(
114
+ ... conditioning={"QUAL": -0.5}
115
+ ... ))
116
+ >>> factor_model.fit(X,y)
117
+ >>> stressed_X = factor_model.prior_model_.returns
118
+ >>> stressed_ptf = model.predict(stressed_X)
119
+ """
120
+
121
+ distribution_estimator_: skb.BaseEstimator
122
+ prior_estimator_: BasePrior
123
+ n_features_in_: int
124
+ feature_names_in_: np.ndarray
125
+
126
+ def __init__(
127
+ self,
128
+ distribution_estimator: skb.BaseEstimator | None = None,
129
+ n_samples: int = 1000,
130
+ sample_args: dict | None = None,
131
+ ):
132
+ self.distribution_estimator = distribution_estimator
133
+ self.n_samples = n_samples
134
+ self.sample_args = sample_args
135
+
136
+ def get_metadata_routing(self):
137
+ # noinspection PyTypeChecker
138
+ router = skm.MetadataRouter(owner=self.__class__.__name__).add(
139
+ distance_estimator=self.distribution_estimator,
140
+ method_mapping=skm.MethodMapping().add(caller="fit", callee="fit"),
141
+ )
142
+ return router
143
+
144
+ def fit(self, X: npt.ArrayLike, y=None, **fit_params) -> "SyntheticData":
145
+ """Fit the Synthetic Data estimator.
146
+
147
+ Parameters
148
+ ----------
149
+ X : array-like of shape (n_observations, n_assets)
150
+ Price returns of the assets.
151
+
152
+ y : Ignored
153
+ Not used, present for API consistency by convention.
154
+
155
+ **fit_params : dict
156
+ Parameters to pass to the underlying estimators.
157
+ Only available if `enable_metadata_routing=True`, which can be
158
+ set by using ``sklearn.set_config(enable_metadata_routing=True)``.
159
+ See :ref:`Metadata Routing User Guide <metadata_routing>` for
160
+ more details.
161
+
162
+ Returns
163
+ -------
164
+ self : SyntheticData
165
+ Fitted estimator.
166
+ """
167
+ routed_params = skm.process_routing(self, "fit", **fit_params)
168
+
169
+ self.distribution_estimator_ = check_estimator(
170
+ self.distribution_estimator,
171
+ default=VineCopula(),
172
+ check_type=skb.BaseEstimator,
173
+ )
174
+ _check_sample_method(self.distribution_estimator_)
175
+
176
+ # fitting distribution estimator on prior returns
177
+ # noinspection PyUnresolvedReferences
178
+ self.distribution_estimator_.fit(
179
+ X, y, **routed_params.distribution_estimator.fit
180
+ )
181
+
182
+ # We validate after all models have been fitted to keep feature names
183
+ # information.
184
+ skv.validate_data(self, X)
185
+
186
+ # sample from the distribution estimator
187
+ sample_args = self.sample_args if self.sample_args is not None else {}
188
+ # noinspection PyUnresolvedReferences
189
+ synthetic_data = self.distribution_estimator_.sample(
190
+ n_samples=self.n_samples, **sample_args
191
+ )
192
+
193
+ # When performing conditional sampling, the conditioning samples are often
194
+ # constant. To avoid null variance, we add a small white noise.
195
+ constant_returns = np.var(synthetic_data, axis=0) < 1e-14
196
+ if np.any(constant_returns):
197
+ noise = 1e-6 * np.random.randn(len(synthetic_data), 1)
198
+ synthetic_data[:, constant_returns] += noise
199
+
200
+ # Fit empirical posterior estimator
201
+ posterior_estimator = EmpiricalPrior()
202
+ posterior_estimator.fit(synthetic_data)
203
+ self.prior_model_ = posterior_estimator.prior_model_
204
+
205
+ return self
206
+
207
+
208
+ def _check_sample_method(distribution_estimator: skb.BaseEstimator) -> None:
209
+ """Check that the distribution_estimator implements a valid 'sample' method.
210
+
211
+ This helper function verifies that the given estimator has a callable 'sample'
212
+ method and that this method accepts an 'n_samples' parameter.
213
+
214
+ Parameters
215
+ ----------
216
+ distribution_estimator : BaseEstimator
217
+ The estimator whose 'sample' method is to be validated.
218
+
219
+ Raises
220
+ ------
221
+ ValueError
222
+ If the 'sample' method is missing or does not have an 'n_samples' parameter.
223
+ """
224
+ # Get the 'sample' attribute; if it doesn't exist, return False.
225
+ sample_method = getattr(distribution_estimator, "sample", None)
226
+ if sample_method is None or not callable(sample_method):
227
+ raise ValueError(
228
+ f"The distribution_estimator {distribution_estimator} must implement a "
229
+ "`sample` method"
230
+ )
231
+
232
+ sig = inspect.signature(sample_method)
233
+
234
+ # Check if the parameter 'n_samples' is in the method's parameters.
235
+ if "n_samples" not in sig.parameters:
236
+ raise ValueError(
237
+ "The `sample` method of the distribution_estimator "
238
+ f"{distribution_estimator} must have `n_samples` as parameter"
239
+ )
@@ -0,0 +1 @@
1
+ """Synthetic Data module."""
skfolio/typing.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  # Copyright (c) 2023
4
4
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
- # License: BSD 3 clause
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
6
 
7
7
  from collections.abc import Callable
8
8
 
@@ -1,3 +1,5 @@
1
+ """Uncertainty Set module."""
2
+
1
3
  from skfolio.uncertainty_set._base import (
2
4
  BaseCovarianceUncertaintySet,
3
5
  BaseMuUncertaintySet,
@@ -1,8 +1,8 @@
1
- """Base Uncertainty estimator"""
1
+ """Base Uncertainty estimator."""
2
2
 
3
3
  # Copyright (c) 2023
4
4
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
- # License: BSD 3 clause
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
6
 
7
7
  from abc import ABC, abstractmethod
8
8
  from dataclasses import dataclass
@@ -2,7 +2,7 @@
2
2
 
3
3
  # Copyright (c) 2023
4
4
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
- # License: BSD 3 clause
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
6
  # Implementation derived from:
7
7
  # Riskfolio-Lib, Copyright (c) 2020-2023, Dany Cajas, Licensed under BSD 3 clause.
8
8
  # scikit-learn, Copyright (c) 2007-2010 David Cournapeau, Fabian Pedregosa, Olivier
@@ -2,7 +2,7 @@
2
2
 
3
3
  # Copyright (c) 2023
4
4
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
- # License: BSD 3 clause
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
6
  # Implementation derived from:
7
7
  # Riskfolio-Lib, Copyright (c) 2020-2023, Dany Cajas, Licensed under BSD 3 clause.
8
8
  # scikit-learn, Copyright (c) 2007-2010 David Cournapeau, Fabian Pedregosa, Olivier
skfolio/utils/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ """Utils module."""
@@ -2,7 +2,7 @@
2
2
 
3
3
  # Copyright (c) 2023
4
4
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
- # License: BSD 3 clause
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
6
  # Implementation derived from:
7
7
  # Riskfolio-Lib, Copyright (c) 2020-2023, Dany Cajas, Licensed under BSD 3 clause.
8
8
 
@@ -71,7 +71,7 @@ def stationary_bootstrap(
71
71
  block_size: float | None = None,
72
72
  seed: int | None = None,
73
73
  ) -> np.ndarray:
74
- """Creates `n_bootstrap_samples` samples from a multivariate return series via
74
+ """Create `n_bootstrap_samples` samples from a multivariate return series via
75
75
  stationary bootstrapping.
76
76
 
77
77
  Parameters
@@ -1,8 +1,8 @@
1
- """Equation module"""
1
+ """Equation module."""
2
2
 
3
3
  # Copyright (c) 2023
4
4
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
- # License: BSD 3 clause
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
6
 
7
7
  import re
8
8
  import warnings
@@ -44,7 +44,8 @@ def equations_to_matrix(
44
44
  groups : array-like of shape (n_groups, n_assets)
45
45
  2D array of assets groups.
46
46
 
47
- Examples:
47
+ For example:
48
+
48
49
  groups = np.array(
49
50
  [
50
51
  ["SPX", "SX5E", "NKY", "TLT"],
@@ -66,7 +67,8 @@ def equations_to_matrix(
66
67
  The second expression means that the sum of all assets in "group_1" should be
67
68
  less or equal to "number" times the sum of all assets in "group_2".
68
69
 
69
- Examples:
70
+ For example:
71
+
70
72
  equations = [
71
73
  "Equity <= 3 * Bond",
72
74
  "US >= 1.5",
@@ -143,9 +145,10 @@ def group_cardinalities_to_matrix(
143
145
  Parameters
144
146
  ----------
145
147
  groups : array-like of shape (n_groups, n_assets)
146
- 2D array of assets groups.
148
+ 2D array of assets groups.
149
+
150
+ For example:
147
151
 
148
- Examples:
149
152
  groups = np.array(
150
153
  [
151
154
  ["Equity", "Equity", "Equity", "Bond"],
@@ -154,8 +157,8 @@ def group_cardinalities_to_matrix(
154
157
  )
155
158
 
156
159
  group_cardinalities : dict[str, int]
157
- Dictionary of cardinality constraint per group.
158
- Examples: {"Equity": 1, "US": 3}
160
+ Dictionary of cardinality constraint per group.
161
+ For example: {"Equity": 1, "US": 3}
159
162
 
160
163
  raise_if_group_missing : bool, default=False
161
164
  If this is set to True, an error is raised when a group is not found in the
@@ -302,7 +305,7 @@ def _comparison_operator_sign(operator: str) -> int:
302
305
 
303
306
 
304
307
  def _sub_add_operator_sign(operator: str) -> int:
305
- """Convert the operators '+' and '-' into 1 or -1
308
+ """Convert the operators '+' and '-' into 1 or -1.
306
309
 
307
310
  Parameters
308
311
  ----------
@@ -342,7 +345,7 @@ def _string_to_float(string: str) -> float:
342
345
 
343
346
 
344
347
  def _split_equation_string(string: str) -> list[str]:
345
- """Split an equation strings by operators"""
348
+ """Split an equation strings by operators."""
346
349
  comp_pattern = "(?=" + "|".join([".+\\" + e for e in _COMPARISON_OPERATORS]) + ")"
347
350
  if not bool(re.match(comp_pattern, string)):
348
351
  raise EquationToMatrixError(
skfolio/utils/sorting.py CHANGED
@@ -1,8 +1,8 @@
1
- """Fast non-dominated sorting module"""
1
+ """Fast non-dominated sorting module."""
2
2
 
3
3
  # Copyright (c) 2023
4
4
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
- # License: BSD 3 clause
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
6
 
7
7
  import numpy as np
8
8
 
skfolio/utils/stats.py CHANGED
@@ -1,10 +1,10 @@
1
- """Tools module"""
1
+ """Tools module."""
2
2
 
3
3
  import warnings
4
4
 
5
5
  # Copyright (c) 2023
6
6
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
7
- # License: BSD 3 clause
7
+ # SPDX-License-Identifier: BSD-3-Clause
8
8
  # Implementation derived from:
9
9
  # Riskfolio-Lib, Copyright (c) 2020-2023, Dany Cajas, Licensed under BSD 3 clause.
10
10
  # Statsmodels, Copyright (C) 2006, Jonathan E. Taylor, Licensed under BSD 3 clause.
@@ -41,7 +41,7 @@ __all__ = [
41
41
 
42
42
 
43
43
  class NBinsMethod(AutoEnum):
44
- """Enumeration of the Number of Bins Methods
44
+ """Enumeration of the Number of Bins Methods.
45
45
 
46
46
  Parameters
47
47
  ----------
@@ -82,7 +82,7 @@ def n_bins_freedman(x: np.ndarray) -> int:
82
82
  if d == 0:
83
83
  return 5
84
84
  n_bins = max(1, np.ceil((np.max(x) - np.min(x)) / d))
85
- return int(round(n_bins))
85
+ return round(n_bins)
86
86
 
87
87
 
88
88
  def n_bins_knuth(x: np.ndarray) -> int:
@@ -122,12 +122,12 @@ def n_bins_knuth(x: np.ndarray) -> int:
122
122
 
123
123
  n_bins_init = n_bins_freedman(x)
124
124
  n_bins = sco.fmin(func, n_bins_init, disp=0)[0]
125
- return int(round(n_bins))
125
+ return round(n_bins)
126
126
 
127
127
 
128
128
  def rand_weights_dirichlet(n: int) -> np.array:
129
129
  """Produces n random weights that sum to one from a dirichlet distribution
130
- (uniform distribution over a simplex)
130
+ (uniform distribution over a simplex).
131
131
 
132
132
  Parameters
133
133
  ----------
@@ -144,7 +144,7 @@ def rand_weights_dirichlet(n: int) -> np.array:
144
144
 
145
145
  def rand_weights(n: int, zeros: int = 0) -> np.array:
146
146
  """Produces n random weights that sum to one from an uniform distribution
147
- (non-uniform distribution over a simplex)
147
+ (non-uniform distribution over a simplex).
148
148
 
149
149
  Parameters
150
150
  ----------
skfolio/utils/tools.py CHANGED
@@ -1,12 +1,13 @@
1
- """Tools module"""
1
+ """Tools module."""
2
2
 
3
3
  # Copyright (c) 2023
4
4
  # Author: Hugo Delatte <delatte.hugo@gmail.com>
5
- # License: BSD 3 clause
5
+ # SPDX-License-Identifier: BSD-3-Clause
6
6
  # Implementation derived from:
7
7
  # scikit-learn, Copyright (c) 2007-2010 David Cournapeau, Fabian Pedregosa, Olivier
8
8
  # Grisel Licensed under BSD 3 clause.
9
9
 
10
+ import warnings
10
11
  from collections.abc import Callable, Iterator
11
12
  from enum import Enum
12
13
  from functools import wraps
@@ -36,19 +37,20 @@ __all__ = [
36
37
  "optimal_rounding_decimals",
37
38
  "safe_indexing",
38
39
  "safe_split",
40
+ "validate_input_list",
39
41
  ]
40
42
 
41
43
  GenericAlias = type(list[int])
42
44
 
43
45
 
44
46
  class AutoEnum(str, Enum):
45
- """Base Enum class used in `skfolio`"""
47
+ """Base Enum class used in `skfolio`."""
46
48
 
47
49
  @staticmethod
48
50
  def _generate_next_value_(
49
51
  name: str, start: int, count: int, last_values: Any
50
52
  ) -> str:
51
- """Overriding `auto()`"""
53
+ """Overriding `auto()`."""
52
54
  return name.lower()
53
55
 
54
56
  @classmethod
@@ -68,13 +70,13 @@ class AutoEnum(str, Enum):
68
70
  return value in cls._value2member_map_
69
71
 
70
72
  def __repr__(self) -> str:
71
- """Representation of the Enum"""
73
+ """Representation of the Enum."""
72
74
  return self.name
73
75
 
74
76
 
75
77
  # noinspection PyPep8Naming
76
78
  class cached_property_slots:
77
- """Cached property decorator for slots"""
79
+ """Cached property decorator for slots."""
78
80
 
79
81
  def __init__(self, func):
80
82
  self.func = func
@@ -83,10 +85,12 @@ class cached_property_slots:
83
85
  self.__doc__ = func.__doc__
84
86
 
85
87
  def __set_name__(self, owner, name):
88
+ """Set Name."""
86
89
  self.public_name = name
87
90
  self.private_name = f"_{name}"
88
91
 
89
92
  def __get__(self, instance, owner=None):
93
+ """Getter."""
90
94
  if instance is None:
91
95
  return self
92
96
  if self.private_name is None:
@@ -102,6 +106,7 @@ class cached_property_slots:
102
106
  return value
103
107
 
104
108
  def __set__(self, instance, owner=None):
109
+ """Setter."""
105
110
  raise AttributeError(
106
111
  f"'{type(instance).__name__}' object attribute '{self.public_name}' is"
107
112
  " read-only"
@@ -111,7 +116,7 @@ class cached_property_slots:
111
116
 
112
117
 
113
118
  def _make_key(args, kwds) -> int:
114
- """Make a cache key from optionally typed positional and keyword arguments"""
119
+ """Make a cache key from optionally typed positional and keyword arguments."""
115
120
  key = args
116
121
  if kwds:
117
122
  for item in kwds.items():
@@ -248,7 +253,6 @@ def safe_split(
248
253
  y_subset : array-like
249
254
  Indexed targets.
250
255
  """
251
-
252
256
  X_subset = safe_indexing(X, indices=indices, axis=axis)
253
257
  if y is not None:
254
258
  y_subset = safe_indexing(y, indices=indices, axis=axis)
@@ -340,10 +344,9 @@ def check_estimator(
340
344
 
341
345
  Returns
342
346
  -------
343
- estimator: Estimator
347
+ estimator : Estimator
344
348
  The checked estimator or the default.
345
349
  """
346
-
347
350
  if estimator is None:
348
351
  return default
349
352
  if not isinstance(estimator, check_type):
@@ -439,6 +442,67 @@ def input_to_array(
439
442
  return arr
440
443
 
441
444
 
445
+ def validate_input_list(
446
+ items: list[int | str],
447
+ n_assets: int,
448
+ assets_names: np.ndarray[str] | None,
449
+ name: str,
450
+ raise_if_string_missing: bool = True,
451
+ ) -> list[int]:
452
+ """Convert a list of items (asset indices or asset names) into a list of
453
+ validated asset indices.
454
+
455
+ Parameters
456
+ ----------
457
+ items : list[int | str]
458
+ List of asset indices or asset names.
459
+
460
+ n_assets : int
461
+ Expected number of assets.
462
+ Used for verification.
463
+
464
+ assets_names : ndarray, optional
465
+ Asset names used when `items` contain strings.
466
+
467
+ name : str
468
+ Name of the items used for error messages.
469
+
470
+ raise_if_string_missing : bool, default=True
471
+ If set to True, raises an error if an item string is missing from assets_names;
472
+ otherwise, issue a User Warning.
473
+
474
+ Returns
475
+ -------
476
+ values : list[int]
477
+ Converted and validated list.
478
+ """
479
+ if len(set(items)) != len(items):
480
+ raise ValueError(f"Duplicates found in {items}")
481
+
482
+ asset_indices = set(range(n_assets))
483
+ res = []
484
+ for asset in items:
485
+ if isinstance(asset, str):
486
+ if assets_names is None:
487
+ raise ValueError(
488
+ f"If `{name}` is provided as a list of string, you must input `X` "
489
+ f"as a DataFrame with assets names in columns."
490
+ )
491
+ mask = assets_names == asset
492
+ if np.any(mask):
493
+ res.append(int(np.where(mask)[0][0]))
494
+ else:
495
+ if raise_if_string_missing:
496
+ raise ValueError(f"{asset} not found in {assets_names}")
497
+ else:
498
+ warnings.warn(f"{asset} not found in {assets_names}", stacklevel=2)
499
+ else:
500
+ if asset not in asset_indices:
501
+ raise ValueError(f"`central_assets` {asset} is not in {asset_indices}.")
502
+ res.append(int(asset))
503
+ return res
504
+
505
+
442
506
  def format_measure(x: float, percent: bool = False) -> str:
443
507
  """Format a measure number into a user-friendly string.
444
508
 
@@ -514,7 +578,7 @@ def fit_single_estimator(
514
578
  indices: np.ndarray | None = None,
515
579
  axis: int = 0,
516
580
  ):
517
- """function used to fit an estimator within a job.
581
+ """Function used to fit an estimator within a job.
518
582
 
519
583
  Parameters
520
584
  ----------
@@ -622,7 +686,7 @@ def fit_and_predict(
622
686
 
623
687
 
624
688
  def default_asset_names(n_assets: int) -> np.ndarray:
625
- """Default asset names are `["x0", "x1", ..., "x(n_assets - 1)"]`
689
+ """Default asset names are `["x0", "x1", ..., "x(n_assets - 1)"]`.
626
690
 
627
691
  Parameters
628
692
  ----------