srbf 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- srbf/__init__.py +10 -0
- srbf/__main__.py +96 -0
- srbf/baselines/__init__.py +7 -0
- srbf/baselines/brute_force_model.py +329 -0
- srbf/baselines/skeleton_pool_model.py +389 -0
- srbf/benchmarks/__init__.py +5 -0
- srbf/benchmarks/fastsrb.py +524 -0
- srbf/compat/__init__.py +5 -0
- srbf/compat/nesymres.py +74 -0
- srbf/eval/__init__.py +15 -0
- srbf/eval/candidate_store.py +308 -0
- srbf/eval/core.py +109 -0
- srbf/eval/data_sources.py +1015 -0
- srbf/eval/engine.py +599 -0
- srbf/eval/evaluation.py +159 -0
- srbf/eval/formatting.py +59 -0
- srbf/eval/metrics/__init__.py +24 -0
- srbf/eval/metrics/bootstrap.py +31 -0
- srbf/eval/metrics/numeric.py +97 -0
- srbf/eval/metrics/symbolic.py +34 -0
- srbf/eval/metrics/token_prediction.py +464 -0
- srbf/eval/metrics/zss.py +42 -0
- srbf/eval/model_adapters.py +1021 -0
- srbf/eval/provenance.py +186 -0
- srbf/eval/result_processing.py +351 -0
- srbf/eval/result_store.py +134 -0
- srbf/eval/run_config.py +779 -0
- srbf/eval/sample_metadata.py +85 -0
- srbf/eval/variable_renaming.py +128 -0
- srbf/py.typed +0 -0
- srbf-0.1.0.dist-info/METADATA +106 -0
- srbf-0.1.0.dist-info/RECORD +37 -0
- srbf-0.1.0.dist-info/WHEEL +5 -0
- srbf-0.1.0.dist-info/entry_points.txt +2 -0
- srbf-0.1.0.dist-info/licenses/LICENSE +21 -0
- srbf-0.1.0.dist-info/licenses/THIRD_PARTY_LICENSES +84 -0
- srbf-0.1.0.dist-info/top_level.txt +1 -0
srbf/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
"""srbf: the symbolic-regression evaluation framework, carved from flash-ansr.
|
|
2
|
+
|
|
3
|
+
Engine + model adapters + benchmarks + metrics for evaluating symbolic-regression models.
|
|
4
|
+
Depends one-way on flash-ansr (srbf imports flash-ansr; flash-ansr never imports srbf).
|
|
5
|
+
"""
|
|
6
|
+
from srbf.eval.evaluation import Evaluation
|
|
7
|
+
from srbf.eval.run_config import EvaluationRunPlan, build_evaluation_run
|
|
8
|
+
|
|
9
|
+
__all__ = ["Evaluation", "EvaluationRunPlan", "build_evaluation_run"]
|
|
10
|
+
__version__ = "0.1.0"
|
srbf/__main__.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""srbf command-line interface: the ``run`` subcommand carved from flash-ansr.
|
|
2
|
+
|
|
3
|
+
flash-ansr keeps the rest of its CLI (train / benchmark / import-data / install / ...); only
|
|
4
|
+
``run`` is evaluation-bound and lives here. The eval imports are ``srbf.eval.*``; the
|
|
5
|
+
flash-ansr ``utils`` imports are the cross-repo contract (srbf depends one-way on flash-ansr).
|
|
6
|
+
"""
|
|
7
|
+
import argparse
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def main(argv: list[str] | None = None) -> None:
|
|
11
|
+
parser = argparse.ArgumentParser(description="srbf: symbolic-regression evaluation framework")
|
|
12
|
+
subparsers = parser.add_subparsers(dest="command_name", required=True)
|
|
13
|
+
|
|
14
|
+
run_parser = subparsers.add_parser("run", help="Run an evaluation from a unified config")
|
|
15
|
+
run_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the evaluation run config file')
|
|
16
|
+
run_parser.add_argument('-n', '--limit', type=int, default=None, help='Override the sample limit specified in the config')
|
|
17
|
+
run_parser.add_argument('-o', '--output-file', type=str, default=None, help='Override the output file path from the config')
|
|
18
|
+
run_parser.add_argument('--save-every', type=int, default=None, help='Override periodic save frequency')
|
|
19
|
+
run_parser.add_argument('--no-resume', action='store_true', help='Ignore previous results even if the output file exists')
|
|
20
|
+
run_parser.add_argument('--experiment', type=str, default=None, help='Name of the experiment defined in the config to execute')
|
|
21
|
+
run_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
|
|
22
|
+
|
|
23
|
+
args = parser.parse_args(argv)
|
|
24
|
+
|
|
25
|
+
match args.command_name:
|
|
26
|
+
case 'run':
|
|
27
|
+
from srbf.eval.run_config import build_evaluation_run, EvaluationRunPlan
|
|
28
|
+
from flash_ansr.utils.config_io import load_config
|
|
29
|
+
from flash_ansr.utils.paths import substitute_root_path
|
|
30
|
+
|
|
31
|
+
config_path = substitute_root_path(args.config)
|
|
32
|
+
if args.verbose:
|
|
33
|
+
print(f"Running evaluation plan from {config_path}")
|
|
34
|
+
|
|
35
|
+
raw_config = load_config(config_path)
|
|
36
|
+
experiment_map = raw_config.get("experiments") if isinstance(raw_config, dict) else None
|
|
37
|
+
|
|
38
|
+
from srbf.eval.provenance import collect_provenance, format_provenance
|
|
39
|
+
base_prov = collect_provenance(config_path, None)
|
|
40
|
+
print(format_provenance(base_prov), flush=True)
|
|
41
|
+
|
|
42
|
+
def _execute_plan(plan: EvaluationRunPlan, experiment_name: str | None = None) -> None:
|
|
43
|
+
label = f"[{experiment_name}] " if experiment_name else ""
|
|
44
|
+
if plan.completed or plan.engine is None:
|
|
45
|
+
if args.verbose:
|
|
46
|
+
target = plan.total_limit or 'configured'
|
|
47
|
+
print(f"{label}Evaluation already completed ({plan.existing_results}/{target}). Nothing to do.")
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
plan.engine.run(
|
|
51
|
+
limit=plan.remaining,
|
|
52
|
+
save_every=plan.save_every,
|
|
53
|
+
output_path=plan.output_path,
|
|
54
|
+
verbose=args.verbose,
|
|
55
|
+
progress=args.verbose,
|
|
56
|
+
meta={**base_prov, "experiment": experiment_name},
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if args.verbose:
|
|
60
|
+
total = plan.engine.result_store.size
|
|
61
|
+
destination = plan.output_path or 'memory'
|
|
62
|
+
print(f"{label}Evaluation finished with {total} samples (saved to {destination}).")
|
|
63
|
+
|
|
64
|
+
if experiment_map and args.experiment is None:
|
|
65
|
+
experiment_names = list(experiment_map.keys())
|
|
66
|
+
if args.verbose:
|
|
67
|
+
count = len(experiment_names)
|
|
68
|
+
print(f"No --experiment provided; running all {count} experiments defined in config.")
|
|
69
|
+
for experiment_name in experiment_names:
|
|
70
|
+
if args.verbose:
|
|
71
|
+
print(f"--> {experiment_name}")
|
|
72
|
+
plan = build_evaluation_run(
|
|
73
|
+
config=config_path,
|
|
74
|
+
limit_override=args.limit,
|
|
75
|
+
output_override=args.output_file,
|
|
76
|
+
save_every_override=args.save_every,
|
|
77
|
+
resume=None if not args.no_resume else False,
|
|
78
|
+
experiment=experiment_name,
|
|
79
|
+
)
|
|
80
|
+
_execute_plan(plan, experiment_name)
|
|
81
|
+
else:
|
|
82
|
+
plan = build_evaluation_run(
|
|
83
|
+
config=config_path,
|
|
84
|
+
limit_override=args.limit,
|
|
85
|
+
output_override=args.output_file,
|
|
86
|
+
save_every_override=args.save_every,
|
|
87
|
+
resume=None if not args.no_resume else False,
|
|
88
|
+
experiment=args.experiment,
|
|
89
|
+
)
|
|
90
|
+
_execute_plan(plan, args.experiment)
|
|
91
|
+
case _:
|
|
92
|
+
parser.print_help()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
if __name__ == "__main__":
|
|
96
|
+
main()
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import os
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Any, Generator, Literal, Sequence
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
from sklearn.base import BaseEstimator
|
|
10
|
+
from simplipy import SimpliPyEngine
|
|
11
|
+
from simplipy.utils import construct_expressions
|
|
12
|
+
|
|
13
|
+
from flash_ansr.expressions import SkeletonPool
|
|
14
|
+
from flash_ansr.refine import Refiner, ConvergenceError
|
|
15
|
+
from flash_ansr.scoring import compute_fvu, count_constants, is_constant_token, normalize_variance, score_from_fvu
|
|
16
|
+
from flash_ansr.utils.paths import substitute_root_path
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BruteForceModel(BaseEstimator):
|
|
20
|
+
"""Exhaustive baseline that enumerates expressions in increasing length.
|
|
21
|
+
|
|
22
|
+
Expressions are generated shortest-first using ``simplipy.utils.construct_expressions``
|
|
23
|
+
over the operator and variable vocabulary defined by the provided
|
|
24
|
+
``SkeletonPool``. Each candidate is refined with the shared ``Refiner`` to
|
|
25
|
+
fit constants against user-supplied data.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
FLOAT64_EPS: float = float(np.finfo(np.float64).eps)
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
*,
|
|
33
|
+
simplipy_engine: SimpliPyEngine,
|
|
34
|
+
skeleton_pool: str | dict[str, Any] | SkeletonPool,
|
|
35
|
+
max_expressions: int = 10_000,
|
|
36
|
+
max_length: int | None = None,
|
|
37
|
+
include_constant_token: bool = True,
|
|
38
|
+
ignore_holdouts: bool = True,
|
|
39
|
+
n_restarts: int = 8,
|
|
40
|
+
refiner_method: Literal[
|
|
41
|
+
'curve_fit_lm',
|
|
42
|
+
'minimize_bfgs',
|
|
43
|
+
'minimize_lbfgsb',
|
|
44
|
+
'minimize_neldermead',
|
|
45
|
+
'minimize_powell',
|
|
46
|
+
'least_squares_trf',
|
|
47
|
+
'least_squares_dogbox',
|
|
48
|
+
] = 'curve_fit_lm',
|
|
49
|
+
refiner_p0_noise: Literal['uniform', 'normal'] | None = 'normal',
|
|
50
|
+
refiner_p0_noise_kwargs: dict | Literal['default'] | None = 'default',
|
|
51
|
+
numpy_errors: Literal['ignore', 'warn', 'raise', 'call', 'print', 'log'] | None = 'ignore',
|
|
52
|
+
length_penalty: float = 0.05,
|
|
53
|
+
constants_penalty: float = 0.0,
|
|
54
|
+
likelihood_penalty: float = 0.0,
|
|
55
|
+
) -> None:
|
|
56
|
+
self.simplipy_engine = simplipy_engine
|
|
57
|
+
self.max_expressions = int(max_expressions)
|
|
58
|
+
self.max_length = max_length
|
|
59
|
+
self.include_constant_token = include_constant_token
|
|
60
|
+
self.ignore_holdouts = ignore_holdouts
|
|
61
|
+
self.n_restarts = n_restarts
|
|
62
|
+
self.refiner_method = refiner_method
|
|
63
|
+
self.refiner_p0_noise = refiner_p0_noise
|
|
64
|
+
if refiner_p0_noise_kwargs == 'default':
|
|
65
|
+
refiner_p0_noise_kwargs = {'loc': 0.0, 'scale': 5.0}
|
|
66
|
+
self.refiner_p0_noise_kwargs = copy.deepcopy(refiner_p0_noise_kwargs) if refiner_p0_noise_kwargs is not None else None
|
|
67
|
+
self.numpy_errors = numpy_errors
|
|
68
|
+
self.length_penalty = float(length_penalty)
|
|
69
|
+
self.constants_penalty = float(constants_penalty)
|
|
70
|
+
self.likelihood_penalty = float(likelihood_penalty)
|
|
71
|
+
|
|
72
|
+
self._pool = self._ensure_pool(skeleton_pool)
|
|
73
|
+
self._results: list[dict[str, Any]] = []
|
|
74
|
+
self.results: pd.DataFrame = pd.DataFrame()
|
|
75
|
+
self._input_dim: int | None = None
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def n_variables(self) -> int:
|
|
79
|
+
return self._pool.n_variables
|
|
80
|
+
|
|
81
|
+
def _ensure_pool(self, skeleton_pool_ref: str | dict[str, Any] | SkeletonPool) -> SkeletonPool:
|
|
82
|
+
if isinstance(skeleton_pool_ref, SkeletonPool):
|
|
83
|
+
pool = skeleton_pool_ref
|
|
84
|
+
elif isinstance(skeleton_pool_ref, str):
|
|
85
|
+
resolved = substitute_root_path(skeleton_pool_ref)
|
|
86
|
+
if os.path.isdir(resolved):
|
|
87
|
+
_, pool = SkeletonPool.load(resolved)
|
|
88
|
+
else:
|
|
89
|
+
pool = SkeletonPool.from_config(resolved)
|
|
90
|
+
elif isinstance(skeleton_pool_ref, dict):
|
|
91
|
+
pool = SkeletonPool.from_config(copy.deepcopy(skeleton_pool_ref))
|
|
92
|
+
else:
|
|
93
|
+
raise TypeError("`skeleton_pool` must be a SkeletonPool, path string, or configuration dictionary.")
|
|
94
|
+
|
|
95
|
+
if self.ignore_holdouts:
|
|
96
|
+
pool.clear_holdouts()
|
|
97
|
+
|
|
98
|
+
return pool
|
|
99
|
+
|
|
100
|
+
def _truncate_input(self, X: np.ndarray) -> np.ndarray:
|
|
101
|
+
n_features = X.shape[-1]
|
|
102
|
+
if n_features == self.n_variables:
|
|
103
|
+
return X
|
|
104
|
+
if n_features < self.n_variables:
|
|
105
|
+
pad_width = self.n_variables - n_features
|
|
106
|
+
pad = np.zeros((*X.shape[:-1], pad_width), dtype=X.dtype)
|
|
107
|
+
return np.concatenate([X, pad], axis=-1)
|
|
108
|
+
|
|
109
|
+
return X[..., : self.n_variables]
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def _normalize_variance(variance: float) -> float:
|
|
113
|
+
return normalize_variance(variance)
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def _compute_fvu(loss: float, sample_count: int, variance: float) -> float:
|
|
117
|
+
return compute_fvu(loss, sample_count, variance)
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _is_constant_token(token: str) -> bool:
|
|
121
|
+
return is_constant_token(token)
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def _count_constants(cls, expression: Sequence[str]) -> int:
|
|
125
|
+
return count_constants(expression)
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def _score_from_fvu(
|
|
129
|
+
fvu: float,
|
|
130
|
+
complexity: int,
|
|
131
|
+
constant_count: int,
|
|
132
|
+
log_prob: float | None,
|
|
133
|
+
length_penalty: float,
|
|
134
|
+
constants_penalty: float,
|
|
135
|
+
likelihood_penalty: float) -> float:
|
|
136
|
+
return score_from_fvu(
|
|
137
|
+
fvu, complexity, constant_count, log_prob,
|
|
138
|
+
length_penalty, constants_penalty, likelihood_penalty)
|
|
139
|
+
|
|
140
|
+
def _leaf_nodes(self) -> list[str]:
|
|
141
|
+
leaves = list(self._pool.variables)
|
|
142
|
+
if self.include_constant_token:
|
|
143
|
+
leaves.append('<constant>')
|
|
144
|
+
return leaves
|
|
145
|
+
|
|
146
|
+
def _non_leaf_nodes(self) -> dict[str, int]:
|
|
147
|
+
operator_weights = self._pool.operator_weights or {}
|
|
148
|
+
return {op: arity for op, arity in self.simplipy_engine.operator_arity.items() if operator_weights.get(op, 0) > 0}
|
|
149
|
+
|
|
150
|
+
def _expression_generator(self) -> Generator[tuple[str, ...], None, None]:
|
|
151
|
+
hashes_by_size: defaultdict[int, set[tuple[str, ...]]] = defaultdict(set)
|
|
152
|
+
seen: set[tuple[str, ...]] = set()
|
|
153
|
+
|
|
154
|
+
for leaf in self._leaf_nodes():
|
|
155
|
+
expr = (leaf,)
|
|
156
|
+
hashes_by_size[1].add(expr)
|
|
157
|
+
seen.add(expr)
|
|
158
|
+
yield expr
|
|
159
|
+
if len(seen) >= self.max_expressions:
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
target_length = 2
|
|
163
|
+
while len(seen) < self.max_expressions:
|
|
164
|
+
new_expressions: list[tuple[str, ...]] = []
|
|
165
|
+
for expr in construct_expressions(hashes_by_size, self._non_leaf_nodes(), must_have_sizes=None):
|
|
166
|
+
expr_len = len(expr)
|
|
167
|
+
if self.max_length is not None and expr_len > self.max_length:
|
|
168
|
+
continue
|
|
169
|
+
if expr_len != target_length:
|
|
170
|
+
continue
|
|
171
|
+
if expr in seen:
|
|
172
|
+
continue
|
|
173
|
+
if not self.simplipy_engine.is_valid(list(expr)):
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
seen.add(expr)
|
|
177
|
+
new_expressions.append(expr)
|
|
178
|
+
yield expr
|
|
179
|
+
if len(seen) >= self.max_expressions:
|
|
180
|
+
break
|
|
181
|
+
|
|
182
|
+
if not new_expressions:
|
|
183
|
+
break
|
|
184
|
+
|
|
185
|
+
hashes_by_size[target_length].update(new_expressions)
|
|
186
|
+
target_length += 1
|
|
187
|
+
|
|
188
|
+
def fit(self, X: np.ndarray | torch.Tensor | pd.DataFrame, y: np.ndarray | torch.Tensor | pd.DataFrame | Sequence[float], *, verbose: bool = False) -> "BruteForceModel":
|
|
189
|
+
if len(np.shape(y)) == 1:
|
|
190
|
+
y = np.reshape(y, (-1, 1))
|
|
191
|
+
|
|
192
|
+
if isinstance(X, torch.Tensor):
|
|
193
|
+
X_np = X.detach().cpu().numpy()
|
|
194
|
+
elif isinstance(X, pd.DataFrame):
|
|
195
|
+
X_np = X.values
|
|
196
|
+
else:
|
|
197
|
+
X_np = np.asarray(X)
|
|
198
|
+
|
|
199
|
+
if isinstance(y, torch.Tensor):
|
|
200
|
+
y_np = y.detach().cpu().numpy()
|
|
201
|
+
elif isinstance(y, (pd.DataFrame, pd.Series)):
|
|
202
|
+
y_np = y.values
|
|
203
|
+
else:
|
|
204
|
+
y_np = np.asarray(y)
|
|
205
|
+
|
|
206
|
+
if y_np.ndim == 1:
|
|
207
|
+
y_np = y_np.reshape(-1, 1)
|
|
208
|
+
elif y_np.shape[-1] != 1:
|
|
209
|
+
raise ValueError("The target data must have a single output dimension.")
|
|
210
|
+
|
|
211
|
+
X_np = self._truncate_input(np.asarray(X_np))
|
|
212
|
+
self._input_dim = X_np.shape[1]
|
|
213
|
+
|
|
214
|
+
sample_count = y_np.shape[0]
|
|
215
|
+
if sample_count <= 1:
|
|
216
|
+
y_variance = float('nan')
|
|
217
|
+
else:
|
|
218
|
+
y_variance = float(np.var(y_np, axis=0, ddof=1).item())
|
|
219
|
+
|
|
220
|
+
numpy_state = np.geterr()
|
|
221
|
+
np.seterr(all=self.numpy_errors)
|
|
222
|
+
|
|
223
|
+
results: list[dict[str, Any]] = []
|
|
224
|
+
for skeleton in self._expression_generator():
|
|
225
|
+
expression_tokens = list(skeleton)
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
refiner = Refiner(self.simplipy_engine, n_variables=self.n_variables).fit(
|
|
229
|
+
expression=expression_tokens,
|
|
230
|
+
X=X_np,
|
|
231
|
+
y=y_np,
|
|
232
|
+
n_restarts=self.n_restarts,
|
|
233
|
+
method=self.refiner_method,
|
|
234
|
+
p0=None,
|
|
235
|
+
p0_noise=self.refiner_p0_noise,
|
|
236
|
+
p0_noise_kwargs=copy.deepcopy(self.refiner_p0_noise_kwargs) if self.refiner_p0_noise_kwargs is not None else None,
|
|
237
|
+
converge_error='ignore',
|
|
238
|
+
)
|
|
239
|
+
except ConvergenceError:
|
|
240
|
+
continue
|
|
241
|
+
|
|
242
|
+
if len(refiner._all_constants_values) == 0:
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
has_constants = len(refiner.constants_symbols) > 0
|
|
246
|
+
valid_fit = refiner.valid_fit or not has_constants
|
|
247
|
+
if not valid_fit:
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
loss = float(refiner._all_constants_values[0][-1])
|
|
251
|
+
if not np.isfinite(loss):
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
fvu = self._compute_fvu(loss, sample_count, y_variance)
|
|
255
|
+
if not np.isfinite(fvu):
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
constant_count = self._count_constants(expression_tokens)
|
|
259
|
+
score = self._score_from_fvu(
|
|
260
|
+
fvu,
|
|
261
|
+
len(expression_tokens),
|
|
262
|
+
constant_count,
|
|
263
|
+
None,
|
|
264
|
+
self.length_penalty,
|
|
265
|
+
self.constants_penalty,
|
|
266
|
+
self.likelihood_penalty,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
results.append({
|
|
270
|
+
'log_prob': float('nan'),
|
|
271
|
+
'fvu': fvu,
|
|
272
|
+
'score': score,
|
|
273
|
+
'expression': expression_tokens,
|
|
274
|
+
'constant_count': constant_count,
|
|
275
|
+
'complexity': len(expression_tokens),
|
|
276
|
+
'requested_complexity': None,
|
|
277
|
+
'raw_beam': expression_tokens,
|
|
278
|
+
'beam': expression_tokens,
|
|
279
|
+
'raw_beam_decoded': ' '.join(expression_tokens),
|
|
280
|
+
'function': refiner.expression_lambda,
|
|
281
|
+
'refiner': refiner,
|
|
282
|
+
'fits': copy.deepcopy(refiner._all_constants_values),
|
|
283
|
+
'prompt_metadata': None,
|
|
284
|
+
})
|
|
285
|
+
|
|
286
|
+
if len(results) >= self.max_expressions:
|
|
287
|
+
break
|
|
288
|
+
|
|
289
|
+
np.seterr(**numpy_state)
|
|
290
|
+
|
|
291
|
+
results.sort(key=lambda item: item['score'])
|
|
292
|
+
|
|
293
|
+
self._results = results
|
|
294
|
+
self.results = pd.DataFrame(results)
|
|
295
|
+
return self
|
|
296
|
+
|
|
297
|
+
def predict(self, X: np.ndarray | torch.Tensor | pd.DataFrame, nth_best: int = 0) -> np.ndarray:
|
|
298
|
+
if not self._results:
|
|
299
|
+
raise ValueError("The model has not been fitted yet. Please call `fit` first.")
|
|
300
|
+
|
|
301
|
+
if nth_best >= len(self._results):
|
|
302
|
+
raise IndexError(f"nth_best={nth_best} is out of range for {len(self._results)} results.")
|
|
303
|
+
|
|
304
|
+
refiner = self._results[nth_best]['refiner']
|
|
305
|
+
|
|
306
|
+
if isinstance(X, torch.Tensor):
|
|
307
|
+
X_np = X.detach().cpu().numpy()
|
|
308
|
+
elif isinstance(X, pd.DataFrame):
|
|
309
|
+
X_np = X.values
|
|
310
|
+
else:
|
|
311
|
+
X_np = np.asarray(X)
|
|
312
|
+
|
|
313
|
+
X_np = self._truncate_input(np.asarray(X_np))
|
|
314
|
+
return refiner.predict(X_np)
|
|
315
|
+
|
|
316
|
+
def get_expression(self, nth_best: int = 0, *, return_prefix: bool = False, precision: int = 2) -> list[str] | str:
|
|
317
|
+
if not self._results:
|
|
318
|
+
raise ValueError("The model has not been fitted yet. Please call `fit` first.")
|
|
319
|
+
|
|
320
|
+
if nth_best >= len(self._results):
|
|
321
|
+
raise IndexError(f"nth_best={nth_best} is out of range for {len(self._results)} results.")
|
|
322
|
+
|
|
323
|
+
refiner = self._results[nth_best]['refiner']
|
|
324
|
+
return refiner.transform(
|
|
325
|
+
self._results[nth_best]['expression'],
|
|
326
|
+
nth_best_constants=0,
|
|
327
|
+
return_prefix=return_prefix,
|
|
328
|
+
precision=precision,
|
|
329
|
+
)
|