aspire-inference 0.1.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,196 @@
1
+ import logging
2
+ from typing import Callable
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jax.random as jrandom
7
+ from flowjax.train import fit_to_data
8
+
9
+ from ...transforms import IdentityTransform
10
+ from ...utils import decode_dtype, encode_dtype, resolve_dtype
11
+ from ..base import Flow
12
+ from .utils import get_flow
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class FlowJax(Flow):
18
+ xp = jnp
19
+
20
+ def __init__(
21
+ self,
22
+ dims: int,
23
+ key=None,
24
+ data_transform=None,
25
+ dtype=None,
26
+ **kwargs,
27
+ ):
28
+ device = kwargs.pop("device", None)
29
+ if device is not None:
30
+ logger.warning("The device argument is not used in FlowJax. ")
31
+ resolved_dtype = (
32
+ resolve_dtype(dtype, jnp)
33
+ if dtype is not None
34
+ else jnp.dtype(jnp.float32)
35
+ )
36
+ if data_transform is None:
37
+ data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
38
+ elif getattr(data_transform, "dtype", None) is None:
39
+ data_transform.dtype = resolved_dtype
40
+ super().__init__(dims, device=device, data_transform=data_transform)
41
+ self.dtype = resolved_dtype
42
+ if key is None:
43
+ key = jrandom.key(0)
44
+ logger.warning(
45
+ "The key argument is None. "
46
+ "A random key will be used for the flow. "
47
+ "Results may not be reproducible."
48
+ )
49
+ self.key = key
50
+ self.loc = None
51
+ self.scale = None
52
+ self.key, subkey = jrandom.split(self.key)
53
+ self._flow = get_flow(
54
+ key=subkey,
55
+ dims=self.dims,
56
+ dtype=self.dtype,
57
+ **kwargs,
58
+ )
59
+
60
+ def fit(self, x, **kwargs):
61
+ from ...history import FlowHistory
62
+
63
+ x = jnp.asarray(x, dtype=self.dtype)
64
+ x_prime = jnp.asarray(self.fit_data_transform(x), dtype=self.dtype)
65
+ self.key, subkey = jrandom.split(self.key)
66
+ self._flow, losses = fit_to_data(subkey, self._flow, x_prime, **kwargs)
67
+ return FlowHistory(
68
+ training_loss=list(losses["train"]),
69
+ validation_loss=list(losses["val"]),
70
+ )
71
+
72
+ def forward(self, x, xp: Callable = jnp):
73
+ x = jnp.asarray(x, dtype=self.dtype)
74
+ x_prime, log_abs_det_jacobian = self.rescale(x)
75
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
76
+ z, log_abs_det_jacobian_flow = self._flow.forward(x_prime)
77
+ return xp.asarray(z), xp.asarray(
78
+ log_abs_det_jacobian + log_abs_det_jacobian_flow
79
+ )
80
+
81
+ def inverse(self, z, xp: Callable = jnp):
82
+ z = jnp.asarray(z, dtype=self.dtype)
83
+ x_prime, log_abs_det_jacobian_flow = self._flow.inverse(z)
84
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
85
+ x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
86
+ return xp.asarray(x), xp.asarray(
87
+ log_abs_det_jacobian + log_abs_det_jacobian_flow
88
+ )
89
+
90
+ def log_prob(self, x, xp: Callable = jnp):
91
+ x = jnp.asarray(x, dtype=self.dtype)
92
+ x_prime, log_abs_det_jacobian = self.rescale(x)
93
+ x_prime = jnp.asarray(x_prime, dtype=self.dtype)
94
+ log_prob = self._flow.log_prob(x_prime)
95
+ return xp.asarray(log_prob + log_abs_det_jacobian)
96
+
97
+ def sample(self, n_samples: int, xp: Callable = jnp):
98
+ self.key, subkey = jrandom.split(self.key)
99
+ x_prime = self._flow.sample(subkey, (n_samples,))
100
+ x = self.inverse_rescale(x_prime)[0]
101
+ return xp.asarray(x)
102
+
103
+ def sample_and_log_prob(self, n_samples: int, xp: Callable = jnp):
104
+ self.key, subkey = jrandom.split(self.key)
105
+ x_prime = self._flow.sample(subkey, (n_samples,))
106
+ log_prob = self._flow.log_prob(x_prime)
107
+ x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
108
+ return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
109
+
110
+ def save(self, h5_file, path="flow"):
111
+ import equinox as eqx
112
+ from array_api_compat import numpy as np
113
+
114
+ from ...utils import recursively_save_to_h5_file
115
+
116
+ grp = h5_file.require_group(path)
117
+
118
+ # ---- config ----
119
+ config = self.config_dict().copy()
120
+ config.pop("key", None)
121
+ config["key_data"] = jax.random.key_data(self.key)
122
+ dtype_value = config.get("dtype")
123
+ if dtype_value is None:
124
+ dtype_value = self.dtype
125
+ else:
126
+ dtype_value = jnp.dtype(dtype_value)
127
+ print(dtype_value)
128
+ config["dtype"] = encode_dtype(jnp, dtype_value)
129
+
130
+ data_transform = config.pop("data_transform", None)
131
+ if data_transform is not None:
132
+ data_transform.save(grp, "data_transform")
133
+
134
+ recursively_save_to_h5_file(grp, "config", config)
135
+
136
+ # ---- save arrays ----
137
+ arrays, _ = eqx.partition(self._flow, eqx.is_array)
138
+ leaves, _ = jax.tree_util.tree_flatten(arrays)
139
+
140
+ params_grp = grp.require_group("params")
141
+ # clear old datasets
142
+ for name in list(params_grp.keys()):
143
+ del params_grp[name]
144
+
145
+ for i, p in enumerate(leaves):
146
+ params_grp.create_dataset(str(i), data=np.asarray(p))
147
+
148
+ @classmethod
149
+ def load(cls, h5_file, path="flow"):
150
+ import equinox as eqx
151
+
152
+ from ...utils import load_from_h5_file
153
+
154
+ grp = h5_file[path]
155
+
156
+ # ---- config ----
157
+ config = load_from_h5_file(grp, "config")
158
+ config["dtype"] = decode_dtype(jnp, config.get("dtype"))
159
+ if "data_transform" in grp:
160
+ from ...transforms import BaseTransform
161
+
162
+ config["data_transform"] = BaseTransform.load(
163
+ grp,
164
+ "data_transform",
165
+ strict=False,
166
+ )
167
+
168
+ key_data = config.pop("key_data", None)
169
+ if key_data is not None:
170
+ config["key"] = jax.random.wrap_key_data(key_data)
171
+
172
+ kwargs = config.pop("kwargs", {})
173
+ config.update(kwargs)
174
+
175
+ # build object (will replace its _flow)
176
+ obj = cls(**config)
177
+
178
+ # ---- load arrays ----
179
+ params_grp = grp["params"]
180
+ loaded_params = [
181
+ jnp.array(params_grp[str(i)][:]) for i in range(len(params_grp))
182
+ ]
183
+
184
+ # rebuild template flow
185
+ kwargs.pop("device")
186
+ flow_template = get_flow(key=jrandom.key(0), dims=obj.dims, **kwargs)
187
+ arrays_template, static = eqx.partition(flow_template, eqx.is_array)
188
+
189
+ # use treedef from template
190
+ treedef = jax.tree_util.tree_structure(arrays_template)
191
+ arrays = jax.tree_util.tree_unflatten(treedef, loaded_params)
192
+
193
+ # recombine
194
+ obj._flow = eqx.combine(static, arrays)
195
+
196
+ return obj
@@ -0,0 +1,57 @@
1
+ from typing import Callable
2
+
3
+ import flowjax.bijections
4
+ import flowjax.distributions
5
+ import flowjax.flows
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import jax.random as jrandom
9
+
10
+
11
+ def get_flow_function_class(name: str) -> Callable:
12
+ try:
13
+ return getattr(flowjax.flows, name)
14
+ except AttributeError:
15
+ raise ValueError(f"Unknown flow function: {name}")
16
+
17
+
18
+ def get_bijection_class(name: str) -> Callable:
19
+ try:
20
+ return getattr(flowjax.bijections, name)
21
+ except AttributeError:
22
+ raise ValueError(f"Unknown bijection: {name}")
23
+
24
+
25
+ def get_flow(
26
+ *,
27
+ key: jax.Array,
28
+ dims: int,
29
+ flow_type: str | Callable = "masked_autoregressive_flow",
30
+ bijection_type: str | flowjax.bijections.AbstractBijection | None = None,
31
+ bijection_kwargs: dict | None = None,
32
+ dtype=None,
33
+ **kwargs,
34
+ ) -> flowjax.distributions.Transformed:
35
+ dtype = dtype or jnp.float32
36
+
37
+ if isinstance(flow_type, str):
38
+ flow_type = get_flow_function_class(flow_type)
39
+
40
+ if isinstance(bijection_type, str):
41
+ bijection_type = get_bijection_class(bijection_type)
42
+ if bijection_type is not None:
43
+ transformer = bijection_type(**bijection_kwargs)
44
+ else:
45
+ transformer = None
46
+
47
+ if bijection_kwargs is None:
48
+ bijection_kwargs = {}
49
+
50
+ base_dist = flowjax.distributions.Normal(jnp.zeros(dims, dtype=dtype))
51
+ key, subkey = jrandom.split(key)
52
+ return flow_type(
53
+ subkey,
54
+ base_dist=base_dist,
55
+ transformer=transformer,
56
+ **kwargs,
57
+ )
File without changes
@@ -0,0 +1,344 @@
1
+ import copy
2
+ import logging
3
+ from typing import Callable
4
+
5
+ import array_api_compat.torch as torch_api
6
+ import torch
7
+ import tqdm
8
+ import zuko
9
+ from array_api_compat import is_numpy_namespace, is_torch_array
10
+
11
+ from ...history import FlowHistory
12
+ from ...transforms import IdentityTransform
13
+ from ...utils import decode_dtype, encode_dtype, resolve_dtype
14
+ from ..base import Flow
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class BaseTorchFlow(Flow):
20
+ _flow = None
21
+ xp = torch_api
22
+
23
+ def __init__(
24
+ self,
25
+ dims: int,
26
+ seed: int = 1234,
27
+ device: str = "cpu",
28
+ data_transform=None,
29
+ dtype=None,
30
+ ):
31
+ resolved_dtype = (
32
+ resolve_dtype(dtype, torch)
33
+ if dtype is not None
34
+ else torch.get_default_dtype()
35
+ )
36
+ if data_transform is None:
37
+ data_transform = IdentityTransform(self.xp, dtype=resolved_dtype)
38
+ elif getattr(data_transform, "dtype", None) is None:
39
+ data_transform.dtype = resolved_dtype
40
+ super().__init__(
41
+ dims,
42
+ device=torch.device(device or "cpu"),
43
+ data_transform=data_transform,
44
+ )
45
+ self.dtype = resolved_dtype
46
+ torch.manual_seed(seed)
47
+ self.loc = None
48
+ self.scale = None
49
+
50
+ @property
51
+ def flow(self):
52
+ return self._flow
53
+
54
+ @flow.setter
55
+ def flow(self, flow):
56
+ self._flow = flow
57
+ self._flow.to(device=self.device, dtype=self.dtype)
58
+ self._flow.compile()
59
+
60
+ def fit(self, x) -> FlowHistory:
61
+ raise NotImplementedError()
62
+
63
+ def save(self, h5_file, path="flow"):
64
+ """Save the weights of the flow to an HDF5 file."""
65
+ from ...utils import recursively_save_to_h5_file
66
+
67
+ flow_grp = h5_file.create_group(path)
68
+ # Save config
69
+ config = self.config_dict().copy()
70
+ data_transform = config.pop("data_transform", None)
71
+ dtype_value = config.get("dtype")
72
+ if dtype_value is None:
73
+ dtype_value = self.dtype
74
+ else:
75
+ dtype_value = resolve_dtype(dtype_value, torch)
76
+ config["dtype"] = encode_dtype(torch, dtype_value)
77
+ if data_transform is not None:
78
+ data_transform.save(flow_grp, "data_transform")
79
+ recursively_save_to_h5_file(flow_grp, "config", config)
80
+ # Save weights
81
+ weights_grp = flow_grp.create_group("weights")
82
+ for name, tensor in self._flow.state_dict().items():
83
+ weights_grp.create_dataset(name, data=tensor.cpu().numpy())
84
+
85
+ @classmethod
86
+ def load(self, h5_file, path="flow"):
87
+ """Load the weights of the flow from an HDF5 file."""
88
+ from ...utils import load_from_h5_file
89
+
90
+ flow_grp = h5_file[path]
91
+ # Load config
92
+ config = load_from_h5_file(flow_grp, "config")
93
+ config["dtype"] = decode_dtype(torch, config.get("dtype"))
94
+ if "data_transform" in flow_grp:
95
+ from ..transforms import BaseTransform
96
+
97
+ data_transform = BaseTransform.load(
98
+ flow_grp,
99
+ "data_transform",
100
+ strict=False,
101
+ )
102
+ config["data_transform"] = data_transform
103
+ obj = self(**config)
104
+ # Load weights
105
+ weights = {
106
+ name: torch.tensor(data[()])
107
+ for name, data in flow_grp["weights"].items()
108
+ }
109
+ obj._flow.load_state_dict(weights)
110
+ return obj
111
+
112
+
113
+ class ZukoFlow(BaseTorchFlow):
114
+ def __init__(
115
+ self,
116
+ dims,
117
+ flow_class: str | Callable = "MAF",
118
+ data_transform=None,
119
+ seed=1234,
120
+ device: str = "cpu",
121
+ dtype=None,
122
+ **kwargs,
123
+ ):
124
+ super().__init__(
125
+ dims,
126
+ device=device,
127
+ data_transform=data_transform,
128
+ seed=seed,
129
+ dtype=dtype,
130
+ )
131
+
132
+ if isinstance(flow_class, str):
133
+ FlowClass = getattr(zuko.flows, flow_class)
134
+ else:
135
+ FlowClass = flow_class
136
+
137
+ # Ints are some times passed as strings, so we convert them
138
+ if hidden_features := kwargs.pop("hidden_features", None):
139
+ kwargs["hidden_features"] = list(map(int, hidden_features))
140
+
141
+ self.flow = FlowClass(self.dims, 0, **kwargs)
142
+ logger.info(f"Initialized normalizing flow: \n {self.flow}\n")
143
+
144
+ def loss_fn(self, x):
145
+ return -self.flow().log_prob(x).mean()
146
+
147
+ def fit(
148
+ self,
149
+ x,
150
+ n_epochs: int = 100,
151
+ lr: float = 1e-3,
152
+ batch_size: int = 500,
153
+ validation_fraction: float = 0.2,
154
+ clip_grad: float | None = None,
155
+ lr_annealing: bool = False,
156
+ ):
157
+ from ...history import FlowHistory
158
+
159
+ if not is_torch_array(x):
160
+ x = torch.tensor(x, dtype=self.dtype, device=self.device)
161
+ else:
162
+ x = torch.clone(x)
163
+ x = x.type(self.dtype)
164
+ x = x.to(self.device)
165
+ x_prime = self.fit_data_transform(x)
166
+ indices = torch.randperm(x_prime.shape[0])
167
+ x_prime = x_prime[indices, ...]
168
+
169
+ n = x_prime.shape[0]
170
+ x_train = torch.as_tensor(
171
+ x_prime[: -int(validation_fraction * n)],
172
+ dtype=self.dtype,
173
+ device=self.device,
174
+ )
175
+
176
+ logger.info(
177
+ f"Training on {x_train.shape[0]} samples, "
178
+ f"validating on {x_prime.shape[0] - x_train.shape[0]} samples."
179
+ )
180
+
181
+ if torch.isnan(x_train).any():
182
+ dims_with_nan = (
183
+ torch.isnan(x_train).any(dim=0).nonzero(as_tuple=True)[0]
184
+ )
185
+ raise ValueError(
186
+ f"Training data contains NaN values in dimensions: {dims_with_nan.tolist()}"
187
+ )
188
+ if not torch.isfinite(x_train).all():
189
+ dims_with_inf = (
190
+ (~torch.isfinite(x_train)).any(dim=0).nonzero(as_tuple=True)[0]
191
+ )
192
+ raise ValueError(
193
+ f"Training data contains infinite values in dimensions: {dims_with_inf.tolist()}"
194
+ )
195
+
196
+ x_val = torch.as_tensor(
197
+ x_prime[-int(validation_fraction * n) :],
198
+ dtype=self.dtype,
199
+ device=self.device,
200
+ )
201
+ if torch.isnan(x_val).any():
202
+ raise ValueError("Validation data contains infinite values.")
203
+
204
+ if not torch.isfinite(x_val).all():
205
+ raise ValueError("Validation data contains infinite values.")
206
+
207
+ dataset = torch.utils.data.DataLoader(
208
+ torch.utils.data.TensorDataset(x_train),
209
+ shuffle=True,
210
+ batch_size=batch_size,
211
+ )
212
+ val_dataset = torch.utils.data.DataLoader(
213
+ torch.utils.data.TensorDataset(x_val),
214
+ shuffle=False,
215
+ batch_size=batch_size,
216
+ )
217
+
218
+ # Train to maximize the log-likelihood
219
+ optimizer = torch.optim.Adam(self._flow.parameters(), lr=lr)
220
+ if lr_annealing:
221
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
222
+ optimizer, n_epochs
223
+ )
224
+ history = FlowHistory()
225
+
226
+ best_val_loss = float("inf")
227
+ best_flow_state = None
228
+
229
+ with tqdm.tqdm(range(n_epochs), desc="Epochs") as pbar:
230
+ for _ in pbar:
231
+ self.flow.train()
232
+ loss_epoch = 0.0
233
+ for (x_batch,) in dataset:
234
+ loss = self.loss_fn(x_batch)
235
+ optimizer.zero_grad()
236
+ loss.backward()
237
+ if clip_grad is not None:
238
+ torch.nn.utils.clip_grad_norm_(
239
+ self.flow.parameters(), clip_grad
240
+ )
241
+ optimizer.step()
242
+ loss_epoch += loss.item()
243
+ if lr_annealing:
244
+ scheduler.step()
245
+ avg_train_loss = loss_epoch / len(dataset)
246
+ history.training_loss.append(avg_train_loss)
247
+ self.flow.eval()
248
+ val_loss = 0.0
249
+ for (x_batch,) in val_dataset:
250
+ with torch.no_grad():
251
+ val_loss += self.loss_fn(x_batch).item()
252
+ avg_val_loss = val_loss / len(val_dataset)
253
+ if avg_val_loss < best_val_loss:
254
+ best_val_loss = avg_val_loss
255
+ best_flow_state = copy.deepcopy(self.flow.state_dict())
256
+
257
+ history.validation_loss.append(avg_val_loss)
258
+ pbar.set_postfix(
259
+ train_loss=f"{avg_train_loss:.4f}",
260
+ val_loss=f"{avg_val_loss:.4f}",
261
+ )
262
+ if best_flow_state is not None:
263
+ self.flow.load_state_dict(best_flow_state)
264
+ logger.info(f"Loaded best model with val loss {best_val_loss:.4f}")
265
+
266
+ self.flow.eval()
267
+ return history
268
+
269
+ def sample_and_log_prob(self, n_samples: int, xp=torch_api):
270
+ with torch.no_grad():
271
+ x_prime, log_prob = self.flow().rsample_and_log_prob((n_samples,))
272
+ x, log_abs_det_jacobian = self.inverse_rescale(x_prime)
273
+ return xp.asarray(x), xp.asarray(log_prob - log_abs_det_jacobian)
274
+
275
+ def sample(self, n_samples: int, xp=torch_api):
276
+ with torch.no_grad():
277
+ x_prime = self.flow().rsample((n_samples,))
278
+ x = self.inverse_rescale(x_prime)[0]
279
+ return xp.asarray(x)
280
+
281
+ def log_prob(self, x, xp=torch_api):
282
+ x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
283
+ x_prime, log_abs_det_jacobian = self.rescale(x)
284
+ return xp.asarray(
285
+ self._flow().log_prob(x_prime) + log_abs_det_jacobian
286
+ )
287
+
288
+ def forward(self, x, xp=torch_api):
289
+ x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
290
+ x_prime, log_j_rescale = self.rescale(x)
291
+ z, log_abs_det_jacobian = self._flow().transform.call_and_ladj(x_prime)
292
+ if is_numpy_namespace(xp):
293
+ # Convert to numpy namespace if needed
294
+ z = z.detach().numpy()
295
+ log_abs_det_jacobian = log_abs_det_jacobian.detach().numpy()
296
+ log_j_rescale = log_j_rescale.detach().numpy()
297
+ return xp.asarray(z), xp.asarray(log_abs_det_jacobian + log_j_rescale)
298
+
299
+ def inverse(self, z, xp=torch_api):
300
+ z = torch.as_tensor(z, dtype=self.dtype, device=self.device)
301
+ with torch.no_grad():
302
+ x_prime, log_abs_det_jacobian = (
303
+ self._flow().transform.inv.call_and_ladj(z)
304
+ )
305
+ x, log_j_rescale = self.inverse_rescale(x_prime)
306
+ if is_numpy_namespace(xp):
307
+ # Convert to numpy namespace if needed
308
+ x = x.detach().numpy()
309
+ log_abs_det_jacobian = log_abs_det_jacobian.detach().numpy()
310
+ log_j_rescale = log_j_rescale.detach().numpy()
311
+ return xp.asarray(x), xp.asarray(log_j_rescale + log_abs_det_jacobian)
312
+
313
+
314
+ class ZukoFlowMatching(ZukoFlow):
315
+ def __init__(
316
+ self,
317
+ dims,
318
+ data_transform=None,
319
+ seed=1234,
320
+ device="cpu",
321
+ eta: float = 1e-3,
322
+ dtype=None,
323
+ **kwargs,
324
+ ):
325
+ kwargs.setdefault("hidden_features", 4 * [100])
326
+ super().__init__(
327
+ dims,
328
+ seed=seed,
329
+ device=device,
330
+ data_transform=data_transform,
331
+ flow_class="CNF",
332
+ dtype=dtype,
333
+ )
334
+ self.eta = eta
335
+
336
+ def loss_fn(self, theta: torch.Tensor):
337
+ t = torch.rand(
338
+ theta.shape[:-1], dtype=theta.dtype, device=theta.device
339
+ )
340
+ t_ = t[..., None]
341
+ eps = torch.randn_like(theta)
342
+ theta_prime = (1 - t_) * theta + (t_ + self.eta) * eps
343
+ v = eps - theta
344
+ return (self._flow.transform.f(t, theta_prime) - v).square().mean()