quickseries 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from quickseries.approximate import quickseries
2
+ from quickseries.benchmark import benchmark
3
+
4
+ __version__ = "0.2.1"
@@ -0,0 +1,319 @@
1
+ import re
2
+ from inspect import getfullargspec, signature
3
+ from itertools import chain
4
+ from typing import Literal, Optional, Sequence, Union, Collection
5
+
6
+ import numpy as np
7
+ import sympy as sp
8
+ from cytoolz import groupby
9
+ from dustgoggles.func import gmap
10
+
11
+ from quickseries.expansions import multivariate_taylor, series_lambda
12
+ from quickseries.simplefit import fit
13
+ from quickseries.sourceutils import (
14
+ _cacheget, _cacheid, _finalize_quickseries, lastline
15
+ )
16
+ from quickseries.sputils import LmSig, lambdify
17
+
18
+ """signature of sympy-lambdified numpy/scipy functions"""
19
+
20
+ EXP_PATTERN = re.compile(r"\w+ ?\*\* ?(\d+)")
21
+ """what exponentials in sympy-lambdified functions look like"""
22
+
23
+
24
+ def is_simple_poly(expr: sp.Expr) -> bool:
25
+ gens = sp.poly_from_expr(expr)[1]["gens"]
26
+ return all(isinstance(g, sp.Symbol) for g in gens)
27
+
28
+
29
+ def regexponents(text: str) -> tuple[int]:
30
+ # noinspection PyTypeChecker
31
+ return tuple(map(int, re.findall(EXP_PATTERN, text)))
32
+
33
+
34
+ def _decompose(
35
+ remaining: tuple[str],
36
+ reduced: set[str],
37
+ replacements: list[tuple[int, list[int]]]
38
+ ) -> bool:
39
+ if len(remaining) == 1: # trivial case
40
+ replacements[0][1][:] = [1 for _ in range(replacements[0][0])]
41
+ return True
42
+ counts = {
43
+ k: len(v)
44
+ for k, v in groupby(lambda x: x, remaining).items()
45
+ if k not in reduced
46
+ }
47
+ if len(counts) < 2: # nothing useful left to do
48
+ return True
49
+ elif counts[max(counts)] > 1:
50
+ # don't decompose the biggest factor; because it appears more than
51
+ # once, we'd like to precompute it
52
+ reduced.add(max(counts))
53
+ return False
54
+ # otherwise, do a decomposition pass with the smallest factor
55
+ factor = sorted(counts.keys())[0]
56
+ for k, v in replacements:
57
+ factorization = []
58
+ # "divide out" `factor` from elements of existing decomposition
59
+ for f in v:
60
+ # don't decompose factors we've already evaluated, and don't
61
+ # try to divide `factor` out of smaller factors (nonsensical)
62
+ if f in reduced or f <= factor:
63
+ factorization.append(f)
64
+ continue
65
+ factorization.append(factor)
66
+ difference = f - factor
67
+ while difference >= max([e for e in remaining if e != f]):
68
+ factorization.append(factor)
69
+ difference = difference - factor
70
+ if difference > 0:
71
+ factorization.append(difference)
72
+ v[:] = factorization
73
+ reduced.add(factor)
74
+ return False
75
+
76
+
77
+ def optimize_exponents(
78
+ exps: Sequence[int],
79
+ ) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
80
+ # list of tuples like: (power, [powers to use in decomposition])
81
+ replacements = [(e, [e]) for e in exps]
82
+ # which powers have we already assessed?
83
+ reduced = set()
84
+ # which powers haven't we?
85
+ remaining = tuple(chain(*[r[1] for r in replacements]))
86
+ # NOTE: _decompose() modifies remaning and reduced inplace
87
+ while _decompose(remaining, reduced, replacements) is False:
88
+ remaining = tuple(chain(*[r[1] for r in replacements]))
89
+ # this is analogous to casting to set: we no longer care about number of
90
+ # occurrences
91
+ replacements = {k: v for k, v in replacements}
92
+ # figure out which factors we'd like to predefine as variables, and what
93
+ # the "building blocks" of those variables are. 1 is a placeholder: we
94
+ # will never define it, but it's useful in this loop.
95
+ variables = {1: [1]}
96
+ for e in sorted(set(remaining)):
97
+ if e == 1:
98
+ continue
99
+ if exps.count(e) == 1:
100
+ if not any(k > e for k, v in replacements.items()):
101
+ continue
102
+ vfactor, remainder = [], e
103
+ while remainder > 0:
104
+ pick = max([v for v in variables.keys() if v <= remainder])
105
+ vfactor.append(pick)
106
+ remainder -= pick
107
+ variables[e] = vfactor
108
+ # remove the placeholder first-order variable
109
+ variables.pop(1)
110
+ return replacements, variables
111
+
112
+
113
+ def force_line_precision(line: str, precision: Literal[16, 32, 64]) -> str:
114
+ constructor_rep = f"float{precision}"
115
+ constructor = getattr(np, f"float{precision}")
116
+ last, out = 0, ""
117
+ for match in re.finditer(
118
+ r"([+* (-]+|^)([\d.]+)(e[+\-]?\d+)?.*?([+* )]|$)", line
119
+ ):
120
+ out += line[last : match.span()[0]]
121
+ # don't replace exponents
122
+ if match.group(1) == "**":
123
+ out += line[slice(*match.span())]
124
+ else:
125
+ # NOTE: casting number to string within the f-string statement
126
+ # appears to upcast it before generating the representation.
127
+ number = str(constructor(float(match.group(2))))
128
+ out += f"{match.group(1)}{constructor_rep}({number}"
129
+ if match.group(3) is not None: # scientific notation
130
+ out += match.group(3)
131
+ out += f"){match.group(4)}"
132
+ last = match.span()[1]
133
+ return out + line[last:]
134
+
135
+
136
+ def rewrite(
137
+ poly_lambda: LmSig,
138
+ precompute: bool = True,
139
+ precision: Optional[Literal[16, 32, 64]] = None,
140
+ ) -> str:
141
+ # sympy will always place this on a single line. it includes
142
+ # the Python expression form of the hornerized polynomial
143
+ # and a return statement. lastline() grabs polynomial and strips return.
144
+ polyexpr = lastline(poly_lambda)
145
+ # remove pointless '1.0' terms
146
+ polyexpr = re.sub(r"(?:\*+)?1\.0\*+", "", polyexpr)
147
+ # names of arguments to the lambdified function
148
+ free = getfullargspec(poly_lambda).args
149
+ lines = []
150
+ if precompute is True:
151
+ polyexpr, factorlines = _rewrite_precomputed(polyexpr, free)
152
+ lines += factorlines
153
+ if precision is not None:
154
+ polyexpr = force_line_precision(polyexpr, precision)
155
+ lines.append(f"return {polyexpr}")
156
+ _, key = _cacheid()
157
+ lines.insert(0, f"def {key}({', '.join(free)}):")
158
+ return "\n ".join(lines)
159
+
160
+
161
+ def _rewrite_precomputed(
162
+ polyexpr: str, free: Collection[str]
163
+ ) -> tuple[str, list[str]]:
164
+ # replacements: what factors we will decompose each exponent into
165
+ # free: which factors we will define as variables, and their
166
+ # "building blocks"
167
+ factorlines = []
168
+ for f in free:
169
+ expat = re.compile(rf"{f}+ ?\*\* ?(\d+)")
170
+ replacements, variables = optimize_exponents(
171
+ gmap(int, expat.findall(polyexpr))
172
+ )
173
+ for k, v in variables.items():
174
+ multiplicands = []
175
+ for power in v:
176
+ if power == 1:
177
+ multiplicands.append(f)
178
+ else:
179
+ multiplicands.append(f"{f}{power}")
180
+ factorlines.append(f"{f}{k} = {'*'.join(multiplicands)}")
181
+ for k, v in replacements.items():
182
+ substitution = "*".join([f"{f}{r}" if r != 1 else f for r in v])
183
+ polyexpr = polyexpr.replace(f"{f}**{k}", substitution)
184
+ return polyexpr, factorlines
185
+
186
+
187
+ def _pvec(
188
+ bounds: Sequence[tuple[float, float]], offset_resolution: int
189
+ ) -> list[np.ndarray]:
190
+ axes = [np.linspace(*b, offset_resolution) for b in bounds]
191
+ indices = map(np.ravel, np.indices([offset_resolution for _ in bounds]))
192
+ return [j[i] for j, i in zip(axes, indices)]
193
+
194
+
195
+ def _perform_series_fit(
196
+ func: str | sp.Expr,
197
+ bounds: tuple[float, float] | Sequence[tuple[float, float]],
198
+ nterms: int,
199
+ fitres: int,
200
+ point: float | Sequence[float],
201
+ apply_bounds: bool,
202
+ is_poly: bool
203
+ ) -> tuple[sp.Expr, np.ndarray]:
204
+ if (len(bounds) == 1) and (is_poly is False):
205
+ approx, expr = series_lambda(func, point[0], nterms, True)
206
+ else:
207
+ approx, expr = multivariate_taylor(func, point, nterms, True)
208
+ lamb, vecs = lambdify(func), _pvec(bounds, fitres)
209
+ try:
210
+ dep = lamb(*vecs)
211
+ except TypeError as err:
212
+ # this is a potentially slow but unavoidable case
213
+ if "converted to Python scalars" not in str(err):
214
+ raise
215
+ dep = np.array([lamb(v) for v in vecs])
216
+ guess = [1 for _ in range(len(signature(approx).parameters) - len(vecs))]
217
+ params, _ = fit(
218
+ func=approx,
219
+ vecs=vecs,
220
+ dependent_variable=dep,
221
+ guess=guess,
222
+ bounds=(-5, 5) if apply_bounds is True else None,
223
+ )
224
+ # insert coefficients into polynomial
225
+ expr = expr.subs({f"a_{i}": coef for i, coef in enumerate(params)})
226
+ return expr, params
227
+
228
+
229
+ def _makebounds(
230
+ bounds: Optional[Sequence[tuple[float, float]] | tuple[float, float]],
231
+ n_free: int,
232
+ point: Optional[Sequence[float] | float]
233
+ ) -> tuple[list[tuple[float, float]], list[float]]:
234
+ bounds = (-1, 1) if bounds is None else bounds
235
+ if not isinstance(bounds[0], (list, tuple)):
236
+ bounds = [bounds for _ in range(n_free)]
237
+ if point is None:
238
+ point = [np.mean(b) for b in bounds]
239
+ elif not isinstance(point, (list, tuple)):
240
+ point = [point for _ in bounds]
241
+ return bounds, point
242
+
243
+
244
+ def _make_quickseries(
245
+ approx_poly: bool,
246
+ bound_series_fit: bool,
247
+ bounds: Optional[Sequence[tuple[float, float]] | tuple[float, float]],
248
+ expr: sp.Expr,
249
+ fit_series_expansion: bool,
250
+ fitres: int,
251
+ nterms: int,
252
+ point: Optional[Sequence[float] | float],
253
+ precision: Optional[Literal[16, 32, 64]],
254
+ prefactor: bool,
255
+ ) -> dict[str, sp.Expr | np.ndarray | str]:
256
+ if len(expr.free_symbols) == 0:
257
+ raise ValueError("func must have at least one free variable.")
258
+ free = sorted(expr.free_symbols, key=lambda s: str(s))
259
+ bounds, point = _makebounds(bounds, len(free), point)
260
+ output, is_poly = {}, is_simple_poly(expr)
261
+ if (approx_poly is True) or (is_poly is False):
262
+ if fit_series_expansion is True:
263
+ expr, output["params"] = _perform_series_fit(
264
+ expr, bounds, nterms, fitres, point, bound_series_fit, is_poly
265
+ )
266
+ elif (len(free) > 1) or (is_poly is True):
267
+ _, expr = multivariate_taylor(expr, point, nterms, False)
268
+ else:
269
+ _, expr = series_lambda(expr, point[0], nterms, False)
270
+ # rewrite polynomial in horner form for fast evaluation
271
+ output["expr"] = sp.horner(expr)
272
+ polyfunc = sp.lambdify(free, output["expr"], ("scipy", "numpy"))
273
+ # polish it and optionally rewrite it to precompute repeated powers or
274
+ # force precision
275
+ return output | {"source": rewrite(polyfunc, prefactor, precision)}
276
+
277
+
278
+ def quickseries(
279
+ func: Union[str, sp.Expr],
280
+ *,
281
+ bounds: tuple[float, float] = (-1, 1),
282
+ nterms: int = 9,
283
+ point: Optional[float] = None,
284
+ fitres: int = 100,
285
+ prefactor: Optional[bool] = None,
286
+ approx_poly: bool = False,
287
+ jit: bool = False,
288
+ precision: Optional[Literal[16, 32, 64]] = None,
289
+ fit_series_expansion: bool = True,
290
+ bound_series_fit: bool = False,
291
+ extended_output: bool = False,
292
+ cache: bool = True,
293
+ ) -> Union[LmSig, tuple[LmSig, dict]]:
294
+ if not isinstance(func, (str, sp.Expr)):
295
+ raise TypeError(f"Unsupported type for func {type(func)}.")
296
+ polyfunc, ext = None, {"cache": "off"}
297
+ if cache is True:
298
+ polyfunc, source = _cacheget(jit)
299
+ if polyfunc is not None:
300
+ ext |= {"source": source, "cache": "hit"}
301
+ else:
302
+ ext["cache"] = "miss"
303
+ if polyfunc is None:
304
+ ext |= _make_quickseries(
305
+ approx_poly,
306
+ bound_series_fit,
307
+ bounds,
308
+ func if isinstance(func, sp.Expr) else sp.sympify(func),
309
+ fit_series_expansion,
310
+ fitres,
311
+ nterms,
312
+ point,
313
+ precision,
314
+ prefactor if prefactor is not None else not jit
315
+ )
316
+ polyfunc = _finalize_quickseries(ext["source"], jit, cache)
317
+ if extended_output is True:
318
+ return polyfunc, ext
319
+ return polyfunc
@@ -0,0 +1,108 @@
1
+ import timeit
2
+ from inspect import getfullargspec
3
+ from itertools import product
4
+ from time import time
5
+ from typing import Union, Sequence, Optional
6
+
7
+ import numpy as np
8
+ import sympy as sp
9
+ from dustgoggles.func import gmap
10
+
11
+ from quickseries import quickseries
12
+ from quickseries.approximate import _makebounds
13
+ from quickseries.sputils import lambdify, LmSig
14
+
15
+
16
+ def _offset_check_cycle(
17
+ absdiff: float,
18
+ frange: tuple[float, float],
19
+ lamb: LmSig,
20
+ quick: LmSig,
21
+ vecs: Sequence[np.ndarray],
22
+ worstpoint: Optional[list[float]],
23
+ ) -> tuple[float, float, float, tuple[float, float], list[float]]:
24
+ approx_y, orig_y = quick(*vecs), lamb(*vecs)
25
+ frange = (min(orig_y.min(), frange[0]), max(orig_y.max(), frange[1]))
26
+ offset = abs(approx_y - orig_y)
27
+ worstix = np.argmax(offset)
28
+ if (new_absdiff := offset[worstix]) > absdiff:
29
+ absdiff = new_absdiff
30
+ worstpoint = [v[worstix] for v in vecs]
31
+ return absdiff, np.median(offset), np.mean(offset ** 2), frange, worstpoint
32
+
33
+
34
+ def benchmark(
35
+ func: Union[str, sp.Expr, sp.core.function.FunctionClass],
36
+ offset_resolution: int = 10000,
37
+ n_offset_shuffles: int = 50,
38
+ timeit_cycles: int = 20000,
39
+ testbounds="equal",
40
+ cache: bool = False,
41
+ **quickkwargs
42
+ ) -> dict[str, sp.Expr | float | np.ndarray | str | list[float]]:
43
+ lamb = lambdify(func)
44
+ compile_start = time()
45
+ quick, ext = quickseries(
46
+ func, **(quickkwargs | {'extended_output': True, 'cache': cache})
47
+ )
48
+ gentime = time() - compile_start
49
+ if testbounds == "equal":
50
+ testbounds, _ = _makebounds(
51
+ quickkwargs.get("bounds"), len(getfullargspec(lamb).args), None
52
+ )
53
+ vecs = [np.linspace(*b, offset_resolution) for b in testbounds]
54
+ if (pre := quickkwargs.get("precision")) is not None:
55
+ vecs = gmap(
56
+ lambda arr: arr.astype(getattr(np, f"float{pre}")), vecs
57
+ )
58
+ if len(testbounds) > 1:
59
+ # always check the extrema of the bounds
60
+ extrema = [[] for _ in vecs]
61
+ for p in product((-1, 1), repeat=len(vecs)):
62
+ for i, side in enumerate(p):
63
+ extrema[i].append(vecs[i][side])
64
+ extrema = [np.array(e) for e in extrema]
65
+ absdiff, _, __, frange, worstpoint = _offset_check_cycle(
66
+ 0, (np.inf, -np.inf), lamb, quick, extrema, None
67
+ )
68
+ medians, mses = [], []
69
+ for _ in range(n_offset_shuffles):
70
+ gmap(np.random.shuffle, vecs)
71
+ absdiff, mediff, mse, frange, worstpoint = _offset_check_cycle(
72
+ absdiff, frange, lamb, quick, vecs, worstpoint
73
+ )
74
+ medians.append(mediff)
75
+ mses.append(mse)
76
+ mediff, mse = np.median(medians), np.median(mses)
77
+ # no point in shuffling for 1D -- we're doing that for > 1D
78
+ # because it becomes quickly unreasonable in terms of memory
79
+ # to be exhaustive, but this _is_ exhaustive for 1D
80
+ else:
81
+ approx_y, orig_y = quick(*vecs), lamb(*vecs)
82
+ frange = (orig_y.min(), orig_y.max())
83
+ offset = abs(approx_y - orig_y)
84
+ worstix = np.argmax(offset)
85
+ absdiff = offset[worstix]
86
+ mediff = np.median(offset)
87
+ mse = np.mean(offset ** 2)
88
+ worstpoint = [vecs[0][worstix]]
89
+ del offset, orig_y, approx_y
90
+ # TODO: should probably permit specifying dtype for jitted
91
+ # functions -- both here and in primary quickseries().
92
+ approx_time = timeit.timeit(lambda: quick(*vecs), number=timeit_cycles)
93
+ orig_time = timeit.timeit(lambda: lamb(*vecs), number=timeit_cycles)
94
+ orig_s = orig_time / timeit_cycles
95
+ approx_s = approx_time / timeit_cycles
96
+ return {
97
+ 'absdiff': absdiff,
98
+ 'reldiff': absdiff / np.ptp(frange),
99
+ 'mediff': mediff,
100
+ 'mse': mse,
101
+ 'worstpoint': worstpoint,
102
+ 'range': frange,
103
+ 'orig_s': orig_s,
104
+ 'approx_s': approx_s,
105
+ 'timeratio': approx_s / orig_s,
106
+ 'gentime': gentime,
107
+ 'polyfunc': quick
108
+ } | ext
@@ -0,0 +1,126 @@
1
+ from functools import reduce
2
+ from typing import Union, Sequence
3
+
4
+ from dustgoggles.structures import listify
5
+ import sympy as sp
6
+
7
+ from quickseries.sputils import LmSig
8
+
9
+
10
+ def _rectify_series(series, add_coefficients):
11
+ if isinstance(series, sp.Order):
12
+ raise ValueError(
13
+ "Cannot produce a meaningful approximation with the requested "
14
+ "parameters (most likely order is too low)."
15
+ )
16
+ outargs, coefsyms = [], []
17
+ for a in series.args:
18
+ # NOTE: the Expr.evalf() calls are simply to try to evaluate
19
+ # anything we can.
20
+ if hasattr(a, "evalf") and isinstance((n := a.evalf()), sp.Number):
21
+ outargs.append(n)
22
+ elif isinstance(a, sp.Order):
23
+ continue
24
+ elif isinstance(a, (sp.Mul, sp.Symbol, sp.Pow)):
25
+ if add_coefficients is True:
26
+ coefficient = sp.symbols(f"a_{len(coefsyms)}")
27
+ outargs.append((coefficient * a).evalf())
28
+ coefsyms.append(coefficient)
29
+ else:
30
+ outargs.append(a.evalf())
31
+ else:
32
+ raise ValueError(
33
+ f"don't know how to handle expression element {a} of "
34
+ f"type({type(a)})"
35
+ )
36
+ return sum(outargs), coefsyms
37
+
38
+
39
+ def series_lambda(
40
+ func: Union[str, sp.Expr],
41
+ x0: float = 0,
42
+ nterms: int = 9,
43
+ add_coefficients: bool = False,
44
+ modules: Union[str, Sequence[str]] = ("scipy", "numpy")
45
+ ) -> tuple[LmSig, sp.Expr]:
46
+ """
47
+ Construct a power expansion of a sympy Expr or the string expression of a
48
+ function; optionally, add free coefficients to the terms of the resulting
49
+ polynomial to permit optimization by downstream functions.
50
+
51
+ Args:
52
+ func: Mathematical function to expand, expressed as a string or a
53
+ sympy Expr.
54
+ x0: Point about which to expand func.
55
+ nterms: Order of power expansion.
56
+ add_coefficients: If True, add additional arguments/symbols to the
57
+ returned function and Expr corresponding to the polynomial's
58
+ coefficients.
59
+ modules: Modules from which to draw the building blocks of the
60
+ returned function.
61
+
62
+ Returns:
63
+ approximant: Python function that implements the power expansion.
64
+ expr: sympy Expr used to construct approximant.
65
+ """
66
+ func = sp.sympify(func) if isinstance(func, str) else func
67
+ # limiting precision of x0 is necessary due to a bug in sp.series
68
+ series = sp.series(func, x0=round(x0, 6), n=nterms)
69
+ # noinspection PyTypeChecker
70
+ # remove Order (limit behavior) terms, try to split constants from
71
+ # polynomial terms
72
+ expr, coefsyms = _rectify_series(series, add_coefficients)
73
+ syms = sorted(func.free_symbols, key=lambda x: str(x))
74
+ # noinspection PyTypeChecker
75
+ return sp.lambdify(syms + coefsyms, expr, modules), expr
76
+
77
+
78
+ def additive_combinations(n_terms, number):
79
+ if n_terms == 1:
80
+ return [(n,) for n in range(number + 1)]
81
+ combinations = [] # NOTE: this is super gross-looking written as a chain
82
+ for j in range(number + 1):
83
+ combinations += [
84
+ (j, *t)
85
+ for t in additive_combinations(n_terms - 1, number - j)
86
+ ]
87
+ return combinations
88
+
89
+
90
+ def multivariate_taylor(
91
+ func: Union[str, sp.Expr],
92
+ point: Sequence[float],
93
+ nterms: int,
94
+ add_coefficients: bool = False
95
+ ) -> tuple[LmSig, sp.Expr]:
96
+ func = sp.sympify(func) if isinstance(func, str) else func
97
+ pointsyms = sorted(func.free_symbols, key=lambda s: str(s))
98
+ dimensionality = len(pointsyms)
99
+ argsyms = listify(
100
+ sp.symbols(",".join([f"x{i}" for i in range(dimensionality)]))
101
+ )
102
+ ixsyms = listify(
103
+ sp.symbols(",".join(f"i{i}" for i in range(dimensionality)))
104
+ )
105
+ deriv = sp.Derivative(func, *[(p, i) for p, i in zip(pointsyms, ixsyms)])
106
+ # noinspection PyTypeChecker
107
+ fact = reduce(sp.Mul, [sp.factorial(i) for i in ixsyms])
108
+ err = reduce(
109
+ sp.Mul,
110
+ [(x - a) ** i for x, a, i in zip(argsyms, pointsyms, ixsyms)]
111
+ )
112
+ taylor = deriv / fact * err
113
+ # TODO, probably: there's a considerably faster way to do this in some
114
+ # cases by precomputing partial derivatives
115
+ decomp = additive_combinations(dimensionality, nterms - 1)
116
+ built = reduce(
117
+ sp.Add,
118
+ (taylor.subs({i: d for i, d in zip(ixsyms, d)}) for d in decomp)
119
+ ).doit()
120
+ evaluated = built.subs({s: p for s, p in zip(pointsyms, point)}).evalf()
121
+ # this next line is kind of aesthetic -- we just want the argument names
122
+ # to be consistent with the input
123
+ evaluated = evaluated.subs({a: p for a, p in zip(argsyms, pointsyms)})
124
+ evaluated, coefsyms = _rectify_series(evaluated, add_coefficients)
125
+ # noinspection PyTypeChecker
126
+ return sp.lambdify(pointsyms + coefsyms, evaluated), evaluated
@@ -0,0 +1,65 @@
1
+ """lightweight version of `moonbow`'s polynomial fit functionality"""
2
+ from functools import wraps
3
+ from inspect import Parameter, signature
4
+ from typing import Callable, Optional, Sequence, Union
5
+
6
+ import numpy as np
7
+ from scipy.optimize import curve_fit
8
+
9
+
10
+ def fit_wrap(
11
+ func: Callable[[np.ndarray | float, ...], np.ndarray | float],
12
+ dimensionality: int,
13
+ fit_parameters: Sequence[str]
14
+ ) -> Callable[[np.ndarray | float, ...], np.ndarray | float]:
15
+ @wraps(func)
16
+ def wrapped_fit(independent_variable, *params):
17
+ variable_components = [
18
+ independent_variable[n] for n in range(dimensionality)
19
+ ]
20
+ exploded_function = func(*variable_components, *params)
21
+ return exploded_function
22
+
23
+ # rewrite the signature so that curve_fit will like it
24
+ sig = signature(wrapped_fit)
25
+ curve_fit_params = (
26
+ Parameter("independent_variable", Parameter.POSITIONAL_ONLY),
27
+ *fit_parameters,
28
+ )
29
+ wrapped_fit.__signature__ = sig.replace(parameters=curve_fit_params)
30
+ return wrapped_fit
31
+
32
+
33
+ def fit(
34
+ func: Callable,
35
+ vecs: list[np.ndarray],
36
+ dependent_variable: np.ndarray,
37
+ guess: Optional[Sequence[float]] = None,
38
+ bounds: Optional[
39
+ Union[tuple[tuple[float, float]], tuple[float, float]]
40
+ ] = None
41
+ ) -> tuple[np.ndarray, np.ndarray]:
42
+ sig = signature(func)
43
+ assert len(vecs) < len(sig.parameters), (
44
+ "The model function must have at least one 'free' "
45
+ "parameter to be a meaningful candidate for fitting."
46
+ )
47
+ fit_parameters = [
48
+ item
49
+ for ix, item in enumerate(sig.parameters.values())
50
+ if ix >= len(vecs)
51
+ ]
52
+ # TODO: check dim of dependent
53
+ if not all(p.ndim == 1 for p in vecs):
54
+ raise ValueError("each input vector must be 1-dimensional")
55
+ # TODO: optional goodness-of-fit evaluation
56
+ kw = {'bounds': bounds} if bounds is not None else {}
57
+ # noinspection PyTypeChecker
58
+ return curve_fit(
59
+ fit_wrap(func, len(vecs), fit_parameters),
60
+ vecs,
61
+ dependent_variable,
62
+ maxfev=20000,
63
+ p0=guess,
64
+ **kw
65
+ )
@@ -0,0 +1,133 @@
1
+ import pickle
2
+ from hashlib import md5
3
+ from inspect import currentframe, getargvalues, getsource
4
+ import linecache
5
+ from pathlib import Path
6
+ import re
7
+ from types import FunctionType
8
+ from typing import Callable
9
+
10
+ from dustgoggles.dynamic import define, get_codechild
11
+
12
+
13
+ CACHE_ARGS = (
14
+ "func",
15
+ "bounds",
16
+ "nterms",
17
+ "point",
18
+ "fitres",
19
+ "prefactor",
20
+ "approx_poly",
21
+ "precision",
22
+ "fit_series_expansion",
23
+ "bound_series_fit"
24
+ )
25
+
26
+
27
+ def cache_source(source: str, fn: Path):
28
+ fn = str(fn)
29
+ linecache.cache[fn] = (len(source), None, source.splitlines(True), fn)
30
+
31
+
32
+ # TODO: pull this little bitty change up to dustgoggles
33
+ def compile_source(source: str, fn: str = ""):
34
+ return get_codechild(compile(source, fn, "exec"))
35
+
36
+
37
+ def _cachedir(callfile: str) -> Path:
38
+ if callfile == 'ipython_shell':
39
+ import IPython.paths
40
+
41
+ return Path(IPython.paths.get_ipython_cache_dir()) / "qs_cache"
42
+ return Path(callfile).parent / "__pycache__" / "qs_cache"
43
+
44
+
45
+ def _cachekey(args, callfile=None):
46
+ from quickseries import __version__
47
+
48
+ # TODO: is this actually stable?
49
+ arghash = pickle.dumps(
50
+ {a: args.locals[a] for a in sorted(CACHE_ARGS)}
51
+ | {'f': callfile, '__version__': __version__}
52
+ )
53
+ # arbitrary cutoff for a reasonable tradeoff between collision safety and
54
+ # readability
55
+ return f"quickseries_{md5(arghash).hexdigest()}"[:-18]
56
+
57
+
58
+ # TODO, maybe: the frame traversal is potentially wasteful when repeated,
59
+ # although it probably doesn't matter too much.
60
+ def _cacheid():
61
+ """
62
+ WARNING: do not call this outside the normal quickseries workflow. It can
63
+ be tricked, but to no good end.
64
+ """
65
+ frame, callfile, args = currentframe(), None, None
66
+ while callfile is None:
67
+ frame = frame.f_back
68
+ if frame is None or frame.f_code.co_filename == "<stdin>":
69
+ callfile = "__quickseries_anonymous_caller_cache__/anonymous"
70
+ elif hasattr(frame.f_code, "co_name"):
71
+ if args is None and frame.f_code.co_name == "quickseries":
72
+ args = getargvalues(frame)
73
+ elif frame.f_code.co_name == "benchmark":
74
+ continue
75
+ elif args is not None:
76
+ callfile = frame.f_code.co_filename
77
+ if re.search(r"interactiveshell.py|ipython", callfile):
78
+ callfile = 'ipython_shell'
79
+ if args is None:
80
+ raise ReferenceError("Cannot use _cachefile() outside quickseries().")
81
+ key = _cachekey(args, callfile)
82
+ return _cachedir(callfile) / key / "func", key
83
+
84
+
85
+ def _compile_quickseries(source, jit, cache, cachefile):
86
+ globals_ = globals()
87
+ if (precmatch := re.search(r"float\d\d", source)) is not None:
88
+ import numpy
89
+
90
+ globals_[precmatch.group()] = getattr(numpy, precmatch.group())
91
+ func = FunctionType(compile_source(source, str(cachefile)), globals_)
92
+ cache_source(source, cachefile)
93
+ func.__doc__ = source
94
+ if jit is True:
95
+ import numba as nb
96
+
97
+ return nb.njit(func, cache=cache)
98
+ return func
99
+
100
+
101
+ def _cacheget(jit=False):
102
+ cachefile, key = _cacheid()
103
+ if not cachefile.exists():
104
+ return None, None
105
+ with cachefile.open() as stream:
106
+ source = stream.read()
107
+ return _compile_quickseries(source, jit, True, cachefile), source
108
+
109
+
110
+ def _cachewrite(source, cachefile):
111
+ # we make the __pycache__ directory to enable numba JIT result caching,
112
+ # just in case it happens; if it doesn't, the presence of the directory is
113
+ # harmless.
114
+ (cachefile.parent / "__pycache__").mkdir(exist_ok=True, parents=True)
115
+ # TODO, maybe: use a more sensible data structure
116
+ with cachefile.open("w") as stream:
117
+ stream.write(source)
118
+
119
+
120
+ def _finalize_quickseries(source, jit=False, cache=False):
121
+ # note that we use this as a function identifier and 'fake' target for
122
+ # linecache even if we're not actually using the quickseries cache
123
+ cachefile, key = _cacheid()
124
+ if cache is True:
125
+ _cachewrite(source, cachefile)
126
+ return _compile_quickseries(source, jit, cache, cachefile)
127
+
128
+
129
+ def lastline(func: Callable) -> str:
130
+ """try to get the last line of a function, sans return statement"""
131
+ return tuple(
132
+ filter(None, getsource(func).split("\n"))
133
+ )[-1].replace("return", "").strip()
quickseries/sputils.py ADDED
@@ -0,0 +1,26 @@
1
+ from typing import Any, Callable, Sequence, Union
2
+
3
+ import numpy as np
4
+ import sympy as sp
5
+
6
+ LmSig = Callable[[np.ndarray | float, ...], np.ndarray | float]
7
+
8
+
9
+ def lambdify(
10
+ func: Union[str, sp.Expr],
11
+ modules: Union[str, Sequence[str]] = ("scipy", "numpy")
12
+ ) -> LmSig:
13
+ """
14
+ Transform a sympy Expr or a string representation of a function into a
15
+ callable with enforced argument order, incorporating code from specified
16
+ modules.
17
+ """
18
+ if isinstance(func, str):
19
+ try:
20
+ func = sp.sympify(func)
21
+ except sp.SympifyError:
22
+ raise ValueError(f"Unable to parse {func}.")
23
+ # noinspection PyTypeChecker
24
+ return sp.lambdify(
25
+ sorted(func.free_symbols, key=lambda x: str(x)), func, modules
26
+ )
@@ -0,0 +1,28 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, Million Concepts
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,17 @@
1
+ Metadata-Version: 2.1
2
+ Name: quickseries
3
+ Version: 0.2.1
4
+ Home-page: https://github.com/millionconcepts/quickseries.git
5
+ Author: Michael St. Clair
6
+ Author-email: mstclair@millionconcepts.com
7
+ Requires-Python: >=3.11
8
+ License-File: LICENSE
9
+ Requires-Dist: dustgoggles
10
+ Requires-Dist: numpy
11
+ Requires-Dist: scipy
12
+ Requires-Dist: sympy
13
+ Provides-Extra: jit
14
+ Requires-Dist: numba ; extra == 'jit'
15
+ Provides-Extra: tests
16
+ Requires-Dist: pytest ; extra == 'tests'
17
+
@@ -0,0 +1,12 @@
1
+ quickseries/__init__.py,sha256=buB8Cr2L8b7cYZsTVtYnF_bBOLFq79BfCxGe47akPnc,115
2
+ quickseries/approximate.py,sha256=yepSj9WaMGkNcqdDgbJMHSoNtc83IhK1LGiAGvmvMYA,11704
3
+ quickseries/benchmark.py,sha256=duUzFz5L710B8MV0BMvA1u1z1QrD9YeeOz49GI6IdVk,3990
4
+ quickseries/expansions.py,sha256=nARZ7RusPVZ9YevdmMI0d3d4heFQoYX5FFalcstF8io,4823
5
+ quickseries/simplefit.py,sha256=07w_uyfEtkHj19lYe1NOkvaD1FNizkXk8d9Nx6h0YkY,2106
6
+ quickseries/sourceutils.py,sha256=BG1ar074-wtsNiYr5nryVfWFfClymmOk50nMqsovd7Y,4289
7
+ quickseries/sputils.py,sha256=cOaZ4OI9amLf3wexhMIjmgMNPaLwQXRF_2t7yHj23fg,756
8
+ quickseries-0.2.1.dist-info/LICENSE,sha256=q0EG1U_Kyw7_ubZTiZL51qAaMzklie_EbIIyq7Ff0zM,1503
9
+ quickseries-0.2.1.dist-info/METADATA,sha256=OfOgJSRuldi-x58ux1mAmEd7pV7NZa3V3o1bHkF7khA,444
10
+ quickseries-0.2.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
11
+ quickseries-0.2.1.dist-info/top_level.txt,sha256=cD5wC1LvdrL0KYP4te9ljRhXK4brBvyvlXGDsNnWoQA,12
12
+ quickseries-0.2.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.43.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ quickseries