trade-study 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,54 @@
1
+ """Multi-objective trade-study orchestration.
2
+
3
+ Scoring, Pareto optimization, and Bayesian stacking.
4
+ """
5
+
6
+ from ._pareto import extract_front, hypervolume, igd_plus, pareto_rank
7
+ from ._scoring import coverage_curve, score
8
+ from ._version import __version__
9
+ from .design import Factor, FactorType, build_grid, reduce_factors, screen
10
+ from .io import load_results, save_results
11
+ from .protocols import (
12
+ Annotation,
13
+ Direction,
14
+ Observable,
15
+ ResultsTable,
16
+ Scorer,
17
+ Simulator,
18
+ TrialResult,
19
+ )
20
+ from .runner import run_adaptive, run_grid
21
+ from .stacking import ensemble_predict, stack_bayesian, stack_scores
22
+ from .study import Phase, Study, top_k_pareto_filter
23
+
24
+ __all__ = [
25
+ "Annotation",
26
+ "Direction",
27
+ "Factor",
28
+ "FactorType",
29
+ "Observable",
30
+ "Phase",
31
+ "ResultsTable",
32
+ "Scorer",
33
+ "Simulator",
34
+ "Study",
35
+ "TrialResult",
36
+ "__version__",
37
+ "build_grid",
38
+ "coverage_curve",
39
+ "ensemble_predict",
40
+ "extract_front",
41
+ "hypervolume",
42
+ "igd_plus",
43
+ "load_results",
44
+ "pareto_rank",
45
+ "reduce_factors",
46
+ "run_adaptive",
47
+ "run_grid",
48
+ "save_results",
49
+ "score",
50
+ "screen",
51
+ "stack_bayesian",
52
+ "stack_scores",
53
+ "top_k_pareto_filter",
54
+ ]
trade_study/_pareto.py ADDED
@@ -0,0 +1,128 @@
1
+ """Pareto front extraction and performance indicators.
2
+
3
+ Wraps pymoo for non-dominated sorting and hypervolume computation.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import numpy as np
11
+
12
+ from .protocols import Direction
13
+
14
+ if TYPE_CHECKING:
15
+ from numpy.typing import NDArray
16
+
17
+
18
+ def extract_front(
19
+ scores: NDArray[np.floating[Any]],
20
+ directions: list[Direction],
21
+ ) -> NDArray[np.intp]:
22
+ """Extract Pareto-optimal indices from a score matrix.
23
+
24
+ Args:
25
+ scores: Array of shape (n_trials, n_objectives).
26
+ directions: Optimization direction for each objective.
27
+
28
+ Returns:
29
+ Integer array of row indices on the Pareto front.
30
+ """
31
+ from pymoo.util.nds.non_dominated_sorting import ( # type: ignore[import-untyped]
32
+ NonDominatedSorting,
33
+ )
34
+
35
+ # pymoo assumes minimization; flip maximize objectives
36
+ obj = scores.copy()
37
+ for j, d in enumerate(directions):
38
+ if d == Direction.MAXIMIZE:
39
+ obj[:, j] = -obj[:, j]
40
+
41
+ nds = NonDominatedSorting()
42
+ fronts = nds.do(obj)
43
+ return np.asarray(fronts[0], dtype=np.intp)
44
+
45
+
46
+ def pareto_rank(
47
+ scores: NDArray[np.floating[Any]],
48
+ directions: list[Direction],
49
+ ) -> NDArray[np.intp]:
50
+ """Assign Pareto rank to each trial (0 = front, 1 = next layer, ...).
51
+
52
+ Args:
53
+ scores: Array of shape (n_trials, n_objectives).
54
+ directions: Optimization direction for each objective.
55
+
56
+ Returns:
57
+ Integer array of ranks, shape (n_trials,).
58
+ """
59
+ from pymoo.util.nds.non_dominated_sorting import (
60
+ NonDominatedSorting,
61
+ )
62
+
63
+ obj = scores.copy()
64
+ for j, d in enumerate(directions):
65
+ if d == Direction.MAXIMIZE:
66
+ obj[:, j] = -obj[:, j]
67
+
68
+ nds = NonDominatedSorting()
69
+ fronts = nds.do(obj)
70
+ ranks = np.empty(len(scores), dtype=np.intp)
71
+ for rank, front in enumerate(fronts):
72
+ ranks[front] = rank
73
+ return ranks
74
+
75
+
76
+ def hypervolume(
77
+ front: NDArray[np.floating[Any]],
78
+ ref_point: NDArray[np.floating[Any]],
79
+ directions: list[Direction] | None = None,
80
+ ) -> float:
81
+ """Compute hypervolume indicator for a Pareto front.
82
+
83
+ Args:
84
+ front: Array of shape (n_points, n_objectives) on the front.
85
+ ref_point: Reference point (should dominate all front points after
86
+ direction normalization).
87
+ directions: If provided, flips maximize objectives before computing.
88
+
89
+ Returns:
90
+ Hypervolume value.
91
+ """
92
+ from pymoo.indicators.hv import HV # type: ignore[import-untyped]
93
+
94
+ obj = front.copy()
95
+ rp = ref_point.copy()
96
+ if directions is not None:
97
+ for j, d in enumerate(directions):
98
+ if d == Direction.MAXIMIZE:
99
+ obj[:, j] = -obj[:, j]
100
+ rp[j] = -rp[j]
101
+ return float(HV(ref_point=rp)(obj))
102
+
103
+
104
+ def igd_plus(
105
+ front: NDArray[np.floating[Any]],
106
+ reference: NDArray[np.floating[Any]],
107
+ directions: list[Direction] | None = None,
108
+ ) -> float:
109
+ """Compute IGD+ indicator.
110
+
111
+ Args:
112
+ front: Obtained Pareto front.
113
+ reference: Reference Pareto front.
114
+ directions: Optimization directions.
115
+
116
+ Returns:
117
+ IGD+ value (lower is better).
118
+ """
119
+ from pymoo.indicators.igd_plus import IGDPlus # type: ignore[import-untyped]
120
+
121
+ obj = front.copy()
122
+ ref = reference.copy()
123
+ if directions is not None:
124
+ for j, d in enumerate(directions):
125
+ if d == Direction.MAXIMIZE:
126
+ obj[:, j] = -obj[:, j]
127
+ ref[:, j] = -ref[:, j]
128
+ return float(IGDPlus(ref)(obj))
@@ -0,0 +1,213 @@
1
+ """Scoring functions wrapping scoringrules and scipy.
2
+
3
+ Provides a uniform ``score(metric, predictions, truth)`` interface
4
+ for all proper scoring rules and calibration diagnostics.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ import numpy as np
12
+
13
+ if TYPE_CHECKING:
14
+ from numpy.typing import NDArray
15
+
16
+
17
+ def score(
18
+ metric: str,
19
+ predictions: NDArray[np.floating[Any]],
20
+ truth: NDArray[np.floating[Any]],
21
+ *,
22
+ alpha: float | NDArray[np.floating[Any]] | None = None,
23
+ level: float = 0.95,
24
+ ) -> float:
25
+ """Compute a scalar scoring rule.
26
+
27
+ Args:
28
+ metric: One of "crps", "wis", "interval", "energy",
29
+ "rmse", "mae", "coverage", "brier".
30
+ predictions: Model predictions (ensemble members, quantiles, etc.).
31
+ truth: Known ground truth values.
32
+ alpha: Significance level for interval-based scores.
33
+ level: Nominal coverage level for coverage metric.
34
+
35
+ Returns:
36
+ Scalar score value.
37
+
38
+ Raises:
39
+ ValueError: If the metric name is not recognized.
40
+ """
41
+ simple = {
42
+ "crps": _crps,
43
+ "energy": _energy,
44
+ "brier": _brier,
45
+ "rmse": _rmse,
46
+ "mae": _mae,
47
+ }
48
+ if metric in simple:
49
+ return simple[metric](predictions, truth)
50
+ if metric == "wis":
51
+ return _wis(predictions, truth, alpha=alpha)
52
+ if metric == "interval":
53
+ return _interval(predictions, truth, alpha=alpha)
54
+ if metric == "coverage":
55
+ return _coverage(predictions, truth, level=level)
56
+ msg = f"Unknown metric: {metric!r}"
57
+ raise ValueError(msg)
58
+
59
+
60
+ def _crps(
61
+ ensemble: NDArray[np.floating[Any]],
62
+ truth: NDArray[np.floating[Any]],
63
+ ) -> float:
64
+ """CRPS via scoringrules.
65
+
66
+ Returns:
67
+ Mean CRPS across observations.
68
+ """
69
+ import scoringrules as sr # type: ignore[import-untyped]
70
+
71
+ return float(np.mean(sr.crps_ensemble(truth, ensemble)))
72
+
73
+
74
+ def _wis(
75
+ predictions: NDArray[np.floating[Any]],
76
+ truth: NDArray[np.floating[Any]],
77
+ *,
78
+ alpha: float | NDArray[np.floating[Any]] | None = None,
79
+ ) -> float:
80
+ """Weighted interval score via scoringrules.
81
+
82
+ Returns:
83
+ Mean WIS across observations.
84
+ """
85
+ import scoringrules as sr
86
+
87
+ if alpha is None:
88
+ alpha = np.array([0.02, 0.05, 0.1, 0.2, 0.5])
89
+ return float(
90
+ np.mean(
91
+ sr.weighted_interval_score(
92
+ truth,
93
+ predictions[..., 0],
94
+ predictions[..., 1],
95
+ predictions[..., 2],
96
+ alpha,
97
+ ),
98
+ ),
99
+ )
100
+
101
+
102
+ def _interval(
103
+ predictions: NDArray[np.floating[Any]],
104
+ truth: NDArray[np.floating[Any]],
105
+ *,
106
+ alpha: float | NDArray[np.floating[Any]] | None = None,
107
+ ) -> float:
108
+ """Interval score via scoringrules.
109
+
110
+ Returns:
111
+ Mean interval score across observations.
112
+ """
113
+ import scoringrules as sr
114
+
115
+ if alpha is None:
116
+ alpha = 0.05
117
+ return float(
118
+ np.mean(
119
+ sr.interval_score(truth, predictions[..., 0], predictions[..., 1], alpha),
120
+ ),
121
+ )
122
+
123
+
124
+ def _coverage(
125
+ predictions: NDArray[np.floating[Any]],
126
+ truth: NDArray[np.floating[Any]],
127
+ *,
128
+ level: float = 0.95,
129
+ ) -> float:
130
+ """Empirical coverage rate at a given nominal level.
131
+
132
+ Returns:
133
+ Fraction of truth values within the predicted interval.
134
+ """
135
+ cov_alpha = 1.0 - level
136
+ lower = np.quantile(predictions, cov_alpha / 2, axis=-1)
137
+ upper = np.quantile(predictions, 1 - cov_alpha / 2, axis=-1)
138
+ return float(np.mean((truth >= lower) & (truth <= upper)))
139
+
140
+
141
+ def _energy(
142
+ ensemble: NDArray[np.floating[Any]],
143
+ truth: NDArray[np.floating[Any]],
144
+ ) -> float:
145
+ """Energy score via scoringrules.
146
+
147
+ Returns:
148
+ Mean energy score across observations.
149
+ """
150
+ import scoringrules as sr
151
+
152
+ return float(np.mean(sr.es_ensemble(truth, ensemble)))
153
+
154
+
155
+ def _brier(
156
+ predictions: NDArray[np.floating[Any]],
157
+ truth: NDArray[np.floating[Any]],
158
+ ) -> float:
159
+ """Brier score via scoringrules.
160
+
161
+ Returns:
162
+ Mean Brier score across observations.
163
+ """
164
+ import scoringrules as sr
165
+
166
+ return float(np.mean(sr.brier_score(truth, predictions)))
167
+
168
+
169
+ def _rmse(
170
+ predictions: NDArray[np.floating[Any]],
171
+ truth: NDArray[np.floating[Any]],
172
+ ) -> float:
173
+ """Root mean squared error.
174
+
175
+ Returns:
176
+ RMSE value.
177
+ """
178
+ return float(np.sqrt(np.mean((predictions - truth) ** 2)))
179
+
180
+
181
+ def _mae(
182
+ predictions: NDArray[np.floating[Any]],
183
+ truth: NDArray[np.floating[Any]],
184
+ ) -> float:
185
+ """Mean absolute error.
186
+
187
+ Returns:
188
+ MAE value.
189
+ """
190
+ return float(np.mean(np.abs(predictions - truth)))
191
+
192
+
193
+ def coverage_curve(
194
+ posteriors: NDArray[np.floating[Any]],
195
+ truth: NDArray[np.floating[Any]],
196
+ levels: NDArray[np.floating[Any]] | None = None,
197
+ ) -> tuple[NDArray[np.floating[Any]], NDArray[np.floating[Any]]]:
198
+ """Compute empirical coverage across nominal levels.
199
+
200
+ Args:
201
+ posteriors: Posterior samples, shape (n_obs, n_samples).
202
+ truth: True values, shape (n_obs,).
203
+ levels: Nominal coverage levels (default: 0.05 to 0.99).
204
+
205
+ Returns:
206
+ Tuple of (nominal_levels, empirical_coverage).
207
+ """
208
+ if levels is None:
209
+ levels = np.linspace(0.05, 0.99, 50)
210
+ empirical = np.array([
211
+ _coverage(posteriors, truth, level=float(lv)) for lv in levels
212
+ ])
213
+ return levels, empirical
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
trade_study/design.py ADDED
@@ -0,0 +1,309 @@
1
+ """Experimental design and factor screening.
2
+
3
+ Wraps pyDOE3 for grid construction and SALib for sensitivity screening.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ from itertools import product
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ import numpy as np
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import Callable
17
+
18
+ from numpy.typing import NDArray
19
+
20
+
21
+ class FactorType(Enum):
22
+ """Type of design factor."""
23
+
24
+ CONTINUOUS = "continuous"
25
+ DISCRETE = "discrete"
26
+ CATEGORICAL = "categorical"
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class Factor:
31
+ """A single design factor.
32
+
33
+ Attributes:
34
+ name: Factor identifier (e.g. "alpha", "layer1_method").
35
+ factor_type: Continuous, discrete, or categorical.
36
+ levels: For categorical/discrete: list of allowed values.
37
+ bounds: For continuous: (low, high) tuple.
38
+ """
39
+
40
+ name: str
41
+ factor_type: FactorType
42
+ levels: list[Any] | None = None
43
+ bounds: tuple[float, float] | None = None
44
+
45
+ def __post_init__(self) -> None:
46
+ """Validate factor constraints.
47
+
48
+ Raises:
49
+ ValueError: If name is empty, continuous factor has missing or
50
+ invalid bounds, or discrete/categorical factor has empty
51
+ levels.
52
+ """
53
+ if not self.name:
54
+ msg = "Factor name must be a non-empty string"
55
+ raise ValueError(msg)
56
+ if self.factor_type == FactorType.CONTINUOUS:
57
+ if self.bounds is None:
58
+ msg = f"Continuous factor '{self.name}' requires bounds"
59
+ raise ValueError(msg)
60
+ lo, hi = self.bounds
61
+ if not (np.isfinite(lo) and np.isfinite(hi)):
62
+ msg = f"Continuous factor '{self.name}' bounds must be finite"
63
+ raise ValueError(msg)
64
+ if lo >= hi:
65
+ msg = f"Continuous factor '{self.name}' requires lo < hi"
66
+ raise ValueError(msg)
67
+ else:
68
+ if self.levels is None:
69
+ msg = f"Factor '{self.name}' of type {self.factor_type} requires levels"
70
+ raise ValueError(msg)
71
+ if len(self.levels) == 0:
72
+ msg = f"Factor '{self.name}' levels must be non-empty"
73
+ raise ValueError(msg)
74
+
75
+
76
+ def build_grid(
77
+ factors: list[Factor],
78
+ *,
79
+ method: str = "full",
80
+ n_samples: int = 100,
81
+ seed: int = 42,
82
+ scramble: bool = True,
83
+ ) -> list[dict[str, Any]]:
84
+ """Build an experimental design grid.
85
+
86
+ Args:
87
+ factors: List of design factors.
88
+ method: Design method. One of:
89
+ - "full": Full factorial (categorical/discrete only).
90
+ - "lhs": Latin hypercube sampling (continuous factors, maps
91
+ categorical factors to uniform random selection).
92
+ - "sobol": Scrambled Sobol' sequence via ``scipy.stats.qmc``.
93
+ - "halton": Scrambled Halton sequence via ``scipy.stats.qmc``.
94
+ n_samples: Number of samples for LHS / QMC methods.
95
+ seed: Random seed.
96
+ scramble: Whether to apply scrambling to QMC sequences (Sobol /
97
+ Halton). Ignored for other methods.
98
+
99
+ Returns:
100
+ List of config dictionaries, one per design point.
101
+
102
+ Raises:
103
+ ValueError: If an unknown design method is specified.
104
+ """
105
+ if method == "full":
106
+ return _full_factorial(factors)
107
+ if method == "lhs":
108
+ return _latin_hypercube(factors, n_samples=n_samples, seed=seed)
109
+ if method in {"sobol", "halton"}:
110
+ return _qmc_sample(
111
+ factors,
112
+ n_samples=n_samples,
113
+ seed=seed,
114
+ qmc_method=method,
115
+ scramble=scramble,
116
+ )
117
+ msg = f"Unknown design method: {method!r}"
118
+ raise ValueError(msg)
119
+
120
+
121
+ def _full_factorial(factors: list[Factor]) -> list[dict[str, Any]]:
122
+ """Full factorial over all factor levels.
123
+
124
+ Returns:
125
+ List of config dictionaries, one per design point.
126
+
127
+ Raises:
128
+ ValueError: If a factor has bounds instead of levels.
129
+ """
130
+ level_lists = []
131
+ for f in factors:
132
+ if f.levels is not None:
133
+ level_lists.append(f.levels)
134
+ elif f.bounds is not None:
135
+ msg = f"Full factorial requires levels, not bounds, for factor '{f.name}'"
136
+ raise ValueError(msg)
137
+ names = [f.name for f in factors]
138
+ return [dict(zip(names, combo, strict=True)) for combo in product(*level_lists)]
139
+
140
+
141
+ def _latin_hypercube(
142
+ factors: list[Factor],
143
+ *,
144
+ n_samples: int,
145
+ seed: int,
146
+ ) -> list[dict[str, Any]]:
147
+ """Latin hypercube design via pyDOE3.
148
+
149
+ Returns:
150
+ List of config dictionaries, one per design point.
151
+ """
152
+ from pyDOE3 import lhs # type: ignore[import-untyped]
153
+
154
+ n_factors = len(factors)
155
+ raw = lhs(n_factors, samples=n_samples, criterion="maximin", seed=seed)
156
+
157
+ configs: list[dict[str, Any]] = []
158
+ for row in raw:
159
+ cfg: dict[str, Any] = {}
160
+ for j, f in enumerate(factors):
161
+ if f.factor_type == FactorType.CONTINUOUS and f.bounds is not None:
162
+ lo, hi = f.bounds
163
+ cfg[f.name] = lo + row[j] * (hi - lo)
164
+ elif f.levels is not None:
165
+ idx = int(row[j] * len(f.levels))
166
+ idx = min(idx, len(f.levels) - 1)
167
+ cfg[f.name] = f.levels[idx]
168
+ configs.append(cfg)
169
+ return configs
170
+
171
+
172
+ def _qmc_sample(
173
+ factors: list[Factor],
174
+ *,
175
+ n_samples: int,
176
+ seed: int,
177
+ qmc_method: str,
178
+ scramble: bool,
179
+ ) -> list[dict[str, Any]]:
180
+ """Quasi-Monte Carlo design via ``scipy.stats.qmc``.
181
+
182
+ Args:
183
+ factors: List of design factors.
184
+ n_samples: Number of sample points.
185
+ seed: Random seed for scrambling.
186
+ qmc_method: ``"sobol"`` or ``"halton"``.
187
+ scramble: Whether to apply scrambling.
188
+
189
+ Returns:
190
+ List of config dictionaries, one per design point.
191
+ """
192
+ from scipy.stats import qmc # type: ignore[import-untyped]
193
+
194
+ n_factors = len(factors)
195
+ sampler: qmc.QMCEngine
196
+ if qmc_method == "sobol":
197
+ sampler = qmc.Sobol(d=n_factors, scramble=scramble, seed=seed)
198
+ else:
199
+ sampler = qmc.Halton(d=n_factors, scramble=scramble, seed=seed)
200
+ raw = sampler.random(n_samples)
201
+
202
+ configs: list[dict[str, Any]] = []
203
+ for row in raw:
204
+ cfg: dict[str, Any] = {}
205
+ for j, f in enumerate(factors):
206
+ if f.factor_type == FactorType.CONTINUOUS and f.bounds is not None:
207
+ lo, hi = f.bounds
208
+ cfg[f.name] = lo + row[j] * (hi - lo)
209
+ elif f.levels is not None:
210
+ idx = int(row[j] * len(f.levels))
211
+ idx = min(idx, len(f.levels) - 1)
212
+ cfg[f.name] = f.levels[idx]
213
+ configs.append(cfg)
214
+ return configs
215
+
216
+
217
+ def screen(
218
+ run_fn: Callable[[dict[str, Any]], dict[str, float]],
219
+ factors: list[Factor],
220
+ *,
221
+ method: str = "morris",
222
+ n_trajectories: int = 100,
223
+ seed: int = 42,
224
+ ) -> dict[str, NDArray[np.floating[Any]]]:
225
+ """Screen factors for influence on observables via SALib.
226
+
227
+ Args:
228
+ run_fn: Callable that takes a config dict and returns a dict of
229
+ observable name → scalar score.
230
+ factors: List of continuous factors to screen.
231
+ method: Screening method ("morris" or "sobol").
232
+ n_trajectories: Number of Morris trajectories or Sobol samples.
233
+ seed: Random seed.
234
+
235
+ Returns:
236
+ Dictionary mapping observable names to arrays of factor importance
237
+ (mu_star for Morris, S1 for Sobol), one value per factor.
238
+
239
+ Raises:
240
+ NotImplementedError: If method is not "morris".
241
+ ValueError: If no continuous factors are provided.
242
+ """
243
+ from SALib.analyze import morris as morris_analyze # type: ignore[import-untyped]
244
+ from SALib.sample import morris as morris_sample # type: ignore[import-untyped]
245
+
246
+ if method != "morris":
247
+ msg = f"Screening method {method!r} not yet implemented"
248
+ raise NotImplementedError(msg)
249
+
250
+ continuous = [f for f in factors if f.factor_type == FactorType.CONTINUOUS]
251
+ if not continuous:
252
+ msg = "Screening requires at least one continuous factor"
253
+ raise ValueError(msg)
254
+
255
+ problem: dict[str, Any] = {
256
+ "num_vars": len(continuous),
257
+ "names": [f.name for f in continuous],
258
+ "bounds": [list(f.bounds) for f in continuous if f.bounds is not None],
259
+ }
260
+ param_values = morris_sample.sample(problem, n_trajectories, seed=seed)
261
+
262
+ # Evaluate model at each sample point
263
+ results_by_obs: dict[str, list[float]] = {}
264
+ for row in param_values:
265
+ cfg = dict(zip(problem["names"], row, strict=True))
266
+ scores = run_fn(cfg)
267
+ for obs_name, val in scores.items():
268
+ results_by_obs.setdefault(obs_name, []).append(val)
269
+
270
+ importance: dict[str, NDArray[np.floating[Any]]] = {}
271
+ for obs_name, vals in results_by_obs.items():
272
+ si = morris_analyze.analyze(
273
+ problem,
274
+ param_values,
275
+ np.array(vals),
276
+ seed=seed,
277
+ )
278
+ importance[obs_name] = np.asarray(si["mu_star"], dtype=np.float64)
279
+
280
+ return importance
281
+
282
+
283
+ def reduce_factors(
284
+ factors: list[Factor],
285
+ importance: dict[str, NDArray[np.floating[Any]]],
286
+ *,
287
+ threshold: float = 0.1,
288
+ ) -> list[Factor]:
289
+ """Keep only factors whose max importance exceeds threshold.
290
+
291
+ Args:
292
+ factors: Original factor list.
293
+ importance: Output of ``screen()``.
294
+ threshold: Minimum importance to retain a factor.
295
+
296
+ Returns:
297
+ Reduced list of influential factors.
298
+ """
299
+ continuous = [f for f in factors if f.factor_type == FactorType.CONTINUOUS]
300
+ non_continuous = [f for f in factors if f.factor_type != FactorType.CONTINUOUS]
301
+
302
+ max_importance = np.zeros(len(continuous))
303
+ for arr in importance.values():
304
+ max_importance = np.maximum(max_importance, arr)
305
+
306
+ kept = [
307
+ f for f, imp in zip(continuous, max_importance, strict=True) if imp >= threshold
308
+ ]
309
+ return non_continuous + kept