madspace 0.3.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. madspace/__init__.py +1 -0
  2. madspace/_madspace_py.cpython-311-darwin.so +0 -0
  3. madspace/_madspace_py.pyi +2189 -0
  4. madspace/_madspace_py_loader.py +111 -0
  5. madspace/include/madspace/constants.h +17 -0
  6. madspace/include/madspace/madcode/function.h +102 -0
  7. madspace/include/madspace/madcode/function_builder_mixin.h +591 -0
  8. madspace/include/madspace/madcode/instruction.h +208 -0
  9. madspace/include/madspace/madcode/opcode_mixin.h +134 -0
  10. madspace/include/madspace/madcode/optimizer.h +31 -0
  11. madspace/include/madspace/madcode/type.h +203 -0
  12. madspace/include/madspace/madcode.h +6 -0
  13. madspace/include/madspace/phasespace/base.h +74 -0
  14. madspace/include/madspace/phasespace/channel_weight_network.h +46 -0
  15. madspace/include/madspace/phasespace/channel_weights.h +51 -0
  16. madspace/include/madspace/phasespace/chili.h +32 -0
  17. madspace/include/madspace/phasespace/cross_section.h +47 -0
  18. madspace/include/madspace/phasespace/cuts.h +34 -0
  19. madspace/include/madspace/phasespace/discrete_flow.h +44 -0
  20. madspace/include/madspace/phasespace/discrete_sampler.h +53 -0
  21. madspace/include/madspace/phasespace/flow.h +53 -0
  22. madspace/include/madspace/phasespace/histograms.h +26 -0
  23. madspace/include/madspace/phasespace/integrand.h +204 -0
  24. madspace/include/madspace/phasespace/invariants.h +26 -0
  25. madspace/include/madspace/phasespace/luminosity.h +41 -0
  26. madspace/include/madspace/phasespace/matrix_element.h +70 -0
  27. madspace/include/madspace/phasespace/mlp.h +37 -0
  28. madspace/include/madspace/phasespace/multichannel.h +49 -0
  29. madspace/include/madspace/phasespace/observable.h +85 -0
  30. madspace/include/madspace/phasespace/pdf.h +78 -0
  31. madspace/include/madspace/phasespace/phasespace.h +67 -0
  32. madspace/include/madspace/phasespace/rambo.h +26 -0
  33. madspace/include/madspace/phasespace/scale.h +52 -0
  34. madspace/include/madspace/phasespace/t_propagator_mapping.h +34 -0
  35. madspace/include/madspace/phasespace/three_particle.h +68 -0
  36. madspace/include/madspace/phasespace/topology.h +116 -0
  37. madspace/include/madspace/phasespace/two_particle.h +63 -0
  38. madspace/include/madspace/phasespace/vegas.h +53 -0
  39. madspace/include/madspace/phasespace.h +27 -0
  40. madspace/include/madspace/runtime/context.h +147 -0
  41. madspace/include/madspace/runtime/discrete_optimizer.h +24 -0
  42. madspace/include/madspace/runtime/event_generator.h +257 -0
  43. madspace/include/madspace/runtime/format.h +68 -0
  44. madspace/include/madspace/runtime/io.h +343 -0
  45. madspace/include/madspace/runtime/lhe_output.h +132 -0
  46. madspace/include/madspace/runtime/logger.h +46 -0
  47. madspace/include/madspace/runtime/runtime_base.h +39 -0
  48. madspace/include/madspace/runtime/tensor.h +603 -0
  49. madspace/include/madspace/runtime/thread_pool.h +101 -0
  50. madspace/include/madspace/runtime/vegas_optimizer.h +26 -0
  51. madspace/include/madspace/runtime.h +12 -0
  52. madspace/include/madspace/umami.h +202 -0
  53. madspace/include/madspace/util.h +142 -0
  54. madspace/lib/libmadspace.dylib +0 -0
  55. madspace/lib/libmadspace_cpu.dylib +0 -0
  56. madspace/madnis/__init__.py +44 -0
  57. madspace/madnis/buffer.py +167 -0
  58. madspace/madnis/channel_grouping.py +85 -0
  59. madspace/madnis/distribution.py +103 -0
  60. madspace/madnis/integrand.py +175 -0
  61. madspace/madnis/integrator.py +973 -0
  62. madspace/madnis/interface.py +191 -0
  63. madspace/madnis/losses.py +186 -0
  64. madspace/torch.py +82 -0
  65. madspace-0.3.1.dist-info/METADATA +71 -0
  66. madspace-0.3.1.dist-info/RECORD +68 -0
  67. madspace-0.3.1.dist-info/WHEEL +6 -0
  68. madspace-0.3.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,973 @@
