madspace 0.3.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. madspace/__init__.py +1 -0
  2. madspace/_madspace_py.cpython-311-darwin.so +0 -0
  3. madspace/_madspace_py.pyi +2189 -0
  4. madspace/_madspace_py_loader.py +111 -0
  5. madspace/include/madspace/constants.h +17 -0
  6. madspace/include/madspace/madcode/function.h +102 -0
  7. madspace/include/madspace/madcode/function_builder_mixin.h +591 -0
  8. madspace/include/madspace/madcode/instruction.h +208 -0
  9. madspace/include/madspace/madcode/opcode_mixin.h +134 -0
  10. madspace/include/madspace/madcode/optimizer.h +31 -0
  11. madspace/include/madspace/madcode/type.h +203 -0
  12. madspace/include/madspace/madcode.h +6 -0
  13. madspace/include/madspace/phasespace/base.h +74 -0
  14. madspace/include/madspace/phasespace/channel_weight_network.h +46 -0
  15. madspace/include/madspace/phasespace/channel_weights.h +51 -0
  16. madspace/include/madspace/phasespace/chili.h +32 -0
  17. madspace/include/madspace/phasespace/cross_section.h +47 -0
  18. madspace/include/madspace/phasespace/cuts.h +34 -0
  19. madspace/include/madspace/phasespace/discrete_flow.h +44 -0
  20. madspace/include/madspace/phasespace/discrete_sampler.h +53 -0
  21. madspace/include/madspace/phasespace/flow.h +53 -0
  22. madspace/include/madspace/phasespace/histograms.h +26 -0
  23. madspace/include/madspace/phasespace/integrand.h +204 -0
  24. madspace/include/madspace/phasespace/invariants.h +26 -0
  25. madspace/include/madspace/phasespace/luminosity.h +41 -0
  26. madspace/include/madspace/phasespace/matrix_element.h +70 -0
  27. madspace/include/madspace/phasespace/mlp.h +37 -0
  28. madspace/include/madspace/phasespace/multichannel.h +49 -0
  29. madspace/include/madspace/phasespace/observable.h +85 -0
  30. madspace/include/madspace/phasespace/pdf.h +78 -0
  31. madspace/include/madspace/phasespace/phasespace.h +67 -0
  32. madspace/include/madspace/phasespace/rambo.h +26 -0
  33. madspace/include/madspace/phasespace/scale.h +52 -0
  34. madspace/include/madspace/phasespace/t_propagator_mapping.h +34 -0
  35. madspace/include/madspace/phasespace/three_particle.h +68 -0
  36. madspace/include/madspace/phasespace/topology.h +116 -0
  37. madspace/include/madspace/phasespace/two_particle.h +63 -0
  38. madspace/include/madspace/phasespace/vegas.h +53 -0
  39. madspace/include/madspace/phasespace.h +27 -0
  40. madspace/include/madspace/runtime/context.h +147 -0
  41. madspace/include/madspace/runtime/discrete_optimizer.h +24 -0
  42. madspace/include/madspace/runtime/event_generator.h +257 -0
  43. madspace/include/madspace/runtime/format.h +68 -0
  44. madspace/include/madspace/runtime/io.h +343 -0
  45. madspace/include/madspace/runtime/lhe_output.h +132 -0
  46. madspace/include/madspace/runtime/logger.h +46 -0
  47. madspace/include/madspace/runtime/runtime_base.h +39 -0
  48. madspace/include/madspace/runtime/tensor.h +603 -0
  49. madspace/include/madspace/runtime/thread_pool.h +101 -0
  50. madspace/include/madspace/runtime/vegas_optimizer.h +26 -0
  51. madspace/include/madspace/runtime.h +12 -0
  52. madspace/include/madspace/umami.h +202 -0
  53. madspace/include/madspace/util.h +142 -0
  54. madspace/lib/libmadspace.dylib +0 -0
  55. madspace/lib/libmadspace_cpu.dylib +0 -0
  56. madspace/madnis/__init__.py +44 -0
  57. madspace/madnis/buffer.py +167 -0
  58. madspace/madnis/channel_grouping.py +85 -0
  59. madspace/madnis/distribution.py +103 -0
  60. madspace/madnis/integrand.py +175 -0
  61. madspace/madnis/integrator.py +973 -0
  62. madspace/madnis/interface.py +191 -0
  63. madspace/madnis/losses.py +186 -0
  64. madspace/torch.py +82 -0
  65. madspace-0.3.1.dist-info/METADATA +71 -0
  66. madspace-0.3.1.dist-info/RECORD +68 -0
  67. madspace-0.3.1.dist-info/WHEEL +6 -0
  68. madspace-0.3.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,208 @@
