madspace 0.3.1__cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- madspace/__init__.py +1 -0
- madspace/_madspace_py.cpython-311-x86_64-linux-gnu.so +0 -0
- madspace/_madspace_py.pyi +2189 -0
- madspace/_madspace_py_loader.py +111 -0
- madspace/include/madspace/constants.h +17 -0
- madspace/include/madspace/madcode/function.h +102 -0
- madspace/include/madspace/madcode/function_builder_mixin.h +591 -0
- madspace/include/madspace/madcode/instruction.h +208 -0
- madspace/include/madspace/madcode/opcode_mixin.h +134 -0
- madspace/include/madspace/madcode/optimizer.h +31 -0
- madspace/include/madspace/madcode/type.h +203 -0
- madspace/include/madspace/madcode.h +6 -0
- madspace/include/madspace/phasespace/base.h +74 -0
- madspace/include/madspace/phasespace/channel_weight_network.h +46 -0
- madspace/include/madspace/phasespace/channel_weights.h +51 -0
- madspace/include/madspace/phasespace/chili.h +32 -0
- madspace/include/madspace/phasespace/cross_section.h +47 -0
- madspace/include/madspace/phasespace/cuts.h +34 -0
- madspace/include/madspace/phasespace/discrete_flow.h +44 -0
- madspace/include/madspace/phasespace/discrete_sampler.h +53 -0
- madspace/include/madspace/phasespace/flow.h +53 -0
- madspace/include/madspace/phasespace/histograms.h +26 -0
- madspace/include/madspace/phasespace/integrand.h +204 -0
- madspace/include/madspace/phasespace/invariants.h +26 -0
- madspace/include/madspace/phasespace/luminosity.h +41 -0
- madspace/include/madspace/phasespace/matrix_element.h +70 -0
- madspace/include/madspace/phasespace/mlp.h +37 -0
- madspace/include/madspace/phasespace/multichannel.h +49 -0
- madspace/include/madspace/phasespace/observable.h +85 -0
- madspace/include/madspace/phasespace/pdf.h +78 -0
- madspace/include/madspace/phasespace/phasespace.h +67 -0
- madspace/include/madspace/phasespace/rambo.h +26 -0
- madspace/include/madspace/phasespace/scale.h +52 -0
- madspace/include/madspace/phasespace/t_propagator_mapping.h +34 -0
- madspace/include/madspace/phasespace/three_particle.h +68 -0
- madspace/include/madspace/phasespace/topology.h +116 -0
- madspace/include/madspace/phasespace/two_particle.h +63 -0
- madspace/include/madspace/phasespace/vegas.h +53 -0
- madspace/include/madspace/phasespace.h +27 -0
- madspace/include/madspace/runtime/context.h +147 -0
- madspace/include/madspace/runtime/discrete_optimizer.h +24 -0
- madspace/include/madspace/runtime/event_generator.h +257 -0
- madspace/include/madspace/runtime/format.h +68 -0
- madspace/include/madspace/runtime/io.h +343 -0
- madspace/include/madspace/runtime/lhe_output.h +132 -0
- madspace/include/madspace/runtime/logger.h +46 -0
- madspace/include/madspace/runtime/runtime_base.h +39 -0
- madspace/include/madspace/runtime/tensor.h +603 -0
- madspace/include/madspace/runtime/thread_pool.h +101 -0
- madspace/include/madspace/runtime/vegas_optimizer.h +26 -0
- madspace/include/madspace/runtime.h +12 -0
- madspace/include/madspace/umami.h +202 -0
- madspace/include/madspace/util.h +142 -0
- madspace/lib/libmadspace.so +0 -0
- madspace/lib/libmadspace_cpu.so +0 -0
- madspace/lib/libmadspace_cpu_avx2.so +0 -0
- madspace/lib/libmadspace_cpu_avx512.so +0 -0
- madspace/lib/libmadspace_cuda.so +0 -0
- madspace/lib/libmadspace_hip.so +0 -0
- madspace/madnis/__init__.py +44 -0
- madspace/madnis/buffer.py +167 -0
- madspace/madnis/channel_grouping.py +85 -0
- madspace/madnis/distribution.py +103 -0
- madspace/madnis/integrand.py +175 -0
- madspace/madnis/integrator.py +973 -0
- madspace/madnis/interface.py +191 -0
- madspace/madnis/losses.py +186 -0
- madspace/torch.py +82 -0
- madspace-0.3.1.dist-info/METADATA +71 -0
- madspace-0.3.1.dist-info/RECORD +75 -0
- madspace-0.3.1.dist-info/WHEEL +6 -0
- madspace-0.3.1.dist-info/licenses/LICENSE +21 -0
- madspace.libs/libgfortran-83c28eba.so.5.0.0 +0 -0
- madspace.libs/libopenblas-r0-11edc3fa.3.15.so +0 -0
- madspace.libs/libquadmath-2284e583.so.0.0.0 +0 -0
|
@@ -0,0 +1,973 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import signal
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Callable, Iterable
|
|
5
|
+
from dataclasses import astuple, dataclass
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
from torch.optim import Optimizer
|
|
13
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
14
|
+
|
|
15
|
+
from .buffer import Buffer
|
|
16
|
+
from .distribution import Distribution
|
|
17
|
+
from .integrand import Integrand
|
|
18
|
+
from .losses import MultiChannelLoss, kl_divergence, stratified_variance, variance
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class TrainingStatus:
|
|
23
|
+
"""
|
|
24
|
+
Contains the MadNIS training status to pass it to a callback function.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
step: optimization step
|
|
28
|
+
loss: loss from the optimization step
|
|
29
|
+
buffered: whether the optimization was performed on buffered samples
|
|
30
|
+
learning_rate: current learning rate if learning rate scheduler is present
|
|
31
|
+
dropped_channels: number of channels dropped after this optimization step
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
step: int
|
|
35
|
+
loss: float
|
|
36
|
+
buffered: bool
|
|
37
|
+
learning_rate: float | None
|
|
38
|
+
dropped_channels: int
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class SampleBatch:
|
|
43
|
+
"""
|
|
44
|
+
Contains a batch of samples
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
x: samples generated by the flow, shape (n, dim)
|
|
48
|
+
y: remapped samples returned by the integrand, shape (n, remapped_dim)
|
|
49
|
+
q_sample: probabilities of the samples, shape (n, )
|
|
50
|
+
func_vals: integrand value, shape (n, )
|
|
51
|
+
channels: channels indices for multi-channel integration, shape (n, ), otherwise None
|
|
52
|
+
alphas_prior: prior channel weights, shape (n, channels), or None for single-channel
|
|
53
|
+
integration
|
|
54
|
+
alpha_channel_indices: channel indices if not all prior channel weights are stored,
|
|
55
|
+
otherwise None
|
|
56
|
+
integration_channels: index of the channel group in case the integration is performed at the
|
|
57
|
+
level of channel groups, shape (n, ), otherwise None
|
|
58
|
+
weights: integration weight, shape (n, ). Only set when returned from Integrator.sample
|
|
59
|
+
function, otherwise None.
|
|
60
|
+
alphas: channel weights including learned correction, shape (n, channels). Only set when
|
|
61
|
+
returned from Integrator.sample function, otherwise None.
|
|
62
|
+
zero_counts: channel-wise counts of samples with zero-weights that are not included in the
|
|
63
|
+
batch, shape (channels, ). This field is ignored by most methods, as it behaves
|
|
64
|
+
does not have the batch size as its first dimension
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
x: torch.Tensor
|
|
68
|
+
y: torch.Tensor | None
|
|
69
|
+
q_sample: torch.Tensor
|
|
70
|
+
func_vals: torch.Tensor
|
|
71
|
+
channels: torch.Tensor | None
|
|
72
|
+
alphas_prior: torch.Tensor | None = None
|
|
73
|
+
alpha_channel_indices: torch.Tensor | None = None
|
|
74
|
+
integration_channels: torch.Tensor | None = None
|
|
75
|
+
weights: torch.Tensor | None = None
|
|
76
|
+
alphas: torch.Tensor | None = None
|
|
77
|
+
zero_counts: torch.Tensor | None = None
|
|
78
|
+
|
|
79
|
+
def __iter__(self) -> Iterable[torch.Tensor | None]:
|
|
80
|
+
"""
|
|
81
|
+
Returns iterator over the fields of the class
|
|
82
|
+
"""
|
|
83
|
+
return iter(astuple(self)[:-1])
|
|
84
|
+
|
|
85
|
+
def map(self, func: Callable[[torch.Tensor], torch.Tensor]) -> "SampleBatch":
|
|
86
|
+
"""
|
|
87
|
+
Applies function to all fields in the batch that are not None and returns a new SampleBatch
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
func: function that is applied to all fields in the batch. Expects a tensor as argument
|
|
91
|
+
and returns a new tensor
|
|
92
|
+
Returns:
|
|
93
|
+
Transformed SampleBatch
|
|
94
|
+
"""
|
|
95
|
+
return SampleBatch(*(None if field is None else func(field) for field in self))
|
|
96
|
+
|
|
97
|
+
def split(self, batch_size: int) -> Iterable["SampleBatch"]:
|
|
98
|
+
"""
|
|
99
|
+
Splits up the fields into batches and yields SampleBatch objects for every batch.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
batch_size: maximal size of the batches
|
|
103
|
+
Returns:
|
|
104
|
+
Iterator over the batches
|
|
105
|
+
"""
|
|
106
|
+
for batch in zip(
|
|
107
|
+
*(
|
|
108
|
+
itertools.repeat(None) if field is None else field.split(batch_size)
|
|
109
|
+
for field in self
|
|
110
|
+
)
|
|
111
|
+
):
|
|
112
|
+
yield SampleBatch(*batch)
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
def cat(batches: Iterable["SampleBatch"]) -> "SampleBatch":
|
|
116
|
+
"""
|
|
117
|
+
Concatenates multiple batches. If the field zero_counts is not None, the zero_counts of
|
|
118
|
+
all batches are added.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
batches: Iterable over SampleBatch objects
|
|
122
|
+
Return:
|
|
123
|
+
New SamplaBatch object containing the concatenated batches
|
|
124
|
+
"""
|
|
125
|
+
cat_batch = SampleBatch(
|
|
126
|
+
*(
|
|
127
|
+
None if item[0] is None else torch.cat(item, dim=0)
|
|
128
|
+
for item in zip(*batches)
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
if batches[0].zero_counts is not None:
|
|
132
|
+
cat_batch.zero_counts = torch.stack(
|
|
133
|
+
[batch.zero_counts for batch in batches], dim=1
|
|
134
|
+
).sum(dim=1)
|
|
135
|
+
return cat_batch
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Integrator(nn.Module):
|
|
139
|
+
"""
|
|
140
|
+
Implements MadNIS training and integration logic. MadNIS integrators are torch modules, so
|
|
141
|
+
their state can easily be saved and loaded using the torch.save and torch.load methods.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
integrand: Callable[[torch.Tensor], torch.Tensor] | Integrand,
|
|
147
|
+
dims: int = 0,
|
|
148
|
+
flow: Distribution | None = None,
|
|
149
|
+
flow_kwargs: dict[str, Any] = {},
|
|
150
|
+
discrete_flow_kwargs: dict[str, Any] = {},
|
|
151
|
+
discrete_model: Literal["made", "transformer"] = "made",
|
|
152
|
+
train_channel_weights: bool = True,
|
|
153
|
+
cwnet: nn.Module | None = None,
|
|
154
|
+
cwnet_kwargs: dict[str, Any] = {},
|
|
155
|
+
loss: MultiChannelLoss | None = None,
|
|
156
|
+
optimizer: (
|
|
157
|
+
Optimizer | Callable[[Iterable[nn.Parameter]], Optimizer] | None
|
|
158
|
+
) = None,
|
|
159
|
+
batch_size: int = 1024,
|
|
160
|
+
batch_size_per_channel: int = 0,
|
|
161
|
+
learning_rate: float = 1e-3,
|
|
162
|
+
scheduler: LRScheduler | Callable[[Optimizer], LRScheduler] | None = None,
|
|
163
|
+
uniform_channel_ratio: float = 1.0,
|
|
164
|
+
integration_history_length: int = 20,
|
|
165
|
+
drop_zero_integrands: bool = False,
|
|
166
|
+
batch_size_threshold: float = 0.5,
|
|
167
|
+
buffer_capacity: int = 0,
|
|
168
|
+
minimum_buffer_size: int = 50,
|
|
169
|
+
buffered_steps: int = 0,
|
|
170
|
+
max_stored_channel_weights: int | None = None,
|
|
171
|
+
channel_dropping_threshold: float = 0.0,
|
|
172
|
+
channel_dropping_interval: int = 100,
|
|
173
|
+
channel_grouping_mode: Literal["none", "uniform", "learned"] = "none",
|
|
174
|
+
freeze_cwnet_iteration: int | None = None,
|
|
175
|
+
device: torch.device | None = None,
|
|
176
|
+
dtype: torch.dtype | None = None,
|
|
177
|
+
):
|
|
178
|
+
"""
|
|
179
|
+
Args:
|
|
180
|
+
integrand: the function to be integrated. In the case of a simple single-channel
|
|
181
|
+
integration, the integrand function can directly be passed to the integrator.
|
|
182
|
+
In more complicated cases, like multi-channel integrals, use the ``Integrand`` class.
|
|
183
|
+
dims: dimension of the integration space. Only required if a simple function is given
|
|
184
|
+
as integrand.
|
|
185
|
+
flow: sampling distribution used for the integration. If None, a flow is constructed
|
|
186
|
+
using the ``Flow`` class. Otherwise, it has to be compatible with a normalizing flow,
|
|
187
|
+
i.e. have the interface defined in the ``Distribution`` class.
|
|
188
|
+
flow_kwargs: If flow is None, these keyword arguments are passed to the `Flow`
|
|
189
|
+
constructor.
|
|
190
|
+
discrete_flow_kwargs: If flow is None, these keyword arguments are passed to the
|
|
191
|
+
``MixedFlow`` or ``DiscreteMADE`` constructor.
|
|
192
|
+
train_channel_weights: If True, construct a channel weight network and train it. Only
|
|
193
|
+
necessary if cwnet is None.
|
|
194
|
+
cwnet: network used for the trainable channel weights. If None and
|
|
195
|
+
train_channel_weights is True, the cwnet is built using the ``MLP`` class.
|
|
196
|
+
cwnet_kwargs: If cwnet is None and train_channel_weights is True, these keyword
|
|
197
|
+
arguments are passed to the ``MLP`` constructor.
|
|
198
|
+
loss: Loss function used for training. If not provided, the KL divergence is chosen in
|
|
199
|
+
the single-channel case and the stratified variance is chosen in the multi-channel
|
|
200
|
+
case.
|
|
201
|
+
optimizer: optimizer for the training. Can be an optimizer object or function that is
|
|
202
|
+
called with the model parameters as argument and returns the optimizer. If None, the
|
|
203
|
+
Adam optimizer is used.
|
|
204
|
+
batch_size: Training batch size
|
|
205
|
+
batch_size_per_channel: used to compute the batch size as a function of the number of
|
|
206
|
+
active channels, ``batch_size + n_active_channels * batch_size_per_channel``
|
|
207
|
+
learning_rate: learning rate used for the Adam optimizer
|
|
208
|
+
scheduler: learning rate scheduler for the training. Can be a learning rate scheduler
|
|
209
|
+
object or a function that gets the optimizer as argument and returns the scheduler.
|
|
210
|
+
If None, a constant learning rate is used.
|
|
211
|
+
uniform_channel_ratio: part of samples in each batch that will be distributed equally
|
|
212
|
+
between all channels, value has to be between 0 and 1.
|
|
213
|
+
integration_history_length: number of batches for which the channel-wise means and
|
|
214
|
+
variances are stored. This is used for stratified sampling during integration, and
|
|
215
|
+
during the training if uniform_channel_ratio is different from one.
|
|
216
|
+
drop_zero_integrands: If True, points with integrand zero are dropped and not used for
|
|
217
|
+
the optimization.
|
|
218
|
+
batch_size_threshold: New samples are drawn until the number of samples is at least
|
|
219
|
+
batch_size_threshold * batch_size.
|
|
220
|
+
buffer_capacity: number of samples that are stored for buffered training
|
|
221
|
+
minimum_buffer_size: minimal size of the buffer to run buffered training
|
|
222
|
+
buffered_steps: number of optimization steps on buffered samples after every online
|
|
223
|
+
training step
|
|
224
|
+
max_stored_channel_weights: number of prior channel weights that are buffered for each
|
|
225
|
+
sample. If None, all prior channel weights are saved, otherwise only those for the
|
|
226
|
+
channels with the largest contributions.
|
|
227
|
+
channel_dropping_threshold: all channels which a cumulated contribution to the
|
|
228
|
+
integrand that is smaller than this threshold are dropped
|
|
229
|
+
channel_dropping_interval: number of training steps after which channel dropping
|
|
230
|
+
is performed
|
|
231
|
+
channel_grouping_mode: If "none" all channels are treated as separate channels in the
|
|
232
|
+
loss and integration, even when they grouped together. If "uniform", the channels
|
|
233
|
+
within each group are sampled with equal probability. If "learned", a discrete
|
|
234
|
+
normalizing flow is used to sample the channel index within a group.
|
|
235
|
+
freeze_cwnet_iteration: If not None, specifies the training iteration after which the
|
|
236
|
+
channel weight network is frozen
|
|
237
|
+
device: torch device used for training and integration. If None, use default device.
|
|
238
|
+
dtype: torch dtype used for training and integration. If None, use default dtype.
|
|
239
|
+
"""
|
|
240
|
+
super().__init__()
|
|
241
|
+
|
|
242
|
+
if not isinstance(integrand, Integrand):
|
|
243
|
+
integrand = Integrand(integrand, dims)
|
|
244
|
+
self.integrand = integrand
|
|
245
|
+
self.multichannel = integrand.channel_count is not None
|
|
246
|
+
discrete_dims = integrand.discrete_dims
|
|
247
|
+
input_dim = integrand.input_dim
|
|
248
|
+
if integrand.channel_grouping is None or channel_grouping_mode == "none":
|
|
249
|
+
self.integration_channel_count = integrand.channel_count
|
|
250
|
+
self.group_channels = False
|
|
251
|
+
self.group_channels_uniform = False
|
|
252
|
+
elif channel_grouping_mode == "uniform":
|
|
253
|
+
self.integration_channel_count = integrand.unique_channel_count()
|
|
254
|
+
self.group_channels = True
|
|
255
|
+
self.group_channels_uniform = True
|
|
256
|
+
elif channel_grouping_mode == "learned":
|
|
257
|
+
self.integration_channel_count = integrand.unique_channel_count()
|
|
258
|
+
self.group_channels = True
|
|
259
|
+
self.group_channels_uniform = False
|
|
260
|
+
self.channel_group_dim = (
|
|
261
|
+
0
|
|
262
|
+
if integrand.discrete_dims_position == "first"
|
|
263
|
+
else input_dim - len(discrete_dims)
|
|
264
|
+
)
|
|
265
|
+
# TODO: provide default implementation of discrete prior
|
|
266
|
+
# discrete_dims.insert(0, max(len(group.channel_indices) for group in integrand.channel_grouping.groups))
|
|
267
|
+
# input_dim += 1
|
|
268
|
+
else:
|
|
269
|
+
raise ValueError(f"Unknown channel grouping mode {channel_grouping_mode}")
|
|
270
|
+
|
|
271
|
+
if self.group_channels:
|
|
272
|
+
self.register_buffer(
|
|
273
|
+
"channel_group_sizes",
|
|
274
|
+
torch.tensor(
|
|
275
|
+
[
|
|
276
|
+
len(group.channel_indices)
|
|
277
|
+
for group in integrand.channel_grouping.groups
|
|
278
|
+
]
|
|
279
|
+
),
|
|
280
|
+
)
|
|
281
|
+
self.register_buffer(
|
|
282
|
+
"channel_group_remap",
|
|
283
|
+
torch.zeros(
|
|
284
|
+
(len(self.channel_group_sizes), max(self.channel_group_sizes)),
|
|
285
|
+
dtype=torch.int64,
|
|
286
|
+
),
|
|
287
|
+
)
|
|
288
|
+
for group in integrand.channel_grouping.groups:
|
|
289
|
+
for i, chan_index in enumerate(group.channel_indices):
|
|
290
|
+
self.channel_group_remap[group.group_index][i] = chan_index
|
|
291
|
+
|
|
292
|
+
if flow is None:
|
|
293
|
+
channel_remap_function = (
|
|
294
|
+
None
|
|
295
|
+
if self.group_channels and not self.group_channels_uniform
|
|
296
|
+
else self.integrand.remap_channels
|
|
297
|
+
)
|
|
298
|
+
if len(discrete_dims) == 0:
|
|
299
|
+
raise NotImplementedError("removed in MadSpace version of MadNIS")
|
|
300
|
+
elif len(discrete_dims) == input_dim:
|
|
301
|
+
if discrete_model == "made":
|
|
302
|
+
raise NotImplementedError("removed in MadSpace version of MadNIS")
|
|
303
|
+
elif discrete_model == "transformer":
|
|
304
|
+
raise NotImplementedError("removed in MadSpace version of MadNIS")
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError("discrete_model must be 'made' or 'transformer'")
|
|
307
|
+
else:
|
|
308
|
+
discrete_kwargs = dict(
|
|
309
|
+
prior_prob_function=integrand.discrete_prior_prob_function,
|
|
310
|
+
**discrete_flow_kwargs,
|
|
311
|
+
)
|
|
312
|
+
if self.multichannel:
|
|
313
|
+
discrete_kwargs["channel_remap_function"] = channel_remap_function
|
|
314
|
+
raise NotImplementedError("removed in MadSpace version of MadNIS")
|
|
315
|
+
|
|
316
|
+
if cwnet is None and train_channel_weights and self.multichannel:
|
|
317
|
+
raise NotImplementedError("removed in MadSpace version of MadNIS")
|
|
318
|
+
if cwnet is None:
|
|
319
|
+
parameters = flow.parameters()
|
|
320
|
+
else:
|
|
321
|
+
parameters = itertools.chain(flow.parameters(), cwnet.parameters())
|
|
322
|
+
if optimizer is None:
|
|
323
|
+
self.optimizer = torch.optim.Adam(parameters, learning_rate)
|
|
324
|
+
elif isinstance(optimizer, Optimizer):
|
|
325
|
+
self.optimizer = optimizer
|
|
326
|
+
else:
|
|
327
|
+
self.optimizer = optimizer(parameters)
|
|
328
|
+
if scheduler is None or isinstance(scheduler, LRScheduler):
|
|
329
|
+
self.scheduler = scheduler
|
|
330
|
+
else:
|
|
331
|
+
self.scheduler = scheduler(self.optimizer)
|
|
332
|
+
|
|
333
|
+
self.flow = flow
|
|
334
|
+
self.cwnet = cwnet
|
|
335
|
+
self.batch_size_offset = batch_size
|
|
336
|
+
self.batch_size_per_channel = batch_size_per_channel
|
|
337
|
+
self.batch_size = batch_size + batch_size_per_channel * (
|
|
338
|
+
self.integration_channel_count or 1
|
|
339
|
+
)
|
|
340
|
+
self.uniform_channel_ratio = uniform_channel_ratio
|
|
341
|
+
self.drop_zero_integrands = drop_zero_integrands
|
|
342
|
+
self.batch_size_threshold = batch_size_threshold
|
|
343
|
+
if loss is None:
|
|
344
|
+
self.loss = stratified_variance if self.multichannel else kl_divergence
|
|
345
|
+
else:
|
|
346
|
+
self.loss = loss
|
|
347
|
+
|
|
348
|
+
self.minimum_buffer_size = minimum_buffer_size
|
|
349
|
+
self.buffered_steps = buffered_steps
|
|
350
|
+
self.max_stored_channel_weights = (
|
|
351
|
+
None
|
|
352
|
+
if max_stored_channel_weights is None
|
|
353
|
+
or integrand.channel_count is None
|
|
354
|
+
or max_stored_channel_weights >= integrand.channel_count
|
|
355
|
+
else max_stored_channel_weights
|
|
356
|
+
)
|
|
357
|
+
if buffer_capacity > 0:
|
|
358
|
+
channel_count = self.max_stored_channel_weights or integrand.channel_count
|
|
359
|
+
buffer_fields = [
|
|
360
|
+
(input_dim,),
|
|
361
|
+
None if integrand.remapped_dim is None else (integrand.remapped_dim,),
|
|
362
|
+
(),
|
|
363
|
+
(),
|
|
364
|
+
None if integrand.channel_count is None else (),
|
|
365
|
+
None if not integrand.has_channel_weight_prior else (channel_count,),
|
|
366
|
+
None if self.max_stored_channel_weights is None else (channel_count,),
|
|
367
|
+
() if self.group_channels else None,
|
|
368
|
+
None,
|
|
369
|
+
None,
|
|
370
|
+
]
|
|
371
|
+
buffer_dtypes = [
|
|
372
|
+
None,
|
|
373
|
+
None,
|
|
374
|
+
None,
|
|
375
|
+
None,
|
|
376
|
+
torch.int64,
|
|
377
|
+
None,
|
|
378
|
+
torch.int64,
|
|
379
|
+
torch.int64,
|
|
380
|
+
None,
|
|
381
|
+
None,
|
|
382
|
+
]
|
|
383
|
+
self.buffer = Buffer(
|
|
384
|
+
buffer_capacity, buffer_fields, persistent=False, dtypes=buffer_dtypes
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
self.buffer = None
|
|
388
|
+
self.channel_dropping_threshold = channel_dropping_threshold
|
|
389
|
+
self.channel_dropping_interval = channel_dropping_interval
|
|
390
|
+
self.freeze_cwnet_iteration = freeze_cwnet_iteration
|
|
391
|
+
hist_shape = (self.integration_channel_count or 1,)
|
|
392
|
+
self.integration_history = Buffer(
|
|
393
|
+
integration_history_length,
|
|
394
|
+
[hist_shape, hist_shape, hist_shape],
|
|
395
|
+
dtypes=[None, None, torch.int64],
|
|
396
|
+
)
|
|
397
|
+
self.step = 0
|
|
398
|
+
self.step_type_count = 0
|
|
399
|
+
if self.multichannel:
|
|
400
|
+
self.register_buffer(
|
|
401
|
+
"active_channels_mask",
|
|
402
|
+
torch.ones((self.integration_channel_count,), dtype=torch.bool),
|
|
403
|
+
)
|
|
404
|
+
# Dummy to determine device and dtype
|
|
405
|
+
self.register_buffer("dummy", torch.zeros((1,)))
|
|
406
|
+
|
|
407
|
+
if device is not None:
|
|
408
|
+
self.to(device)
|
|
409
|
+
if dtype is not None:
|
|
410
|
+
self.to(dtype)
|
|
411
|
+
|
|
412
|
+
def _get_alphas(self, samples: SampleBatch) -> torch.Tensor:
|
|
413
|
+
"""
|
|
414
|
+
Runs the channel weight network and returns the normalized channel weights, taking prior
|
|
415
|
+
channel weights and dropped channels into account.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
samples: batch of samples
|
|
419
|
+
Returns:
|
|
420
|
+
channel weights, shape (n, channels)
|
|
421
|
+
"""
|
|
422
|
+
if self.cwnet is None:
|
|
423
|
+
if samples.alphas_prior is None:
|
|
424
|
+
return samples.x.new_full(
|
|
425
|
+
(samples.x.shape[0], self.integrand.channel_count),
|
|
426
|
+
1 / self.integrand.channel_count,
|
|
427
|
+
)
|
|
428
|
+
return self._restore_prior(samples)
|
|
429
|
+
|
|
430
|
+
if samples.alphas_prior is None:
|
|
431
|
+
alpha_prior = samples.x.new_ones(
|
|
432
|
+
(samples.x.shape[0], self.integrand.channel_count)
|
|
433
|
+
)
|
|
434
|
+
else:
|
|
435
|
+
alpha_prior = self._restore_prior(samples)
|
|
436
|
+
|
|
437
|
+
if self.group_channels:
|
|
438
|
+
active_channels_mask = self.active_channels_mask[
|
|
439
|
+
self.integrand.remap_channels(
|
|
440
|
+
torch.arange(alpha_prior.shape[1], device=alpha_prior.device)
|
|
441
|
+
)
|
|
442
|
+
]
|
|
443
|
+
else:
|
|
444
|
+
active_channels_mask = self.active_channels_mask
|
|
445
|
+
|
|
446
|
+
alpha = alpha_prior * active_channels_mask
|
|
447
|
+
mask = samples.func_vals != 0
|
|
448
|
+
y = samples.x if samples.y is None else samples.y
|
|
449
|
+
alpha[mask] *= self.cwnet(y[mask]).exp()
|
|
450
|
+
ret = alpha / alpha.sum(dim=1, keepdim=True)
|
|
451
|
+
return ret
|
|
452
|
+
|
|
453
|
+
def _compute_integral(
|
|
454
|
+
self, samples: SampleBatch
|
|
455
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
456
|
+
"""
|
|
457
|
+
Computes normalized integrand and channel-wise means, variances and counts
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
samples: batch of samples
|
|
461
|
+
Returns:
|
|
462
|
+
A tuple containing
|
|
463
|
+
|
|
464
|
+
- normalized integrand, shape (n, )
|
|
465
|
+
- channel-wise means of the integral, shape (channels, )
|
|
466
|
+
- channel-wise variances of the integral, shape (channels, )
|
|
467
|
+
- channel-wise number of samples, shape (channels, )
|
|
468
|
+
"""
|
|
469
|
+
if self.multichannel:
|
|
470
|
+
alphas = torch.gather(
|
|
471
|
+
self._get_alphas(samples), index=samples.channels[:, None], dim=1
|
|
472
|
+
)[:, 0]
|
|
473
|
+
f_true = alphas * samples.func_vals
|
|
474
|
+
f_div_q = f_true.detach() / samples.q_sample
|
|
475
|
+
channels = (
|
|
476
|
+
samples.channels
|
|
477
|
+
if samples.integration_channels is None
|
|
478
|
+
else samples.integration_channels
|
|
479
|
+
)
|
|
480
|
+
counts = torch.bincount(channels, minlength=self.integration_channel_count)
|
|
481
|
+
if samples.zero_counts is not None:
|
|
482
|
+
counts += samples.zero_counts
|
|
483
|
+
means = torch.bincount(
|
|
484
|
+
channels,
|
|
485
|
+
weights=f_div_q,
|
|
486
|
+
minlength=self.integration_channel_count,
|
|
487
|
+
) / counts.clip(min=1)
|
|
488
|
+
variances = (
|
|
489
|
+
torch.bincount(
|
|
490
|
+
channels,
|
|
491
|
+
weights=(f_div_q - means[channels]).square(),
|
|
492
|
+
minlength=self.integration_channel_count,
|
|
493
|
+
)
|
|
494
|
+
/ counts
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
f_div_q = samples.func_vals / samples.q_sample
|
|
498
|
+
f_true = samples.func_vals
|
|
499
|
+
means = f_div_q.mean(dim=0, keepdim=True)
|
|
500
|
+
counts = torch.full((1,), f_div_q.shape[0], device=means.device)
|
|
501
|
+
variances = f_div_q.var(dim=0, keepdim=True)
|
|
502
|
+
return f_true, means, variances, counts
|
|
503
|
+
|
|
504
|
+
def _optimization_step(
|
|
505
|
+
self,
|
|
506
|
+
samples: SampleBatch,
|
|
507
|
+
) -> tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
508
|
+
"""
|
|
509
|
+
Perform one optimization step of the networks for the given samples
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
samples: batch of samples
|
|
513
|
+
Returns:
|
|
514
|
+
A tuple containing
|
|
515
|
+
|
|
516
|
+
- value of the loss
|
|
517
|
+
- channel-wise means of the integral, shape (channels, )
|
|
518
|
+
- channel-wise variances of the integral, shape (channels, )
|
|
519
|
+
- channel-wise number of samples, shape (channels, )
|
|
520
|
+
"""
|
|
521
|
+
self.optimizer.zero_grad()
|
|
522
|
+
# TODO: depending on the loss function and for drop_zero_weights=False, we can encounter
|
|
523
|
+
# zero-weight events here and it might be sufficient to evaluate the flow for events with
|
|
524
|
+
# func_val != 0. That might however give wrong results for other loss functions
|
|
525
|
+
q_test = self.flow.prob(
|
|
526
|
+
samples.x,
|
|
527
|
+
channel=(
|
|
528
|
+
samples.integration_channels
|
|
529
|
+
if self.group_channels and not self.group_channels_uniform
|
|
530
|
+
else samples.channels
|
|
531
|
+
),
|
|
532
|
+
)
|
|
533
|
+
f_true, means, variances, counts = self._compute_integral(samples)
|
|
534
|
+
loss = self.loss(
|
|
535
|
+
f_true,
|
|
536
|
+
q_test,
|
|
537
|
+
q_sample=samples.q_sample,
|
|
538
|
+
channels=(
|
|
539
|
+
samples.channels
|
|
540
|
+
if samples.integration_channels is None
|
|
541
|
+
else samples.integration_channels
|
|
542
|
+
),
|
|
543
|
+
)
|
|
544
|
+
if loss.isnan().item():
|
|
545
|
+
warnings.warn("nan batch: skipping optimization")
|
|
546
|
+
else:
|
|
547
|
+
loss.backward()
|
|
548
|
+
self.optimizer.step()
|
|
549
|
+
return loss.item(), means, variances, counts
|
|
550
|
+
|
|
551
|
+
def _restore_prior(self, samples: SampleBatch) -> torch.Tensor:
|
|
552
|
+
"""
|
|
553
|
+
Restores the full prior channel weights if only the largest channel weights and their
|
|
554
|
+
indices were saved.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
samples: batch of samples
|
|
558
|
+
Returns:
|
|
559
|
+
Tensor of prior channel weights with shape (n, channels)
|
|
560
|
+
"""
|
|
561
|
+
if samples.alpha_channel_indices is None:
|
|
562
|
+
return samples.alphas_prior
|
|
563
|
+
|
|
564
|
+
alphas_prior_reduced = samples.alphas_prior
|
|
565
|
+
epsilon = torch.finfo(alphas_prior_reduced.dtype).eps
|
|
566
|
+
|
|
567
|
+
# strategy 1: distribute difference to 1 evenly among non-stored channels
|
|
568
|
+
# n_rest = self.integrand.channel_count - self.max_stored_channel_weights
|
|
569
|
+
# alphas_prior = torch.clamp(
|
|
570
|
+
# (1 - alphas_prior_reduced.sum(dim=1, keepdims=True)) / n_rest,
|
|
571
|
+
# min=epsilon,
|
|
572
|
+
# ).repeat(1, self.integrand.channel_count)
|
|
573
|
+
# alphas_prior.scatter_(1, samples.alpha_channel_indices, alphas_prior_reduced)
|
|
574
|
+
# return alphas_prior
|
|
575
|
+
|
|
576
|
+
# strategy 2: set non-stored channel alphas to epsilon, normalize again
|
|
577
|
+
alphas_prior = alphas_prior_reduced.new_full(
|
|
578
|
+
(alphas_prior_reduced.shape[0], self.integrand.channel_count), epsilon
|
|
579
|
+
)
|
|
580
|
+
alphas_prior.scatter_(1, samples.alpha_channel_indices, alphas_prior_reduced)
|
|
581
|
+
return alphas_prior / alphas_prior.sum(dim=1, keepdims=True)
|
|
582
|
+
|
|
583
|
+
def _get_channel_contributions(
|
|
584
|
+
self,
|
|
585
|
+
expect_full_history: bool,
|
|
586
|
+
channel_weight_mode: Literal["variance", "mean"],
|
|
587
|
+
) -> torch.Tensor:
|
|
588
|
+
"""
|
|
589
|
+
Uses the list of saved variances or means to compute the contribution of each channel for
|
|
590
|
+
stratified sampling.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
expect_full_history: If True, the integration history has to be full, otherwise uniform
|
|
594
|
+
weights are returned.
|
|
595
|
+
channel_weight_mode: specifies whether the channels are weighted by their mean or
|
|
596
|
+
variance. Note that weighting by mean can lead to problems for non-positive functions
|
|
597
|
+
Returns:
|
|
598
|
+
weights for sampling the channels with shape (channels,)
|
|
599
|
+
"""
|
|
600
|
+
min_len = self.integration_history.capacity if expect_full_history else 1
|
|
601
|
+
if self.integration_history.size < min_len:
|
|
602
|
+
return torch.ones(
|
|
603
|
+
self.integration_channel_count,
|
|
604
|
+
device=self.dummy.device,
|
|
605
|
+
dtype=self.dummy.dtype,
|
|
606
|
+
)
|
|
607
|
+
mean_hist, var_hist, count_hist = self.integration_history
|
|
608
|
+
contrib_hist = mean_hist.abs() if channel_weight_mode == "mean" else var_hist
|
|
609
|
+
count_hist = torch.where(
|
|
610
|
+
contrib_hist.isnan(), np.nan, count_hist.to(contrib_hist.dtype)
|
|
611
|
+
)
|
|
612
|
+
hist_weights = count_hist / count_hist.nansum(dim=0)
|
|
613
|
+
return torch.nansum(hist_weights * contrib_hist, dim=0).sqrt()
|
|
614
|
+
|
|
615
|
+
def _disable_unused_channels(self) -> int:
|
|
616
|
+
"""
|
|
617
|
+
Determines channels with a total relative contribution below
|
|
618
|
+
``channel_dropping_threshold``, disables them and removes them from the buffer.
|
|
619
|
+
|
|
620
|
+
Returns:
|
|
621
|
+
Number of channels that were disabled
|
|
622
|
+
"""
|
|
623
|
+
if (
|
|
624
|
+
not self.multichannel
|
|
625
|
+
or self.channel_dropping_threshold == 0.0
|
|
626
|
+
or (self.step + 1) % self.channel_dropping_interval != 0
|
|
627
|
+
):
|
|
628
|
+
return 0
|
|
629
|
+
|
|
630
|
+
mean_hist, _, count_hist = self.integration_history
|
|
631
|
+
count_hist = count_hist.to(mean_hist.dtype)
|
|
632
|
+
mean_hist = torch.nan_to_num(mean_hist)
|
|
633
|
+
hist_weights = count_hist / count_hist.sum(dim=0)
|
|
634
|
+
channel_integrals = torch.nansum(hist_weights * mean_hist, dim=0)
|
|
635
|
+
channel_rel_integrals = channel_integrals / channel_integrals.sum()
|
|
636
|
+
cri_sort, cri_argsort = torch.sort(channel_rel_integrals)
|
|
637
|
+
n_irrelevant = torch.count_nonzero(
|
|
638
|
+
cri_sort.cumsum(dim=0) < self.channel_dropping_threshold
|
|
639
|
+
)
|
|
640
|
+
n_disabled = torch.count_nonzero(
|
|
641
|
+
self.active_channels_mask[cri_argsort[:n_irrelevant]]
|
|
642
|
+
)
|
|
643
|
+
self.active_channels_mask[cri_argsort[:n_irrelevant]] = False
|
|
644
|
+
self.integrand.update_active_channels_mask(self.active_channels_mask)
|
|
645
|
+
self.batch_size = (
|
|
646
|
+
self.batch_size_offset
|
|
647
|
+
+ torch.count_nonzero(self.active_channels_mask).item()
|
|
648
|
+
* self.batch_size_per_channel
|
|
649
|
+
)
|
|
650
|
+
if self.buffer is not None:
|
|
651
|
+
|
|
652
|
+
def filter_func(batch):
|
|
653
|
+
samples = SampleBatch(*batch)
|
|
654
|
+
channels = (
|
|
655
|
+
samples.channels
|
|
656
|
+
if samples.integration_channels is None
|
|
657
|
+
else samples.integration_channels
|
|
658
|
+
)
|
|
659
|
+
return self.active_channels_mask[channels]
|
|
660
|
+
|
|
661
|
+
self.buffer.filter(filter_func)
|
|
662
|
+
return n_disabled
|
|
663
|
+
|
|
664
|
+
def _store_samples(self, samples: SampleBatch):
|
|
665
|
+
"""
|
|
666
|
+
Stores the generated samples and probabilites for reuse during buffered training. If
|
|
667
|
+
``max_stored_channel_weights`` is set, the largest channel weights are determined and only
|
|
668
|
+
those and their weights are stored.
|
|
669
|
+
|
|
670
|
+
Args:
|
|
671
|
+
samples: Object containing a batch of samples
|
|
672
|
+
"""
|
|
673
|
+
if self.buffer is None:
|
|
674
|
+
return
|
|
675
|
+
|
|
676
|
+
if (
|
|
677
|
+
self.max_stored_channel_weights is not None
|
|
678
|
+
and self.integrand.has_channel_weight_prior
|
|
679
|
+
):
|
|
680
|
+
# ensure that the alpha for the channel that the sample was generated with
|
|
681
|
+
# is always stored
|
|
682
|
+
alphas_prior_mod = torch.scatter(
|
|
683
|
+
samples.alphas_prior,
|
|
684
|
+
dim=1,
|
|
685
|
+
index=samples.channels[:, None],
|
|
686
|
+
src=torch.tensor(
|
|
687
|
+
[[2.0]], device=self.dummy.device, dtype=self.dummy.dtype
|
|
688
|
+
).expand(*samples.alphas_prior.shape),
|
|
689
|
+
)
|
|
690
|
+
largest_alphas, alpha_indices = torch.sort(
|
|
691
|
+
alphas_prior_mod, descending=True, dim=1
|
|
692
|
+
)
|
|
693
|
+
largest_alphas[:, 0] = torch.gather(
|
|
694
|
+
samples.alphas_prior, dim=1, index=samples.channels[:, None]
|
|
695
|
+
)[:, 0]
|
|
696
|
+
samples.alphas_prior = largest_alphas[
|
|
697
|
+
:, : self.max_stored_channel_weights
|
|
698
|
+
].clone()
|
|
699
|
+
samples.alpha_channel_indices = alpha_indices[
|
|
700
|
+
:, : self.max_stored_channel_weights
|
|
701
|
+
].clone()
|
|
702
|
+
|
|
703
|
+
self.buffer.store(*samples)
|
|
704
|
+
|
|
705
|
+
def _get_channels(
|
|
706
|
+
self,
|
|
707
|
+
n: int,
|
|
708
|
+
channel_weights: torch.Tensor,
|
|
709
|
+
uniform_channel_ratio: float,
|
|
710
|
+
return_counts: bool = False,
|
|
711
|
+
) -> torch.Tensor:
|
|
712
|
+
"""
|
|
713
|
+
Create a tensor of channel indices or number of samples per channel in two steps:
|
|
714
|
+
1. Split up n * uniform_channel_ratio equally among all the channels
|
|
715
|
+
2. Sample the rest of the events from the distribution given by channel_weights
|
|
716
|
+
after correcting for the uniformly distributed samples
|
|
717
|
+
This allows stratified sampling by variance weighting while ensuring stable training
|
|
718
|
+
because there are events in every channel.
|
|
719
|
+
Args:
|
|
720
|
+
n: Number of samples as scalar integer tensor
|
|
721
|
+
channel_weights: Weights of the channels (not normalized) with shape (channels,)
|
|
722
|
+
uniform_channel_ratio: Number between 0.0 and 1.0 to determine the ratio of samples
|
|
723
|
+
that will be distributed uniformly first
|
|
724
|
+
return_counts: If True, return number of samples per channels, otherwise the channel
|
|
725
|
+
indices
|
|
726
|
+
Returns:
|
|
727
|
+
If return_counts is True, Tensor with number of samples per channel, shape (channels,).
|
|
728
|
+
Otherwise, Tensor of channel numbers with shape (n,)
|
|
729
|
+
"""
|
|
730
|
+
assert channel_weights.shape == (self.integration_channel_count,)
|
|
731
|
+
n_active_channels = torch.count_nonzero(self.active_channels_mask).item()
|
|
732
|
+
uniform_per_channel = int(
|
|
733
|
+
np.ceil(n * uniform_channel_ratio / n_active_channels)
|
|
734
|
+
)
|
|
735
|
+
n_per_channel = torch.full(
|
|
736
|
+
(self.integration_channel_count,),
|
|
737
|
+
uniform_per_channel,
|
|
738
|
+
device=self.dummy.device,
|
|
739
|
+
)
|
|
740
|
+
n_per_channel[~self.active_channels_mask] = 0
|
|
741
|
+
|
|
742
|
+
n_weighted = max(n - n_per_channel.sum(), 0)
|
|
743
|
+
if n_weighted > 0:
|
|
744
|
+
normed_weights = (
|
|
745
|
+
channel_weights / channel_weights[self.active_channels_mask].sum()
|
|
746
|
+
)
|
|
747
|
+
normed_weights[~self.active_channels_mask] = 0.0
|
|
748
|
+
probs = torch.clamp(
|
|
749
|
+
normed_weights - uniform_channel_ratio / n_active_channels, min=0
|
|
750
|
+
)
|
|
751
|
+
n_per_channel += torch.ceil(probs * n_weighted / probs.sum()).int()
|
|
752
|
+
|
|
753
|
+
remove_chan = 0
|
|
754
|
+
while n_per_channel.sum() > n:
|
|
755
|
+
if n_per_channel[remove_chan] > 0:
|
|
756
|
+
n_per_channel[remove_chan] -= 1
|
|
757
|
+
remove_chan = (remove_chan + 1) % self.integration_channel_count
|
|
758
|
+
assert n_per_channel.sum() == n
|
|
759
|
+
|
|
760
|
+
if return_counts:
|
|
761
|
+
return n_per_channel
|
|
762
|
+
|
|
763
|
+
return torch.cat(
|
|
764
|
+
[
|
|
765
|
+
torch.full((npc,), i, device=self.dummy.device)
|
|
766
|
+
for i, npc in enumerate(n_per_channel)
|
|
767
|
+
]
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
def _get_samples(
|
|
771
|
+
self,
|
|
772
|
+
n: int,
|
|
773
|
+
uniform_channel_ratio: float = 0.0,
|
|
774
|
+
train: bool = False,
|
|
775
|
+
channel_weight_mode: Literal["variance", "mean"] = "variance",
|
|
776
|
+
channel: int | None = None,
|
|
777
|
+
) -> SampleBatch:
|
|
778
|
+
"""
|
|
779
|
+
Draws samples from the flow and evaluates the integrand
|
|
780
|
+
|
|
781
|
+
Args:
|
|
782
|
+
n: number of samples
|
|
783
|
+
uniform_channel_ratio: Number between 0.0 and 1.0 to determine the ratio of samples
|
|
784
|
+
that will be distributed uniformly first
|
|
785
|
+
train: If True, the function is used in training mode, i.e. samples where the integrand
|
|
786
|
+
is zero will be removed if drop_zero_integrands is True
|
|
787
|
+
channel_weight_mode: specifies whether the channels are weighted by their mean or
|
|
788
|
+
variance. Note that weighting by mean can lead to problems for non-positive functions
|
|
789
|
+
channel: if different from None, samples are only generated for this channel
|
|
790
|
+
Returns:
|
|
791
|
+
Object containing a batch of samples
|
|
792
|
+
"""
|
|
793
|
+
if channel is None:
|
|
794
|
+
batch_channels = (
|
|
795
|
+
self._get_channels(
|
|
796
|
+
n,
|
|
797
|
+
self._get_channel_contributions(train, channel_weight_mode),
|
|
798
|
+
uniform_channel_ratio,
|
|
799
|
+
)
|
|
800
|
+
if self.multichannel
|
|
801
|
+
else None
|
|
802
|
+
)
|
|
803
|
+
else:
|
|
804
|
+
batch_channels = torch.full((n,), channel, device=self.dummy.device)
|
|
805
|
+
|
|
806
|
+
batches_out = []
|
|
807
|
+
current_batch_size = 0
|
|
808
|
+
while True:
|
|
809
|
+
integration_channels = None
|
|
810
|
+
weight_factor = None
|
|
811
|
+
if self.integrand.function_includes_sampling:
|
|
812
|
+
integration_channels = batch_channels
|
|
813
|
+
x, prob, weight, y, alphas_prior, channels = self.integrand.function(
|
|
814
|
+
batch_channels
|
|
815
|
+
)
|
|
816
|
+
else:
|
|
817
|
+
if self.group_channels and self.group_channels_uniform:
|
|
818
|
+
group_sizes = self.channel_group_sizes[batch_channels]
|
|
819
|
+
chan_in_group = (
|
|
820
|
+
torch.rand(
|
|
821
|
+
(n,), device=self.dummy.device, dtype=self.dummy.dtype
|
|
822
|
+
)
|
|
823
|
+
* group_sizes
|
|
824
|
+
).long()
|
|
825
|
+
weight_factor = group_sizes
|
|
826
|
+
integration_channels = batch_channels
|
|
827
|
+
channels = self.channel_group_remap[batch_channels, chan_in_group]
|
|
828
|
+
else:
|
|
829
|
+
channels = batch_channels
|
|
830
|
+
|
|
831
|
+
with torch.no_grad():
|
|
832
|
+
x, prob = self.flow.sample(
|
|
833
|
+
n,
|
|
834
|
+
channel=channels,
|
|
835
|
+
return_prob=True,
|
|
836
|
+
device=self.dummy.device,
|
|
837
|
+
dtype=self.dummy.dtype,
|
|
838
|
+
)
|
|
839
|
+
weight, y, alphas_prior = self.integrand(x, channels)
|
|
840
|
+
|
|
841
|
+
if self.group_channels and not self.group_channels_uniform:
|
|
842
|
+
chan_in_group = x[:, self.channel_group_dim].long()
|
|
843
|
+
integration_channels = batch_channels
|
|
844
|
+
channels = self.channel_group_remap[
|
|
845
|
+
integration_channels, chan_in_group
|
|
846
|
+
]
|
|
847
|
+
|
|
848
|
+
if weight_factor is not None:
|
|
849
|
+
weight *= weight_factor
|
|
850
|
+
batch = SampleBatch(
|
|
851
|
+
x,
|
|
852
|
+
y,
|
|
853
|
+
prob,
|
|
854
|
+
weight,
|
|
855
|
+
channels,
|
|
856
|
+
alphas_prior,
|
|
857
|
+
integration_channels=integration_channels,
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
if not train:
|
|
861
|
+
current_batch_size += batch.x.shape[0]
|
|
862
|
+
elif self.drop_zero_integrands:
|
|
863
|
+
mask = weight != 0.0
|
|
864
|
+
batch = batch.map(lambda t: t[mask])
|
|
865
|
+
if self.multichannel:
|
|
866
|
+
batch.zero_counts = torch.bincount(
|
|
867
|
+
(
|
|
868
|
+
channels[~mask]
|
|
869
|
+
if integration_channels is None
|
|
870
|
+
else integration_channels[~mask]
|
|
871
|
+
),
|
|
872
|
+
minlength=self.integration_channel_count,
|
|
873
|
+
)
|
|
874
|
+
else:
|
|
875
|
+
batch.zero_counts = torch.full(
|
|
876
|
+
(1,), torch.count_nonzero(~mask), device=x.device
|
|
877
|
+
)
|
|
878
|
+
current_batch_size += batch.x.shape[0]
|
|
879
|
+
else:
|
|
880
|
+
current_batch_size += weight.count_nonzero()
|
|
881
|
+
|
|
882
|
+
# mask = ~(weight.isnan() | x.isnan().any(dim=1)) & mask
|
|
883
|
+
batches_out.append(batch)
|
|
884
|
+
|
|
885
|
+
# check this condition at the end such that the sampling runs at least once
|
|
886
|
+
if current_batch_size > self.batch_size_threshold * n:
|
|
887
|
+
break
|
|
888
|
+
|
|
889
|
+
return SampleBatch.cat(batches_out)
|
|
890
|
+
|
|
891
|
+
def train_step(self) -> TrainingStatus:
|
|
892
|
+
"""
|
|
893
|
+
Performs a single training step
|
|
894
|
+
|
|
895
|
+
Returns:
|
|
896
|
+
Training status
|
|
897
|
+
"""
|
|
898
|
+
|
|
899
|
+
if self.step == self.freeze_cwnet_iteration and self.cwnet is not None:
|
|
900
|
+
for param in self.cwnet.parameters():
|
|
901
|
+
param.requires_grad = False
|
|
902
|
+
|
|
903
|
+
if self.step_type_count == 0:
|
|
904
|
+
buffered = False
|
|
905
|
+
samples = self._get_samples(
|
|
906
|
+
self.batch_size, self.uniform_channel_ratio, train=True
|
|
907
|
+
)
|
|
908
|
+
loss, means, variances, counts = self._optimization_step(samples)
|
|
909
|
+
self._store_samples(samples)
|
|
910
|
+
self.integration_history.store(means[None], variances[None], counts[None])
|
|
911
|
+
|
|
912
|
+
if self.buffered_steps != 0 and self.buffer.size > self.minimum_buffer_size:
|
|
913
|
+
self.step_type_count += 1
|
|
914
|
+
else:
|
|
915
|
+
buffered = True
|
|
916
|
+
samples = SampleBatch(*self.buffer.sample(self.batch_size))
|
|
917
|
+
loss, _, _, _ = self._optimization_step(samples)
|
|
918
|
+
self.step_type_count = (self.step_type_count + 1) % (
|
|
919
|
+
self.buffered_steps + 1
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
dropped_channels = self._disable_unused_channels()
|
|
923
|
+
status = TrainingStatus(
|
|
924
|
+
step=self.step,
|
|
925
|
+
loss=loss,
|
|
926
|
+
buffered=buffered,
|
|
927
|
+
learning_rate=(
|
|
928
|
+
None if self.scheduler is None else self.scheduler.get_last_lr()[0]
|
|
929
|
+
),
|
|
930
|
+
dropped_channels=dropped_channels,
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
if self.scheduler is not None:
|
|
934
|
+
self.scheduler.step()
|
|
935
|
+
self.step += 1
|
|
936
|
+
return status
|
|
937
|
+
|
|
938
|
+
def train(
|
|
939
|
+
self,
|
|
940
|
+
steps: int,
|
|
941
|
+
callback: Callable[[TrainingStatus], None] | None = None,
|
|
942
|
+
capture_keyboard_interrupt: bool = False,
|
|
943
|
+
):
|
|
944
|
+
"""
|
|
945
|
+
Performs multiple training steps
|
|
946
|
+
|
|
947
|
+
Args:
|
|
948
|
+
steps: number of training steps
|
|
949
|
+
callback: function that is called after each training step with the training status
|
|
950
|
+
as argument
|
|
951
|
+
capture_keyboard_interrupt: If True, a keyboard interrupt does not raise an exception.
|
|
952
|
+
Instead, the current training step is finished and the training is aborted
|
|
953
|
+
afterwards.
|
|
954
|
+
"""
|
|
955
|
+
interrupted = False
|
|
956
|
+
if capture_keyboard_interrupt:
|
|
957
|
+
|
|
958
|
+
def handler(sig, frame):
|
|
959
|
+
nonlocal interrupted
|
|
960
|
+
interrupted = True
|
|
961
|
+
|
|
962
|
+
old_handler = signal.signal(signal.SIGINT, handler)
|
|
963
|
+
|
|
964
|
+
try:
|
|
965
|
+
for _ in range(steps):
|
|
966
|
+
status = self.train_step()
|
|
967
|
+
if callback is not None:
|
|
968
|
+
callback(status)
|
|
969
|
+
if interrupted:
|
|
970
|
+
break
|
|
971
|
+
finally:
|
|
972
|
+
if capture_keyboard_interrupt:
|
|
973
|
+
signal.signal(signal.SIGINT, old_handler)
|