invrs-opt 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +1 -1
- invrs_opt/optimizers/lbfgsb.py +103 -77
- invrs_opt/optimizers/wrapped_optax.py +90 -67
- invrs_opt/parameterization/base.py +62 -16
- invrs_opt/parameterization/filter_project.py +48 -18
- invrs_opt/parameterization/gaussian_levelset.py +88 -96
- invrs_opt/parameterization/pixel.py +20 -6
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.9.0.dist-info}/METADATA +2 -2
- invrs_opt-0.9.0.dist-info/RECORD +20 -0
- invrs_opt-0.8.0.dist-info/RECORD +0 -20
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.9.0.dist-info}/LICENSE +0 -0
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.9.0.dist-info}/WHEEL +0 -0
- {invrs_opt-0.8.0.dist-info → invrs_opt-0.9.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
invrs_opt/optimizers/lbfgsb.py
CHANGED
@@ -19,7 +19,7 @@ from totypes import types
|
|
19
19
|
|
20
20
|
from invrs_opt.optimizers import base
|
21
21
|
from invrs_opt.parameterization import (
|
22
|
-
base as
|
22
|
+
base as param_base,
|
23
23
|
filter_project,
|
24
24
|
gaussian_levelset,
|
25
25
|
pixel,
|
@@ -252,7 +252,7 @@ def levelset_lbfgsb(
|
|
252
252
|
|
253
253
|
|
254
254
|
def parameterized_lbfgsb(
|
255
|
-
density_parameterization: Optional[
|
255
|
+
density_parameterization: Optional[param_base.Density2DParameterization],
|
256
256
|
penalty: float,
|
257
257
|
maxcor: int = DEFAULT_MAXCOR,
|
258
258
|
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
@@ -296,59 +296,6 @@ def parameterized_lbfgsb(
|
|
296
296
|
if density_parameterization is None:
|
297
297
|
density_parameterization = pixel.pixel()
|
298
298
|
|
299
|
-
def _init_latents(params: PyTree) -> PyTree:
|
300
|
-
def _leaf_init_latents(leaf: Any) -> Any:
|
301
|
-
leaf = _clip(leaf)
|
302
|
-
if not _is_density(leaf) or density_parameterization is None:
|
303
|
-
return leaf
|
304
|
-
return density_parameterization.from_density(leaf)
|
305
|
-
|
306
|
-
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
307
|
-
|
308
|
-
def _params_from_latents(latent_params: PyTree) -> PyTree:
|
309
|
-
def _leaf_params_from_latents(leaf: Any) -> Any:
|
310
|
-
if not _is_parameterized_density(leaf) or density_parameterization is None:
|
311
|
-
return leaf
|
312
|
-
return density_parameterization.to_density(leaf)
|
313
|
-
|
314
|
-
return tree_util.tree_map(
|
315
|
-
_leaf_params_from_latents,
|
316
|
-
latent_params,
|
317
|
-
is_leaf=_is_parameterized_density,
|
318
|
-
)
|
319
|
-
|
320
|
-
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
321
|
-
def _constraint_loss_leaf(
|
322
|
-
params: parameterization_base.ParameterizedDensity2DArrayBase,
|
323
|
-
) -> jnp.ndarray:
|
324
|
-
constraints = density_parameterization.constraints(params)
|
325
|
-
constraints = tree_util.tree_map(
|
326
|
-
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
327
|
-
constraints,
|
328
|
-
)
|
329
|
-
return jnp.sum(jnp.asarray(constraints))
|
330
|
-
|
331
|
-
losses = [0.0] + [
|
332
|
-
_constraint_loss_leaf(p)
|
333
|
-
for p in tree_util.tree_leaves(
|
334
|
-
latent_params, is_leaf=_is_parameterized_density
|
335
|
-
)
|
336
|
-
if _is_parameterized_density(p)
|
337
|
-
]
|
338
|
-
return penalty * jnp.sum(jnp.asarray(losses))
|
339
|
-
|
340
|
-
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
341
|
-
def _update_leaf(leaf: Any) -> Any:
|
342
|
-
if not _is_parameterized_density(leaf):
|
343
|
-
return leaf
|
344
|
-
return density_parameterization.update(leaf, step)
|
345
|
-
|
346
|
-
return tree_util.tree_map(
|
347
|
-
_update_leaf,
|
348
|
-
latent_params,
|
349
|
-
is_leaf=_is_parameterized_density,
|
350
|
-
)
|
351
|
-
|
352
299
|
def init_fn(params: PyTree) -> LbfgsbState:
|
353
300
|
"""Initializes the optimization state."""
|
354
301
|
|
@@ -368,10 +315,17 @@ def parameterized_lbfgsb(
|
|
368
315
|
return latent_params, scipy_lbfgsb_state.to_jax()
|
369
316
|
|
370
317
|
latent_params = _init_latents(params)
|
371
|
-
|
372
|
-
|
318
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
319
|
+
latents, jax_lbfgsb_state = jax.pure_callback(
|
320
|
+
_init_state_pure, _example_state(latents, maxcor), latents
|
321
|
+
)
|
322
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
323
|
+
return (
|
324
|
+
0, # step
|
325
|
+
_params_from_latent_params(latent_params), # params
|
326
|
+
latent_params, # latent params
|
327
|
+
jax_lbfgsb_state, # opt state
|
373
328
|
)
|
374
|
-
return 0, _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
|
375
329
|
|
376
330
|
def params_fn(state: LbfgsbState) -> PyTree:
|
377
331
|
"""Returns the parameters for the given `state`."""
|
@@ -403,46 +357,118 @@ def parameterized_lbfgsb(
|
|
403
357
|
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
404
358
|
|
405
359
|
step, _, latent_params, jax_lbfgsb_state = state
|
406
|
-
|
407
|
-
|
360
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
361
|
+
|
362
|
+
def _params_from_latents(latents: PyTree) -> PyTree:
|
363
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
364
|
+
return _params_from_latent_params(latent_params)
|
365
|
+
|
366
|
+
def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
|
367
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
368
|
+
return _constraint_loss(latent_params)
|
369
|
+
|
370
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
371
|
+
(latents_grad,) = vjp_fn(grad)
|
408
372
|
|
409
373
|
if not (
|
410
|
-
tree_util.tree_structure(
|
411
|
-
== tree_util.tree_structure(
|
374
|
+
tree_util.tree_structure(latents_grad)
|
375
|
+
== tree_util.tree_structure(latents) # type: ignore[operator]
|
412
376
|
):
|
413
377
|
raise ValueError(
|
414
|
-
f"Tree structure of `
|
415
|
-
f"{tree_util.tree_structure(
|
416
|
-
f"{tree_util.tree_structure(
|
378
|
+
f"Tree structure of `latents_grad` was different than expected, got \n"
|
379
|
+
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
380
|
+
f"{tree_util.tree_structure(latents)}."
|
417
381
|
)
|
418
382
|
|
419
383
|
(
|
420
384
|
constraint_loss_value,
|
421
385
|
constraint_loss_grad,
|
422
386
|
) = jax.value_and_grad(
|
423
|
-
|
424
|
-
)(
|
387
|
+
_constraint_loss_latents
|
388
|
+
)(latents)
|
425
389
|
value += constraint_loss_value
|
426
|
-
|
427
|
-
lambda a, b: a + b,
|
390
|
+
latents_grad = tree_util.tree_map(
|
391
|
+
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
428
392
|
)
|
429
393
|
|
430
|
-
|
431
|
-
|
394
|
+
flat_latents_grad, unflatten_fn = flatten_util.ravel_pytree(
|
395
|
+
latents_grad
|
432
396
|
) # type: ignore[no-untyped-call]
|
433
397
|
|
434
|
-
|
398
|
+
flat_latents, jax_lbfgsb_state = jax.pure_callback(
|
435
399
|
_update_pure,
|
436
|
-
(
|
437
|
-
|
400
|
+
(flat_latents_grad, jax_lbfgsb_state),
|
401
|
+
flat_latents_grad,
|
438
402
|
value,
|
439
403
|
jax_lbfgsb_state,
|
440
404
|
)
|
441
|
-
|
405
|
+
latents = unflatten_fn(flat_latents)
|
406
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
442
407
|
latent_params = _update_parameterized_densities(latent_params, step)
|
443
|
-
params =
|
408
|
+
params = _params_from_latent_params(latent_params)
|
444
409
|
return step + 1, params, latent_params, jax_lbfgsb_state
|
445
410
|
|
411
|
+
# -------------------------------------------------------------------------
|
412
|
+
# Functions related to the density parameterization.
|
413
|
+
# -------------------------------------------------------------------------
|
414
|
+
|
415
|
+
def _init_latents(params: PyTree) -> PyTree:
|
416
|
+
def _leaf_init_latents(leaf: Any) -> Any:
|
417
|
+
leaf = _clip(leaf)
|
418
|
+
if not _is_density(leaf) or density_parameterization is None:
|
419
|
+
return leaf
|
420
|
+
return density_parameterization.from_density(leaf)
|
421
|
+
|
422
|
+
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
423
|
+
|
424
|
+
def _params_from_latent_params(latent_params: PyTree) -> PyTree:
|
425
|
+
def _leaf_params_from_latents(leaf: Any) -> Any:
|
426
|
+
if not _is_parameterized_density(leaf) or density_parameterization is None:
|
427
|
+
return leaf
|
428
|
+
return density_parameterization.to_density(leaf)
|
429
|
+
|
430
|
+
return tree_util.tree_map(
|
431
|
+
_leaf_params_from_latents,
|
432
|
+
latent_params,
|
433
|
+
is_leaf=_is_parameterized_density,
|
434
|
+
)
|
435
|
+
|
436
|
+
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
437
|
+
def _update_leaf(leaf: Any) -> Any:
|
438
|
+
if not _is_parameterized_density(leaf):
|
439
|
+
return leaf
|
440
|
+
return density_parameterization.update(leaf, step)
|
441
|
+
|
442
|
+
return tree_util.tree_map(
|
443
|
+
_update_leaf,
|
444
|
+
latent_params,
|
445
|
+
is_leaf=_is_parameterized_density,
|
446
|
+
)
|
447
|
+
|
448
|
+
# -------------------------------------------------------------------------
|
449
|
+
# Functions related to the constraints to be minimized.
|
450
|
+
# -------------------------------------------------------------------------
|
451
|
+
|
452
|
+
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
453
|
+
def _constraint_loss_leaf(
|
454
|
+
leaf: param_base.ParameterizedDensity2DArray,
|
455
|
+
) -> jnp.ndarray:
|
456
|
+
constraints = density_parameterization.constraints(leaf)
|
457
|
+
constraints = tree_util.tree_map(
|
458
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
459
|
+
constraints,
|
460
|
+
)
|
461
|
+
return jnp.sum(jnp.asarray(constraints))
|
462
|
+
|
463
|
+
losses = [0.0] + [
|
464
|
+
_constraint_loss_leaf(p)
|
465
|
+
for p in tree_util.tree_leaves(
|
466
|
+
latent_params, is_leaf=_is_parameterized_density
|
467
|
+
)
|
468
|
+
if _is_parameterized_density(p)
|
469
|
+
]
|
470
|
+
return penalty * jnp.sum(jnp.asarray(losses))
|
471
|
+
|
446
472
|
return base.Optimizer(
|
447
473
|
init=init_fn,
|
448
474
|
params=params_fn,
|
@@ -467,7 +493,7 @@ def _is_density(leaf: Any) -> Any:
|
|
467
493
|
|
468
494
|
def _is_parameterized_density(leaf: Any) -> Any:
|
469
495
|
"""Return `True` if `leaf` is a parameterized density array."""
|
470
|
-
return isinstance(leaf,
|
496
|
+
return isinstance(leaf, param_base.ParameterizedDensity2DArray)
|
471
497
|
|
472
498
|
|
473
499
|
def _is_custom_type(leaf: Any) -> bool:
|
@@ -13,7 +13,7 @@ from totypes import types
|
|
13
13
|
|
14
14
|
from invrs_opt.optimizers import base
|
15
15
|
from invrs_opt.parameterization import (
|
16
|
-
base as
|
16
|
+
base as param_base,
|
17
17
|
filter_project,
|
18
18
|
gaussian_levelset,
|
19
19
|
pixel,
|
@@ -158,7 +158,7 @@ def levelset_wrapped_optax(
|
|
158
158
|
|
159
159
|
def parameterized_wrapped_optax(
|
160
160
|
opt: optax.GradientTransformation,
|
161
|
-
density_parameterization: Optional[
|
161
|
+
density_parameterization: Optional[param_base.Density2DParameterization],
|
162
162
|
penalty: float,
|
163
163
|
) -> base.Optimizer:
|
164
164
|
"""Wrapped optax optimizer with specified density parameterization.
|
@@ -178,6 +178,74 @@ def parameterized_wrapped_optax(
|
|
178
178
|
if density_parameterization is None:
|
179
179
|
density_parameterization = pixel.pixel()
|
180
180
|
|
181
|
+
def init_fn(params: PyTree) -> WrappedOptaxState:
|
182
|
+
"""Initializes the optimization state."""
|
183
|
+
latent_params = _init_latents(params)
|
184
|
+
_, latents = param_base.partition_density_metadata(latent_params)
|
185
|
+
return (
|
186
|
+
0, # step
|
187
|
+
_params_from_latent_params(latent_params), # params
|
188
|
+
latent_params, # latent params
|
189
|
+
opt.init(latents), # opt state
|
190
|
+
)
|
191
|
+
|
192
|
+
def params_fn(state: WrappedOptaxState) -> PyTree:
|
193
|
+
"""Returns the parameters for the given `state`."""
|
194
|
+
_, params, _, _ = state
|
195
|
+
return params
|
196
|
+
|
197
|
+
def update_fn(
|
198
|
+
*,
|
199
|
+
grad: PyTree,
|
200
|
+
value: float,
|
201
|
+
params: PyTree,
|
202
|
+
state: WrappedOptaxState,
|
203
|
+
) -> WrappedOptaxState:
|
204
|
+
"""Updates the state."""
|
205
|
+
del value, params
|
206
|
+
|
207
|
+
step, params, latent_params, opt_state = state
|
208
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
209
|
+
|
210
|
+
def _params_from_latents(latents: PyTree) -> PyTree:
|
211
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
212
|
+
return _params_from_latent_params(latent_params)
|
213
|
+
|
214
|
+
def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
|
215
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
216
|
+
return _constraint_loss(latent_params)
|
217
|
+
|
218
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
219
|
+
(latents_grad,) = vjp_fn(grad)
|
220
|
+
|
221
|
+
if not (
|
222
|
+
tree_util.tree_structure(latents_grad)
|
223
|
+
== tree_util.tree_structure(latents) # type: ignore[operator]
|
224
|
+
):
|
225
|
+
raise ValueError(
|
226
|
+
f"Tree structure of `latents_grad` was different than expected, got \n"
|
227
|
+
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
228
|
+
f"{tree_util.tree_structure(latents)}."
|
229
|
+
)
|
230
|
+
|
231
|
+
constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents)
|
232
|
+
latents_grad = tree_util.tree_map(
|
233
|
+
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
234
|
+
)
|
235
|
+
|
236
|
+
updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents)
|
237
|
+
latents = optax.apply_updates(params=latents, updates=updates)
|
238
|
+
|
239
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
240
|
+
latent_params = _clip(latent_params)
|
241
|
+
latent_params = _update_parameterized_densities(latent_params, step + 1)
|
242
|
+
params = _params_from_latent_params(latent_params)
|
243
|
+
return (step + 1, params, latent_params, opt_state)
|
244
|
+
|
245
|
+
# -------------------------------------------------------------------------
|
246
|
+
# Functions related to the density parameterization.
|
247
|
+
# -------------------------------------------------------------------------
|
248
|
+
|
181
249
|
def _init_latents(params: PyTree) -> PyTree:
|
182
250
|
def _leaf_init_latents(leaf: Any) -> Any:
|
183
251
|
leaf = _clip(leaf)
|
@@ -187,7 +255,7 @@ def parameterized_wrapped_optax(
|
|
187
255
|
|
188
256
|
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
189
257
|
|
190
|
-
def
|
258
|
+
def _params_from_latent_params(params: PyTree) -> PyTree:
|
191
259
|
def _leaf_params_from_latents(leaf: Any) -> Any:
|
192
260
|
if not _is_parameterized_density(leaf):
|
193
261
|
return leaf
|
@@ -199,11 +267,27 @@ def parameterized_wrapped_optax(
|
|
199
267
|
is_leaf=_is_parameterized_density,
|
200
268
|
)
|
201
269
|
|
270
|
+
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
271
|
+
def _update_leaf(leaf: Any) -> Any:
|
272
|
+
if not _is_parameterized_density(leaf):
|
273
|
+
return leaf
|
274
|
+
return density_parameterization.update(leaf, step)
|
275
|
+
|
276
|
+
return tree_util.tree_map(
|
277
|
+
_update_leaf,
|
278
|
+
latent_params,
|
279
|
+
is_leaf=_is_parameterized_density,
|
280
|
+
)
|
281
|
+
|
282
|
+
# -------------------------------------------------------------------------
|
283
|
+
# Functions related to the constraints to be minimized.
|
284
|
+
# -------------------------------------------------------------------------
|
285
|
+
|
202
286
|
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
203
287
|
def _constraint_loss_leaf(
|
204
|
-
|
288
|
+
leaf: param_base.ParameterizedDensity2DArray,
|
205
289
|
) -> jnp.ndarray:
|
206
|
-
constraints = density_parameterization.constraints(
|
290
|
+
constraints = density_parameterization.constraints(leaf)
|
207
291
|
constraints = tree_util.tree_map(
|
208
292
|
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
209
293
|
constraints,
|
@@ -219,67 +303,6 @@ def parameterized_wrapped_optax(
|
|
219
303
|
]
|
220
304
|
return penalty * jnp.sum(jnp.asarray(losses))
|
221
305
|
|
222
|
-
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
223
|
-
def _update_leaf(leaf: Any) -> Any:
|
224
|
-
if not _is_parameterized_density(leaf):
|
225
|
-
return leaf
|
226
|
-
return density_parameterization.update(leaf, step)
|
227
|
-
|
228
|
-
return tree_util.tree_map(
|
229
|
-
_update_leaf,
|
230
|
-
latent_params,
|
231
|
-
is_leaf=_is_parameterized_density,
|
232
|
-
)
|
233
|
-
|
234
|
-
def init_fn(params: PyTree) -> WrappedOptaxState:
|
235
|
-
"""Initializes the optimization state."""
|
236
|
-
latent_params = _init_latents(params)
|
237
|
-
params = _params_from_latents(latent_params)
|
238
|
-
return 0, params, latent_params, opt.init(latent_params)
|
239
|
-
|
240
|
-
def params_fn(state: WrappedOptaxState) -> PyTree:
|
241
|
-
"""Returns the parameters for the given `state`."""
|
242
|
-
_, params, _, _ = state
|
243
|
-
return params
|
244
|
-
|
245
|
-
def update_fn(
|
246
|
-
*,
|
247
|
-
grad: PyTree,
|
248
|
-
value: float,
|
249
|
-
params: PyTree,
|
250
|
-
state: WrappedOptaxState,
|
251
|
-
) -> WrappedOptaxState:
|
252
|
-
"""Updates the state."""
|
253
|
-
del value, params
|
254
|
-
|
255
|
-
step, _, latent_params, opt_state = state
|
256
|
-
_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
|
257
|
-
(latent_grad,) = vjp_fn(grad)
|
258
|
-
|
259
|
-
if not (
|
260
|
-
tree_util.tree_structure(latent_grad)
|
261
|
-
== tree_util.tree_structure(latent_params) # type: ignore[operator]
|
262
|
-
):
|
263
|
-
raise ValueError(
|
264
|
-
f"Tree structure of `latent_grad` was different than expected, got \n"
|
265
|
-
f"{tree_util.tree_structure(latent_grad)} but expected \n"
|
266
|
-
f"{tree_util.tree_structure(latent_params)}."
|
267
|
-
)
|
268
|
-
|
269
|
-
constraint_loss_grad = jax.grad(_constraint_loss)(latent_params)
|
270
|
-
latent_grad = tree_util.tree_map(
|
271
|
-
lambda a, b: a + b, latent_grad, constraint_loss_grad
|
272
|
-
)
|
273
|
-
|
274
|
-
updates, opt_state = opt.update(
|
275
|
-
updates=latent_grad, state=opt_state, params=latent_params
|
276
|
-
)
|
277
|
-
latent_params = optax.apply_updates(params=latent_params, updates=updates)
|
278
|
-
latent_params = _clip(latent_params)
|
279
|
-
latent_params = _update_parameterized_densities(latent_params, step)
|
280
|
-
params = _params_from_latents(latent_params)
|
281
|
-
return step + 1, params, latent_params, opt_state
|
282
|
-
|
283
306
|
return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
|
284
307
|
|
285
308
|
|
@@ -290,7 +313,7 @@ def _is_density(leaf: Any) -> Any:
|
|
290
313
|
|
291
314
|
def _is_parameterized_density(leaf: Any) -> Any:
|
292
315
|
"""Return `True` if `leaf` is a parameterized density array."""
|
293
|
-
return isinstance(leaf,
|
316
|
+
return isinstance(leaf, param_base.ParameterizedDensity2DArray)
|
294
317
|
|
295
318
|
|
296
319
|
def _is_custom_type(leaf: Any) -> bool:
|
@@ -9,24 +9,74 @@ from typing import Any, Optional, Protocol, Sequence, Tuple
|
|
9
9
|
import jax.numpy as jnp
|
10
10
|
import numpy as onp
|
11
11
|
from jax import tree_util
|
12
|
-
from totypes import json_utils, types
|
12
|
+
from totypes import json_utils, partition_utils, types
|
13
13
|
|
14
14
|
Array = jnp.ndarray | onp.ndarray[Any, Any]
|
15
15
|
PyTree = Any
|
16
16
|
|
17
17
|
|
18
|
-
|
19
|
-
|
18
|
+
@dataclasses.dataclass
|
19
|
+
class ParameterizedDensity2DArray:
|
20
|
+
"""Stores latents and metadata for a parameterized density array."""
|
21
|
+
|
22
|
+
latents: "LatentsBase"
|
23
|
+
metadata: Optional["MetadataBase"]
|
24
|
+
|
25
|
+
|
26
|
+
class LatentsBase:
|
27
|
+
"""Base class for latents of a parameterized density array."""
|
28
|
+
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
class MetadataBase:
|
33
|
+
"""Base class for metadata of a parameterized density array."""
|
20
34
|
|
21
35
|
pass
|
22
36
|
|
23
37
|
|
38
|
+
tree_util.register_dataclass(
|
39
|
+
ParameterizedDensity2DArray,
|
40
|
+
data_fields=["latents", "metadata"],
|
41
|
+
meta_fields=[],
|
42
|
+
)
|
43
|
+
json_utils.register_custom_type(ParameterizedDensity2DArray)
|
44
|
+
|
45
|
+
|
46
|
+
def partition_density_metadata(tree: PyTree) -> Tuple[PyTree, PyTree]:
|
47
|
+
"""Splits a pytree with parameterized densities into metadata from latents."""
|
48
|
+
metadata, latents = partition_utils.partition(
|
49
|
+
tree,
|
50
|
+
select_fn=lambda x: isinstance(x, MetadataBase),
|
51
|
+
is_leaf=_is_metadata_or_none,
|
52
|
+
)
|
53
|
+
return metadata, latents
|
54
|
+
|
55
|
+
|
56
|
+
def combine_density_metadata(metadata: PyTree, latents: PyTree) -> PyTree:
|
57
|
+
"""Combines pytrees containing metadata and latents."""
|
58
|
+
return partition_utils.combine(metadata, latents, is_leaf=_is_metadata_or_none)
|
59
|
+
|
60
|
+
|
61
|
+
def _is_metadata_or_none(leaf: Any) -> bool:
|
62
|
+
"""Return `True` if `leaf` is `None` or density metadata."""
|
63
|
+
return leaf is None or isinstance(leaf, MetadataBase)
|
64
|
+
|
65
|
+
|
66
|
+
@dataclasses.dataclass
|
67
|
+
class Density2DParameterization:
|
68
|
+
"""Stores `(from_density, to_density, constraints, update)` function triple."""
|
69
|
+
|
70
|
+
from_density: "FromDensityFn"
|
71
|
+
to_density: "ToDensityFn"
|
72
|
+
constraints: "ConstraintsFn"
|
73
|
+
update: "UpdateFn"
|
74
|
+
|
75
|
+
|
24
76
|
class FromDensityFn(Protocol):
|
25
77
|
"""Generate the latent representation of a density array."""
|
26
78
|
|
27
|
-
def __call__(
|
28
|
-
self, density: types.Density2DArray
|
29
|
-
) -> ParameterizedDensity2DArrayBase:
|
79
|
+
def __call__(self, density: types.Density2DArray) -> ParameterizedDensity2DArray:
|
30
80
|
...
|
31
81
|
|
32
82
|
|
@@ -51,16 +101,6 @@ class UpdateFn(Protocol):
|
|
51
101
|
...
|
52
102
|
|
53
103
|
|
54
|
-
@dataclasses.dataclass
|
55
|
-
class Density2DParameterization:
|
56
|
-
"""Stores `(from_density, to_density, constraints)` function triple."""
|
57
|
-
|
58
|
-
from_density: FromDensityFn
|
59
|
-
to_density: ToDensityFn
|
60
|
-
constraints: ConstraintsFn
|
61
|
-
update: UpdateFn
|
62
|
-
|
63
|
-
|
64
104
|
@dataclasses.dataclass
|
65
105
|
class Density2DMetadata:
|
66
106
|
"""Stores the metadata of a `Density2DArray`."""
|
@@ -78,6 +118,12 @@ class Density2DMetadata:
|
|
78
118
|
self.periodic = tuple(self.periodic)
|
79
119
|
self.symmetries = tuple(self.symmetries)
|
80
120
|
|
121
|
+
@classmethod
|
122
|
+
def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
|
123
|
+
density_metadata_dict = dataclasses.asdict(density)
|
124
|
+
del density_metadata_dict["array"]
|
125
|
+
return Density2DMetadata(**density_metadata_dict)
|
126
|
+
|
81
127
|
|
82
128
|
def _flatten_density_2d_metadata(
|
83
129
|
metadata: Density2DMetadata,
|
@@ -13,25 +13,53 @@ from invrs_opt.parameterization import base, transforms
|
|
13
13
|
|
14
14
|
|
15
15
|
@dataclasses.dataclass
|
16
|
-
class
|
17
|
-
"""Stores
|
16
|
+
class FilterProjectParams(base.ParameterizedDensity2DArray):
|
17
|
+
"""Stores parameters for the filter-project parameterization."""
|
18
18
|
|
19
|
-
|
19
|
+
latents: "FilterProjectLatents"
|
20
|
+
metadata: "FilterProjectMetadata"
|
21
|
+
|
22
|
+
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class FilterProjectLatents(base.LatentsBase):
|
25
|
+
"""Stores latent parameters for the filter-project parameterization.
|
26
|
+
|
27
|
+
Attributes:s
|
20
28
|
latent_density: The latent variable from which the density is obtained.
|
21
|
-
beta: Determines the sharpness of the thresholding operation.
|
22
29
|
"""
|
23
30
|
|
24
31
|
latent_density: types.Density2DArray
|
32
|
+
|
33
|
+
|
34
|
+
@dataclasses.dataclass
|
35
|
+
class FilterProjectMetadata(base.MetadataBase):
|
36
|
+
"""Stores metadata for the filter-project parameterization.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
beta: Determines the sharpness of the thresholding operation.
|
40
|
+
"""
|
41
|
+
|
25
42
|
beta: float
|
26
43
|
|
27
44
|
|
28
45
|
tree_util.register_dataclass(
|
29
|
-
|
46
|
+
FilterProjectParams,
|
47
|
+
data_fields=["latents", "metadata"],
|
48
|
+
meta_fields=[],
|
49
|
+
)
|
50
|
+
tree_util.register_dataclass(
|
51
|
+
FilterProjectLatents,
|
30
52
|
data_fields=["latent_density"],
|
53
|
+
meta_fields=[],
|
54
|
+
)
|
55
|
+
tree_util.register_dataclass(
|
56
|
+
FilterProjectMetadata,
|
57
|
+
data_fields=[],
|
31
58
|
meta_fields=["beta"],
|
32
59
|
)
|
33
|
-
|
34
|
-
json_utils.register_custom_type(
|
60
|
+
json_utils.register_custom_type(FilterProjectParams)
|
61
|
+
json_utils.register_custom_type(FilterProjectLatents)
|
62
|
+
json_utils.register_custom_type(FilterProjectMetadata)
|
35
63
|
|
36
64
|
|
37
65
|
def filter_project(beta: float) -> base.Density2DParameterization:
|
@@ -55,24 +83,26 @@ def filter_project(beta: float) -> base.Density2DParameterization:
|
|
55
83
|
The `Density2DParameterization`.
|
56
84
|
"""
|
57
85
|
|
58
|
-
def from_density_fn(density: types.Density2DArray) ->
|
86
|
+
def from_density_fn(density: types.Density2DArray) -> FilterProjectParams:
|
59
87
|
"""Return latent parameters for the given `density`."""
|
60
88
|
array = transforms.normalized_array_from_density(density)
|
61
89
|
array = jnp.clip(array, -1, 1)
|
62
90
|
array *= jnp.tanh(beta)
|
63
91
|
latent_array = jnp.arctanh(array) / beta
|
64
92
|
latent_array = transforms.rescale_array_for_density(latent_array, density)
|
65
|
-
|
66
|
-
|
67
|
-
|
93
|
+
latent_density = density = dataclasses.replace(density, array=latent_array)
|
94
|
+
return FilterProjectParams(
|
95
|
+
latents=FilterProjectLatents(latent_density=latent_density),
|
96
|
+
metadata=FilterProjectMetadata(beta=beta),
|
68
97
|
)
|
69
98
|
|
70
|
-
def to_density_fn(params:
|
99
|
+
def to_density_fn(params: FilterProjectParams) -> types.Density2DArray:
|
71
100
|
"""Return a density from the latent parameters."""
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
)
|
101
|
+
latent_density = params.latents.latent_density
|
102
|
+
beta = params.metadata.beta
|
103
|
+
|
104
|
+
transformed = types.symmetrize_density(latent_density)
|
105
|
+
transformed = transforms.density_gaussian_filter_and_tanh(transformed, beta)
|
76
106
|
# Scale to ensure that the full valid range of the density array is reachable.
|
77
107
|
mid_value = (transformed.lower_bound + transformed.upper_bound) / 2
|
78
108
|
transformed = tree_util.tree_map(
|
@@ -80,12 +110,12 @@ def filter_project(beta: float) -> base.Density2DParameterization:
|
|
80
110
|
)
|
81
111
|
return transforms.apply_fixed_pixels(transformed)
|
82
112
|
|
83
|
-
def constraints_fn(params:
|
113
|
+
def constraints_fn(params: FilterProjectParams) -> jnp.ndarray:
|
84
114
|
"""Computes constraints associated with the params."""
|
85
115
|
del params
|
86
116
|
return jnp.asarray(0.0)
|
87
117
|
|
88
|
-
def update_fn(params:
|
118
|
+
def update_fn(params: FilterProjectParams, step: int) -> FilterProjectParams:
|
89
119
|
"""Perform updates to `params` required for the given `step`."""
|
90
120
|
del step
|
91
121
|
return params
|
@@ -29,12 +29,30 @@ DEFAULT_INIT_OPTIMIZER: optax.GradientTransformation = optax.adam(1e-1)
|
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
32
|
-
class GaussianLevelsetParams(base.
|
33
|
-
"""
|
32
|
+
class GaussianLevelsetParams(base.ParameterizedDensity2DArray):
|
33
|
+
"""Stores parameters for the Gaussian levelset parameterization."""
|
34
|
+
|
35
|
+
latents: "GaussianLevelsetLatents"
|
36
|
+
metadata: "GaussianLevelsetMetadata"
|
37
|
+
|
38
|
+
|
39
|
+
@dataclasses.dataclass
|
40
|
+
class GaussianLevelsetLatents(base.LatentsBase):
|
41
|
+
"""Stores latent parameters for the Gaussian levelset parameterization.
|
34
42
|
|
35
43
|
Attributes:
|
36
44
|
amplitude: Array giving the amplitude of the Gaussian basis function at
|
37
45
|
levelset control points.
|
46
|
+
"""
|
47
|
+
|
48
|
+
amplitude: jnp.ndarray
|
49
|
+
|
50
|
+
|
51
|
+
@dataclasses.dataclass
|
52
|
+
class GaussianLevelsetMetadata(base.MetadataBase):
|
53
|
+
"""Stores metadata for the Gaussian levelset parameterization.
|
54
|
+
|
55
|
+
Attributes:
|
38
56
|
length_scale_spacing_factor: The number of levelset control points per unit of
|
39
57
|
minimum length scale (mean of density minimum width and minimum spacing).
|
40
58
|
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
@@ -45,7 +63,6 @@ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase):
|
|
45
63
|
density_metadata: Metadata for the density array obtained from the parameters.
|
46
64
|
"""
|
47
65
|
|
48
|
-
amplitude: jnp.ndarray
|
49
66
|
length_scale_spacing_factor: float
|
50
67
|
length_scale_fwhm_factor: float
|
51
68
|
smoothing_factor: int
|
@@ -55,68 +72,31 @@ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase):
|
|
55
72
|
def __post_init__(self) -> None:
|
56
73
|
self.density_shape = tuple(self.density_shape)
|
57
74
|
|
58
|
-
def example_density(self) -> types.Density2DArray:
|
59
|
-
"""Returns an example density with appropriate shape and metadata."""
|
60
|
-
with jax.ensure_compile_time_eval():
|
61
|
-
return types.Density2DArray(
|
62
|
-
array=jnp.zeros(self.density_shape),
|
63
|
-
**dataclasses.asdict(self.density_metadata),
|
64
|
-
)
|
65
|
-
|
66
|
-
|
67
|
-
_GaussianLevelsetParamsAux = Tuple[
|
68
|
-
float, float, int, Tuple[int, ...], tree_util.PyTreeDef
|
69
|
-
]
|
70
|
-
|
71
75
|
|
72
|
-
|
73
|
-
params: GaussianLevelsetParams,
|
74
|
-
) -> Tuple[Tuple[jnp.ndarray], _GaussianLevelsetParamsAux]:
|
75
|
-
_, flat_metadata = tree_util.tree_flatten(params.density_metadata)
|
76
|
-
return (
|
77
|
-
(params.amplitude,),
|
78
|
-
(
|
79
|
-
params.length_scale_spacing_factor,
|
80
|
-
params.length_scale_fwhm_factor,
|
81
|
-
params.smoothing_factor,
|
82
|
-
params.density_shape,
|
83
|
-
flat_metadata,
|
84
|
-
),
|
85
|
-
)
|
86
|
-
|
87
|
-
|
88
|
-
def _unflatten_gaussian_levelset_params(
|
89
|
-
aux: _GaussianLevelsetParamsAux,
|
90
|
-
children: Tuple[jnp.ndarray],
|
91
|
-
) -> GaussianLevelsetParams:
|
92
|
-
(amplitude,) = children
|
93
|
-
(
|
94
|
-
length_scale_spacing_factor,
|
95
|
-
length_scale_fwhm_factor,
|
96
|
-
smoothing_factor,
|
97
|
-
density_shape,
|
98
|
-
flat_metadata,
|
99
|
-
) = aux
|
100
|
-
|
101
|
-
density_metadata = tree_util.tree_unflatten(flat_metadata, ())
|
102
|
-
return GaussianLevelsetParams(
|
103
|
-
amplitude=amplitude,
|
104
|
-
length_scale_spacing_factor=length_scale_spacing_factor,
|
105
|
-
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
106
|
-
smoothing_factor=smoothing_factor,
|
107
|
-
density_shape=tuple(density_shape),
|
108
|
-
density_metadata=density_metadata,
|
109
|
-
)
|
110
|
-
|
111
|
-
|
112
|
-
tree_util.register_pytree_node(
|
76
|
+
tree_util.register_dataclass(
|
113
77
|
GaussianLevelsetParams,
|
114
|
-
|
115
|
-
|
78
|
+
data_fields=["latents", "metadata"],
|
79
|
+
meta_fields=[],
|
80
|
+
)
|
81
|
+
tree_util.register_dataclass(
|
82
|
+
GaussianLevelsetLatents,
|
83
|
+
data_fields=["amplitude"],
|
84
|
+
meta_fields=[],
|
85
|
+
)
|
86
|
+
tree_util.register_dataclass(
|
87
|
+
GaussianLevelsetMetadata,
|
88
|
+
data_fields=[
|
89
|
+
"length_scale_spacing_factor",
|
90
|
+
"length_scale_fwhm_factor",
|
91
|
+
"smoothing_factor",
|
92
|
+
"density_shape",
|
93
|
+
"density_metadata",
|
94
|
+
],
|
95
|
+
meta_fields=[],
|
116
96
|
)
|
117
|
-
|
118
|
-
|
119
97
|
json_utils.register_custom_type(GaussianLevelsetParams)
|
98
|
+
json_utils.register_custom_type(GaussianLevelsetLatents)
|
99
|
+
json_utils.register_custom_type(GaussianLevelsetMetadata)
|
120
100
|
|
121
101
|
|
122
102
|
def gaussian_levelset(
|
@@ -187,23 +167,21 @@ def gaussian_levelset(
|
|
187
167
|
pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),)
|
188
168
|
amplitude = jnp.pad(amplitude, pad_width, mode="edge")
|
189
169
|
|
190
|
-
|
191
|
-
|
192
|
-
density_metadata = base.Density2DMetadata(**density_metadata_dict)
|
193
|
-
params = GaussianLevelsetParams(
|
194
|
-
amplitude=amplitude,
|
170
|
+
latents = GaussianLevelsetLatents(amplitude=amplitude)
|
171
|
+
metadata = GaussianLevelsetMetadata(
|
195
172
|
length_scale_spacing_factor=length_scale_spacing_factor,
|
196
173
|
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
197
174
|
smoothing_factor=smoothing_factor,
|
198
175
|
density_shape=density.shape,
|
199
|
-
density_metadata=
|
176
|
+
density_metadata=base.Density2DMetadata.from_density(density),
|
200
177
|
)
|
201
178
|
|
202
179
|
def step_fn(
|
203
180
|
_: int,
|
204
181
|
params_and_state: Tuple[PyTree, PyTree],
|
205
182
|
) -> Tuple[PyTree, PyTree]:
|
206
|
-
def loss_fn(
|
183
|
+
def loss_fn(latents: GaussianLevelsetLatents) -> jnp.ndarray:
|
184
|
+
params = GaussianLevelsetParams(latents, metadata=metadata)
|
207
185
|
density_from_params = to_density_fn(params, mask_gradient=False)
|
208
186
|
return jnp.mean((density_from_params.array - target_array) ** 2)
|
209
187
|
|
@@ -213,13 +191,14 @@ def gaussian_levelset(
|
|
213
191
|
params = optax.apply_updates(params, updates)
|
214
192
|
return params, state
|
215
193
|
|
216
|
-
state = init_optimizer.init(
|
217
|
-
|
218
|
-
0, init_steps, body_fun=step_fn, init_val=(
|
194
|
+
state = init_optimizer.init(latents)
|
195
|
+
latents, _ = jax.lax.fori_loop(
|
196
|
+
0, init_steps, body_fun=step_fn, init_val=(latents, state)
|
219
197
|
)
|
220
198
|
|
221
|
-
maxval = jnp.amax(jnp.abs(
|
222
|
-
|
199
|
+
maxval = jnp.amax(jnp.abs(latents.amplitude), axis=(-2, -1), keepdims=True)
|
200
|
+
latents = dataclasses.replace(latents, amplitude=latents.amplitude / maxval)
|
201
|
+
return GaussianLevelsetParams(latents=latents, metadata=metadata)
|
223
202
|
|
224
203
|
def to_density_fn(
|
225
204
|
params: GaussianLevelsetParams,
|
@@ -228,7 +207,7 @@ def gaussian_levelset(
|
|
228
207
|
"""Return a density from the latent parameters."""
|
229
208
|
array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0)
|
230
209
|
|
231
|
-
example_density = params
|
210
|
+
example_density = _example_density(params)
|
232
211
|
lb = example_density.lower_bound
|
233
212
|
ub = example_density.upper_bound
|
234
213
|
array = lb + array * (ub - lb)
|
@@ -263,7 +242,7 @@ def gaussian_levelset(
|
|
263
242
|
)
|
264
243
|
|
265
244
|
# Normalize constraints to make them (somewhat) resolution-independent.
|
266
|
-
example_density = params
|
245
|
+
example_density = _example_density(params)
|
267
246
|
length_scale = 0.5 * (
|
268
247
|
example_density.minimum_spacing + example_density.minimum_width
|
269
248
|
)
|
@@ -287,6 +266,15 @@ def gaussian_levelset(
|
|
287
266
|
# -----------------------------------------------------------------------------
|
288
267
|
|
289
268
|
|
269
|
+
def _example_density(params: GaussianLevelsetParams) -> types.Density2DArray:
|
270
|
+
"""Returns an example density with appropriate shape and metadata."""
|
271
|
+
with jax.ensure_compile_time_eval():
|
272
|
+
return types.Density2DArray(
|
273
|
+
array=jnp.zeros(params.metadata.density_shape),
|
274
|
+
**dataclasses.asdict(params.metadata.density_metadata),
|
275
|
+
)
|
276
|
+
|
277
|
+
|
290
278
|
def _to_array(
|
291
279
|
params: GaussianLevelsetParams,
|
292
280
|
mask_gradient: bool,
|
@@ -308,7 +296,7 @@ def _to_array(
|
|
308
296
|
Returns:
|
309
297
|
The array.
|
310
298
|
"""
|
311
|
-
example_density = params
|
299
|
+
example_density = _example_density(params)
|
312
300
|
periodic: Tuple[bool, bool] = example_density.periodic
|
313
301
|
phi = _phi_from_params(
|
314
302
|
params=params,
|
@@ -319,7 +307,7 @@ def _to_array(
|
|
319
307
|
periodic=periodic,
|
320
308
|
mask_gradient=mask_gradient,
|
321
309
|
)
|
322
|
-
return _downsample_spatial_dims(array, params.smoothing_factor)
|
310
|
+
return _downsample_spatial_dims(array, params.metadata.smoothing_factor)
|
323
311
|
|
324
312
|
|
325
313
|
def _phi_from_params(
|
@@ -337,32 +325,35 @@ def _phi_from_params(
|
|
337
325
|
The levelset array `phi`.
|
338
326
|
"""
|
339
327
|
with jax.ensure_compile_time_eval():
|
340
|
-
example_density = params
|
328
|
+
example_density = _example_density(params)
|
341
329
|
length_scale = 0.5 * (
|
342
330
|
example_density.minimum_width + example_density.minimum_spacing
|
343
331
|
)
|
344
|
-
fwhm = length_scale * params.length_scale_fwhm_factor
|
332
|
+
fwhm = length_scale * params.metadata.length_scale_fwhm_factor
|
345
333
|
sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
|
346
334
|
|
335
|
+
s_factor = params.metadata.smoothing_factor
|
347
336
|
highres_i = (
|
348
337
|
0.5
|
349
338
|
+ jnp.arange(
|
350
|
-
|
351
|
-
|
339
|
+
s_factor * (-pad_pixels),
|
340
|
+
s_factor * (pad_pixels + example_density.shape[-2]),
|
352
341
|
)
|
353
|
-
) /
|
342
|
+
) / s_factor
|
354
343
|
highres_j = (
|
355
344
|
0.5
|
356
345
|
+ jnp.arange(
|
357
|
-
|
358
|
-
|
346
|
+
s_factor * (-pad_pixels),
|
347
|
+
s_factor * (pad_pixels + example_density.shape[-1]),
|
359
348
|
)
|
360
|
-
) /
|
349
|
+
) / s_factor
|
361
350
|
|
362
351
|
# Coordinates for the control points of the Gaussian radial basis functions.
|
363
352
|
levelset_i, levelset_j = _control_point_coords(
|
364
|
-
density_shape=params.density_shape[-2:], # type: ignore[arg-type]
|
365
|
-
levelset_shape=
|
353
|
+
density_shape=params.metadata.density_shape[-2:], # type: ignore[arg-type]
|
354
|
+
levelset_shape=(
|
355
|
+
params.latents.amplitude.shape[-2:] # type: ignore[arg-type]
|
356
|
+
),
|
366
357
|
periodic=example_density.periodic,
|
367
358
|
)
|
368
359
|
|
@@ -391,7 +382,7 @@ def _phi_from_params(
|
|
391
382
|
levelset_i = levelset_i.flatten()
|
392
383
|
levelset_j = levelset_j.flatten()
|
393
384
|
|
394
|
-
amplitude = params.amplitude
|
385
|
+
amplitude = params.latents.amplitude
|
395
386
|
if example_density.periodic[0]:
|
396
387
|
amplitude = jnp.concat([amplitude] * 3, axis=-2)
|
397
388
|
if example_density.periodic[1]:
|
@@ -410,8 +401,8 @@ def _phi_from_params(
|
|
410
401
|
_, array = jax.lax.scan(scan_fn, (), xs=highres_i)
|
411
402
|
array = jnp.moveaxis(array, 0, -2)
|
412
403
|
|
413
|
-
assert array.shape[-2] %
|
414
|
-
assert array.shape[-1] %
|
404
|
+
assert array.shape[-2] % s_factor == 0
|
405
|
+
assert array.shape[-1] % s_factor == 0
|
415
406
|
array = symmetry.symmetrize(array, tuple(example_density.symmetries))
|
416
407
|
return array
|
417
408
|
|
@@ -443,7 +434,7 @@ def _fixed_pixel_constraint(
|
|
443
434
|
"""
|
444
435
|
array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels)
|
445
436
|
|
446
|
-
example_density = params
|
437
|
+
example_density = _example_density(params)
|
447
438
|
fixed_solid = jnp.zeros(example_density.shape[-2:], dtype=bool)
|
448
439
|
fixed_void = jnp.zeros(example_density.shape[-2:], dtype=bool)
|
449
440
|
if example_density.fixed_solid is not None:
|
@@ -491,9 +482,9 @@ def _levelset_constraints(
|
|
491
482
|
beyond the boundaries of the parameterized density.
|
492
483
|
|
493
484
|
Returns:
|
494
|
-
The minimum length scale and minimum curvature constraint arrays.
|
485
|
+
The minimum length scale and minimum curvature constraint arrays.s
|
495
486
|
"""
|
496
|
-
example_density = params
|
487
|
+
example_density = _example_density(params)
|
497
488
|
minimum_length_scale = 0.5 * (
|
498
489
|
example_density.minimum_width + example_density.minimum_spacing
|
499
490
|
)
|
@@ -512,9 +503,10 @@ def _levelset_constraints(
|
|
512
503
|
)
|
513
504
|
|
514
505
|
# Downsample so that constraints shape matches the density shape.
|
506
|
+
factor = params.metadata.smoothing_factor
|
515
507
|
return (
|
516
|
-
_downsample_spatial_dims(length_scale_constraint,
|
517
|
-
_downsample_spatial_dims(curvature_constraint,
|
508
|
+
_downsample_spatial_dims(length_scale_constraint, factor),
|
509
|
+
_downsample_spatial_dims(curvature_constraint, factor),
|
518
510
|
)
|
519
511
|
|
520
512
|
|
@@ -529,7 +521,7 @@ def _phi_derivatives_and_inverse_radius(
|
|
529
521
|
pad_pixels=pad_pixels,
|
530
522
|
)
|
531
523
|
|
532
|
-
d = 1 / params.smoothing_factor
|
524
|
+
d = 1 / params.metadata.smoothing_factor
|
533
525
|
phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1))
|
534
526
|
phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1))
|
535
527
|
phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1))
|
@@ -13,26 +13,40 @@ from invrs_opt.parameterization import base
|
|
13
13
|
|
14
14
|
|
15
15
|
@dataclasses.dataclass
|
16
|
-
class PixelParams(base.
|
17
|
-
|
16
|
+
class PixelParams(base.ParameterizedDensity2DArray):
|
17
|
+
latents: "PixelLatents"
|
18
|
+
metadata: None = None
|
18
19
|
|
19
|
-
density: types.Density2DArray
|
20
20
|
|
21
|
+
@dataclasses.dataclass
|
22
|
+
class PixelLatents(base.LatentsBase):
|
23
|
+
"""Stores latent parameters for the direct pixel parameterization."""
|
21
24
|
|
22
|
-
|
25
|
+
density: types.Density2DArray
|
23
26
|
|
24
27
|
|
28
|
+
tree_util.register_dataclass(
|
29
|
+
PixelParams,
|
30
|
+
data_fields=["latents"],
|
31
|
+
meta_fields=[],
|
32
|
+
)
|
33
|
+
tree_util.register_dataclass(
|
34
|
+
PixelLatents,
|
35
|
+
data_fields=["density"],
|
36
|
+
meta_fields=[],
|
37
|
+
)
|
25
38
|
json_utils.register_custom_type(PixelParams)
|
39
|
+
json_utils.register_custom_type(PixelLatents)
|
26
40
|
|
27
41
|
|
28
42
|
def pixel() -> base.Density2DParameterization:
|
29
43
|
"""Return the direct pixel parameterization."""
|
30
44
|
|
31
45
|
def from_density_fn(density: types.Density2DArray) -> PixelParams:
|
32
|
-
return PixelParams(density=density)
|
46
|
+
return PixelParams(latents=PixelLatents(density=density))
|
33
47
|
|
34
48
|
def to_density_fn(params: PixelParams) -> types.Density2DArray:
|
35
|
-
return params.density
|
49
|
+
return params.latents.density
|
36
50
|
|
37
51
|
def constraints_fn(params: PixelParams) -> jnp.ndarray:
|
38
52
|
del params
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.9.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
|
|
533
533
|
Requires-Dist: pytest-subtests; extra == "tests"
|
534
534
|
|
535
535
|
# invrs-opt - Optimization algorithms for inverse design
|
536
|
-
`v0.
|
536
|
+
`v0.9.0`
|
537
537
|
|
538
538
|
## Overview
|
539
539
|
|
@@ -0,0 +1,20 @@
|
|
1
|
+
invrs_opt/__init__.py,sha256=LkjQEBq5HYuP2WgS265RPBoIe48tcG4dxGfMxPOXa68,585
|
2
|
+
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
|
5
|
+
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
6
|
+
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
+
invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
|
8
|
+
invrs_opt/optimizers/lbfgsb.py,sha256=8BPiEAqececL-zLnqrgN0CogGDkAd1tyAGndUB-kahc,36349
|
9
|
+
invrs_opt/optimizers/wrapped_optax.py,sha256=VXdCteT2kumqhP81l3p6QiEqwBffoUuJ3UjrAyX5ToA,13468
|
10
|
+
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
invrs_opt/parameterization/base.py,sha256=BObzbz6efT2nBjib0_5BSdkCmFi2f0mcZ9VJYpDzO6Q,5278
|
12
|
+
invrs_opt/parameterization/filter_project.py,sha256=7Jb8JVENmBTdx3-XmI-VRm4aMjxg_Wtin8tMKKKxWvQ,4309
|
13
|
+
invrs_opt/parameterization/gaussian_levelset.py,sha256=N10UeezN4LdrlACxTEfJgJfhqh79zhUAY8XiSYvFYRY,24865
|
14
|
+
invrs_opt/parameterization/pixel.py,sha256=CCSuWF_bDebwmwTG33vmZwmvZoDJXYAR1kCrz4KnLt8,1607
|
15
|
+
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
+
invrs_opt-0.9.0.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
+
invrs_opt-0.9.0.dist-info/METADATA,sha256=qS0kFxxwAsfBW1igDyJHel8xFDsQQWCiUh7f1QRRGEs,32633
|
18
|
+
invrs_opt-0.9.0.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
19
|
+
invrs_opt-0.9.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
+
invrs_opt-0.9.0.dist-info/RECORD,,
|
invrs_opt-0.8.0.dist-info/RECORD
DELETED
@@ -1,20 +0,0 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=kTrg48iZu7i5OlH0Nqtfh_wBn3be9u2eZgJBRqUr6uQ,585
|
2
|
-
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
|
5
|
-
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
6
|
-
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
-
invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
|
8
|
-
invrs_opt/optimizers/lbfgsb.py,sha256=2NcyFllUM5CrjOZLEBocYAlCrkWt1fiRGJrxg05waog,35149
|
9
|
-
invrs_opt/optimizers/wrapped_optax.py,sha256=zPC2j_KkfF2RqeorxB38ovuUsg1SNwdxwhjA7gvOMC4,12387
|
10
|
-
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
invrs_opt/parameterization/base.py,sha256=k2VXZorMnkmaIWWY_G2CvS0NtxYojfvKlPccDGZyGNA,3779
|
12
|
-
invrs_opt/parameterization/filter_project.py,sha256=D2w8xrg34V8ysIbbr1RPvegM5WoLdz8QKUCGi80ieOI,3466
|
13
|
-
invrs_opt/parameterization/gaussian_levelset.py,sha256=DllvpkBpVxuDFQFu941f8v_Xh2D0m9Eh-JxBH5afNHU,25221
|
14
|
-
invrs_opt/parameterization/pixel.py,sha256=yWdGpf0x6Om_-7CYOcHTgNK9UF-XGAPM7GbOBNijnBw,1305
|
15
|
-
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
-
invrs_opt-0.8.0.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
-
invrs_opt-0.8.0.dist-info/METADATA,sha256=WSzpZwByiBqOZfiwSGvaVaFa5TUAPj4b-mr2qZRGLq0,32633
|
18
|
-
invrs_opt-0.8.0.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
19
|
-
invrs_opt-0.8.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
-
invrs_opt-0.8.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|