srbf 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
srbf/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ """srbf: the symbolic-regression evaluation framework, carved from flash-ansr.
2
+
3
+ Engine + model adapters + benchmarks + metrics for evaluating symbolic-regression models.
4
+ Depends one-way on flash-ansr (srbf imports flash-ansr; flash-ansr never imports srbf).
5
+ """
6
+ from srbf.eval.evaluation import Evaluation
7
+ from srbf.eval.run_config import EvaluationRunPlan, build_evaluation_run
8
+
9
+ __all__ = ["Evaluation", "EvaluationRunPlan", "build_evaluation_run"]
10
+ __version__ = "0.1.0"
srbf/__main__.py ADDED
@@ -0,0 +1,96 @@
1
+ """srbf command-line interface: the ``run`` subcommand carved from flash-ansr.
2
+
3
+ flash-ansr keeps the rest of its CLI (train / benchmark / import-data / install / ...); only
4
+ ``run`` is evaluation-bound and lives here. The eval imports are ``srbf.eval.*``; the
5
+ flash-ansr ``utils`` imports are the cross-repo contract (srbf depends one-way on flash-ansr).
6
+ """
7
+ import argparse
8
+
9
+
10
+ def main(argv: list[str] | None = None) -> None:
11
+ parser = argparse.ArgumentParser(description="srbf: symbolic-regression evaluation framework")
12
+ subparsers = parser.add_subparsers(dest="command_name", required=True)
13
+
14
+ run_parser = subparsers.add_parser("run", help="Run an evaluation from a unified config")
15
+ run_parser.add_argument('-c', '--config', type=str, required=True, help='Path to the evaluation run config file')
16
+ run_parser.add_argument('-n', '--limit', type=int, default=None, help='Override the sample limit specified in the config')
17
+ run_parser.add_argument('-o', '--output-file', type=str, default=None, help='Override the output file path from the config')
18
+ run_parser.add_argument('--save-every', type=int, default=None, help='Override periodic save frequency')
19
+ run_parser.add_argument('--no-resume', action='store_true', help='Ignore previous results even if the output file exists')
20
+ run_parser.add_argument('--experiment', type=str, default=None, help='Name of the experiment defined in the config to execute')
21
+ run_parser.add_argument('-v', '--verbose', action='store_true', help='Print a progress bar')
22
+
23
+ args = parser.parse_args(argv)
24
+
25
+ match args.command_name:
26
+ case 'run':
27
+ from srbf.eval.run_config import build_evaluation_run, EvaluationRunPlan
28
+ from flash_ansr.utils.config_io import load_config
29
+ from flash_ansr.utils.paths import substitute_root_path
30
+
31
+ config_path = substitute_root_path(args.config)
32
+ if args.verbose:
33
+ print(f"Running evaluation plan from {config_path}")
34
+
35
+ raw_config = load_config(config_path)
36
+ experiment_map = raw_config.get("experiments") if isinstance(raw_config, dict) else None
37
+
38
+ from srbf.eval.provenance import collect_provenance, format_provenance
39
+ base_prov = collect_provenance(config_path, None)
40
+ print(format_provenance(base_prov), flush=True)
41
+
42
+ def _execute_plan(plan: EvaluationRunPlan, experiment_name: str | None = None) -> None:
43
+ label = f"[{experiment_name}] " if experiment_name else ""
44
+ if plan.completed or plan.engine is None:
45
+ if args.verbose:
46
+ target = plan.total_limit or 'configured'
47
+ print(f"{label}Evaluation already completed ({plan.existing_results}/{target}). Nothing to do.")
48
+ return
49
+
50
+ plan.engine.run(
51
+ limit=plan.remaining,
52
+ save_every=plan.save_every,
53
+ output_path=plan.output_path,
54
+ verbose=args.verbose,
55
+ progress=args.verbose,
56
+ meta={**base_prov, "experiment": experiment_name},
57
+ )
58
+
59
+ if args.verbose:
60
+ total = plan.engine.result_store.size
61
+ destination = plan.output_path or 'memory'
62
+ print(f"{label}Evaluation finished with {total} samples (saved to {destination}).")
63
+
64
+ if experiment_map and args.experiment is None:
65
+ experiment_names = list(experiment_map.keys())
66
+ if args.verbose:
67
+ count = len(experiment_names)
68
+ print(f"No --experiment provided; running all {count} experiments defined in config.")
69
+ for experiment_name in experiment_names:
70
+ if args.verbose:
71
+ print(f"--> {experiment_name}")
72
+ plan = build_evaluation_run(
73
+ config=config_path,
74
+ limit_override=args.limit,
75
+ output_override=args.output_file,
76
+ save_every_override=args.save_every,
77
+ resume=None if not args.no_resume else False,
78
+ experiment=experiment_name,
79
+ )
80
+ _execute_plan(plan, experiment_name)
81
+ else:
82
+ plan = build_evaluation_run(
83
+ config=config_path,
84
+ limit_override=args.limit,
85
+ output_override=args.output_file,
86
+ save_every_override=args.save_every,
87
+ resume=None if not args.no_resume else False,
88
+ experiment=args.experiment,
89
+ )
90
+ _execute_plan(plan, args.experiment)
91
+ case _:
92
+ parser.print_help()
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()
@@ -0,0 +1,7 @@
1
+ from .skeleton_pool_model import SkeletonPoolModel
2
+ from .brute_force_model import BruteForceModel
3
+
4
+ __all__ = [
5
+ "SkeletonPoolModel",
6
+ "BruteForceModel",
7
+ ]
@@ -0,0 +1,329 @@
1
+ import copy
2
+ import os
3
+ from collections import defaultdict
4
+ from typing import Any, Generator, Literal, Sequence
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from sklearn.base import BaseEstimator
10
+ from simplipy import SimpliPyEngine
11
+ from simplipy.utils import construct_expressions
12
+
13
+ from flash_ansr.expressions import SkeletonPool
14
+ from flash_ansr.refine import Refiner, ConvergenceError
15
+ from flash_ansr.scoring import compute_fvu, count_constants, is_constant_token, normalize_variance, score_from_fvu
16
+ from flash_ansr.utils.paths import substitute_root_path
17
+
18
+
19
+ class BruteForceModel(BaseEstimator):
20
+ """Exhaustive baseline that enumerates expressions in increasing length.
21
+
22
+ Expressions are generated shortest-first using ``simplipy.utils.construct_expressions``
23
+ over the operator and variable vocabulary defined by the provided
24
+ ``SkeletonPool``. Each candidate is refined with the shared ``Refiner`` to
25
+ fit constants against user-supplied data.
26
+ """
27
+
28
+ FLOAT64_EPS: float = float(np.finfo(np.float64).eps)
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ simplipy_engine: SimpliPyEngine,
34
+ skeleton_pool: str | dict[str, Any] | SkeletonPool,
35
+ max_expressions: int = 10_000,
36
+ max_length: int | None = None,
37
+ include_constant_token: bool = True,
38
+ ignore_holdouts: bool = True,
39
+ n_restarts: int = 8,
40
+ refiner_method: Literal[
41
+ 'curve_fit_lm',
42
+ 'minimize_bfgs',
43
+ 'minimize_lbfgsb',
44
+ 'minimize_neldermead',
45
+ 'minimize_powell',
46
+ 'least_squares_trf',
47
+ 'least_squares_dogbox',
48
+ ] = 'curve_fit_lm',
49
+ refiner_p0_noise: Literal['uniform', 'normal'] | None = 'normal',
50
+ refiner_p0_noise_kwargs: dict | Literal['default'] | None = 'default',
51
+ numpy_errors: Literal['ignore', 'warn', 'raise', 'call', 'print', 'log'] | None = 'ignore',
52
+ length_penalty: float = 0.05,
53
+ constants_penalty: float = 0.0,
54
+ likelihood_penalty: float = 0.0,
55
+ ) -> None:
56
+ self.simplipy_engine = simplipy_engine
57
+ self.max_expressions = int(max_expressions)
58
+ self.max_length = max_length
59
+ self.include_constant_token = include_constant_token
60
+ self.ignore_holdouts = ignore_holdouts
61
+ self.n_restarts = n_restarts
62
+ self.refiner_method = refiner_method
63
+ self.refiner_p0_noise = refiner_p0_noise
64
+ if refiner_p0_noise_kwargs == 'default':
65
+ refiner_p0_noise_kwargs = {'loc': 0.0, 'scale': 5.0}
66
+ self.refiner_p0_noise_kwargs = copy.deepcopy(refiner_p0_noise_kwargs) if refiner_p0_noise_kwargs is not None else None
67
+ self.numpy_errors = numpy_errors
68
+ self.length_penalty = float(length_penalty)
69
+ self.constants_penalty = float(constants_penalty)
70
+ self.likelihood_penalty = float(likelihood_penalty)
71
+
72
+ self._pool = self._ensure_pool(skeleton_pool)
73
+ self._results: list[dict[str, Any]] = []
74
+ self.results: pd.DataFrame = pd.DataFrame()
75
+ self._input_dim: int | None = None
76
+
77
+ @property
78
+ def n_variables(self) -> int:
79
+ return self._pool.n_variables
80
+
81
+ def _ensure_pool(self, skeleton_pool_ref: str | dict[str, Any] | SkeletonPool) -> SkeletonPool:
82
+ if isinstance(skeleton_pool_ref, SkeletonPool):
83
+ pool = skeleton_pool_ref
84
+ elif isinstance(skeleton_pool_ref, str):
85
+ resolved = substitute_root_path(skeleton_pool_ref)
86
+ if os.path.isdir(resolved):
87
+ _, pool = SkeletonPool.load(resolved)
88
+ else:
89
+ pool = SkeletonPool.from_config(resolved)
90
+ elif isinstance(skeleton_pool_ref, dict):
91
+ pool = SkeletonPool.from_config(copy.deepcopy(skeleton_pool_ref))
92
+ else:
93
+ raise TypeError("`skeleton_pool` must be a SkeletonPool, path string, or configuration dictionary.")
94
+
95
+ if self.ignore_holdouts:
96
+ pool.clear_holdouts()
97
+
98
+ return pool
99
+
100
+ def _truncate_input(self, X: np.ndarray) -> np.ndarray:
101
+ n_features = X.shape[-1]
102
+ if n_features == self.n_variables:
103
+ return X
104
+ if n_features < self.n_variables:
105
+ pad_width = self.n_variables - n_features
106
+ pad = np.zeros((*X.shape[:-1], pad_width), dtype=X.dtype)
107
+ return np.concatenate([X, pad], axis=-1)
108
+
109
+ return X[..., : self.n_variables]
110
+
111
+ @staticmethod
112
+ def _normalize_variance(variance: float) -> float:
113
+ return normalize_variance(variance)
114
+
115
+ @staticmethod
116
+ def _compute_fvu(loss: float, sample_count: int, variance: float) -> float:
117
+ return compute_fvu(loss, sample_count, variance)
118
+
119
+ @staticmethod
120
+ def _is_constant_token(token: str) -> bool:
121
+ return is_constant_token(token)
122
+
123
+ @classmethod
124
+ def _count_constants(cls, expression: Sequence[str]) -> int:
125
+ return count_constants(expression)
126
+
127
+ @staticmethod
128
+ def _score_from_fvu(
129
+ fvu: float,
130
+ complexity: int,
131
+ constant_count: int,
132
+ log_prob: float | None,
133
+ length_penalty: float,
134
+ constants_penalty: float,
135
+ likelihood_penalty: float) -> float:
136
+ return score_from_fvu(
137
+ fvu, complexity, constant_count, log_prob,
138
+ length_penalty, constants_penalty, likelihood_penalty)
139
+
140
+ def _leaf_nodes(self) -> list[str]:
141
+ leaves = list(self._pool.variables)
142
+ if self.include_constant_token:
143
+ leaves.append('<constant>')
144
+ return leaves
145
+
146
+ def _non_leaf_nodes(self) -> dict[str, int]:
147
+ operator_weights = self._pool.operator_weights or {}
148
+ return {op: arity for op, arity in self.simplipy_engine.operator_arity.items() if operator_weights.get(op, 0) > 0}
149
+
150
+ def _expression_generator(self) -> Generator[tuple[str, ...], None, None]:
151
+ hashes_by_size: defaultdict[int, set[tuple[str, ...]]] = defaultdict(set)
152
+ seen: set[tuple[str, ...]] = set()
153
+
154
+ for leaf in self._leaf_nodes():
155
+ expr = (leaf,)
156
+ hashes_by_size[1].add(expr)
157
+ seen.add(expr)
158
+ yield expr
159
+ if len(seen) >= self.max_expressions:
160
+ return
161
+
162
+ target_length = 2
163
+ while len(seen) < self.max_expressions:
164
+ new_expressions: list[tuple[str, ...]] = []
165
+ for expr in construct_expressions(hashes_by_size, self._non_leaf_nodes(), must_have_sizes=None):
166
+ expr_len = len(expr)
167
+ if self.max_length is not None and expr_len > self.max_length:
168
+ continue
169
+ if expr_len != target_length:
170
+ continue
171
+ if expr in seen:
172
+ continue
173
+ if not self.simplipy_engine.is_valid(list(expr)):
174
+ continue
175
+
176
+ seen.add(expr)
177
+ new_expressions.append(expr)
178
+ yield expr
179
+ if len(seen) >= self.max_expressions:
180
+ break
181
+
182
+ if not new_expressions:
183
+ break
184
+
185
+ hashes_by_size[target_length].update(new_expressions)
186
+ target_length += 1
187
+
188
+ def fit(self, X: np.ndarray | torch.Tensor | pd.DataFrame, y: np.ndarray | torch.Tensor | pd.DataFrame | Sequence[float], *, verbose: bool = False) -> "BruteForceModel":
189
+ if len(np.shape(y)) == 1:
190
+ y = np.reshape(y, (-1, 1))
191
+
192
+ if isinstance(X, torch.Tensor):
193
+ X_np = X.detach().cpu().numpy()
194
+ elif isinstance(X, pd.DataFrame):
195
+ X_np = X.values
196
+ else:
197
+ X_np = np.asarray(X)
198
+
199
+ if isinstance(y, torch.Tensor):
200
+ y_np = y.detach().cpu().numpy()
201
+ elif isinstance(y, (pd.DataFrame, pd.Series)):
202
+ y_np = y.values
203
+ else:
204
+ y_np = np.asarray(y)
205
+
206
+ if y_np.ndim == 1:
207
+ y_np = y_np.reshape(-1, 1)
208
+ elif y_np.shape[-1] != 1:
209
+ raise ValueError("The target data must have a single output dimension.")
210
+
211
+ X_np = self._truncate_input(np.asarray(X_np))
212
+ self._input_dim = X_np.shape[1]
213
+
214
+ sample_count = y_np.shape[0]
215
+ if sample_count <= 1:
216
+ y_variance = float('nan')
217
+ else:
218
+ y_variance = float(np.var(y_np, axis=0, ddof=1).item())
219
+
220
+ numpy_state = np.geterr()
221
+ np.seterr(all=self.numpy_errors)
222
+
223
+ results: list[dict[str, Any]] = []
224
+ for skeleton in self._expression_generator():
225
+ expression_tokens = list(skeleton)
226
+
227
+ try:
228
+ refiner = Refiner(self.simplipy_engine, n_variables=self.n_variables).fit(
229
+ expression=expression_tokens,
230
+ X=X_np,
231
+ y=y_np,
232
+ n_restarts=self.n_restarts,
233
+ method=self.refiner_method,
234
+ p0=None,
235
+ p0_noise=self.refiner_p0_noise,
236
+ p0_noise_kwargs=copy.deepcopy(self.refiner_p0_noise_kwargs) if self.refiner_p0_noise_kwargs is not None else None,
237
+ converge_error='ignore',
238
+ )
239
+ except ConvergenceError:
240
+ continue
241
+
242
+ if len(refiner._all_constants_values) == 0:
243
+ continue
244
+
245
+ has_constants = len(refiner.constants_symbols) > 0
246
+ valid_fit = refiner.valid_fit or not has_constants
247
+ if not valid_fit:
248
+ continue
249
+
250
+ loss = float(refiner._all_constants_values[0][-1])
251
+ if not np.isfinite(loss):
252
+ continue
253
+
254
+ fvu = self._compute_fvu(loss, sample_count, y_variance)
255
+ if not np.isfinite(fvu):
256
+ continue
257
+
258
+ constant_count = self._count_constants(expression_tokens)
259
+ score = self._score_from_fvu(
260
+ fvu,
261
+ len(expression_tokens),
262
+ constant_count,
263
+ None,
264
+ self.length_penalty,
265
+ self.constants_penalty,
266
+ self.likelihood_penalty,
267
+ )
268
+
269
+ results.append({
270
+ 'log_prob': float('nan'),
271
+ 'fvu': fvu,
272
+ 'score': score,
273
+ 'expression': expression_tokens,
274
+ 'constant_count': constant_count,
275
+ 'complexity': len(expression_tokens),
276
+ 'requested_complexity': None,
277
+ 'raw_beam': expression_tokens,
278
+ 'beam': expression_tokens,
279
+ 'raw_beam_decoded': ' '.join(expression_tokens),
280
+ 'function': refiner.expression_lambda,
281
+ 'refiner': refiner,
282
+ 'fits': copy.deepcopy(refiner._all_constants_values),
283
+ 'prompt_metadata': None,
284
+ })
285
+
286
+ if len(results) >= self.max_expressions:
287
+ break
288
+
289
+ np.seterr(**numpy_state)
290
+
291
+ results.sort(key=lambda item: item['score'])
292
+
293
+ self._results = results
294
+ self.results = pd.DataFrame(results)
295
+ return self
296
+
297
+ def predict(self, X: np.ndarray | torch.Tensor | pd.DataFrame, nth_best: int = 0) -> np.ndarray:
298
+ if not self._results:
299
+ raise ValueError("The model has not been fitted yet. Please call `fit` first.")
300
+
301
+ if nth_best >= len(self._results):
302
+ raise IndexError(f"nth_best={nth_best} is out of range for {len(self._results)} results.")
303
+
304
+ refiner = self._results[nth_best]['refiner']
305
+
306
+ if isinstance(X, torch.Tensor):
307
+ X_np = X.detach().cpu().numpy()
308
+ elif isinstance(X, pd.DataFrame):
309
+ X_np = X.values
310
+ else:
311
+ X_np = np.asarray(X)
312
+
313
+ X_np = self._truncate_input(np.asarray(X_np))
314
+ return refiner.predict(X_np)
315
+
316
+ def get_expression(self, nth_best: int = 0, *, return_prefix: bool = False, precision: int = 2) -> list[str] | str:
317
+ if not self._results:
318
+ raise ValueError("The model has not been fitted yet. Please call `fit` first.")
319
+
320
+ if nth_best >= len(self._results):
321
+ raise IndexError(f"nth_best={nth_best} is out of range for {len(self._results)} results.")
322
+
323
+ refiner = self._results[nth_best]['refiner']
324
+ return refiner.transform(
325
+ self._results[nth_best]['expression'],
326
+ nth_best_constants=0,
327
+ return_prefix=return_prefix,
328
+ precision=precision,
329
+ )