flowMC 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flowMC-0.3.0/src/flowMC.egg-info → flowmc-0.3.2}/PKG-INFO +1 -1
- {flowMC-0.3.0 → flowmc-0.3.2}/setup.cfg +1 -1
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/Sampler.py +17 -9
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/base.py +18 -11
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/common.py +2 -7
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/realNVP.py +7 -5
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/rqSpline.py +7 -5
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/Gaussian_random_walk.py +3 -1
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/HMC.py +3 -1
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/MALA.py +4 -2
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/NF_proposal.py +5 -3
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/base.py +2 -1
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/flowHMC.py +7 -6
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/strategy/base.py +5 -1
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/strategy/global_tuning.py +24 -11
- flowmc-0.3.2/src/flowMC/strategy/optimization.py +120 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/utils/PythonFunctionWrap.py +2 -1
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/utils/postprocessing.py +3 -2
- {flowMC-0.3.0 → flowmc-0.3.2/src/flowMC.egg-info}/PKG-INFO +1 -1
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC.egg-info/SOURCES.txt +1 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/LICENSE +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/README.md +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/pyproject.toml +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/__init__.py +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/__init__.py +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/__init__.py +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/strategy/__init__.py +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/strategy/importance_sampling.py +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/utils/EvolutionaryOptimizer.py +2 -2
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/utils/__init__.py +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC.egg-info/dependency_links.txt +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC.egg-info/requires.txt +0 -0
- {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC.egg-info/top_level.txt +0 -0
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
import pickle
|
|
2
|
-
|
|
3
|
-
import jax.numpy as jnp
|
|
4
|
-
from jaxtyping import Array, Int, Float, PRNGKeyArray
|
|
5
|
-
from tqdm import tqdm
|
|
2
|
+
|
|
6
3
|
import equinox as eqx
|
|
4
|
+
import jax.numpy as jnp
|
|
7
5
|
import optax
|
|
8
|
-
from
|
|
9
|
-
|
|
6
|
+
from jaxtyping import Array, Float, Int, PRNGKeyArray
|
|
7
|
+
|
|
10
8
|
from flowMC.nfmodel.base import NFModel
|
|
9
|
+
from flowMC.proposal.base import ProposalBase
|
|
10
|
+
from flowMC.proposal.NF_proposal import NFProposal
|
|
11
11
|
from flowMC.strategy.base import Strategy
|
|
12
|
-
from flowMC.strategy.global_tuning import
|
|
12
|
+
from flowMC.strategy.global_tuning import GlobalSampling, GlobalTuning
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class Sampler:
|
|
@@ -131,7 +131,7 @@ class Sampler:
|
|
|
131
131
|
)
|
|
132
132
|
self.optim_state = self.optim.init(eqx.filter(self.nf_model, eqx.is_array))
|
|
133
133
|
|
|
134
|
-
|
|
134
|
+
default_strategies = [
|
|
135
135
|
GlobalTuning(
|
|
136
136
|
n_dim=self.n_dim,
|
|
137
137
|
n_chains=self.n_chains,
|
|
@@ -161,7 +161,15 @@ class Sampler:
|
|
|
161
161
|
if kwargs.get("strategies") is not None:
|
|
162
162
|
kwargs_strategies = kwargs.get("strategies")
|
|
163
163
|
assert isinstance(kwargs_strategies, list)
|
|
164
|
-
self.strategies =
|
|
164
|
+
self.strategies = []
|
|
165
|
+
for strategy in kwargs_strategies:
|
|
166
|
+
if isinstance(strategy, str):
|
|
167
|
+
if strategy.lower() == "default":
|
|
168
|
+
self.strategies += default_strategies
|
|
169
|
+
else:
|
|
170
|
+
self.strategies.append(strategy)
|
|
171
|
+
else:
|
|
172
|
+
self.strategies = default_strategies
|
|
165
173
|
|
|
166
174
|
self.summary = {}
|
|
167
175
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
1
|
import copy
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import Optional, overload
|
|
4
|
+
|
|
3
5
|
import equinox as eqx
|
|
4
|
-
from typing import overload, Optional
|
|
5
|
-
from typing_extensions import Self
|
|
6
|
-
from jaxtyping import Array, PRNGKeyArray, Float
|
|
7
|
-
import optax
|
|
8
|
-
from tqdm import trange, tqdm
|
|
9
|
-
import jax.numpy as jnp
|
|
10
6
|
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import optax
|
|
9
|
+
from jaxtyping import Array, Float, PRNGKeyArray
|
|
10
|
+
from tqdm import tqdm, trange
|
|
11
|
+
from typing_extensions import Self
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class NFModel(eqx.Module):
|
|
@@ -102,9 +103,10 @@ class NFModel(eqx.Module):
|
|
|
102
103
|
model (eqx.Model): Updated model.
|
|
103
104
|
opt_state (optax.OptState): Updated optimizer state.
|
|
104
105
|
"""
|
|
105
|
-
|
|
106
|
+
model = self
|
|
107
|
+
loss, grads = model.loss_fn(x)
|
|
106
108
|
updates, state = optim.update(grads, state)
|
|
107
|
-
model = eqx.apply_updates(
|
|
109
|
+
model = eqx.apply_updates(model, updates)
|
|
108
110
|
return loss, model, state
|
|
109
111
|
|
|
110
112
|
def train_epoch(
|
|
@@ -162,16 +164,21 @@ class NFModel(eqx.Module):
|
|
|
162
164
|
pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10))
|
|
163
165
|
else:
|
|
164
166
|
pbar = range(num_epochs)
|
|
167
|
+
|
|
165
168
|
best_model = model = self
|
|
169
|
+
best_state = state
|
|
166
170
|
best_loss = 1e9
|
|
167
171
|
for epoch in pbar:
|
|
168
172
|
# Use a separate PRNG key to permute image data during shuffling
|
|
169
173
|
rng, input_rng = jax.random.split(rng)
|
|
170
174
|
# Run an optimization step over a training batch
|
|
171
|
-
value, model, state = model.train_epoch(
|
|
175
|
+
value, model, state = model.train_epoch(
|
|
176
|
+
input_rng, optim, state, data, batch_size
|
|
177
|
+
)
|
|
172
178
|
loss_values = loss_values.at[epoch].set(value)
|
|
173
179
|
if loss_values[epoch] < best_loss:
|
|
174
180
|
best_model = model
|
|
181
|
+
best_state = state
|
|
175
182
|
best_loss = loss_values[epoch]
|
|
176
183
|
if verbose:
|
|
177
184
|
assert isinstance(pbar, tqdm)
|
|
@@ -182,7 +189,7 @@ class NFModel(eqx.Module):
|
|
|
182
189
|
if epoch == num_epochs:
|
|
183
190
|
pbar.set_description(f"Training NF, current loss: {value:.3f}")
|
|
184
191
|
|
|
185
|
-
return rng, best_model,
|
|
192
|
+
return rng, best_model, best_state, loss_values
|
|
186
193
|
|
|
187
194
|
|
|
188
195
|
class Bijection(eqx.Module):
|
|
@@ -1,17 +1,12 @@
|
|
|
1
1
|
from typing import Callable, List, Tuple
|
|
2
2
|
|
|
3
|
-
import jax
|
|
4
|
-
import jax.numpy as jnp
|
|
5
|
-
from jaxtyping import Array
|
|
6
3
|
import equinox as eqx
|
|
7
|
-
|
|
8
|
-
from flowMC.nfmodel.base import Bijection, Distribution
|
|
9
|
-
|
|
10
4
|
import jax
|
|
11
5
|
import jax.numpy as jnp
|
|
12
|
-
import equinox as eqx
|
|
13
6
|
from jaxtyping import Array, Float, PRNGKeyArray
|
|
14
7
|
|
|
8
|
+
from flowMC.nfmodel.base import Bijection, Distribution
|
|
9
|
+
|
|
15
10
|
|
|
16
11
|
class MLP(eqx.Module):
|
|
17
12
|
r"""Multilayer perceptron.
|
|
@@ -1,13 +1,15 @@
|
|
|
1
|
+
from functools import partial
|
|
1
2
|
from typing import List, Tuple
|
|
3
|
+
|
|
4
|
+
import equinox as eqx
|
|
2
5
|
import jax
|
|
3
6
|
import jax.numpy as jnp
|
|
4
7
|
import numpy as np
|
|
5
|
-
import equinox as eqx
|
|
6
|
-
from flowMC.nfmodel.base import NFModel, Distribution
|
|
7
|
-
from flowMC.nfmodel.common import MLP, MaskedCouplingLayer, MLPAffine, Gaussian
|
|
8
|
-
from jaxtyping import Array, Float, PRNGKeyArray
|
|
9
|
-
from functools import partial
|
|
10
8
|
import optax
|
|
9
|
+
from jaxtyping import Array, Float, PRNGKeyArray
|
|
10
|
+
|
|
11
|
+
from flowMC.nfmodel.base import Distribution, NFModel
|
|
12
|
+
from flowMC.nfmodel.common import MLP, Gaussian, MaskedCouplingLayer, MLPAffine
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class AffineCoupling(eqx.Module):
|
|
@@ -1,11 +1,13 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import equinox as eqx
|
|
1
4
|
import jax
|
|
2
5
|
import jax.numpy as jnp
|
|
3
|
-
from jaxtyping import Array,
|
|
4
|
-
import equinox as eqx
|
|
6
|
+
from jaxtyping import Array, Float, PRNGKeyArray
|
|
5
7
|
|
|
6
|
-
from flowMC.nfmodel.base import
|
|
7
|
-
from flowMC.nfmodel.common import
|
|
8
|
-
|
|
8
|
+
from flowMC.nfmodel.base import Bijection, Distribution, NFModel
|
|
9
|
+
from flowMC.nfmodel.common import (MLP, Gaussian, MaskedCouplingLayer,
|
|
10
|
+
ScalarAffine)
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
@partial(jax.vmap, in_axes=(0, None, None))
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from typing import Callable
|
|
2
|
+
|
|
2
3
|
import jax
|
|
3
4
|
import jax.numpy as jnp
|
|
5
|
+
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
|
|
4
6
|
from tqdm import tqdm
|
|
7
|
+
|
|
5
8
|
from flowMC.proposal.base import ProposalBase
|
|
6
|
-
from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class GaussianRandomWalk(ProposalBase):
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from typing import Callable
|
|
2
|
+
|
|
2
3
|
import jax
|
|
3
4
|
import jax.numpy as jnp
|
|
5
|
+
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
|
|
4
6
|
from tqdm import tqdm
|
|
7
|
+
|
|
5
8
|
from flowMC.proposal.base import ProposalBase
|
|
6
|
-
from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class HMC(ProposalBase):
|
|
@@ -1,11 +1,13 @@
|
|
|
1
|
+
from functools import partialmethod
|
|
1
2
|
from typing import Callable
|
|
3
|
+
|
|
2
4
|
import jax
|
|
3
5
|
import jax.numpy as jnp
|
|
4
6
|
from jax.scipy.stats import multivariate_normal
|
|
7
|
+
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, PyTree
|
|
5
8
|
from tqdm import tqdm
|
|
9
|
+
|
|
6
10
|
from flowMC.proposal.base import ProposalBase
|
|
7
|
-
from functools import partialmethod
|
|
8
|
-
from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray, Bool
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class MALA(ProposalBase):
|
|
@@ -1,12 +1,14 @@
|
|
|
1
|
+
from math import ceil
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
1
4
|
import jax
|
|
2
5
|
import jax.numpy as jnp
|
|
3
6
|
from jax import random
|
|
7
|
+
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
|
|
4
8
|
from tqdm import tqdm
|
|
9
|
+
|
|
5
10
|
from flowMC.nfmodel.base import NFModel
|
|
6
|
-
from typing import Callable
|
|
7
11
|
from flowMC.proposal.base import ProposalBase
|
|
8
|
-
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
|
|
9
|
-
from math import ceil
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
@jax.tree_util.register_pytree_node_class
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
2
|
from typing import Callable
|
|
3
|
+
|
|
3
4
|
import jax
|
|
4
5
|
import jax.numpy as jnp
|
|
5
|
-
from jaxtyping import
|
|
6
|
+
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
@jax.tree_util.register_pytree_node_class
|
|
@@ -1,14 +1,15 @@
|
|
|
1
|
+
from math import ceil
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
1
4
|
import jax
|
|
2
5
|
import jax.numpy as jnp
|
|
6
|
+
from jax import random
|
|
7
|
+
from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
3
10
|
from flowMC.nfmodel.base import NFModel
|
|
4
|
-
from jaxtyping import Array, PRNGKeyArray, PyTree
|
|
5
|
-
from typing import Callable
|
|
6
11
|
from flowMC.proposal.HMC import HMC
|
|
7
12
|
from flowMC.proposal.NF_proposal import NFProposal
|
|
8
|
-
from jaxtyping import Array, Float, Int, PRNGKeyArray
|
|
9
|
-
from math import ceil
|
|
10
|
-
from jax import random
|
|
11
|
-
from tqdm import tqdm
|
|
12
13
|
|
|
13
14
|
###################################
|
|
14
15
|
# This is not in production yet
|
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
+
|
|
3
|
+
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
|
|
4
|
+
|
|
2
5
|
from flowMC.proposal.base import ProposalBase
|
|
3
6
|
from flowMC.proposal.NF_proposal import NFProposal
|
|
4
|
-
|
|
7
|
+
|
|
8
|
+
|
|
5
9
|
class Strategy:
|
|
6
10
|
"""
|
|
7
11
|
Base class for strategies, which are basically wrapper blocks that modify the state of the sampler
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
|
|
2
|
-
from flowMC.proposal.NF_proposal import NFProposal
|
|
3
|
-
from flowMC.strategy.base import Strategy
|
|
1
|
+
import equinox as eqx
|
|
4
2
|
import jax
|
|
5
3
|
import jax.numpy as jnp
|
|
6
|
-
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
|
|
7
4
|
import optax
|
|
8
|
-
import
|
|
5
|
+
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
|
|
9
6
|
from tqdm import tqdm
|
|
10
7
|
|
|
8
|
+
from flowMC.proposal.base import ProposalBase
|
|
9
|
+
from flowMC.proposal.NF_proposal import NFProposal
|
|
10
|
+
from flowMC.strategy.base import Strategy
|
|
11
|
+
|
|
11
12
|
|
|
12
13
|
class GlobalTuning(Strategy):
|
|
13
14
|
|
|
@@ -76,6 +77,7 @@ class GlobalTuning(Strategy):
|
|
|
76
77
|
summary["local_accs"] = jnp.empty((self.n_chains, 0))
|
|
77
78
|
summary["global_accs"] = jnp.empty((self.n_chains, 0))
|
|
78
79
|
summary["loss_vals"] = jnp.empty((0, self.n_epochs))
|
|
80
|
+
current_position = initial_position
|
|
79
81
|
for _ in tqdm(
|
|
80
82
|
range(self.n_loop),
|
|
81
83
|
desc="Global Tuning",
|
|
@@ -90,7 +92,7 @@ class GlobalTuning(Strategy):
|
|
|
90
92
|
) = local_sampler.sample(
|
|
91
93
|
rng_keys_mcmc,
|
|
92
94
|
self.n_local_steps,
|
|
93
|
-
|
|
95
|
+
current_position,
|
|
94
96
|
data,
|
|
95
97
|
verbose=self.verbose,
|
|
96
98
|
)
|
|
@@ -112,6 +114,8 @@ class GlobalTuning(Strategy):
|
|
|
112
114
|
axis=1,
|
|
113
115
|
)
|
|
114
116
|
|
|
117
|
+
current_position = summary["chains"][:, -1]
|
|
118
|
+
|
|
115
119
|
rng_key, rng_keys_nf = jax.random.split(rng_key)
|
|
116
120
|
positions = summary["chains"][:, :: self.train_thinning]
|
|
117
121
|
chain_size = positions.shape[0] * positions.shape[1]
|
|
@@ -142,8 +146,8 @@ class GlobalTuning(Strategy):
|
|
|
142
146
|
|
|
143
147
|
(
|
|
144
148
|
rng_keys_nf,
|
|
145
|
-
|
|
146
|
-
|
|
149
|
+
model,
|
|
150
|
+
optim_state,
|
|
147
151
|
loss_values,
|
|
148
152
|
) = global_sampler.model.train(
|
|
149
153
|
rng_keys_nf,
|
|
@@ -154,6 +158,8 @@ class GlobalTuning(Strategy):
|
|
|
154
158
|
self.batch_size,
|
|
155
159
|
self.verbose,
|
|
156
160
|
)
|
|
161
|
+
global_sampler.model = model
|
|
162
|
+
self.optim_state = optim_state
|
|
157
163
|
summary["loss_vals"] = jnp.append(
|
|
158
164
|
summary["loss_vals"],
|
|
159
165
|
loss_values.reshape(1, -1),
|
|
@@ -168,7 +174,7 @@ class GlobalTuning(Strategy):
|
|
|
168
174
|
) = global_sampler.sample(
|
|
169
175
|
rng_keys_nf,
|
|
170
176
|
self.n_global_steps,
|
|
171
|
-
|
|
177
|
+
current_position,
|
|
172
178
|
data,
|
|
173
179
|
verbose=self.verbose,
|
|
174
180
|
)
|
|
@@ -190,7 +196,9 @@ class GlobalTuning(Strategy):
|
|
|
190
196
|
axis=1,
|
|
191
197
|
)
|
|
192
198
|
|
|
193
|
-
|
|
199
|
+
current_position = summary["chains"][:, -1]
|
|
200
|
+
|
|
201
|
+
return rng_key, current_position, local_sampler, global_sampler, summary
|
|
194
202
|
|
|
195
203
|
|
|
196
204
|
class GlobalSampling(Strategy):
|
|
@@ -239,6 +247,7 @@ class GlobalSampling(Strategy):
|
|
|
239
247
|
summary["local_accs"] = jnp.empty((self.n_chains, 0))
|
|
240
248
|
summary["global_accs"] = jnp.empty((self.n_chains, 0))
|
|
241
249
|
|
|
250
|
+
current_position = initial_position
|
|
242
251
|
for _ in tqdm(
|
|
243
252
|
range(self.n_loop),
|
|
244
253
|
desc="Global Sampling",
|
|
@@ -253,7 +262,7 @@ class GlobalSampling(Strategy):
|
|
|
253
262
|
) = local_sampler.sample(
|
|
254
263
|
rng_keys_mcmc,
|
|
255
264
|
self.n_local_steps,
|
|
256
|
-
|
|
265
|
+
current_position,
|
|
257
266
|
data,
|
|
258
267
|
verbose=self.verbose,
|
|
259
268
|
)
|
|
@@ -275,6 +284,8 @@ class GlobalSampling(Strategy):
|
|
|
275
284
|
axis=1,
|
|
276
285
|
)
|
|
277
286
|
|
|
287
|
+
current_position = summary["chains"][:, -1]
|
|
288
|
+
|
|
278
289
|
rng_key, rng_keys_nf = jax.random.split(rng_key)
|
|
279
290
|
(
|
|
280
291
|
rng_keys_nf,
|
|
@@ -306,4 +317,6 @@ class GlobalSampling(Strategy):
|
|
|
306
317
|
axis=1,
|
|
307
318
|
)
|
|
308
319
|
|
|
320
|
+
current_position = summary["chains"][:, -1]
|
|
321
|
+
|
|
309
322
|
return rng_key, summary['chains'][:, -1], local_sampler, global_sampler, summary
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import optax
|
|
4
|
+
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from flowMC.proposal.base import ProposalBase
|
|
8
|
+
from flowMC.proposal.NF_proposal import NFProposal
|
|
9
|
+
from flowMC.strategy.base import Strategy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class optimization_Adam(Strategy):
|
|
13
|
+
|
|
14
|
+
"""
|
|
15
|
+
Optimize a set of chains using Adam optimization.
|
|
16
|
+
Note that if the posterior can go to infinity, this optimization scheme is likely to return NaNs.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
n_steps: int = 100
|
|
21
|
+
learning_rate: float = 1e-2
|
|
22
|
+
noise_level: float = 10
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def __name__(self):
|
|
26
|
+
return "AdamOptimization"
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
class_keys = list(self.__class__.__annotations__.keys())
|
|
33
|
+
for key, value in kwargs.items():
|
|
34
|
+
if key in class_keys:
|
|
35
|
+
if not key.startswith("__"):
|
|
36
|
+
setattr(self, key, value)
|
|
37
|
+
|
|
38
|
+
self.solver = optax.chain(
|
|
39
|
+
optax.adam(learning_rate=self.learning_rate),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def __call__(
|
|
43
|
+
self,
|
|
44
|
+
rng_key: PRNGKeyArray,
|
|
45
|
+
local_sampler: ProposalBase,
|
|
46
|
+
global_sampler: NFProposal,
|
|
47
|
+
initial_position: Float[Array, " n_chain n_dim"],
|
|
48
|
+
data: dict,
|
|
49
|
+
) -> tuple[
|
|
50
|
+
PRNGKeyArray, Float[Array, " n_chain n_dim"], ProposalBase, NFProposal, PyTree
|
|
51
|
+
]:
|
|
52
|
+
def loss_fn(params: Float[Array, " n_dim"]) -> Float:
|
|
53
|
+
return -local_sampler.logpdf(params, data)
|
|
54
|
+
|
|
55
|
+
grad_fn = jax.jit(jax.grad(loss_fn))
|
|
56
|
+
|
|
57
|
+
def _kernel(carry, data):
|
|
58
|
+
key, params, opt_state = carry
|
|
59
|
+
|
|
60
|
+
key, subkey = jax.random.split(key)
|
|
61
|
+
grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
|
|
62
|
+
updates, opt_state = self.solver.update(grad, opt_state, params)
|
|
63
|
+
params = optax.apply_updates(params, updates)
|
|
64
|
+
return (key, params, opt_state), None
|
|
65
|
+
|
|
66
|
+
def _single_optimize(
|
|
67
|
+
key: PRNGKeyArray,
|
|
68
|
+
initial_position: Float[Array, " n_dim"],
|
|
69
|
+
) -> Float[Array, " n_dim"]:
|
|
70
|
+
|
|
71
|
+
opt_state = self.solver.init(initial_position)
|
|
72
|
+
|
|
73
|
+
(key, params, opt_state), _ = jax.lax.scan(
|
|
74
|
+
_kernel,
|
|
75
|
+
(key, initial_position, opt_state),
|
|
76
|
+
jnp.arange(self.n_steps),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return params # type: ignore
|
|
80
|
+
|
|
81
|
+
print("Using Adam optimization")
|
|
82
|
+
rng_key, subkey = jax.random.split(rng_key)
|
|
83
|
+
keys = jax.random.split(subkey, initial_position.shape[0])
|
|
84
|
+
optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
|
|
85
|
+
keys, initial_position
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
summary = {}
|
|
89
|
+
summary["initial_positions"] = initial_position
|
|
90
|
+
summary["initial_log_prob"] = local_sampler.logpdf_vmap(initial_position, data)
|
|
91
|
+
summary["final_positions"] = optimized_positions
|
|
92
|
+
summary["final_log_prob"] = local_sampler.logpdf_vmap(optimized_positions, data)
|
|
93
|
+
|
|
94
|
+
if jnp.isinf(summary['final_log_prob']).any() or jnp.isnan(summary['final_log_prob']).any():
|
|
95
|
+
print("Warning: Optimization accessed infinite or NaN log-probabilities.")
|
|
96
|
+
|
|
97
|
+
return rng_key, optimized_positions, local_sampler, global_sampler, summary
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Evosax_CMA_ES(Strategy):
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
**kwargs,
|
|
105
|
+
):
|
|
106
|
+
class_keys = list(self.__class__.__annotations__.keys())
|
|
107
|
+
for key, value in kwargs.items():
|
|
108
|
+
if key in class_keys:
|
|
109
|
+
if not key.startswith("__"):
|
|
110
|
+
setattr(self, key, value)
|
|
111
|
+
|
|
112
|
+
def __call__(
|
|
113
|
+
self,
|
|
114
|
+
rng_key: PRNGKeyArray,
|
|
115
|
+
local_sampler: ProposalBase,
|
|
116
|
+
global_sampler: NFProposal,
|
|
117
|
+
initial_position: Array,
|
|
118
|
+
data: dict,
|
|
119
|
+
) -> tuple[PRNGKeyArray, Array, ProposalBase, NFProposal, PyTree]:
|
|
120
|
+
raise NotImplementedError
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from functools import wraps
|
|
3
|
-
from typing import Any,
|
|
3
|
+
from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Tuple,
|
|
4
|
+
Union)
|
|
4
5
|
|
|
5
6
|
Array = Any
|
|
6
7
|
PyTree = Union[Array, Iterable[Array], Dict[Any, Array], NamedTuple]
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
import matplotlib.pyplot as plt
|
|
2
1
|
import jax.numpy as jnp
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
3
|
# from flowMC.sampler.Sampler import Sampler
|
|
4
|
-
from jaxtyping import
|
|
4
|
+
from jaxtyping import Array, Float
|
|
5
|
+
|
|
5
6
|
|
|
6
7
|
def plot_summary(sampler: object, training: bool = False, **plotkwargs) -> None:
|
|
7
8
|
"""
|
|
@@ -25,6 +25,7 @@ src/flowMC/strategy/__init__.py
|
|
|
25
25
|
src/flowMC/strategy/base.py
|
|
26
26
|
src/flowMC/strategy/global_tuning.py
|
|
27
27
|
src/flowMC/strategy/importance_sampling.py
|
|
28
|
+
src/flowMC/strategy/optimization.py
|
|
28
29
|
src/flowMC/utils/EvolutionaryOptimizer.py
|
|
29
30
|
src/flowMC/utils/PythonFunctionWrap.py
|
|
30
31
|
src/flowMC/utils/__init__.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|