quickseries 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quickseries-0.2.1/LICENSE +28 -0
- quickseries-0.2.1/PKG-INFO +16 -0
- quickseries-0.2.1/README.md +234 -0
- quickseries-0.2.1/quickseries/__init__.py +4 -0
- quickseries-0.2.1/quickseries/approximate.py +319 -0
- quickseries-0.2.1/quickseries/benchmark.py +108 -0
- quickseries-0.2.1/quickseries/expansions.py +126 -0
- quickseries-0.2.1/quickseries/simplefit.py +65 -0
- quickseries-0.2.1/quickseries/sourceutils.py +133 -0
- quickseries-0.2.1/quickseries/sputils.py +26 -0
- quickseries-0.2.1/quickseries.egg-info/PKG-INFO +16 -0
- quickseries-0.2.1/quickseries.egg-info/SOURCES.txt +15 -0
- quickseries-0.2.1/quickseries.egg-info/dependency_links.txt +1 -0
- quickseries-0.2.1/quickseries.egg-info/requires.txt +10 -0
- quickseries-0.2.1/quickseries.egg-info/top_level.txt +1 -0
- quickseries-0.2.1/setup.cfg +4 -0
- quickseries-0.2.1/setup.py +13 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023, Million Concepts
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: quickseries
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Home-page: https://github.com/millionconcepts/quickseries.git
|
|
5
|
+
Author: Michael St. Clair
|
|
6
|
+
Author-email: mstclair@millionconcepts.com
|
|
7
|
+
Requires-Python: >=3.11
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: dustgoggles
|
|
10
|
+
Requires-Dist: numpy
|
|
11
|
+
Requires-Dist: scipy
|
|
12
|
+
Requires-Dist: sympy
|
|
13
|
+
Provides-Extra: jit
|
|
14
|
+
Requires-Dist: numba; extra == "jit"
|
|
15
|
+
Provides-Extra: tests
|
|
16
|
+
Requires-Dist: pytest; extra == "tests"
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# quickseries
|
|
2
|
+
|
|
3
|
+
`quickseries` generates Python functions that perform fast vectorized power
|
|
4
|
+
series approximations of mathematical functions. It can provide performance
|
|
5
|
+
improvements ranging from ~3x (simple functions, no fiddling around with
|
|
6
|
+
parameters) to ~100x (complicated functions, some parameter tuning).
|
|
7
|
+
|
|
8
|
+
`quickseries` is in beta; bug reports are appreciated.
|
|
9
|
+
|
|
10
|
+
Install from source using `pip install .`. Dependencies are also described
|
|
11
|
+
in a Conda `environment.yml` file.
|
|
12
|
+
|
|
13
|
+
The minimum supported version of Python is *3.11*.
|
|
14
|
+
|
|
15
|
+
## example of use
|
|
16
|
+
|
|
17
|
+
```
|
|
18
|
+
>>> import numpy as np
|
|
19
|
+
>>> from quickseries import quickseries
|
|
20
|
+
|
|
21
|
+
>>> bounds = (-np.pi, np.pi)
|
|
22
|
+
>>> approx = quickseries("sin(x)*cos(x)", point=0, order=12, bounds=bounds)
|
|
23
|
+
>>> x = np.linspace(*bounds, 100000)
|
|
24
|
+
>>> print(f"max error: {max(abs(np.sin(x) * np.cos(x) - approx(x)))}")
|
|
25
|
+
>>> print("original runtime:")
|
|
26
|
+
>>> %timeit np.sin(x) * np.cos(x)
|
|
27
|
+
>>> print("approx runtime:")
|
|
28
|
+
>>> %timeit approx(x)
|
|
29
|
+
|
|
30
|
+
max error: 0.0003270875375037813
|
|
31
|
+
original runtime:
|
|
32
|
+
968 µs ± 2.17 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
|
|
33
|
+
approx runtime:
|
|
34
|
+
325 µs ± 3.89 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
## usage notes
|
|
39
|
+
|
|
40
|
+
### features
|
|
41
|
+
|
|
42
|
+
* The most important keyword arguments to `quickseries` are `bounds`,
|
|
43
|
+
`nterms`, and `point`. `bounds` specifies the range (or ranges, for
|
|
44
|
+
multivariate functions) of values across which to approximate the function.
|
|
45
|
+
`nterms` specifies how many terms to use in the series expansion. `point`
|
|
46
|
+
specifies the value (or values, for multivariate functions) about which
|
|
47
|
+
to generate the series expansion. See "limitations" and "tips" below for
|
|
48
|
+
examples and discussion.
|
|
49
|
+
* `quickseries()` is capable of auto-jitting the functions it generates
|
|
50
|
+
with `numba`. Pass the `jit=True` argument. `numba` is an optional dependency;
|
|
51
|
+
install it with your preferred package manager.
|
|
52
|
+
* `quickseries.benchmark()` offers an easy way to test the accuracy and
|
|
53
|
+
efficiency of `quickseries.quickseries()`-generated functions.
|
|
54
|
+
* By default, `quickseries()` caches the code it generates. If you wish to
|
|
55
|
+
turn this behavior off, pass `cache=False`.
|
|
56
|
+
* If you call `quickseries()` with the same arguments from separate modules,
|
|
57
|
+
it will write separate caches for each module.
|
|
58
|
+
* ipython/Jupyter shells/kernels all share one cache within the same user
|
|
59
|
+
account.
|
|
60
|
+
* `quickseries()` treats stdin or similar 'anonymous' invocation contexts
|
|
61
|
+
like modules named "__quickseries_anonymous_caller_cache__" in the current
|
|
62
|
+
working directory.
|
|
63
|
+
* In this mode, `quickseries()` also caches any results of `numba` JIT
|
|
64
|
+
compilation.
|
|
65
|
+
* Caching is turned _off_ by default for `benchmark()`.
|
|
66
|
+
* If you pass the `precision` argument to `quickseries()`, it will attempt to
|
|
67
|
+
guarantee that the function it returns will not cast input values to bit widths
|
|
68
|
+
greater than the value of `precision`. Legal values of `precision` are 16, 32,
|
|
69
|
+
and 64. The returned function will not, however, attempt to reduce the precision
|
|
70
|
+
of its arguments. For instance, `quickseries("sin(x) + exp(x)", precision=32)`
|
|
71
|
+
will return a Python `float` if passed an `float`, and a `np.float64` `ndarray`
|
|
72
|
+
if passed a `np.float64` `ndarray`. However, it will return a `np.float32`
|
|
73
|
+
`ndarray` if passed a `np.float32` `ndarray`, which is not guaranteed without
|
|
74
|
+
the `precision=32` argument.
|
|
75
|
+
|
|
76
|
+
### argument naming
|
|
77
|
+
|
|
78
|
+
* Multivariate `quickseries()`-generated functions always map positional arguments
|
|
79
|
+
to variables in the string representation of the input function in alphanumeric
|
|
80
|
+
order. This is in order to maintain consistency between slightly different
|
|
81
|
+
forms of the same expression.
|
|
82
|
+
* Examples:
|
|
83
|
+
* `quickseries("cos(x) * sin(y)")(1, 2)` approximates `sin(1) * cos(2)`
|
|
84
|
+
* `quickseries("sin(y) * cos(x)")(1, 2)` approximates `cos(1) * sin(2)`
|
|
85
|
+
* `quickseries("sin(x) * cos(y)")(1, 2)` approximates `sin(1) * cos(2)`
|
|
86
|
+
* Note that you can always determine the argument order of a `quickseries()`-
|
|
87
|
+
generated function by using the `help()` builtin, `inspect.getfullargspec()`,
|
|
88
|
+
examining the function's docstring, etc.
|
|
89
|
+
* Most legal Python variable names are allowable names for free variables.
|
|
90
|
+
Named mathematical functions and constants are the major exceptions.
|
|
91
|
+
* Examples:
|
|
92
|
+
* `"ln(_)"`, `"ln(One_kitty)"`, `"ln(x0)"`, and `"ln(ă)"` will all work fine.
|
|
93
|
+
* `"ln(if)"` and `"ln(🔥)"` will both fail, because `if` and `🔥` are not
|
|
94
|
+
legal Python variable names.
|
|
95
|
+
* `"ln(gamma)"` will fail, because `quickseries()` will interpret "gamma"
|
|
96
|
+
as the gamma function.
|
|
97
|
+
* `"cos(x) * cos(pi * 2)"` will succeed, but `quickseries()` will interpret
|
|
98
|
+
it as "the cosine of a variable named 'x' times the cosine of two times
|
|
99
|
+
the mathematical constant pi" -- in other words, as `"cos(x)"`.
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
### limitations
|
|
103
|
+
|
|
104
|
+
* `quickseries` only works for functions ℝ<sup>_n_</sup>🡒ℝ for finite _n_. In
|
|
105
|
+
programming terms, this means it will only produce functions that accept a
|
|
106
|
+
fixed number of floating-point or integer arguments (which may be 'arraylike'
|
|
107
|
+
objects such as pandas `Series` or numpy `ndarrays`) and return a single
|
|
108
|
+
floating-point value (or a 1-D floating-point array if passed arraylike
|
|
109
|
+
arguments).
|
|
110
|
+
* `quickseries` only works consistently on functions that are continuous and
|
|
111
|
+
infinitely differentiable within the domain of interest. Specifically, they
|
|
112
|
+
should not have singularities, discontinuities, or infinite / undefined
|
|
113
|
+
values at `point` or within `bounds`. Failure cases differ:
|
|
114
|
+
* `quickseries` will always fail on functions that are infinite/undefined
|
|
115
|
+
at `point`, like `quickseries("ln(x)", point=-1)`.
|
|
116
|
+
* It will almost always fail on functions with a largeish interval of
|
|
117
|
+
infinite/undefined values within `bounds`, such as
|
|
118
|
+
`quickseries("gamma(x + y)", bounds=((-1.1, 0), (0, 1)), point=(-0.5, 0))`.
|
|
119
|
+
* It will usually succeed but produce bad results on functions with
|
|
120
|
+
singularities or point discontinuities within `bounds` or
|
|
121
|
+
near `point` but not at `point`, such as `quickseries("tan(x)", bounds=(1, 2))`.
|
|
122
|
+
* It will often succeed, but usually produce bad results, on univariate
|
|
123
|
+
functions that are continuous but not differentiable at `point`, such as
|
|
124
|
+
`quickseries("abs(sin(x))", point=0)`. It will always fail on multivariate
|
|
125
|
+
functions of this kind.
|
|
126
|
+
* Functions given to `quickseries` must be expressed in strict closed form
|
|
127
|
+
and include only finite terms. They cannot contain limits, integrals,
|
|
128
|
+
derivatives, summations, continued fractions, etc.
|
|
129
|
+
* `quickseries` is not guaranteed to work for all such functions.
|
|
130
|
+
|
|
131
|
+
### tips
|
|
132
|
+
|
|
133
|
+
* Narrowing `bounds` will tend to make the approximation more accurate within
|
|
134
|
+
those bounds. In the example at the top of this README, setting `bounds` to
|
|
135
|
+
`(-1, 1)` provides ~20x greater accuracy within the (-1, 1) interval (with
|
|
136
|
+
the downside that the resulting approximation will get pretty bad past about
|
|
137
|
+
+/-pi/2).
|
|
138
|
+
* Like many optimizers, `quickseries()` tends to be much more effective
|
|
139
|
+
closer to 0 and when its input arguments have similar orders of
|
|
140
|
+
magnitude. If it is practical to shift/squeeze your data towards 0, you
|
|
141
|
+
may be able to get more use out of `quickseries`. One of the biggest reasons
|
|
142
|
+
for this is that high-order polynomials are more numerically stable with
|
|
143
|
+
smaller input values.
|
|
144
|
+
* Functions with a pole at 0 can of course present an exception to this
|
|
145
|
+
rule. It will still generally be better to keep their input values small.
|
|
146
|
+
* Increasing `nterms` will tend to make the approximation slower but more
|
|
147
|
+
accurate. In the example above, increasing `nterms` to 14 provides ~20x
|
|
148
|
+
greater accuracy but makes the approximation ~20% slower.
|
|
149
|
+
* This tends to have diminishing returns. In the example above, increasing
|
|
150
|
+
`nterms` to 30 provides no meaningful increase in accuracy over `order=14`,
|
|
151
|
+
but makes the approximation *slower* than `np.sin(x) * np.cos(x)`.
|
|
152
|
+
* Setting `nterms` too high can also cause the approximation algorithm to
|
|
153
|
+
fail entirely.
|
|
154
|
+
* For most functions, placing `point` in the middle of `bounds` will produce the
|
|
155
|
+
best results, and if you don't pass `point` at all, `quickseries` defaults to
|
|
156
|
+
placing it in the middle of `bounds`.
|
|
157
|
+
* The location of accuracy/performance "sweet spots" in the parameter space
|
|
158
|
+
depends on the function and the approximation bounds. If you want to
|
|
159
|
+
seriously optimize a particular function in a particular interval, you will
|
|
160
|
+
need to play around with these parameters.
|
|
161
|
+
* The speedup (or lack thereof) that a `quickseries()`-generated approximation
|
|
162
|
+
provides can vary greatly in different operating environments and on different
|
|
163
|
+
processors.
|
|
164
|
+
* It can also vary depending on the length of the input arguments. It generally
|
|
165
|
+
provides most benefit on arrays with tens or hundreds of thousands of elements,
|
|
166
|
+
although this again varies depending on operating environment, the particular
|
|
167
|
+
approximated function, etc.
|
|
168
|
+
* In general, `quickseries` provides more performance benefits for more 'complicated'
|
|
169
|
+
input functions. This is due to the implicit 'simplification' offered by the
|
|
170
|
+
power series expansion.
|
|
171
|
+
* It is often difficult to generate a polynomial approximation that
|
|
172
|
+
remains good across a wide range of input values. In some cases, it may be
|
|
173
|
+
useful to generate different functions for different parts of your code, or
|
|
174
|
+
even to perform piecewise operations with multiple functions (although this
|
|
175
|
+
of course adds complexity and overhead).
|
|
176
|
+
* By default, if you pass a simple polynomial expression to `quickseries()`
|
|
177
|
+
(e.g. `"x**4 + 2 * x**3"`), it does not actually generate an approximation,
|
|
178
|
+
but instead simply attempts to rewrite it in a more efficient form.
|
|
179
|
+
* `nterms`, `bounds`, and `point` are ignored in this "rewrite" mode.
|
|
180
|
+
* This type of `quickseries()`-generated function should produce the same
|
|
181
|
+
results as any other Python function that straightforwardly implements a
|
|
182
|
+
form of the input polynomial (down to floating-point error).
|
|
183
|
+
* This can produce surprising speedups even in simple cases -- for example,
|
|
184
|
+
`quickseries("x**4")` is ~20x faster than `lambda x: x ** 4` on some
|
|
185
|
+
`numpy` arrays.
|
|
186
|
+
* If you want `quickseries()` to actually create an approximation of a
|
|
187
|
+
simple polynomial, pass `approx_poly=True`.
|
|
188
|
+
* When approximating a polynomial, there is generally no good reason to
|
|
189
|
+
set `nterms` > that polynomial's order. If you do, the function
|
|
190
|
+
`quickseries()` generates will typically be very similar to a simple
|
|
191
|
+
rewrite of the input polynomial, but with slightly worse performance and
|
|
192
|
+
accuracy.
|
|
193
|
+
* `point=0` often produces boring results for polynomial approximation.
|
|
194
|
+
* In many, but not all, cases, `jit=True` will provide a significant performance
|
|
195
|
+
improvement, sometimes by an order of magnitude. It also permits calling
|
|
196
|
+
`quickseries`-generated functions from within other `numba`-compiled
|
|
197
|
+
functions.
|
|
198
|
+
* Note that some functions may not be compatible with `numba`.
|
|
199
|
+
* `quickseries` tends to be most effective on univariate functions, mostly
|
|
200
|
+
because the number of terms in a function's power expansion increases
|
|
201
|
+
geometrically with its number of free parameters.
|
|
202
|
+
* Functions generated by `quickseries()` may in some cases be less
|
|
203
|
+
space/memory-efficient even if they are more time/compute-efficient.
|
|
204
|
+
* By default, `quickseries` takes the analytic series expansion of the input
|
|
205
|
+
function as a strong suggestion rather than the last word on the topic, and
|
|
206
|
+
performs a numerical optimization step to improve its goodness of fit across
|
|
207
|
+
`bounds`. There are good reasons you might not want it to do this, though --
|
|
208
|
+
for instance, if your input arguments are always going to be quite close to
|
|
209
|
+
`point`, messing with the analytic series expansion may be wasteful or even
|
|
210
|
+
counterproductive. If you don't want it to do this, pass `fit_series_expansion=False`.
|
|
211
|
+
In this case, `quickseries` ignores the `bounds` argument, except to infer
|
|
212
|
+
a value for `point` if you do not specify one.
|
|
213
|
+
* In some cases, this optimization step can become numerically unstable. In
|
|
214
|
+
these cases, you may wish to experiment with constraining it rather than
|
|
215
|
+
turning it off completely. You can do this by passing `bound_series_fit=True`.
|
|
216
|
+
* By default, the functions that `quickseries` generates precompute all repeated
|
|
217
|
+
exponents in the generated polynomial. This is a space-for-time trade, and
|
|
218
|
+
may not always be desirable (or even effective). You can turn this off by
|
|
219
|
+
passing `prefactor=False`.
|
|
220
|
+
* If `jit=True`, `quickseries` does _not_ do this by default. The `numba`
|
|
221
|
+
compiler implicitly performs a similar optimization, and computing these
|
|
222
|
+
terms explicitly tends to be counterproductive. If you want `quickseries`
|
|
223
|
+
to do it anyway, you can pass `prefactor=True`.
|
|
224
|
+
* Specifying `precision` can lead to significant speedups and memory usage
|
|
225
|
+
improvements.
|
|
226
|
+
* Many libraries and formats do not support the "half-float" values generated
|
|
227
|
+
by `quickseries` when passed `precision=16`.
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
## tests
|
|
231
|
+
|
|
232
|
+
`quickseries` has a few simple tests. You can run them by executing `pytest`
|
|
233
|
+
in the repository's root directory. More comprehensive test coverage is
|
|
234
|
+
planned.
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from inspect import getfullargspec, signature
|
|
3
|
+
from itertools import chain
|
|
4
|
+
from typing import Literal, Optional, Sequence, Union, Collection
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import sympy as sp
|
|
8
|
+
from cytoolz import groupby
|
|
9
|
+
from dustgoggles.func import gmap
|
|
10
|
+
|
|
11
|
+
from quickseries.expansions import multivariate_taylor, series_lambda
|
|
12
|
+
from quickseries.simplefit import fit
|
|
13
|
+
from quickseries.sourceutils import (
|
|
14
|
+
_cacheget, _cacheid, _finalize_quickseries, lastline
|
|
15
|
+
)
|
|
16
|
+
from quickseries.sputils import LmSig, lambdify
|
|
17
|
+
|
|
18
|
+
"""signature of sympy-lambdified numpy/scipy functions"""
|
|
19
|
+
|
|
20
|
+
EXP_PATTERN = re.compile(r"\w+ ?\*\* ?(\d+)")
|
|
21
|
+
"""what exponentials in sympy-lambdified functions look like"""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def is_simple_poly(expr: sp.Expr) -> bool:
|
|
25
|
+
gens = sp.poly_from_expr(expr)[1]["gens"]
|
|
26
|
+
return all(isinstance(g, sp.Symbol) for g in gens)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def regexponents(text: str) -> tuple[int]:
|
|
30
|
+
# noinspection PyTypeChecker
|
|
31
|
+
return tuple(map(int, re.findall(EXP_PATTERN, text)))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _decompose(
|
|
35
|
+
remaining: tuple[str],
|
|
36
|
+
reduced: set[str],
|
|
37
|
+
replacements: list[tuple[int, list[int]]]
|
|
38
|
+
) -> bool:
|
|
39
|
+
if len(remaining) == 1: # trivial case
|
|
40
|
+
replacements[0][1][:] = [1 for _ in range(replacements[0][0])]
|
|
41
|
+
return True
|
|
42
|
+
counts = {
|
|
43
|
+
k: len(v)
|
|
44
|
+
for k, v in groupby(lambda x: x, remaining).items()
|
|
45
|
+
if k not in reduced
|
|
46
|
+
}
|
|
47
|
+
if len(counts) < 2: # nothing useful left to do
|
|
48
|
+
return True
|
|
49
|
+
elif counts[max(counts)] > 1:
|
|
50
|
+
# don't decompose the biggest factor; because it appears more than
|
|
51
|
+
# once, we'd like to precompute it
|
|
52
|
+
reduced.add(max(counts))
|
|
53
|
+
return False
|
|
54
|
+
# otherwise, do a decomposition pass with the smallest factor
|
|
55
|
+
factor = sorted(counts.keys())[0]
|
|
56
|
+
for k, v in replacements:
|
|
57
|
+
factorization = []
|
|
58
|
+
# "divide out" `factor` from elements of existing decomposition
|
|
59
|
+
for f in v:
|
|
60
|
+
# don't decompose factors we've already evaluated, and don't
|
|
61
|
+
# try to divide `factor` out of smaller factors (nonsensical)
|
|
62
|
+
if f in reduced or f <= factor:
|
|
63
|
+
factorization.append(f)
|
|
64
|
+
continue
|
|
65
|
+
factorization.append(factor)
|
|
66
|
+
difference = f - factor
|
|
67
|
+
while difference >= max([e for e in remaining if e != f]):
|
|
68
|
+
factorization.append(factor)
|
|
69
|
+
difference = difference - factor
|
|
70
|
+
if difference > 0:
|
|
71
|
+
factorization.append(difference)
|
|
72
|
+
v[:] = factorization
|
|
73
|
+
reduced.add(factor)
|
|
74
|
+
return False
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def optimize_exponents(
|
|
78
|
+
exps: Sequence[int],
|
|
79
|
+
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
|
|
80
|
+
# list of tuples like: (power, [powers to use in decomposition])
|
|
81
|
+
replacements = [(e, [e]) for e in exps]
|
|
82
|
+
# which powers have we already assessed?
|
|
83
|
+
reduced = set()
|
|
84
|
+
# which powers haven't we?
|
|
85
|
+
remaining = tuple(chain(*[r[1] for r in replacements]))
|
|
86
|
+
# NOTE: _decompose() modifies remaning and reduced inplace
|
|
87
|
+
while _decompose(remaining, reduced, replacements) is False:
|
|
88
|
+
remaining = tuple(chain(*[r[1] for r in replacements]))
|
|
89
|
+
# this is analogous to casting to set: we no longer care about number of
|
|
90
|
+
# occurrences
|
|
91
|
+
replacements = {k: v for k, v in replacements}
|
|
92
|
+
# figure out which factors we'd like to predefine as variables, and what
|
|
93
|
+
# the "building blocks" of those variables are. 1 is a placeholder: we
|
|
94
|
+
# will never define it, but it's useful in this loop.
|
|
95
|
+
variables = {1: [1]}
|
|
96
|
+
for e in sorted(set(remaining)):
|
|
97
|
+
if e == 1:
|
|
98
|
+
continue
|
|
99
|
+
if exps.count(e) == 1:
|
|
100
|
+
if not any(k > e for k, v in replacements.items()):
|
|
101
|
+
continue
|
|
102
|
+
vfactor, remainder = [], e
|
|
103
|
+
while remainder > 0:
|
|
104
|
+
pick = max([v for v in variables.keys() if v <= remainder])
|
|
105
|
+
vfactor.append(pick)
|
|
106
|
+
remainder -= pick
|
|
107
|
+
variables[e] = vfactor
|
|
108
|
+
# remove the placeholder first-order variable
|
|
109
|
+
variables.pop(1)
|
|
110
|
+
return replacements, variables
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def force_line_precision(line: str, precision: Literal[16, 32, 64]) -> str:
|
|
114
|
+
constructor_rep = f"float{precision}"
|
|
115
|
+
constructor = getattr(np, f"float{precision}")
|
|
116
|
+
last, out = 0, ""
|
|
117
|
+
for match in re.finditer(
|
|
118
|
+
r"([+* (-]+|^)([\d.]+)(e[+\-]?\d+)?.*?([+* )]|$)", line
|
|
119
|
+
):
|
|
120
|
+
out += line[last : match.span()[0]]
|
|
121
|
+
# don't replace exponents
|
|
122
|
+
if match.group(1) == "**":
|
|
123
|
+
out += line[slice(*match.span())]
|
|
124
|
+
else:
|
|
125
|
+
# NOTE: casting number to string within the f-string statement
|
|
126
|
+
# appears to upcast it before generating the representation.
|
|
127
|
+
number = str(constructor(float(match.group(2))))
|
|
128
|
+
out += f"{match.group(1)}{constructor_rep}({number}"
|
|
129
|
+
if match.group(3) is not None: # scientific notation
|
|
130
|
+
out += match.group(3)
|
|
131
|
+
out += f"){match.group(4)}"
|
|
132
|
+
last = match.span()[1]
|
|
133
|
+
return out + line[last:]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def rewrite(
|
|
137
|
+
poly_lambda: LmSig,
|
|
138
|
+
precompute: bool = True,
|
|
139
|
+
precision: Optional[Literal[16, 32, 64]] = None,
|
|
140
|
+
) -> str:
|
|
141
|
+
# sympy will always place this on a single line. it includes
|
|
142
|
+
# the Python expression form of the hornerized polynomial
|
|
143
|
+
# and a return statement. lastline() grabs polynomial and strips return.
|
|
144
|
+
polyexpr = lastline(poly_lambda)
|
|
145
|
+
# remove pointless '1.0' terms
|
|
146
|
+
polyexpr = re.sub(r"(?:\*+)?1\.0\*+", "", polyexpr)
|
|
147
|
+
# names of arguments to the lambdified function
|
|
148
|
+
free = getfullargspec(poly_lambda).args
|
|
149
|
+
lines = []
|
|
150
|
+
if precompute is True:
|
|
151
|
+
polyexpr, factorlines = _rewrite_precomputed(polyexpr, free)
|
|
152
|
+
lines += factorlines
|
|
153
|
+
if precision is not None:
|
|
154
|
+
polyexpr = force_line_precision(polyexpr, precision)
|
|
155
|
+
lines.append(f"return {polyexpr}")
|
|
156
|
+
_, key = _cacheid()
|
|
157
|
+
lines.insert(0, f"def {key}({', '.join(free)}):")
|
|
158
|
+
return "\n ".join(lines)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _rewrite_precomputed(
|
|
162
|
+
polyexpr: str, free: Collection[str]
|
|
163
|
+
) -> tuple[str, list[str]]:
|
|
164
|
+
# replacements: what factors we will decompose each exponent into
|
|
165
|
+
# free: which factors we will define as variables, and their
|
|
166
|
+
# "building blocks"
|
|
167
|
+
factorlines = []
|
|
168
|
+
for f in free:
|
|
169
|
+
expat = re.compile(rf"{f}+ ?\*\* ?(\d+)")
|
|
170
|
+
replacements, variables = optimize_exponents(
|
|
171
|
+
gmap(int, expat.findall(polyexpr))
|
|
172
|
+
)
|
|
173
|
+
for k, v in variables.items():
|
|
174
|
+
multiplicands = []
|
|
175
|
+
for power in v:
|
|
176
|
+
if power == 1:
|
|
177
|
+
multiplicands.append(f)
|
|
178
|
+
else:
|
|
179
|
+
multiplicands.append(f"{f}{power}")
|
|
180
|
+
factorlines.append(f"{f}{k} = {'*'.join(multiplicands)}")
|
|
181
|
+
for k, v in replacements.items():
|
|
182
|
+
substitution = "*".join([f"{f}{r}" if r != 1 else f for r in v])
|
|
183
|
+
polyexpr = polyexpr.replace(f"{f}**{k}", substitution)
|
|
184
|
+
return polyexpr, factorlines
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _pvec(
|
|
188
|
+
bounds: Sequence[tuple[float, float]], offset_resolution: int
|
|
189
|
+
) -> list[np.ndarray]:
|
|
190
|
+
axes = [np.linspace(*b, offset_resolution) for b in bounds]
|
|
191
|
+
indices = map(np.ravel, np.indices([offset_resolution for _ in bounds]))
|
|
192
|
+
return [j[i] for j, i in zip(axes, indices)]
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _perform_series_fit(
|
|
196
|
+
func: str | sp.Expr,
|
|
197
|
+
bounds: tuple[float, float] | Sequence[tuple[float, float]],
|
|
198
|
+
nterms: int,
|
|
199
|
+
fitres: int,
|
|
200
|
+
point: float | Sequence[float],
|
|
201
|
+
apply_bounds: bool,
|
|
202
|
+
is_poly: bool
|
|
203
|
+
) -> tuple[sp.Expr, np.ndarray]:
|
|
204
|
+
if (len(bounds) == 1) and (is_poly is False):
|
|
205
|
+
approx, expr = series_lambda(func, point[0], nterms, True)
|
|
206
|
+
else:
|
|
207
|
+
approx, expr = multivariate_taylor(func, point, nterms, True)
|
|
208
|
+
lamb, vecs = lambdify(func), _pvec(bounds, fitres)
|
|
209
|
+
try:
|
|
210
|
+
dep = lamb(*vecs)
|
|
211
|
+
except TypeError as err:
|
|
212
|
+
# this is a potentially slow but unavoidable case
|
|
213
|
+
if "converted to Python scalars" not in str(err):
|
|
214
|
+
raise
|
|
215
|
+
dep = np.array([lamb(v) for v in vecs])
|
|
216
|
+
guess = [1 for _ in range(len(signature(approx).parameters) - len(vecs))]
|
|
217
|
+
params, _ = fit(
|
|
218
|
+
func=approx,
|
|
219
|
+
vecs=vecs,
|
|
220
|
+
dependent_variable=dep,
|
|
221
|
+
guess=guess,
|
|
222
|
+
bounds=(-5, 5) if apply_bounds is True else None,
|
|
223
|
+
)
|
|
224
|
+
# insert coefficients into polynomial
|
|
225
|
+
expr = expr.subs({f"a_{i}": coef for i, coef in enumerate(params)})
|
|
226
|
+
return expr, params
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _makebounds(
|
|
230
|
+
bounds: Optional[Sequence[tuple[float, float]] | tuple[float, float]],
|
|
231
|
+
n_free: int,
|
|
232
|
+
point: Optional[Sequence[float] | float]
|
|
233
|
+
) -> tuple[list[tuple[float, float]], list[float]]:
|
|
234
|
+
bounds = (-1, 1) if bounds is None else bounds
|
|
235
|
+
if not isinstance(bounds[0], (list, tuple)):
|
|
236
|
+
bounds = [bounds for _ in range(n_free)]
|
|
237
|
+
if point is None:
|
|
238
|
+
point = [np.mean(b) for b in bounds]
|
|
239
|
+
elif not isinstance(point, (list, tuple)):
|
|
240
|
+
point = [point for _ in bounds]
|
|
241
|
+
return bounds, point
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _make_quickseries(
|
|
245
|
+
approx_poly: bool,
|
|
246
|
+
bound_series_fit: bool,
|
|
247
|
+
bounds: Optional[Sequence[tuple[float, float]] | tuple[float, float]],
|
|
248
|
+
expr: sp.Expr,
|
|
249
|
+
fit_series_expansion: bool,
|
|
250
|
+
fitres: int,
|
|
251
|
+
nterms: int,
|
|
252
|
+
point: Optional[Sequence[float] | float],
|
|
253
|
+
precision: Optional[Literal[16, 32, 64]],
|
|
254
|
+
prefactor: bool,
|
|
255
|
+
) -> dict[str, sp.Expr | np.ndarray | str]:
|
|
256
|
+
if len(expr.free_symbols) == 0:
|
|
257
|
+
raise ValueError("func must have at least one free variable.")
|
|
258
|
+
free = sorted(expr.free_symbols, key=lambda s: str(s))
|
|
259
|
+
bounds, point = _makebounds(bounds, len(free), point)
|
|
260
|
+
output, is_poly = {}, is_simple_poly(expr)
|
|
261
|
+
if (approx_poly is True) or (is_poly is False):
|
|
262
|
+
if fit_series_expansion is True:
|
|
263
|
+
expr, output["params"] = _perform_series_fit(
|
|
264
|
+
expr, bounds, nterms, fitres, point, bound_series_fit, is_poly
|
|
265
|
+
)
|
|
266
|
+
elif (len(free) > 1) or (is_poly is True):
|
|
267
|
+
_, expr = multivariate_taylor(expr, point, nterms, False)
|
|
268
|
+
else:
|
|
269
|
+
_, expr = series_lambda(expr, point[0], nterms, False)
|
|
270
|
+
# rewrite polynomial in horner form for fast evaluation
|
|
271
|
+
output["expr"] = sp.horner(expr)
|
|
272
|
+
polyfunc = sp.lambdify(free, output["expr"], ("scipy", "numpy"))
|
|
273
|
+
# polish it and optionally rewrite it to precompute repeated powers or
|
|
274
|
+
# force precision
|
|
275
|
+
return output | {"source": rewrite(polyfunc, prefactor, precision)}
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def quickseries(
|
|
279
|
+
func: Union[str, sp.Expr],
|
|
280
|
+
*,
|
|
281
|
+
bounds: tuple[float, float] = (-1, 1),
|
|
282
|
+
nterms: int = 9,
|
|
283
|
+
point: Optional[float] = None,
|
|
284
|
+
fitres: int = 100,
|
|
285
|
+
prefactor: Optional[bool] = None,
|
|
286
|
+
approx_poly: bool = False,
|
|
287
|
+
jit: bool = False,
|
|
288
|
+
precision: Optional[Literal[16, 32, 64]] = None,
|
|
289
|
+
fit_series_expansion: bool = True,
|
|
290
|
+
bound_series_fit: bool = False,
|
|
291
|
+
extended_output: bool = False,
|
|
292
|
+
cache: bool = True,
|
|
293
|
+
) -> Union[LmSig, tuple[LmSig, dict]]:
|
|
294
|
+
if not isinstance(func, (str, sp.Expr)):
|
|
295
|
+
raise TypeError(f"Unsupported type for func {type(func)}.")
|
|
296
|
+
polyfunc, ext = None, {"cache": "off"}
|
|
297
|
+
if cache is True:
|
|
298
|
+
polyfunc, source = _cacheget(jit)
|
|
299
|
+
if polyfunc is not None:
|
|
300
|
+
ext |= {"source": source, "cache": "hit"}
|
|
301
|
+
else:
|
|
302
|
+
ext["cache"] = "miss"
|
|
303
|
+
if polyfunc is None:
|
|
304
|
+
ext |= _make_quickseries(
|
|
305
|
+
approx_poly,
|
|
306
|
+
bound_series_fit,
|
|
307
|
+
bounds,
|
|
308
|
+
func if isinstance(func, sp.Expr) else sp.sympify(func),
|
|
309
|
+
fit_series_expansion,
|
|
310
|
+
fitres,
|
|
311
|
+
nterms,
|
|
312
|
+
point,
|
|
313
|
+
precision,
|
|
314
|
+
prefactor if prefactor is not None else not jit
|
|
315
|
+
)
|
|
316
|
+
polyfunc = _finalize_quickseries(ext["source"], jit, cache)
|
|
317
|
+
if extended_output is True:
|
|
318
|
+
return polyfunc, ext
|
|
319
|
+
return polyfunc
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import timeit
|
|
2
|
+
from inspect import getfullargspec
|
|
3
|
+
from itertools import product
|
|
4
|
+
from time import time
|
|
5
|
+
from typing import Union, Sequence, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import sympy as sp
|
|
9
|
+
from dustgoggles.func import gmap
|
|
10
|
+
|
|
11
|
+
from quickseries import quickseries
|
|
12
|
+
from quickseries.approximate import _makebounds
|
|
13
|
+
from quickseries.sputils import lambdify, LmSig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _offset_check_cycle(
|
|
17
|
+
absdiff: float,
|
|
18
|
+
frange: tuple[float, float],
|
|
19
|
+
lamb: LmSig,
|
|
20
|
+
quick: LmSig,
|
|
21
|
+
vecs: Sequence[np.ndarray],
|
|
22
|
+
worstpoint: Optional[list[float]],
|
|
23
|
+
) -> tuple[float, float, float, tuple[float, float], list[float]]:
|
|
24
|
+
approx_y, orig_y = quick(*vecs), lamb(*vecs)
|
|
25
|
+
frange = (min(orig_y.min(), frange[0]), max(orig_y.max(), frange[1]))
|
|
26
|
+
offset = abs(approx_y - orig_y)
|
|
27
|
+
worstix = np.argmax(offset)
|
|
28
|
+
if (new_absdiff := offset[worstix]) > absdiff:
|
|
29
|
+
absdiff = new_absdiff
|
|
30
|
+
worstpoint = [v[worstix] for v in vecs]
|
|
31
|
+
return absdiff, np.median(offset), np.mean(offset ** 2), frange, worstpoint
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def benchmark(
|
|
35
|
+
func: Union[str, sp.Expr, sp.core.function.FunctionClass],
|
|
36
|
+
offset_resolution: int = 10000,
|
|
37
|
+
n_offset_shuffles: int = 50,
|
|
38
|
+
timeit_cycles: int = 20000,
|
|
39
|
+
testbounds="equal",
|
|
40
|
+
cache: bool = False,
|
|
41
|
+
**quickkwargs
|
|
42
|
+
) -> dict[str, sp.Expr | float | np.ndarray | str | list[float]]:
|
|
43
|
+
lamb = lambdify(func)
|
|
44
|
+
compile_start = time()
|
|
45
|
+
quick, ext = quickseries(
|
|
46
|
+
func, **(quickkwargs | {'extended_output': True, 'cache': cache})
|
|
47
|
+
)
|
|
48
|
+
gentime = time() - compile_start
|
|
49
|
+
if testbounds == "equal":
|
|
50
|
+
testbounds, _ = _makebounds(
|
|
51
|
+
quickkwargs.get("bounds"), len(getfullargspec(lamb).args), None
|
|
52
|
+
)
|
|
53
|
+
vecs = [np.linspace(*b, offset_resolution) for b in testbounds]
|
|
54
|
+
if (pre := quickkwargs.get("precision")) is not None:
|
|
55
|
+
vecs = gmap(
|
|
56
|
+
lambda arr: arr.astype(getattr(np, f"float{pre}")), vecs
|
|
57
|
+
)
|
|
58
|
+
if len(testbounds) > 1:
|
|
59
|
+
# always check the extrema of the bounds
|
|
60
|
+
extrema = [[] for _ in vecs]
|
|
61
|
+
for p in product((-1, 1), repeat=len(vecs)):
|
|
62
|
+
for i, side in enumerate(p):
|
|
63
|
+
extrema[i].append(vecs[i][side])
|
|
64
|
+
extrema = [np.array(e) for e in extrema]
|
|
65
|
+
absdiff, _, __, frange, worstpoint = _offset_check_cycle(
|
|
66
|
+
0, (np.inf, -np.inf), lamb, quick, extrema, None
|
|
67
|
+
)
|
|
68
|
+
medians, mses = [], []
|
|
69
|
+
for _ in range(n_offset_shuffles):
|
|
70
|
+
gmap(np.random.shuffle, vecs)
|
|
71
|
+
absdiff, mediff, mse, frange, worstpoint = _offset_check_cycle(
|
|
72
|
+
absdiff, frange, lamb, quick, vecs, worstpoint
|
|
73
|
+
)
|
|
74
|
+
medians.append(mediff)
|
|
75
|
+
mses.append(mse)
|
|
76
|
+
mediff, mse = np.median(medians), np.median(mses)
|
|
77
|
+
# no point in shuffling for 1D -- we're doing that for > 1D
|
|
78
|
+
# because it becomes quickly unreasonable in terms of memory
|
|
79
|
+
# to be exhaustive, but this _is_ exhaustive for 1D
|
|
80
|
+
else:
|
|
81
|
+
approx_y, orig_y = quick(*vecs), lamb(*vecs)
|
|
82
|
+
frange = (orig_y.min(), orig_y.max())
|
|
83
|
+
offset = abs(approx_y - orig_y)
|
|
84
|
+
worstix = np.argmax(offset)
|
|
85
|
+
absdiff = offset[worstix]
|
|
86
|
+
mediff = np.median(offset)
|
|
87
|
+
mse = np.mean(offset ** 2)
|
|
88
|
+
worstpoint = [vecs[0][worstix]]
|
|
89
|
+
del offset, orig_y, approx_y
|
|
90
|
+
# TODO: should probably permit specifying dtype for jitted
|
|
91
|
+
# functions -- both here and in primary quickseries().
|
|
92
|
+
approx_time = timeit.timeit(lambda: quick(*vecs), number=timeit_cycles)
|
|
93
|
+
orig_time = timeit.timeit(lambda: lamb(*vecs), number=timeit_cycles)
|
|
94
|
+
orig_s = orig_time / timeit_cycles
|
|
95
|
+
approx_s = approx_time / timeit_cycles
|
|
96
|
+
return {
|
|
97
|
+
'absdiff': absdiff,
|
|
98
|
+
'reldiff': absdiff / np.ptp(frange),
|
|
99
|
+
'mediff': mediff,
|
|
100
|
+
'mse': mse,
|
|
101
|
+
'worstpoint': worstpoint,
|
|
102
|
+
'range': frange,
|
|
103
|
+
'orig_s': orig_s,
|
|
104
|
+
'approx_s': approx_s,
|
|
105
|
+
'timeratio': approx_s / orig_s,
|
|
106
|
+
'gentime': gentime,
|
|
107
|
+
'polyfunc': quick
|
|
108
|
+
} | ext
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from functools import reduce
|
|
2
|
+
from typing import Union, Sequence
|
|
3
|
+
|
|
4
|
+
from dustgoggles.structures import listify
|
|
5
|
+
import sympy as sp
|
|
6
|
+
|
|
7
|
+
from quickseries.sputils import LmSig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _rectify_series(series, add_coefficients):
|
|
11
|
+
if isinstance(series, sp.Order):
|
|
12
|
+
raise ValueError(
|
|
13
|
+
"Cannot produce a meaningful approximation with the requested "
|
|
14
|
+
"parameters (most likely order is too low)."
|
|
15
|
+
)
|
|
16
|
+
outargs, coefsyms = [], []
|
|
17
|
+
for a in series.args:
|
|
18
|
+
# NOTE: the Expr.evalf() calls are simply to try to evaluate
|
|
19
|
+
# anything we can.
|
|
20
|
+
if hasattr(a, "evalf") and isinstance((n := a.evalf()), sp.Number):
|
|
21
|
+
outargs.append(n)
|
|
22
|
+
elif isinstance(a, sp.Order):
|
|
23
|
+
continue
|
|
24
|
+
elif isinstance(a, (sp.Mul, sp.Symbol, sp.Pow)):
|
|
25
|
+
if add_coefficients is True:
|
|
26
|
+
coefficient = sp.symbols(f"a_{len(coefsyms)}")
|
|
27
|
+
outargs.append((coefficient * a).evalf())
|
|
28
|
+
coefsyms.append(coefficient)
|
|
29
|
+
else:
|
|
30
|
+
outargs.append(a.evalf())
|
|
31
|
+
else:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
f"don't know how to handle expression element {a} of "
|
|
34
|
+
f"type({type(a)})"
|
|
35
|
+
)
|
|
36
|
+
return sum(outargs), coefsyms
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def series_lambda(
|
|
40
|
+
func: Union[str, sp.Expr],
|
|
41
|
+
x0: float = 0,
|
|
42
|
+
nterms: int = 9,
|
|
43
|
+
add_coefficients: bool = False,
|
|
44
|
+
modules: Union[str, Sequence[str]] = ("scipy", "numpy")
|
|
45
|
+
) -> tuple[LmSig, sp.Expr]:
|
|
46
|
+
"""
|
|
47
|
+
Construct a power expansion of a sympy Expr or the string expression of a
|
|
48
|
+
function; optionally, add free coefficients to the terms of the resulting
|
|
49
|
+
polynomial to permit optimization by downstream functions.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
func: Mathematical function to expand, expressed as a string or a
|
|
53
|
+
sympy Expr.
|
|
54
|
+
x0: Point about which to expand func.
|
|
55
|
+
nterms: Order of power expansion.
|
|
56
|
+
add_coefficients: If True, add additional arguments/symbols to the
|
|
57
|
+
returned function and Expr corresponding to the polynomial's
|
|
58
|
+
coefficients.
|
|
59
|
+
modules: Modules from which to draw the building blocks of the
|
|
60
|
+
returned function.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
approximant: Python function that implements the power expansion.
|
|
64
|
+
expr: sympy Expr used to construct approximant.
|
|
65
|
+
"""
|
|
66
|
+
func = sp.sympify(func) if isinstance(func, str) else func
|
|
67
|
+
# limiting precision of x0 is necessary due to a bug in sp.series
|
|
68
|
+
series = sp.series(func, x0=round(x0, 6), n=nterms)
|
|
69
|
+
# noinspection PyTypeChecker
|
|
70
|
+
# remove Order (limit behavior) terms, try to split constants from
|
|
71
|
+
# polynomial terms
|
|
72
|
+
expr, coefsyms = _rectify_series(series, add_coefficients)
|
|
73
|
+
syms = sorted(func.free_symbols, key=lambda x: str(x))
|
|
74
|
+
# noinspection PyTypeChecker
|
|
75
|
+
return sp.lambdify(syms + coefsyms, expr, modules), expr
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def additive_combinations(n_terms, number):
|
|
79
|
+
if n_terms == 1:
|
|
80
|
+
return [(n,) for n in range(number + 1)]
|
|
81
|
+
combinations = [] # NOTE: this is super gross-looking written as a chain
|
|
82
|
+
for j in range(number + 1):
|
|
83
|
+
combinations += [
|
|
84
|
+
(j, *t)
|
|
85
|
+
for t in additive_combinations(n_terms - 1, number - j)
|
|
86
|
+
]
|
|
87
|
+
return combinations
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def multivariate_taylor(
|
|
91
|
+
func: Union[str, sp.Expr],
|
|
92
|
+
point: Sequence[float],
|
|
93
|
+
nterms: int,
|
|
94
|
+
add_coefficients: bool = False
|
|
95
|
+
) -> tuple[LmSig, sp.Expr]:
|
|
96
|
+
func = sp.sympify(func) if isinstance(func, str) else func
|
|
97
|
+
pointsyms = sorted(func.free_symbols, key=lambda s: str(s))
|
|
98
|
+
dimensionality = len(pointsyms)
|
|
99
|
+
argsyms = listify(
|
|
100
|
+
sp.symbols(",".join([f"x{i}" for i in range(dimensionality)]))
|
|
101
|
+
)
|
|
102
|
+
ixsyms = listify(
|
|
103
|
+
sp.symbols(",".join(f"i{i}" for i in range(dimensionality)))
|
|
104
|
+
)
|
|
105
|
+
deriv = sp.Derivative(func, *[(p, i) for p, i in zip(pointsyms, ixsyms)])
|
|
106
|
+
# noinspection PyTypeChecker
|
|
107
|
+
fact = reduce(sp.Mul, [sp.factorial(i) for i in ixsyms])
|
|
108
|
+
err = reduce(
|
|
109
|
+
sp.Mul,
|
|
110
|
+
[(x - a) ** i for x, a, i in zip(argsyms, pointsyms, ixsyms)]
|
|
111
|
+
)
|
|
112
|
+
taylor = deriv / fact * err
|
|
113
|
+
# TODO, probably: there's a considerably faster way to do this in some
|
|
114
|
+
# cases by precomputing partial derivatives
|
|
115
|
+
decomp = additive_combinations(dimensionality, nterms - 1)
|
|
116
|
+
built = reduce(
|
|
117
|
+
sp.Add,
|
|
118
|
+
(taylor.subs({i: d for i, d in zip(ixsyms, d)}) for d in decomp)
|
|
119
|
+
).doit()
|
|
120
|
+
evaluated = built.subs({s: p for s, p in zip(pointsyms, point)}).evalf()
|
|
121
|
+
# this next line is kind of aesthetic -- we just want the argument names
|
|
122
|
+
# to be consistent with the input
|
|
123
|
+
evaluated = evaluated.subs({a: p for a, p in zip(argsyms, pointsyms)})
|
|
124
|
+
evaluated, coefsyms = _rectify_series(evaluated, add_coefficients)
|
|
125
|
+
# noinspection PyTypeChecker
|
|
126
|
+
return sp.lambdify(pointsyms + coefsyms, evaluated), evaluated
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""lightweight version of `moonbow`'s polynomial fit functionality"""
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from inspect import Parameter, signature
|
|
4
|
+
from typing import Callable, Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.optimize import curve_fit
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def fit_wrap(
|
|
11
|
+
func: Callable[[np.ndarray | float, ...], np.ndarray | float],
|
|
12
|
+
dimensionality: int,
|
|
13
|
+
fit_parameters: Sequence[str]
|
|
14
|
+
) -> Callable[[np.ndarray | float, ...], np.ndarray | float]:
|
|
15
|
+
@wraps(func)
|
|
16
|
+
def wrapped_fit(independent_variable, *params):
|
|
17
|
+
variable_components = [
|
|
18
|
+
independent_variable[n] for n in range(dimensionality)
|
|
19
|
+
]
|
|
20
|
+
exploded_function = func(*variable_components, *params)
|
|
21
|
+
return exploded_function
|
|
22
|
+
|
|
23
|
+
# rewrite the signature so that curve_fit will like it
|
|
24
|
+
sig = signature(wrapped_fit)
|
|
25
|
+
curve_fit_params = (
|
|
26
|
+
Parameter("independent_variable", Parameter.POSITIONAL_ONLY),
|
|
27
|
+
*fit_parameters,
|
|
28
|
+
)
|
|
29
|
+
wrapped_fit.__signature__ = sig.replace(parameters=curve_fit_params)
|
|
30
|
+
return wrapped_fit
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def fit(
|
|
34
|
+
func: Callable,
|
|
35
|
+
vecs: list[np.ndarray],
|
|
36
|
+
dependent_variable: np.ndarray,
|
|
37
|
+
guess: Optional[Sequence[float]] = None,
|
|
38
|
+
bounds: Optional[
|
|
39
|
+
Union[tuple[tuple[float, float]], tuple[float, float]]
|
|
40
|
+
] = None
|
|
41
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
42
|
+
sig = signature(func)
|
|
43
|
+
assert len(vecs) < len(sig.parameters), (
|
|
44
|
+
"The model function must have at least one 'free' "
|
|
45
|
+
"parameter to be a meaningful candidate for fitting."
|
|
46
|
+
)
|
|
47
|
+
fit_parameters = [
|
|
48
|
+
item
|
|
49
|
+
for ix, item in enumerate(sig.parameters.values())
|
|
50
|
+
if ix >= len(vecs)
|
|
51
|
+
]
|
|
52
|
+
# TODO: check dim of dependent
|
|
53
|
+
if not all(p.ndim == 1 for p in vecs):
|
|
54
|
+
raise ValueError("each input vector must be 1-dimensional")
|
|
55
|
+
# TODO: optional goodness-of-fit evaluation
|
|
56
|
+
kw = {'bounds': bounds} if bounds is not None else {}
|
|
57
|
+
# noinspection PyTypeChecker
|
|
58
|
+
return curve_fit(
|
|
59
|
+
fit_wrap(func, len(vecs), fit_parameters),
|
|
60
|
+
vecs,
|
|
61
|
+
dependent_variable,
|
|
62
|
+
maxfev=20000,
|
|
63
|
+
p0=guess,
|
|
64
|
+
**kw
|
|
65
|
+
)
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
import pickle
|
|
2
|
+
from hashlib import md5
|
|
3
|
+
from inspect import currentframe, getargvalues, getsource
|
|
4
|
+
import linecache
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import re
|
|
7
|
+
from types import FunctionType
|
|
8
|
+
from typing import Callable
|
|
9
|
+
|
|
10
|
+
from dustgoggles.dynamic import define, get_codechild
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
CACHE_ARGS = (
|
|
14
|
+
"func",
|
|
15
|
+
"bounds",
|
|
16
|
+
"nterms",
|
|
17
|
+
"point",
|
|
18
|
+
"fitres",
|
|
19
|
+
"prefactor",
|
|
20
|
+
"approx_poly",
|
|
21
|
+
"precision",
|
|
22
|
+
"fit_series_expansion",
|
|
23
|
+
"bound_series_fit"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def cache_source(source: str, fn: Path):
|
|
28
|
+
fn = str(fn)
|
|
29
|
+
linecache.cache[fn] = (len(source), None, source.splitlines(True), fn)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# TODO: pull this little bitty change up to dustgoggles
|
|
33
|
+
def compile_source(source: str, fn: str = ""):
|
|
34
|
+
return get_codechild(compile(source, fn, "exec"))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _cachedir(callfile: str) -> Path:
|
|
38
|
+
if callfile == 'ipython_shell':
|
|
39
|
+
import IPython.paths
|
|
40
|
+
|
|
41
|
+
return Path(IPython.paths.get_ipython_cache_dir()) / "qs_cache"
|
|
42
|
+
return Path(callfile).parent / "__pycache__" / "qs_cache"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _cachekey(args, callfile=None):
|
|
46
|
+
from quickseries import __version__
|
|
47
|
+
|
|
48
|
+
# TODO: is this actually stable?
|
|
49
|
+
arghash = pickle.dumps(
|
|
50
|
+
{a: args.locals[a] for a in sorted(CACHE_ARGS)}
|
|
51
|
+
| {'f': callfile, '__version__': __version__}
|
|
52
|
+
)
|
|
53
|
+
# arbitrary cutoff for a reasonable tradeoff between collision safety and
|
|
54
|
+
# readability
|
|
55
|
+
return f"quickseries_{md5(arghash).hexdigest()}"[:-18]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# TODO, maybe: the frame traversal is potentially wasteful when repeated,
|
|
59
|
+
# although it probably doesn't matter too much.
|
|
60
|
+
def _cacheid():
|
|
61
|
+
"""
|
|
62
|
+
WARNING: do not call this outside the normal quickseries workflow. It can
|
|
63
|
+
be tricked, but to no good end.
|
|
64
|
+
"""
|
|
65
|
+
frame, callfile, args = currentframe(), None, None
|
|
66
|
+
while callfile is None:
|
|
67
|
+
frame = frame.f_back
|
|
68
|
+
if frame is None or frame.f_code.co_filename == "<stdin>":
|
|
69
|
+
callfile = "__quickseries_anonymous_caller_cache__/anonymous"
|
|
70
|
+
elif hasattr(frame.f_code, "co_name"):
|
|
71
|
+
if args is None and frame.f_code.co_name == "quickseries":
|
|
72
|
+
args = getargvalues(frame)
|
|
73
|
+
elif frame.f_code.co_name == "benchmark":
|
|
74
|
+
continue
|
|
75
|
+
elif args is not None:
|
|
76
|
+
callfile = frame.f_code.co_filename
|
|
77
|
+
if re.search(r"interactiveshell.py|ipython", callfile):
|
|
78
|
+
callfile = 'ipython_shell'
|
|
79
|
+
if args is None:
|
|
80
|
+
raise ReferenceError("Cannot use _cachefile() outside quickseries().")
|
|
81
|
+
key = _cachekey(args, callfile)
|
|
82
|
+
return _cachedir(callfile) / key / "func", key
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _compile_quickseries(source, jit, cache, cachefile):
|
|
86
|
+
globals_ = globals()
|
|
87
|
+
if (precmatch := re.search(r"float\d\d", source)) is not None:
|
|
88
|
+
import numpy
|
|
89
|
+
|
|
90
|
+
globals_[precmatch.group()] = getattr(numpy, precmatch.group())
|
|
91
|
+
func = FunctionType(compile_source(source, str(cachefile)), globals_)
|
|
92
|
+
cache_source(source, cachefile)
|
|
93
|
+
func.__doc__ = source
|
|
94
|
+
if jit is True:
|
|
95
|
+
import numba as nb
|
|
96
|
+
|
|
97
|
+
return nb.njit(func, cache=cache)
|
|
98
|
+
return func
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _cacheget(jit=False):
|
|
102
|
+
cachefile, key = _cacheid()
|
|
103
|
+
if not cachefile.exists():
|
|
104
|
+
return None, None
|
|
105
|
+
with cachefile.open() as stream:
|
|
106
|
+
source = stream.read()
|
|
107
|
+
return _compile_quickseries(source, jit, True, cachefile), source
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _cachewrite(source, cachefile):
|
|
111
|
+
# we make the __pycache__ directory to enable numba JIT result caching,
|
|
112
|
+
# just in case it happens; if it doesn't, the presence of the directory is
|
|
113
|
+
# harmless.
|
|
114
|
+
(cachefile.parent / "__pycache__").mkdir(exist_ok=True, parents=True)
|
|
115
|
+
# TODO, maybe: use a more sensible data structure
|
|
116
|
+
with cachefile.open("w") as stream:
|
|
117
|
+
stream.write(source)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _finalize_quickseries(source, jit=False, cache=False):
|
|
121
|
+
# note that we use this as a function identifier and 'fake' target for
|
|
122
|
+
# linecache even if we're not actually using the quickseries cache
|
|
123
|
+
cachefile, key = _cacheid()
|
|
124
|
+
if cache is True:
|
|
125
|
+
_cachewrite(source, cachefile)
|
|
126
|
+
return _compile_quickseries(source, jit, cache, cachefile)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def lastline(func: Callable) -> str:
|
|
130
|
+
"""try to get the last line of a function, sans return statement"""
|
|
131
|
+
return tuple(
|
|
132
|
+
filter(None, getsource(func).split("\n"))
|
|
133
|
+
)[-1].replace("return", "").strip()
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import Any, Callable, Sequence, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import sympy as sp
|
|
5
|
+
|
|
6
|
+
LmSig = Callable[[np.ndarray | float, ...], np.ndarray | float]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def lambdify(
|
|
10
|
+
func: Union[str, sp.Expr],
|
|
11
|
+
modules: Union[str, Sequence[str]] = ("scipy", "numpy")
|
|
12
|
+
) -> LmSig:
|
|
13
|
+
"""
|
|
14
|
+
Transform a sympy Expr or a string representation of a function into a
|
|
15
|
+
callable with enforced argument order, incorporating code from specified
|
|
16
|
+
modules.
|
|
17
|
+
"""
|
|
18
|
+
if isinstance(func, str):
|
|
19
|
+
try:
|
|
20
|
+
func = sp.sympify(func)
|
|
21
|
+
except sp.SympifyError:
|
|
22
|
+
raise ValueError(f"Unable to parse {func}.")
|
|
23
|
+
# noinspection PyTypeChecker
|
|
24
|
+
return sp.lambdify(
|
|
25
|
+
sorted(func.free_symbols, key=lambda x: str(x)), func, modules
|
|
26
|
+
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: quickseries
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Home-page: https://github.com/millionconcepts/quickseries.git
|
|
5
|
+
Author: Michael St. Clair
|
|
6
|
+
Author-email: mstclair@millionconcepts.com
|
|
7
|
+
Requires-Python: >=3.11
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Requires-Dist: dustgoggles
|
|
10
|
+
Requires-Dist: numpy
|
|
11
|
+
Requires-Dist: scipy
|
|
12
|
+
Requires-Dist: sympy
|
|
13
|
+
Provides-Extra: jit
|
|
14
|
+
Requires-Dist: numba; extra == "jit"
|
|
15
|
+
Provides-Extra: tests
|
|
16
|
+
Requires-Dist: pytest; extra == "tests"
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
setup.py
|
|
4
|
+
quickseries/__init__.py
|
|
5
|
+
quickseries/approximate.py
|
|
6
|
+
quickseries/benchmark.py
|
|
7
|
+
quickseries/expansions.py
|
|
8
|
+
quickseries/simplefit.py
|
|
9
|
+
quickseries/sourceutils.py
|
|
10
|
+
quickseries/sputils.py
|
|
11
|
+
quickseries.egg-info/PKG-INFO
|
|
12
|
+
quickseries.egg-info/SOURCES.txt
|
|
13
|
+
quickseries.egg-info/dependency_links.txt
|
|
14
|
+
quickseries.egg-info/requires.txt
|
|
15
|
+
quickseries.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
quickseries
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from setuptools import find_packages, setup
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name="quickseries",
|
|
5
|
+
version="0.2.1",
|
|
6
|
+
packages=find_packages(),
|
|
7
|
+
url="https://github.com/millionconcepts/quickseries.git",
|
|
8
|
+
author="Michael St. Clair",
|
|
9
|
+
author_email="mstclair@millionconcepts.com",
|
|
10
|
+
python_requires=">=3.11",
|
|
11
|
+
install_requires=["dustgoggles", "numpy", "scipy", "sympy"],
|
|
12
|
+
extras_require={"jit": "numba", "tests": "pytest"}
|
|
13
|
+
)
|