mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,540 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
// Copyright © 2018 the V8 project authors.
|
|
3
|
+
//
|
|
4
|
+
// Redistribution and use in source and binary forms, with or without
|
|
5
|
+
// modification, are permitted provided that the following conditions are
|
|
6
|
+
// met:
|
|
7
|
+
//
|
|
8
|
+
// * Redistributions of source code must retain the above copyright
|
|
9
|
+
// notice, this list of conditions and the following disclaimer.
|
|
10
|
+
// * Redistributions in binary form must reproduce the above
|
|
11
|
+
// copyright notice, this list of conditions and the following
|
|
12
|
+
// disclaimer in the documentation and/or other materials provided
|
|
13
|
+
// with the distribution.
|
|
14
|
+
// * Neither the name of Google Inc. nor the names of its
|
|
15
|
+
// contributors may be used to endorse or promote products derived
|
|
16
|
+
// from this software without specific prior written permission.
|
|
17
|
+
//
|
|
18
|
+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
19
|
+
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
20
|
+
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
21
|
+
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
22
|
+
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
23
|
+
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
24
|
+
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
25
|
+
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
26
|
+
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
27
|
+
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
29
|
+
|
|
30
|
+
#pragma once
|
|
31
|
+
|
|
32
|
+
#include <algorithm>
|
|
33
|
+
#include <cassert>
|
|
34
|
+
#include <type_traits>
|
|
35
|
+
#include <utility>
|
|
36
|
+
|
|
37
|
+
namespace mlx::core {
|
|
38
|
+
|
|
39
|
+
#if defined(__has_builtin)
|
|
40
|
+
#define MLX_HAS_BUILTIN(x) __has_builtin(x)
|
|
41
|
+
#else
|
|
42
|
+
#define MLX_HAS_BUILTIN(x) 0
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
#if defined(__has_attribute)
|
|
46
|
+
#define MLX_HAS_ATTRIBUTE(x) __has_attribute(x)
|
|
47
|
+
#else
|
|
48
|
+
#define MLX_HAS_ATTRIBUTE(x) 0
|
|
49
|
+
#endif
|
|
50
|
+
|
|
51
|
+
#if MLX_HAS_BUILTIN(__builtin_expect)
|
|
52
|
+
#define MLX_LIKELY(condition) (__builtin_expect(!!(condition), 1))
|
|
53
|
+
#define MLX_UNLIKELY(condition) (__builtin_expect(!!(condition), 0))
|
|
54
|
+
#else
|
|
55
|
+
#define MLX_LIKELY(condition) (condition)
|
|
56
|
+
#define MLX_UNLIKELY(condition) (condition)
|
|
57
|
+
#endif
|
|
58
|
+
|
|
59
|
+
#if MLX_HAS_ATTRIBUTE(noinline)
|
|
60
|
+
#define MLX_NOINLINE __attribute__((noinline))
|
|
61
|
+
#else
|
|
62
|
+
#define MLX_NOINLINE
|
|
63
|
+
#endif
|
|
64
|
+
|
|
65
|
+
template <typename T, typename = void>
|
|
66
|
+
struct is_iterator : std::false_type {};
|
|
67
|
+
|
|
68
|
+
template <typename T>
|
|
69
|
+
struct is_iterator<
|
|
70
|
+
T,
|
|
71
|
+
std::void_t<
|
|
72
|
+
typename std::iterator_traits<T>::difference_type,
|
|
73
|
+
typename std::iterator_traits<T>::iterator_category,
|
|
74
|
+
typename std::iterator_traits<T>::pointer,
|
|
75
|
+
typename std::iterator_traits<T>::reference,
|
|
76
|
+
typename std::iterator_traits<T>::value_type>> : std::true_type {};
|
|
77
|
+
|
|
78
|
+
template <typename T>
|
|
79
|
+
constexpr bool is_iterator_v = is_iterator<T>::value;
|
|
80
|
+
|
|
81
|
+
// Minimal SmallVector implementation. Uses inline storage first, switches to
|
|
82
|
+
// dynamic storage when it overflows.
|
|
83
|
+
//
|
|
84
|
+
// Notes:
|
|
85
|
+
// * The default inline storage size is MAX_NDIM, as it is mainly used for
|
|
86
|
+
// shapes and strides, users should choose a better size for other cases.
|
|
87
|
+
// * The data() returns real address even for empty vector.
|
|
88
|
+
// * The pointer returned by data() will change after moving the vector as it
|
|
89
|
+
// points to the inline storage.
|
|
90
|
+
// * For trivial elements the storage will not be default constructed,
|
|
91
|
+
// i.e. SmallVector<int>(10) will not be filled with 0 by default.
|
|
92
|
+
template <typename T, size_t kSize = 10, typename Allocator = std::allocator<T>>
|
|
93
|
+
class SmallVector {
|
|
94
|
+
public:
|
|
95
|
+
using value_type = T;
|
|
96
|
+
using reference = T&;
|
|
97
|
+
using const_reference = const T&;
|
|
98
|
+
using iterator = T*;
|
|
99
|
+
using const_iterator = const T*;
|
|
100
|
+
using difference_type = std::ptrdiff_t;
|
|
101
|
+
using size_type = std::size_t;
|
|
102
|
+
|
|
103
|
+
SmallVector() = default;
|
|
104
|
+
|
|
105
|
+
explicit SmallVector(const Allocator& allocator) : allocator_(allocator) {}
|
|
106
|
+
|
|
107
|
+
explicit SmallVector(size_t size, const Allocator& allocator = Allocator())
|
|
108
|
+
: allocator_(allocator) {
|
|
109
|
+
resize(size);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
SmallVector(
|
|
113
|
+
size_t size,
|
|
114
|
+
const T& initial_value,
|
|
115
|
+
const Allocator& allocator = Allocator())
|
|
116
|
+
: allocator_(allocator) {
|
|
117
|
+
resize(size, initial_value);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
SmallVector(
|
|
121
|
+
std::initializer_list<T> init,
|
|
122
|
+
const Allocator& allocator = Allocator())
|
|
123
|
+
: allocator_(allocator) {
|
|
124
|
+
if (init.size() > capacity()) {
|
|
125
|
+
grow(init.size());
|
|
126
|
+
}
|
|
127
|
+
assert(capacity() >= init.size()); // sanity check
|
|
128
|
+
std::uninitialized_move(init.begin(), init.end(), begin_);
|
|
129
|
+
end_ = begin_ + init.size();
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
|
|
133
|
+
SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator())
|
|
134
|
+
: allocator_(allocator) {
|
|
135
|
+
size_t size = std::distance(begin, end);
|
|
136
|
+
if (size > capacity()) {
|
|
137
|
+
grow(size);
|
|
138
|
+
}
|
|
139
|
+
assert(capacity() >= size); // sanity check
|
|
140
|
+
std::uninitialized_copy(begin, end, begin_);
|
|
141
|
+
end_ = begin_ + size;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
SmallVector(const SmallVector& other) : allocator_(other.allocator_) {
|
|
145
|
+
*this = other;
|
|
146
|
+
}
|
|
147
|
+
SmallVector(const SmallVector& other, const Allocator& allocator)
|
|
148
|
+
: allocator_(allocator) {
|
|
149
|
+
*this = other;
|
|
150
|
+
}
|
|
151
|
+
SmallVector(SmallVector&& other) : allocator_(std::move(other.allocator_)) {
|
|
152
|
+
*this = std::move(other);
|
|
153
|
+
}
|
|
154
|
+
SmallVector(SmallVector&& other, const Allocator& allocator)
|
|
155
|
+
: allocator_(allocator) {
|
|
156
|
+
*this = std::move(other);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
~SmallVector() {
|
|
160
|
+
free_storage();
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
SmallVector& operator=(const SmallVector& other) {
|
|
164
|
+
if (this == &other) {
|
|
165
|
+
return *this;
|
|
166
|
+
}
|
|
167
|
+
size_t other_size = other.size();
|
|
168
|
+
if (capacity() < other_size) {
|
|
169
|
+
// Create large-enough heap-allocated storage.
|
|
170
|
+
free_storage();
|
|
171
|
+
begin_ = allocator_.allocate(other_size);
|
|
172
|
+
end_of_storage_ = begin_ + other_size;
|
|
173
|
+
std::uninitialized_copy(other.begin_, other.end_, begin_);
|
|
174
|
+
} else if constexpr (kHasTrivialElement) {
|
|
175
|
+
std::copy(other.begin_, other.end_, begin_);
|
|
176
|
+
} else {
|
|
177
|
+
ptrdiff_t to_copy =
|
|
178
|
+
std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);
|
|
179
|
+
std::copy(other.begin_, other.begin_ + to_copy, begin_);
|
|
180
|
+
if (other.begin_ + to_copy < other.end_) {
|
|
181
|
+
std::uninitialized_copy(
|
|
182
|
+
other.begin_ + to_copy, other.end_, begin_ + to_copy);
|
|
183
|
+
} else {
|
|
184
|
+
std::destroy_n(begin_ + to_copy, size() - to_copy);
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
end_ = begin_ + other_size;
|
|
188
|
+
return *this;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
SmallVector& operator=(SmallVector&& other) {
|
|
192
|
+
if (this == &other) {
|
|
193
|
+
return *this;
|
|
194
|
+
}
|
|
195
|
+
if (other.is_big()) {
|
|
196
|
+
free_storage();
|
|
197
|
+
begin_ = other.begin_;
|
|
198
|
+
end_ = other.end_;
|
|
199
|
+
end_of_storage_ = other.end_of_storage_;
|
|
200
|
+
} else {
|
|
201
|
+
assert(capacity() >= other.size()); // sanity check
|
|
202
|
+
size_t other_size = other.size();
|
|
203
|
+
if constexpr (kHasTrivialElement) {
|
|
204
|
+
std::move(other.begin_, other.end_, begin_);
|
|
205
|
+
} else {
|
|
206
|
+
ptrdiff_t to_move =
|
|
207
|
+
std::min(static_cast<ptrdiff_t>(other_size), end_ - begin_);
|
|
208
|
+
std::move(other.begin_, other.begin_ + to_move, begin_);
|
|
209
|
+
if (other.begin_ + to_move < other.end_) {
|
|
210
|
+
std::uninitialized_move(
|
|
211
|
+
other.begin_ + to_move, other.end_, begin_ + to_move);
|
|
212
|
+
} else {
|
|
213
|
+
std::destroy_n(begin_ + to_move, size() - to_move);
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
end_ = begin_ + other_size;
|
|
217
|
+
}
|
|
218
|
+
other.reset_to_inline_storage();
|
|
219
|
+
return *this;
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
bool operator==(const SmallVector& other) const {
|
|
223
|
+
if (size() != other.size()) {
|
|
224
|
+
return false;
|
|
225
|
+
}
|
|
226
|
+
return std::equal(begin_, end_, other.begin_);
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
bool operator!=(const SmallVector& other) const {
|
|
230
|
+
return !(*this == other);
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
T* data() {
|
|
234
|
+
return begin_;
|
|
235
|
+
}
|
|
236
|
+
const T* data() const {
|
|
237
|
+
return begin_;
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
iterator begin() {
|
|
241
|
+
return begin_;
|
|
242
|
+
}
|
|
243
|
+
const_iterator begin() const {
|
|
244
|
+
return begin_;
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
iterator end() {
|
|
248
|
+
return end_;
|
|
249
|
+
}
|
|
250
|
+
const_iterator end() const {
|
|
251
|
+
return end_;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
const_iterator cbegin() const {
|
|
255
|
+
return begin_;
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
const_iterator cend() const {
|
|
259
|
+
return end_;
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
auto rbegin() {
|
|
263
|
+
return std::make_reverse_iterator(end_);
|
|
264
|
+
}
|
|
265
|
+
auto rbegin() const {
|
|
266
|
+
return std::make_reverse_iterator(end_);
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
auto rend() {
|
|
270
|
+
return std::make_reverse_iterator(begin_);
|
|
271
|
+
}
|
|
272
|
+
auto rend() const {
|
|
273
|
+
return std::make_reverse_iterator(begin_);
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
size_t size() const {
|
|
277
|
+
return end_ - begin_;
|
|
278
|
+
}
|
|
279
|
+
bool empty() const {
|
|
280
|
+
return end_ == begin_;
|
|
281
|
+
}
|
|
282
|
+
size_t capacity() const {
|
|
283
|
+
return end_of_storage_ - begin_;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
T& front() {
|
|
287
|
+
assert(size() != 0);
|
|
288
|
+
return begin_[0];
|
|
289
|
+
}
|
|
290
|
+
const T& front() const {
|
|
291
|
+
assert(size() != 0);
|
|
292
|
+
return begin_[0];
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
T& back() {
|
|
296
|
+
assert(size() != 0);
|
|
297
|
+
return end_[-1];
|
|
298
|
+
}
|
|
299
|
+
const T& back() const {
|
|
300
|
+
assert(size() != 0);
|
|
301
|
+
return end_[-1];
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
T& at(size_t index) {
|
|
305
|
+
if (index >= size()) {
|
|
306
|
+
throw std::out_of_range("SmallVector out of range.");
|
|
307
|
+
}
|
|
308
|
+
return begin_[index];
|
|
309
|
+
}
|
|
310
|
+
const T& at(size_t index) const {
|
|
311
|
+
return const_cast<SmallVector*>(this)->at(index);
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
T& operator[](size_t index) {
|
|
315
|
+
assert(size() > index);
|
|
316
|
+
return begin_[index];
|
|
317
|
+
}
|
|
318
|
+
const T& operator[](size_t index) const {
|
|
319
|
+
return const_cast<SmallVector*>(this)->operator[](index);
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
template <typename... Args>
|
|
323
|
+
void emplace_back(Args&&... args) {
|
|
324
|
+
if (MLX_UNLIKELY(end_ == end_of_storage_)) {
|
|
325
|
+
grow();
|
|
326
|
+
}
|
|
327
|
+
void* storage = end_;
|
|
328
|
+
end_ += 1;
|
|
329
|
+
new (storage) T(std::forward<Args>(args)...);
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
void push_back(T x) {
|
|
333
|
+
emplace_back(std::move(x));
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
void pop_back(size_t count = 1) {
|
|
337
|
+
assert(size() >= count);
|
|
338
|
+
end_ -= count;
|
|
339
|
+
std::destroy_n(end_, count);
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
iterator insert(iterator pos, T value) {
|
|
343
|
+
return insert(pos, static_cast<size_t>(1), std::move(value));
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
iterator insert(iterator pos, size_t count, T value) {
|
|
347
|
+
assert(pos <= end_);
|
|
348
|
+
size_t offset = pos - begin_;
|
|
349
|
+
size_t old_size = size();
|
|
350
|
+
resize(old_size + count);
|
|
351
|
+
pos = begin_ + offset;
|
|
352
|
+
iterator old_end = begin_ + old_size;
|
|
353
|
+
assert(old_end <= end_);
|
|
354
|
+
std::move_backward(pos, old_end, end_);
|
|
355
|
+
if constexpr (kHasTrivialElement) {
|
|
356
|
+
std::fill_n(pos, count, value);
|
|
357
|
+
} else {
|
|
358
|
+
std::fill_n(pos + 1, count - 1, value);
|
|
359
|
+
*pos = std::move(value);
|
|
360
|
+
}
|
|
361
|
+
return pos;
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
template <typename Iter, typename = std::enable_if_t<is_iterator_v<Iter>>>
|
|
365
|
+
iterator insert(iterator pos, Iter begin, Iter end) {
|
|
366
|
+
if constexpr (std::is_same_v<std::decay_t<Iter>, iterator>) {
|
|
367
|
+
// The implementation can not take overlapping range.
|
|
368
|
+
assert(!(begin >= pos && begin < pos + std::distance(begin, end)));
|
|
369
|
+
assert(!(end > pos && end <= pos + std::distance(begin, end)));
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
assert(pos <= end_);
|
|
373
|
+
size_t offset = pos - begin_;
|
|
374
|
+
size_t count = std::distance(begin, end);
|
|
375
|
+
size_t old_size = size();
|
|
376
|
+
resize(old_size + count);
|
|
377
|
+
pos = begin_ + offset;
|
|
378
|
+
iterator old_end = begin_ + old_size;
|
|
379
|
+
assert(old_end <= end_);
|
|
380
|
+
std::move_backward(pos, old_end, end_);
|
|
381
|
+
std::copy(begin, end, pos);
|
|
382
|
+
return pos;
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
iterator insert(iterator pos, std::initializer_list<const T> values) {
|
|
386
|
+
return insert(pos, values.begin(), values.end());
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
iterator erase(iterator erase_start, iterator erase_end) {
|
|
390
|
+
assert(erase_start >= begin_);
|
|
391
|
+
assert(erase_start <= erase_end);
|
|
392
|
+
assert(erase_end <= end_);
|
|
393
|
+
iterator new_end = std::move(erase_end, end_, erase_start);
|
|
394
|
+
std::destroy_n(new_end, std::distance(new_end, end_));
|
|
395
|
+
end_ = new_end;
|
|
396
|
+
return erase_start;
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
iterator erase(iterator pos) {
|
|
400
|
+
return erase(pos, pos + 1);
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
void resize(size_t new_size) {
|
|
404
|
+
if (new_size > capacity()) {
|
|
405
|
+
grow(new_size);
|
|
406
|
+
}
|
|
407
|
+
T* new_end = begin_ + new_size;
|
|
408
|
+
if constexpr (!kHasTrivialElement) {
|
|
409
|
+
if (new_end > end_) {
|
|
410
|
+
std::uninitialized_default_construct(end_, new_end);
|
|
411
|
+
} else {
|
|
412
|
+
std::destroy_n(new_end, end_ - new_end);
|
|
413
|
+
}
|
|
414
|
+
}
|
|
415
|
+
end_ = new_end;
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
void resize(size_t new_size, const T& initial_value) {
|
|
419
|
+
if (new_size > capacity()) {
|
|
420
|
+
grow(new_size);
|
|
421
|
+
}
|
|
422
|
+
T* new_end = begin_ + new_size;
|
|
423
|
+
if (new_end > end_) {
|
|
424
|
+
std::uninitialized_fill(end_, new_end, initial_value);
|
|
425
|
+
} else {
|
|
426
|
+
std::destroy_n(new_end, end_ - new_end);
|
|
427
|
+
}
|
|
428
|
+
end_ = new_end;
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
void reserve(size_t new_capacity) {
|
|
432
|
+
if (new_capacity > capacity()) {
|
|
433
|
+
grow(new_capacity);
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
// Clear without reverting back to inline storage.
|
|
438
|
+
void clear() {
|
|
439
|
+
std::destroy_n(begin_, end_ - begin_);
|
|
440
|
+
end_ = begin_;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
private:
|
|
444
|
+
// Grows the backing store by a factor of two, and at least to {min_capacity}.
|
|
445
|
+
// TODO: Move to private after removing external code using this method.
|
|
446
|
+
MLX_NOINLINE void grow(size_t min_capacity = 0) {
|
|
447
|
+
size_t new_capacity = std::max(min_capacity, 2 * capacity());
|
|
448
|
+
// Round up to power of 2.
|
|
449
|
+
new_capacity--;
|
|
450
|
+
new_capacity |= new_capacity >> 1;
|
|
451
|
+
new_capacity |= new_capacity >> 2;
|
|
452
|
+
new_capacity |= new_capacity >> 4;
|
|
453
|
+
new_capacity |= new_capacity >> 8;
|
|
454
|
+
new_capacity |= new_capacity >> 16;
|
|
455
|
+
if constexpr (sizeof(size_t) == sizeof(uint64_t)) {
|
|
456
|
+
new_capacity |= new_capacity >> 32;
|
|
457
|
+
}
|
|
458
|
+
new_capacity++;
|
|
459
|
+
|
|
460
|
+
T* new_storage = allocator_.allocate(new_capacity);
|
|
461
|
+
if (new_storage == nullptr) {
|
|
462
|
+
throw std::bad_alloc();
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
size_t in_use = end_ - begin_;
|
|
466
|
+
std::uninitialized_move(begin_, end_, new_storage);
|
|
467
|
+
free_storage();
|
|
468
|
+
begin_ = new_storage;
|
|
469
|
+
end_ = new_storage + in_use;
|
|
470
|
+
end_of_storage_ = new_storage + new_capacity;
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
MLX_NOINLINE void free_storage() {
|
|
474
|
+
std::destroy_n(begin_, end_ - begin_);
|
|
475
|
+
if (is_big()) {
|
|
476
|
+
allocator_.deallocate(begin_, end_of_storage_ - begin_);
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
// Clear and go back to inline storage. Dynamic storage is *not* freed. For
|
|
481
|
+
// internal use only.
|
|
482
|
+
void reset_to_inline_storage() {
|
|
483
|
+
if constexpr (!kHasTrivialElement) {
|
|
484
|
+
if (!is_big())
|
|
485
|
+
std::destroy_n(begin_, end_ - begin_);
|
|
486
|
+
}
|
|
487
|
+
begin_ = inline_storage_begin();
|
|
488
|
+
end_ = begin_;
|
|
489
|
+
end_of_storage_ = begin_ + kSize;
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
bool is_big() const {
|
|
493
|
+
return begin_ != inline_storage_begin();
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
T* inline_storage_begin() {
|
|
497
|
+
return reinterpret_cast<T*>(inline_storage_);
|
|
498
|
+
}
|
|
499
|
+
const T* inline_storage_begin() const {
|
|
500
|
+
return reinterpret_cast<const T*>(inline_storage_);
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
Allocator allocator_;
|
|
504
|
+
|
|
505
|
+
// Invariants:
|
|
506
|
+
// 1. The elements in the range between `begin_` (included) and `end_` (not
|
|
507
|
+
// included) will be initialized at all times.
|
|
508
|
+
// 2. All other elements outside the range, both in the inline storage and in
|
|
509
|
+
// the dynamic storage (if it exists), will be uninitialized at all times.
|
|
510
|
+
|
|
511
|
+
T* begin_ = inline_storage_begin();
|
|
512
|
+
T* end_ = begin_;
|
|
513
|
+
T* end_of_storage_ = begin_ + kSize;
|
|
514
|
+
|
|
515
|
+
alignas(T) char inline_storage_[sizeof(T) * kSize];
|
|
516
|
+
|
|
517
|
+
static constexpr bool kHasTrivialElement =
|
|
518
|
+
std::is_trivially_copyable<T>::value &&
|
|
519
|
+
std::is_trivially_destructible<T>::value;
|
|
520
|
+
};
|
|
521
|
+
|
|
522
|
+
template <typename>
|
|
523
|
+
struct is_vector : std::false_type {};
|
|
524
|
+
|
|
525
|
+
template <typename T, size_t Size, typename Allocator>
|
|
526
|
+
struct is_vector<SmallVector<T, Size, Allocator>> : std::true_type {};
|
|
527
|
+
|
|
528
|
+
template <typename T, typename Allocator>
|
|
529
|
+
struct is_vector<std::vector<T, Allocator>> : std::true_type {};
|
|
530
|
+
|
|
531
|
+
template <typename Vec>
|
|
532
|
+
inline constexpr bool is_vector_v = is_vector<Vec>::value;
|
|
533
|
+
|
|
534
|
+
#undef MLX_HAS_BUILTIN
|
|
535
|
+
#undef MLX_HAS_ATTRIBUTE
|
|
536
|
+
#undef MLX_LIKELY
|
|
537
|
+
#undef MLX_UNLIKELY
|
|
538
|
+
#undef MLX_NOINLINE
|
|
539
|
+
|
|
540
|
+
} // namespace mlx::core
|
mlx/include/mlx/stream.h
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/device.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
struct Stream {
|
|
10
|
+
int index;
|
|
11
|
+
Device device;
|
|
12
|
+
explicit Stream(int index, Device device) : index(index), device(device) {}
|
|
13
|
+
};
|
|
14
|
+
|
|
15
|
+
/** Get the default stream for the given device. */
|
|
16
|
+
Stream default_stream(Device d);
|
|
17
|
+
|
|
18
|
+
/** Make the stream the default for its device. */
|
|
19
|
+
void set_default_stream(Stream s);
|
|
20
|
+
|
|
21
|
+
/** Make a new stream on the given device. */
|
|
22
|
+
Stream new_stream(Device d);
|
|
23
|
+
|
|
24
|
+
/** Get the stream with the given index. */
|
|
25
|
+
Stream get_stream(int index);
|
|
26
|
+
|
|
27
|
+
inline bool operator==(const Stream& lhs, const Stream& rhs) {
|
|
28
|
+
return lhs.index == rhs.index;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
inline bool operator!=(const Stream& lhs, const Stream& rhs) {
|
|
32
|
+
return !(lhs == rhs);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
/* Synchronize with the default stream. */
|
|
36
|
+
void synchronize();
|
|
37
|
+
|
|
38
|
+
/* Synchronize with the provided stream. */
|
|
39
|
+
void synchronize(Stream);
|
|
40
|
+
|
|
41
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
// This code was modified from https://github.com/progschj/ThreadPool
|
|
2
|
+
// The original License is copied below:
|
|
3
|
+
//
|
|
4
|
+
// Copyright (c) 2012 Jakob Progsch, Václav Zeman
|
|
5
|
+
// This software is provided 'as-is', without any express or implied
|
|
6
|
+
// warranty. In no event will the authors be held liable for any damages
|
|
7
|
+
// arising from the use of this software.
|
|
8
|
+
//
|
|
9
|
+
// Permission is granted to anyone to use this software for any purpose,
|
|
10
|
+
// including commercial applications, and to alter it and redistribute it
|
|
11
|
+
// freely, subject to the following restrictions:
|
|
12
|
+
//
|
|
13
|
+
// 1. The origin of this software must not be misrepresented; you must not
|
|
14
|
+
// claim that you wrote the original software. If you use this software
|
|
15
|
+
// in a product, an acknowledgment in the product documentation would be
|
|
16
|
+
// appreciated but is not required.
|
|
17
|
+
//
|
|
18
|
+
// 2. Altered source versions must be plainly marked as such, and must not be
|
|
19
|
+
// misrepresented as being the original software.
|
|
20
|
+
//
|
|
21
|
+
// 3. This notice may not be removed or altered from any source
|
|
22
|
+
// distribution.
|
|
23
|
+
#pragma once
|
|
24
|
+
|
|
25
|
+
#include <condition_variable>
|
|
26
|
+
#include <functional>
|
|
27
|
+
#include <future>
|
|
28
|
+
#include <memory>
|
|
29
|
+
#include <mutex>
|
|
30
|
+
#include <queue>
|
|
31
|
+
#include <stdexcept>
|
|
32
|
+
#include <thread>
|
|
33
|
+
#include <vector>
|
|
34
|
+
|
|
35
|
+
class ThreadPool {
|
|
36
|
+
public:
|
|
37
|
+
ThreadPool(size_t);
|
|
38
|
+
template <class F, class... Args>
|
|
39
|
+
auto enqueue(F&& f, Args&&... args)
|
|
40
|
+
-> std::future<typename std::invoke_result_t<F, Args...>>;
|
|
41
|
+
void resize(size_t);
|
|
42
|
+
~ThreadPool();
|
|
43
|
+
|
|
44
|
+
private:
|
|
45
|
+
void stop_and_wait();
|
|
46
|
+
void start_threads(size_t);
|
|
47
|
+
|
|
48
|
+
std::vector<std::thread> workers;
|
|
49
|
+
std::queue<std::function<void()>> tasks;
|
|
50
|
+
std::mutex queue_mutex;
|
|
51
|
+
std::condition_variable condition;
|
|
52
|
+
bool stop;
|
|
53
|
+
};
|
|
54
|
+
|
|
55
|
+
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
|
|
56
|
+
start_threads(threads);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
template <class F, class... Args>
|
|
60
|
+
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
|
61
|
+
-> std::future<typename std::invoke_result_t<F, Args...>> {
|
|
62
|
+
using return_type = typename std::invoke_result_t<F, Args...>;
|
|
63
|
+
|
|
64
|
+
auto task = std::make_shared<std::packaged_task<return_type()>>(
|
|
65
|
+
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
|
|
66
|
+
|
|
67
|
+
std::future<return_type> res = task->get_future();
|
|
68
|
+
{
|
|
69
|
+
std::unique_lock<std::mutex> lock(queue_mutex);
|
|
70
|
+
|
|
71
|
+
if (stop) {
|
|
72
|
+
throw std::runtime_error(
|
|
73
|
+
"[ThreadPool::enqueue] Not allowed on stopped ThreadPool");
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
tasks.emplace([task]() { (*task)(); });
|
|
77
|
+
}
|
|
78
|
+
condition.notify_one();
|
|
79
|
+
return res;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
inline void ThreadPool::resize(size_t threads) {
|
|
83
|
+
if (workers.size() == threads) {
|
|
84
|
+
return;
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
if (workers.size() > threads) {
|
|
88
|
+
stop_and_wait();
|
|
89
|
+
}
|
|
90
|
+
start_threads(threads - workers.size());
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
inline ThreadPool::~ThreadPool() {
|
|
94
|
+
stop_and_wait();
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
inline void ThreadPool::stop_and_wait() {
|
|
98
|
+
// Stop the current threads and wait until they finish
|
|
99
|
+
{
|
|
100
|
+
std::unique_lock<std::mutex> lock(queue_mutex);
|
|
101
|
+
stop = true;
|
|
102
|
+
}
|
|
103
|
+
condition.notify_all();
|
|
104
|
+
for (std::thread& worker : workers) {
|
|
105
|
+
worker.join();
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// Reset the member variables so that the threadpool is reusable
|
|
109
|
+
stop = false;
|
|
110
|
+
workers.clear();
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
inline void ThreadPool::start_threads(size_t threads) {
|
|
114
|
+
for (size_t i = 0; i < threads; ++i) {
|
|
115
|
+
workers.emplace_back([this] {
|
|
116
|
+
for (;;) {
|
|
117
|
+
std::function<void()> task;
|
|
118
|
+
|
|
119
|
+
{
|
|
120
|
+
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
|
121
|
+
this->condition.wait(
|
|
122
|
+
lock, [this] { return this->stop || !this->tasks.empty(); });
|
|
123
|
+
if (this->stop && this->tasks.empty())
|
|
124
|
+
return;
|
|
125
|
+
task = std::move(this->tasks.front());
|
|
126
|
+
this->tasks.pop();
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
task();
|
|
130
|
+
}
|
|
131
|
+
});
|
|
132
|
+
}
|
|
133
|
+
}
|