madspace 0.3.1__cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. madspace/__init__.py +1 -0
  2. madspace/_madspace_py.cpython-314-x86_64-linux-gnu.so +0 -0
  3. madspace/_madspace_py.pyi +2189 -0
  4. madspace/_madspace_py_loader.py +111 -0
  5. madspace/include/madspace/constants.h +17 -0
  6. madspace/include/madspace/madcode/function.h +102 -0
  7. madspace/include/madspace/madcode/function_builder_mixin.h +591 -0
  8. madspace/include/madspace/madcode/instruction.h +208 -0
  9. madspace/include/madspace/madcode/opcode_mixin.h +134 -0
  10. madspace/include/madspace/madcode/optimizer.h +31 -0
  11. madspace/include/madspace/madcode/type.h +203 -0
  12. madspace/include/madspace/madcode.h +6 -0
  13. madspace/include/madspace/phasespace/base.h +74 -0
  14. madspace/include/madspace/phasespace/channel_weight_network.h +46 -0
  15. madspace/include/madspace/phasespace/channel_weights.h +51 -0
  16. madspace/include/madspace/phasespace/chili.h +32 -0
  17. madspace/include/madspace/phasespace/cross_section.h +47 -0
  18. madspace/include/madspace/phasespace/cuts.h +34 -0
  19. madspace/include/madspace/phasespace/discrete_flow.h +44 -0
  20. madspace/include/madspace/phasespace/discrete_sampler.h +53 -0
  21. madspace/include/madspace/phasespace/flow.h +53 -0
  22. madspace/include/madspace/phasespace/histograms.h +26 -0
  23. madspace/include/madspace/phasespace/integrand.h +204 -0
  24. madspace/include/madspace/phasespace/invariants.h +26 -0
  25. madspace/include/madspace/phasespace/luminosity.h +41 -0
  26. madspace/include/madspace/phasespace/matrix_element.h +70 -0
  27. madspace/include/madspace/phasespace/mlp.h +37 -0
  28. madspace/include/madspace/phasespace/multichannel.h +49 -0
  29. madspace/include/madspace/phasespace/observable.h +85 -0
  30. madspace/include/madspace/phasespace/pdf.h +78 -0
  31. madspace/include/madspace/phasespace/phasespace.h +67 -0
  32. madspace/include/madspace/phasespace/rambo.h +26 -0
  33. madspace/include/madspace/phasespace/scale.h +52 -0
  34. madspace/include/madspace/phasespace/t_propagator_mapping.h +34 -0
  35. madspace/include/madspace/phasespace/three_particle.h +68 -0
  36. madspace/include/madspace/phasespace/topology.h +116 -0
  37. madspace/include/madspace/phasespace/two_particle.h +63 -0
  38. madspace/include/madspace/phasespace/vegas.h +53 -0
  39. madspace/include/madspace/phasespace.h +27 -0
  40. madspace/include/madspace/runtime/context.h +147 -0
  41. madspace/include/madspace/runtime/discrete_optimizer.h +24 -0
  42. madspace/include/madspace/runtime/event_generator.h +257 -0
  43. madspace/include/madspace/runtime/format.h +68 -0
  44. madspace/include/madspace/runtime/io.h +343 -0
  45. madspace/include/madspace/runtime/lhe_output.h +132 -0
  46. madspace/include/madspace/runtime/logger.h +46 -0
  47. madspace/include/madspace/runtime/runtime_base.h +39 -0
  48. madspace/include/madspace/runtime/tensor.h +603 -0
  49. madspace/include/madspace/runtime/thread_pool.h +101 -0
  50. madspace/include/madspace/runtime/vegas_optimizer.h +26 -0
  51. madspace/include/madspace/runtime.h +12 -0
  52. madspace/include/madspace/umami.h +202 -0
  53. madspace/include/madspace/util.h +142 -0
  54. madspace/lib/libmadspace.so +0 -0
  55. madspace/lib/libmadspace_cpu.so +0 -0
  56. madspace/lib/libmadspace_cpu_avx2.so +0 -0
  57. madspace/lib/libmadspace_cpu_avx512.so +0 -0
  58. madspace/lib/libmadspace_cuda.so +0 -0
  59. madspace/lib/libmadspace_hip.so +0 -0
  60. madspace/madnis/__init__.py +44 -0
  61. madspace/madnis/buffer.py +167 -0
  62. madspace/madnis/channel_grouping.py +85 -0
  63. madspace/madnis/distribution.py +103 -0
  64. madspace/madnis/integrand.py +175 -0
  65. madspace/madnis/integrator.py +973 -0
  66. madspace/madnis/interface.py +191 -0
  67. madspace/madnis/losses.py +186 -0
  68. madspace/torch.py +82 -0
  69. madspace-0.3.1.dist-info/METADATA +71 -0
  70. madspace-0.3.1.dist-info/RECORD +75 -0
  71. madspace-0.3.1.dist-info/WHEEL +6 -0
  72. madspace-0.3.1.dist-info/licenses/LICENSE +21 -0
  73. madspace.libs/libgfortran-83c28eba.so.5.0.0 +0 -0
  74. madspace.libs/libopenblas-r0-11edc3fa.3.15.so +0 -0
  75. madspace.libs/libquadmath-2284e583.so.0.0.0 +0 -0
