flowMC 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {flowmc-0.3.2/src/flowMC.egg-info → flowmc-0.3.4}/PKG-INFO +1 -1
  2. {flowmc-0.3.2 → flowmc-0.3.4}/setup.cfg +1 -1
  3. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/Sampler.py +44 -1
  4. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/nfmodel/base.py +1 -1
  5. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/strategy/optimization.py +82 -5
  6. {flowmc-0.3.2 → flowmc-0.3.4/src/flowMC.egg-info}/PKG-INFO +1 -1
  7. {flowmc-0.3.2 → flowmc-0.3.4}/LICENSE +0 -0
  8. {flowmc-0.3.2 → flowmc-0.3.4}/README.md +0 -0
  9. {flowmc-0.3.2 → flowmc-0.3.4}/pyproject.toml +0 -0
  10. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/__init__.py +0 -0
  11. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/nfmodel/__init__.py +0 -0
  12. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/nfmodel/common.py +0 -0
  13. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/nfmodel/realNVP.py +0 -0
  14. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/nfmodel/rqSpline.py +0 -0
  15. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/proposal/Gaussian_random_walk.py +0 -0
  16. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/proposal/HMC.py +0 -0
  17. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/proposal/MALA.py +0 -0
  18. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/proposal/NF_proposal.py +0 -0
  19. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/proposal/__init__.py +0 -0
  20. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/proposal/base.py +0 -0
  21. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/proposal/flowHMC.py +0 -0
  22. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/strategy/__init__.py +0 -0
  23. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/strategy/base.py +0 -0
  24. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/strategy/global_tuning.py +0 -0
  25. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/strategy/importance_sampling.py +0 -0
  26. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/utils/EvolutionaryOptimizer.py +0 -0
  27. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/utils/PythonFunctionWrap.py +0 -0
  28. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/utils/__init__.py +0 -0
  29. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC/utils/postprocessing.py +0 -0
  30. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC.egg-info/SOURCES.txt +0 -0
  31. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC.egg-info/dependency_links.txt +0 -0
  32. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC.egg-info/requires.txt +0 -0
  33. {flowmc-0.3.2 → flowmc-0.3.4}/src/flowMC.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flowMC
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Normalizing flow exhanced sampler in jax
5
5
  Home-page: https://github.com/kazewong/flowMC
6
6
  Author: Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
@@ -1,6 +1,6 @@
1
1
  [metadata]
2
2
  name = flowMC
3
- version = 0.3.2
3
+ version = 0.3.4
4
4
  author = Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
5
5
  author_email = kazewong.physics@gmail.com
6
6
  url = https://github.com/kazewong/flowMC
@@ -271,7 +271,7 @@ class Sampler:
271
271
  Args:
272
272
  path (str): Path to save the normalizing flow.
273
273
  """
274
- self.nf_model.load_model(path)
274
+ self.nf_model = self.nf_model.load_model(path)
275
275
 
276
276
  def reset(self):
277
277
  """
@@ -388,3 +388,46 @@ class Sampler:
388
388
  """
389
389
  with open(path, "wb") as f:
390
390
  pickle.dump(self.summary, f)
391
+
392
+ def print_summary(self) -> None:
393
+ """
394
+ Print summary to the screen about log probabilities and local/global acceptance rates.
395
+ """
396
+ train_summary = self.get_sampler_state(training=True)
397
+ production_summary = self.get_sampler_state(training=False)
398
+
399
+ training_log_prob = train_summary["log_prob"]
400
+ training_local_acceptance = train_summary["local_accs"]
401
+ training_global_acceptance = train_summary["global_accs"]
402
+ training_loss = train_summary["loss_vals"]
403
+
404
+ production_log_prob = production_summary["log_prob"]
405
+ production_local_acceptance = production_summary["local_accs"]
406
+ production_global_acceptance = production_summary["global_accs"]
407
+
408
+ print("Training summary")
409
+ print("=" * 10)
410
+ print(
411
+ f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}"
412
+ )
413
+ print(
414
+ f"Local acceptance: {training_local_acceptance.mean():.3f} +/- {training_local_acceptance.std():.3f}"
415
+ )
416
+ print(
417
+ f"Global acceptance: {training_global_acceptance.mean():.3f} +/- {training_global_acceptance.std():.3f}"
418
+ )
419
+ print(
420
+ f"Max loss: {training_loss.max():.3f}, Min loss: {training_loss.min():.3f}"
421
+ )
422
+
423
+ print("Production summary")
424
+ print("=" * 10)
425
+ print(
426
+ f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}"
427
+ )
428
+ print(
429
+ f"Local acceptance: {production_local_acceptance.mean():.3f} +/- {production_local_acceptance.std():.3f}"
430
+ )
431
+ print(
432
+ f"Global acceptance: {production_global_acceptance.mean():.3f} +/- {production_global_acceptance.std():.3f}"
433
+ )
@@ -78,7 +78,7 @@ class NFModel(eqx.Module):
78
78
  eqx.tree_serialise_leaves(path + ".eqx", self)
79
79
 
80
80
  def load_model(self, path: str):
81
- self = eqx.tree_deserialise_leaves(path + ".eqx", self)
81
+ return eqx.tree_deserialise_leaves(path + ".eqx", self)
82
82
 
83
83
  @eqx.filter_value_and_grad
84
84
  def loss_fn(self, x):
@@ -10,23 +10,31 @@ from flowMC.strategy.base import Strategy
10
10
 
11
11
 
12
12
  class optimization_Adam(Strategy):
13
-
14
13
  """
