madspace 0.3.1__cp314-cp314-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- madspace/__init__.py +1 -0
- madspace/_madspace_py.cpython-314-darwin.so +0 -0
- madspace/_madspace_py.pyi +2189 -0
- madspace/_madspace_py_loader.py +111 -0
- madspace/include/madspace/constants.h +17 -0
- madspace/include/madspace/madcode/function.h +102 -0
- madspace/include/madspace/madcode/function_builder_mixin.h +591 -0
- madspace/include/madspace/madcode/instruction.h +208 -0
- madspace/include/madspace/madcode/opcode_mixin.h +134 -0
- madspace/include/madspace/madcode/optimizer.h +31 -0
- madspace/include/madspace/madcode/type.h +203 -0
- madspace/include/madspace/madcode.h +6 -0
- madspace/include/madspace/phasespace/base.h +74 -0
- madspace/include/madspace/phasespace/channel_weight_network.h +46 -0
- madspace/include/madspace/phasespace/channel_weights.h +51 -0
- madspace/include/madspace/phasespace/chili.h +32 -0
- madspace/include/madspace/phasespace/cross_section.h +47 -0
- madspace/include/madspace/phasespace/cuts.h +34 -0
- madspace/include/madspace/phasespace/discrete_flow.h +44 -0
- madspace/include/madspace/phasespace/discrete_sampler.h +53 -0
- madspace/include/madspace/phasespace/flow.h +53 -0
- madspace/include/madspace/phasespace/histograms.h +26 -0
- madspace/include/madspace/phasespace/integrand.h +204 -0
- madspace/include/madspace/phasespace/invariants.h +26 -0
- madspace/include/madspace/phasespace/luminosity.h +41 -0
- madspace/include/madspace/phasespace/matrix_element.h +70 -0
- madspace/include/madspace/phasespace/mlp.h +37 -0
- madspace/include/madspace/phasespace/multichannel.h +49 -0
- madspace/include/madspace/phasespace/observable.h +85 -0
- madspace/include/madspace/phasespace/pdf.h +78 -0
- madspace/include/madspace/phasespace/phasespace.h +67 -0
- madspace/include/madspace/phasespace/rambo.h +26 -0
- madspace/include/madspace/phasespace/scale.h +52 -0
- madspace/include/madspace/phasespace/t_propagator_mapping.h +34 -0
- madspace/include/madspace/phasespace/three_particle.h +68 -0
- madspace/include/madspace/phasespace/topology.h +116 -0
- madspace/include/madspace/phasespace/two_particle.h +63 -0
- madspace/include/madspace/phasespace/vegas.h +53 -0
- madspace/include/madspace/phasespace.h +27 -0
- madspace/include/madspace/runtime/context.h +147 -0
- madspace/include/madspace/runtime/discrete_optimizer.h +24 -0
- madspace/include/madspace/runtime/event_generator.h +257 -0
- madspace/include/madspace/runtime/format.h +68 -0
- madspace/include/madspace/runtime/io.h +343 -0
- madspace/include/madspace/runtime/lhe_output.h +132 -0
- madspace/include/madspace/runtime/logger.h +46 -0
- madspace/include/madspace/runtime/runtime_base.h +39 -0
- madspace/include/madspace/runtime/tensor.h +603 -0
- madspace/include/madspace/runtime/thread_pool.h +101 -0
- madspace/include/madspace/runtime/vegas_optimizer.h +26 -0
- madspace/include/madspace/runtime.h +12 -0
- madspace/include/madspace/umami.h +202 -0
- madspace/include/madspace/util.h +142 -0
- madspace/lib/libmadspace.dylib +0 -0
- madspace/lib/libmadspace_cpu.dylib +0 -0
- madspace/madnis/__init__.py +44 -0
- madspace/madnis/buffer.py +167 -0
- madspace/madnis/channel_grouping.py +85 -0
- madspace/madnis/distribution.py +103 -0
- madspace/madnis/integrand.py +175 -0
- madspace/madnis/integrator.py +973 -0
- madspace/madnis/interface.py +191 -0
- madspace/madnis/losses.py +186 -0
- madspace/torch.py +82 -0
- madspace-0.3.1.dist-info/METADATA +71 -0
- madspace-0.3.1.dist-info/RECORD +68 -0
- madspace-0.3.1.dist-info/WHEEL +6 -0
- madspace-0.3.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,603 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "madspace/madcode/type.h"
|
|
4
|
+
#include "madspace/util.h"
|
|
5
|
+
|
|
6
|
+
#include <algorithm>
|
|
7
|
+
#include <atomic>
|
|
8
|
+
#include <concepts>
|
|
9
|
+
#include <cstdint>
|
|
10
|
+
#include <functional>
|
|
11
|
+
#include <initializer_list>
|
|
12
|
+
#include <vector>
|
|
13
|
+
|
|
14
|
+
namespace madspace {
|
|
15
|
+
|
|
16
|
+
using SizeVec = std::vector<std::size_t>;
|
|
17
|
+
|
|
18
|
+
class Sizes {
|
|
19
|
+
public:
|
|
20
|
+
static constexpr std::size_t max_size = 4;
|
|
21
|
+
|
|
22
|
+
Sizes() : _size(0) {};
|
|
23
|
+
Sizes(std::size_t size) : _size(size) { std::fill(begin(), end(), 0); };
|
|
24
|
+
Sizes(std::size_t size, std::size_t value) : _size(size) {
|
|
25
|
+
std::fill(begin(), end(), value);
|
|
26
|
+
};
|
|
27
|
+
Sizes(std::initializer_list<std::size_t> values) : _size(values.size()) {
|
|
28
|
+
if (values.size() > max_size) {
|
|
29
|
+
throw std::invalid_argument("maximum dimension exceeded");
|
|
30
|
+
}
|
|
31
|
+
std::copy(values.begin(), values.end(), begin());
|
|
32
|
+
}
|
|
33
|
+
Sizes(const SizeVec& values) : _size(values.size()) {
|
|
34
|
+
if (values.size() > max_size) {
|
|
35
|
+
throw std::invalid_argument("maximum dimension exceeded");
|
|
36
|
+
}
|
|
37
|
+
std::copy(values.begin(), values.end(), begin());
|
|
38
|
+
}
|
|
39
|
+
std::size_t& operator[](std::size_t index) { return _values[index]; }
|
|
40
|
+
const std::size_t& operator[](std::size_t index) const { return _values[index]; }
|
|
41
|
+
std::size_t size() const { return _size; }
|
|
42
|
+
std::size_t* begin() { return &_values[0]; }
|
|
43
|
+
std::size_t* end() { return &_values[_size]; }
|
|
44
|
+
const std::size_t* begin() const { return &_values[0]; }
|
|
45
|
+
const std::size_t* end() const { return &_values[_size]; }
|
|
46
|
+
void push_back(std::size_t item) {
|
|
47
|
+
_values[_size] = item;
|
|
48
|
+
++_size;
|
|
49
|
+
}
|
|
50
|
+
std::size_t* data() { return &_values[0]; }
|
|
51
|
+
const std::size_t* data() const { return &_values[0]; }
|
|
52
|
+
std::size_t& back() { return _values[_size - 1]; }
|
|
53
|
+
const std::size_t& back() const { return _values[_size - 1]; }
|
|
54
|
+
|
|
55
|
+
private:
|
|
56
|
+
std::size_t _values[max_size];
|
|
57
|
+
std::size_t _size;
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
inline bool operator==(const Sizes& a, const Sizes& b) {
|
|
61
|
+
return std::equal(a.begin(), a.end(), b.begin(), b.end());
|
|
62
|
+
}
|
|
63
|
+
inline bool operator!=(const Sizes& a, const Sizes& b) { return !(a == b); }
|
|
64
|
+
|
|
65
|
+
template <ScalarType T, int _dim>
|
|
66
|
+
struct PackedTensorView {
|
|
67
|
+
using DType = T;
|
|
68
|
+
static const int dim = _dim;
|
|
69
|
+
T* data;
|
|
70
|
+
Sizes stride;
|
|
71
|
+
Sizes shape;
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
template <ScalarType T, int _dim>
|
|
75
|
+
class TensorView {
|
|
76
|
+
public:
|
|
77
|
+
using DType = T;
|
|
78
|
+
static const int dim = _dim;
|
|
79
|
+
|
|
80
|
+
TensorView(T* data, std::size_t* stride, std::size_t* shape) :
|
|
81
|
+
_data(data), _stride(stride), _shape(shape) {}
|
|
82
|
+
|
|
83
|
+
TensorView(PackedTensorView<T, _dim>& packed_view) :
|
|
84
|
+
_data(packed_view.data),
|
|
85
|
+
_stride(packed_view.stride.data()),
|
|
86
|
+
_shape(packed_view.shape.data()) {}
|
|
87
|
+
|
|
88
|
+
TensorView(T& value) : _data(&value), _stride(nullptr), _shape(nullptr) {}
|
|
89
|
+
|
|
90
|
+
const TensorView<T, _dim - 1> operator[](std::size_t index) const
|
|
91
|
+
requires(_dim != 0)
|
|
92
|
+
{
|
|
93
|
+
return {&_data[index * _stride[0]], &_stride[1], &_shape[1]};
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
TensorView<T, _dim - 1> operator[](std::size_t index)
|
|
97
|
+
requires(_dim != 0)
|
|
98
|
+
{
|
|
99
|
+
return {&_data[index * _stride[0]], &_stride[1], &_shape[1]};
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
template <typename... I>
|
|
103
|
+
const TensorView<T, _dim - sizeof...(I)> get(I... index) const
|
|
104
|
+
requires(_dim >= sizeof...(I))
|
|
105
|
+
{
|
|
106
|
+
T* ptr = _data;
|
|
107
|
+
int i = 0;
|
|
108
|
+
((ptr = &ptr[index * _stride[i++]]), ...);
|
|
109
|
+
return {ptr, &_stride[sizeof...(I)], &_shape[sizeof...(I)]};
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
template <typename... I>
|
|
113
|
+
TensorView<T, _dim - sizeof...(I)> get(I... index)
|
|
114
|
+
requires(_dim >= sizeof...(I))
|
|
115
|
+
{
|
|
116
|
+
T* ptr = _data;
|
|
117
|
+
int i = 0;
|
|
118
|
+
((ptr = &ptr[index * _stride[i++]]), ...);
|
|
119
|
+
return {ptr, &_stride[sizeof...(I)], &_shape[sizeof...(I)]};
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
operator T() const
|
|
123
|
+
requires(_dim == 0)
|
|
124
|
+
{
|
|
125
|
+
return *_data;
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
T operator=(T value)
|
|
129
|
+
requires(_dim == 0)
|
|
130
|
+
{
|
|
131
|
+
*_data = value;
|
|
132
|
+
return value;
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
T operator+=(T value)
|
|
136
|
+
requires(_dim == 0)
|
|
137
|
+
{
|
|
138
|
+
*_data += value;
|
|
139
|
+
return value;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
TensorView<T, _dim>& operator=(TensorView<T, _dim>& value) = delete;
|
|
143
|
+
std::size_t size(std::size_t index = 0) const { return _shape[index]; }
|
|
144
|
+
T* data() const { return _data; }
|
|
145
|
+
std::size_t* stride() const { return _stride; }
|
|
146
|
+
std::size_t* shape() const { return _shape; }
|
|
147
|
+
T gather(me_int_t index) const
|
|
148
|
+
requires(_dim == 1)
|
|
149
|
+
{
|
|
150
|
+
return (*this)[index];
|
|
151
|
+
}
|
|
152
|
+
void scatter_add(me_int_t index, T value)
|
|
153
|
+
requires(_dim == 1)
|
|
154
|
+
{
|
|
155
|
+
(*this)[index] += value;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
private:
|
|
159
|
+
T* _data;
|
|
160
|
+
std::size_t* _stride;
|
|
161
|
+
std::size_t* _shape;
|
|
162
|
+
};
|
|
163
|
+
|
|
164
|
+
class Tensor;
|
|
165
|
+
|
|
166
|
+
enum class DeviceType { cpu, cuda, hip };
|
|
167
|
+
|
|
168
|
+
class Device {
|
|
169
|
+
public:
|
|
170
|
+
virtual ~Device() = default;
|
|
171
|
+
virtual void* allocate(std::size_t size) const = 0;
|
|
172
|
+
virtual void free(void* ptr) const = 0;
|
|
173
|
+
virtual void memcpy(void* to, void* from, std::size_t size) const = 0;
|
|
174
|
+
virtual void tensor_copy(const Tensor& source, Tensor& target) const = 0;
|
|
175
|
+
virtual void tensor_zero(Tensor& tensor) const = 0;
|
|
176
|
+
virtual void tensor_add(const Tensor& source, Tensor& target) const = 0;
|
|
177
|
+
virtual void tensor_cpu(const Tensor& source, Tensor& target) const = 0;
|
|
178
|
+
virtual const Device* device_ptr() const = 0;
|
|
179
|
+
virtual void sync_barrier() const {}
|
|
180
|
+
virtual DeviceType device_type() const = 0;
|
|
181
|
+
};
|
|
182
|
+
|
|
183
|
+
using DevicePtr = const Device*;
|
|
184
|
+
// defined in runtime_base.cpp, but need to declare them here
|
|
185
|
+
DevicePtr cpu_device();
|
|
186
|
+
DevicePtr cuda_device();
|
|
187
|
+
DevicePtr hip_device();
|
|
188
|
+
|
|
189
|
+
class Tensor {
|
|
190
|
+
public:
|
|
191
|
+
Tensor() : impl(nullptr) {}
|
|
192
|
+
|
|
193
|
+
Tensor(const Tensor& other) : impl(other.impl) {
|
|
194
|
+
if (impl != nullptr) {
|
|
195
|
+
impl->incref();
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
Tensor(Tensor&& other) noexcept : impl(other.impl) { other.impl = nullptr; }
|
|
200
|
+
|
|
201
|
+
Tensor(DataType dtype, const Sizes& shape) : Tensor(dtype, shape, cpu_device()) {}
|
|
202
|
+
|
|
203
|
+
Tensor(DataType dtype, const Sizes& shape, DevicePtr device) :
|
|
204
|
+
impl(new TensorImpl{dtype, shape, device}) {
|
|
205
|
+
auto size = init_stride();
|
|
206
|
+
impl->data = device->allocate(size);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
template <typename D>
|
|
210
|
+
Tensor(DataType dtype, const Sizes& shape, const D& device) :
|
|
211
|
+
impl(new TensorImpl{dtype, shape, device.device_ptr()}) {
|
|
212
|
+
auto size = init_stride();
|
|
213
|
+
impl->data = device.allocate(size);
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
Tensor(
|
|
217
|
+
DataType dtype,
|
|
218
|
+
const Sizes& shape,
|
|
219
|
+
void* data,
|
|
220
|
+
std::function<void()> external_reset
|
|
221
|
+
) :
|
|
222
|
+
Tensor(dtype, shape, cpu_device(), data, external_reset) {}
|
|
223
|
+
|
|
224
|
+
Tensor(
|
|
225
|
+
DataType dtype,
|
|
226
|
+
const Sizes& shape,
|
|
227
|
+
DevicePtr device,
|
|
228
|
+
void* data,
|
|
229
|
+
std::function<void()> external_reset
|
|
230
|
+
) :
|
|
231
|
+
impl(new TensorImpl{dtype, shape, device, data, false, external_reset}) {
|
|
232
|
+
init_stride();
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
Tensor(
|
|
236
|
+
DataType dtype,
|
|
237
|
+
const Sizes& shape,
|
|
238
|
+
const Sizes& stride,
|
|
239
|
+
DevicePtr device,
|
|
240
|
+
void* data,
|
|
241
|
+
std::function<void()> external_reset
|
|
242
|
+
) :
|
|
243
|
+
impl(new TensorImpl{
|
|
244
|
+
dtype, shape, device, data, false, external_reset, nullptr, 1, stride
|
|
245
|
+
}) {
|
|
246
|
+
std::size_t stride_prod = 1;
|
|
247
|
+
bool first = true;
|
|
248
|
+
impl->contiguous_dims = 0;
|
|
249
|
+
for (auto [size_i, stride_i] : zip(shape, stride)) {
|
|
250
|
+
if (stride_i == stride_prod) {
|
|
251
|
+
++impl->contiguous_dims;
|
|
252
|
+
}
|
|
253
|
+
if (first && size_i == 1) {
|
|
254
|
+
impl->stride[0] = 0;
|
|
255
|
+
}
|
|
256
|
+
stride_prod *= size_i;
|
|
257
|
+
first = false;
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
Tensor(const SizeVec& batch_sizes) :
|
|
262
|
+
impl(new TensorImpl{
|
|
263
|
+
DataType::batch_sizes,
|
|
264
|
+
{},
|
|
265
|
+
cpu_device(),
|
|
266
|
+
nullptr,
|
|
267
|
+
true,
|
|
268
|
+
std::nullopt,
|
|
269
|
+
nullptr,
|
|
270
|
+
1,
|
|
271
|
+
{},
|
|
272
|
+
0,
|
|
273
|
+
0,
|
|
274
|
+
batch_sizes
|
|
275
|
+
}) {}
|
|
276
|
+
|
|
277
|
+
template <ScalarType T>
|
|
278
|
+
Tensor(T value, DevicePtr device) :
|
|
279
|
+
impl(new TensorImpl{
|
|
280
|
+
std::is_same_v<T, me_int_t> ? DataType::dt_int : DataType::dt_float,
|
|
281
|
+
{1},
|
|
282
|
+
device
|
|
283
|
+
}) {
|
|
284
|
+
auto size = init_stride();
|
|
285
|
+
impl->data = device->allocate(size);
|
|
286
|
+
device->memcpy(impl->data, &value, sizeof(value));
|
|
287
|
+
if (std::is_same_v<T, me_int_t> && value >= 0) {
|
|
288
|
+
impl->batch_sizes.push_back(value);
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
Tensor(TensorValue value, DevicePtr device) :
|
|
293
|
+
impl(new TensorImpl{
|
|
294
|
+
std::visit(
|
|
295
|
+
Overloaded{
|
|
296
|
+
[](std::vector<me_int_t>) { return DataType::dt_int; },
|
|
297
|
+
[](std::vector<double>) { return DataType::dt_float; },
|
|
298
|
+
},
|
|
299
|
+
std::get<1>(value)
|
|
300
|
+
),
|
|
301
|
+
[&] {
|
|
302
|
+
auto& val_shape = std::get<0>(value);
|
|
303
|
+
Sizes full_shape(val_shape.size() + 1);
|
|
304
|
+
full_shape[0] = 1;
|
|
305
|
+
std::copy(val_shape.begin(), val_shape.end(), full_shape.begin() + 1);
|
|
306
|
+
return full_shape;
|
|
307
|
+
}(),
|
|
308
|
+
device
|
|
309
|
+
}) {
|
|
310
|
+
auto size = init_stride();
|
|
311
|
+
impl->data = device->allocate(size);
|
|
312
|
+
std::visit(
|
|
313
|
+
[&](auto& vec) { device->memcpy(impl->data, vec.data(), size); },
|
|
314
|
+
std::get<1>(value)
|
|
315
|
+
);
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
~Tensor() { reset(); }
|
|
319
|
+
|
|
320
|
+
Tensor& operator=(const Tensor& other) {
|
|
321
|
+
reset();
|
|
322
|
+
impl = other.impl;
|
|
323
|
+
if (impl != nullptr) {
|
|
324
|
+
impl->incref();
|
|
325
|
+
}
|
|
326
|
+
return *this;
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
Tensor& operator=(Tensor&& other) noexcept {
|
|
330
|
+
reset();
|
|
331
|
+
impl = other.impl;
|
|
332
|
+
other.impl = nullptr;
|
|
333
|
+
return *this;
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
operator bool() const { return impl != nullptr; }
|
|
337
|
+
|
|
338
|
+
template <class T, int dim>
|
|
339
|
+
TensorView<T, dim> view() {
|
|
340
|
+
check_impl();
|
|
341
|
+
T* data = static_cast<T*>(impl->data);
|
|
342
|
+
return TensorView<T, dim>(
|
|
343
|
+
&data[impl->offset], impl->stride.data(), impl->shape.data()
|
|
344
|
+
);
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
template <class T, int dim>
|
|
348
|
+
const TensorView<T, dim> view() const {
|
|
349
|
+
check_impl();
|
|
350
|
+
T* data = static_cast<T*>(impl->data);
|
|
351
|
+
return TensorView<T, dim>(
|
|
352
|
+
&data[impl->offset], impl->stride.data(), impl->shape.data()
|
|
353
|
+
);
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
template <class T, int dim>
|
|
357
|
+
PackedTensorView<T, dim> flat_view(std::size_t flatten_count) const {
|
|
358
|
+
check_impl();
|
|
359
|
+
T* data = static_cast<T*>(impl->data);
|
|
360
|
+
if (flatten_count <= 1) {
|
|
361
|
+
return {&data[impl->offset], impl->stride, impl->shape};
|
|
362
|
+
}
|
|
363
|
+
if (flatten_count > impl->contiguous_dims) {
|
|
364
|
+
throw std::invalid_argument("can only flatten contiguous dimensions");
|
|
365
|
+
}
|
|
366
|
+
Sizes stride{1}, shape{1};
|
|
367
|
+
std::size_t i = 0;
|
|
368
|
+
for (; i < flatten_count; ++i) {
|
|
369
|
+
shape[0] *= impl->shape[i];
|
|
370
|
+
}
|
|
371
|
+
for (; i < impl->shape.size(); ++i) {
|
|
372
|
+
shape.push_back(impl->shape[i]);
|
|
373
|
+
stride.push_back(impl->stride[i]);
|
|
374
|
+
}
|
|
375
|
+
return {&data[impl->offset], stride, shape};
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
void* data() {
|
|
379
|
+
check_impl();
|
|
380
|
+
return impl->data;
|
|
381
|
+
}
|
|
382
|
+
void* data() const {
|
|
383
|
+
check_impl();
|
|
384
|
+
return impl->data;
|
|
385
|
+
}
|
|
386
|
+
const Sizes& shape() const {
|
|
387
|
+
check_impl();
|
|
388
|
+
return impl->shape;
|
|
389
|
+
}
|
|
390
|
+
const Sizes& stride() const {
|
|
391
|
+
check_impl();
|
|
392
|
+
return impl->stride;
|
|
393
|
+
}
|
|
394
|
+
std::size_t size(std::size_t i) const {
|
|
395
|
+
check_impl();
|
|
396
|
+
return impl->shape[i];
|
|
397
|
+
}
|
|
398
|
+
std::size_t offset() const { return impl->offset; }
|
|
399
|
+
DataType dtype() const {
|
|
400
|
+
check_impl();
|
|
401
|
+
return impl->dtype;
|
|
402
|
+
}
|
|
403
|
+
const SizeVec& batch_sizes() const {
|
|
404
|
+
check_impl();
|
|
405
|
+
return impl->batch_sizes;
|
|
406
|
+
}
|
|
407
|
+
DevicePtr device() const {
|
|
408
|
+
check_impl();
|
|
409
|
+
return impl->device;
|
|
410
|
+
}
|
|
411
|
+
std::size_t index_value() const {
|
|
412
|
+
check_impl();
|
|
413
|
+
if (impl->batch_sizes.size() > 0) {
|
|
414
|
+
return impl->batch_sizes[0];
|
|
415
|
+
}
|
|
416
|
+
auto cpu_tensor = cpu();
|
|
417
|
+
return cpu_tensor.view<me_int_t, 1>()[0];
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
std::size_t dtype_size() const {
|
|
421
|
+
check_impl();
|
|
422
|
+
switch (impl->dtype) {
|
|
423
|
+
case DataType::dt_int:
|
|
424
|
+
return sizeof(me_int_t);
|
|
425
|
+
case DataType::dt_float:
|
|
426
|
+
return sizeof(double);
|
|
427
|
+
case DataType::batch_sizes:
|
|
428
|
+
return 0;
|
|
429
|
+
default:
|
|
430
|
+
throw std::logic_error("invalid data type");
|
|
431
|
+
}
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
std::size_t byte_size() const {
|
|
435
|
+
check_impl();
|
|
436
|
+
std::size_t size = dtype_size();
|
|
437
|
+
for (auto dim_size : impl->shape) {
|
|
438
|
+
size *= dim_size;
|
|
439
|
+
}
|
|
440
|
+
return size;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
void reset() {
|
|
444
|
+
if (impl == nullptr) {
|
|
445
|
+
return;
|
|
446
|
+
}
|
|
447
|
+
impl->reset(*impl->device);
|
|
448
|
+
impl = nullptr;
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
template <typename D>
|
|
452
|
+
void reset(const D& device) {
|
|
453
|
+
if (impl == nullptr) {
|
|
454
|
+
return;
|
|
455
|
+
}
|
|
456
|
+
impl->reset(device);
|
|
457
|
+
impl = nullptr;
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
Tensor select(std::size_t axis, std::size_t index) const;
|
|
461
|
+
Tensor slice(std::size_t axis, std::size_t start, std::size_t stop) const;
|
|
462
|
+
std::vector<Tensor> split(std::size_t axis, const SizeVec& sizes) const;
|
|
463
|
+
std::vector<Tensor> unstack(std::size_t axis) const;
|
|
464
|
+
Tensor unsqueeze(std::size_t axis) const;
|
|
465
|
+
Tensor expand(const Sizes& shape) const;
|
|
466
|
+
Tensor factor_dim(std::size_t axis, std::size_t factor);
|
|
467
|
+
|
|
468
|
+
template <typename D>
|
|
469
|
+
Tensor cpu(const D& device) const {
|
|
470
|
+
check_impl();
|
|
471
|
+
if (impl->device == cpu_device()) {
|
|
472
|
+
return *this;
|
|
473
|
+
} else {
|
|
474
|
+
Tensor tensor(impl->dtype, impl->shape);
|
|
475
|
+
device.tensor_cpu(contiguous(device), tensor);
|
|
476
|
+
return tensor;
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
Tensor cpu() const { return cpu(*impl->device); }
|
|
480
|
+
|
|
481
|
+
template <typename D>
|
|
482
|
+
void zero(const D& device) {
|
|
483
|
+
check_impl();
|
|
484
|
+
device.tensor_zero(*this);
|
|
485
|
+
}
|
|
486
|
+
void zero() { zero(*impl->device); }
|
|
487
|
+
|
|
488
|
+
template <typename D>
|
|
489
|
+
void copy_from(const Tensor& source, const D& device) {
|
|
490
|
+
check_impl();
|
|
491
|
+
if (source.device() == this->device()) {
|
|
492
|
+
device.tensor_copy(source, *this);
|
|
493
|
+
} else if (is_contiguous()) {
|
|
494
|
+
auto contig_source = source.contiguous();
|
|
495
|
+
device.memcpy(data(), contig_source.data(), byte_size());
|
|
496
|
+
} else {
|
|
497
|
+
throw std::runtime_error(
|
|
498
|
+
"tensor must be contiguous for copy across devices"
|
|
499
|
+
);
|
|
500
|
+
}
|
|
501
|
+
}
|
|
502
|
+
void copy_from(const Tensor& source) { copy_from(source, *impl->device); }
|
|
503
|
+
|
|
504
|
+
template <typename D>
|
|
505
|
+
void add(const Tensor& source, const D& device) {
|
|
506
|
+
check_impl();
|
|
507
|
+
device.tensor_add(source, *this);
|
|
508
|
+
}
|
|
509
|
+
void add(const Tensor& source) { add(source, *impl->device); }
|
|
510
|
+
|
|
511
|
+
template <typename D>
|
|
512
|
+
Tensor copy(const D& device) const {
|
|
513
|
+
check_impl();
|
|
514
|
+
Tensor tensor(impl->dtype, impl->shape, impl->device);
|
|
515
|
+
device.tensor_copy(*this, tensor);
|
|
516
|
+
return tensor;
|
|
517
|
+
}
|
|
518
|
+
Tensor copy() const { return copy(*impl->device); }
|
|
519
|
+
|
|
520
|
+
bool is_contiguous() const { return impl->contiguous_dims == impl->shape.size(); }
|
|
521
|
+
|
|
522
|
+
std::size_t contiguous_dims() const { return impl->contiguous_dims; }
|
|
523
|
+
|
|
524
|
+
template <typename D>
|
|
525
|
+
Tensor contiguous(const D& device) const {
|
|
526
|
+
check_impl();
|
|
527
|
+
return is_contiguous() ? *this : copy(device);
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
Tensor contiguous() const { return contiguous(*impl->device); }
|
|
531
|
+
|
|
532
|
+
template <typename D>
|
|
533
|
+
Tensor contiguous(std::size_t batch_size, const D& device) const {
|
|
534
|
+
check_impl();
|
|
535
|
+
if (size(0) == batch_size) {
|
|
536
|
+
return contiguous(device);
|
|
537
|
+
} else if (size(0) == 1) {
|
|
538
|
+
auto shape = impl->shape;
|
|
539
|
+
shape[0] = batch_size;
|
|
540
|
+
Tensor tensor(impl->dtype, shape, impl->device);
|
|
541
|
+
device.tensor_copy(*this, tensor);
|
|
542
|
+
return tensor;
|
|
543
|
+
} else {
|
|
544
|
+
throw std::runtime_error("invalid batch size");
|
|
545
|
+
}
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
Tensor contiguous(std::size_t batch_size) const {
|
|
549
|
+
return contiguous(batch_size, *impl->device);
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
private:
|
|
553
|
+
struct TensorImpl {
|
|
554
|
+
DataType dtype;
|
|
555
|
+
Sizes shape;
|
|
556
|
+
DevicePtr device;
|
|
557
|
+
void* data;
|
|
558
|
+
bool owns_data = true;
|
|
559
|
+
std::optional<std::function<void()>> external_reset = std::nullopt;
|
|
560
|
+
TensorImpl* data_owner;
|
|
561
|
+
std::atomic<int> ref_count = 1;
|
|
562
|
+
Sizes stride;
|
|
563
|
+
std::size_t offset;
|
|
564
|
+
std::size_t contiguous_dims;
|
|
565
|
+
SizeVec batch_sizes;
|
|
566
|
+
|
|
567
|
+
template <typename D>
|
|
568
|
+
void reset(const D& device) {
|
|
569
|
+
if (ref_count.fetch_sub(1, std::memory_order_acq_rel) != 1) {
|
|
570
|
+
return;
|
|
571
|
+
}
|
|
572
|
+
if (owns_data && data != nullptr) {
|
|
573
|
+
device.free(data);
|
|
574
|
+
} else if (data_owner != nullptr) {
|
|
575
|
+
data_owner->reset(device);
|
|
576
|
+
} else if (external_reset) {
|
|
577
|
+
(*external_reset)();
|
|
578
|
+
}
|
|
579
|
+
delete this;
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
void incref() { ref_count.fetch_add(1, std::memory_order_relaxed); }
|
|
583
|
+
};
|
|
584
|
+
|
|
585
|
+
Tensor(TensorImpl* _impl) : impl(_impl) {
|
|
586
|
+
if (impl->data_owner != nullptr) {
|
|
587
|
+
impl->data_owner->incref();
|
|
588
|
+
}
|
|
589
|
+
}
|
|
590
|
+
std::size_t init_stride();
|
|
591
|
+
|
|
592
|
+
void check_impl() const {
|
|
593
|
+
if (impl == nullptr) {
|
|
594
|
+
throw std::runtime_error("empty tensor");
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
TensorImpl* impl;
|
|
599
|
+
};
|
|
600
|
+
|
|
601
|
+
using TensorVec = std::vector<Tensor>;
|
|
602
|
+
|
|
603
|
+
} // namespace madspace
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <condition_variable>
|
|
4
|
+
#include <deque>
|
|
5
|
+
#include <functional>
|
|
6
|
+
#include <mutex>
|
|
7
|
+
#include <optional>
|
|
8
|
+
#include <thread>
|
|
9
|
+
#include <vector>
|
|
10
|
+
|
|
11
|
+
namespace madspace {
|
|
12
|
+
|
|
13
|
+
class ThreadPool {
|
|
14
|
+
public:
|
|
15
|
+
using JobFunc = std::function<std::size_t()>;
|
|
16
|
+
ThreadPool(int thread_count = -1);
|
|
17
|
+
~ThreadPool();
|
|
18
|
+
ThreadPool(const ThreadPool&) = delete;
|
|
19
|
+
ThreadPool& operator=(const ThreadPool&) = delete;
|
|
20
|
+
void set_thread_count(int new_count);
|
|
21
|
+
std::size_t thread_count() const { return _thread_count; }
|
|
22
|
+
void submit(JobFunc job);
|
|
23
|
+
void submit(std::vector<JobFunc>& jobs);
|
|
24
|
+
std::optional<std::size_t> wait();
|
|
25
|
+
std::vector<std::size_t> wait_multiple();
|
|
26
|
+
std::size_t add_listener(std::function<void(std::size_t)> listener);
|
|
27
|
+
void remove_listener(std::size_t id);
|
|
28
|
+
|
|
29
|
+
static std::size_t thread_index() { return _thread_index; }
|
|
30
|
+
|
|
31
|
+
private:
|
|
32
|
+
static inline thread_local std::size_t _thread_index = 0;
|
|
33
|
+
static const std::size_t QUEUE_SIZE_PER_THREAD = 16384;
|
|
34
|
+
|
|
35
|
+
void thread_loop(std::size_t index);
|
|
36
|
+
bool fill_done_cache();
|
|
37
|
+
|
|
38
|
+
std::mutex _mutex;
|
|
39
|
+
std::condition_variable _cv_run, _cv_done;
|
|
40
|
+
std::size_t _thread_count;
|
|
41
|
+
std::vector<std::thread> _threads;
|
|
42
|
+
std::deque<JobFunc> _job_queue;
|
|
43
|
+
std::deque<std::size_t> _done_queue;
|
|
44
|
+
std::vector<std::size_t> _done_buffer;
|
|
45
|
+
std::size_t _busy_threads;
|
|
46
|
+
std::size_t _listener_id = 0;
|
|
47
|
+
std::unordered_map<std::size_t, std::function<void(std::size_t)>> _listeners;
|
|
48
|
+
bool _buffer_submit;
|
|
49
|
+
};
|
|
50
|
+
|
|
51
|
+
template <typename T>
|
|
52
|
+
class ThreadResource {
|
|
53
|
+
public:
|
|
54
|
+
ThreadResource() = default;
|
|
55
|
+
ThreadResource(ThreadPool& pool, std::function<T()> constructor) :
|
|
56
|
+
_pool(&pool),
|
|
57
|
+
_listener_id(pool.add_listener([this, constructor](std::size_t thread_count) {
|
|
58
|
+
while (_resources.size() < thread_count) {
|
|
59
|
+
_resources.push_back(constructor());
|
|
60
|
+
}
|
|
61
|
+
})) {
|
|
62
|
+
for (std::size_t i = 0; i == 0 || i < pool.thread_count(); ++i) {
|
|
63
|
+
_resources.push_back(constructor());
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
~ThreadResource() {
|
|
67
|
+
if (_pool) {
|
|
68
|
+
_pool->remove_listener(_listener_id);
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
ThreadResource(ThreadResource&& other) noexcept :
|
|
72
|
+
_pool(std::move(other._pool)),
|
|
73
|
+
_resources(std::move(other._resources)),
|
|
74
|
+
_listener_id(std::move(other._listener_id)) {
|
|
75
|
+
other._pool = nullptr;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
ThreadResource& operator=(ThreadResource&& other) noexcept {
|
|
79
|
+
_pool = std::move(other._pool);
|
|
80
|
+
_resources = std::move(other._resources);
|
|
81
|
+
_listener_id = std::move(other._listener_id);
|
|
82
|
+
other._pool = nullptr;
|
|
83
|
+
return *this;
|
|
84
|
+
}
|
|
85
|
+
ThreadResource(const ThreadResource&) = delete;
|
|
86
|
+
ThreadResource& operator=(const ThreadResource&) = delete;
|
|
87
|
+
T& get(std::size_t thread_id) { return _resources.at(thread_id); }
|
|
88
|
+
const T& get(std::size_t thread_id) const { return _resources.at(thread_id); }
|
|
89
|
+
|
|
90
|
+
private:
|
|
91
|
+
ThreadPool* _pool = nullptr;
|
|
92
|
+
std::vector<T> _resources;
|
|
93
|
+
std::size_t _listener_id;
|
|
94
|
+
};
|
|
95
|
+
|
|
96
|
+
inline ThreadPool& default_thread_pool() {
|
|
97
|
+
static ThreadPool instance;
|
|
98
|
+
return instance;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
} // namespace madspace
|