minimizesolve 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- minimizesolve-1.0.dist-info/METADATA +5 -0
- minimizesolve-1.0.dist-info/RECORD +12 -0
- minimizesolve-1.0.dist-info/WHEEL +5 -0
- minimizesolve-1.0.dist-info/top_level.txt +1 -0
- miso/__init__.py +23 -0
- miso/codegen.py +503 -0
- miso/codewriter.py +267 -0
- miso/domain.py +398 -0
- miso/generate.py +266 -0
- miso/logger.py +55 -0
- miso/poly.py +62 -0
- miso/subdivision.py +68 -0
miso/codewriter.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from math import nextafter, inf
|
|
3
|
+
from sympy.printing.cxx import CXX11CodePrinter
|
|
4
|
+
from sympy import S, Float, Rational, Function
|
|
5
|
+
from .poly import is_miso_poly
|
|
6
|
+
|
|
7
|
+
# Singleton functions used as placeholders in inclusion/sample expressions
|
|
8
|
+
inclusion_func = Function('MiSoInclusion', real=True)
|
|
9
|
+
sample_func = Function('MiSoSample', real=True)
|
|
10
|
+
|
|
11
|
+
# Special C++ math functions not natively supported by SymPy
|
|
12
|
+
sinpi = Function('sinpi')
|
|
13
|
+
cospi = Function('cospi')
|
|
14
|
+
tanpi = Function('tanpi')
|
|
15
|
+
asinpi = Function('asinpi')
|
|
16
|
+
acospi = Function('acospi')
|
|
17
|
+
atanpi = Function('atanpi')
|
|
18
|
+
exp2 = Function('exp2')
|
|
19
|
+
exp10 = Function('exp10')
|
|
20
|
+
expm1 = Function('expm1')
|
|
21
|
+
exp2m1 = Function('exp2m1')
|
|
22
|
+
exp10m1 = Function('exp10m1')
|
|
23
|
+
log2 = Function('log2')
|
|
24
|
+
log10 = Function('log10')
|
|
25
|
+
log1p = Function('log1p')
|
|
26
|
+
log2p1 = Function('log2p1')
|
|
27
|
+
log10p1 = Function('log10p1')
|
|
28
|
+
cbrt = Function('cbrt')
|
|
29
|
+
rsqrt = Function('rsqrt')
|
|
30
|
+
hypot = Function('hypot')
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def replace_special_func(expr):
|
|
34
|
+
"""Replace expression patterns with specialized C++ math functions.
|
|
35
|
+
|
|
36
|
+
Bottom-up traversal: processes children first, then matches at each node.
|
|
37
|
+
Recognized patterns:
|
|
38
|
+
sin(pi*x) -> sinpi(x), asin(x)/pi -> asinpi(x), etc.
|
|
39
|
+
2**x -> exp2(x), exp(x)-1 -> expm1(x), 2**x-1 -> exp2m1(x), etc.
|
|
40
|
+
log(x)/log(2) -> log2(x), log(1+x) -> log1p(x), log(1+x)/log(2) -> log2p1(x), etc.
|
|
41
|
+
x**(1/3) -> cbrt(x), x**(-1/2) -> rsqrt(x), sqrt(x²+y²) -> hypot(x,y)
|
|
42
|
+
"""
|
|
43
|
+
from sympy import (pi, sin, cos, tan, asin, acos, atan, exp, log, Add)
|
|
44
|
+
|
|
45
|
+
trig_pi = {sin: sinpi, cos: cospi, tan: tanpi}
|
|
46
|
+
atrig_pi = {asin: asinpi, acos: acospi, atan: atanpi}
|
|
47
|
+
inv_log2, inv_log10 = 1/log(S(2)), 1/log(S(10))
|
|
48
|
+
inv_pi = 1/pi
|
|
49
|
+
|
|
50
|
+
def _walk(node):
|
|
51
|
+
if not node.args:
|
|
52
|
+
return node
|
|
53
|
+
if is_miso_poly(node):
|
|
54
|
+
return node
|
|
55
|
+
|
|
56
|
+
# Bottom-up: children first
|
|
57
|
+
new_args = tuple(_walk(a) for a in node.args)
|
|
58
|
+
if any(a is not b for a, b in zip(new_args, node.args)):
|
|
59
|
+
node = node.func(*new_args)
|
|
60
|
+
|
|
61
|
+
# trig(pi*x) -> trigpi(x)
|
|
62
|
+
if node.func in trig_pi:
|
|
63
|
+
c = node.args[0].as_coefficient(pi)
|
|
64
|
+
if c is not None:
|
|
65
|
+
return trig_pi[node.func](c)
|
|
66
|
+
|
|
67
|
+
# Pow patterns
|
|
68
|
+
if node.is_Pow:
|
|
69
|
+
base, e = node.as_base_exp()
|
|
70
|
+
if e == Rational(1, 3):
|
|
71
|
+
return cbrt(base)
|
|
72
|
+
if e == Rational(-1, 2):
|
|
73
|
+
return rsqrt(base)
|
|
74
|
+
if e == S.Half and base.is_Add and len(base.args) == 2:
|
|
75
|
+
a, b = base.args
|
|
76
|
+
if a.is_Pow and a.exp == 2 and b.is_Pow and b.exp == 2:
|
|
77
|
+
return hypot(a.base, b.base)
|
|
78
|
+
if base == S(2) and not e.is_Number:
|
|
79
|
+
return exp2(e)
|
|
80
|
+
if base == S(10) and not e.is_Number:
|
|
81
|
+
return exp10(e)
|
|
82
|
+
|
|
83
|
+
# log(1+x) -> log1p(x)
|
|
84
|
+
if isinstance(node, log) and len(node.args) == 1:
|
|
85
|
+
arg = node.args[0]
|
|
86
|
+
if arg.is_Add and S.One in arg.args:
|
|
87
|
+
return log1p(arg - 1)
|
|
88
|
+
|
|
89
|
+
# Mul: atrig/pi, log/log(base)
|
|
90
|
+
if node.is_Mul:
|
|
91
|
+
r = node.as_coefficient(inv_pi)
|
|
92
|
+
if r is not None and r.func in atrig_pi:
|
|
93
|
+
return atrig_pi[r.func](*r.args)
|
|
94
|
+
for inv_logb, f_log, f_logp1 in [
|
|
95
|
+
(inv_log2, log2, log2p1), (inv_log10, log10, log10p1)
|
|
96
|
+
]:
|
|
97
|
+
r = node.as_coefficient(inv_logb)
|
|
98
|
+
if r is not None:
|
|
99
|
+
if r.func == log1p:
|
|
100
|
+
return f_logp1(r.args[0])
|
|
101
|
+
if isinstance(r, log):
|
|
102
|
+
a = r.args[0]
|
|
103
|
+
if a.is_Add and S.One in a.args:
|
|
104
|
+
return f_logp1(a - 1)
|
|
105
|
+
return f_log(a)
|
|
106
|
+
|
|
107
|
+
# Add: f(x) - 1 -> fm1(x)
|
|
108
|
+
if node.is_Add and S.NegativeOne in node.args:
|
|
109
|
+
others = [a for a in node.args if a != S.NegativeOne]
|
|
110
|
+
for i, a in enumerate(others):
|
|
111
|
+
repl = None
|
|
112
|
+
if isinstance(a, exp):
|
|
113
|
+
repl = expm1(a.args[0])
|
|
114
|
+
elif a.func == exp2:
|
|
115
|
+
repl = exp2m1(a.args[0])
|
|
116
|
+
elif a.func == exp10:
|
|
117
|
+
repl = exp10m1(a.args[0])
|
|
118
|
+
if repl is not None:
|
|
119
|
+
return Add(repl, *others[:i], *others[i+1:])
|
|
120
|
+
|
|
121
|
+
return node
|
|
122
|
+
|
|
123
|
+
return _walk(expr)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class CodeWriter:
|
|
127
|
+
''' Manages a C++ source file '''
|
|
128
|
+
|
|
129
|
+
class SymPyFormatter(CXX11CodePrinter):
|
|
130
|
+
''' Formats SymPy expressions for C++ '''
|
|
131
|
+
|
|
132
|
+
def _print_Rational(self, expr):
|
|
133
|
+
fp = float(expr)
|
|
134
|
+
if fp < expr:
|
|
135
|
+
fpu = nextafter(fp, +inf)
|
|
136
|
+
return f'_R({fp}, {fpu})'
|
|
137
|
+
elif fp > expr:
|
|
138
|
+
fpl = nextafter(fp, -inf)
|
|
139
|
+
return f'_R({fpl}, {fp})'
|
|
140
|
+
else:
|
|
141
|
+
return f'{fp}'
|
|
142
|
+
|
|
143
|
+
def _print_Pow(self, expr):
|
|
144
|
+
base, exp = expr.as_base_exp()
|
|
145
|
+
if exp.is_integer:
|
|
146
|
+
return f"pown({self._print(base)}, {int(exp)})"
|
|
147
|
+
if exp == Rational(1, 2):
|
|
148
|
+
return f"sqrt({self._print(base)})"
|
|
149
|
+
if exp == Rational(-1, 2):
|
|
150
|
+
return f"rsqrt({self._print(base)})"
|
|
151
|
+
return f"pow({self._print(base)}, {self._print(exp)})"
|
|
152
|
+
|
|
153
|
+
def _print_Abs(self, expr):
|
|
154
|
+
return f'abs({self._print(expr.args[0])})'
|
|
155
|
+
|
|
156
|
+
def _print_Max(self, expr):
|
|
157
|
+
return '(' + '|'.join(self._print(a) for a in expr.args) + ')'
|
|
158
|
+
|
|
159
|
+
def _print_Min(self, expr):
|
|
160
|
+
return '(' + '&'.join(self._print(a) for a in expr.args) + ')'
|
|
161
|
+
|
|
162
|
+
# Standard math functions: emit unqualified so ADL finds tight:: overloads
|
|
163
|
+
def _unary(self, expr): return f'{type(expr).__name__}({self._print(expr.args[0])})'
|
|
164
|
+
_print_sin = _print_cos = _print_tan = _print_asin = _print_acos = _print_atan = \
|
|
165
|
+
_print_sinh = _print_cosh = _print_tanh = _print_asinh = _print_acosh = _print_atanh = \
|
|
166
|
+
_print_exp = _print_log = _print_erf = _unary
|
|
167
|
+
|
|
168
|
+
def _print_atan2(self, expr):
|
|
169
|
+
return f'atan2({self._print(expr.args[0])}, {self._print(expr.args[1])})'
|
|
170
|
+
|
|
171
|
+
_special_funcs = {f.__name__ for f in (
|
|
172
|
+
sinpi, cospi, tanpi, asinpi, acospi, atanpi,
|
|
173
|
+
exp2, exp10, expm1, exp2m1, exp10m1,
|
|
174
|
+
log2, log10, log1p, log2p1, log10p1,
|
|
175
|
+
cbrt, rsqrt, hypot,
|
|
176
|
+
)} | {'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2',
|
|
177
|
+
'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh',
|
|
178
|
+
'exp', 'log', 'sqrt', 'erf'}
|
|
179
|
+
|
|
180
|
+
def _print_Function(self, expr):
|
|
181
|
+
if expr.func.__name__ == 'MiSoInclusion':
|
|
182
|
+
return f'{self._print(expr.args[0])}.inclusion()'
|
|
183
|
+
if expr.func.__name__ == 'MiSoSample':
|
|
184
|
+
name = self._print(expr.args[0])
|
|
185
|
+
indices = ','.join(str(int(a)) for a in expr.args[1:])
|
|
186
|
+
return f'{name}.sample<{indices}>()'
|
|
187
|
+
if expr.func.__name__ in self._special_funcs:
|
|
188
|
+
args = ', '.join(self._print(a) for a in expr.args)
|
|
189
|
+
return f'{expr.func.__name__}({args})'
|
|
190
|
+
return super()._print_Function(expr)
|
|
191
|
+
|
|
192
|
+
def __init__(self, filepath, overwrite=True, pragma_regions=True):
|
|
193
|
+
''' Open and possibly create the managed file '''
|
|
194
|
+
self.filepath = filepath
|
|
195
|
+
if overwrite or not os.path.exists(filepath):
|
|
196
|
+
self.file = open(filepath, 'w')
|
|
197
|
+
else:
|
|
198
|
+
raise FileExistsError(f'{filepath} already exists')
|
|
199
|
+
self.tablv = 0
|
|
200
|
+
self.formatter = self.SymPyFormatter()
|
|
201
|
+
self.newline = True
|
|
202
|
+
self.pragma_regions = pragma_regions
|
|
203
|
+
|
|
204
|
+
def __enter__(self): return self
|
|
205
|
+
def __exit__(self, exc_type, exc_val, ex_tb): self.close()
|
|
206
|
+
def __str__(self): return f'File {self.filepath}'
|
|
207
|
+
|
|
208
|
+
def close(self):
|
|
209
|
+
self.file.close()
|
|
210
|
+
|
|
211
|
+
def tabin(self):
|
|
212
|
+
self.tablv += 1
|
|
213
|
+
return self
|
|
214
|
+
|
|
215
|
+
def tabout(self):
|
|
216
|
+
self.tablv -= 1
|
|
217
|
+
assert self.tablv >= 0
|
|
218
|
+
return self
|
|
219
|
+
|
|
220
|
+
def write(self, l='', newline=True):
|
|
221
|
+
''' Write a single line with indent '''
|
|
222
|
+
def brackets_imbalance(l):
|
|
223
|
+
d = 0
|
|
224
|
+
for c in l:
|
|
225
|
+
if c in '{[(': d += 1
|
|
226
|
+
elif c in '}])': d -= 1
|
|
227
|
+
return d
|
|
228
|
+
|
|
229
|
+
if len(l) > 0:
|
|
230
|
+
diff = brackets_imbalance(l)
|
|
231
|
+
if diff < 0: self.tabout()
|
|
232
|
+
if self.newline:
|
|
233
|
+
for i in range(self.tablv): self.file.write('\t')
|
|
234
|
+
if diff > 0: self.tabin()
|
|
235
|
+
self.file.write(l)
|
|
236
|
+
if newline: self.file.write('\n')
|
|
237
|
+
self.newline = newline
|
|
238
|
+
return self
|
|
239
|
+
|
|
240
|
+
def joint_write(self, separator, lines):
|
|
241
|
+
''' Write all lines, adding separator to each except the last '''
|
|
242
|
+
if len(lines) == 0: return self
|
|
243
|
+
for l in lines[:-1]:
|
|
244
|
+
self.write(l + separator)
|
|
245
|
+
self.write(lines[-1])
|
|
246
|
+
return self
|
|
247
|
+
|
|
248
|
+
def region(self, name):
|
|
249
|
+
if self.pragma_regions:
|
|
250
|
+
self.write(f'#pragma region {name}')
|
|
251
|
+
else:
|
|
252
|
+
self.write(f'/* {name} */')
|
|
253
|
+
|
|
254
|
+
def endregion(self):
|
|
255
|
+
if self.pragma_regions:
|
|
256
|
+
self.write(f'#pragma endregion')
|
|
257
|
+
|
|
258
|
+
def format(self, expr):
|
|
259
|
+
''' Return a formatted SymPy expression as a C++ string '''
|
|
260
|
+
return self.formatter.doprint(expr)
|
|
261
|
+
|
|
262
|
+
def w_assign(self, lhs, rhs, declare=''):
|
|
263
|
+
''' Write an assignment using the SymPy formatter '''
|
|
264
|
+
s = self.formatter.doprint(rhs, lhs)
|
|
265
|
+
if declare: self.write(f'{declare} {s}')
|
|
266
|
+
else: self.write(s)
|
|
267
|
+
return self
|
miso/domain.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
from sympy import symbols, Rational, Matrix, numbered_symbols, Symbol, Add, Mul, Pow, diff, factorial
|
|
2
|
+
from sympy.polys.polyfuncs import interpolate
|
|
3
|
+
from sympy.polys.matrices import DomainMatrix
|
|
4
|
+
from sympy import QQ
|
|
5
|
+
from itertools import product, chain, starmap
|
|
6
|
+
from math import prod
|
|
7
|
+
|
|
8
|
+
from .poly import make_poly, is_miso_poly
|
|
9
|
+
from .codewriter import inclusion_func, sample_func
|
|
10
|
+
from .subdivision import MiSoSubdivision, subdivision, combine_subdiv
|
|
11
|
+
from .logger import MiSoLogger
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _degree_in_vars(expr, variables):
|
|
15
|
+
'''Max total degree of expr in the given variables, via tree walk (no expansion).'''
|
|
16
|
+
vs = set(variables)
|
|
17
|
+
|
|
18
|
+
class DegExpr:
|
|
19
|
+
def __init__(self, d=1): self.deg = d
|
|
20
|
+
def __add__(self, o): return DegExpr(max(self.deg, o.deg)) if isinstance(o, DegExpr) else self
|
|
21
|
+
def __radd__(self, o): return self + o
|
|
22
|
+
def __mul__(self, o): return DegExpr(self.deg + o.deg) if isinstance(o, DegExpr) else self
|
|
23
|
+
def __rmul__(self, o): return self * o
|
|
24
|
+
def __pow__(self, o): return DegExpr(self.deg * int(o))
|
|
25
|
+
|
|
26
|
+
def propagate(e):
|
|
27
|
+
if e.func == Symbol:
|
|
28
|
+
return DegExpr(1) if e in vs else e
|
|
29
|
+
children = tuple(propagate(x) for x in e.args)
|
|
30
|
+
if e.func == Add:
|
|
31
|
+
return sum(children)
|
|
32
|
+
elif e.func == Mul:
|
|
33
|
+
return prod(children)
|
|
34
|
+
elif e.func == Pow:
|
|
35
|
+
return children[0] ** children[1]
|
|
36
|
+
return e
|
|
37
|
+
|
|
38
|
+
d = propagate(expr)
|
|
39
|
+
return d.deg if isinstance(d, DegExpr) else 0
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _build_l2b(B, variables, E):
|
|
43
|
+
'''Build Bernstein evaluation matrix as DomainMatrix over QQ and return its inverse.'''
|
|
44
|
+
N = len(B)
|
|
45
|
+
rows = [[QQ.from_sympy(b.eval(dict(zip(variables, e))).as_expr()) for b in B] for e in E]
|
|
46
|
+
return DomainMatrix(rows, (N, N), QQ).inv()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _compute_b2b(l2b, B, variables, E, trans, bv_mat):
|
|
50
|
+
'''Compute a single B2B conversion matrix (picklable for parallel use).'''
|
|
51
|
+
ET = tuple(
|
|
52
|
+
tuple(
|
|
53
|
+
(trans @ Matrix(list(e) + [Rational(1)]))[i, 0]
|
|
54
|
+
for i in range(trans.rows)
|
|
55
|
+
)
|
|
56
|
+
for e in E
|
|
57
|
+
)
|
|
58
|
+
N = len(B)
|
|
59
|
+
rows = [[QQ.from_sympy(b.eval(dict(zip(variables, e))).as_expr()) for b in B] for e in ET]
|
|
60
|
+
MT = DomainMatrix(rows, (N, N), QQ)
|
|
61
|
+
b2b = l2b * MT # DomainMatrix over QQ
|
|
62
|
+
bv = list(bv_mat.flat())
|
|
63
|
+
bv_dm = DomainMatrix([[v] for v in bv], (N, 1), QQ)
|
|
64
|
+
result = b2b * bv_dm
|
|
65
|
+
return result.to_list_flat()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _sig_name(simplex_degs):
|
|
69
|
+
'''Format signature as e.g. "2p3" or "2p1_1p1" '''
|
|
70
|
+
return '_'.join(f'{dim}p{deg}' for dim, deg in simplex_degs)
|
|
71
|
+
|
|
72
|
+
def _all_compositions(n, total):
|
|
73
|
+
if n == 1:
|
|
74
|
+
yield (total,)
|
|
75
|
+
return
|
|
76
|
+
for i in range(total + 1):
|
|
77
|
+
for rest in _all_compositions(n - 1, total - i):
|
|
78
|
+
yield (i,) + rest
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class MiSoDomain:
|
|
82
|
+
_map_counter = 0
|
|
83
|
+
|
|
84
|
+
def __init__(self, *variables):
|
|
85
|
+
ms = MiSoDomain.MiSoSimplex
|
|
86
|
+
simplices = (v if isinstance(v, ms) else ms(v) for v in variables)
|
|
87
|
+
self.simplices = tuple(filter(lambda s: not s.is_empty, simplices))
|
|
88
|
+
self._basis_cache = {}
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def variables(self):
|
|
92
|
+
return tuple(v for s in self.simplices for v in s.variables)
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def is_empty(self): return len(self.variables) == 0
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def dimension(self):
|
|
99
|
+
return sum(s.dimension for s in self.simplices)
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def num_vertices(self):
|
|
103
|
+
return prod(s.num_vertices for s in self.simplices)
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def vertices(self):
|
|
107
|
+
'''
|
|
108
|
+
All vertices of the domain as tuples of coordinate values.
|
|
109
|
+
'''
|
|
110
|
+
factors = [s.vertices for s in self.simplices]
|
|
111
|
+
return list(map(lambda x: tuple(chain(*x)), product(*factors)))
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def default_subdivision(self):
|
|
115
|
+
sub = tuple(s.default_subdivision for s in self.simplices)
|
|
116
|
+
prod_mats = tuple(starmap(combine_subdiv, product(*sub)))
|
|
117
|
+
return MiSoSubdivision(*prod_mats)
|
|
118
|
+
|
|
119
|
+
def subdivision(self, *simplex_domains):
|
|
120
|
+
'''Bisect only the listed simplex domains; identity on the rest.
|
|
121
|
+
Each argument should be a MiSoDomain whose variables match one of the simplices.'''
|
|
122
|
+
target_vars = set()
|
|
123
|
+
for sd in simplex_domains:
|
|
124
|
+
target_vars.update(sd.variables)
|
|
125
|
+
sub = []
|
|
126
|
+
for s in self.simplices:
|
|
127
|
+
if any(v in target_vars for v in s.variables):
|
|
128
|
+
sub.append(s.default_subdivision)
|
|
129
|
+
else:
|
|
130
|
+
# Identity: single matrix that maps domain to itself
|
|
131
|
+
d = s.dimension
|
|
132
|
+
identity = subdivision(
|
|
133
|
+
*(tuple(int(i == j) for j in range(d)) for i in range(d + 1))
|
|
134
|
+
)
|
|
135
|
+
sub.append((identity,))
|
|
136
|
+
prod_mats = tuple(starmap(combine_subdiv, product(*sub)))
|
|
137
|
+
return MiSoSubdivision(*prod_mats)
|
|
138
|
+
|
|
139
|
+
def __getitem__(self, i): return self.variables[i]
|
|
140
|
+
|
|
141
|
+
def __str__(self):
|
|
142
|
+
return 'Domain(' + str(', '.join(str(s.variables) for s in self.simplices)) + ')'
|
|
143
|
+
|
|
144
|
+
def __mul__(self, o):
|
|
145
|
+
V = set(self.variables)
|
|
146
|
+
OV = set(o.variables)
|
|
147
|
+
if not V.isdisjoint(OV):
|
|
148
|
+
raise ValueError('Domains cannot share variables')
|
|
149
|
+
return MiSoDomain(*self.simplices, *o.simplices)
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def signature(self):
|
|
153
|
+
return tuple(len(s.variables) for s in self.simplices)
|
|
154
|
+
|
|
155
|
+
def subdomain(self, expr):
|
|
156
|
+
'''Return the smallest subdomain that can represent the expression.'''
|
|
157
|
+
return MiSoDomain(*(s.subdomain(expr) for s in self.simplices))
|
|
158
|
+
|
|
159
|
+
def degs(self, expr):
|
|
160
|
+
'''Return the list of degrees (one per simplex) needed to represent the expression.'''
|
|
161
|
+
res = []
|
|
162
|
+
for s in self.simplices:
|
|
163
|
+
d = _degree_in_vars(expr, s.variables)
|
|
164
|
+
if d > 0:
|
|
165
|
+
res.append(d)
|
|
166
|
+
return tuple(res)
|
|
167
|
+
|
|
168
|
+
def parameters(self, *expr):
|
|
169
|
+
'''
|
|
170
|
+
Get free parameters (arguments) from a list of expressions —
|
|
171
|
+
all free symbols not in domain variables.
|
|
172
|
+
'''
|
|
173
|
+
fs = set.union(*(ex.free_symbols for ex in expr))
|
|
174
|
+
if any(not s.is_Symbol for s in fs):
|
|
175
|
+
raise ValueError(f'{fs} contains non-Symbols')
|
|
176
|
+
return fs.difference(self.variables)
|
|
177
|
+
|
|
178
|
+
def _product_basis(self, simplexf, *degs):
|
|
179
|
+
assert len(self.simplices) == len(degs), \
|
|
180
|
+
'Degrees/simplices list size mismatch'
|
|
181
|
+
factors = tuple(simplexf(s, d) for s, d in zip(self.simplices, degs))
|
|
182
|
+
return tuple(map(
|
|
183
|
+
lambda p: p.as_poly(self.variables),
|
|
184
|
+
map(prod, product(*factors))
|
|
185
|
+
))
|
|
186
|
+
|
|
187
|
+
def bernstein(self, *degs):
|
|
188
|
+
key = ('bernstein', degs)
|
|
189
|
+
if key not in self._basis_cache:
|
|
190
|
+
self._basis_cache[key] = self._product_basis(lambda s, d: s.bernstein(d), *degs)
|
|
191
|
+
return self._basis_cache[key]
|
|
192
|
+
|
|
193
|
+
def lagrange(self, *degs):
|
|
194
|
+
key = ('lagrange', degs)
|
|
195
|
+
if key not in self._basis_cache:
|
|
196
|
+
self._basis_cache[key] = self._product_basis(lambda s, d: s.lagrange(d), *degs)
|
|
197
|
+
return self._basis_cache[key]
|
|
198
|
+
|
|
199
|
+
def monomial(self, *degs):
|
|
200
|
+
key = ('monomial', degs)
|
|
201
|
+
if key not in self._basis_cache:
|
|
202
|
+
self._basis_cache[key] = self._product_basis(lambda s, d: s.monomial(d), *degs)
|
|
203
|
+
return self._basis_cache[key]
|
|
204
|
+
|
|
205
|
+
def eval_points(self, *degs):
|
|
206
|
+
assert len(self.simplices) == len(degs), \
|
|
207
|
+
'Degrees/simplices list size mismatch'
|
|
208
|
+
factors = tuple(s.eval_points(d) for s, d in zip(self.simplices, degs))
|
|
209
|
+
return tuple(map(lambda x: tuple(chain(*x)), product(*factors)))
|
|
210
|
+
|
|
211
|
+
def poly_map(self, deg, basis, codomain_dim=1, name=None, output_miso_poly=True):
|
|
212
|
+
'''
|
|
213
|
+
Create a polynomial map over this domain.
|
|
214
|
+
|
|
215
|
+
Generates control point symbols and returns a Matrix of `codomain_dim`
|
|
216
|
+
MiSoPoly (or plain SymPy) expressions — one per output coordinate.
|
|
217
|
+
|
|
218
|
+
Arguments:
|
|
219
|
+
deg -- polynomial degree (int or tuple, one per simplex)
|
|
220
|
+
basis -- Basis.LAGRANGE, Basis.BERNSTEIN, or Basis.MONOMIAL
|
|
221
|
+
codomain_dim -- number of output coordinates (default 1)
|
|
222
|
+
name -- base name for control point symbols (auto-generated if None)
|
|
223
|
+
output_miso_poly -- wrap output in MiSoPoly if True (default True)
|
|
224
|
+
'''
|
|
225
|
+
degs = (deg,) * len(self.simplices) if isinstance(deg, int) else tuple(deg)
|
|
226
|
+
|
|
227
|
+
if basis == Basis.LAGRANGE:
|
|
228
|
+
basis_polys = self.lagrange(*degs)
|
|
229
|
+
elif basis == Basis.BERNSTEIN:
|
|
230
|
+
basis_polys = self.bernstein(*degs)
|
|
231
|
+
elif basis == Basis.MONOMIAL:
|
|
232
|
+
basis_polys = self.monomial(*degs)
|
|
233
|
+
else:
|
|
234
|
+
raise ValueError(f'Unknown basis: {basis}')
|
|
235
|
+
|
|
236
|
+
N = len(basis_polys)
|
|
237
|
+
|
|
238
|
+
if name is None:
|
|
239
|
+
name = f'p{MiSoDomain._map_counter}'
|
|
240
|
+
MiSoDomain._map_counter += 1
|
|
241
|
+
|
|
242
|
+
coord_names = 'xyzw'
|
|
243
|
+
result = []
|
|
244
|
+
for d in range(codomain_dim):
|
|
245
|
+
cname = coord_names[d] if d < len(coord_names) else f'x{d}_'
|
|
246
|
+
cps = tuple(Symbol(f'{name}{cname}[{i}]') for i in range(N))
|
|
247
|
+
poly_expr = sum(cp * b.as_expr() for cp, b in zip(cps, basis_polys))
|
|
248
|
+
result.append(make_poly(poly_expr) if output_miso_poly else poly_expr)
|
|
249
|
+
|
|
250
|
+
return Matrix(result)
|
|
251
|
+
|
|
252
|
+
# Alias for backward compatibility
|
|
253
|
+
geo_map = poly_map
|
|
254
|
+
|
|
255
|
+
###########################################################
|
|
256
|
+
|
|
257
|
+
class MiSoSimplex:
|
|
258
|
+
def __init__(self, *names):
|
|
259
|
+
if not names:
|
|
260
|
+
self._variables = tuple()
|
|
261
|
+
elif all(isinstance(n, str) for n in names):
|
|
262
|
+
vs = symbols(*names, real=True, positive=True)
|
|
263
|
+
if isinstance(vs, (list, tuple)):
|
|
264
|
+
self._variables = tuple(vs)
|
|
265
|
+
elif vs.is_Symbol:
|
|
266
|
+
self._variables = (vs,)
|
|
267
|
+
else:
|
|
268
|
+
self._variables = tuple(vs)
|
|
269
|
+
elif all(n.is_Symbol for n in names):
|
|
270
|
+
self._variables = tuple(names)
|
|
271
|
+
else:
|
|
272
|
+
raise ValueError('Variable names must be all strings or symbols')
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def variables(self): return self._variables
|
|
276
|
+
|
|
277
|
+
@property
|
|
278
|
+
def is_empty(self): return len(self.variables) == 0
|
|
279
|
+
|
|
280
|
+
@property
|
|
281
|
+
def dimension(self): return len(self.variables)
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def num_vertices(self): return self.dimension + 1
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def vertices(self):
|
|
288
|
+
return [tuple(int(x) for x in e) for e in self.eval_points(1)]
|
|
289
|
+
|
|
290
|
+
@property
|
|
291
|
+
def default_subdivision(self):
|
|
292
|
+
H = Rational(1, 2)
|
|
293
|
+
D = self.dimension
|
|
294
|
+
if D == 1:
|
|
295
|
+
return MiSoSubdivision(
|
|
296
|
+
subdivision((0,), (H,)),
|
|
297
|
+
subdivision((H,), (1,)),
|
|
298
|
+
)
|
|
299
|
+
elif D == 2:
|
|
300
|
+
return MiSoSubdivision(
|
|
301
|
+
subdivision((0, 0), (H, 0), (0, H)),
|
|
302
|
+
subdivision((H, 0), (1, 0), (H, H)),
|
|
303
|
+
subdivision((0, H), (H, H), (0, 1)),
|
|
304
|
+
subdivision((H, H), (0, H), (H, 0)),
|
|
305
|
+
)
|
|
306
|
+
elif D == 3:
|
|
307
|
+
return MiSoSubdivision(
|
|
308
|
+
subdivision((0, 0, 0), (H, 0, 0), (0, H, 0), (0, 0, H)),
|
|
309
|
+
subdivision((H, 0, 0), (1, 0, 0), (H, H, 0), (H, 0, H)),
|
|
310
|
+
subdivision((0, H, 0), (H, H, 0), (0, 1, 0), (0, H, H)),
|
|
311
|
+
subdivision((0, 0, H), (H, 0, H), (0, H, H), (0, 0, 1)),
|
|
312
|
+
subdivision((H, 0, 0), (0, H, 0), (0, 0, H), (H, 0, H)),
|
|
313
|
+
subdivision((H, 0, 0), (0, H, 0), (H, H, 0), (H, 0, H)),
|
|
314
|
+
subdivision((0, H, 0), (0, 0, H), (H, 0, H), (0, H, H)),
|
|
315
|
+
subdivision((0, H, 0), (H, H, 0), (H, 0, H), (0, H, H)),
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
raise NotImplementedError(f'Default subdivision not implemented for simplex dimension {D}, supply your own')
|
|
319
|
+
|
|
320
|
+
def subdomain(self, expr):
|
|
321
|
+
fs = expr.free_symbols
|
|
322
|
+
return MiSoDomain.MiSoSimplex(*(v for v in self.variables if v in fs))
|
|
323
|
+
|
|
324
|
+
def bernstein(self, deg):
|
|
325
|
+
vs = self.variables + (1 - sum(self.variables),)
|
|
326
|
+
n = len(vs)
|
|
327
|
+
|
|
328
|
+
def bp(combo_coeff):
|
|
329
|
+
combo, coeff = combo_coeff
|
|
330
|
+
return (coeff * prod(v ** k for v, k in zip(vs, combo))).as_poly(self.variables)
|
|
331
|
+
|
|
332
|
+
return tuple(map(bp, _all_compositions_with_coeff(n, deg)))
|
|
333
|
+
|
|
334
|
+
def lagrange(self, deg):
|
|
335
|
+
vs = self.variables + (1 - sum(self.variables),)
|
|
336
|
+
|
|
337
|
+
def lp(combo):
|
|
338
|
+
return prod(
|
|
339
|
+
interpolate(
|
|
340
|
+
{Rational(x, deg): int(x == n) for x in range(n + 1)}, v)
|
|
341
|
+
for v, n in zip(vs, combo)
|
|
342
|
+
).as_poly(self.variables)
|
|
343
|
+
|
|
344
|
+
return tuple(map(lp, _all_compositions(len(vs), deg)))
|
|
345
|
+
|
|
346
|
+
def monomial(self, deg):
|
|
347
|
+
vs = self.variables
|
|
348
|
+
k = len(vs)
|
|
349
|
+
result = []
|
|
350
|
+
for total in range(deg + 1):
|
|
351
|
+
for combo in _all_compositions(k, total):
|
|
352
|
+
mono = prod(v**e for v, e in zip(vs, combo))
|
|
353
|
+
result.append(mono.as_poly(vs))
|
|
354
|
+
return tuple(result)
|
|
355
|
+
|
|
356
|
+
def eval_points(self, deg):
|
|
357
|
+
lv = len(self.variables)
|
|
358
|
+
D = max(int(deg), 1)
|
|
359
|
+
return tuple(
|
|
360
|
+
tuple(Rational(n, D) for n in combo[:lv])
|
|
361
|
+
for combo in _all_compositions(1 + lv, deg)
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _all_compositions(n, total):
|
|
366
|
+
if n == 1:
|
|
367
|
+
yield (total,)
|
|
368
|
+
return
|
|
369
|
+
for i in range(total + 1):
|
|
370
|
+
for rest in _all_compositions(n - 1, total - i):
|
|
371
|
+
yield (i,) + rest
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def _all_compositions_with_coeff(n, deg):
|
|
375
|
+
from math import factorial
|
|
376
|
+
def multinomial(lst):
|
|
377
|
+
s = sum(lst)
|
|
378
|
+
r = factorial(s)
|
|
379
|
+
for x in lst:
|
|
380
|
+
r //= factorial(x)
|
|
381
|
+
return r
|
|
382
|
+
for combo in _all_compositions(n, deg):
|
|
383
|
+
yield combo, multinomial(combo)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
###########################################################
|
|
387
|
+
# Basis enum and geo_map helper
|
|
388
|
+
###########################################################
|
|
389
|
+
|
|
390
|
+
class Basis:
|
|
391
|
+
LAGRANGE = 'lagrange'
|
|
392
|
+
BERNSTEIN = 'bernstein'
|
|
393
|
+
MONOMIAL = 'monomial'
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def geo_map(domain, deg, basis, embed_dim=1, name=None, poly=True):
|
|
397
|
+
'''Alias for domain.poly_map() as a free function (backward-compatible signature).'''
|
|
398
|
+
return domain.poly_map(deg, basis, codomain_dim=embed_dim, name=name, output_miso_poly=poly)
|