minimizesolve 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
miso/codewriter.py ADDED
@@ -0,0 +1,267 @@
1
+ import os
2
+ from math import nextafter, inf
3
+ from sympy.printing.cxx import CXX11CodePrinter
4
+ from sympy import S, Float, Rational, Function
5
+ from .poly import is_miso_poly
6
+
7
+ # Singleton functions used as placeholders in inclusion/sample expressions
8
+ inclusion_func = Function('MiSoInclusion', real=True)
9
+ sample_func = Function('MiSoSample', real=True)
10
+
11
+ # Special C++ math functions not natively supported by SymPy
12
+ sinpi = Function('sinpi')
13
+ cospi = Function('cospi')
14
+ tanpi = Function('tanpi')
15
+ asinpi = Function('asinpi')
16
+ acospi = Function('acospi')
17
+ atanpi = Function('atanpi')
18
+ exp2 = Function('exp2')
19
+ exp10 = Function('exp10')
20
+ expm1 = Function('expm1')
21
+ exp2m1 = Function('exp2m1')
22
+ exp10m1 = Function('exp10m1')
23
+ log2 = Function('log2')
24
+ log10 = Function('log10')
25
+ log1p = Function('log1p')
26
+ log2p1 = Function('log2p1')
27
+ log10p1 = Function('log10p1')
28
+ cbrt = Function('cbrt')
29
+ rsqrt = Function('rsqrt')
30
+ hypot = Function('hypot')
31
+
32
+
33
+ def replace_special_func(expr):
34
+ """Replace expression patterns with specialized C++ math functions.
35
+
36
+ Bottom-up traversal: processes children first, then matches at each node.
37
+ Recognized patterns:
38
+ sin(pi*x) -> sinpi(x), asin(x)/pi -> asinpi(x), etc.
39
+ 2**x -> exp2(x), exp(x)-1 -> expm1(x), 2**x-1 -> exp2m1(x), etc.
40
+ log(x)/log(2) -> log2(x), log(1+x) -> log1p(x), log(1+x)/log(2) -> log2p1(x), etc.
41
+ x**(1/3) -> cbrt(x), x**(-1/2) -> rsqrt(x), sqrt(x²+y²) -> hypot(x,y)
42
+ """
43
+ from sympy import (pi, sin, cos, tan, asin, acos, atan, exp, log, Add)
44
+
45
+ trig_pi = {sin: sinpi, cos: cospi, tan: tanpi}
46
+ atrig_pi = {asin: asinpi, acos: acospi, atan: atanpi}
47
+ inv_log2, inv_log10 = 1/log(S(2)), 1/log(S(10))
48
+ inv_pi = 1/pi
49
+
50
+ def _walk(node):
51
+ if not node.args:
52
+ return node
53
+ if is_miso_poly(node):
54
+ return node
55
+
56
+ # Bottom-up: children first
57
+ new_args = tuple(_walk(a) for a in node.args)
58
+ if any(a is not b for a, b in zip(new_args, node.args)):
59
+ node = node.func(*new_args)
60
+
61
+ # trig(pi*x) -> trigpi(x)
62
+ if node.func in trig_pi:
63
+ c = node.args[0].as_coefficient(pi)
64
+ if c is not None:
65
+ return trig_pi[node.func](c)
66
+
67
+ # Pow patterns
68
+ if node.is_Pow:
69
+ base, e = node.as_base_exp()
70
+ if e == Rational(1, 3):
71
+ return cbrt(base)
72
+ if e == Rational(-1, 2):
73
+ return rsqrt(base)
74
+ if e == S.Half and base.is_Add and len(base.args) == 2:
75
+ a, b = base.args
76
+ if a.is_Pow and a.exp == 2 and b.is_Pow and b.exp == 2:
77
+ return hypot(a.base, b.base)
78
+ if base == S(2) and not e.is_Number:
79
+ return exp2(e)
80
+ if base == S(10) and not e.is_Number:
81
+ return exp10(e)
82
+
83
+ # log(1+x) -> log1p(x)
84
+ if isinstance(node, log) and len(node.args) == 1:
85
+ arg = node.args[0]
86
+ if arg.is_Add and S.One in arg.args:
87
+ return log1p(arg - 1)
88
+
89
+ # Mul: atrig/pi, log/log(base)
90
+ if node.is_Mul:
91
+ r = node.as_coefficient(inv_pi)
92
+ if r is not None and r.func in atrig_pi:
93
+ return atrig_pi[r.func](*r.args)
94
+ for inv_logb, f_log, f_logp1 in [
95
+ (inv_log2, log2, log2p1), (inv_log10, log10, log10p1)
96
+ ]:
97
+ r = node.as_coefficient(inv_logb)
98
+ if r is not None:
99
+ if r.func == log1p:
100
+ return f_logp1(r.args[0])
101
+ if isinstance(r, log):
102
+ a = r.args[0]
103
+ if a.is_Add and S.One in a.args:
104
+ return f_logp1(a - 1)
105
+ return f_log(a)
106
+
107
+ # Add: f(x) - 1 -> fm1(x)
108
+ if node.is_Add and S.NegativeOne in node.args:
109
+ others = [a for a in node.args if a != S.NegativeOne]
110
+ for i, a in enumerate(others):
111
+ repl = None
112
+ if isinstance(a, exp):
113
+ repl = expm1(a.args[0])
114
+ elif a.func == exp2:
115
+ repl = exp2m1(a.args[0])
116
+ elif a.func == exp10:
117
+ repl = exp10m1(a.args[0])
118
+ if repl is not None:
119
+ return Add(repl, *others[:i], *others[i+1:])
120
+
121
+ return node
122
+
123
+ return _walk(expr)
124
+
125
+
126
+ class CodeWriter:
127
+ ''' Manages a C++ source file '''
128
+
129
+ class SymPyFormatter(CXX11CodePrinter):
130
+ ''' Formats SymPy expressions for C++ '''
131
+
132
+ def _print_Rational(self, expr):
133
+ fp = float(expr)
134
+ if fp < expr:
135
+ fpu = nextafter(fp, +inf)
136
+ return f'_R({fp}, {fpu})'
137
+ elif fp > expr:
138
+ fpl = nextafter(fp, -inf)
139
+ return f'_R({fpl}, {fp})'
140
+ else:
141
+ return f'{fp}'
142
+
143
+ def _print_Pow(self, expr):
144
+ base, exp = expr.as_base_exp()
145
+ if exp.is_integer:
146
+ return f"pown({self._print(base)}, {int(exp)})"
147
+ if exp == Rational(1, 2):
148
+ return f"sqrt({self._print(base)})"
149
+ if exp == Rational(-1, 2):
150
+ return f"rsqrt({self._print(base)})"
151
+ return f"pow({self._print(base)}, {self._print(exp)})"
152
+
153
+ def _print_Abs(self, expr):
154
+ return f'abs({self._print(expr.args[0])})'
155
+
156
+ def _print_Max(self, expr):
157
+ return '(' + '|'.join(self._print(a) for a in expr.args) + ')'
158
+
159
+ def _print_Min(self, expr):
160
+ return '(' + '&'.join(self._print(a) for a in expr.args) + ')'
161
+
162
+ # Standard math functions: emit unqualified so ADL finds tight:: overloads
163
+ def _unary(self, expr): return f'{type(expr).__name__}({self._print(expr.args[0])})'
164
+ _print_sin = _print_cos = _print_tan = _print_asin = _print_acos = _print_atan = \
165
+ _print_sinh = _print_cosh = _print_tanh = _print_asinh = _print_acosh = _print_atanh = \
166
+ _print_exp = _print_log = _print_erf = _unary
167
+
168
+ def _print_atan2(self, expr):
169
+ return f'atan2({self._print(expr.args[0])}, {self._print(expr.args[1])})'
170
+
171
+ _special_funcs = {f.__name__ for f in (
172
+ sinpi, cospi, tanpi, asinpi, acospi, atanpi,
173
+ exp2, exp10, expm1, exp2m1, exp10m1,
174
+ log2, log10, log1p, log2p1, log10p1,
175
+ cbrt, rsqrt, hypot,
176
+ )} | {'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2',
177
+ 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh',
178
+ 'exp', 'log', 'sqrt', 'erf'}
179
+
180
+ def _print_Function(self, expr):
181
+ if expr.func.__name__ == 'MiSoInclusion':
182
+ return f'{self._print(expr.args[0])}.inclusion()'
183
+ if expr.func.__name__ == 'MiSoSample':
184
+ name = self._print(expr.args[0])
185
+ indices = ','.join(str(int(a)) for a in expr.args[1:])
186
+ return f'{name}.sample<{indices}>()'
187
+ if expr.func.__name__ in self._special_funcs:
188
+ args = ', '.join(self._print(a) for a in expr.args)
189
+ return f'{expr.func.__name__}({args})'
190
+ return super()._print_Function(expr)
191
+
192
+ def __init__(self, filepath, overwrite=True, pragma_regions=True):
193
+ ''' Open and possibly create the managed file '''
194
+ self.filepath = filepath
195
+ if overwrite or not os.path.exists(filepath):
196
+ self.file = open(filepath, 'w')
197
+ else:
198
+ raise FileExistsError(f'{filepath} already exists')
199
+ self.tablv = 0
200
+ self.formatter = self.SymPyFormatter()
201
+ self.newline = True
202
+ self.pragma_regions = pragma_regions
203
+
204
+ def __enter__(self): return self
205
+ def __exit__(self, exc_type, exc_val, ex_tb): self.close()
206
+ def __str__(self): return f'File {self.filepath}'
207
+
208
+ def close(self):
209
+ self.file.close()
210
+
211
+ def tabin(self):
212
+ self.tablv += 1
213
+ return self
214
+
215
+ def tabout(self):
216
+ self.tablv -= 1
217
+ assert self.tablv >= 0
218
+ return self
219
+
220
+ def write(self, l='', newline=True):
221
+ ''' Write a single line with indent '''
222
+ def brackets_imbalance(l):
223
+ d = 0
224
+ for c in l:
225
+ if c in '{[(': d += 1
226
+ elif c in '}])': d -= 1
227
+ return d
228
+
229
+ if len(l) > 0:
230
+ diff = brackets_imbalance(l)
231
+ if diff < 0: self.tabout()
232
+ if self.newline:
233
+ for i in range(self.tablv): self.file.write('\t')
234
+ if diff > 0: self.tabin()
235
+ self.file.write(l)
236
+ if newline: self.file.write('\n')
237
+ self.newline = newline
238
+ return self
239
+
240
+ def joint_write(self, separator, lines):
241
+ ''' Write all lines, adding separator to each except the last '''
242
+ if len(lines) == 0: return self
243
+ for l in lines[:-1]:
244
+ self.write(l + separator)
245
+ self.write(lines[-1])
246
+ return self
247
+
248
+ def region(self, name):
249
+ if self.pragma_regions:
250
+ self.write(f'#pragma region {name}')
251
+ else:
252
+ self.write(f'/* {name} */')
253
+
254
+ def endregion(self):
255
+ if self.pragma_regions:
256
+ self.write(f'#pragma endregion')
257
+
258
+ def format(self, expr):
259
+ ''' Return a formatted SymPy expression as a C++ string '''
260
+ return self.formatter.doprint(expr)
261
+
262
+ def w_assign(self, lhs, rhs, declare=''):
263
+ ''' Write an assignment using the SymPy formatter '''
264
+ s = self.formatter.doprint(rhs, lhs)
265
+ if declare: self.write(f'{declare} {s}')
266
+ else: self.write(s)
267
+ return self
miso/domain.py ADDED
@@ -0,0 +1,398 @@
1
+ from sympy import symbols, Rational, Matrix, numbered_symbols, Symbol, Add, Mul, Pow, diff, factorial
2
+ from sympy.polys.polyfuncs import interpolate
3
+ from sympy.polys.matrices import DomainMatrix
4
+ from sympy import QQ
5
+ from itertools import product, chain, starmap
6
+ from math import prod
7
+
8
+ from .poly import make_poly, is_miso_poly
9
+ from .codewriter import inclusion_func, sample_func
10
+ from .subdivision import MiSoSubdivision, subdivision, combine_subdiv
11
+ from .logger import MiSoLogger
12
+
13
+
14
+ def _degree_in_vars(expr, variables):
15
+ '''Max total degree of expr in the given variables, via tree walk (no expansion).'''
16
+ vs = set(variables)
17
+
18
+ class DegExpr:
19
+ def __init__(self, d=1): self.deg = d
20
+ def __add__(self, o): return DegExpr(max(self.deg, o.deg)) if isinstance(o, DegExpr) else self
21
+ def __radd__(self, o): return self + o
22
+ def __mul__(self, o): return DegExpr(self.deg + o.deg) if isinstance(o, DegExpr) else self
23
+ def __rmul__(self, o): return self * o
24
+ def __pow__(self, o): return DegExpr(self.deg * int(o))
25
+
26
+ def propagate(e):
27
+ if e.func == Symbol:
28
+ return DegExpr(1) if e in vs else e
29
+ children = tuple(propagate(x) for x in e.args)
30
+ if e.func == Add:
31
+ return sum(children)
32
+ elif e.func == Mul:
33
+ return prod(children)
34
+ elif e.func == Pow:
35
+ return children[0] ** children[1]
36
+ return e
37
+
38
+ d = propagate(expr)
39
+ return d.deg if isinstance(d, DegExpr) else 0
40
+
41
+
42
+ def _build_l2b(B, variables, E):
43
+ '''Build Bernstein evaluation matrix as DomainMatrix over QQ and return its inverse.'''
44
+ N = len(B)
45
+ rows = [[QQ.from_sympy(b.eval(dict(zip(variables, e))).as_expr()) for b in B] for e in E]
46
+ return DomainMatrix(rows, (N, N), QQ).inv()
47
+
48
+
49
+ def _compute_b2b(l2b, B, variables, E, trans, bv_mat):
50
+ '''Compute a single B2B conversion matrix (picklable for parallel use).'''
51
+ ET = tuple(
52
+ tuple(
53
+ (trans @ Matrix(list(e) + [Rational(1)]))[i, 0]
54
+ for i in range(trans.rows)
55
+ )
56
+ for e in E
57
+ )
58
+ N = len(B)
59
+ rows = [[QQ.from_sympy(b.eval(dict(zip(variables, e))).as_expr()) for b in B] for e in ET]
60
+ MT = DomainMatrix(rows, (N, N), QQ)
61
+ b2b = l2b * MT # DomainMatrix over QQ
62
+ bv = list(bv_mat.flat())
63
+ bv_dm = DomainMatrix([[v] for v in bv], (N, 1), QQ)
64
+ result = b2b * bv_dm
65
+ return result.to_list_flat()
66
+
67
+
68
+ def _sig_name(simplex_degs):
69
+ '''Format signature as e.g. "2p3" or "2p1_1p1" '''
70
+ return '_'.join(f'{dim}p{deg}' for dim, deg in simplex_degs)
71
+
72
+ def _all_compositions(n, total):
73
+ if n == 1:
74
+ yield (total,)
75
+ return
76
+ for i in range(total + 1):
77
+ for rest in _all_compositions(n - 1, total - i):
78
+ yield (i,) + rest
79
+
80
+
81
+ class MiSoDomain:
82
+ _map_counter = 0
83
+
84
+ def __init__(self, *variables):
85
+ ms = MiSoDomain.MiSoSimplex
86
+ simplices = (v if isinstance(v, ms) else ms(v) for v in variables)
87
+ self.simplices = tuple(filter(lambda s: not s.is_empty, simplices))
88
+ self._basis_cache = {}
89
+
90
+ @property
91
+ def variables(self):
92
+ return tuple(v for s in self.simplices for v in s.variables)
93
+
94
+ @property
95
+ def is_empty(self): return len(self.variables) == 0
96
+
97
+ @property
98
+ def dimension(self):
99
+ return sum(s.dimension for s in self.simplices)
100
+
101
+ @property
102
+ def num_vertices(self):
103
+ return prod(s.num_vertices for s in self.simplices)
104
+
105
+ @property
106
+ def vertices(self):
107
+ '''
108
+ All vertices of the domain as tuples of coordinate values.
109
+ '''
110
+ factors = [s.vertices for s in self.simplices]
111
+ return list(map(lambda x: tuple(chain(*x)), product(*factors)))
112
+
113
+ @property
114
+ def default_subdivision(self):
115
+ sub = tuple(s.default_subdivision for s in self.simplices)
116
+ prod_mats = tuple(starmap(combine_subdiv, product(*sub)))
117
+ return MiSoSubdivision(*prod_mats)
118
+
119
+ def subdivision(self, *simplex_domains):
120
+ '''Bisect only the listed simplex domains; identity on the rest.
121
+ Each argument should be a MiSoDomain whose variables match one of the simplices.'''
122
+ target_vars = set()
123
+ for sd in simplex_domains:
124
+ target_vars.update(sd.variables)
125
+ sub = []
126
+ for s in self.simplices:
127
+ if any(v in target_vars for v in s.variables):
128
+ sub.append(s.default_subdivision)
129
+ else:
130
+ # Identity: single matrix that maps domain to itself
131
+ d = s.dimension
132
+ identity = subdivision(
133
+ *(tuple(int(i == j) for j in range(d)) for i in range(d + 1))
134
+ )
135
+ sub.append((identity,))
136
+ prod_mats = tuple(starmap(combine_subdiv, product(*sub)))
137
+ return MiSoSubdivision(*prod_mats)
138
+
139
+ def __getitem__(self, i): return self.variables[i]
140
+
141
+ def __str__(self):
142
+ return 'Domain(' + str(', '.join(str(s.variables) for s in self.simplices)) + ')'
143
+
144
+ def __mul__(self, o):
145
+ V = set(self.variables)
146
+ OV = set(o.variables)
147
+ if not V.isdisjoint(OV):
148
+ raise ValueError('Domains cannot share variables')
149
+ return MiSoDomain(*self.simplices, *o.simplices)
150
+
151
+ @property
152
+ def signature(self):
153
+ return tuple(len(s.variables) for s in self.simplices)
154
+
155
+ def subdomain(self, expr):
156
+ '''Return the smallest subdomain that can represent the expression.'''
157
+ return MiSoDomain(*(s.subdomain(expr) for s in self.simplices))
158
+
159
+ def degs(self, expr):
160
+ '''Return the list of degrees (one per simplex) needed to represent the expression.'''
161
+ res = []
162
+ for s in self.simplices:
163
+ d = _degree_in_vars(expr, s.variables)
164
+ if d > 0:
165
+ res.append(d)
166
+ return tuple(res)
167
+
168
+ def parameters(self, *expr):
169
+ '''
170
+ Get free parameters (arguments) from a list of expressions —
171
+ all free symbols not in domain variables.
172
+ '''
173
+ fs = set.union(*(ex.free_symbols for ex in expr))
174
+ if any(not s.is_Symbol for s in fs):
175
+ raise ValueError(f'{fs} contains non-Symbols')
176
+ return fs.difference(self.variables)
177
+
178
+ def _product_basis(self, simplexf, *degs):
179
+ assert len(self.simplices) == len(degs), \
180
+ 'Degrees/simplices list size mismatch'
181
+ factors = tuple(simplexf(s, d) for s, d in zip(self.simplices, degs))
182
+ return tuple(map(
183
+ lambda p: p.as_poly(self.variables),
184
+ map(prod, product(*factors))
185
+ ))
186
+
187
+ def bernstein(self, *degs):
188
+ key = ('bernstein', degs)
189
+ if key not in self._basis_cache:
190
+ self._basis_cache[key] = self._product_basis(lambda s, d: s.bernstein(d), *degs)
191
+ return self._basis_cache[key]
192
+
193
+ def lagrange(self, *degs):
194
+ key = ('lagrange', degs)
195
+ if key not in self._basis_cache:
196
+ self._basis_cache[key] = self._product_basis(lambda s, d: s.lagrange(d), *degs)
197
+ return self._basis_cache[key]
198
+
199
+ def monomial(self, *degs):
200
+ key = ('monomial', degs)
201
+ if key not in self._basis_cache:
202
+ self._basis_cache[key] = self._product_basis(lambda s, d: s.monomial(d), *degs)
203
+ return self._basis_cache[key]
204
+
205
+ def eval_points(self, *degs):
206
+ assert len(self.simplices) == len(degs), \
207
+ 'Degrees/simplices list size mismatch'
208
+ factors = tuple(s.eval_points(d) for s, d in zip(self.simplices, degs))
209
+ return tuple(map(lambda x: tuple(chain(*x)), product(*factors)))
210
+
211
+ def poly_map(self, deg, basis, codomain_dim=1, name=None, output_miso_poly=True):
212
+ '''
213
+ Create a polynomial map over this domain.
214
+
215
+ Generates control point symbols and returns a Matrix of `codomain_dim`
216
+ MiSoPoly (or plain SymPy) expressions — one per output coordinate.
217
+
218
+ Arguments:
219
+ deg -- polynomial degree (int or tuple, one per simplex)
220
+ basis -- Basis.LAGRANGE, Basis.BERNSTEIN, or Basis.MONOMIAL
221
+ codomain_dim -- number of output coordinates (default 1)
222
+ name -- base name for control point symbols (auto-generated if None)
223
+ output_miso_poly -- wrap output in MiSoPoly if True (default True)
224
+ '''
225
+ degs = (deg,) * len(self.simplices) if isinstance(deg, int) else tuple(deg)
226
+
227
+ if basis == Basis.LAGRANGE:
228
+ basis_polys = self.lagrange(*degs)
229
+ elif basis == Basis.BERNSTEIN:
230
+ basis_polys = self.bernstein(*degs)
231
+ elif basis == Basis.MONOMIAL:
232
+ basis_polys = self.monomial(*degs)
233
+ else:
234
+ raise ValueError(f'Unknown basis: {basis}')
235
+
236
+ N = len(basis_polys)
237
+
238
+ if name is None:
239
+ name = f'p{MiSoDomain._map_counter}'
240
+ MiSoDomain._map_counter += 1
241
+
242
+ coord_names = 'xyzw'
243
+ result = []
244
+ for d in range(codomain_dim):
245
+ cname = coord_names[d] if d < len(coord_names) else f'x{d}_'
246
+ cps = tuple(Symbol(f'{name}{cname}[{i}]') for i in range(N))
247
+ poly_expr = sum(cp * b.as_expr() for cp, b in zip(cps, basis_polys))
248
+ result.append(make_poly(poly_expr) if output_miso_poly else poly_expr)
249
+
250
+ return Matrix(result)
251
+
252
+ # Alias for backward compatibility
253
+ geo_map = poly_map
254
+
255
+ ###########################################################
256
+
257
+ class MiSoSimplex:
258
+ def __init__(self, *names):
259
+ if not names:
260
+ self._variables = tuple()
261
+ elif all(isinstance(n, str) for n in names):
262
+ vs = symbols(*names, real=True, positive=True)
263
+ if isinstance(vs, (list, tuple)):
264
+ self._variables = tuple(vs)
265
+ elif vs.is_Symbol:
266
+ self._variables = (vs,)
267
+ else:
268
+ self._variables = tuple(vs)
269
+ elif all(n.is_Symbol for n in names):
270
+ self._variables = tuple(names)
271
+ else:
272
+ raise ValueError('Variable names must be all strings or symbols')
273
+
274
+ @property
275
+ def variables(self): return self._variables
276
+
277
+ @property
278
+ def is_empty(self): return len(self.variables) == 0
279
+
280
+ @property
281
+ def dimension(self): return len(self.variables)
282
+
283
+ @property
284
+ def num_vertices(self): return self.dimension + 1
285
+
286
+ @property
287
+ def vertices(self):
288
+ return [tuple(int(x) for x in e) for e in self.eval_points(1)]
289
+
290
+ @property
291
+ def default_subdivision(self):
292
+ H = Rational(1, 2)
293
+ D = self.dimension
294
+ if D == 1:
295
+ return MiSoSubdivision(
296
+ subdivision((0,), (H,)),
297
+ subdivision((H,), (1,)),
298
+ )
299
+ elif D == 2:
300
+ return MiSoSubdivision(
301
+ subdivision((0, 0), (H, 0), (0, H)),
302
+ subdivision((H, 0), (1, 0), (H, H)),
303
+ subdivision((0, H), (H, H), (0, 1)),
304
+ subdivision((H, H), (0, H), (H, 0)),
305
+ )
306
+ elif D == 3:
307
+ return MiSoSubdivision(
308
+ subdivision((0, 0, 0), (H, 0, 0), (0, H, 0), (0, 0, H)),
309
+ subdivision((H, 0, 0), (1, 0, 0), (H, H, 0), (H, 0, H)),
310
+ subdivision((0, H, 0), (H, H, 0), (0, 1, 0), (0, H, H)),
311
+ subdivision((0, 0, H), (H, 0, H), (0, H, H), (0, 0, 1)),
312
+ subdivision((H, 0, 0), (0, H, 0), (0, 0, H), (H, 0, H)),
313
+ subdivision((H, 0, 0), (0, H, 0), (H, H, 0), (H, 0, H)),
314
+ subdivision((0, H, 0), (0, 0, H), (H, 0, H), (0, H, H)),
315
+ subdivision((0, H, 0), (H, H, 0), (H, 0, H), (0, H, H)),
316
+ )
317
+ else:
318
+ raise NotImplementedError(f'Default subdivision not implemented for simplex dimension {D}, supply your own')
319
+
320
+ def subdomain(self, expr):
321
+ fs = expr.free_symbols
322
+ return MiSoDomain.MiSoSimplex(*(v for v in self.variables if v in fs))
323
+
324
+ def bernstein(self, deg):
325
+ vs = self.variables + (1 - sum(self.variables),)
326
+ n = len(vs)
327
+
328
+ def bp(combo_coeff):
329
+ combo, coeff = combo_coeff
330
+ return (coeff * prod(v ** k for v, k in zip(vs, combo))).as_poly(self.variables)
331
+
332
+ return tuple(map(bp, _all_compositions_with_coeff(n, deg)))
333
+
334
+ def lagrange(self, deg):
335
+ vs = self.variables + (1 - sum(self.variables),)
336
+
337
+ def lp(combo):
338
+ return prod(
339
+ interpolate(
340
+ {Rational(x, deg): int(x == n) for x in range(n + 1)}, v)
341
+ for v, n in zip(vs, combo)
342
+ ).as_poly(self.variables)
343
+
344
+ return tuple(map(lp, _all_compositions(len(vs), deg)))
345
+
346
+ def monomial(self, deg):
347
+ vs = self.variables
348
+ k = len(vs)
349
+ result = []
350
+ for total in range(deg + 1):
351
+ for combo in _all_compositions(k, total):
352
+ mono = prod(v**e for v, e in zip(vs, combo))
353
+ result.append(mono.as_poly(vs))
354
+ return tuple(result)
355
+
356
+ def eval_points(self, deg):
357
+ lv = len(self.variables)
358
+ D = max(int(deg), 1)
359
+ return tuple(
360
+ tuple(Rational(n, D) for n in combo[:lv])
361
+ for combo in _all_compositions(1 + lv, deg)
362
+ )
363
+
364
+
365
+ def _all_compositions(n, total):
366
+ if n == 1:
367
+ yield (total,)
368
+ return
369
+ for i in range(total + 1):
370
+ for rest in _all_compositions(n - 1, total - i):
371
+ yield (i,) + rest
372
+
373
+
374
+ def _all_compositions_with_coeff(n, deg):
375
+ from math import factorial
376
+ def multinomial(lst):
377
+ s = sum(lst)
378
+ r = factorial(s)
379
+ for x in lst:
380
+ r //= factorial(x)
381
+ return r
382
+ for combo in _all_compositions(n, deg):
383
+ yield combo, multinomial(combo)
384
+
385
+
386
+ ###########################################################
387
+ # Basis enum and geo_map helper
388
+ ###########################################################
389
+
390
+ class Basis:
391
+ LAGRANGE = 'lagrange'
392
+ BERNSTEIN = 'bernstein'
393
+ MONOMIAL = 'monomial'
394
+
395
+
396
+ def geo_map(domain, deg, basis, embed_dim=1, name=None, poly=True):
397
+ '''Alias for domain.poly_map() as a free function (backward-compatible signature).'''
398
+ return domain.poly_map(deg, basis, codomain_dim=embed_dim, name=name, output_miso_poly=poly)