madspace 0.3.1__cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- madspace/__init__.py +1 -0
- madspace/_madspace_py.cpython-312-x86_64-linux-gnu.so +0 -0
- madspace/_madspace_py.pyi +2189 -0
- madspace/_madspace_py_loader.py +111 -0
- madspace/include/madspace/constants.h +17 -0
- madspace/include/madspace/madcode/function.h +102 -0
- madspace/include/madspace/madcode/function_builder_mixin.h +591 -0
- madspace/include/madspace/madcode/instruction.h +208 -0
- madspace/include/madspace/madcode/opcode_mixin.h +134 -0
- madspace/include/madspace/madcode/optimizer.h +31 -0
- madspace/include/madspace/madcode/type.h +203 -0
- madspace/include/madspace/madcode.h +6 -0
- madspace/include/madspace/phasespace/base.h +74 -0
- madspace/include/madspace/phasespace/channel_weight_network.h +46 -0
- madspace/include/madspace/phasespace/channel_weights.h +51 -0
- madspace/include/madspace/phasespace/chili.h +32 -0
- madspace/include/madspace/phasespace/cross_section.h +47 -0
- madspace/include/madspace/phasespace/cuts.h +34 -0
- madspace/include/madspace/phasespace/discrete_flow.h +44 -0
- madspace/include/madspace/phasespace/discrete_sampler.h +53 -0
- madspace/include/madspace/phasespace/flow.h +53 -0
- madspace/include/madspace/phasespace/histograms.h +26 -0
- madspace/include/madspace/phasespace/integrand.h +204 -0
- madspace/include/madspace/phasespace/invariants.h +26 -0
- madspace/include/madspace/phasespace/luminosity.h +41 -0
- madspace/include/madspace/phasespace/matrix_element.h +70 -0
- madspace/include/madspace/phasespace/mlp.h +37 -0
- madspace/include/madspace/phasespace/multichannel.h +49 -0
- madspace/include/madspace/phasespace/observable.h +85 -0
- madspace/include/madspace/phasespace/pdf.h +78 -0
- madspace/include/madspace/phasespace/phasespace.h +67 -0
- madspace/include/madspace/phasespace/rambo.h +26 -0
- madspace/include/madspace/phasespace/scale.h +52 -0
- madspace/include/madspace/phasespace/t_propagator_mapping.h +34 -0
- madspace/include/madspace/phasespace/three_particle.h +68 -0
- madspace/include/madspace/phasespace/topology.h +116 -0
- madspace/include/madspace/phasespace/two_particle.h +63 -0
- madspace/include/madspace/phasespace/vegas.h +53 -0
- madspace/include/madspace/phasespace.h +27 -0
- madspace/include/madspace/runtime/context.h +147 -0
- madspace/include/madspace/runtime/discrete_optimizer.h +24 -0
- madspace/include/madspace/runtime/event_generator.h +257 -0
- madspace/include/madspace/runtime/format.h +68 -0
- madspace/include/madspace/runtime/io.h +343 -0
- madspace/include/madspace/runtime/lhe_output.h +132 -0
- madspace/include/madspace/runtime/logger.h +46 -0
- madspace/include/madspace/runtime/runtime_base.h +39 -0
- madspace/include/madspace/runtime/tensor.h +603 -0
- madspace/include/madspace/runtime/thread_pool.h +101 -0
- madspace/include/madspace/runtime/vegas_optimizer.h +26 -0
- madspace/include/madspace/runtime.h +12 -0
- madspace/include/madspace/umami.h +202 -0
- madspace/include/madspace/util.h +142 -0
- madspace/lib/libmadspace.so +0 -0
- madspace/lib/libmadspace_cpu.so +0 -0
- madspace/lib/libmadspace_cpu_avx2.so +0 -0
- madspace/lib/libmadspace_cpu_avx512.so +0 -0
- madspace/lib/libmadspace_cuda.so +0 -0
- madspace/lib/libmadspace_hip.so +0 -0
- madspace/madnis/__init__.py +44 -0
- madspace/madnis/buffer.py +167 -0
- madspace/madnis/channel_grouping.py +85 -0
- madspace/madnis/distribution.py +103 -0
- madspace/madnis/integrand.py +175 -0
- madspace/madnis/integrator.py +973 -0
- madspace/madnis/interface.py +191 -0
- madspace/madnis/losses.py +186 -0
- madspace/torch.py +82 -0
- madspace-0.3.1.dist-info/METADATA +71 -0
- madspace-0.3.1.dist-info/RECORD +75 -0
- madspace-0.3.1.dist-info/WHEEL +6 -0
- madspace-0.3.1.dist-info/licenses/LICENSE +21 -0
- madspace.libs/libgfortran-83c28eba.so.5.0.0 +0 -0
- madspace.libs/libopenblas-r0-11edc3fa.3.15.so +0 -0
- madspace.libs/libquadmath-2284e583.so.0.0.0 +0 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "madspace/madcode.h"
|
|
4
|
+
#include "madspace/runtime/context.h"
|
|
5
|
+
#include "madspace/runtime/tensor.h"
|
|
6
|
+
|
|
7
|
+
namespace madspace {
|
|
8
|
+
|
|
9
|
+
class VegasGridOptimizer {
|
|
10
|
+
public:
|
|
11
|
+
VegasGridOptimizer(
|
|
12
|
+
ContextPtr context, const std::string& grid_name, double damping
|
|
13
|
+
) :
|
|
14
|
+
_context(context), _grid_name(grid_name), _damping(damping) {}
|
|
15
|
+
void add_data(Tensor weights, Tensor inputs);
|
|
16
|
+
void optimize();
|
|
17
|
+
std::size_t input_dim() const;
|
|
18
|
+
|
|
19
|
+
private:
|
|
20
|
+
ContextPtr _context;
|
|
21
|
+
std::string _grid_name;
|
|
22
|
+
double _damping;
|
|
23
|
+
std::vector<std::tuple<std::vector<std::size_t>, std::vector<double>>> _data;
|
|
24
|
+
};
|
|
25
|
+
|
|
26
|
+
} // namespace madspace
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "runtime/context.h"
|
|
4
|
+
#include "runtime/discrete_optimizer.h"
|
|
5
|
+
#include "runtime/event_generator.h"
|
|
6
|
+
#include "runtime/format.h"
|
|
7
|
+
#include "runtime/io.h"
|
|
8
|
+
#include "runtime/lhe_output.h"
|
|
9
|
+
#include "runtime/logger.h"
|
|
10
|
+
#include "runtime/tensor.h"
|
|
11
|
+
#include "runtime/thread_pool.h"
|
|
12
|
+
#include "runtime/vegas_optimizer.h"
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* _
|
|
3
|
+
* (_)
|
|
4
|
+
* _ _ _ __ ___ __ _ _ __ ___ _
|
|
5
|
+
* | | | | '_ ` _ \ / _` | '_ ` _ \| |
|
|
6
|
+
* | |_| | | | | | | (_| | | | | | | |
|
|
7
|
+
* \__,_|_| |_| |_|\__,_|_| |_| |_|_|
|
|
8
|
+
*
|
|
9
|
+
* Unified MAtrix eleMent Interface
|
|
10
|
+
*
|
|
11
|
+
*
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
#ifndef UMAMI_HEADER
|
|
15
|
+
#define UMAMI_HEADER 1
|
|
16
|
+
|
|
17
|
+
#include <stddef.h>
|
|
18
|
+
|
|
19
|
+
#ifdef __cplusplus
|
|
20
|
+
extern "C" {
|
|
21
|
+
#endif
|
|
22
|
+
|
|
23
|
+
/**
|
|
24
|
+
* Major version number of the UMAMI interface. If the major version is the same
|
|
25
|
+
* between caller and implementation, binary compatibility is ensured.
|
|
26
|
+
*/
|
|
27
|
+
const inline int UMAMI_MAJOR_VERSION = 1;
|
|
28
|
+
/**
|
|
29
|
+
* Minor version number of the UMAMI interface. Between minor versions, new keys for
|
|
30
|
+
* errors, devices, metadata, inputs and outputs can be added.
|
|
31
|
+
*/
|
|
32
|
+
const inline int UMAMI_MINOR_VERSION = 0;
|
|
33
|
+
|
|
34
|
+
typedef enum {
|
|
35
|
+
UMAMI_SUCCESS,
|
|
36
|
+
UMAMI_ERROR,
|
|
37
|
+
UMAMI_ERROR_NOT_IMPLEMENTED,
|
|
38
|
+
UMAMI_ERROR_UNSUPPORTED_INPUT,
|
|
39
|
+
UMAMI_ERROR_UNSUPPORTED_OUTPUT,
|
|
40
|
+
UMAMI_ERROR_UNSUPPORTED_META,
|
|
41
|
+
UMAMI_ERROR_MISSING_INPUT,
|
|
42
|
+
} UmamiStatus;
|
|
43
|
+
|
|
44
|
+
typedef enum {
|
|
45
|
+
UMAMI_DEVICE_CPU,
|
|
46
|
+
UMAMI_DEVICE_CUDA,
|
|
47
|
+
UMAMI_DEVICE_HIP,
|
|
48
|
+
} UmamiDevice;
|
|
49
|
+
|
|
50
|
+
typedef enum {
|
|
51
|
+
UMAMI_META_DEVICE,
|
|
52
|
+
UMAMI_META_PARTICLE_COUNT,
|
|
53
|
+
UMAMI_META_DIAGRAM_COUNT,
|
|
54
|
+
UMAMI_META_HELICITY_COUNT,
|
|
55
|
+
UMAMI_META_COLOR_COUNT,
|
|
56
|
+
} UmamiMetaKey;
|
|
57
|
+
|
|
58
|
+
typedef enum {
|
|
59
|
+
UMAMI_IN_MOMENTA,
|
|
60
|
+
UMAMI_IN_ALPHA_S,
|
|
61
|
+
UMAMI_IN_FLAVOR_INDEX,
|
|
62
|
+
UMAMI_IN_RANDOM_COLOR,
|
|
63
|
+
UMAMI_IN_RANDOM_HELICITY,
|
|
64
|
+
UMAMI_IN_RANDOM_DIAGRAM,
|
|
65
|
+
UMAMI_IN_HELICITY_INDEX,
|
|
66
|
+
UMAMI_IN_DIAGRAM_INDEX,
|
|
67
|
+
UMAMI_IN_GPU_STREAM,
|
|
68
|
+
} UmamiInputKey;
|
|
69
|
+
|
|
70
|
+
typedef enum {
|
|
71
|
+
UMAMI_OUT_MATRIX_ELEMENT,
|
|
72
|
+
UMAMI_OUT_DIAGRAM_AMP2,
|
|
73
|
+
UMAMI_OUT_COLOR_INDEX,
|
|
74
|
+
UMAMI_OUT_HELICITY_INDEX,
|
|
75
|
+
UMAMI_OUT_DIAGRAM_INDEX,
|
|
76
|
+
// NLO: born, virtual, poles, counterterms
|
|
77
|
+
// color: LC-ME, FC-ME
|
|
78
|
+
} UmamiOutputKey;
|
|
79
|
+
|
|
80
|
+
typedef void* UmamiHandle;
|
|
81
|
+
|
|
82
|
+
/**
|
|
83
|
+
* Creates an instance of the matrix element. Each instance is independent, so thread
|
|
84
|
+
* safety can be achieved by creating a separate one for every thread.
|
|
85
|
+
*
|
|
86
|
+
* @param meta_key
|
|
87
|
+
* path to the parameter file
|
|
88
|
+
* @param handle
|
|
89
|
+
* pointer to an instance of the subprocess. Has to be cleaned up by
|
|
90
|
+
* the caller with `free_subprocess`.
|
|
91
|
+
* @return
|
|
92
|
+
* UMAMI_SUCCESS on success, error code otherwise
|
|
93
|
+
*/
|
|
94
|
+
UmamiStatus umami_get_meta(UmamiMetaKey meta_key, void* result);
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* Creates an instance of the matrix element. Each instance is independent, so thread
|
|
98
|
+
* safety can be achieved by creating a separate one for every thread.
|
|
99
|
+
*
|
|
100
|
+
* @param param_card_path
|
|
101
|
+
* path to the parameter file
|
|
102
|
+
* @param handle
|
|
103
|
+
* pointer to an instance of the subprocess. Has to be cleaned up by
|
|
104
|
+
* the caller with `free_subprocess`.
|
|
105
|
+
* @return
|
|
106
|
+
* UMAMI_SUCCESS on success, error code otherwise
|
|
107
|
+
*/
|
|
108
|
+
UmamiStatus umami_initialize(UmamiHandle* handle, char const* param_card_path);
|
|
109
|
+
|
|
110
|
+
/**
|
|
111
|
+
* Sets the value of a model parameter
|
|
112
|
+
*
|
|
113
|
+
* @param handle
|
|
114
|
+
* handle of a matrix element instance
|
|
115
|
+
* @param name
|
|
116
|
+
* name of the parameter
|
|
117
|
+
* @param parameter_real
|
|
118
|
+
* real part of the parameter value
|
|
119
|
+
* @param parameter_imag
|
|
120
|
+
* imaginary part of the parameter value. Ignored for real valued parameters.
|
|
121
|
+
* @return
|
|
122
|
+
* UMAMI_SUCCESS on success, error code otherwise
|
|
123
|
+
*/
|
|
124
|
+
UmamiStatus umami_set_parameter(
|
|
125
|
+
UmamiHandle handle, char const* name, double parameter_real, double parameter_imag
|
|
126
|
+
);
|
|
127
|
+
|
|
128
|
+
/**
|
|
129
|
+
* Retrieves the value of a model parameter
|
|
130
|
+
*
|
|
131
|
+
* @param handle
|
|
132
|
+
* handle of a matrix element instance
|
|
133
|
+
* @param name
|
|
134
|
+
* name of the parameter
|
|
135
|
+
* @param parameter_real
|
|
136
|
+
* pointer to double to return real part of the parameter value
|
|
137
|
+
* @param parameter_imag
|
|
138
|
+
* pointer to double to return imaginary part of the parameter value. Ignored
|
|
139
|
+
* for real-valued parameters (i.e. you may pass a null pointer)
|
|
140
|
+
* @return
|
|
141
|
+
* UMAMI_SUCCESS on success, error code otherwise
|
|
142
|
+
*/
|
|
143
|
+
UmamiStatus umami_get_parameter(
|
|
144
|
+
UmamiHandle handle, char const* name, double* parameter_real, double* parameter_imag
|
|
145
|
+
);
|
|
146
|
+
|
|
147
|
+
/**
|
|
148
|
+
* Evaluates the matrix element as a function of the given inputs, filling the
|
|
149
|
+
* requested outputs.
|
|
150
|
+
*
|
|
151
|
+
* @param handle
|
|
152
|
+
* handle of a matrix element instance
|
|
153
|
+
* @param count
|
|
154
|
+
* number of events to evaluate the matrix element for
|
|
155
|
+
* @param stride
|
|
156
|
+
* stride of the batch dimension of the input and output arrays, see memory layout
|
|
157
|
+
* @param offset
|
|
158
|
+
* offset of the event index
|
|
159
|
+
* @param input_count
|
|
160
|
+
* number of inputs to the matrix element
|
|
161
|
+
* @param input_keys
|
|
162
|
+
* pointer to an array of input keys, length `input_count`
|
|
163
|
+
* @param inputs
|
|
164
|
+
* pointer to an array of void pointers to the inputs. The type of the inputs
|
|
165
|
+
* depends on the input key
|
|
166
|
+
* @param output_count
|
|
167
|
+
* number of outputs to the matrix element
|
|
168
|
+
* @param output_keys
|
|
169
|
+
* pointer to an array of output keys, length `output_count`
|
|
170
|
+
* @param outputs
|
|
171
|
+
* pointer to an array of void pointers to the outputs. The type of the outputs
|
|
172
|
+
* depends on the output key. The caller is responsible for allocating memory for
|
|
173
|
+
* the outputs.
|
|
174
|
+
* @return
|
|
175
|
+
* UMAMI_SUCCESS on success, error code otherwise
|
|
176
|
+
*/
|
|
177
|
+
UmamiStatus umami_matrix_element(
|
|
178
|
+
UmamiHandle handle,
|
|
179
|
+
size_t count,
|
|
180
|
+
size_t stride,
|
|
181
|
+
size_t offset,
|
|
182
|
+
size_t input_count,
|
|
183
|
+
UmamiInputKey const* input_keys,
|
|
184
|
+
void const* const* inputs,
|
|
185
|
+
size_t output_count,
|
|
186
|
+
UmamiOutputKey const* output_keys,
|
|
187
|
+
void* const* outputs
|
|
188
|
+
);
|
|
189
|
+
|
|
190
|
+
/**
|
|
191
|
+
* Frees matrix element instance
|
|
192
|
+
*
|
|
193
|
+
* @param handle
|
|
194
|
+
* handle of a matrix element instance
|
|
195
|
+
*/
|
|
196
|
+
UmamiStatus umami_free(UmamiHandle handle);
|
|
197
|
+
|
|
198
|
+
#ifdef __cplusplus
|
|
199
|
+
}
|
|
200
|
+
#endif
|
|
201
|
+
|
|
202
|
+
#endif // UMAMI_HEADER
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cstdio>
|
|
4
|
+
#include <format>
|
|
5
|
+
#include <ranges>
|
|
6
|
+
#include <tuple>
|
|
7
|
+
#include <vector>
|
|
8
|
+
|
|
9
|
+
namespace madspace {
|
|
10
|
+
|
|
11
|
+
template <class... Ts>
|
|
12
|
+
struct Overloaded : Ts... {
|
|
13
|
+
using Ts::operator()...;
|
|
14
|
+
};
|
|
15
|
+
template <class... Ts>
|
|
16
|
+
Overloaded(Ts...) -> Overloaded<Ts...>;
|
|
17
|
+
|
|
18
|
+
template <typename T>
|
|
19
|
+
using nested_vector2 = std::vector<std::vector<T>>;
|
|
20
|
+
template <typename T>
|
|
21
|
+
using nested_vector3 = std::vector<std::vector<std::vector<T>>>;
|
|
22
|
+
template <typename T>
|
|
23
|
+
using nested_vector4 = std::vector<std::vector<std::vector<std::vector<T>>>>;
|
|
24
|
+
|
|
25
|
+
// Unfortunately nvcc does not support C++23 yet, so we implement our own zip function
|
|
26
|
+
// here (based on https://github.com/alemuntoni/zip-views), otherwise use the standard
|
|
27
|
+
// library function
|
|
28
|
+
|
|
29
|
+
namespace detail {
|
|
30
|
+
|
|
31
|
+
inline void print_impl(
|
|
32
|
+
std::FILE* stream, bool new_line, std::string_view fmt, std::format_args args
|
|
33
|
+
) {
|
|
34
|
+
std::string str = std::vformat(fmt, args);
|
|
35
|
+
if (new_line) {
|
|
36
|
+
str.push_back('\n');
|
|
37
|
+
}
|
|
38
|
+
fwrite(str.data(), 1, str.size(), stream);
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
template <typename... Args, std::size_t... Index>
|
|
42
|
+
bool any_match_impl(
|
|
43
|
+
const std::tuple<Args...>& lhs,
|
|
44
|
+
const std::tuple<Args...>& rhs,
|
|
45
|
+
std::index_sequence<Index...>
|
|
46
|
+
) {
|
|
47
|
+
auto result = false;
|
|
48
|
+
result = (... || (std::get<Index>(lhs) == std::get<Index>(rhs)));
|
|
49
|
+
return result;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
template <typename... Args>
|
|
53
|
+
bool any_match(const std::tuple<Args...>& lhs, const std::tuple<Args...>& rhs) {
|
|
54
|
+
return any_match_impl(lhs, rhs, std::index_sequence_for<Args...>{});
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
template <std::ranges::viewable_range... Rng>
|
|
58
|
+
class zip_iterator {
|
|
59
|
+
public:
|
|
60
|
+
using value_type = std::tuple<std::ranges::range_reference_t<Rng>...>;
|
|
61
|
+
|
|
62
|
+
zip_iterator() = delete;
|
|
63
|
+
zip_iterator(std::ranges::iterator_t<Rng>&&... iters) :
|
|
64
|
+
_iters{std::forward<std::ranges::iterator_t<Rng>>(iters)...} {}
|
|
65
|
+
|
|
66
|
+
zip_iterator& operator++() {
|
|
67
|
+
std::apply([](auto&&... args) { ((++args), ...); }, _iters);
|
|
68
|
+
return *this;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
zip_iterator operator++(int) {
|
|
72
|
+
auto tmp = *this;
|
|
73
|
+
++*this;
|
|
74
|
+
return tmp;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
bool operator!=(const zip_iterator& other) const { return !(*this == other); }
|
|
78
|
+
|
|
79
|
+
bool operator==(const zip_iterator& other) const {
|
|
80
|
+
return any_match(_iters, other._iters);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
value_type operator*() {
|
|
84
|
+
return std::apply([](auto&&... args) { return value_type(*args...); }, _iters);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
private:
|
|
88
|
+
std::tuple<std::ranges::iterator_t<Rng>...> _iters;
|
|
89
|
+
};
|
|
90
|
+
|
|
91
|
+
template <std::ranges::viewable_range... T>
|
|
92
|
+
class zipper {
|
|
93
|
+
public:
|
|
94
|
+
using zip_type = zip_iterator<T...>;
|
|
95
|
+
|
|
96
|
+
template <typename... Args>
|
|
97
|
+
zipper(Args&&... args) : _args{std::forward<Args>(args)...} {}
|
|
98
|
+
|
|
99
|
+
zip_type begin() {
|
|
100
|
+
return std::apply(
|
|
101
|
+
[](auto&&... args) { return zip_type(std::ranges::begin(args)...); }, _args
|
|
102
|
+
);
|
|
103
|
+
}
|
|
104
|
+
zip_type end() {
|
|
105
|
+
return std::apply(
|
|
106
|
+
[](auto&&... args) { return zip_type(std::ranges::end(args)...); }, _args
|
|
107
|
+
);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
private:
|
|
111
|
+
std::tuple<T...> _args;
|
|
112
|
+
};
|
|
113
|
+
|
|
114
|
+
} // namespace detail
|
|
115
|
+
|
|
116
|
+
template <std::ranges::viewable_range... T>
|
|
117
|
+
auto zip(T&&... t) {
|
|
118
|
+
return detail::zipper<T...>{std::forward<T>(t)...};
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
template <typename... Args>
|
|
122
|
+
inline void print(std::format_string<Args...> fmt, Args&&... args) {
|
|
123
|
+
detail::print_impl(stdout, false, fmt.get(), std::make_format_args(args...));
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
template <typename... Args>
|
|
127
|
+
inline void print(std::FILE* stream, std::format_string<Args...> fmt, Args&&... args) {
|
|
128
|
+
detail::print_impl(stream, false, fmt.get(), std::make_format_args(args...));
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
template <typename... Args>
|
|
132
|
+
inline void println(std::format_string<Args...> fmt, Args&&... args) {
|
|
133
|
+
detail::print_impl(stdout, true, fmt.get(), std::make_format_args(args...));
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
template <typename... Args>
|
|
137
|
+
inline void
|
|
138
|
+
println(std::FILE* stream, std::format_string<Args...> fmt, Args&&... args) {
|
|
139
|
+
detail::print_impl(stream, true, fmt.get(), std::make_format_args(args...));
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
} // namespace madspace
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains functions and classes to train neural importance sampling networks and
|
|
3
|
+
evaluate the integration and sampling performance.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .buffer import Buffer
|
|
7
|
+
from .channel_grouping import ChannelData, ChannelGroup, ChannelGrouping
|
|
8
|
+
from .distribution import Distribution
|
|
9
|
+
from .integrand import Integrand
|
|
10
|
+
from .integrator import Integrator, SampleBatch, TrainingStatus
|
|
11
|
+
from .interface import (
|
|
12
|
+
MADNIS_INTEGRAND_FLAGS,
|
|
13
|
+
IntegrandDistribution,
|
|
14
|
+
IntegrandFunction,
|
|
15
|
+
build_madnis_integrand,
|
|
16
|
+
)
|
|
17
|
+
from .losses import (
|
|
18
|
+
kl_divergence,
|
|
19
|
+
multi_channel_loss,
|
|
20
|
+
rkl_divergence,
|
|
21
|
+
stratified_variance,
|
|
22
|
+
variance,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"Integrator",
|
|
27
|
+
"TrainingStatus",
|
|
28
|
+
"SampleBatch",
|
|
29
|
+
"Integrand",
|
|
30
|
+
"Buffer",
|
|
31
|
+
"multi_channel_loss",
|
|
32
|
+
"stratified_variance",
|
|
33
|
+
"variance",
|
|
34
|
+
"kl_divergence",
|
|
35
|
+
"rkl_divergence",
|
|
36
|
+
"ChannelGroup",
|
|
37
|
+
"ChannelData",
|
|
38
|
+
"ChannelGrouping",
|
|
39
|
+
"Distribution",
|
|
40
|
+
"MADNIS_INTEGRAND_FLAGS",
|
|
41
|
+
"IntegrandDistribution",
|
|
42
|
+
"IntegrandFunction",
|
|
43
|
+
"build_madnis_integrand",
|
|
44
|
+
]
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Buffer(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Circular buffer for multiple tensors with different shapes. The class is a torch.nn.Module to
|
|
10
|
+
allow for simple storage.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
capacity: int,
|
|
16
|
+
shapes: list[tuple[int, ...]],
|
|
17
|
+
persistent: bool = True,
|
|
18
|
+
dtypes: list[torch.dtype | None] | None = None,
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
Args:
|
|
22
|
+
capacity: maximum number of samples stored in the buffer
|
|
23
|
+
shapes: shapes of the tensors to be stored, without batch dimension. If a shape is
|
|
24
|
+
None, no tensor is stored at that position. This allows for simpler handling of
|
|
25
|
+
optional stored fields.
|
|
26
|
+
persistent: if True, the content of the buffer is part of the module's state_dict
|
|
27
|
+
dtypes: if different from None, specifies the tensors which have a non-standard dtype
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.keys = []
|
|
31
|
+
if dtypes is None:
|
|
32
|
+
dtypes = [None] * len(shapes)
|
|
33
|
+
for i, (shape, dtype) in enumerate(zip(shapes, dtypes)):
|
|
34
|
+
key = f"buffer{i}"
|
|
35
|
+
self.register_buffer(
|
|
36
|
+
key,
|
|
37
|
+
None if shape is None else torch.zeros((capacity, *shape), dtype=dtype),
|
|
38
|
+
persistent,
|
|
39
|
+
)
|
|
40
|
+
self.keys.append(key)
|
|
41
|
+
self.capacity = capacity
|
|
42
|
+
self.size = 0
|
|
43
|
+
self.store_index = 0
|
|
44
|
+
|
|
45
|
+
def _batch_slices(self, batch_size: int) -> slice:
|
|
46
|
+
"""
|
|
47
|
+
Returns slices that split up the buffer into batches of at most ``batch_size``, respecting
|
|
48
|
+
the buffer size and periodic boundary.
|
|
49
|
+
"""
|
|
50
|
+
start = self.store_index
|
|
51
|
+
while start < self.size:
|
|
52
|
+
stop = min(start + batch_size, self.size)
|
|
53
|
+
yield slice(start, stop)
|
|
54
|
+
start = stop
|
|
55
|
+
start = 0
|
|
56
|
+
while start < self.store_index:
|
|
57
|
+
stop = min(start + batch_size, self.store_index)
|
|
58
|
+
yield slice(start, stop)
|
|
59
|
+
start = stop
|
|
60
|
+
|
|
61
|
+
def _buffer_fields(self) -> torch.Tensor | None:
|
|
62
|
+
"""
|
|
63
|
+
Iterates over the buffered tensors, without removing the padding if the buffer is not full.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The buffered tensors, or None if a tensor was initialized with shape None
|
|
67
|
+
"""
|
|
68
|
+
for key in self.keys:
|
|
69
|
+
yield getattr(self, key)
|
|
70
|
+
|
|
71
|
+
def __iter__(self) -> torch.Tensor | None:
|
|
72
|
+
"""
|
|
73
|
+
Iterates over the buffered tensors
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The buffered tensors, or None if a tensor was initialized with shape None
|
|
77
|
+
"""
|
|
78
|
+
for key in self.keys:
|
|
79
|
+
buffer = getattr(self, key)
|
|
80
|
+
yield None if buffer is None else buffer[: self.size]
|
|
81
|
+
|
|
82
|
+
def store(self, *tensors: torch.Tensor | None):
|
|
83
|
+
"""
|
|
84
|
+
Adds the given tensors to the buffer. If the buffer is full, the oldest stored samples are
|
|
85
|
+
overwritten.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
tensors: samples to be stored. The shapes of the tensors after the batch dimension must
|
|
89
|
+
match the shapes given during initialization. The argument can be None if the
|
|
90
|
+
corresponding shape was None during initialization.
|
|
91
|
+
"""
|
|
92
|
+
store_slice1 = None
|
|
93
|
+
for buffer, data in zip(self._buffer_fields(), tensors):
|
|
94
|
+
if data is None:
|
|
95
|
+
continue
|
|
96
|
+
if store_slice1 is None:
|
|
97
|
+
size = min(data.shape[0], self.capacity)
|
|
98
|
+
end_index = self.store_index + size
|
|
99
|
+
if end_index < self.capacity:
|
|
100
|
+
store_slice1 = slice(self.store_index, end_index)
|
|
101
|
+
store_slice2 = slice(0, 0)
|
|
102
|
+
load_slice1 = slice(0, size)
|
|
103
|
+
load_slice2 = slice(0, 0)
|
|
104
|
+
else:
|
|
105
|
+
store_slice1 = slice(self.store_index, self.capacity)
|
|
106
|
+
store_slice2 = slice(0, end_index - self.capacity)
|
|
107
|
+
load_slice1 = slice(0, self.capacity - self.store_index)
|
|
108
|
+
load_slice2 = slice(self.capacity - self.store_index, size)
|
|
109
|
+
self.store_index = end_index % self.capacity
|
|
110
|
+
self.size = min(self.size + size, self.capacity)
|
|
111
|
+
buffer[store_slice1] = data[load_slice1]
|
|
112
|
+
buffer[store_slice2] = data[load_slice2]
|
|
113
|
+
|
|
114
|
+
def filter(
|
|
115
|
+
self,
|
|
116
|
+
predicate: Callable[[tuple[torch.Tensor | None, ...]], torch.Tensor],
|
|
117
|
+
batch_size: int = 100000,
|
|
118
|
+
):
|
|
119
|
+
"""
|
|
120
|
+
Removes samples from the buffer that do not fulfill the criterion given by the predicate
|
|
121
|
+
function.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
predicate: function that returns a mask for a batch of samples, given a tuple with
|
|
125
|
+
all the buffered fields as argument
|
|
126
|
+
batch_size: maximal batch size to limit memory usage
|
|
127
|
+
"""
|
|
128
|
+
masks = []
|
|
129
|
+
masked_size = 0
|
|
130
|
+
for batch_slice in self._batch_slices(batch_size):
|
|
131
|
+
mask = predicate(
|
|
132
|
+
tuple(
|
|
133
|
+
None if t is None else t[batch_slice] for t in self._buffer_fields()
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
masked_size += torch.count_nonzero(mask)
|
|
137
|
+
masks.append(mask)
|
|
138
|
+
for buffer in self._buffer_fields():
|
|
139
|
+
if buffer is None:
|
|
140
|
+
continue
|
|
141
|
+
buffer[:masked_size] = torch.cat(
|
|
142
|
+
[
|
|
143
|
+
buffer[batch_slice][mask]
|
|
144
|
+
for batch_slice, mask in zip(self._batch_slices(batch_size), masks)
|
|
145
|
+
],
|
|
146
|
+
dim=0,
|
|
147
|
+
)
|
|
148
|
+
self.size = masked_size
|
|
149
|
+
self.store_index = masked_size % self.capacity
|
|
150
|
+
|
|
151
|
+
def sample(self, count: int) -> list[torch.Tensor | None]:
|
|
152
|
+
"""
|
|
153
|
+
Returns a batch of samples drawn from the buffer without replacement.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
count: number of samples
|
|
157
|
+
Returns:
|
|
158
|
+
samples drawn from the buffer
|
|
159
|
+
"""
|
|
160
|
+
weights = next(b for b in self._buffer_fields() if b is not None).new_ones(
|
|
161
|
+
self.size
|
|
162
|
+
)
|
|
163
|
+
indices = torch.multinomial(weights, min(count, self.size), replacement=False)
|
|
164
|
+
return [
|
|
165
|
+
None if buffer is None else buffer[indices]
|
|
166
|
+
for buffer in self._buffer_fields()
|
|
167
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class ChannelGroup:
|
|
6
|
+
"""
|
|
7
|
+
A group of channels
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
group_index: index of the group in the list of groups
|
|
11
|
+
target_index: index of the channel that all other channels in the group are mapped to
|
|
12
|
+
channel_indices: indices of the channels in the group
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
group_index: int
|
|
16
|
+
target_index: int
|
|
17
|
+
channel_indices: list[int]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ChannelData:
|
|
22
|
+
"""
|
|
23
|
+
Information about a single channel
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
channel_index: index of the channel
|
|
27
|
+
target_index: index of the channel that it is mapped to
|
|
28
|
+
group: channel group that the channel belongs to
|
|
29
|
+
remapped: True if the channel is remapped to another channel
|
|
30
|
+
position_in_group: index of the channel within its group
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
channel_index: int
|
|
34
|
+
target_index: int
|
|
35
|
+
group: ChannelGroup
|
|
36
|
+
remapped: bool
|
|
37
|
+
position_in_group: int
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ChannelGrouping:
|
|
41
|
+
"""
|
|
42
|
+
Class that encodes how channels are grouped together for a multi-channel integrand
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, channel_assignment: list[int | None]):
|
|
46
|
+
"""
|
|
47
|
+
Args:
|
|
48
|
+
channel_assignment: list with an entry for each channel. If None, the channel is not
|
|
49
|
+
remapped. Otherwise, the index of the channel to which it is mapped.
|
|
50
|
+
"""
|
|
51
|
+
group_dict = {}
|
|
52
|
+
for source_channel, target_channel in enumerate(channel_assignment):
|
|
53
|
+
if target_channel is None:
|
|
54
|
+
group_dict[source_channel] = ChannelGroup(
|
|
55
|
+
group_index=len(group_dict),
|
|
56
|
+
target_index=source_channel,
|
|
57
|
+
channel_indices=[source_channel],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.channels: list[ChannelData] = []
|
|
61
|
+
self.groups: list[ChannelGroup] = list(group_dict.values())
|
|
62
|
+
|
|
63
|
+
for source_channel, target_channel in enumerate(channel_assignment):
|
|
64
|
+
if target_channel is None:
|
|
65
|
+
self.channels.append(
|
|
66
|
+
ChannelData(
|
|
67
|
+
channel_index=source_channel,
|
|
68
|
+
target_index=source_channel,
|
|
69
|
+
group=group_dict[source_channel],
|
|
70
|
+
remapped=False,
|
|
71
|
+
position_in_group=0,
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
group = group_dict[target_channel]
|
|
76
|
+
self.channels.append(
|
|
77
|
+
ChannelData(
|
|
78
|
+
channel_index=source_channel,
|
|
79
|
+
target_index=target_channel,
|
|
80
|
+
group=group,
|
|
81
|
+
remapped=True,
|
|
82
|
+
position_in_group=len(group.channel_indices),
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
group.channel_indices.append(source_channel)
|