1
+ import itertools
2
+ import signal
3
+ import warnings
4
+ from collections.abc import Callable, Iterable
5
+ from dataclasses import astuple, dataclass
6
+ from typing import Any, Literal
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.optim import Optimizer
13
+ from torch.optim.lr_scheduler import LRScheduler
14
+
15
+ from .buffer import Buffer
16
+ from .distribution import Distribution
17
+ from .integrand import Integrand
18
+ from .losses import MultiChannelLoss, kl_divergence, stratified_variance, variance
19
+
20
+
21
+ @dataclass
22
+ class TrainingStatus:
23
+ """
24
+ Contains the MadNIS training status to pass it to a callback function.
25
+
26
+ Args:
27
+ step: optimization step
28
+ loss: loss from the optimization step
29
+ buffered: whether the optimization was performed on buffered samples
30
+ learning_rate: current learning rate if learning rate scheduler is present
31
+ dropped_channels: number of channels dropped after this optimization step
32
+ """
33
+
34
+ step: int
35
+ loss: float
36
+ buffered: bool
37
+ learning_rate: float | None
38
+ dropped_channels: int
39
+
40
+
41
+ @dataclass
42
+ class SampleBatch:
43
+ """
44
+ Contains a batch of samples
45
+
46
+ Args:
47
+ x: samples generated by the flow, shape (n, dim)
48
+ y: remapped samples returned by the integrand, shape (n, remapped_dim)
49
+ q_sample: probabilities of the samples, shape (n, )
50
+ func_vals: integrand value, shape (n, )
51
+ channels: channels indices for multi-channel integration, shape (n, ), otherwise None
52
+ alphas_prior: prior channel weights, shape (n, channels), or None for single-channel
53
+ integration
54
+ alpha_channel_indices: channel indices if not all prior channel weights are stored,
55
+ otherwise None
56
+ integration_channels: index of the channel group in case the integration is performed at the
57
+ level of channel groups, shape (n, ), otherwise None
58
+ weights: integration weight, shape (n, ). Only set when returned from Integrator.sample
59
+ function, otherwise None.
60
+ alphas: channel weights including learned correction, shape (n, channels). Only set when
61
+ returned from Integrator.sample function, otherwise None.
62
+ zero_counts: channel-wise counts of samples with zero-weights that are not included in the
63
+ batch, shape (channels, ). This field is ignored by most methods, as it behaves
64
+ does not have the batch size as its first dimension
65
+ """
66
+
67
+ x: torch.Tensor
68
+ y: torch.Tensor | None
69
+ q_sample: torch.Tensor
70
+ func_vals: torch.Tensor
71
+ channels: torch.Tensor | None
72
+ alphas_prior: torch.Tensor | None = None
73
+ alpha_channel_indices: torch.Tensor | None = None
74
+ integration_channels: torch.Tensor | None = None
75
+ weights: torch.Tensor | None = None
76
+ alphas: torch.Tensor | None = None
77
+ zero_counts: torch.Tensor | None = None
78
+
79
+ def __iter__(self) -> Iterable[torch.Tensor | None]:
80
+ """
81
+ Returns iterator over the fields of the class
82
+ """
83
+ return iter(astuple(self)[:-1])
84
+
85
+ def map(self, func: Callable[[torch.Tensor], torch.Tensor]) -> "SampleBatch":
86
+ """
87
+ Applies function to all fields in the batch that are not None and returns a new SampleBatch
88
+
89
+ Args:
90
+ func: function that is applied to all fields in the batch. Expects a tensor as argument
91
+ and returns a new tensor
92
+ Returns:
93
+ Transformed SampleBatch
94
+ """
95
+ return SampleBatch(*(None if field is None else func(field) for field in self))
96
+
97
+ def split(self, batch_size: int) -> Iterable["SampleBatch"]:
98
+ """
99
+ Splits up the fields into batches and yields SampleBatch objects for every batch.
100
+
101
+ Args:
102
+ batch_size: maximal size of the batches
103
+ Returns:
104
+ Iterator over the batches
105
+ """
106
+ for batch in zip(
107
+ *(
108
+ itertools.repeat(None) if field is None else field.split(batch_size)
109
+ for field in self
110
+ )
111
+ ):
112
+ yield SampleBatch(*batch)
113
+
114
+ @staticmethod
115
+ def cat(batches: Iterable["SampleBatch"]) -> "SampleBatch":
116
+ """
117
+ Concatenates multiple batches. If the field zero_counts is not None, the zero_counts of
118
+ all batches are added.
119
+
120
+ Args:
121
+ batches: Iterable over SampleBatch objects
122
+ Return:
123
+ New SamplaBatch object containing the concatenated batches
124
+ """
125
+ cat_batch = SampleBatch(
126
+ *(
127
+ None if item[0] is None else torch.cat(item, dim=0)
128
+ for item in zip(*batches)
129
+ )
130
+ )
131
+ if batches[0].zero_counts is not None:
132
+ cat_batch.zero_counts = torch.stack(
133
+ [batch.zero_counts for batch in batches], dim=1
134
+ ).sum(dim=1)
135
+ return cat_batch
136
+
137
+
138
+ class Integrator(nn.Module):
139
+ """
140
+ Implements MadNIS training and integration logic. MadNIS integrators are torch modules, so
141
+ their state can easily be saved and loaded using the torch.save and torch.load methods.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ integrand: Callable[[torch.Tensor], torch.Tensor] | Integrand,
147
+ dims: int = 0,
148
+ flow: Distribution | None = None,
149
+ flow_kwargs: dict[str, Any] = {},
150
+ discrete_flow_kwargs: dict[str, Any] = {},
151
+ discrete_model: Literal["made", "transformer"] = "made",
152
+ train_channel_weights: bool = True,
153
+ cwnet: nn.Module | None = None,
154
+ cwnet_kwargs: dict[str, Any] = {},
155
+ loss: MultiChannelLoss | None = None,
156
+ optimizer: (
157
+ Optimizer | Callable[[Iterable[nn.Parameter]], Optimizer] | None
158
+ ) = None,
159
+ batch_size: int = 1024,
160
+ batch_size_per_channel: int = 0,
161
+ learning_rate: float = 1e-3,
162
+ scheduler: LRScheduler | Callable[[Optimizer], LRScheduler] | None = None,
163
+ uniform_channel_ratio: float = 1.0,
164
+ integration_history_length: int = 20,
165
+ drop_zero_integrands: bool = False,
166
+ batch_size_threshold: float = 0.5,
167
+ buffer_capacity: int = 0,
168
+ minimum_buffer_size: int = 50,
169
+ buffered_steps: int = 0,
170
+ max_stored_channel_weights: int | None = None,
171
+ channel_dropping_threshold: float = 0.0,
172
+ channel_dropping_interval: int = 100,
173
+ channel_grouping_mode: Literal["none", "uniform", "learned"] = "none",
174
+ freeze_cwnet_iteration: int | None = None,
175
+ device: torch.device | None = None,
176
+ dtype: torch.dtype | None = None,
177
+ ):
178
+ """
179
+ Args:
180
+ integrand: the function to be integrated. In the case of a simple single-channel
181
+ integration, the integrand function can directly be passed to the integrator.
182
+ In more complicated cases, like multi-channel integrals, use the ``Integrand`` class.
183
+ dims: dimension of the integration space. Only required if a simple function is given
184
+ as integrand.
185
+ flow: sampling distribution used for the integration. If None, a flow is constructed
186
+ using the ``Flow`` class. Otherwise, it has to be compatible with a normalizing flow,
187
+ i.e. have the interface defined in the ``Distribution`` class.
188
+ flow_kwargs: If flow is None, these keyword arguments are passed to the `Flow`
189
+ constructor.
190
+ discrete_flow_kwargs: If flow is None, these keyword arguments are passed to the
191
+ ``MixedFlow`` or ``DiscreteMADE`` constructor.
192
+ train_channel_weights: If True, construct a channel weight network and train it. Only
193
+ necessary if cwnet is None.
194
+ cwnet: network used for the trainable channel weights. If None and
195
+ train_channel_weights is True, the cwnet is built using the ``MLP`` class.
196
+ cwnet_kwargs: If cwnet is None and train_channel_weights is True, these keyword
197
+ arguments are passed to the ``MLP`` constructor.
198
+ loss: Loss function used for training. If not provided, the KL divergence is chosen in
199
+ the single-channel case and the stratified variance is chosen in the multi-channel
200
+ case.
201
+ optimizer: optimizer for the training. Can be an optimizer object or function that is
202
+ called with the model parameters as argument and returns the optimizer. If None, the
203
+ Adam optimizer is used.
204
+ batch_size: Training batch size
205
+ batch_size_per_channel: used to compute the batch size as a function of the number of
206
+ active channels, ``batch_size + n_active_channels * batch_size_per_channel``
207
+ learning_rate: learning rate used for the Adam optimizer
208
+ scheduler: learning rate scheduler for the training. Can be a learning rate scheduler
209
+ object or a function that gets the optimizer as argument and returns the scheduler.
210
+ If None, a constant learning rate is used.
211
+ uniform_channel_ratio: part of samples in each batch that will be distributed equally
212
+ between all channels, value has to be between 0 and 1.
213
+ integration_history_length: number of batches for which the channel-wise means and
214
+ variances are stored. This is used for stratified sampling during integration, and
215
+ during the training if uniform_channel_ratio is different from one.
216
+ drop_zero_integrands: If True, points with integrand zero are dropped and not used for
217
+ the optimization.
218
+ batch_size_threshold: New samples are drawn until the number of samples is at least
219
+ batch_size_threshold * batch_size.
220
+ buffer_capacity: number of samples that are stored for buffered training
221
+ minimum_buffer_size: minimal size of the buffer to run buffered training
222
+ buffered_steps: number of optimization steps on buffered samples after every online
223
+ training step
224
+ max_stored_channel_weights: number of prior channel weights that are buffered for each
225
+ sample. If None, all prior channel weights are saved, otherwise only those for the
226
+ channels with the largest contributions.
227
+ channel_dropping_threshold: all channels which a cumulated contribution to the
228
+ integrand that is smaller than this threshold are dropped
229
+ channel_dropping_interval: number of training steps after which channel dropping
230
+ is performed
231
+ channel_grouping_mode: If "none" all channels are treated as separate channels in the
232
+ loss and integration, even when they grouped together. If "uniform", the channels
233
+ within each group are sampled with equal probability. If "learned", a discrete
234
+ normalizing flow is used to sample the channel index within a group.
235
+ freeze_cwnet_iteration: If not None, specifies the training iteration after which the
236
+ channel weight network is frozen
237
+ device: torch device used for training and integration. If None, use default device.
238
+ dtype: torch dtype used for training and integration. If None, use default dtype.
239
+ """
240
+ super().__init__()
241
+
242
+ if not isinstance(integrand, Integrand):
243
+ integrand = Integrand(integrand, dims)
244
+ self.integrand = integrand
245
+ self.multichannel = integrand.channel_count is not None
246
+ discrete_dims = integrand.discrete_dims
247
+ input_dim = integrand.input_dim
248
+ if integrand.channel_grouping is None or channel_grouping_mode == "none":
249
+ self.integration_channel_count = integrand.channel_count
250
+ self.group_channels = False
251
+ self.group_channels_uniform = False
252
+ elif channel_grouping_mode == "uniform":
253
+ self.integration_channel_count = integrand.unique_channel_count()
254
+ self.group_channels = True
255
+ self.group_channels_uniform = True
256
+ elif channel_grouping_mode == "learned":
257
+ self.integration_channel_count = integrand.unique_channel_count()
258
+ self.group_channels = True
259
+ self.group_channels_uniform = False
260
+ self.channel_group_dim = (
261
+ 0
262
+ if integrand.discrete_dims_position == "first"
263
+ else input_dim - len(discrete_dims)
264
+ )
265
+ # TODO: provide default implementation of discrete prior
266
+ # discrete_dims.insert(0, max(len(group.channel_indices) for group in integrand.channel_grouping.groups))
267
+ # input_dim += 1
268
+ else:
269
+ raise ValueError(f"Unknown channel grouping mode {channel_grouping_mode}")
270
+
271
+ if self.group_channels:
272
+ self.register_buffer(
273
+ "channel_group_sizes",
274
+ torch.tensor(
275
+ [
276
+ len(group.channel_indices)
277
+ for group in integrand.channel_grouping.groups
278
+ ]
279
+ ),
280
+ )
281
+ self.register_buffer(
282
+ "channel_group_remap",
283
+ torch.zeros(
284
+ (len(self.channel_group_sizes), max(self.channel_group_sizes)),
285
+ dtype=torch.int64,
286
+ ),
287
+ )
288
+ for group in integrand.channel_grouping.groups:
289
+ for i, chan_index in enumerate(group.channel_indices):
290
+ self.channel_group_remap[group.group_index][i] = chan_index
291
+
292
+ if flow is None:
293
+ channel_remap_function = (
294
+ None
295
+ if self.group_channels and not self.group_channels_uniform
296
+ else self.integrand.remap_channels
297
+ )
298
+ if len(discrete_dims) == 0:
299
+ raise NotImplementedError("removed in MadSpace version of MadNIS")
300
+ elif len(discrete_dims) == input_dim:
301
+ if discrete_model == "made":
302
+ raise NotImplementedError("removed in MadSpace version of MadNIS")
303
+ elif discrete_model == "transformer":
304
+ raise NotImplementedError("removed in MadSpace version of MadNIS")
305
+ else:
306
+ raise ValueError("discrete_model must be 'made' or 'transformer'")
307
+ else:
308
+ discrete_kwargs = dict(
309
+ prior_prob_function=integrand.discrete_prior_prob_function,
310
+ **discrete_flow_kwargs,
311
+ )
312
+ if self.multichannel:
313
+ discrete_kwargs["channel_remap_function"] = channel_remap_function
314
+ raise NotImplementedError("removed in MadSpace version of MadNIS")
315
+
316
+ if cwnet is None and train_channel_weights and self.multichannel:
317
+ raise NotImplementedError("removed in MadSpace version of MadNIS")
318
+ if cwnet is None:
319
+ parameters = flow.parameters()
320
+ else:
321
+ parameters = itertools.chain(flow.parameters(), cwnet.parameters())
322
+ if optimizer is None:
323
+ self.optimizer = torch.optim.Adam(parameters, learning_rate)
324
+ elif isinstance(optimizer, Optimizer):
325
+ self.optimizer = optimizer
326
+ else:
327
+ self.optimizer = optimizer(parameters)
328
+ if scheduler is None or isinstance(scheduler, LRScheduler):
329
+ self.scheduler = scheduler
330
+ else:
331
+ self.scheduler = scheduler(self.optimizer)
332
+
333
+ self.flow = flow
334
+ self.cwnet = cwnet
335
+ self.batch_size_offset = batch_size
336
+ self.batch_size_per_channel = batch_size_per_channel
337
+ self.batch_size = batch_size + batch_size_per_channel * (
338
+ self.integration_channel_count or 1
339
+ )
340
+ self.uniform_channel_ratio = uniform_channel_ratio
341
+ self.drop_zero_integrands = drop_zero_integrands
342
+ self.batch_size_threshold = batch_size_threshold
343
+ if loss is None:
344
+ self.loss = stratified_variance if self.multichannel else kl_divergence
345
+ else:
346
+ self.loss = loss
347
+
348
+ self.minimum_buffer_size = minimum_buffer_size
349
+ self.buffered_steps = buffered_steps
350
+ self.max_stored_channel_weights = (
351
+ None
352
+ if max_stored_channel_weights is None
353
+ or integrand.channel_count is None
354
+ or max_stored_channel_weights >= integrand.channel_count
355
+ else max_stored_channel_weights
356
+ )
357
+ if buffer_capacity > 0:
358
+ channel_count = self.max_stored_channel_weights or integrand.channel_count
359
+ buffer_fields = [
360
+ (input_dim,),
361
+ None if integrand.remapped_dim is None else (integrand.remapped_dim,),
362
+ (),
363
+ (),
364
+ None if integrand.channel_count is None else (),
365
+ None if not integrand.has_channel_weight_prior else (channel_count,),
366
+ None if self.max_stored_channel_weights is None else (channel_count,),
367
+ () if self.group_channels else None,
368
+ None,
369
+ None,
370
+ ]
371
+ buffer_dtypes = [
372
+ None,
373
+ None,
374
+ None,
375
+ None,
376
+ torch.int64,
377
+ None,
378
+ torch.int64,
379
+ torch.int64,
380
+ None,
381
+ None,
382
+ ]
383
+ self.buffer = Buffer(
384
+ buffer_capacity, buffer_fields, persistent=False, dtypes=buffer_dtypes
385
+ )
386
+ else:
387
+ self.buffer = None
388
+ self.channel_dropping_threshold = channel_dropping_threshold
389
+ self.channel_dropping_interval = channel_dropping_interval
390
+ self.freeze_cwnet_iteration = freeze_cwnet_iteration
391
+ hist_shape = (self.integration_channel_count or 1,)
392
+ self.integration_history = Buffer(
393
+ integration_history_length,
394
+ [hist_shape, hist_shape, hist_shape],
395
+ dtypes=[None, None, torch.int64],
396
+ )
397
+ self.step = 0
398
+ self.step_type_count = 0
399
+ if self.multichannel:
400
+ self.register_buffer(
401
+ "active_channels_mask",
402
+ torch.ones((self.integration_channel_count,), dtype=torch.bool),
403
+ )
404
+ # Dummy to determine device and dtype
405
+ self.register_buffer("dummy", torch.zeros((1,)))
406
+
407
+ if device is not None:
408
+ self.to(device)
409
+ if dtype is not None:
410
+ self.to(dtype)
411
+
412
+ def _get_alphas(self, samples: SampleBatch) -> torch.Tensor:
413
+ """
414
+ Runs the channel weight network and returns the normalized channel weights, taking prior
415
+ channel weights and dropped channels into account.
416
+
417
+ Args:
418
+ samples: batch of samples
419
+ Returns:
420
+ channel weights, shape (n, channels)
421
+ """
422
+ if self.cwnet is None:
423
+ if samples.alphas_prior is None:
424
+ return samples.x.new_full(
425
+ (samples.x.shape[0], self.integrand.channel_count),
426
+ 1 / self.integrand.channel_count,
427
+ )
428
+ return self._restore_prior(samples)
429
+
430
+ if samples.alphas_prior is None:
431
+ alpha_prior = samples.x.new_ones(
432
+ (samples.x.shape[0], self.integrand.channel_count)
433
+ )
434
+ else:
435
+ alpha_prior = self._restore_prior(samples)
436
+
437
+ if self.group_channels:
438
+ active_channels_mask = self.active_channels_mask[
439
+ self.integrand.remap_channels(
440
+ torch.arange(alpha_prior.shape[1], device=alpha_prior.device)
441
+ )
442
+ ]
443
+ else:
444
+ active_channels_mask = self.active_channels_mask
445
+
446
+ alpha = alpha_prior * active_channels_mask
447
+ mask = samples.func_vals != 0
448
+ y = samples.x if samples.y is None else samples.y
449
+ alpha[mask] *= self.cwnet(y[mask]).exp()
450
+ ret = alpha / alpha.sum(dim=1, keepdim=True)
451
+ return ret
452
+
453
+ def _compute_integral(
454
+ self, samples: SampleBatch
455
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
456
+ """
457
+ Computes normalized integrand and channel-wise means, variances and counts
458
+
459
+ Args:
460
+ samples: batch of samples
461
+ Returns:
462
+ A tuple containing
463
+
464
+ - normalized integrand, shape (n, )
465
+ - channel-wise means of the integral, shape (channels, )
466
+ - channel-wise variances of the integral, shape (channels, )
467
+ - channel-wise number of samples, shape (channels, )
468
+ """
469
+ if self.multichannel:
470
+ alphas = torch.gather(
471
+ self._get_alphas(samples), index=samples.channels[:, None], dim=1
472
+ )[:, 0]
473
+ f_true = alphas * samples.func_vals
474
+ f_div_q = f_true.detach() / samples.q_sample
475
+ channels = (
476
+ samples.channels
477
+ if samples.integration_channels is None
478
+ else samples.integration_channels
479
+ )
480
+ counts = torch.bincount(channels, minlength=self.integration_channel_count)
481
+ if samples.zero_counts is not None:
482
+ counts += samples.zero_counts
483
+ means = torch.bincount(
484
+ channels,
485
+ weights=f_div_q,
486
+ minlength=self.integration_channel_count,
487
+ ) / counts.clip(min=1)
488
+ variances = (
489
+ torch.bincount(
490
+ channels,
491
+ weights=(f_div_q - means[channels]).square(),
492
+ minlength=self.integration_channel_count,
493
+ )
494
+ / counts
495
+ )
496
+ else:
497
+ f_div_q = samples.func_vals / samples.q_sample
498
+ f_true = samples.func_vals
499
+ means = f_div_q.mean(dim=0, keepdim=True)
500
+ counts = torch.full((1,), f_div_q.shape[0], device=means.device)
501
+ variances = f_div_q.var(dim=0, keepdim=True)
502
+ return f_true, means, variances, counts
503
+
504
+ def _optimization_step(
505
+ self,
506
+ samples: SampleBatch,
507
+ ) -> tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]:
508
+ """
509
+ Perform one optimization step of the networks for the given samples
510
+
511
+ Args:
512
+ samples: batch of samples
513
+ Returns:
514
+ A tuple containing
515
+
516
+ - value of the loss
517
+ - channel-wise means of the integral, shape (channels, )
518
+ - channel-wise variances of the integral, shape (channels, )
519
+ - channel-wise number of samples, shape (channels, )
520
+ """
521
+ self.optimizer.zero_grad()
522
+ # TODO: depending on the loss function and for drop_zero_weights=False, we can encounter
523
+ # zero-weight events here and it might be sufficient to evaluate the flow for events with
524
+ # func_val != 0. That might however give wrong results for other loss functions
525
+ q_test = self.flow.prob(
526
+ samples.x,
527
+ channel=(
528
+ samples.integration_channels
529
+ if self.group_channels and not self.group_channels_uniform
530
+ else samples.channels
531
+ ),
532
+ )
533
+ f_true, means, variances, counts = self._compute_integral(samples)
534
+ loss = self.loss(
535
+ f_true,
536
+ q_test,
537
+ q_sample=samples.q_sample,
538
+ channels=(
539
+ samples.channels
540
+ if samples.integration_channels is None
541
+ else samples.integration_channels
542
+ ),
543
+ )
544
+ if loss.isnan().item():
545
+ warnings.warn("nan batch: skipping optimization")
546
+ else:
547
+ loss.backward()
548
+ self.optimizer.step()
549
+ return loss.item(), means, variances, counts
550
+
551
+ def _restore_prior(self, samples: SampleBatch) -> torch.Tensor:
552
+ """
553
+ Restores the full prior channel weights if only the largest channel weights and their
554
+ indices were saved.
555
+
556
+ Args:
557
+ samples: batch of samples
558
+ Returns:
559
+ Tensor of prior channel weights with shape (n, channels)
560
+ """
561
+ if samples.alpha_channel_indices is None:
562
+ return samples.alphas_prior
563
+
564
+ alphas_prior_reduced = samples.alphas_prior
565
+ epsilon = torch.finfo(alphas_prior_reduced.dtype).eps
566
+
567
+ # strategy 1: distribute difference to 1 evenly among non-stored channels
568
+ # n_rest = self.integrand.channel_count - self.max_stored_channel_weights
569
+ # alphas_prior = torch.clamp(
570
+ # (1 - alphas_prior_reduced.sum(dim=1, keepdims=True)) / n_rest,
571
+ # min=epsilon,
572
+ # ).repeat(1, self.integrand.channel_count)
573
+ # alphas_prior.scatter_(1, samples.alpha_channel_indices, alphas_prior_reduced)
574
+ # return alphas_prior
575
+
576
+ # strategy 2: set non-stored channel alphas to epsilon, normalize again
577
+ alphas_prior = alphas_prior_reduced.new_full(
578
+ (alphas_prior_reduced.shape[0], self.integrand.channel_count), epsilon
579
+ )
580
+ alphas_prior.scatter_(1, samples.alpha_channel_indices, alphas_prior_reduced)
581
+ return alphas_prior / alphas_prior.sum(dim=1, keepdims=True)
582
+
583
+ def _get_channel_contributions(
584
+ self,
585
+ expect_full_history: bool,
586
+ channel_weight_mode: Literal["variance", "mean"],
587
+ ) -> torch.Tensor:
588
+ """
589
+ Uses the list of saved variances or means to compute the contribution of each channel for
590
+ stratified sampling.
591
+
592
+ Args:
593
+ expect_full_history: If True, the integration history has to be full, otherwise uniform
594
+ weights are returned.
595
+ channel_weight_mode: specifies whether the channels are weighted by their mean or
596
+ variance. Note that weighting by mean can lead to problems for non-positive functions
597
+ Returns:
598
+ weights for sampling the channels with shape (channels,)
599
+ """
600
+ min_len = self.integration_history.capacity if expect_full_history else 1
601
+ if self.integration_history.size < min_len:
602
+ return torch.ones(
603
+ self.integration_channel_count,
604
+ device=self.dummy.device,
605
+ dtype=self.dummy.dtype,
606
+ )
607
+ mean_hist, var_hist, count_hist = self.integration_history
608
+ contrib_hist = mean_hist.abs() if channel_weight_mode == "mean" else var_hist
609
+ count_hist = torch.where(
610
+ contrib_hist.isnan(), np.nan, count_hist.to(contrib_hist.dtype)
611
+ )
612
+ hist_weights = count_hist / count_hist.nansum(dim=0)
613
+ return torch.nansum(hist_weights * contrib_hist, dim=0).sqrt()
614
+
615
+ def _disable_unused_channels(self) -> int:
616
+ """
617
+ Determines channels with a total relative contribution below
618
+ ``channel_dropping_threshold``, disables them and removes them from the buffer.
619
+
620
+ Returns:
621
+ Number of channels that were disabled
622
+ """
623
+ if (
624
+ not self.multichannel
625
+ or self.channel_dropping_threshold == 0.0
626
+ or (self.step + 1) % self.channel_dropping_interval != 0
627
+ ):
628
+ return 0
629
+
630
+ mean_hist, _, count_hist = self.integration_history
631
+ count_hist = count_hist.to(mean_hist.dtype)
632
+ mean_hist = torch.nan_to_num(mean_hist)
633
+ hist_weights = count_hist / count_hist.sum(dim=0)
634
+ channel_integrals = torch.nansum(hist_weights * mean_hist, dim=0)
635
+ channel_rel_integrals = channel_integrals / channel_integrals.sum()
636
+ cri_sort, cri_argsort = torch.sort(channel_rel_integrals)
637
+ n_irrelevant = torch.count_nonzero(
638
+ cri_sort.cumsum(dim=0) < self.channel_dropping_threshold
639
+ )
640
+ n_disabled = torch.count_nonzero(
641
+ self.active_channels_mask[cri_argsort[:n_irrelevant]]
642
+ )
643
+ self.active_channels_mask[cri_argsort[:n_irrelevant]] = False
644
+ self.integrand.update_active_channels_mask(self.active_channels_mask)
645
+ self.batch_size = (
646
+ self.batch_size_offset
647
+ + torch.count_nonzero(self.active_channels_mask).item()
648
+ * self.batch_size_per_channel
649
+ )
650
+ if self.buffer is not None:
651
+
652
+ def filter_func(batch):
653
+ samples = SampleBatch(*batch)
654
+ channels = (
655
+ samples.channels
656
+ if samples.integration_channels is None
657
+ else samples.integration_channels
658
+ )
659
+ return self.active_channels_mask[channels]
660
+
661
+ self.buffer.filter(filter_func)
662
+ return n_disabled
663
+
664
+ def _store_samples(self, samples: SampleBatch):
665
+ """
666
+ Stores the generated samples and probabilites for reuse during buffered training. If
667
+ ``max_stored_channel_weights`` is set, the largest channel weights are determined and only
668
+ those and their weights are stored.
669
+
670
+ Args:
671
+ samples: Object containing a batch of samples
672
+ """
673
+ if self.buffer is None:
674
+ return
675
+
676
+ if (
677
+ self.max_stored_channel_weights is not None
678
+ and self.integrand.has_channel_weight_prior
679
+ ):
680
+ # ensure that the alpha for the channel that the sample was generated with
681
+ # is always stored
682
+ alphas_prior_mod = torch.scatter(
683
+ samples.alphas_prior,
684
+ dim=1,
685
+ index=samples.channels[:, None],
686
+ src=torch.tensor(
687
+ [[2.0]], device=self.dummy.device, dtype=self.dummy.dtype
688
+ ).expand(*samples.alphas_prior.shape),
689
+ )
690
+ largest_alphas, alpha_indices = torch.sort(
691
+ alphas_prior_mod, descending=True, dim=1
692
+ )
693
+ largest_alphas[:, 0] = torch.gather(
694
+ samples.alphas_prior, dim=1, index=samples.channels[:, None]
695
+ )[:, 0]
696
+ samples.alphas_prior = largest_alphas[
697
+ :, : self.max_stored_channel_weights
698
+ ].clone()
699
+ samples.alpha_channel_indices = alpha_indices[
700
+ :, : self.max_stored_channel_weights
701
+ ].clone()
702
+
703
+ self.buffer.store(*samples)
704
+
705
+ def _get_channels(
706
+ self,
707
+ n: int,
708
+ channel_weights: torch.Tensor,
709
+ uniform_channel_ratio: float,
710
+ return_counts: bool = False,
711
+ ) -> torch.Tensor:
712
+ """
713
+ Create a tensor of channel indices or number of samples per channel in two steps:
714
+ 1. Split up n * uniform_channel_ratio equally among all the channels
715
+ 2. Sample the rest of the events from the distribution given by channel_weights
716
+ after correcting for the uniformly distributed samples
717
+ This allows stratified sampling by variance weighting while ensuring stable training
718
+ because there are events in every channel.
719
+ Args:
720
+ n: Number of samples as scalar integer tensor
721
+ channel_weights: Weights of the channels (not normalized) with shape (channels,)
722
+ uniform_channel_ratio: Number between 0.0 and 1.0 to determine the ratio of samples
723
+ that will be distributed uniformly first
724
+ return_counts: If True, return number of samples per channels, otherwise the channel
725
+ indices
726
+ Returns:
727
+ If return_counts is True, Tensor with number of samples per channel, shape (channels,).
728
+ Otherwise, Tensor of channel numbers with shape (n,)
729
+ """
730
+ assert channel_weights.shape == (self.integration_channel_count,)
731
+ n_active_channels = torch.count_nonzero(self.active_channels_mask).item()
732
+ uniform_per_channel = int(
733
+ np.ceil(n * uniform_channel_ratio / n_active_channels)
734
+ )
735
+ n_per_channel = torch.full(
736
+ (self.integration_channel_count,),
737
+ uniform_per_channel,
738
+ device=self.dummy.device,
739
+ )
740
+ n_per_channel[~self.active_channels_mask] = 0
741
+
742
+ n_weighted = max(n - n_per_channel.sum(), 0)
743
+ if n_weighted > 0:
744
+ normed_weights = (
745
+ channel_weights / channel_weights[self.active_channels_mask].sum()
746
+ )
747
+ normed_weights[~self.active_channels_mask] = 0.0
748
+ probs = torch.clamp(
749
+ normed_weights - uniform_channel_ratio / n_active_channels, min=0
750
+ )
751
+ n_per_channel += torch.ceil(probs * n_weighted / probs.sum()).int()
752
+
753
+ remove_chan = 0
754
+ while n_per_channel.sum() > n:
755
+ if n_per_channel[remove_chan] > 0:
756
+ n_per_channel[remove_chan] -= 1
757
+ remove_chan = (remove_chan + 1) % self.integration_channel_count
758
+ assert n_per_channel.sum() == n
759
+
760
+ if return_counts:
761
+ return n_per_channel
762
+
763
+ return torch.cat(
764
+ [
765
+ torch.full((npc,), i, device=self.dummy.device)
766
+ for i, npc in enumerate(n_per_channel)
767
+ ]
768
+ )
769
+
770
+ def _get_samples(
771
+ self,
772
+ n: int,
773
+ uniform_channel_ratio: float = 0.0,
774
+ train: bool = False,
775
+ channel_weight_mode: Literal["variance", "mean"] = "variance",
776
+ channel: int | None = None,
777
+ ) -> SampleBatch:
778
+ """
779
+ Draws samples from the flow and evaluates the integrand
780
+
781
+ Args:
782
+ n: number of samples
783
+ uniform_channel_ratio: Number between 0.0 and 1.0 to determine the ratio of samples
784
+ that will be distributed uniformly first
785
+ train: If True, the function is used in training mode, i.e. samples where the integrand
786
+ is zero will be removed if drop_zero_integrands is True
787
+ channel_weight_mode: specifies whether the channels are weighted by their mean or
788
+ variance. Note that weighting by mean can lead to problems for non-positive functions
789
+ channel: if different from None, samples are only generated for this channel
790
+ Returns:
791
+ Object containing a batch of samples
792
+ """
793
+ if channel is None:
794
+ batch_channels = (
795
+ self._get_channels(
796
+ n,
797
+ self._get_channel_contributions(train, channel_weight_mode),
798
+ uniform_channel_ratio,
799
+ )
800
+ if self.multichannel
801
+ else None
802
+ )
803
+ else:
804
+ batch_channels = torch.full((n,), channel, device=self.dummy.device)
805
+
806
+ batches_out = []
807
+ current_batch_size = 0
808
+ while True:
809
+ integration_channels = None
810
+ weight_factor = None
811
+ if self.integrand.function_includes_sampling:
812
+ integration_channels = batch_channels
813
+ x, prob, weight, y, alphas_prior, channels = self.integrand.function(
814
+ batch_channels
815
+ )
816
+ else:
817
+ if self.group_channels and self.group_channels_uniform:
818
+ group_sizes = self.channel_group_sizes[batch_channels]
819
+ chan_in_group = (
820
+ torch.rand(
821
+ (n,), device=self.dummy.device, dtype=self.dummy.dtype
822
+ )
823
+ * group_sizes
824
+ ).long()
825
+ weight_factor = group_sizes
826
+ integration_channels = batch_channels
827
+ channels = self.channel_group_remap[batch_channels, chan_in_group]
828
+ else:
829
+ channels = batch_channels
830
+
831
+ with torch.no_grad():
832
+ x, prob = self.flow.sample(
833
+ n,
834
+ channel=channels,
835
+ return_prob=True,
836
+ device=self.dummy.device,
837
+ dtype=self.dummy.dtype,
838
+ )
839
+ weight, y, alphas_prior = self.integrand(x, channels)
840
+
841
+ if self.group_channels and not self.group_channels_uniform:
842
+ chan_in_group = x[:, self.channel_group_dim].long()
843
+ integration_channels = batch_channels
844
+ channels = self.channel_group_remap[
845
+ integration_channels, chan_in_group
846
+ ]
847
+
848
+ if weight_factor is not None:
849
+ weight *= weight_factor
850
+ batch = SampleBatch(
851
+ x,
852
+ y,
853
+ prob,
854
+ weight,
855
+ channels,
856
+ alphas_prior,
857
+ integration_channels=integration_channels,
858
+ )
859
+
860
+ if not train:
861
+ current_batch_size += batch.x.shape[0]
862
+ elif self.drop_zero_integrands:
863
+ mask = weight != 0.0
864
+ batch = batch.map(lambda t: t[mask])
865
+ if self.multichannel:
866
+ batch.zero_counts = torch.bincount(
867
+ (
868
+ channels[~mask]
869
+ if integration_channels is None
870
+ else integration_channels[~mask]
871
+ ),
872
+ minlength=self.integration_channel_count,
873
+ )
874
+ else:
875
+ batch.zero_counts = torch.full(
876
+ (1,), torch.count_nonzero(~mask), device=x.device
877
+ )
878
+ current_batch_size += batch.x.shape[0]
879
+ else:
880
+ current_batch_size += weight.count_nonzero()
881
+
882
+ # mask = ~(weight.isnan() | x.isnan().any(dim=1)) & mask
883
+ batches_out.append(batch)
884
+
885
+ # check this condition at the end such that the sampling runs at least once
886
+ if current_batch_size > self.batch_size_threshold * n:
887
+ break
888
+
889
+ return SampleBatch.cat(batches_out)
890
+
891
+ def train_step(self) -> TrainingStatus:
892
+ """
893
+ Performs a single training step
894
+
895
+ Returns:
896
+ Training status
897
+ """
898
+
899
+ if self.step == self.freeze_cwnet_iteration and self.cwnet is not None:
900
+ for param in self.cwnet.parameters():
901
+ param.requires_grad = False
902
+
903
+ if self.step_type_count == 0:
904
+ buffered = False
905
+ samples = self._get_samples(
906
+ self.batch_size, self.uniform_channel_ratio, train=True
907
+ )
908
+ loss, means, variances, counts = self._optimization_step(samples)
909
+ self._store_samples(samples)
910
+ self.integration_history.store(means[None], variances[None], counts[None])
911
+
912
+ if self.buffered_steps != 0 and self.buffer.size > self.minimum_buffer_size:
913
+ self.step_type_count += 1
914
+ else:
915
+ buffered = True
916
+ samples = SampleBatch(*self.buffer.sample(self.batch_size))
917
+ loss, _, _, _ = self._optimization_step(samples)
918
+ self.step_type_count = (self.step_type_count + 1) % (
919
+ self.buffered_steps + 1
920
+ )
921
+
922
+ dropped_channels = self._disable_unused_channels()
923
+ status = TrainingStatus(
924
+ step=self.step,
925
+ loss=loss,
926
+ buffered=buffered,
927
+ learning_rate=(
928
+ None if self.scheduler is None else self.scheduler.get_last_lr()[0]
929
+ ),
930
+ dropped_channels=dropped_channels,
931
+ )
932
+
933
+ if self.scheduler is not None:
934
+ self.scheduler.step()
935
+ self.step += 1
936
+ return status
937
+
938
+ def train(
939
+ self,
940
+ steps: int,
941
+ callback: Callable[[TrainingStatus], None] | None = None,
942
+ capture_keyboard_interrupt: bool = False,
943
+ ):
944
+ """
945
+ Performs multiple training steps
946
+
947
+ Args:
948
+ steps: number of training steps
949
+ callback: function that is called after each training step with the training status
950
+ as argument
951
+ capture_keyboard_interrupt: If True, a keyboard interrupt does not raise an exception.
952
+ Instead, the current training step is finished and the training is aborted
953
+ afterwards.
954
+ """
955
+ interrupted = False
956
+ if capture_keyboard_interrupt:
957
+
958
+ def handler(sig, frame):
959
+ nonlocal interrupted
960
+ interrupted = True
961
+
962
+ old_handler = signal.signal(signal.SIGINT, handler)
963
+
964
+ try:
965
+ for _ in range(steps):
966
+ status = self.train_step()
967
+ if callback is not None:
968
+ callback(status)
969
+ if interrupted:
970
+ break
971
+ finally:
972
+ if capture_keyboard_interrupt:
973
+ signal.signal(signal.SIGINT, old_handler)