1
+ #pragma once
2
+
3
+ #include <initializer_list>
4
+ #include <map>
5
+ #include <memory>
6
+ #include <string>
7
+ #include <tuple>
8
+ #include <unordered_map>
9
+ #include <variant>
10
+ #include <vector>
11
+
12
+ #include "type.h"
13
+
14
+ namespace madspace {
15
+
16
+ namespace opcodes {
17
+ enum Opcode {
18
+ #include "opcode_mixin.h"
19
+ };
20
+ } // namespace opcodes
21
+
22
+ class Instruction {
23
+ public:
24
+ Instruction(const std::string& name, int opcode, bool differentiable) :
25
+ _name(name), _opcode(opcode), _differentiable(differentiable) {}
26
+ virtual ~Instruction() = default;
27
+ virtual TypeVec signature(const ValueVec& args) const = 0;
28
+ const std::string& name() const { return _name; }
29
+ int opcode() const { return _opcode; }
30
+ bool differentiable() const { return _differentiable; }
31
+
32
+ protected:
33
+ void check_arg_count(const ValueVec& args, std::size_t count) const;
34
+ me_int_t int_literal_arg(
35
+ const ValueVec& args, std::size_t index, bool check_non_negative = true
36
+ ) const;
37
+
38
+ private:
39
+ std::string _name;
40
+ int _opcode;
41
+ bool _differentiable;
42
+ };
43
+
44
+ class ShapeExpr {
45
+ public:
46
+ ShapeExpr(const char* expr);
47
+ bool check_and_update(std::map<char, int>& variables, int value) const;
48
+ std::optional<int> evaluate(const std::map<char, int>& variables) const;
49
+ char first_var_name() const { return std::get<0>(terms.at(0)); }
50
+
51
+ private:
52
+ std::vector<std::tuple<char, int>> terms;
53
+ };
54
+
55
+ class SimpleInstruction : public Instruction {
56
+ public:
57
+ using DynShape = std::vector<std::variant<int, ShapeExpr, std::monostate>>;
58
+ using SigType = std::tuple<DataType, bool, DynShape, bool>;
59
+
60
+ SimpleInstruction(
61
+ std::string name,
62
+ int opcode,
63
+ bool differentiable,
64
+ std::initializer_list<SigType> _inputs,
65
+ std::initializer_list<SigType> _outputs
66
+ ) :
67
+ Instruction(name, opcode, differentiable), inputs(_inputs), outputs(_outputs) {}
68
+
69
+ TypeVec signature(const ValueVec& args) const override;
70
+
71
+ private:
72
+ const std::vector<SigType> inputs;
73
+ const std::vector<SigType> outputs;
74
+ };
75
+
76
+ class StackInstruction : public Instruction {
77
+ public:
78
+ StackInstruction(int opcode, bool differentiable) :
79
+ Instruction("stack", opcode, differentiable) {}
80
+ TypeVec signature(const ValueVec& args) const override;
81
+ };
82
+
83
+ class UnstackInstruction : public Instruction {
84
+ public:
85
+ UnstackInstruction(int opcode, bool differentiable) :
86
+ Instruction("unstack", opcode, differentiable) {}
87
+ TypeVec signature(const ValueVec& args) const override;
88
+ };
89
+
90
+ class UnstackSizesInstruction : public Instruction {
91
+ public:
92
+ UnstackSizesInstruction(int opcode, bool differentiable) :
93
+ Instruction("unstack_sizes", opcode, differentiable) {}
94
+ TypeVec signature(const ValueVec& args) const override;
95
+ };
96
+
97
+ class BatchCatInstruction : public Instruction {
98
+ public:
99
+ BatchCatInstruction(int opcode, bool differentiable) :
100
+ Instruction("batch_cat", opcode, differentiable) {}
101
+ TypeVec signature(const ValueVec& args) const override;
102
+ };
103
+
104
+ class BatchSplitInstruction : public Instruction {
105
+ public:
106
+ BatchSplitInstruction(int opcode, bool differentiable) :
107
+ Instruction("batch_split", opcode, differentiable) {}
108
+ TypeVec signature(const ValueVec& args) const override;
109
+ };
110
+
111
+ class CatInstruction : public Instruction {
112
+ public:
113
+ CatInstruction(int opcode, bool differentiable) :
114
+ Instruction("cat", opcode, differentiable) {}
115
+ TypeVec signature(const ValueVec& args) const override;
116
+ };
117
+
118
+ class BatchSizeInstruction : public Instruction {
119
+ public:
120
+ BatchSizeInstruction(int opcode, bool differentiable) :
121
+ Instruction("batch_size", opcode, differentiable) {}
122
+ TypeVec signature(const ValueVec& args) const override;
123
+ };
124
+
125
+ class OffsetIndicesInstruction : public Instruction {
126
+ public:
127
+ OffsetIndicesInstruction(int opcode, bool differentiable) :
128
+ Instruction("offset_indices", opcode, differentiable) {}
129
+ TypeVec signature(const ValueVec& args) const override;
130
+ };
131
+
132
+ class FullInstruction : public Instruction {
133
+ public:
134
+ FullInstruction(int opcode, bool differentiable) :
135
+ Instruction("full", opcode, differentiable) {}
136
+ TypeVec signature(const ValueVec& args) const override;
137
+ };
138
+
139
+ class SqueezeInstruction : public Instruction {
140
+ public:
141
+ SqueezeInstruction(int opcode, bool differentiable) :
142
+ Instruction("squeeze", opcode, differentiable) {}
143
+ TypeVec signature(const ValueVec& args) const override;
144
+ };
145
+
146
+ class UnsqueezeInstruction : public Instruction {
147
+ public:
148
+ UnsqueezeInstruction(int opcode, bool differentiable) :
149
+ Instruction("unsqueeze", opcode, differentiable) {}
150
+ TypeVec signature(const ValueVec& args) const override;
151
+ };
152
+
153
+ class RqsReshapeInstruction : public Instruction {
154
+ public:
155
+ RqsReshapeInstruction(int opcode, bool differentiable) :
156
+ Instruction("rqs_reshape", opcode, differentiable) {}
157
+ TypeVec signature(const ValueVec& args) const override;
158
+ };
159
+
160
+ class NonzeroInstruction : public Instruction {
161
+ public:
162
+ NonzeroInstruction(int opcode, bool differentiable) :
163
+ Instruction("nonzero", opcode, differentiable) {}
164
+ TypeVec signature(const ValueVec& args) const override;
165
+ };
166
+
167
+ class BatchGatherInstruction : public Instruction {
168
+ public:
169
+ BatchGatherInstruction(int opcode, bool differentiable) :
170
+ Instruction("batch_gather", opcode, differentiable) {}
171
+ TypeVec signature(const ValueVec& args) const override;
172
+ };
173
+
174
+ class BatchScatterInstruction : public Instruction {
175
+ public:
176
+ BatchScatterInstruction(int opcode, bool differentiable) :
177
+ Instruction("batch_scatter", opcode, differentiable) {}
178
+ TypeVec signature(const ValueVec& args) const override;
179
+ };
180
+
181
+ class RandomInstruction : public Instruction {
182
+ public:
183
+ RandomInstruction(int opcode, bool differentiable) :
184
+ Instruction("random", opcode, differentiable) {}
185
+ TypeVec signature(const ValueVec& args) const override;
186
+ };
187
+
188
+ class UnweightInstruction : public Instruction {
189
+ public:
190
+ UnweightInstruction(int opcode, bool differentiable) :
191
+ Instruction("unweight", opcode, differentiable) {}
192
+ TypeVec signature(const ValueVec& args) const override;
193
+ };
194
+
195
+ class MatrixElementInstruction : public Instruction {
196
+ public:
197
+ MatrixElementInstruction(int opcode, bool differentiable) :
198
+ Instruction("matrix_element", opcode, differentiable) {}
199
+ TypeVec signature(const ValueVec& args) const override;
200
+ };
201
+
202
+ using InstructionOwner = std::unique_ptr<const Instruction>;
203
+ using InstructionPtr = Instruction const*;
204
+ const std::unordered_map<std::string, InstructionOwner> build_instruction_set();
205
+ const std::unordered_map<std::string, InstructionOwner> instruction_set =
206
+ build_instruction_set();
207
+
208
+ } // namespace madspace
@@ -0,0 +1,134 @@
1
+ stack = 0,
2
+ unstack = 1,
3
+ unstack_sizes = 2,
4
+ pop = 3,
5
+ batch_cat = 4,
6
+ batch_split = 5,
7
+ cat = 6,
8
+ batch_size = 7,
9
+ offset_indices = 8,
10
+ full = 9,
11
+ squeeze = 10,
12
+ unsqueeze = 11,
13
+ add = 12,
14
+ add_int = 13,
15
+ sub = 14,
16
+ mul = 15,
17
+ reduce_sum = 16,
18
+ reduce_sum_vector = 17,
19
+ reduce_product = 18,
20
+ sqrt = 19,
21
+ square = 20,
22
+ min = 21,
23
+ max = 22,
24
+ obs_sqrt_s = 23,
25
+ obs_e = 24,
26
+ obs_px = 25,
27
+ obs_py = 26,
28
+ obs_pz = 27,
29
+ obs_mass = 28,
30
+ obs_pt = 29,
31
+ obs_p_mag = 30,
32
+ obs_phi = 31,
33
+ obs_theta = 32,
34
+ obs_y = 33,
35
+ obs_y_abs = 34,
36
+ obs_eta = 35,
37
+ obs_eta_abs = 36,
38
+ obs_delta_eta = 37,
39
+ obs_delta_phi = 38,
40
+ obs_delta_r = 39,
41
+ boost_beam = 40,
42
+ boost_beam_inverse = 41,
43
+ com_p_in = 42,
44
+ r_to_x1x2 = 43,
45
+ x1x2_to_r = 44,
46
+ diff_cross_section = 45,
47
+ two_body_decay_com = 46,
48
+ two_body_decay_com_inverse = 47,
49
+ two_body_decay = 48,
50
+ two_body_decay_inverse = 49,
51
+ two_to_two_particle_scattering_com = 50,
52
+ two_to_two_particle_scattering_com_inverse = 51,
53
+ two_to_two_particle_scattering = 52,
54
+ two_to_two_particle_scattering_inverse = 53,
55
+ two_to_three_particle_scattering = 54,
56
+ two_to_three_particle_scattering_inverse = 55,
57
+ three_body_decay_com = 56,
58
+ three_body_decay_com_inverse = 57,
59
+ three_body_decay = 58,
60
+ three_body_decay_inverse = 59,
61
+ t_inv_min_max = 60,
62
+ t_inv_value_and_min_max = 61,
63
+ s23_min_max = 62,
64
+ s23_value_and_min_max = 63,
65
+ invariants_from_momenta = 64,
66
+ sde2_channel_weights = 65,
67
+ subchannel_weights = 66,
68
+ apply_subchannel_weights = 67,
69
+ pt_eta_phi_x = 68,
70
+ mirror_momenta = 69,
71
+ momenta_to_x1x2 = 70,
72
+ uniform_invariant = 71,
73
+ uniform_invariant_inverse = 72,
74
+ breit_wigner_invariant = 73,
75
+ breit_wigner_invariant_inverse = 74,
76
+ stable_invariant = 75,
77
+ stable_invariant_inverse = 76,
78
+ stable_invariant_nu = 77,
79
+ stable_invariant_nu_inverse = 78,
80
+ fast_rambo_massless = 79,
81
+ fast_rambo_massless_inverse = 80,
82
+ fast_rambo_massless_com = 81,
83
+ fast_rambo_massive = 82,
84
+ fast_rambo_massive_inverse = 83,
85
+ fast_rambo_massive_com = 84,
86
+ cut_unphysical = 85,
87
+ cut_one = 86,
88
+ cut_all = 87,
89
+ cut_any = 88,
90
+ scale_transverse_energy = 89,
91
+ scale_transverse_mass = 90,
92
+ scale_half_transverse_mass = 91,
93
+ scale_partonic_energy = 92,
94
+ chili_forward = 93,
95
+ chili_inverse = 94,
96
+ matrix_element = 95,
97
+ collect_channel_weights = 96,
98
+ interpolate_pdf = 97,
99
+ interpolate_alpha_s = 98,
100
+ matmul = 99,
101
+ relu = 100,
102
+ leaky_relu = 101,
103
+ elu = 102,
104
+ gelu = 103,
105
+ sigmoid = 104,
106
+ softplus = 105,
107
+ rqs_reshape = 106,
108
+ rqs_find_bin = 107,
109
+ rqs_forward = 108,
110
+ rqs_inverse = 109,
111
+ softmax = 110,
112
+ softmax_prior = 111,
113
+ sample_discrete = 112,
114
+ sample_discrete_inverse = 113,
115
+ sample_discrete_probs = 114,
116
+ sample_discrete_probs_inverse = 115,
117
+ discrete_histogram = 116,
118
+ permute_momenta = 117,
119
+ gather = 118,
120
+ gather_int = 119,
121
+ select_int = 120,
122
+ select = 121,
123
+ select_vector = 122,
124
+ argsort = 123,
125
+ one_hot = 124,
126
+ nonzero = 125,
127
+ batch_gather = 126,
128
+ batch_scatter = 127,
129
+ random = 128,
130
+ unweight = 129,
131
+ vegas_forward = 130,
132
+ vegas_inverse = 131,
133
+ vegas_histogram = 132,
134
+ histogram = 133
@@ -0,0 +1,31 @@
1
+ #pragma once
2
+
3
+ #include "madspace/madcode/function.h"
4
+
5
+ #include <vector>
6
+
7
+ namespace madspace {
8
+
9
+ class InstructionDependencies {
10
+ public:
11
+ InstructionDependencies(const Function& function);
12
+ bool depends(std::size_t test_index, std::size_t dependency_index) {
13
+ return matrix[test_index * size + dependency_index];
14
+ }
15
+
16
+ private:
17
+ std::size_t size;
18
+ std::vector<bool> matrix;
19
+ std::vector<int> ranks;
20
+ };
21
+
22
+ class LastUseOfLocals {
23
+ public:
24
+ LastUseOfLocals(const Function& function);
25
+ std::vector<int>& local_indices(std::size_t index) { return last_used[index]; }
26
+
27
+ private:
28
+ std::vector<std::vector<int>> last_used;
29
+ };
30
+
31
+ } // namespace madspace
@@ -0,0 +1,203 @@
1
+ #pragma once
2
+
3
+ #include <iostream>
4
+ #include <string>
5
+ #include <unordered_map>
6
+ #include <variant>
7
+ #include <vector>
8
+
9
+ #include <nlohmann/json.hpp>
10
+
11
+ namespace madspace {
12
+
13
+ enum class DataType { dt_int, dt_float, batch_sizes };
14
+
15
+ using me_int_t = int;
16
+
17
+ template <typename T>
18
+ concept ScalarType = std::same_as<T, me_int_t> || std::same_as<T, double>;
19
+
20
+ class BatchSize {
21
+ public:
22
+ using Named = std::string;
23
+ class UnnamedBody {
24
+ public:
25
+ UnnamedBody() : id(counter++) {}
26
+ friend std::ostream& operator<<(std::ostream& out, const BatchSize& batch_size);
27
+ friend void to_json(nlohmann::json& j, const BatchSize& batch_size);
28
+ bool operator==(const UnnamedBody& other) const { return id == other.id; }
29
+ bool operator!=(const UnnamedBody& other) const { return id != other.id; }
30
+
31
+ private:
32
+ static std::size_t counter;
33
+ std::size_t id;
34
+ };
35
+ using Unnamed = std::shared_ptr<UnnamedBody>;
36
+ using One = std::monostate;
37
+ using Compound = std::unordered_map<std::variant<Named, Unnamed, One>, int>;
38
+
39
+ static const BatchSize zero;
40
+ static const BatchSize one;
41
+
42
+ BatchSize(const std::string& name) : value(name) {}
43
+ BatchSize(One value) : value(value) {}
44
+ BatchSize() : value(std::make_shared<UnnamedBody>()) {}
45
+ BatchSize operator+(const BatchSize& other) const { return add(other, 1); }
46
+ BatchSize operator-(const BatchSize& other) const { return add(other, -1); }
47
+ bool operator==(const BatchSize& other) const { return value == other.value; }
48
+ bool operator!=(const BatchSize& other) const { return value != other.value; }
49
+
50
+ friend std::ostream& operator<<(std::ostream& out, const BatchSize& batch_size);
51
+ friend void to_json(nlohmann::json& j, const BatchSize& batch_size);
52
+ friend void from_json(const nlohmann::json& j, BatchSize& batch_size);
53
+
54
+ private:
55
+ BatchSize(Compound value) : value(value) {}
56
+ BatchSize(Unnamed value) : value(value) {}
57
+ BatchSize add(const BatchSize& other, int factor) const;
58
+
59
+ std::variant<Named, Unnamed, One, Compound> value;
60
+ };
61
+
62
+ void to_json(nlohmann::json& j, const BatchSize& batch_size);
63
+ void to_json(nlohmann::json& j, const BatchSize& batch_size);
64
+ void from_json(const nlohmann::json& j, BatchSize& batch_size);
65
+
66
+ struct Type {
67
+ DataType dtype;
68
+ BatchSize batch_size;
69
+ std::vector<int> shape;
70
+ std::vector<BatchSize> batch_size_list;
71
+
72
+ Type(DataType dtype, BatchSize batch_size, const std::vector<int>& shape) :
73
+ dtype(dtype), batch_size(batch_size), shape(shape) {}
74
+ Type(const std::vector<BatchSize>& batch_size_list) :
75
+ dtype(DataType::batch_sizes),
76
+ batch_size(BatchSize::one),
77
+ batch_size_list(batch_size_list) {}
78
+ };
79
+
80
+ std::ostream& operator<<(std::ostream& out, const BatchSize& batch_size);
81
+ std::ostream& operator<<(std::ostream& out, const DataType& dtype);
82
+ std::ostream& operator<<(std::ostream& out, const Type& type);
83
+
84
+ inline bool operator==(const Type& lhs, const Type& rhs) {
85
+ return lhs.dtype == rhs.dtype && lhs.batch_size == rhs.batch_size &&
86
+ lhs.shape == rhs.shape;
87
+ }
88
+
89
+ inline bool operator!=(const Type& lhs, const Type& rhs) {
90
+ return lhs.dtype != rhs.dtype || lhs.batch_size != rhs.batch_size ||
91
+ lhs.shape != rhs.shape;
92
+ }
93
+
94
+ using TypeVec = std::vector<Type>;
95
+
96
+ const Type single_float{DataType::dt_float, BatchSize::One{}, {}};
97
+ const Type single_int{DataType::dt_int, BatchSize::One{}, {}};
98
+ inline Type single_float_array(int count) {
99
+ return {DataType::dt_float, BatchSize::one, {count}};
100
+ }
101
+ inline Type single_int_array(int count) {
102
+ return {DataType::dt_int, BatchSize::one, {count}};
103
+ }
104
+ inline Type single_float_array_2d(int count1, int count2) {
105
+ return {DataType::dt_float, BatchSize::one, {count1, count2}};
106
+ }
107
+ inline Type single_int_array_2d(int count1, int count2) {
108
+ return {DataType::dt_int, BatchSize::one, {count1, count2}};
109
+ }
110
+
111
+ const BatchSize batch_size = BatchSize("batch_size");
112
+ Type multichannel_batch_size(int count);
113
+ const Type batch_float{DataType::dt_float, batch_size, {}};
114
+ const Type batch_int{DataType::dt_int, batch_size, {}};
115
+ const Type batch_four_vec{DataType::dt_float, batch_size, {4}};
116
+ inline Type batch_float_array(int count) {
117
+ return {DataType::dt_float, batch_size, {count}};
118
+ }
119
+ inline Type batch_four_vec_array(int count) {
120
+ return {DataType::dt_float, batch_size, {count, 4}};
121
+ }
122
+
123
+ using TensorValue = std::tuple<
124
+ std::vector<int>,
125
+ std::variant<std::vector<me_int_t>, std::vector<double>>>; // TODO: make this a
126
+ // class
127
+
128
+ using LiteralValue = std::variant<me_int_t, double, TensorValue, std::monostate>;
129
+
130
+ struct Value {
131
+ Type type;
132
+ LiteralValue literal_value;
133
+ int local_index = -1;
134
+
135
+ Value() : type(single_float), literal_value(std::monostate{}) {}
136
+
137
+ Value(me_int_t value) : type(single_int), literal_value(value) {}
138
+ Value(double value) : type(single_float), literal_value(value) {}
139
+
140
+ template <ScalarType T>
141
+ Value(const std::vector<std::vector<T>>& values) :
142
+ Value(
143
+ [&] {
144
+ std::size_t outer_size = values.size();
145
+ std::size_t inner_size = values.at(0).size();
146
+ std::vector<T> flat_values;
147
+ for (auto& vec : values) {
148
+ if (vec.size() != inner_size) {
149
+ throw std::invalid_argument(
150
+ "All inner vectors must have the same size"
151
+ );
152
+ }
153
+ }
154
+ for (std::size_t j = 0; j < inner_size; ++j) {
155
+ for (std::size_t i = 0; i < outer_size; ++i) {
156
+ flat_values.push_back(values.at(i).at(j));
157
+ }
158
+ }
159
+ return flat_values;
160
+ }(),
161
+ {static_cast<int>(values.size()), static_cast<int>(values.at(0).size())}
162
+ ) {}
163
+
164
+ template <ScalarType T>
165
+ Value(const std::vector<T>& values, const std::vector<int>& shape = {}) :
166
+ type{
167
+ std::is_same_v<T, me_int_t> ? DataType::dt_int : DataType::dt_float,
168
+ BatchSize::one,
169
+ shape.size() == 0 ? std::vector<int>{static_cast<int>(values.size())}
170
+ : shape
171
+ },
172
+ literal_value(TensorValue(type.shape, values)) {
173
+ std::size_t prod = 1;
174
+ for (auto size : type.shape) {
175
+ prod *= size;
176
+ }
177
+ if (prod != values.size()) {
178
+ throw std::invalid_argument(
179
+ "size of value vector not compatible with given shape"
180
+ );
181
+ }
182
+ }
183
+
184
+ Value(Type _type, int _local_index) :
185
+ type(_type), literal_value(std::monostate{}), local_index(_local_index) {}
186
+ Value(Type _type, LiteralValue _literal_value, int _local_index = -1) :
187
+ type(_type), literal_value(_literal_value), local_index(_local_index) {}
188
+
189
+ operator bool() {
190
+ return !(
191
+ local_index == -1 && std::holds_alternative<std::monostate>(literal_value)
192
+ );
193
+ }
194
+ };
195
+
196
+ using ValueVec = std::vector<Value>;
197
+
198
+ void to_json(nlohmann::json& j, const DataType& dtype);
199
+ void to_json(nlohmann::json& j, const Value& value);
200
+ void from_json(const nlohmann::json& j, DataType& dtype);
201
+ void from_json(const nlohmann::json& j, Value& dtype);
202
+
203
+ } // namespace madspace
@@ -0,0 +1,6 @@
1
+ #pragma once
2
+
3
+ #include "madcode/function.h"
4
+ #include "madcode/instruction.h"
5
+ #include "madcode/optimizer.h"
6
+ #include "madcode/type.h"
@@ -0,0 +1,74 @@
1
+ #pragma once
2
+
3
+ #include "madspace/madcode.h"
4
+
5
+ namespace madspace {
6
+
7
+ class Mapping {
8
+ public:
9
+ using Result = std::tuple<ValueVec, Value>;
10
+
11
+ Mapping(
12
+ const std::string& name,
13
+ const TypeVec& input_types,
14
+ const TypeVec& output_types,
15
+ const TypeVec& condition_types
16
+ ) :
17
+ _name(name),
18
+ _input_types(input_types),
19
+ _output_types(output_types),
20
+ _condition_types(condition_types) {}
21
+ virtual ~Mapping() = default;
22
+ Result build_forward(
23
+ FunctionBuilder& fb, const ValueVec& inputs, const ValueVec& conditions = {}
24
+ ) const;
25
+ Result build_inverse(
26
+ FunctionBuilder& fb, const ValueVec& inputs, const ValueVec& conditions = {}
27
+ ) const;
28
+ Function forward_function() const;
29
+ Function inverse_function() const;
30
+ const TypeVec& input_types() const { return _input_types; }
31
+ const TypeVec& output_types() const { return _output_types; }
32
+ const TypeVec& condition_types() const { return _condition_types; }
33
+ const std::string& name() const { return _name; }
34
+
35
+ protected:
36
+ // TODO: make parameters const ref
37
+ virtual Result build_forward_impl(
38
+ FunctionBuilder& fb, const ValueVec& inputs, const ValueVec& conditions
39
+ ) const = 0;
40
+ virtual Result build_inverse_impl(
41
+ FunctionBuilder& fb, const ValueVec& inputs, const ValueVec& conditions
42
+ ) const = 0;
43
+
44
+ private:
45
+ std::string _name;
46
+ TypeVec _input_types;
47
+ TypeVec _output_types;
48
+ TypeVec _condition_types;
49
+ };
50
+
51
+ class FunctionGenerator {
52
+ public:
53
+ FunctionGenerator(
54
+ const std::string& name, const TypeVec& arg_types, const TypeVec& return_types
55
+ ) :
56
+ _name(name), _arg_types(arg_types), _return_types(return_types) {}
57
+ virtual ~FunctionGenerator() = default;
58
+ ValueVec build_function(FunctionBuilder& fb, const ValueVec& args) const;
59
+ Function function() const;
60
+ const TypeVec& arg_types() const { return _arg_types; }
61
+ const TypeVec& return_types() const { return _return_types; }
62
+ const std::string& name() const { return _name; }
63
+
64
+ protected:
65
+ virtual ValueVec
66
+ build_function_impl(FunctionBuilder& fb, const ValueVec& args) const = 0;
67
+
68
+ private:
69
+ std::string _name;
70
+ TypeVec _arg_types;
71
+ TypeVec _return_types;
72
+ };
73
+
74
+ } // namespace madspace
@@ -0,0 +1,46 @@
1
+ #pragma once
2
+
3
+ #include "madspace/phasespace/base.h"
4
+ #include "madspace/phasespace/mlp.h"
5
+
6
+ namespace madspace {
7
+
8
+ class MomentumPreprocessing : public FunctionGenerator {
9
+ public:
10
+ MomentumPreprocessing(std::size_t particle_count);
11
+ std::size_t output_dim() const { return _output_dim; };
12
+
13
+ private:
14
+ ValueVec
15
+ build_function_impl(FunctionBuilder& fb, const ValueVec& args) const override;
16
+
17
+ std::size_t _output_dim;
18
+ };
19
+
20
+ class ChannelWeightNetwork : public FunctionGenerator {
21
+ public:
22
+ ChannelWeightNetwork(
23
+ std::size_t channel_count,
24
+ std::size_t particle_count,
25
+ std::size_t hidden_dim = 32,
26
+ std::size_t layers = 3,
27
+ MLP::Activation activation = MLP::leaky_relu,
28
+ const std::string& prefix = ""
29
+ );
30
+
31
+ const MLP& mlp() const { return _mlp; }
32
+ const MomentumPreprocessing& preprocessing() const { return _preprocessing; }
33
+ void initialize_globals(ContextPtr context) const;
34
+ const std::string& mask_name() const { return _mask_name; }
35
+
36
+ private:
37
+ ValueVec
38
+ build_function_impl(FunctionBuilder& fb, const ValueVec& args) const override;
39
+
40
+ MomentumPreprocessing _preprocessing;
41
+ MLP _mlp;
42
+ std::size_t _channel_count;
43
+ std::string _mask_name;
44
+ };
45
+
46
+ } // namespace madspace