symbolic-data 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symbolic_data/__init__.py +19 -0
- symbolic_data/benchmarks/__init__.py +4 -0
- symbolic_data/benchmarks/fastsrb.py +524 -0
- symbolic_data/compilation.py +27 -0
- symbolic_data/config_io.py +155 -0
- symbolic_data/datasets.py +77 -0
- symbolic_data/distributions.py +174 -0
- symbolic_data/holdout.py +106 -0
- symbolic_data/paths.py +68 -0
- symbolic_data/prior_factory.py +31 -0
- symbolic_data/registry.py +133 -0
- symbolic_data/samples.py +255 -0
- symbolic_data/skeleton_pool.py +817 -0
- symbolic_data/skeleton_sampling.py +129 -0
- symbolic_data/structure.py +24 -0
- symbolic_data/support_sampling.py +457 -0
- symbolic_data/sympy_timeout.py +89 -0
- symbolic_data/tensor_ops.py +55 -0
- symbolic_data/token_ops.py +72 -0
- symbolic_data-0.1.0.dist-info/METADATA +74 -0
- symbolic_data-0.1.0.dist-info/RECORD +24 -0
- symbolic_data-0.1.0.dist-info/WHEEL +5 -0
- symbolic_data-0.1.0.dist-info/licenses/LICENSE +21 -0
- symbolic_data-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""symbolic_data -- the model-agnostic symbolic-regression data layer.
|
|
2
|
+
|
|
3
|
+
Skeleton/expression sampling, priors, (X, y) support sampling, holdout management, and
|
|
4
|
+
dataset construction -- carved out of flash-ansr so symbolic-regression methods and the
|
|
5
|
+
srbf eval framework share one data substrate. Depends only on simplipy + numpy/sklearn.
|
|
6
|
+
"""
|
|
7
|
+
from symbolic_data.skeleton_pool import SkeletonPool, NoValidSampleFoundError
|
|
8
|
+
from symbolic_data.holdout import HoldoutManager
|
|
9
|
+
from symbolic_data.skeleton_sampling import SkeletonSampler
|
|
10
|
+
from symbolic_data.support_sampling import SupportSampler, SupportSamplingError
|
|
11
|
+
from symbolic_data.distributions import get_distribution, DISTRIBUTIONS, BASE_DISTRIBUTIONS
|
|
12
|
+
from symbolic_data.prior_factory import build_prior_callable
|
|
13
|
+
from symbolic_data.registry import Registry
|
|
14
|
+
from symbolic_data.samples import Sample, sample_from_skeleton, iter_samples
|
|
15
|
+
from symbolic_data.tensor_ops import mask_unused_variable_columns
|
|
16
|
+
from symbolic_data.datasets import load_benchmark, BENCHMARKS
|
|
17
|
+
from symbolic_data.benchmarks import FastSRBBenchmark
|
|
18
|
+
from symbolic_data.paths import get_path, get_root, substitute_root_path
|
|
19
|
+
from symbolic_data.config_io import load_config, save_config
|
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
"""Utilities for sampling the FastSRB benchmark equations using SimpliPy.
|
|
2
|
+
|
|
3
|
+
Translated and adapted from the Julia FastSRB benchmarking code by Viktor Martinek
|
|
4
|
+
(https://github.com/viktmar/FastSRB, arXiv:2508.14481), distributed under the MIT License.
|
|
5
|
+
The full notice + citation is reproduced in ``THIRD_PARTY_LICENSES`` (FastSRB section).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
import warnings
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import yaml
|
|
17
|
+
|
|
18
|
+
from simplipy import SimpliPyEngine
|
|
19
|
+
|
|
20
|
+
from simplipy.utils import codify
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
Number = Union[int, float]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class FastSRBBenchmark:
|
|
27
|
+
"""Sample datasets from the FastSRB benchmark YAML specification."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
yaml_path: Union[str, Path],
|
|
32
|
+
*,
|
|
33
|
+
simplipy_engine: SimpliPyEngine | str = "dev_7-3",
|
|
34
|
+
random_state: Optional[Union[int, np.random.Generator]] = None,
|
|
35
|
+
) -> None:
|
|
36
|
+
"""Load the YAML benchmark specification and prepare a SimpliPy engine."""
|
|
37
|
+
|
|
38
|
+
path = Path(yaml_path)
|
|
39
|
+
if not path.exists():
|
|
40
|
+
raise FileNotFoundError(path)
|
|
41
|
+
|
|
42
|
+
with path.open("r", encoding="utf-8") as handle:
|
|
43
|
+
entries = yaml.safe_load(handle)
|
|
44
|
+
|
|
45
|
+
if not isinstance(entries, Mapping):
|
|
46
|
+
raise ValueError("Benchmark specification must be a mapping from equation ids to entries.")
|
|
47
|
+
|
|
48
|
+
self._entries: Dict[str, MutableMapping[str, Any]] = dict(entries)
|
|
49
|
+
self._yaml_path = path
|
|
50
|
+
self._rng = self._resolve_rng(random_state)
|
|
51
|
+
|
|
52
|
+
self._simplipy_engine = simplipy_engine if isinstance(simplipy_engine, SimpliPyEngine) else SimpliPyEngine.load(simplipy_engine, install=True)
|
|
53
|
+
self._compiled_cache: Dict[str, Dict[str, Any]] = {}
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def _resolve_rng(random_state: Optional[Union[int, np.random.Generator]]) -> np.random.Generator:
|
|
57
|
+
if isinstance(random_state, np.random.Generator):
|
|
58
|
+
return random_state
|
|
59
|
+
return np.random.default_rng(random_state)
|
|
60
|
+
|
|
61
|
+
def equation_ids(self) -> List[str]:
|
|
62
|
+
"""Return the identifiers of all benchmark equations."""
|
|
63
|
+
return list(self._entries.keys())
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def _resolve_variable_order(vars_info: Mapping[str, Mapping[str, Any]]) -> List[str]:
|
|
67
|
+
candidate_keys = [key for key in vars_info.keys() if key.startswith("v") and key != "v0"]
|
|
68
|
+
if not candidate_keys:
|
|
69
|
+
raise ValueError("Entry does not define any input variables")
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
indices = sorted(int(key[1:]) for key in candidate_keys)
|
|
73
|
+
except ValueError as exc:
|
|
74
|
+
raise ValueError("Variable identifiers must follow the 'v<int>' pattern") from exc
|
|
75
|
+
|
|
76
|
+
variable_order: List[str] = []
|
|
77
|
+
for idx in range(1, indices[-1] + 1):
|
|
78
|
+
key = f"v{idx}"
|
|
79
|
+
if key not in vars_info:
|
|
80
|
+
raise KeyError(f"Missing sampling specification for {key}")
|
|
81
|
+
variable_order.append(key)
|
|
82
|
+
return variable_order
|
|
83
|
+
|
|
84
|
+
@staticmethod
|
|
85
|
+
def _normalize_prepared_expression(expression: str) -> str:
|
|
86
|
+
"""Normalize prepared expressions so SimpliPy can parse them consistently."""
|
|
87
|
+
|
|
88
|
+
return expression.replace("^", "**")
|
|
89
|
+
|
|
90
|
+
def _compile_expression(self, eq_id: str, entry: Mapping[str, Any]) -> Dict[str, Any]:
|
|
91
|
+
cache = self._compiled_cache.get(eq_id)
|
|
92
|
+
if cache is not None:
|
|
93
|
+
return cache
|
|
94
|
+
|
|
95
|
+
prepared = entry.get("prepared")
|
|
96
|
+
if not isinstance(prepared, str) or not prepared.strip():
|
|
97
|
+
raise ValueError(f"Entry {eq_id} has no prepared expression")
|
|
98
|
+
|
|
99
|
+
prepared_text = self._normalize_prepared_expression(prepared)
|
|
100
|
+
|
|
101
|
+
vars_info = entry.get("vars")
|
|
102
|
+
if not isinstance(vars_info, Mapping):
|
|
103
|
+
raise ValueError(f"Entry {eq_id} has no variable definitions")
|
|
104
|
+
|
|
105
|
+
variable_order = self._resolve_variable_order(vars_info)
|
|
106
|
+
|
|
107
|
+
prefix_parsed = self._simplipy_engine.parse(prepared_text, mask_numbers=False)
|
|
108
|
+
try:
|
|
109
|
+
prefix_simplified = self._simplipy_engine.simplify(prefix_parsed, max_pattern_length=4)
|
|
110
|
+
except Exception as exc: # pragma: no cover - defensive against SimpliPy regressions
|
|
111
|
+
warnings.warn(
|
|
112
|
+
f"Failed to simplify FastSRB expression {eq_id}: {exc}. Falling back to unsimplified prefix.",
|
|
113
|
+
RuntimeWarning,
|
|
114
|
+
)
|
|
115
|
+
prefix_simplified = prefix_parsed
|
|
116
|
+
|
|
117
|
+
used_variables = {token for token in prefix_simplified if isinstance(token, str) and token.startswith("v")}
|
|
118
|
+
unknown_variables = used_variables - set(variable_order) - {"v0"}
|
|
119
|
+
if unknown_variables:
|
|
120
|
+
unknown_str = ", ".join(sorted(unknown_variables))
|
|
121
|
+
raise KeyError(f"Prepared expression for {eq_id} references undefined variables: {unknown_str}")
|
|
122
|
+
|
|
123
|
+
prefix_realized = self._simplipy_engine.operators_to_realizations(prefix_parsed)
|
|
124
|
+
code_string = self._simplipy_engine.prefix_to_infix(prefix_realized, realization=True)
|
|
125
|
+
code = codify(code_string, variable_order)
|
|
126
|
+
expression_callable = self._simplipy_engine.code_to_lambda(code)
|
|
127
|
+
normalized_infix = self._simplipy_engine.prefix_to_infix(prefix_simplified, realization=False)
|
|
128
|
+
|
|
129
|
+
cache = {
|
|
130
|
+
"code": code,
|
|
131
|
+
"callable": expression_callable,
|
|
132
|
+
"variable_order": variable_order,
|
|
133
|
+
"prefix": tuple(prefix_simplified),
|
|
134
|
+
"normalized_infix": normalized_infix,
|
|
135
|
+
}
|
|
136
|
+
self._compiled_cache[eq_id] = cache
|
|
137
|
+
return cache
|
|
138
|
+
|
|
139
|
+
def _evaluate(self, compiled: Dict[str, Any], values: Mapping[str, Any]) -> Any:
|
|
140
|
+
ordered_inputs = [values[name] for name in compiled["variable_order"]]
|
|
141
|
+
with np.errstate(all="ignore"):
|
|
142
|
+
return compiled["callable"](*ordered_inputs)
|
|
143
|
+
|
|
144
|
+
def _sample_points(
|
|
145
|
+
self,
|
|
146
|
+
low: Number,
|
|
147
|
+
high: Number,
|
|
148
|
+
n_points: int,
|
|
149
|
+
*,
|
|
150
|
+
method: str,
|
|
151
|
+
distribution: str,
|
|
152
|
+
sign_mode: str,
|
|
153
|
+
integer: bool,
|
|
154
|
+
rng: np.random.Generator,
|
|
155
|
+
) -> np.ndarray:
|
|
156
|
+
if method not in {"random", "range"}:
|
|
157
|
+
raise ValueError("method must be 'random' or 'range'")
|
|
158
|
+
if distribution not in {"uni", "log"}:
|
|
159
|
+
raise ValueError("distribution must be 'uni' or 'log'")
|
|
160
|
+
if sign_mode not in {"pos", "neg", "pos_neg"}:
|
|
161
|
+
raise ValueError("sign_mode must be 'pos', 'neg', or 'pos_neg'")
|
|
162
|
+
if n_points < 1:
|
|
163
|
+
raise ValueError("n_points must be at least 1")
|
|
164
|
+
if method == "range" and n_points == 1:
|
|
165
|
+
warnings.warn("Sampling one point with method='range' is degenerate; consider method='random'", RuntimeWarning, stacklevel=2)
|
|
166
|
+
low_f = float(low)
|
|
167
|
+
high_f = float(high)
|
|
168
|
+
if low_f > high_f:
|
|
169
|
+
raise ValueError("sample_range lower bound must not exceed upper bound")
|
|
170
|
+
if math.isclose(low_f, high_f):
|
|
171
|
+
arr = np.full(n_points, high_f, dtype=float)
|
|
172
|
+
else:
|
|
173
|
+
if distribution == "log":
|
|
174
|
+
if low_f <= 0 or high_f <= 0:
|
|
175
|
+
raise ValueError("log sampling requires strictly positive bounds")
|
|
176
|
+
low_val = math.log10(low_f)
|
|
177
|
+
high_val = math.log10(high_f)
|
|
178
|
+
else:
|
|
179
|
+
low_val = low_f
|
|
180
|
+
high_val = high_f
|
|
181
|
+
if method == "random":
|
|
182
|
+
arr = rng.uniform(low_val, high_val, size=n_points)
|
|
183
|
+
else:
|
|
184
|
+
arr = np.linspace(low_val, high_val, n_points)
|
|
185
|
+
rng.shuffle(arr)
|
|
186
|
+
if distribution == "log":
|
|
187
|
+
arr = 10.0 ** arr
|
|
188
|
+
if sign_mode == "neg":
|
|
189
|
+
arr = -np.abs(arr)
|
|
190
|
+
elif sign_mode == "pos_neg":
|
|
191
|
+
signs = rng.choice([-1.0, 1.0], size=arr.shape)
|
|
192
|
+
arr = arr * signs
|
|
193
|
+
if integer:
|
|
194
|
+
arr = np.rint(arr)
|
|
195
|
+
return arr.astype(float, copy=False)
|
|
196
|
+
|
|
197
|
+
def _sample_matrix(
|
|
198
|
+
self,
|
|
199
|
+
vars_info: Mapping[str, Mapping[str, Any]],
|
|
200
|
+
variable_order: Sequence[str],
|
|
201
|
+
n_points: int,
|
|
202
|
+
method: str,
|
|
203
|
+
rng: np.random.Generator,
|
|
204
|
+
) -> np.ndarray:
|
|
205
|
+
columns: List[np.ndarray] = []
|
|
206
|
+
for key in variable_order:
|
|
207
|
+
spec = vars_info.get(key)
|
|
208
|
+
if spec is None:
|
|
209
|
+
raise KeyError(f"Missing sampling spec for {key}")
|
|
210
|
+
try:
|
|
211
|
+
sample_type = spec["sample_type"]
|
|
212
|
+
sample_range = spec["sample_range"]
|
|
213
|
+
except KeyError as exc:
|
|
214
|
+
raise KeyError(f"Variable {key} is missing required field {exc.args[0]}") from exc
|
|
215
|
+
if not isinstance(sample_type, Sequence) or len(sample_type) < 2:
|
|
216
|
+
raise ValueError(f"sample_type for {key} must have two entries")
|
|
217
|
+
if not isinstance(sample_range, Sequence) or len(sample_range) < 2:
|
|
218
|
+
raise ValueError(f"sample_range for {key} must provide at least lower and upper bounds")
|
|
219
|
+
distribution = sample_type[0]
|
|
220
|
+
sign_mode = sample_type[1]
|
|
221
|
+
integer = False
|
|
222
|
+
if distribution == "int":
|
|
223
|
+
distribution = "uni"
|
|
224
|
+
integer = True
|
|
225
|
+
column = self._sample_points(
|
|
226
|
+
sample_range[0],
|
|
227
|
+
sample_range[1],
|
|
228
|
+
n_points,
|
|
229
|
+
method=method,
|
|
230
|
+
distribution=distribution,
|
|
231
|
+
sign_mode=sign_mode,
|
|
232
|
+
integer=integer,
|
|
233
|
+
rng=rng,
|
|
234
|
+
)
|
|
235
|
+
columns.append(column)
|
|
236
|
+
matrix = np.column_stack(columns)
|
|
237
|
+
return matrix.astype(float, copy=False)
|
|
238
|
+
|
|
239
|
+
def _sample_single_point(
|
|
240
|
+
self,
|
|
241
|
+
vars_info: Mapping[str, Mapping[str, Any]],
|
|
242
|
+
variable_order: Sequence[str],
|
|
243
|
+
method: str,
|
|
244
|
+
rng: np.random.Generator,
|
|
245
|
+
compiled: Dict[str, Any],
|
|
246
|
+
max_trials: int,
|
|
247
|
+
) -> np.ndarray:
|
|
248
|
+
for _ in range(max_trials):
|
|
249
|
+
values: List[float] = []
|
|
250
|
+
value_map: Dict[str, float] = {}
|
|
251
|
+
for key in variable_order:
|
|
252
|
+
spec = vars_info.get(key)
|
|
253
|
+
if spec is None:
|
|
254
|
+
raise KeyError(f"Missing sampling spec for {key}")
|
|
255
|
+
sample_type = spec["sample_type"]
|
|
256
|
+
sample_range = spec["sample_range"]
|
|
257
|
+
if not isinstance(sample_type, Sequence) or len(sample_type) < 2:
|
|
258
|
+
raise ValueError(f"sample_type for {key} must have two entries")
|
|
259
|
+
if not isinstance(sample_range, Sequence) or len(sample_range) < 2:
|
|
260
|
+
raise ValueError(f"sample_range for {key} must provide at least lower and upper bounds")
|
|
261
|
+
distribution = sample_type[0]
|
|
262
|
+
sign_mode = sample_type[1]
|
|
263
|
+
integer = False
|
|
264
|
+
if distribution == "int":
|
|
265
|
+
distribution = "uni"
|
|
266
|
+
integer = True
|
|
267
|
+
sample_value = self._sample_points(
|
|
268
|
+
sample_range[0],
|
|
269
|
+
sample_range[1],
|
|
270
|
+
1,
|
|
271
|
+
method=method,
|
|
272
|
+
distribution=distribution,
|
|
273
|
+
sign_mode=sign_mode,
|
|
274
|
+
integer=integer,
|
|
275
|
+
rng=rng,
|
|
276
|
+
)[0]
|
|
277
|
+
values.append(float(sample_value))
|
|
278
|
+
value_map[key] = float(sample_value)
|
|
279
|
+
try:
|
|
280
|
+
target = self._evaluate(compiled, value_map)
|
|
281
|
+
except Exception:
|
|
282
|
+
continue
|
|
283
|
+
combined = np.array(values + [float(target)], dtype=float)
|
|
284
|
+
if np.all(np.isfinite(combined)):
|
|
285
|
+
return combined
|
|
286
|
+
raise RuntimeError("Exceeded max_trials while sampling a single data point")
|
|
287
|
+
|
|
288
|
+
def _sample_entry(
|
|
289
|
+
self,
|
|
290
|
+
eq_id: str,
|
|
291
|
+
entry: Mapping[str, Any],
|
|
292
|
+
n_points: int,
|
|
293
|
+
method: str,
|
|
294
|
+
max_trials: int,
|
|
295
|
+
incremental: bool,
|
|
296
|
+
rng: np.random.Generator,
|
|
297
|
+
) -> np.ndarray:
|
|
298
|
+
compiled = self._compile_expression(eq_id, entry)
|
|
299
|
+
vars_info = entry.get("vars")
|
|
300
|
+
if not isinstance(vars_info, Mapping):
|
|
301
|
+
raise ValueError(f"Entry {eq_id} has no variable definitions")
|
|
302
|
+
|
|
303
|
+
variable_order = compiled["variable_order"]
|
|
304
|
+
matrix: Optional[np.ndarray] = None
|
|
305
|
+
for _ in range(max_trials):
|
|
306
|
+
try:
|
|
307
|
+
if incremental:
|
|
308
|
+
rows = [
|
|
309
|
+
self._sample_single_point(vars_info, variable_order, method, rng, compiled, max_trials)
|
|
310
|
+
for _ in range(n_points)
|
|
311
|
+
]
|
|
312
|
+
matrix = np.vstack(rows)
|
|
313
|
+
else:
|
|
314
|
+
inputs = self._sample_matrix(vars_info, variable_order, n_points, method, rng)
|
|
315
|
+
value_map = {var: inputs[:, idx] for idx, var in enumerate(variable_order)}
|
|
316
|
+
try:
|
|
317
|
+
target = self._evaluate(compiled, value_map)
|
|
318
|
+
except Exception:
|
|
319
|
+
continue
|
|
320
|
+
target_arr = np.asarray(target, dtype=float)
|
|
321
|
+
if target_arr.shape != (n_points,):
|
|
322
|
+
if target_arr.size == 1:
|
|
323
|
+
target_arr = np.full(n_points, float(target_arr), dtype=float)
|
|
324
|
+
else:
|
|
325
|
+
squeezed = np.squeeze(target_arr)
|
|
326
|
+
if squeezed.shape == (n_points,):
|
|
327
|
+
target_arr = squeezed
|
|
328
|
+
else:
|
|
329
|
+
try:
|
|
330
|
+
target_arr = np.broadcast_to(target_arr, (n_points,))
|
|
331
|
+
except ValueError as exc:
|
|
332
|
+
raise ValueError(
|
|
333
|
+
f"Could not broadcast target values to length {n_points} for {eq_id}"
|
|
334
|
+
) from exc
|
|
335
|
+
matrix = np.column_stack((inputs, target_arr))
|
|
336
|
+
if np.all(np.isfinite(matrix)):
|
|
337
|
+
return matrix
|
|
338
|
+
except Exception:
|
|
339
|
+
continue
|
|
340
|
+
raise RuntimeError(f"Failed to sample finite data for {eq_id} after {max_trials} attempts")
|
|
341
|
+
|
|
342
|
+
def sample(
|
|
343
|
+
self,
|
|
344
|
+
eq_id: str,
|
|
345
|
+
*,
|
|
346
|
+
n_points: int = 100,
|
|
347
|
+
method: str = "random",
|
|
348
|
+
max_trials: int = 100,
|
|
349
|
+
incremental: bool = False,
|
|
350
|
+
random_state: Optional[Union[int, np.random.Generator]] = None,
|
|
351
|
+
) -> Dict[str, Any]:
|
|
352
|
+
"""Sample a dataset for the requested equation."""
|
|
353
|
+
if eq_id not in self._entries:
|
|
354
|
+
raise KeyError(f"Unknown equation id: {eq_id}")
|
|
355
|
+
if n_points < 1:
|
|
356
|
+
raise ValueError("n_points must be positive")
|
|
357
|
+
rng = self._resolve_rng(random_state) if random_state is not None else self._rng
|
|
358
|
+
entry = self._entries[eq_id]
|
|
359
|
+
try:
|
|
360
|
+
matrix = self._sample_entry(eq_id, entry, n_points, method, max_trials, incremental, rng)
|
|
361
|
+
except RuntimeError as exc:
|
|
362
|
+
if not incremental:
|
|
363
|
+
warnings.warn(
|
|
364
|
+
f"Falling back to incremental sampling for {eq_id} after vectorized sampling failed: {exc}",
|
|
365
|
+
RuntimeWarning,
|
|
366
|
+
stacklevel=2,
|
|
367
|
+
)
|
|
368
|
+
matrix = self._sample_entry(eq_id, entry, n_points, method, max_trials, True, rng)
|
|
369
|
+
else:
|
|
370
|
+
raise
|
|
371
|
+
inputs = matrix[:, :-1]
|
|
372
|
+
target = matrix[:, -1]
|
|
373
|
+
vars_info = entry.get("vars", {})
|
|
374
|
+
compiled = self._compile_expression(eq_id, entry)
|
|
375
|
+
variable_order = compiled["variable_order"]
|
|
376
|
+
feature_meta: List[Dict[str, Any]] = []
|
|
377
|
+
for idx, key in enumerate(variable_order):
|
|
378
|
+
spec = vars_info.get(key, {})
|
|
379
|
+
feature_meta.append(
|
|
380
|
+
{
|
|
381
|
+
"id": key,
|
|
382
|
+
"name": spec.get("name", key),
|
|
383
|
+
"metadata": spec,
|
|
384
|
+
"values": inputs[:, idx],
|
|
385
|
+
}
|
|
386
|
+
)
|
|
387
|
+
target_meta = vars_info.get("v0", {})
|
|
388
|
+
metadata = {k: v for k, v in entry.items() if k != "vars"}
|
|
389
|
+
metadata["variable_order"] = list(variable_order)
|
|
390
|
+
metadata["prepared_prefix"] = list(compiled["prefix"])
|
|
391
|
+
metadata["prepared_normalized"] = compiled.get("normalized_infix")
|
|
392
|
+
|
|
393
|
+
return {
|
|
394
|
+
"eq_id": eq_id,
|
|
395
|
+
"metadata": metadata,
|
|
396
|
+
"n_points": n_points,
|
|
397
|
+
"method": method,
|
|
398
|
+
"incremental": incremental,
|
|
399
|
+
"data": {
|
|
400
|
+
"X": inputs,
|
|
401
|
+
"y": target,
|
|
402
|
+
"columns": list(variable_order),
|
|
403
|
+
"target": target_meta.get("name", "v0"),
|
|
404
|
+
},
|
|
405
|
+
"variables": {
|
|
406
|
+
"inputs": feature_meta,
|
|
407
|
+
"target": {
|
|
408
|
+
"id": "v0",
|
|
409
|
+
"name": target_meta.get("name", "v0"),
|
|
410
|
+
"metadata": target_meta,
|
|
411
|
+
"values": target,
|
|
412
|
+
},
|
|
413
|
+
},
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
def sample_multiple(
|
|
417
|
+
self,
|
|
418
|
+
eq_ids: Optional[Union[str, Sequence[str]]] = None,
|
|
419
|
+
*,
|
|
420
|
+
count: int = 5,
|
|
421
|
+
random_state: Optional[Union[int, np.random.Generator]] = None,
|
|
422
|
+
**sample_kwargs: Any,
|
|
423
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
424
|
+
"""Sample multiple datasets per equation.
|
|
425
|
+
|
|
426
|
+
Parameters
|
|
427
|
+
----------
|
|
428
|
+
eq_ids:
|
|
429
|
+
Optional identifier or iterable of identifiers to sample. If omitted, all
|
|
430
|
+
equations in the benchmark are used.
|
|
431
|
+
count:
|
|
432
|
+
Number of datasets to draw for each equation (default: 5).
|
|
433
|
+
random_state:
|
|
434
|
+
Optional seed or generator to make the repeated sampling reproducible.
|
|
435
|
+
sample_kwargs:
|
|
436
|
+
Additional keyword arguments forwarded to :meth:`sample` (e.g. ``n_points``).
|
|
437
|
+
|
|
438
|
+
Returns
|
|
439
|
+
-------
|
|
440
|
+
dict
|
|
441
|
+
Mapping from equation id to a list with ``count`` sampled dataset dictionaries.
|
|
442
|
+
"""
|
|
443
|
+
|
|
444
|
+
if count < 1:
|
|
445
|
+
raise ValueError("count must be a positive integer")
|
|
446
|
+
|
|
447
|
+
if eq_ids is None:
|
|
448
|
+
eq_list = list(self._entries.keys())
|
|
449
|
+
elif isinstance(eq_ids, str):
|
|
450
|
+
eq_list = [eq_ids]
|
|
451
|
+
else:
|
|
452
|
+
eq_list = list(eq_ids)
|
|
453
|
+
|
|
454
|
+
shared_rng: Optional[np.random.Generator] = None
|
|
455
|
+
if random_state is not None:
|
|
456
|
+
shared_rng = self._resolve_rng(random_state)
|
|
457
|
+
|
|
458
|
+
results: Dict[str, List[Dict[str, Any]]] = {}
|
|
459
|
+
for eq_id in eq_list:
|
|
460
|
+
datasets: List[Dict[str, Any]] = []
|
|
461
|
+
for _ in range(count):
|
|
462
|
+
if shared_rng is not None:
|
|
463
|
+
sample = self.sample(eq_id, random_state=shared_rng, **sample_kwargs)
|
|
464
|
+
else:
|
|
465
|
+
sample = self.sample(eq_id, **sample_kwargs)
|
|
466
|
+
datasets.append(sample)
|
|
467
|
+
results[eq_id] = datasets
|
|
468
|
+
return results
|
|
469
|
+
|
|
470
|
+
def iter_samples(
|
|
471
|
+
self,
|
|
472
|
+
eq_ids: Optional[Union[str, Sequence[str]]] = None,
|
|
473
|
+
*,
|
|
474
|
+
count: int = 5,
|
|
475
|
+
random_state: Optional[Union[int, np.random.Generator]] = None,
|
|
476
|
+
**sample_kwargs: Any,
|
|
477
|
+
) -> Iterable[Tuple[str, int, Dict[str, Any]]]:
|
|
478
|
+
"""Yield datasets lazily, one equation-instance at a time.
|
|
479
|
+
|
|
480
|
+
Parameters
|
|
481
|
+
----------
|
|
482
|
+
eq_ids:
|
|
483
|
+
Optional identifier or iterable of identifiers to sample. If omitted, iterate over
|
|
484
|
+
every equation in the benchmark.
|
|
485
|
+
count:
|
|
486
|
+
Number of datasets to draw for each equation (default: 5).
|
|
487
|
+
random_state:
|
|
488
|
+
Optional seed or generator to ensure reproducible iteration.
|
|
489
|
+
sample_kwargs:
|
|
490
|
+
Extra keyword arguments for :meth:`sample` (e.g. ``n_points``).
|
|
491
|
+
|
|
492
|
+
Yields
|
|
493
|
+
------
|
|
494
|
+
tuple
|
|
495
|
+
``(eq_id, index, sample_dict)`` where ``index`` runs from ``0`` to ``count-1`` for each
|
|
496
|
+
equation.
|
|
497
|
+
"""
|
|
498
|
+
|
|
499
|
+
if count < 1:
|
|
500
|
+
raise ValueError("count must be a positive integer")
|
|
501
|
+
|
|
502
|
+
if eq_ids is None:
|
|
503
|
+
eq_list = list(self._entries.keys())
|
|
504
|
+
elif isinstance(eq_ids, str):
|
|
505
|
+
eq_list = [eq_ids]
|
|
506
|
+
else:
|
|
507
|
+
eq_list = list(eq_ids)
|
|
508
|
+
|
|
509
|
+
rng = self._resolve_rng(random_state) if random_state is not None else None
|
|
510
|
+
|
|
511
|
+
for eq_id in eq_list:
|
|
512
|
+
for i in range(count):
|
|
513
|
+
try:
|
|
514
|
+
if rng is not None:
|
|
515
|
+
sample = self.sample(eq_id, random_state=rng, **sample_kwargs)
|
|
516
|
+
else:
|
|
517
|
+
sample = self.sample(eq_id, **sample_kwargs)
|
|
518
|
+
except Exception as exc: # pragma: no cover - defensive against SimpliPy edge cases
|
|
519
|
+
warnings.warn(
|
|
520
|
+
f"Failed to sample FastSRB equation {eq_id}: {exc}. Skipping remaining repeats.",
|
|
521
|
+
RuntimeWarning,
|
|
522
|
+
)
|
|
523
|
+
break
|
|
524
|
+
yield eq_id, i, sample
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Helpers for compiling and evaluating expression programs."""
|
|
2
|
+
import time
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from types import CodeType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def codify(code_string: str, variables: list[str] | None = None) -> CodeType:
|
|
11
|
+
"""Compile an infix expression body into a callable lambda."""
|
|
12
|
+
if variables is None:
|
|
13
|
+
variables = []
|
|
14
|
+
func_string = f"lambda {', '.join(variables)}: {code_string}"
|
|
15
|
+
filename = f"<lambdifygenerated-{time.time_ns()}"
|
|
16
|
+
return compile(func_string, filename, "eval")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def safe_f(f: Callable, X: np.ndarray, constants: np.ndarray | None = None) -> np.ndarray:
|
|
20
|
+
"""Evaluate ``f`` on ``X`` while normalising scalar outputs to vectors."""
|
|
21
|
+
if constants is None:
|
|
22
|
+
y = f(*X.T)
|
|
23
|
+
else:
|
|
24
|
+
y = f(*X.T, *constants)
|
|
25
|
+
if not isinstance(y, np.ndarray) or y.shape[0] == 1:
|
|
26
|
+
y = np.full(X.shape[0], y)
|
|
27
|
+
return y
|