flowMC 0.3.1__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flowMC-0.3.1/src/flowMC.egg-info → flowmc-0.3.2}/PKG-INFO +1 -1
- {flowMC-0.3.1 → flowmc-0.3.2}/setup.cfg +1 -1
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/Sampler.py +10 -2
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/base.py +1 -0
- flowmc-0.3.2/src/flowMC/strategy/optimization.py +120 -0
- {flowMC-0.3.1 → flowmc-0.3.2/src/flowMC.egg-info}/PKG-INFO +1 -1
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC.egg-info/SOURCES.txt +1 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/LICENSE +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/README.md +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/pyproject.toml +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/__init__.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/__init__.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/common.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/realNVP.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/rqSpline.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/Gaussian_random_walk.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/HMC.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/MALA.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/NF_proposal.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/__init__.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/base.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/flowHMC.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/strategy/__init__.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/strategy/base.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/strategy/global_tuning.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/strategy/importance_sampling.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/utils/EvolutionaryOptimizer.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/utils/PythonFunctionWrap.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/utils/__init__.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/utils/postprocessing.py +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC.egg-info/dependency_links.txt +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC.egg-info/requires.txt +0 -0
- {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC.egg-info/top_level.txt +0 -0
|
@@ -131,7 +131,7 @@ class Sampler:
|
|
|
131
131
|
)
|
|
132
132
|
self.optim_state = self.optim.init(eqx.filter(self.nf_model, eqx.is_array))
|
|
133
133
|
|
|
134
|
-
|
|
134
|
+
default_strategies = [
|
|
135
135
|
GlobalTuning(
|
|
136
136
|
n_dim=self.n_dim,
|
|
137
137
|
n_chains=self.n_chains,
|
|
@@ -161,7 +161,15 @@ class Sampler:
|
|
|
161
161
|
if kwargs.get("strategies") is not None:
|
|
162
162
|
kwargs_strategies = kwargs.get("strategies")
|
|
163
163
|
assert isinstance(kwargs_strategies, list)
|
|
164
|
-
self.strategies =
|
|
164
|
+
self.strategies = []
|
|
165
|
+
for strategy in kwargs_strategies:
|
|
166
|
+
if isinstance(strategy, str):
|
|
167
|
+
if strategy.lower() == "default":
|
|
168
|
+
self.strategies += default_strategies
|
|
169
|
+
else:
|
|
170
|
+
self.strategies.append(strategy)
|
|
171
|
+
else:
|
|
172
|
+
self.strategies = default_strategies
|
|
165
173
|
|
|
166
174
|
self.summary = {}
|
|
167
175
|
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import optax
|
|
4
|
+
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
|
|
5
|
+
from typing import Callable
|
|
6
|
+
|
|
7
|
+
from flowMC.proposal.base import ProposalBase
|
|
8
|
+
from flowMC.proposal.NF_proposal import NFProposal
|
|
9
|
+
from flowMC.strategy.base import Strategy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class optimization_Adam(Strategy):
|
|
13
|
+
|
|
14
|
+
"""
|
|
15
|
+
Optimize a set of chains using Adam optimization.
|
|
16
|
+
Note that if the posterior can go to infinity, this optimization scheme is likely to return NaNs.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
n_steps: int = 100
|
|
21
|
+
learning_rate: float = 1e-2
|
|
22
|
+
noise_level: float = 10
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def __name__(self):
|
|
26
|
+
return "AdamOptimization"
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
class_keys = list(self.__class__.__annotations__.keys())
|
|
33
|
+
for key, value in kwargs.items():
|
|
34
|
+
if key in class_keys:
|
|
35
|
+
if not key.startswith("__"):
|
|
36
|
+
setattr(self, key, value)
|
|
37
|
+
|
|
38
|
+
self.solver = optax.chain(
|
|
39
|
+
optax.adam(learning_rate=self.learning_rate),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
def __call__(
|
|
43
|
+
self,
|
|
44
|
+
rng_key: PRNGKeyArray,
|
|
45
|
+
local_sampler: ProposalBase,
|
|
46
|
+
global_sampler: NFProposal,
|
|
47
|
+
initial_position: Float[Array, " n_chain n_dim"],
|
|
48
|
+
data: dict,
|
|
49
|
+
) -> tuple[
|
|
50
|
+
PRNGKeyArray, Float[Array, " n_chain n_dim"], ProposalBase, NFProposal, PyTree
|
|
51
|
+
]:
|
|
52
|
+
def loss_fn(params: Float[Array, " n_dim"]) -> Float:
|
|
53
|
+
return -local_sampler.logpdf(params, data)
|
|
54
|
+
|
|
55
|
+
grad_fn = jax.jit(jax.grad(loss_fn))
|
|
56
|
+
|
|
57
|
+
def _kernel(carry, data):
|
|
58
|
+
key, params, opt_state = carry
|
|
59
|
+
|
|
60
|
+
key, subkey = jax.random.split(key)
|
|
61
|
+
grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
|
|
62
|
+
updates, opt_state = self.solver.update(grad, opt_state, params)
|
|
63
|
+
params = optax.apply_updates(params, updates)
|
|
64
|
+
return (key, params, opt_state), None
|
|
65
|
+
|
|
66
|
+
def _single_optimize(
|
|
67
|
+
key: PRNGKeyArray,
|
|
68
|
+
initial_position: Float[Array, " n_dim"],
|
|
69
|
+
) -> Float[Array, " n_dim"]:
|
|
70
|
+
|
|
71
|
+
opt_state = self.solver.init(initial_position)
|
|
72
|
+
|
|
73
|
+
(key, params, opt_state), _ = jax.lax.scan(
|
|
74
|
+
_kernel,
|
|
75
|
+
(key, initial_position, opt_state),
|
|
76
|
+
jnp.arange(self.n_steps),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return params # type: ignore
|
|
80
|
+
|
|
81
|
+
print("Using Adam optimization")
|
|
82
|
+
rng_key, subkey = jax.random.split(rng_key)
|
|
83
|
+
keys = jax.random.split(subkey, initial_position.shape[0])
|
|
84
|
+
optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
|
|
85
|
+
keys, initial_position
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
summary = {}
|
|
89
|
+
summary["initial_positions"] = initial_position
|
|
90
|
+
summary["initial_log_prob"] = local_sampler.logpdf_vmap(initial_position, data)
|
|
91
|
+
summary["final_positions"] = optimized_positions
|
|
92
|
+
summary["final_log_prob"] = local_sampler.logpdf_vmap(optimized_positions, data)
|
|
93
|
+
|
|
94
|
+
if jnp.isinf(summary['final_log_prob']).any() or jnp.isnan(summary['final_log_prob']).any():
|
|
95
|
+
print("Warning: Optimization accessed infinite or NaN log-probabilities.")
|
|
96
|
+
|
|
97
|
+
return rng_key, optimized_positions, local_sampler, global_sampler, summary
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Evosax_CMA_ES(Strategy):
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
**kwargs,
|
|
105
|
+
):
|
|
106
|
+
class_keys = list(self.__class__.__annotations__.keys())
|
|
107
|
+
for key, value in kwargs.items():
|
|
108
|
+
if key in class_keys:
|
|
109
|
+
if not key.startswith("__"):
|
|
110
|
+
setattr(self, key, value)
|
|
111
|
+
|
|
112
|
+
def __call__(
|
|
113
|
+
self,
|
|
114
|
+
rng_key: PRNGKeyArray,
|
|
115
|
+
local_sampler: ProposalBase,
|
|
116
|
+
global_sampler: NFProposal,
|
|
117
|
+
initial_position: Array,
|
|
118
|
+
data: dict,
|
|
119
|
+
) -> tuple[PRNGKeyArray, Array, ProposalBase, NFProposal, PyTree]:
|
|
120
|
+
raise NotImplementedError
|
|
@@ -25,6 +25,7 @@ src/flowMC/strategy/__init__.py
|
|
|
25
25
|
src/flowMC/strategy/base.py
|
|
26
26
|
src/flowMC/strategy/global_tuning.py
|
|
27
27
|
src/flowMC/strategy/importance_sampling.py
|
|
28
|
+
src/flowMC/strategy/optimization.py
|
|
28
29
|
src/flowMC/utils/EvolutionaryOptimizer.py
|
|
29
30
|
src/flowMC/utils/PythonFunctionWrap.py
|
|
30
31
|
src/flowMC/utils/__init__.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|