aspire-inference 0.1.0a7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aspire/__init__.py +19 -0
- aspire/aspire.py +506 -0
- aspire/flows/__init__.py +40 -0
- aspire/flows/base.py +84 -0
- aspire/flows/jax/__init__.py +3 -0
- aspire/flows/jax/flows.py +196 -0
- aspire/flows/jax/utils.py +57 -0
- aspire/flows/torch/__init__.py +0 -0
- aspire/flows/torch/flows.py +344 -0
- aspire/history.py +148 -0
- aspire/plot.py +50 -0
- aspire/samplers/__init__.py +0 -0
- aspire/samplers/base.py +94 -0
- aspire/samplers/importance.py +22 -0
- aspire/samplers/mcmc.py +160 -0
- aspire/samplers/smc/__init__.py +0 -0
- aspire/samplers/smc/base.py +318 -0
- aspire/samplers/smc/blackjax.py +332 -0
- aspire/samplers/smc/emcee.py +75 -0
- aspire/samplers/smc/minipcn.py +82 -0
- aspire/samples.py +568 -0
- aspire/transforms.py +751 -0
- aspire/utils.py +760 -0
- aspire_inference-0.1.0a7.dist-info/METADATA +52 -0
- aspire_inference-0.1.0a7.dist-info/RECORD +28 -0
- aspire_inference-0.1.0a7.dist-info/WHEEL +5 -0
- aspire_inference-0.1.0a7.dist-info/licenses/LICENSE +21 -0
- aspire_inference-0.1.0a7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Callable
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jrandom
|
|
7
|
+
from flowjax.train import fit_to_data
|
|
8
|
+
|
|
9
|
+
from ...transforms import IdentityTransform
|
|
10
|
+
from ...utils import decode_dtype, encode_dtype, resolve_dtype
|
|
11
|
+
from ..base import Flow
|
|
12
|
+
from .utils import get_flow
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FlowJax(Flow):
|
|
18
|
+
xp = jnp
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
dims: int,
|
|
23
|
+
key=None,
|
|
24
|
+
data_transform=None,
|
|
25
|
+
dtype=None,
|
|
26
|
+
**kwargs,
|
|
27
|
+
):
|
|
28
|
+
device = kwargs.pop("device", None)
|
|
29
|
+
if device is not None:
|
|
30
|
+
logger.warning("The device argument is not used in FlowJax. ")
|
|
31
|
+
resolved_dtype = (
|
|
32
|
+
resolve_dtype(dtype, jnp)
|
|
33
|
+
if dtype is not None
|
|
34
|
+
else jnp.dtype(jnp.float32)
|
|
35
|
+
)
|
|
36
|
+
if data_transform is None:
|
|
37
|
+
data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
|
|
38
|
+
elif getattr(data_transform, "dtype", None) is None:
|
|
39
|
+
data_transform.dtype = resolved_dtype
|
|
40
|
+
super().__init__(dims, device=device, data_transform=data_transform)
|
|
41
|
+
self.dtype = resolved_dtype
|
|
42
|
+
if key is None:
|
|
43
|
+
key = jrandom.key(0)
|
|
44
|
+
logger.warning(
|
|
45
|
+
"The key argument is None. "
|
|
46
|
+
"A random key will be used for the flow. "
|
|
47
|
+
"Results may not be reproducible."
|
|
48
|
+
)
|
|
49
|
+
self.key = key
|
|
50
|
+
self.loc = None
|
|
51
|
+
self.scale = None
|
|
52
|
+
self.key, subkey = jrandom.split(self.key)
|
|
53
|
+
self._flow = get_flow(
|
|
54
|
+
key=subkey,
|
|
55
|
+
dims=self.dims,
|
|
56
|
+
dtype=self.dtype,
|
|
57
|
+
**kwargs,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def fit(self, x, **kwargs):
|
|
61
|
+
from ...history import FlowHistory
|
|
62
|
+
|
|
63
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
64
|
+
x_prime = jnp.asarray(self.fit_data_transform(x), dtype=self.dtype)
|
|
65
|
+
self.key, subkey = jrandom.split(self.key)
|
|
66
|
+
self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
|
|
67
|
+
return FlowHistory(
|
|
68
|
+
training_loss=list(losses["train"]),
|
|
69
|
+
validation_loss=list(losses["val"]),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def forward(self, x, xp: Callable = jnp):
|
|
73
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
74
|
+
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
75
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
76
|
+
z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
|
|
77
|
+
return xp.asarray(z), xp.asarray(
|
|
78
|
+
log_abs_det_jacobian + log_abs_det_jacobian_flow
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def inverse(self, z, xp: Callable = jnp):
|
|
82
|
+
z = jnp.asarray(z, dtype=self.dtype)
|
|
83
|
+
x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
|
|
84
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
85
|
+
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
86
|
+
return xp.asarray(x), xp.asarray(
|
|
87
|
+
log_abs_det_jacobian + log_abs_det_jacobian_flow
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def log_prob(self, x, xp: Callable = jnp):
|
|
91
|
+
x = jnp.asarray(x, dtype=self.dtype)
|
|
92
|
+
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
93
|
+
x_prime = jnp.asarray(x_prime, dtype=self.dtype)
|
|
94
|
+
log_prob = self._flow.log_prob(x_prime)
|
|
95
|
+
return xp.asarray(log_prob + log_abs_det_jacobian)
|
|
96
|
+
|
|
97
|
+
def sample(self, n_samples: int, xp: Callable = jnp):
|
|
98
|
+
self.key, subkey = jrandom.split(self.key)
|
|
99
|
+
x_prime = self._flow.sample(subkey, (n_samples,))
|
|
100
|
+
x = self.inverse_rescale(x_prime)[0]
|
|
101
|
+
return xp.asarray(x)
|
|
102
|
+
|
|
103
|
+
def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp):
|
|
104
|
+
self.key, subkey = jrandom.split(self.key)
|
|
105
|
+
x_prime = self._flow.sample(subkey, (n_samples,))
|
|
106
|
+
log_prob = self._flow.log_prob(x_prime)
|
|
107
|
+
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
108
|
+
return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
|
|
109
|
+
|
|
110
|
+
def save(self, h5_file, path="flow"):
|
|
111
|
+
import equinox as eqx
|
|
112
|
+
from array_api_compat import numpy as np
|
|
113
|
+
|
|
114
|
+
from ...utils import recursively_save_to_h5_file
|
|
115
|
+
|
|
116
|
+
grp = h5_file.require_group(path)
|
|
117
|
+
|
|
118
|
+
# ---- config ----
|
|
119
|
+
config = self.config_dict().copy()
|
|
120
|
+
config.pop("key", None)
|
|
121
|
+
config["key_data"] = jax.random.key_data(self.key)
|
|
122
|
+
dtype_value = config.get("dtype")
|
|
123
|
+
if dtype_value is None:
|
|
124
|
+
dtype_value = self.dtype
|
|
125
|
+
else:
|
|
126
|
+
dtype_value = jnp.dtype(dtype_value)
|
|
127
|
+
print(dtype_value)
|
|
128
|
+
config["dtype"] = encode_dtype(jnp, dtype_value)
|
|
129
|
+
|
|
130
|
+
data_transform = config.pop("data_transform", None)
|
|
131
|
+
if data_transform is not None:
|
|
132
|
+
data_transform.save(grp, "data_transform")
|
|
133
|
+
|
|
134
|
+
recursively_save_to_h5_file(grp, "config", config)
|
|
135
|
+
|
|
136
|
+
# ---- save arrays ----
|
|
137
|
+
arrays, _ = eqx.partition(self._flow, eqx.is_array)
|
|
138
|
+
leaves, _ = jax.tree_util.tree_flatten(arrays)
|
|
139
|
+
|
|
140
|
+
params_grp = grp.require_group("params")
|
|
141
|
+
# clear old datasets
|
|
142
|
+
for name in list(params_grp.keys()):
|
|
143
|
+
del params_grp[name]
|
|
144
|
+
|
|
145
|
+
for i, p in enumerate(leaves):
|
|
146
|
+
params_grp.create_dataset(str(i), data=np.asarray(p))
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def load(cls, h5_file, path="flow"):
|
|
150
|
+
import equinox as eqx
|
|
151
|
+
|
|
152
|
+
from ...utils import load_from_h5_file
|
|
153
|
+
|
|
154
|
+
grp = h5_file[path]
|
|
155
|
+
|
|
156
|
+
# ---- config ----
|
|
157
|
+
config = load_from_h5_file(grp, "config")
|
|
158
|
+
config["dtype"] = decode_dtype(jnp, config.get("dtype"))
|
|
159
|
+
if "data_transform" in grp:
|
|
160
|
+
from ...transforms import BaseTransform
|
|
161
|
+
|
|
162
|
+
config["data_transform"] = BaseTransform.load(
|
|
163
|
+
grp,
|
|
164
|
+
"data_transform",
|
|
165
|
+
strict=False,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
key_data = config.pop("key_data", None)
|
|
169
|
+
if key_data is not None:
|
|
170
|
+
config["key"] = jax.random.wrap_key_data(key_data)
|
|
171
|
+
|
|
172
|
+
kwargs = config.pop("kwargs", {})
|
|
173
|
+
config.update(kwargs)
|
|
174
|
+
|
|
175
|
+
# build object (will replace its _flow)
|
|
176
|
+
obj = cls(**config)
|
|
177
|
+
|
|
178
|
+
# ---- load arrays ----
|
|
179
|
+
params_grp = grp["params"]
|
|
180
|
+
loaded_params = [
|
|
181
|
+
jnp.array(params_grp[str(i)][:]) for i in range(len(params_grp))
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
# rebuild template flow
|
|
185
|
+
kwargs.pop("device")
|
|
186
|
+
flow_template = get_flow(key=jrandom.key(0), dims=obj.dims, **kwargs)
|
|
187
|
+
arrays_template, static = eqx.partition(flow_template, eqx.is_array)
|
|
188
|
+
|
|
189
|
+
# use treedef from template
|
|
190
|
+
treedef = jax.tree_util.tree_structure(arrays_template)
|
|
191
|
+
arrays = jax.tree_util.tree_unflatten(treedef, loaded_params)
|
|
192
|
+
|
|
193
|
+
# recombine
|
|
194
|
+
obj._flow = eqx.combine(static, arrays)
|
|
195
|
+
|
|
196
|
+
return obj
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
import flowjax.bijections
|
|
4
|
+
import flowjax.distributions
|
|
5
|
+
import flowjax.flows
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import jax.random as jrandom
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_flow_function_class(name: str) -> Callable:
|
|
12
|
+
try:
|
|
13
|
+
return getattr(flowjax.flows, name)
|
|
14
|
+
except AttributeError:
|
|
15
|
+
raise ValueError(f"Unknown flow function: {name}")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_bijection_class(name: str) -> Callable:
|
|
19
|
+
try:
|
|
20
|
+
return getattr(flowjax.bijections, name)
|
|
21
|
+
except AttributeError:
|
|
22
|
+
raise ValueError(f"Unknown bijection: {name}")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_flow(
|
|
26
|
+
*,
|
|
27
|
+
key: jax.Array,
|
|
28
|
+
dims: int,
|
|
29
|
+
flow_type: str | Callable = "masked_autoregressive_flow",
|
|
30
|
+
bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
|
|
31
|
+
bijection_kwargs: dict | None = None,
|
|
32
|
+
dtype=None,
|
|
33
|
+
**kwargs,
|
|
34
|
+
) -> flowjax.distributions.Transformed:
|
|
35
|
+
dtype = dtype or jnp.float32
|
|
36
|
+
|
|
37
|
+
if isinstance(flow_type, str):
|
|
38
|
+
flow_type = get_flow_function_class(flow_type)
|
|
39
|
+
|
|
40
|
+
if isinstance(bijection_type, str):
|
|
41
|
+
bijection_type = get_bijection_class(bijection_type)
|
|
42
|
+
if bijection_type is not None:
|
|
43
|
+
transformer = bijection_type(**bijection_kwargs)
|
|
44
|
+
else:
|
|
45
|
+
transformer = None
|
|
46
|
+
|
|
47
|
+
if bijection_kwargs is None:
|
|
48
|
+
bijection_kwargs = {}
|
|
49
|
+
|
|
50
|
+
base_dist = flowjax.distributions.Normal(jnp.zeros(dims, dtype=dtype))
|
|
51
|
+
key, subkey = jrandom.split(key)
|
|
52
|
+
return flow_type(
|
|
53
|
+
subkey,
|
|
54
|
+
base_dist=base_dist,
|
|
55
|
+
transformer=transformer,
|
|
56
|
+
**kwargs,
|
|
57
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import array_api_compat.torch as torch_api
|
|
6
|
+
import torch
|
|
7
|
+
import tqdm
|
|
8
|
+
import zuko
|
|
9
|
+
from array_api_compat import is_numpy_namespace, is_torch_array
|
|
10
|
+
|
|
11
|
+
from ...history import FlowHistory
|
|
12
|
+
from ...transforms import IdentityTransform
|
|
13
|
+
from ...utils import decode_dtype, encode_dtype, resolve_dtype
|
|
14
|
+
from ..base import Flow
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseTorchFlow(Flow):
|
|
20
|
+
_flow = None
|
|
21
|
+
xp = torch_api
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
dims: int,
|
|
26
|
+
seed: int = 1234,
|
|
27
|
+
device: str = "cpu",
|
|
28
|
+
data_transform=None,
|
|
29
|
+
dtype=None,
|
|
30
|
+
):
|
|
31
|
+
resolved_dtype = (
|
|
32
|
+
resolve_dtype(dtype, torch)
|
|
33
|
+
if dtype is not None
|
|
34
|
+
else torch.get_default_dtype()
|
|
35
|
+
)
|
|
36
|
+
if data_transform is None:
|
|
37
|
+
data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
|
|
38
|
+
elif getattr(data_transform, "dtype", None) is None:
|
|
39
|
+
data_transform.dtype = resolved_dtype
|
|
40
|
+
super().__init__(
|
|
41
|
+
dims,
|
|
42
|
+
device=torch.device(device or "cpu"),
|
|
43
|
+
data_transform=data_transform,
|
|
44
|
+
)
|
|
45
|
+
self.dtype = resolved_dtype
|
|
46
|
+
torch.manual_seed(seed)
|
|
47
|
+
self.loc = None
|
|
48
|
+
self.scale = None
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def flow(self):
|
|
52
|
+
return self._flow
|
|
53
|
+
|
|
54
|
+
@flow.setter
|
|
55
|
+
def flow(self, flow):
|
|
56
|
+
self._flow = flow
|
|
57
|
+
self._flow.to(device=self.device, dtype=self.dtype)
|
|
58
|
+
self._flow.compile()
|
|
59
|
+
|
|
60
|
+
def fit(self, x) -> FlowHistory:
|
|
61
|
+
raise NotImplementedError()
|
|
62
|
+
|
|
63
|
+
def save(self, h5_file, path="flow"):
|
|
64
|
+
"""Save the weights of the flow to an HDF5 file."""
|
|
65
|
+
from ...utils import recursively_save_to_h5_file
|
|
66
|
+
|
|
67
|
+
flow_grp = h5_file.create_group(path)
|
|
68
|
+
# Save config
|
|
69
|
+
config = self.config_dict().copy()
|
|
70
|
+
data_transform = config.pop("data_transform", None)
|
|
71
|
+
dtype_value = config.get("dtype")
|
|
72
|
+
if dtype_value is None:
|
|
73
|
+
dtype_value = self.dtype
|
|
74
|
+
else:
|
|
75
|
+
dtype_value = resolve_dtype(dtype_value, torch)
|
|
76
|
+
config["dtype"] = encode_dtype(torch, dtype_value)
|
|
77
|
+
if data_transform is not None:
|
|
78
|
+
data_transform.save(flow_grp, "data_transform")
|
|
79
|
+
recursively_save_to_h5_file(flow_grp, "config", config)
|
|
80
|
+
# Save weights
|
|
81
|
+
weights_grp = flow_grp.create_group("weights")
|
|
82
|
+
for name, tensor in self._flow.state_dict().items():
|
|
83
|
+
weights_grp.create_dataset(name, data=tensor.cpu().numpy())
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def load(self, h5_file, path="flow"):
|
|
87
|
+
"""Load the weights of the flow from an HDF5 file."""
|
|
88
|
+
from ...utils import load_from_h5_file
|
|
89
|
+
|
|
90
|
+
flow_grp = h5_file[path]
|
|
91
|
+
# Load config
|
|
92
|
+
config = load_from_h5_file(flow_grp, "config")
|
|
93
|
+
config["dtype"] = decode_dtype(torch, config.get("dtype"))
|
|
94
|
+
if "data_transform" in flow_grp:
|
|
95
|
+
from ..transforms import BaseTransform
|
|
96
|
+
|
|
97
|
+
data_transform = BaseTransform.load(
|
|
98
|
+
flow_grp,
|
|
99
|
+
"data_transform",
|
|
100
|
+
strict=False,
|
|
101
|
+
)
|
|
102
|
+
config["data_transform"] = data_transform
|
|
103
|
+
obj = self(**config)
|
|
104
|
+
# Load weights
|
|
105
|
+
weights = {
|
|
106
|
+
name: torch.tensor(data[()])
|
|
107
|
+
for name, data in flow_grp["weights"].items()
|
|
108
|
+
}
|
|
109
|
+
obj._flow.load_state_dict(weights)
|
|
110
|
+
return obj
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ZukoFlow(BaseTorchFlow):
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
dims,
|
|
117
|
+
flow_class: str | Callable = "MAF",
|
|
118
|
+
data_transform=None,
|
|
119
|
+
seed=1234,
|
|
120
|
+
device: str = "cpu",
|
|
121
|
+
dtype=None,
|
|
122
|
+
**kwargs,
|
|
123
|
+
):
|
|
124
|
+
super().__init__(
|
|
125
|
+
dims,
|
|
126
|
+
device=device,
|
|
127
|
+
data_transform=data_transform,
|
|
128
|
+
seed=seed,
|
|
129
|
+
dtype=dtype,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if isinstance(flow_class, str):
|
|
133
|
+
FlowClass = getattr(zuko.flows, flow_class)
|
|
134
|
+
else:
|
|
135
|
+
FlowClass = flow_class
|
|
136
|
+
|
|
137
|
+
# Ints are some times passed as strings, so we convert them
|
|
138
|
+
if hidden_features := kwargs.pop("hidden_features", None):
|
|
139
|
+
kwargs["hidden_features"] = list(map(int, hidden_features))
|
|
140
|
+
|
|
141
|
+
self.flow = FlowClass(self.dims, 0, **kwargs)
|
|
142
|
+
logger.info(f"Initialized normalizing flow: \n {self.flow}\n")
|
|
143
|
+
|
|
144
|
+
def loss_fn(self, x):
|
|
145
|
+
return -self.flow().log_prob(x).mean()
|
|
146
|
+
|
|
147
|
+
def fit(
|
|
148
|
+
self,
|
|
149
|
+
x,
|
|
150
|
+
n_epochs: int = 100,
|
|
151
|
+
lr: float = 1e-3,
|
|
152
|
+
batch_size: int = 500,
|
|
153
|
+
validation_fraction: float = 0.2,
|
|
154
|
+
clip_grad: float | None = None,
|
|
155
|
+
lr_annealing: bool = False,
|
|
156
|
+
):
|
|
157
|
+
from ...history import FlowHistory
|
|
158
|
+
|
|
159
|
+
if not is_torch_array(x):
|
|
160
|
+
x = torch.tensor(x, dtype=self.dtype, device=self.device)
|
|
161
|
+
else:
|
|
162
|
+
x = torch.clone(x)
|
|
163
|
+
x = x.type(self.dtype)
|
|
164
|
+
x = x.to(self.device)
|
|
165
|
+
x_prime = self.fit_data_transform(x)
|
|
166
|
+
indices = torch.randperm(x_prime.shape[0])
|
|
167
|
+
x_prime = x_prime[indices, ...]
|
|
168
|
+
|
|
169
|
+
n = x_prime.shape[0]
|
|
170
|
+
x_train = torch.as_tensor(
|
|
171
|
+
x_prime[: -int(validation_fraction * n)],
|
|
172
|
+
dtype=self.dtype,
|
|
173
|
+
device=self.device,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
logger.info(
|
|
177
|
+
f"Training on {x_train.shape[0]} samples, "
|
|
178
|
+
f"validating on {x_prime.shape[0] - x_train.shape[0]} samples."
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if torch.isnan(x_train).any():
|
|
182
|
+
dims_with_nan = (
|
|
183
|
+
torch.isnan(x_train).any(dim=0).nonzero(as_tuple=True)[0]
|
|
184
|
+
)
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Training data contains NaN values in dimensions: {dims_with_nan.tolist()}"
|
|
187
|
+
)
|
|
188
|
+
if not torch.isfinite(x_train).all():
|
|
189
|
+
dims_with_inf = (
|
|
190
|
+
(~torch.isfinite(x_train)).any(dim=0).nonzero(as_tuple=True)[0]
|
|
191
|
+
)
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Training data contains infinite values in dimensions: {dims_with_inf.tolist()}"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
x_val = torch.as_tensor(
|
|
197
|
+
x_prime[-int(validation_fraction * n) :],
|
|
198
|
+
dtype=self.dtype,
|
|
199
|
+
device=self.device,
|
|
200
|
+
)
|
|
201
|
+
if torch.isnan(x_val).any():
|
|
202
|
+
raise ValueError("Validation data contains infinite values.")
|
|
203
|
+
|
|
204
|
+
if not torch.isfinite(x_val).all():
|
|
205
|
+
raise ValueError("Validation data contains infinite values.")
|
|
206
|
+
|
|
207
|
+
dataset = torch.utils.data.DataLoader(
|
|
208
|
+
torch.utils.data.TensorDataset(x_train),
|
|
209
|
+
shuffle=True,
|
|
210
|
+
batch_size=batch_size,
|
|
211
|
+
)
|
|
212
|
+
val_dataset = torch.utils.data.DataLoader(
|
|
213
|
+
torch.utils.data.TensorDataset(x_val),
|
|
214
|
+
shuffle=False,
|
|
215
|
+
batch_size=batch_size,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Train to maximize the log-likelihood
|
|
219
|
+
optimizer = torch.optim.Adam(self._flow.parameters(), lr=lr)
|
|
220
|
+
if lr_annealing:
|
|
221
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
222
|
+
optimizer, n_epochs
|
|
223
|
+
)
|
|
224
|
+
history = FlowHistory()
|
|
225
|
+
|
|
226
|
+
best_val_loss = float("inf")
|
|
227
|
+
best_flow_state = None
|
|
228
|
+
|
|
229
|
+
with tqdm.tqdm(range(n_epochs), desc="Epochs") as pbar:
|
|
230
|
+
for _ in pbar:
|
|
231
|
+
self.flow.train()
|
|
232
|
+
loss_epoch = 0.0
|
|
233
|
+
for (x_batch,) in dataset:
|
|
234
|
+
loss = self.loss_fn(x_batch)
|
|
235
|
+
optimizer.zero_grad()
|
|
236
|
+
loss.backward()
|
|
237
|
+
if clip_grad is not None:
|
|
238
|
+
torch.nn.utils.clip_grad_norm_(
|
|
239
|
+
self.flow.parameters(), clip_grad
|
|
240
|
+
)
|
|
241
|
+
optimizer.step()
|
|
242
|
+
loss_epoch += loss.item()
|
|
243
|
+
if lr_annealing:
|
|
244
|
+
scheduler.step()
|
|
245
|
+
avg_train_loss = loss_epoch / len(dataset)
|
|
246
|
+
history.training_loss.append(avg_train_loss)
|
|
247
|
+
self.flow.eval()
|
|
248
|
+
val_loss = 0.0
|
|
249
|
+
for (x_batch,) in val_dataset:
|
|
250
|
+
with torch.no_grad():
|
|
251
|
+
val_loss += self.loss_fn(x_batch).item()
|
|
252
|
+
avg_val_loss = val_loss / len(val_dataset)
|
|
253
|
+
if avg_val_loss < best_val_loss:
|
|
254
|
+
best_val_loss = avg_val_loss
|
|
255
|
+
best_flow_state = copy.deepcopy(self.flow.state_dict())
|
|
256
|
+
|
|
257
|
+
history.validation_loss.append(avg_val_loss)
|
|
258
|
+
pbar.set_postfix(
|
|
259
|
+
train_loss=f"{avg_train_loss:.4f}",
|
|
260
|
+
val_loss=f"{avg_val_loss:.4f}",
|
|
261
|
+
)
|
|
262
|
+
if best_flow_state is not None:
|
|
263
|
+
self.flow.load_state_dict(best_flow_state)
|
|
264
|
+
logger.info(f"Loaded best model with val loss {best_val_loss:.4f}")
|
|
265
|
+
|
|
266
|
+
self.flow.eval()
|
|
267
|
+
return history
|
|
268
|
+
|
|
269
|
+
def sample_and_log_prob(self, n_samples: int, xp=torch_api):
|
|
270
|
+
with torch.no_grad():
|
|
271
|
+
x_prime, log_prob = self.flow().rsample_and_log_prob((n_samples,))
|
|
272
|
+
x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
|
|
273
|
+
return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
|
|
274
|
+
|
|
275
|
+
def sample(self, n_samples: int, xp=torch_api):
|
|
276
|
+
with torch.no_grad():
|
|
277
|
+
x_prime = self.flow().rsample((n_samples,))
|
|
278
|
+
x = self.inverse_rescale(x_prime)[0]
|
|
279
|
+
return xp.asarray(x)
|
|
280
|
+
|
|
281
|
+
def log_prob(self, x, xp=torch_api):
|
|
282
|
+
x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
|
|
283
|
+
x_prime, log_abs_det_jacobian = self.rescale(x)
|
|
284
|
+
return xp.asarray(
|
|
285
|
+
self._flow().log_prob(x_prime) + log_abs_det_jacobian
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def forward(self, x, xp=torch_api):
|
|
289
|
+
x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
|
|
290
|
+
x_prime, log_j_rescale = self.rescale(x)
|
|
291
|
+
z, log_abs_det_jacobian = self._flow().transform.call_and_ladj(x_prime)
|
|
292
|
+
if is_numpy_namespace(xp):
|
|
293
|
+
# Convert to numpy namespace if needed
|
|
294
|
+
z = z.detach().numpy()
|
|
295
|
+
log_abs_det_jacobian = log_abs_det_jacobian.detach().numpy()
|
|
296
|
+
log_j_rescale = log_j_rescale.detach().numpy()
|
|
297
|
+
return xp.asarray(z), xp.asarray(log_abs_det_jacobian + log_j_rescale)
|
|
298
|
+
|
|
299
|
+
def inverse(self, z, xp=torch_api):
|
|
300
|
+
z = torch.as_tensor(z, dtype=self.dtype, device=self.device)
|
|
301
|
+
with torch.no_grad():
|
|
302
|
+
x_prime, log_abs_det_jacobian = (
|
|
303
|
+
self._flow().transform.inv.call_and_ladj(z)
|
|
304
|
+
)
|
|
305
|
+
x, log_j_rescale = self.inverse_rescale(x_prime)
|
|
306
|
+
if is_numpy_namespace(xp):
|
|
307
|
+
# Convert to numpy namespace if needed
|
|
308
|
+
x = x.detach().numpy()
|
|
309
|
+
log_abs_det_jacobian = log_abs_det_jacobian.detach().numpy()
|
|
310
|
+
log_j_rescale = log_j_rescale.detach().numpy()
|
|
311
|
+
return xp.asarray(x), xp.asarray(log_j_rescale + log_abs_det_jacobian)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class ZukoFlowMatching(ZukoFlow):
|
|
315
|
+
def __init__(
|
|
316
|
+
self,
|
|
317
|
+
dims,
|
|
318
|
+
data_transform=None,
|
|
319
|
+
seed=1234,
|
|
320
|
+
device="cpu",
|
|
321
|
+
eta: float = 1e-3,
|
|
322
|
+
dtype=None,
|
|
323
|
+
**kwargs,
|
|
324
|
+
):
|
|
325
|
+
kwargs.setdefault("hidden_features", 4 * [100])
|
|
326
|
+
super().__init__(
|
|
327
|
+
dims,
|
|
328
|
+
seed=seed,
|
|
329
|
+
device=device,
|
|
330
|
+
data_transform=data_transform,
|
|
331
|
+
flow_class="CNF",
|
|
332
|
+
dtype=dtype,
|
|
333
|
+
)
|
|
334
|
+
self.eta = eta
|
|
335
|
+
|
|
336
|
+
def loss_fn(self, theta: torch.Tensor):
|
|
337
|
+
t = torch.rand(
|
|
338
|
+
theta.shape[:-1], dtype=theta.dtype, device=theta.device
|
|
339
|
+
)
|
|
340
|
+
t_ = t[..., None]
|
|
341
|
+
eps = torch.randn_like(theta)
|
|
342
|
+
theta_prime = (1 - t_) * theta + (t_ + self.eta) * eps
|
|
343
|
+
v = eps - theta
|
|
344
|
+
return (self._flow.transform.f(t, theta_prime) - v).square().mean()
|