minimizesolve 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- minimizesolve-1.0.dist-info/METADATA +5 -0
- minimizesolve-1.0.dist-info/RECORD +12 -0
- minimizesolve-1.0.dist-info/WHEEL +5 -0
- minimizesolve-1.0.dist-info/top_level.txt +1 -0
- miso/__init__.py +23 -0
- miso/codegen.py +503 -0
- miso/codewriter.py +267 -0
- miso/domain.py +398 -0
- miso/generate.py +266 -0
- miso/logger.py +55 -0
- miso/poly.py +62 -0
- miso/subdivision.py +68 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
miso/__init__.py,sha256=7h-MClkP2Rffy-TUZkP0Q_5nS2E5KKliY_ssGqrSwQY,837
|
|
2
|
+
miso/codegen.py,sha256=LJHZGiULHn3v__W1IehSKGJ1-tThqVuFblDMNf_gLHE,16840
|
|
3
|
+
miso/codewriter.py,sha256=XPePIIIGkDQGLEMqlvboiq4MhWdciCb9QELbnmoJn78,8066
|
|
4
|
+
miso/domain.py,sha256=OV3uzG5ULUZLwkXXgNS5sTtwE-JkwX7I3eO2Y9N2MDU,12368
|
|
5
|
+
miso/generate.py,sha256=mOhyfabDRPF6gQyz0Wmfh1z9FJd23jWsBiFsu2DmVjM,8776
|
|
6
|
+
miso/logger.py,sha256=UzPGDq4-BJ-45DN47uB_fUo110YNpcTglIk7eduTfk8,1173
|
|
7
|
+
miso/poly.py,sha256=uNtsNDiz0JkSdCN_RxeTy2swygVEjBLkTfC2AiUvCro,1546
|
|
8
|
+
miso/subdivision.py,sha256=wzD1jeSrctyQ0s29FzK-rRynwFIQ-LEHnqb2kjXsN3k,2045
|
|
9
|
+
minimizesolve-1.0.dist-info/METADATA,sha256=Ar-6UeUPqcYX0fcQzBKH9_Cl9gYySnr2yzs-yQfni4Q,99
|
|
10
|
+
minimizesolve-1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
11
|
+
minimizesolve-1.0.dist-info/top_level.txt,sha256=iwv6V2F286f_4aS4Rq0XaYXFfE_w0LTpI9J-DycHZig,5
|
|
12
|
+
minimizesolve-1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
miso
|
miso/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
MISO_VNUM = (1, 0)
|
|
2
|
+
MISO_VERSION = '.'.join(map(str, MISO_VNUM))
|
|
3
|
+
|
|
4
|
+
import sys
|
|
5
|
+
MIN_PYTHON_VERSION = (3, 7)
|
|
6
|
+
if sys.version_info < MIN_PYTHON_VERSION:
|
|
7
|
+
raise RuntimeError(
|
|
8
|
+
f'Python ({sys.version_info}) is not supported. '
|
|
9
|
+
f'MiSo {MISO_VERSION} requires Python '
|
|
10
|
+
f'{".".join(map(str, MIN_PYTHON_VERSION))} or higher.'
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from .domain import MiSoDomain, geo_map, Basis
|
|
14
|
+
from .generate import generate
|
|
15
|
+
from .poly import make_poly, unmake_poly, collapse, expand, is_miso_poly, tree
|
|
16
|
+
from .subdivision import MiSoSubdivision, subdivision, combine_subdiv
|
|
17
|
+
from .codewriter import (CodeWriter, inclusion_func, replace_special_func,
|
|
18
|
+
sinpi, cospi, tanpi, asinpi, acospi, atanpi,
|
|
19
|
+
exp2, exp10, expm1, exp2m1, exp10m1,
|
|
20
|
+
log2, log10, log1p, log2p1, log10p1,
|
|
21
|
+
cbrt, rsqrt, hypot)
|
|
22
|
+
from .codegen import MiSoCodeGenerator
|
|
23
|
+
from .logger import MiSoLogger
|
miso/codegen.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from sympy import cse, numbered_symbols
|
|
3
|
+
from .codewriter import CodeWriter, inclusion_func
|
|
4
|
+
from .logger import MiSoLogger
|
|
5
|
+
|
|
6
|
+
NAMESPACE = 'miso'
|
|
7
|
+
TARGET_NAME = NAMESPACE
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _group_args(syms):
|
|
11
|
+
'''
|
|
12
|
+
Group symbols into named vectors or scalars.
|
|
13
|
+
Symbols named 'base[i]' are grouped into a vector (base, [sym0, sym1, ...]).
|
|
14
|
+
Indices must be contiguous from 0; raises ValueError otherwise.
|
|
15
|
+
Symbols without '[' are treated as scalars: (name, [sym]).
|
|
16
|
+
Groups are sorted alphabetically by base name.
|
|
17
|
+
'''
|
|
18
|
+
vectors = {}
|
|
19
|
+
scalars = []
|
|
20
|
+
for s in syms:
|
|
21
|
+
n = s.name
|
|
22
|
+
if '[' in n:
|
|
23
|
+
base, rest = n.rsplit('[', 1)
|
|
24
|
+
idx = int(rest.rstrip(']'))
|
|
25
|
+
vectors.setdefault(base, {})[idx] = s
|
|
26
|
+
else:
|
|
27
|
+
scalars.append(s)
|
|
28
|
+
|
|
29
|
+
result = []
|
|
30
|
+
for base in sorted(vectors):
|
|
31
|
+
d = vectors[base]
|
|
32
|
+
count = len(d)
|
|
33
|
+
for i in range(count):
|
|
34
|
+
if i not in d:
|
|
35
|
+
raise ValueError(f'Missing index {i} in vector argument {base!r}')
|
|
36
|
+
result.append((base, [d[i] for i in range(count)]))
|
|
37
|
+
|
|
38
|
+
for s in sorted(scalars, key=lambda s: s.name):
|
|
39
|
+
result.append((s.name, [s]))
|
|
40
|
+
|
|
41
|
+
return result
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _args_decl(grouped):
|
|
45
|
+
'''Declaration strings for grouped args: RealVector<N> for vectors, Real for scalars.'''
|
|
46
|
+
parts = []
|
|
47
|
+
for name, syms in grouped:
|
|
48
|
+
if '[' in syms[0].name:
|
|
49
|
+
parts.append(f'const RealVector<{len(syms)}> &{name}')
|
|
50
|
+
else:
|
|
51
|
+
parts.append(f'const Real {name}')
|
|
52
|
+
return parts
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MiSoCodeGenerator:
|
|
56
|
+
''' Manages C++ code generation for a MiSo problem class '''
|
|
57
|
+
|
|
58
|
+
def __init__(self, source_dir, class_name, domain, clean=True):
|
|
59
|
+
self.class_name = class_name
|
|
60
|
+
self.domain = domain
|
|
61
|
+
self.directory = os.path.join(source_dir, class_name)
|
|
62
|
+
self.added_cpp = []
|
|
63
|
+
self.added_sig = []
|
|
64
|
+
self.log = MiSoLogger.get()
|
|
65
|
+
|
|
66
|
+
if not os.path.exists(self.directory):
|
|
67
|
+
os.makedirs(self.directory)
|
|
68
|
+
elif clean:
|
|
69
|
+
for file_name in os.listdir(self.directory):
|
|
70
|
+
file_path = os.path.join(self.directory, file_name)
|
|
71
|
+
if os.path.isfile(file_path):
|
|
72
|
+
os.remove(file_path)
|
|
73
|
+
|
|
74
|
+
def _cpp_path(self, name):
|
|
75
|
+
return os.path.join(self.directory, name + '.cpp')
|
|
76
|
+
|
|
77
|
+
def _write_with_cse(self, f, exprs):
|
|
78
|
+
'''Write expressions with CSE temporaries extraction.'''
|
|
79
|
+
replacements, reduced = cse(exprs, symbols=numbered_symbols('_t'))
|
|
80
|
+
if replacements:
|
|
81
|
+
f.region('Temporaries')
|
|
82
|
+
for sym, val in replacements:
|
|
83
|
+
f.write(f'const Real {sym} = {f.format(val)};')
|
|
84
|
+
f.endregion()
|
|
85
|
+
f.region('Expressions')
|
|
86
|
+
f.write('return {')
|
|
87
|
+
for e in reduced:
|
|
88
|
+
f.write(f.format(e) + ',')
|
|
89
|
+
f.write('};')
|
|
90
|
+
f.endregion()
|
|
91
|
+
|
|
92
|
+
def write_cl(self, member_name, lag_exprs, poly_args):
|
|
93
|
+
'''
|
|
94
|
+
Write CL_<member_name>.cpp — evaluates polynomial at Lagrange points.
|
|
95
|
+
'''
|
|
96
|
+
func_name = f'CL_{member_name}'
|
|
97
|
+
filename = func_name + '.cpp'
|
|
98
|
+
self.log.info(f'Writing {filename}')
|
|
99
|
+
self.added_cpp.append(filename)
|
|
100
|
+
N = len(lag_exprs)
|
|
101
|
+
|
|
102
|
+
grouped = _group_args(poly_args)
|
|
103
|
+
inputs_str = ', '.join(_args_decl(grouped))
|
|
104
|
+
|
|
105
|
+
with CodeWriter(self._cpp_path(func_name)) as f:
|
|
106
|
+
f.write(f'#include "{self.class_name}.hpp"')
|
|
107
|
+
f.write(f'namespace {NAMESPACE} {{')
|
|
108
|
+
sig = f'RealVector<{N}> {self.class_name}::{func_name}({inputs_str})'
|
|
109
|
+
f.write(sig + ' {')
|
|
110
|
+
self._write_with_cse(f, lag_exprs)
|
|
111
|
+
f.write('}')
|
|
112
|
+
f.write('}')
|
|
113
|
+
|
|
114
|
+
self.added_sig.append(f'static RealVector<{N}> {func_name}({inputs_str});')
|
|
115
|
+
|
|
116
|
+
def write_cb(self, member_name, cb_exprs, poly_args):
|
|
117
|
+
'''
|
|
118
|
+
Write CB_<member_name>.cpp — combined Lagrange-to-Bernstein conversion.
|
|
119
|
+
'''
|
|
120
|
+
func_name = f'CB_{member_name}'
|
|
121
|
+
filename = func_name + '.cpp'
|
|
122
|
+
self.log.info(f'Writing {filename}')
|
|
123
|
+
self.added_cpp.append(filename)
|
|
124
|
+
N = len(cb_exprs)
|
|
125
|
+
|
|
126
|
+
grouped = _group_args(poly_args)
|
|
127
|
+
inputs_str = ', '.join(_args_decl(grouped))
|
|
128
|
+
|
|
129
|
+
with CodeWriter(self._cpp_path(func_name)) as f:
|
|
130
|
+
f.write(f'#include "{self.class_name}.hpp"')
|
|
131
|
+
f.write(f'namespace {NAMESPACE} {{')
|
|
132
|
+
sig = f'RealVector<{N}> {self.class_name}::{func_name}({inputs_str})'
|
|
133
|
+
f.write(sig + ' {')
|
|
134
|
+
self._write_with_cse(f, cb_exprs)
|
|
135
|
+
f.write('}')
|
|
136
|
+
f.write('}')
|
|
137
|
+
|
|
138
|
+
self.added_sig.append(f'static RealVector<{N}> {func_name}({inputs_str});')
|
|
139
|
+
|
|
140
|
+
def write_lb(self, sig_name, lb_exprs, N):
|
|
141
|
+
'''
|
|
142
|
+
Write LB_<sig>.cpp — Lagrange to Bernstein conversion.
|
|
143
|
+
Takes const RealVector<N> &_l as input.
|
|
144
|
+
'''
|
|
145
|
+
func_name = f'LB_{sig_name}'
|
|
146
|
+
filename = func_name + '.cpp'
|
|
147
|
+
self.log.info(f'Writing {filename}')
|
|
148
|
+
self.added_cpp.append(filename)
|
|
149
|
+
|
|
150
|
+
with CodeWriter(self._cpp_path(func_name)) as f:
|
|
151
|
+
f.write(f'#include "{self.class_name}.hpp"')
|
|
152
|
+
f.write(f'namespace {NAMESPACE} {{')
|
|
153
|
+
f.write(f'RealVector<{N}> {self.class_name}::{func_name}(const RealVector<{N}> &_l) {{')
|
|
154
|
+
self._write_with_cse(f, lb_exprs)
|
|
155
|
+
f.write('}')
|
|
156
|
+
f.write('}')
|
|
157
|
+
|
|
158
|
+
self.added_sig.append(f'static RealVector<{N}> {func_name}(const RealVector<{N}> &_l);')
|
|
159
|
+
|
|
160
|
+
def write_subdivision(self, sig_name, si, q, b2b_exprs, N):
|
|
161
|
+
'''
|
|
162
|
+
Write subdiv_<si>_<sig>_<q>.cpp — B2B subdivision (one file per child).
|
|
163
|
+
Takes const RealVector<N> &_b as input.
|
|
164
|
+
'''
|
|
165
|
+
func_name = f'subdiv_{si}_{sig_name}_{q}'
|
|
166
|
+
filename = func_name + '.cpp'
|
|
167
|
+
self.log.info(f'Writing {filename}')
|
|
168
|
+
self.added_cpp.append(filename)
|
|
169
|
+
|
|
170
|
+
with CodeWriter(self._cpp_path(func_name)) as f:
|
|
171
|
+
f.write(f'#include "{self.class_name}.hpp"')
|
|
172
|
+
f.write(f'namespace {NAMESPACE} {{')
|
|
173
|
+
f.write(f'RealVector<{N}> {self.class_name}::{func_name}(const RealVector<{N}> &_b) {{')
|
|
174
|
+
self._write_with_cse(f, b2b_exprs)
|
|
175
|
+
f.write('}')
|
|
176
|
+
f.write('}')
|
|
177
|
+
|
|
178
|
+
self.added_sig.append(f'static RealVector<{N}> {func_name}(const RealVector<{N}> &_b);')
|
|
179
|
+
|
|
180
|
+
def write_subdivision_combined(self, sig_name, si, all_b2b_exprs, N):
|
|
181
|
+
'''
|
|
182
|
+
Write subdiv_<si>_<sig>.cpp — all children in one function with shared CSE.
|
|
183
|
+
Takes const RealVector<N> &_b, returns std::array<RealVector<N>, Q>.
|
|
184
|
+
'''
|
|
185
|
+
Q = len(all_b2b_exprs)
|
|
186
|
+
func_name = f'subdiv_{si}_{sig_name}'
|
|
187
|
+
filename = func_name + '.cpp'
|
|
188
|
+
self.log.info(f'Writing {filename}')
|
|
189
|
+
self.added_cpp.append(filename)
|
|
190
|
+
|
|
191
|
+
flat = [e for exprs in all_b2b_exprs for e in exprs]
|
|
192
|
+
replacements, reduced = cse(flat, symbols=numbered_symbols('_t'))
|
|
193
|
+
|
|
194
|
+
with CodeWriter(self._cpp_path(func_name)) as f:
|
|
195
|
+
f.write(f'#include "{self.class_name}.hpp"')
|
|
196
|
+
f.write(f'namespace {NAMESPACE} {{')
|
|
197
|
+
f.write(f'std::array<RealVector<{N}>, {Q}> {self.class_name}::{func_name}(const RealVector<{N}> &_b) {{')
|
|
198
|
+
if replacements:
|
|
199
|
+
f.region('Temporaries')
|
|
200
|
+
for sym, val in replacements:
|
|
201
|
+
f.write(f'const Real {sym} = {f.format(val)};')
|
|
202
|
+
f.endregion()
|
|
203
|
+
f.region('Expressions')
|
|
204
|
+
f.write('return {{')
|
|
205
|
+
for q in range(Q):
|
|
206
|
+
chunk = reduced[q*N:(q+1)*N]
|
|
207
|
+
f.write('{')
|
|
208
|
+
for e in chunk:
|
|
209
|
+
f.write(f.format(e) + ',')
|
|
210
|
+
f.write('},')
|
|
211
|
+
f.write('}};')
|
|
212
|
+
f.endregion()
|
|
213
|
+
f.write('}')
|
|
214
|
+
f.write('}')
|
|
215
|
+
|
|
216
|
+
self.added_sig.append(f'static std::array<RealVector<{N}>, {Q}> {func_name}(const RealVector<{N}> &_b);')
|
|
217
|
+
|
|
218
|
+
def write_class(self, *, members, var_syms, poly_sizes, member_sigs,
|
|
219
|
+
var_sig, inclusions, samples, num_constraints, has_objective,
|
|
220
|
+
args, all_subdivisions, conversions_B2B, separate_conversion,
|
|
221
|
+
separate_subdivision=False):
|
|
222
|
+
'''Write <class_name>.cpp and <class_name>.hpp.'''
|
|
223
|
+
source_name = f'{self.class_name}.cpp'
|
|
224
|
+
header_name = f'{self.class_name}.hpp'
|
|
225
|
+
self.log.info(f'Writing {source_name} and {header_name}')
|
|
226
|
+
self.added_cpp.append(source_name)
|
|
227
|
+
|
|
228
|
+
name = self.class_name
|
|
229
|
+
domain = self.domain
|
|
230
|
+
dim = domain.dimension
|
|
231
|
+
num_vertices = domain.num_vertices
|
|
232
|
+
|
|
233
|
+
poly_member_names = [str(v) for v in members.values()]
|
|
234
|
+
var_member_names = [str(v) for v in var_syms.values()]
|
|
235
|
+
all_member_names = poly_member_names + var_member_names
|
|
236
|
+
|
|
237
|
+
# Constructor argument declarations
|
|
238
|
+
grouped_args = _group_args(args)
|
|
239
|
+
args_decl = _args_decl(grouped_args)
|
|
240
|
+
|
|
241
|
+
# Args that appear directly in inclusion expressions (not via poly members)
|
|
242
|
+
# must be stored as members so inclusion()/sample() can access them.
|
|
243
|
+
member_syms = set(members.values()) | set(var_syms.values())
|
|
244
|
+
raw_arg_syms = {s for inc in inclusions for s in inc.free_symbols
|
|
245
|
+
if s in set(args) and s not in member_syms}
|
|
246
|
+
raw_arg_groups = _group_args(raw_arg_syms)
|
|
247
|
+
|
|
248
|
+
# Member initializations
|
|
249
|
+
poly_init = []
|
|
250
|
+
for poly, mem_sym in members.items():
|
|
251
|
+
mname = str(mem_sym)
|
|
252
|
+
sig = member_sigs[mem_sym]
|
|
253
|
+
poly_syms = {s for s in poly.free_symbols if s in set(args)}
|
|
254
|
+
poly_bases = {s.name.rsplit('[', 1)[0] if '[' in s.name else s.name for s in poly_syms}
|
|
255
|
+
full_poly_syms = {s for s in args if (s.name.rsplit('[', 1)[0] if '[' in s.name else s.name) in poly_bases}
|
|
256
|
+
call_args = ', '.join(name for name, _ in _group_args(full_poly_syms))
|
|
257
|
+
if separate_conversion:
|
|
258
|
+
poly_init.append(f'{mname}(LB_{sig}(CL_{mname}({call_args})))')
|
|
259
|
+
else:
|
|
260
|
+
poly_init.append(f'{mname}(CB_{mname}({call_args}))')
|
|
261
|
+
|
|
262
|
+
var_init = []
|
|
263
|
+
for i, (var_sym, vsym) in enumerate(var_syms.items()):
|
|
264
|
+
vname = str(vsym)
|
|
265
|
+
coords = [str(v[i]) for v in domain.vertices]
|
|
266
|
+
var_init.append(f'{vname}{{{", ".join(coords)}}}')
|
|
267
|
+
|
|
268
|
+
raw_arg_init = [f'{gname}({gname})' for gname, _ in raw_arg_groups]
|
|
269
|
+
all_init = raw_arg_init + poly_init + var_init
|
|
270
|
+
|
|
271
|
+
# Build mapping: for each strategy and combined subdivision index q,
|
|
272
|
+
# which per-signature B2B function applies to each member
|
|
273
|
+
# member_subdiv_mapping[si][mem_sym] = list of projected q indices
|
|
274
|
+
member_subdiv_mapping = {}
|
|
275
|
+
for si, subdivs in enumerate(all_subdivisions):
|
|
276
|
+
member_subdiv_mapping[si] = {}
|
|
277
|
+
for poly, mem_sym in members.items():
|
|
278
|
+
subdomain = domain.subdomain(poly)
|
|
279
|
+
drop = [i for i in range(dim) if domain.variables[i] not in subdomain.variables]
|
|
280
|
+
proj = subdivs.projection(*drop)
|
|
281
|
+
member_subdiv_mapping[si][mem_sym] = [proj.mapping(i) for i in range(len(subdivs))]
|
|
282
|
+
|
|
283
|
+
##########
|
|
284
|
+
# SOURCE #
|
|
285
|
+
##########
|
|
286
|
+
with CodeWriter(os.path.join(self.directory, source_name)) as f:
|
|
287
|
+
f.write(f'#include "{header_name}"')
|
|
288
|
+
f.write(f'namespace {NAMESPACE} {{')
|
|
289
|
+
|
|
290
|
+
# Subdivision constructor
|
|
291
|
+
member_decl = (
|
|
292
|
+
[f'const RealVector<{poly_sizes[ms]}> &{str(ms)}' for ms in members.values()] +
|
|
293
|
+
[f'const RealVector<{num_vertices}> &{str(vs)}' for vs in var_syms.values()]
|
|
294
|
+
)
|
|
295
|
+
f.write(f'{name}::{name}(const {name} &parent, unsigned q,')
|
|
296
|
+
f.joint_write(',', member_decl)
|
|
297
|
+
f.write(') :')
|
|
298
|
+
f.joint_write(',', [f'{m}({m})' for m in all_member_names])
|
|
299
|
+
f.write('{ inherit(parent, q); }')
|
|
300
|
+
f.write()
|
|
301
|
+
|
|
302
|
+
# Main constructor
|
|
303
|
+
if args_decl:
|
|
304
|
+
f.write(f'{name}::{name}(')
|
|
305
|
+
f.joint_write(',', args_decl)
|
|
306
|
+
f.write(') :')
|
|
307
|
+
else:
|
|
308
|
+
f.write(f'{name}::{name}() :')
|
|
309
|
+
f.joint_write(',', all_init)
|
|
310
|
+
f.write('{')
|
|
311
|
+
f.write('history = std::make_shared<SubdivHistory>();')
|
|
312
|
+
f.write('}')
|
|
313
|
+
f.write()
|
|
314
|
+
|
|
315
|
+
# inherit()
|
|
316
|
+
f.write(f'void {name}::inherit(const {name} &parent, unsigned q) {{')
|
|
317
|
+
f.write('depth = parent.depth + 1;')
|
|
318
|
+
f.write('history = std::make_shared<SubdivHistory>(q, parent.history);')
|
|
319
|
+
for gname, _ in raw_arg_groups:
|
|
320
|
+
f.write(f'{gname} = parent.{gname};')
|
|
321
|
+
f.write('}')
|
|
322
|
+
|
|
323
|
+
# inclusion()
|
|
324
|
+
f.write(f'RealInterval {name}::inclusion(unsigned i) const {{')
|
|
325
|
+
f.write('switch (i) {')
|
|
326
|
+
for j, inc in enumerate(inclusions):
|
|
327
|
+
f.write(f'case {j}: return {f.format(inc)};')
|
|
328
|
+
f.write('default: throw std::logic_error("Undefined inclusion function");')
|
|
329
|
+
f.write('}')
|
|
330
|
+
f.write('}')
|
|
331
|
+
|
|
332
|
+
# sample(): same structure as inclusion but with .sample<>() calls
|
|
333
|
+
f.write(f'RealVector<{name}::numVertices> {name}::sample(unsigned i) const {{')
|
|
334
|
+
f.write('switch (i) {')
|
|
335
|
+
for j, samp in enumerate(samples):
|
|
336
|
+
f.write(f'case {j}: return {f.format(samp)};')
|
|
337
|
+
f.write('default: throw std::logic_error("Undefined sampling function");')
|
|
338
|
+
f.write('}')
|
|
339
|
+
f.write('}')
|
|
340
|
+
|
|
341
|
+
# split<SI>() explicit specializations
|
|
342
|
+
for si, subdivs in enumerate(all_subdivisions):
|
|
343
|
+
num_subs = len(subdivs)
|
|
344
|
+
f.write(f'template<> std::array<{name}, {name}::schemes[{si}]> {name}::split<{si}>() const {{')
|
|
345
|
+
if not separate_subdivision:
|
|
346
|
+
for mem_sym in members.values():
|
|
347
|
+
mname = str(mem_sym)
|
|
348
|
+
sig = member_sigs[mem_sym]
|
|
349
|
+
f.write(f'auto _r_{mname} = subdiv_{si}_{sig}({mname});')
|
|
350
|
+
for vsym in var_syms.values():
|
|
351
|
+
vname = str(vsym)
|
|
352
|
+
f.write(f'auto _r_{vname} = subdiv_{si}_{var_sig}({vname});')
|
|
353
|
+
f.write('return {')
|
|
354
|
+
for qi in range(num_subs):
|
|
355
|
+
child_args = []
|
|
356
|
+
for mem_sym in members.values():
|
|
357
|
+
mname = str(mem_sym)
|
|
358
|
+
pq = member_subdiv_mapping[si][mem_sym][qi]
|
|
359
|
+
child_args.append(f'_r_{mname}[{pq}]')
|
|
360
|
+
for vsym in var_syms.values():
|
|
361
|
+
vname = str(vsym)
|
|
362
|
+
child_args.append(f'_r_{vname}[{qi}]')
|
|
363
|
+
f.write(f'{name}{{*this, {qi},')
|
|
364
|
+
f.joint_write(', ', child_args)
|
|
365
|
+
f.write('},')
|
|
366
|
+
f.write('};')
|
|
367
|
+
else:
|
|
368
|
+
f.write('return {')
|
|
369
|
+
for qi in range(num_subs):
|
|
370
|
+
splits = []
|
|
371
|
+
for mem_sym in members.values():
|
|
372
|
+
mname = str(mem_sym)
|
|
373
|
+
sig = member_sigs[mem_sym]
|
|
374
|
+
pq = member_subdiv_mapping[si][mem_sym][qi]
|
|
375
|
+
splits.append(f'subdiv_{si}_{sig}_{pq}({mname})')
|
|
376
|
+
for vsym in var_syms.values():
|
|
377
|
+
vname = str(vsym)
|
|
378
|
+
splits.append(f'subdiv_{si}_{var_sig}_{qi}({vname})')
|
|
379
|
+
f.write(f'{name}{{*this, {qi},')
|
|
380
|
+
f.joint_write(', ', splits)
|
|
381
|
+
f.write('},')
|
|
382
|
+
f.write('};')
|
|
383
|
+
f.write('}')
|
|
384
|
+
|
|
385
|
+
# operator<<
|
|
386
|
+
f.write(f'std::ostream &operator<<(std::ostream &out, const {name} &s) {{')
|
|
387
|
+
f.write('for (unsigned i=0; i<s.numVertices; ++i) {')
|
|
388
|
+
f.write(f"out << s.getVertex(i) << ' ';")
|
|
389
|
+
f.write('}')
|
|
390
|
+
f.write('return out;')
|
|
391
|
+
f.write('}')
|
|
392
|
+
|
|
393
|
+
f.write('}')
|
|
394
|
+
|
|
395
|
+
##########
|
|
396
|
+
# HEADER #
|
|
397
|
+
##########
|
|
398
|
+
with CodeWriter(os.path.join(self.directory, header_name)) as f:
|
|
399
|
+
f.write('#pragma once')
|
|
400
|
+
f.write('#include "core/RealVector.hpp"')
|
|
401
|
+
f.write('#include "core/SubdivHistory.hpp"')
|
|
402
|
+
f.write()
|
|
403
|
+
f.write(f'namespace {NAMESPACE} {{')
|
|
404
|
+
f.write(f'class {name} {{')
|
|
405
|
+
f.write('public:')
|
|
406
|
+
f.write(f'static constexpr unsigned dimension = {dim};')
|
|
407
|
+
f.write(f'static constexpr unsigned numVertices = {num_vertices};')
|
|
408
|
+
f.write(f'static constexpr unsigned numConstraints = {num_constraints};')
|
|
409
|
+
f.write(f'static constexpr bool hasObjective = {"true" if has_objective else "false"};')
|
|
410
|
+
f.write(f'unsigned scheme = 0;')
|
|
411
|
+
f.write(f'unsigned depth = 0;')
|
|
412
|
+
f.write(f'std::shared_ptr<const SubdivHistory> history;')
|
|
413
|
+
f.write()
|
|
414
|
+
|
|
415
|
+
# Main constructor declaration
|
|
416
|
+
if args_decl:
|
|
417
|
+
f.write(f'{name}(')
|
|
418
|
+
f.joint_write(',', args_decl)
|
|
419
|
+
f.write(');')
|
|
420
|
+
f.write()
|
|
421
|
+
f.write(f'{name}() = default;')
|
|
422
|
+
else:
|
|
423
|
+
f.write(f'{name}();')
|
|
424
|
+
f.write()
|
|
425
|
+
|
|
426
|
+
# schemeSize()
|
|
427
|
+
f.write('unsigned schemeSize() const {')
|
|
428
|
+
f.write('auto s = (scheme < schemes.size()) ? scheme : 0;')
|
|
429
|
+
f.write('return schemes.at(s);')
|
|
430
|
+
f.write('}')
|
|
431
|
+
f.write()
|
|
432
|
+
|
|
433
|
+
# width()
|
|
434
|
+
widthlist = [f'{v}.inclusion().width()' for v in var_member_names]
|
|
435
|
+
f.write('const Real width() const {')
|
|
436
|
+
f.write('return std::max({')
|
|
437
|
+
f.joint_write(',', widthlist)
|
|
438
|
+
f.write('});')
|
|
439
|
+
f.write('}')
|
|
440
|
+
f.write()
|
|
441
|
+
|
|
442
|
+
# getVertex()
|
|
443
|
+
getvertlist = [f'{v}[i]' for v in var_member_names]
|
|
444
|
+
f.write('RealVector<dimension> getVertex(unsigned i) const {')
|
|
445
|
+
f.write('return {')
|
|
446
|
+
f.joint_write(',', getvertlist)
|
|
447
|
+
f.write('};')
|
|
448
|
+
f.write('}')
|
|
449
|
+
f.write()
|
|
450
|
+
|
|
451
|
+
# schemes array (public, needed as template argument)
|
|
452
|
+
scheme_sizes = ', '.join(str(len(s)) for s in all_subdivisions)
|
|
453
|
+
f.write(f'static constexpr std::array<unsigned, {len(all_subdivisions)}> schemes = {{{scheme_sizes}}};')
|
|
454
|
+
f.write()
|
|
455
|
+
|
|
456
|
+
# Method declarations
|
|
457
|
+
f.write('RealInterval inclusion(unsigned i) const;')
|
|
458
|
+
f.write(f'RealVector<numVertices> sample(unsigned i) const;')
|
|
459
|
+
f.write(f'template<unsigned SI=0> std::array<{name}, schemes[SI]> split() const;')
|
|
460
|
+
f.write(f'void inherit(const {name} &parent, unsigned q);')
|
|
461
|
+
f.write(f'friend std::ostream &operator<<(std::ostream &out, const {name} &s);')
|
|
462
|
+
f.write()
|
|
463
|
+
|
|
464
|
+
# Private
|
|
465
|
+
f.write('private:')
|
|
466
|
+
f.write()
|
|
467
|
+
for gname, gsyms in raw_arg_groups:
|
|
468
|
+
if '[' in gsyms[0].name:
|
|
469
|
+
f.write(f'RealVector<{len(gsyms)}> {gname};')
|
|
470
|
+
else:
|
|
471
|
+
f.write(f'Real {gname};')
|
|
472
|
+
for ms in members.values():
|
|
473
|
+
f.write(f'RealVector<{poly_sizes[ms]}> {str(ms)};')
|
|
474
|
+
for vsym in var_syms.values():
|
|
475
|
+
f.write(f'RealVector<{num_vertices}> {str(vsym)};')
|
|
476
|
+
f.write()
|
|
477
|
+
|
|
478
|
+
# Private subdivision constructor
|
|
479
|
+
f.write(f'{name}(const {name} &parent, unsigned q,')
|
|
480
|
+
f.joint_write(',', member_decl)
|
|
481
|
+
f.write(');')
|
|
482
|
+
f.write()
|
|
483
|
+
|
|
484
|
+
# Static function declarations
|
|
485
|
+
for sig in self.added_sig:
|
|
486
|
+
f.write(sig)
|
|
487
|
+
|
|
488
|
+
f.write('};')
|
|
489
|
+
f.write()
|
|
490
|
+
for si in range(len(all_subdivisions)):
|
|
491
|
+
f.write(f'template<> std::array<{name}, {name}::schemes[{si}]> {name}::split<{si}>() const;')
|
|
492
|
+
f.write('}')
|
|
493
|
+
|
|
494
|
+
def write_cmakelists(self):
|
|
495
|
+
''' Write CMakeLists.txt '''
|
|
496
|
+
self.log.info('Writing CMakeLists.txt')
|
|
497
|
+
with CodeWriter(os.path.join(self.directory, 'CMakeLists.txt')) as f:
|
|
498
|
+
f.write('set(SOURCES')
|
|
499
|
+
for s in self.added_cpp:
|
|
500
|
+
f.write(s)
|
|
501
|
+
f.write(')')
|
|
502
|
+
f.write()
|
|
503
|
+
f.write(f'target_sources({TARGET_NAME} PRIVATE ${{SOURCES}})')
|