15
14
  Optimize a set of chains using Adam optimization.
16
15
  Note that if the posterior can go to infinity, this optimization scheme is likely to return NaNs.
17
-
16
+
17
+ Args:
18
+ n_steps: int = 100
19
+ Number of optimization steps.
20
+ learning_rate: float = 1e-2
21
+ Learning rate for the optimization.
22
+ noise_level: float = 10
23
+ Variance of the noise added to the gradients.
18
24
  """
19
25
 
20
26
  n_steps: int = 100
21
27
  learning_rate: float = 1e-2
22
28
  noise_level: float = 10
29
+ bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])
23
30
 
24
31
  @property
25
32
  def __name__(self):
26
33
  return "AdamOptimization"
27
-
34
+
28
35
  def __init__(
29
36
  self,
37
+ bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
30
38
  **kwargs,
31
39
  ):
32
40
  class_keys = list(self.__class__.__annotations__.keys())
@@ -39,6 +47,8 @@ class optimization_Adam(Strategy):
39
47
  optax.adam(learning_rate=self.learning_rate),
40
48
  )
41
49
 
50
+ self.bounds = bounds
51
+
42
52
  def __call__(
43
53
  self,
44
54
  rng_key: PRNGKeyArray,
@@ -61,6 +71,7 @@ class optimization_Adam(Strategy):
61
71
  grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
62
72
  updates, opt_state = self.solver.update(grad, opt_state, params)
63
73
  params = optax.apply_updates(params, updates)
74
+ params = optax.projections.projection_box(params, self.bounds[:, 0], self.bounds[:, 1])
64
75
  return (key, params, opt_state), None
65
76
 
66
77
  def _single_optimize(
@@ -91,11 +102,77 @@ class optimization_Adam(Strategy):
91
102
  summary["final_positions"] = optimized_positions
92
103
  summary["final_log_prob"] = local_sampler.logpdf_vmap(optimized_positions, data)
93
104
 
94
- if jnp.isinf(summary['final_log_prob']).any() or jnp.isnan(summary['final_log_prob']).any():
105
+ if (
106
+ jnp.isinf(summary["final_log_prob"]).any()
107
+ or jnp.isnan(summary["final_log_prob"]).any()
108
+ ):
95
109
  print("Warning: Optimization accessed infinite or NaN log-probabilities.")
96
-
110
+
97
111
  return rng_key, optimized_positions, local_sampler, global_sampler, summary
98
112
 
113
+ def optimize(
114
+ self,
115
+ rng_key: PRNGKeyArray,
116
+ objective: Callable,
117
+ initial_position: Float[Array, " n_chain n_dim"],
118
+ ):
119
+ """
120
+ Standalone optimization function that takes an objective function and returns the optimized positions.
121
+
122
+ Args:
123
+ rng_key: PRNGKeyArray
124
+ Random key for the optimization.
125
+ objective: Callable
126
+ Objective function to optimize.
127
+ initial_position: Float[Array, " n_chain n_dim"]
128
+ Initial positions for the optimization.
129
+ """
130
+ grad_fn = jax.jit(jax.grad(objective))
131
+
132
+ def _kernel(carry, data):
133
+ key, params, opt_state = carry
134
+
135
+ key, subkey = jax.random.split(key)
136
+ grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
137
+ updates, opt_state = self.solver.update(grad, opt_state, params)
138
+ params = optax.apply_updates(params, updates)
139
+ return (key, params, opt_state), None
140
+
141
+ def _single_optimize(
142
+ key: PRNGKeyArray,
143
+ initial_position: Float[Array, " n_dim"],
144
+ ) -> Float[Array, " n_dim"]:
145
+
146
+ opt_state = self.solver.init(initial_position)
147
+
148
+ (key, params, opt_state), _ = jax.lax.scan(
149
+ _kernel,
150
+ (key, initial_position, opt_state),
151
+ jnp.arange(self.n_steps),
152
+ )
153
+
154
+ return params # type: ignore
155
+
156
+ print("Using Adam optimization")
157
+ rng_key, subkey = jax.random.split(rng_key)
158
+ keys = jax.random.split(subkey, initial_position.shape[0])
159
+ optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
160
+ keys, initial_position
161
+ )
162
+
163
+ summary = {}
164
+ summary["initial_positions"] = initial_position
165
+ summary["initial_log_prob"] = jax.jit(jax.vmap(objective))(initial_position)
166
+ summary["final_positions"] = optimized_positions
167
+ summary["final_log_prob"] = jax.jit(jax.vmap(objective))(optimized_positions)
168
+
169
+ if (
170
+ jnp.isinf(summary["final_log_prob"]).any()
171
+ or jnp.isnan(summary["final_log_prob"]).any()
172
+ ):
173
+ print("Warning: Optimization accessed infinite or NaN log-probabilities.")
174
+
175
+ return rng_key, optimized_positions, summary
99
176
 
100
177
  class Evosax_CMA_ES(Strategy):
101
178
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flowMC
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Normalizing flow exhanced sampler in jax
5
5
  Home-page: https://github.com/kazewong/flowMC
6
6
  Author: Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
File without changes
File without changes
File without changes
File without changes