@@ -0,0 +1,26 @@
1
+ #pragma once
2
+
3
+ #include "madspace/madcode.h"
4
+ #include "madspace/runtime/context.h"
5
+ #include "madspace/runtime/tensor.h"
6
+
7
+ namespace madspace {
8
+
9
+ class VegasGridOptimizer {
10
+ public:
11
+ VegasGridOptimizer(
12
+ ContextPtr context, const std::string& grid_name, double damping
13
+ ) :
14
+ _context(context), _grid_name(grid_name), _damping(damping) {}
15
+ void add_data(Tensor weights, Tensor inputs);
16
+ void optimize();
17
+ std::size_t input_dim() const;
18
+
19
+ private:
20
+ ContextPtr _context;
21
+ std::string _grid_name;
22
+ double _damping;
23
+ std::vector<std::tuple<std::vector<std::size_t>, std::vector<double>>> _data;
24
+ };
25
+
26
+ } // namespace madspace
@@ -0,0 +1,12 @@
1
+ #pragma once
2
+
3
+ #include "runtime/context.h"
4
+ #include "runtime/discrete_optimizer.h"
5
+ #include "runtime/event_generator.h"
6
+ #include "runtime/format.h"
7
+ #include "runtime/io.h"
8
+ #include "runtime/lhe_output.h"
9
+ #include "runtime/logger.h"
10
+ #include "runtime/tensor.h"
11
+ #include "runtime/thread_pool.h"
12
+ #include "runtime/vegas_optimizer.h"
@@ -0,0 +1,202 @@
1
+ /*
2
+ * _
3
+ * (_)
4
+ * _ _ _ __ ___ __ _ _ __ ___ _
5
+ * | | | | '_ ` _ \ / _` | '_ ` _ \| |
6
+ * | |_| | | | | | | (_| | | | | | | |
7
+ * \__,_|_| |_| |_|\__,_|_| |_| |_|_|
8
+ *
9
+ * Unified MAtrix eleMent Interface
10
+ *
11
+ *
12
+ */
13
+
14
+ #ifndef UMAMI_HEADER
15
+ #define UMAMI_HEADER 1
16
+
17
+ #include <stddef.h>
18
+
19
+ #ifdef __cplusplus
20
+ extern "C" {
21
+ #endif
22
+
23
+ /**
24
+ * Major version number of the UMAMI interface. If the major version is the same
25
+ * between caller and implementation, binary compatibility is ensured.
26
+ */
27
+ const inline int UMAMI_MAJOR_VERSION = 1;
28
+ /**
29
+ * Minor version number of the UMAMI interface. Between minor versions, new keys for
30
+ * errors, devices, metadata, inputs and outputs can be added.
31
+ */
32
+ const inline int UMAMI_MINOR_VERSION = 0;
33
+
34
+ typedef enum {
35
+ UMAMI_SUCCESS,
36
+ UMAMI_ERROR,
37
+ UMAMI_ERROR_NOT_IMPLEMENTED,
38
+ UMAMI_ERROR_UNSUPPORTED_INPUT,
39
+ UMAMI_ERROR_UNSUPPORTED_OUTPUT,
40
+ UMAMI_ERROR_UNSUPPORTED_META,
41
+ UMAMI_ERROR_MISSING_INPUT,
42
+ } UmamiStatus;
43
+
44
+ typedef enum {
45
+ UMAMI_DEVICE_CPU,
46
+ UMAMI_DEVICE_CUDA,
47
+ UMAMI_DEVICE_HIP,
48
+ } UmamiDevice;
49
+
50
+ typedef enum {
51
+ UMAMI_META_DEVICE,
52
+ UMAMI_META_PARTICLE_COUNT,
53
+ UMAMI_META_DIAGRAM_COUNT,
54
+ UMAMI_META_HELICITY_COUNT,
55
+ UMAMI_META_COLOR_COUNT,
56
+ } UmamiMetaKey;
57
+
58
+ typedef enum {
59
+ UMAMI_IN_MOMENTA,
60
+ UMAMI_IN_ALPHA_S,
61
+ UMAMI_IN_FLAVOR_INDEX,
62
+ UMAMI_IN_RANDOM_COLOR,
63
+ UMAMI_IN_RANDOM_HELICITY,
64
+ UMAMI_IN_RANDOM_DIAGRAM,
65
+ UMAMI_IN_HELICITY_INDEX,
66
+ UMAMI_IN_DIAGRAM_INDEX,
67
+ UMAMI_IN_GPU_STREAM,
68
+ } UmamiInputKey;
69
+
70
+ typedef enum {
71
+ UMAMI_OUT_MATRIX_ELEMENT,
72
+ UMAMI_OUT_DIAGRAM_AMP2,
73
+ UMAMI_OUT_COLOR_INDEX,
74
+ UMAMI_OUT_HELICITY_INDEX,
75
+ UMAMI_OUT_DIAGRAM_INDEX,
76
+ // NLO: born, virtual, poles, counterterms
77
+ // color: LC-ME, FC-ME
78
+ } UmamiOutputKey;
79
+
80
+ typedef void* UmamiHandle;
81
+
82
+ /**
83
+ * Creates an instance of the matrix element. Each instance is independent, so thread
84
+ * safety can be achieved by creating a separate one for every thread.
85
+ *
86
+ * @param meta_key
87
+ * path to the parameter file
88
+ * @param handle
89
+ * pointer to an instance of the subprocess. Has to be cleaned up by
90
+ * the caller with `free_subprocess`.
91
+ * @return
92
+ * UMAMI_SUCCESS on success, error code otherwise
93
+ */
94
+ UmamiStatus umami_get_meta(UmamiMetaKey meta_key, void* result);
95
+
96
+ /**
97
+ * Creates an instance of the matrix element. Each instance is independent, so thread
98
+ * safety can be achieved by creating a separate one for every thread.
99
+ *
100
+ * @param param_card_path
101
+ * path to the parameter file
102
+ * @param handle
103
+ * pointer to an instance of the subprocess. Has to be cleaned up by
104
+ * the caller with `free_subprocess`.
105
+ * @return
106
+ * UMAMI_SUCCESS on success, error code otherwise
107
+ */
108
+ UmamiStatus umami_initialize(UmamiHandle* handle, char const* param_card_path);
109
+
110
+ /**
111
+ * Sets the value of a model parameter
112
+ *
113
+ * @param handle
114
+ * handle of a matrix element instance
115
+ * @param name
116
+ * name of the parameter
117
+ * @param parameter_real
118
+ * real part of the parameter value
119
+ * @param parameter_imag
120
+ * imaginary part of the parameter value. Ignored for real valued parameters.
121
+ * @return
122
+ * UMAMI_SUCCESS on success, error code otherwise
123
+ */
124
+ UmamiStatus umami_set_parameter(
125
+ UmamiHandle handle, char const* name, double parameter_real, double parameter_imag
126
+ );
127
+
128
+ /**
129
+ * Retrieves the value of a model parameter
130
+ *
131
+ * @param handle
132
+ * handle of a matrix element instance
133
+ * @param name
134
+ * name of the parameter
135
+ * @param parameter_real
136
+ * pointer to double to return real part of the parameter value
137
+ * @param parameter_imag
138
+ * pointer to double to return imaginary part of the parameter value. Ignored
139
+ * for real-valued parameters (i.e. you may pass a null pointer)
140
+ * @return
141
+ * UMAMI_SUCCESS on success, error code otherwise
142
+ */
143
+ UmamiStatus umami_get_parameter(
144
+ UmamiHandle handle, char const* name, double* parameter_real, double* parameter_imag
145
+ );
146
+
147
+ /**
148
+ * Evaluates the matrix element as a function of the given inputs, filling the
149
+ * requested outputs.
150
+ *
151
+ * @param handle
152
+ * handle of a matrix element instance
153
+ * @param count
154
+ * number of events to evaluate the matrix element for
155
+ * @param stride
156
+ * stride of the batch dimension of the input and output arrays, see memory layout
157
+ * @param offset
158
+ * offset of the event index
159
+ * @param input_count
160
+ * number of inputs to the matrix element
161
+ * @param input_keys
162
+ * pointer to an array of input keys, length `input_count`
163
+ * @param inputs
164
+ * pointer to an array of void pointers to the inputs. The type of the inputs
165
+ * depends on the input key
166
+ * @param output_count
167
+ * number of outputs to the matrix element
168
+ * @param output_keys
169
+ * pointer to an array of output keys, length `output_count`
170
+ * @param outputs
171
+ * pointer to an array of void pointers to the outputs. The type of the outputs
172
+ * depends on the output key. The caller is responsible for allocating memory for
173
+ * the outputs.
174
+ * @return
175
+ * UMAMI_SUCCESS on success, error code otherwise
176
+ */
177
+ UmamiStatus umami_matrix_element(
178
+ UmamiHandle handle,
179
+ size_t count,
180
+ size_t stride,
181
+ size_t offset,
182
+ size_t input_count,
183
+ UmamiInputKey const* input_keys,
184
+ void const* const* inputs,
185
+ size_t output_count,
186
+ UmamiOutputKey const* output_keys,
187
+ void* const* outputs
188
+ );
189
+
190
+ /**
191
+ * Frees matrix element instance
192
+ *
193
+ * @param handle
194
+ * handle of a matrix element instance
195
+ */
196
+ UmamiStatus umami_free(UmamiHandle handle);
197
+
198
+ #ifdef __cplusplus
199
+ }
200
+ #endif
201
+
202
+ #endif // UMAMI_HEADER
@@ -0,0 +1,142 @@
1
+ #pragma once
2
+
3
+ #include <cstdio>
4
+ #include <format>
5
+ #include <ranges>
6
+ #include <tuple>
7
+ #include <vector>
8
+
9
+ namespace madspace {
10
+
11
+ template <class... Ts>
12
+ struct Overloaded : Ts... {
13
+ using Ts::operator()...;
14
+ };
15
+ template <class... Ts>
16
+ Overloaded(Ts...) -> Overloaded<Ts...>;
17
+
18
+ template <typename T>
19
+ using nested_vector2 = std::vector<std::vector<T>>;
20
+ template <typename T>
21
+ using nested_vector3 = std::vector<std::vector<std::vector<T>>>;
22
+ template <typename T>
23
+ using nested_vector4 = std::vector<std::vector<std::vector<std::vector<T>>>>;
24
+
25
+ // Unfortunately nvcc does not support C++23 yet, so we implement our own zip function
26
+ // here (based on https://github.com/alemuntoni/zip-views), otherwise use the standard
27
+ // library function
28
+
29
+ namespace detail {
30
+
31
+ inline void print_impl(
32
+ std::FILE* stream, bool new_line, std::string_view fmt, std::format_args args
33
+ ) {
34
+ std::string str = std::vformat(fmt, args);
35
+ if (new_line) {
36
+ str.push_back('\n');
37
+ }
38
+ fwrite(str.data(), 1, str.size(), stream);
39
+ }
40
+
41
+ template <typename... Args, std::size_t... Index>
42
+ bool any_match_impl(
43
+ const std::tuple<Args...>& lhs,
44
+ const std::tuple<Args...>& rhs,
45
+ std::index_sequence<Index...>
46
+ ) {
47
+ auto result = false;
48
+ result = (... || (std::get<Index>(lhs) == std::get<Index>(rhs)));
49
+ return result;
50
+ }
51
+
52
+ template <typename... Args>
53
+ bool any_match(const std::tuple<Args...>& lhs, const std::tuple<Args...>& rhs) {
54
+ return any_match_impl(lhs, rhs, std::index_sequence_for<Args...>{});
55
+ }
56
+
57
+ template <std::ranges::viewable_range... Rng>
58
+ class zip_iterator {
59
+ public:
60
+ using value_type = std::tuple<std::ranges::range_reference_t<Rng>...>;
61
+
62
+ zip_iterator() = delete;
63
+ zip_iterator(std::ranges::iterator_t<Rng>&&... iters) :
64
+ _iters{std::forward<std::ranges::iterator_t<Rng>>(iters)...} {}
65
+
66
+ zip_iterator& operator++() {
67
+ std::apply([](auto&&... args) { ((++args), ...); }, _iters);
68
+ return *this;
69
+ }
70
+
71
+ zip_iterator operator++(int) {
72
+ auto tmp = *this;
73
+ ++*this;
74
+ return tmp;
75
+ }
76
+
77
+ bool operator!=(const zip_iterator& other) const { return !(*this == other); }
78
+
79
+ bool operator==(const zip_iterator& other) const {
80
+ return any_match(_iters, other._iters);
81
+ }
82
+
83
+ value_type operator*() {
84
+ return std::apply([](auto&&... args) { return value_type(*args...); }, _iters);
85
+ }
86
+
87
+ private:
88
+ std::tuple<std::ranges::iterator_t<Rng>...> _iters;
89
+ };
90
+
91
+ template <std::ranges::viewable_range... T>
92
+ class zipper {
93
+ public:
94
+ using zip_type = zip_iterator<T...>;
95
+
96
+ template <typename... Args>
97
+ zipper(Args&&... args) : _args{std::forward<Args>(args)...} {}
98
+
99
+ zip_type begin() {
100
+ return std::apply(
101
+ [](auto&&... args) { return zip_type(std::ranges::begin(args)...); }, _args
102
+ );
103
+ }
104
+ zip_type end() {
105
+ return std::apply(
106
+ [](auto&&... args) { return zip_type(std::ranges::end(args)...); }, _args
107
+ );
108
+ }
109
+
110
+ private:
111
+ std::tuple<T...> _args;
112
+ };
113
+
114
+ } // namespace detail
115
+
116
+ template <std::ranges::viewable_range... T>
117
+ auto zip(T&&... t) {
118
+ return detail::zipper<T...>{std::forward<T>(t)...};
119
+ }
120
+
121
+ template <typename... Args>
122
+ inline void print(std::format_string<Args...> fmt, Args&&... args) {
123
+ detail::print_impl(stdout, false, fmt.get(), std::make_format_args(args...));
124
+ }
125
+
126
+ template <typename... Args>
127
+ inline void print(std::FILE* stream, std::format_string<Args...> fmt, Args&&... args) {
128
+ detail::print_impl(stream, false, fmt.get(), std::make_format_args(args...));
129
+ }
130
+
131
+ template <typename... Args>
132
+ inline void println(std::format_string<Args...> fmt, Args&&... args) {
133
+ detail::print_impl(stdout, true, fmt.get(), std::make_format_args(args...));
134
+ }
135
+
136
+ template <typename... Args>
137
+ inline void
138
+ println(std::FILE* stream, std::format_string<Args...> fmt, Args&&... args) {
139
+ detail::print_impl(stream, true, fmt.get(), std::make_format_args(args...));
140
+ }
141
+
142
+ } // namespace madspace
Binary file
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -0,0 +1,44 @@
1
+ """
2
+ This module contains functions and classes to train neural importance sampling networks and
3
+ evaluate the integration and sampling performance.
4
+ """
5
+
6
+ from .buffer import Buffer
7
+ from .channel_grouping import ChannelData, ChannelGroup, ChannelGrouping
8
+ from .distribution import Distribution
9
+ from .integrand import Integrand
10
+ from .integrator import Integrator, SampleBatch, TrainingStatus
11
+ from .interface import (
12
+ MADNIS_INTEGRAND_FLAGS,
13
+ IntegrandDistribution,
14
+ IntegrandFunction,
15
+ build_madnis_integrand,
16
+ )
17
+ from .losses import (
18
+ kl_divergence,
19
+ multi_channel_loss,
20
+ rkl_divergence,
21
+ stratified_variance,
22
+ variance,
23
+ )
24
+
25
+ __all__ = [
26
+ "Integrator",
27
+ "TrainingStatus",
28
+ "SampleBatch",
29
+ "Integrand",
30
+ "Buffer",
31
+ "multi_channel_loss",
32
+ "stratified_variance",
33
+ "variance",
34
+ "kl_divergence",
35
+ "rkl_divergence",
36
+ "ChannelGroup",
37
+ "ChannelData",
38
+ "ChannelGrouping",
39
+ "Distribution",
40
+ "MADNIS_INTEGRAND_FLAGS",
41
+ "IntegrandDistribution",
42
+ "IntegrandFunction",
43
+ "build_madnis_integrand",
44
+ ]
@@ -0,0 +1,167 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class Buffer(nn.Module):
8
+ """
9
+ Circular buffer for multiple tensors with different shapes. The class is a torch.nn.Module to
10
+ allow for simple storage.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ capacity: int,
16
+ shapes: list[tuple[int, ...]],
17
+ persistent: bool = True,
18
+ dtypes: list[torch.dtype | None] | None = None,
19
+ ):
20
+ """
21
+ Args:
22
+ capacity: maximum number of samples stored in the buffer
23
+ shapes: shapes of the tensors to be stored, without batch dimension. If a shape is
24
+ None, no tensor is stored at that position. This allows for simpler handling of
25
+ optional stored fields.
26
+ persistent: if True, the content of the buffer is part of the module's state_dict
27
+ dtypes: if different from None, specifies the tensors which have a non-standard dtype
28
+ """
29
+ super().__init__()
30
+ self.keys = []
31
+ if dtypes is None:
32
+ dtypes = [None] * len(shapes)
33
+ for i, (shape, dtype) in enumerate(zip(shapes, dtypes)):
34
+ key = f"buffer{i}"
35
+ self.register_buffer(
36
+ key,
37
+ None if shape is None else torch.zeros((capacity, *shape), dtype=dtype),
38
+ persistent,
39
+ )
40
+ self.keys.append(key)
41
+ self.capacity = capacity
42
+ self.size = 0
43
+ self.store_index = 0
44
+
45
+ def _batch_slices(self, batch_size: int) -> slice:
46
+ """
47
+ Returns slices that split up the buffer into batches of at most ``batch_size``, respecting
48
+ the buffer size and periodic boundary.
49
+ """
50
+ start = self.store_index
51
+ while start < self.size:
52
+ stop = min(start + batch_size, self.size)
53
+ yield slice(start, stop)
54
+ start = stop
55
+ start = 0
56
+ while start < self.store_index:
57
+ stop = min(start + batch_size, self.store_index)
58
+ yield slice(start, stop)
59
+ start = stop
60
+
61
+ def _buffer_fields(self) -> torch.Tensor | None:
62
+ """
63
+ Iterates over the buffered tensors, without removing the padding if the buffer is not full.
64
+
65
+ Returns:
66
+ The buffered tensors, or None if a tensor was initialized with shape None
67
+ """
68
+ for key in self.keys:
69
+ yield getattr(self, key)
70
+
71
+ def __iter__(self) -> torch.Tensor | None:
72
+ """
73
+ Iterates over the buffered tensors
74
+
75
+ Returns:
76
+ The buffered tensors, or None if a tensor was initialized with shape None
77
+ """
78
+ for key in self.keys:
79
+ buffer = getattr(self, key)
80
+ yield None if buffer is None else buffer[: self.size]
81
+
82
+ def store(self, *tensors: torch.Tensor | None):
83
+ """
84
+ Adds the given tensors to the buffer. If the buffer is full, the oldest stored samples are
85
+ overwritten.
86
+
87
+ Args:
88
+ tensors: samples to be stored. The shapes of the tensors after the batch dimension must
89
+ match the shapes given during initialization. The argument can be None if the
90
+ corresponding shape was None during initialization.
91
+ """
92
+ store_slice1 = None
93
+ for buffer, data in zip(self._buffer_fields(), tensors):
94
+ if data is None:
95
+ continue
96
+ if store_slice1 is None:
97
+ size = min(data.shape[0], self.capacity)
98
+ end_index = self.store_index + size
99
+ if end_index < self.capacity:
100
+ store_slice1 = slice(self.store_index, end_index)
101
+ store_slice2 = slice(0, 0)
102
+ load_slice1 = slice(0, size)
103
+ load_slice2 = slice(0, 0)
104
+ else:
105
+ store_slice1 = slice(self.store_index, self.capacity)
106
+ store_slice2 = slice(0, end_index - self.capacity)
107
+ load_slice1 = slice(0, self.capacity - self.store_index)
108
+ load_slice2 = slice(self.capacity - self.store_index, size)
109
+ self.store_index = end_index % self.capacity
110
+ self.size = min(self.size + size, self.capacity)
111
+ buffer[store_slice1] = data[load_slice1]
112
+ buffer[store_slice2] = data[load_slice2]
113
+
114
+ def filter(
115
+ self,
116
+ predicate: Callable[[tuple[torch.Tensor | None, ...]], torch.Tensor],
117
+ batch_size: int = 100000,
118
+ ):
119
+ """
120
+ Removes samples from the buffer that do not fulfill the criterion given by the predicate
121
+ function.
122
+
123
+ Args:
124
+ predicate: function that returns a mask for a batch of samples, given a tuple with
125
+ all the buffered fields as argument
126
+ batch_size: maximal batch size to limit memory usage
127
+ """
128
+ masks = []
129
+ masked_size = 0
130
+ for batch_slice in self._batch_slices(batch_size):
131
+ mask = predicate(
132
+ tuple(
133
+ None if t is None else t[batch_slice] for t in self._buffer_fields()
134
+ )
135
+ )
136
+ masked_size += torch.count_nonzero(mask)
137
+ masks.append(mask)
138
+ for buffer in self._buffer_fields():
139
+ if buffer is None:
140
+ continue
141
+ buffer[:masked_size] = torch.cat(
142
+ [
143
+ buffer[batch_slice][mask]
144
+ for batch_slice, mask in zip(self._batch_slices(batch_size), masks)
145
+ ],
146
+ dim=0,
147
+ )
148
+ self.size = masked_size
149
+ self.store_index = masked_size % self.capacity
150
+
151
+ def sample(self, count: int) -> list[torch.Tensor | None]:
152
+ """
153
+ Returns a batch of samples drawn from the buffer without replacement.
154
+
155
+ Args:
156
+ count: number of samples
157
+ Returns:
158
+ samples drawn from the buffer
159
+ """
160
+ weights = next(b for b in self._buffer_fields() if b is not None).new_ones(
161
+ self.size
162
+ )
163
+ indices = torch.multinomial(weights, min(count, self.size), replacement=False)
164
+ return [
165
+ None if buffer is None else buffer[indices]
166
+ for buffer in self._buffer_fields()
167
+ ]
@@ -0,0 +1,85 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class ChannelGroup:
6
+ """
7
+ A group of channels
8
+
9
+ Args:
10
+ group_index: index of the group in the list of groups
11
+ target_index: index of the channel that all other channels in the group are mapped to
12
+ channel_indices: indices of the channels in the group
13
+ """
14
+
15
+ group_index: int
16
+ target_index: int
17
+ channel_indices: list[int]
18
+
19
+
20
+ @dataclass
21
+ class ChannelData:
22
+ """
23
+ Information about a single channel
24
+
25
+ Args:
26
+ channel_index: index of the channel
27
+ target_index: index of the channel that it is mapped to
28
+ group: channel group that the channel belongs to
29
+ remapped: True if the channel is remapped to another channel
30
+ position_in_group: index of the channel within its group
31
+ """
32
+
33
+ channel_index: int
34
+ target_index: int
35
+ group: ChannelGroup
36
+ remapped: bool
37
+ position_in_group: int
38
+
39
+
40
+ class ChannelGrouping:
41
+ """
42
+ Class that encodes how channels are grouped together for a multi-channel integrand
43
+ """
44
+
45
+ def __init__(self, channel_assignment: list[int | None]):
46
+ """
47
+ Args:
48
+ channel_assignment: list with an entry for each channel. If None, the channel is not
49
+ remapped. Otherwise, the index of the channel to which it is mapped.
50
+ """
51
+ group_dict = {}
52
+ for source_channel, target_channel in enumerate(channel_assignment):
53
+ if target_channel is None:
54
+ group_dict[source_channel] = ChannelGroup(
55
+ group_index=len(group_dict),
56
+ target_index=source_channel,
57
+ channel_indices=[source_channel],
58
+ )
59
+
60
+ self.channels: list[ChannelData] = []
61
+ self.groups: list[ChannelGroup] = list(group_dict.values())
62
+
63
+ for source_channel, target_channel in enumerate(channel_assignment):
64
+ if target_channel is None:
65
+ self.channels.append(
66
+ ChannelData(
67
+ channel_index=source_channel,
68
+ target_index=source_channel,
69
+ group=group_dict[source_channel],
70
+ remapped=False,
71
+ position_in_group=0,
72
+ )
73
+ )
74
+ else:
75
+ group = group_dict[target_channel]
76
+ self.channels.append(
77
+ ChannelData(
78
+ channel_index=source_channel,
79
+ target_index=target_channel,
80
+ group=group,
81
+ remapped=True,
82
+ position_in_group=len(group.channel_indices),
83
+ )
84
+ )
85
+ group.channel_indices.append(source_channel)