mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
mlx/include/mlx/random.h
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <chrono>
|
|
6
|
+
#include <optional>
|
|
7
|
+
|
|
8
|
+
#include "mlx/array.h"
|
|
9
|
+
#include "mlx/stream.h"
|
|
10
|
+
#include "mlx/utils.h"
|
|
11
|
+
|
|
12
|
+
namespace mlx::core::random {
|
|
13
|
+
|
|
14
|
+
class KeySequence {
|
|
15
|
+
public:
|
|
16
|
+
explicit KeySequence(uint64_t seed);
|
|
17
|
+
|
|
18
|
+
void seed(uint64_t seed);
|
|
19
|
+
array next();
|
|
20
|
+
|
|
21
|
+
// static default
|
|
22
|
+
static KeySequence& default_() {
|
|
23
|
+
static KeySequence ks(get_current_time_seed());
|
|
24
|
+
return ks;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
private:
|
|
28
|
+
array key_;
|
|
29
|
+
static uint64_t get_current_time_seed() {
|
|
30
|
+
auto now = std::chrono::system_clock::now();
|
|
31
|
+
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
|
32
|
+
now.time_since_epoch())
|
|
33
|
+
.count();
|
|
34
|
+
}
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
/** Get a PRNG key from a seed. */
|
|
38
|
+
array key(uint64_t seed);
|
|
39
|
+
|
|
40
|
+
/** Seed the default PRNG key. */
|
|
41
|
+
void seed(uint64_t seed);
|
|
42
|
+
|
|
43
|
+
/** Generate an array with type uint32 filled with random bits. */
|
|
44
|
+
array bits(
|
|
45
|
+
const Shape& shape,
|
|
46
|
+
int width,
|
|
47
|
+
const std::optional<array>& key = std::nullopt,
|
|
48
|
+
StreamOrDevice s = {});
|
|
49
|
+
inline array bits(
|
|
50
|
+
const Shape& shape,
|
|
51
|
+
const std::optional<array>& key = std::nullopt,
|
|
52
|
+
StreamOrDevice s = {}) {
|
|
53
|
+
return bits(shape, 4, key, s);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
/** Split the rng key into a pair of keys. */
|
|
57
|
+
std::pair<array, array> split(const array& key, StreamOrDevice s = {});
|
|
58
|
+
|
|
59
|
+
/** Split the rng key into `num` keys. */
|
|
60
|
+
array split(const array& key, int num, StreamOrDevice s = {});
|
|
61
|
+
|
|
62
|
+
/** Generate uniform random numbers between low and high. */
|
|
63
|
+
array uniform(
|
|
64
|
+
const array& low,
|
|
65
|
+
const array& high,
|
|
66
|
+
const Shape& shape,
|
|
67
|
+
Dtype dtype = float32,
|
|
68
|
+
const std::optional<array>& key = std::nullopt,
|
|
69
|
+
StreamOrDevice s = {});
|
|
70
|
+
|
|
71
|
+
template <typename T, typename U>
|
|
72
|
+
array uniform(
|
|
73
|
+
T low,
|
|
74
|
+
U high,
|
|
75
|
+
const Shape& shape,
|
|
76
|
+
Dtype dtype = float32,
|
|
77
|
+
const std::optional<array>& key = std::nullopt,
|
|
78
|
+
StreamOrDevice s = {}) {
|
|
79
|
+
return uniform(array(low), array(high), shape, dtype, key, to_stream(s));
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/** Generate uniform random numbers between 0 and 1. */
|
|
83
|
+
array uniform(
|
|
84
|
+
const Shape& shape,
|
|
85
|
+
Dtype dtype,
|
|
86
|
+
const std::optional<array>& key = std::nullopt,
|
|
87
|
+
StreamOrDevice s = {});
|
|
88
|
+
inline array uniform(
|
|
89
|
+
const Shape& shape,
|
|
90
|
+
const std::optional<array>& key = std::nullopt,
|
|
91
|
+
StreamOrDevice s = {}) {
|
|
92
|
+
return uniform(shape, float32, key);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
/** Generate samples from the standard normal distribution. */
|
|
96
|
+
array normal(
|
|
97
|
+
const Shape& shape,
|
|
98
|
+
Dtype dtype,
|
|
99
|
+
const std::optional<array>& loc,
|
|
100
|
+
const std::optional<array>& scale,
|
|
101
|
+
const std::optional<array>& key,
|
|
102
|
+
StreamOrDevice s = {});
|
|
103
|
+
inline array normal(
|
|
104
|
+
const Shape& shape,
|
|
105
|
+
Dtype dtype,
|
|
106
|
+
const float loc,
|
|
107
|
+
const float scale,
|
|
108
|
+
const std::optional<array>& key = std::nullopt,
|
|
109
|
+
StreamOrDevice s = {}) {
|
|
110
|
+
auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype));
|
|
111
|
+
auto scale_ =
|
|
112
|
+
scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype));
|
|
113
|
+
return normal(shape, dtype, loc_, scale_, key, s);
|
|
114
|
+
}
|
|
115
|
+
inline array normal(
|
|
116
|
+
const Shape& shape,
|
|
117
|
+
const float loc,
|
|
118
|
+
const float scale,
|
|
119
|
+
const std::optional<array>& key = std::nullopt,
|
|
120
|
+
StreamOrDevice s = {}) {
|
|
121
|
+
return normal(shape, float32, loc, scale, key, s);
|
|
122
|
+
}
|
|
123
|
+
inline array normal(
|
|
124
|
+
const Shape& shape,
|
|
125
|
+
const Dtype dtype,
|
|
126
|
+
const std::optional<array>& key = std::nullopt,
|
|
127
|
+
StreamOrDevice s = {}) {
|
|
128
|
+
return normal(shape, dtype, std::nullopt, std::nullopt, key, s);
|
|
129
|
+
}
|
|
130
|
+
inline array normal(
|
|
131
|
+
const Shape& shape,
|
|
132
|
+
const std::optional<array>& key = std::nullopt,
|
|
133
|
+
StreamOrDevice s = {}) {
|
|
134
|
+
return normal(shape, float32, std::nullopt, std::nullopt, key, s);
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
/** Generate samples from a multivariate normal distribution. **/
|
|
138
|
+
array multivariate_normal(
|
|
139
|
+
const array& mean,
|
|
140
|
+
const array& cov,
|
|
141
|
+
const Shape& shape,
|
|
142
|
+
Dtype dtype,
|
|
143
|
+
const std::optional<array>& key = std::nullopt,
|
|
144
|
+
StreamOrDevice s = {});
|
|
145
|
+
|
|
146
|
+
/** Generate integer samples uniformly at random */
|
|
147
|
+
array randint(
|
|
148
|
+
const array& low,
|
|
149
|
+
const array& high,
|
|
150
|
+
const Shape& shape,
|
|
151
|
+
Dtype dtype = int32,
|
|
152
|
+
const std::optional<array>& key = std::nullopt,
|
|
153
|
+
StreamOrDevice s = {});
|
|
154
|
+
|
|
155
|
+
template <typename T, typename U>
|
|
156
|
+
array randint(
|
|
157
|
+
T low,
|
|
158
|
+
U high,
|
|
159
|
+
const Shape& shape,
|
|
160
|
+
Dtype dtype = int32,
|
|
161
|
+
const std::optional<array>& key = std::nullopt,
|
|
162
|
+
StreamOrDevice s = {}) {
|
|
163
|
+
return randint(array(low), array(high), shape, dtype, key, to_stream(s));
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
/** Generate binary variables with probability to be true equal to p */
|
|
167
|
+
array bernoulli(
|
|
168
|
+
const array& p,
|
|
169
|
+
const Shape& shape,
|
|
170
|
+
const std::optional<array>& key = std::nullopt,
|
|
171
|
+
StreamOrDevice s = {});
|
|
172
|
+
array bernoulli(
|
|
173
|
+
const array& p,
|
|
174
|
+
const std::optional<array>& key = std::nullopt,
|
|
175
|
+
StreamOrDevice s = {});
|
|
176
|
+
|
|
177
|
+
template <typename T>
|
|
178
|
+
array bernoulli(
|
|
179
|
+
T p,
|
|
180
|
+
const std::optional<array>& key = std::nullopt,
|
|
181
|
+
StreamOrDevice s = {}) {
|
|
182
|
+
return bernoulli(array(p), key, s);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
template <typename T>
|
|
186
|
+
array bernoulli(
|
|
187
|
+
T p,
|
|
188
|
+
const Shape& shape,
|
|
189
|
+
const std::optional<array>& key = std::nullopt,
|
|
190
|
+
StreamOrDevice s = {}) {
|
|
191
|
+
return bernoulli(array(p), shape, key, s);
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
array bernoulli(
|
|
195
|
+
const std::optional<array>& key = std::nullopt,
|
|
196
|
+
StreamOrDevice s = {});
|
|
197
|
+
|
|
198
|
+
array truncated_normal(
|
|
199
|
+
const array& lower,
|
|
200
|
+
const array& upper,
|
|
201
|
+
const Shape& shape,
|
|
202
|
+
Dtype dtype = float32,
|
|
203
|
+
const std::optional<array>& key = std::nullopt,
|
|
204
|
+
StreamOrDevice s = {});
|
|
205
|
+
|
|
206
|
+
array truncated_normal(
|
|
207
|
+
const array& lower,
|
|
208
|
+
const array& upper,
|
|
209
|
+
Dtype dtype = float32,
|
|
210
|
+
const std::optional<array>& key = std::nullopt,
|
|
211
|
+
StreamOrDevice s = {});
|
|
212
|
+
|
|
213
|
+
array gumbel(
|
|
214
|
+
const Shape& shape,
|
|
215
|
+
Dtype dtype = float32,
|
|
216
|
+
const std::optional<array>& key = std::nullopt,
|
|
217
|
+
StreamOrDevice s = {});
|
|
218
|
+
|
|
219
|
+
array categorical(
|
|
220
|
+
const array& logits,
|
|
221
|
+
int axis,
|
|
222
|
+
const Shape& shape,
|
|
223
|
+
const std::optional<array>& key = std::nullopt,
|
|
224
|
+
StreamOrDevice s = {});
|
|
225
|
+
|
|
226
|
+
array categorical(
|
|
227
|
+
const array& logits_,
|
|
228
|
+
int axis,
|
|
229
|
+
int num_samples,
|
|
230
|
+
const std::optional<array>& key = std::nullopt,
|
|
231
|
+
StreamOrDevice s = {});
|
|
232
|
+
|
|
233
|
+
array categorical(
|
|
234
|
+
const array& logits,
|
|
235
|
+
int axis = -1,
|
|
236
|
+
const std::optional<array>& key = std::nullopt,
|
|
237
|
+
StreamOrDevice s = {});
|
|
238
|
+
|
|
239
|
+
/** Generate samples from the laplace distribution. */
|
|
240
|
+
array laplace(
|
|
241
|
+
const Shape& shape,
|
|
242
|
+
Dtype dtype,
|
|
243
|
+
const float loc,
|
|
244
|
+
const float scale,
|
|
245
|
+
const std::optional<array>& key = std::nullopt,
|
|
246
|
+
StreamOrDevice s = {});
|
|
247
|
+
inline array laplace(
|
|
248
|
+
const Shape& shape,
|
|
249
|
+
const float loc,
|
|
250
|
+
const float scale,
|
|
251
|
+
const std::optional<array>& key = std::nullopt,
|
|
252
|
+
StreamOrDevice s = {}) {
|
|
253
|
+
return laplace(shape, float32, loc, scale, key, s);
|
|
254
|
+
}
|
|
255
|
+
inline array laplace(
|
|
256
|
+
const Shape& shape,
|
|
257
|
+
const Dtype dtype,
|
|
258
|
+
const std::optional<array>& key = std::nullopt,
|
|
259
|
+
StreamOrDevice s = {}) {
|
|
260
|
+
return laplace(shape, dtype, 0.0, 1.0, key, s);
|
|
261
|
+
}
|
|
262
|
+
inline array laplace(
|
|
263
|
+
const Shape& shape,
|
|
264
|
+
const std::optional<array>& key = std::nullopt,
|
|
265
|
+
StreamOrDevice s = {}) {
|
|
266
|
+
return laplace(shape, float32, 0.0, 1.0, key, s);
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
/* Randomly permute the elements of x along the given axis. */
|
|
270
|
+
array permutation(
|
|
271
|
+
const array& x,
|
|
272
|
+
int axis = 0,
|
|
273
|
+
const std::optional<array>& key = std::nullopt,
|
|
274
|
+
StreamOrDevice s = {});
|
|
275
|
+
|
|
276
|
+
/* A random permutation of `arange(x)` */
|
|
277
|
+
array permutation(
|
|
278
|
+
int x,
|
|
279
|
+
const std::optional<array>& key = std::nullopt,
|
|
280
|
+
StreamOrDevice s = {});
|
|
281
|
+
|
|
282
|
+
} // namespace mlx::core::random
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <atomic>
|
|
6
|
+
#include <future>
|
|
7
|
+
#include <queue>
|
|
8
|
+
#include <thread>
|
|
9
|
+
#include <unordered_map>
|
|
10
|
+
|
|
11
|
+
#include "mlx/backend/gpu/eval.h"
|
|
12
|
+
#include "mlx/device.h"
|
|
13
|
+
#include "mlx/stream.h"
|
|
14
|
+
|
|
15
|
+
namespace mlx::core::scheduler {
|
|
16
|
+
|
|
17
|
+
struct StreamThread {
|
|
18
|
+
std::mutex mtx;
|
|
19
|
+
std::queue<std::function<void()>> q;
|
|
20
|
+
std::condition_variable cond;
|
|
21
|
+
bool stop;
|
|
22
|
+
std::thread thread;
|
|
23
|
+
|
|
24
|
+
StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {}
|
|
25
|
+
|
|
26
|
+
~StreamThread() {
|
|
27
|
+
{
|
|
28
|
+
std::lock_guard<std::mutex> lk(mtx);
|
|
29
|
+
stop = true;
|
|
30
|
+
}
|
|
31
|
+
cond.notify_one();
|
|
32
|
+
thread.join();
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
void thread_fn() {
|
|
36
|
+
while (true) {
|
|
37
|
+
std::function<void()> task;
|
|
38
|
+
{
|
|
39
|
+
std::unique_lock<std::mutex> lk(mtx);
|
|
40
|
+
cond.wait(lk, [this] { return !this->q.empty() || this->stop; });
|
|
41
|
+
if (q.empty() && stop) {
|
|
42
|
+
return;
|
|
43
|
+
}
|
|
44
|
+
task = std::move(q.front());
|
|
45
|
+
q.pop();
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
task();
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
template <typename F>
|
|
53
|
+
void enqueue(F&& f) {
|
|
54
|
+
{
|
|
55
|
+
std::lock_guard<std::mutex> lk(mtx);
|
|
56
|
+
if (stop) {
|
|
57
|
+
throw std::runtime_error(
|
|
58
|
+
"Cannot enqueue work after stream is stopped.");
|
|
59
|
+
}
|
|
60
|
+
q.emplace(std::forward<F>(f));
|
|
61
|
+
}
|
|
62
|
+
cond.notify_one();
|
|
63
|
+
}
|
|
64
|
+
};
|
|
65
|
+
|
|
66
|
+
class Scheduler {
|
|
67
|
+
public:
|
|
68
|
+
Scheduler() : n_active_tasks_(0) {
|
|
69
|
+
if (is_available(Device::gpu)) {
|
|
70
|
+
default_streams_.insert({Device::gpu, new_stream(Device::gpu)});
|
|
71
|
+
}
|
|
72
|
+
default_streams_.insert({Device::cpu, new_stream(Device::cpu)});
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
// Not copyable or moveable
|
|
76
|
+
Scheduler(const Scheduler&) = delete;
|
|
77
|
+
Scheduler(Scheduler&&) = delete;
|
|
78
|
+
Scheduler& operator=(const Scheduler&) = delete;
|
|
79
|
+
Scheduler& operator=(Scheduler&&) = delete;
|
|
80
|
+
|
|
81
|
+
Stream new_stream(const Device& d) {
|
|
82
|
+
streams_.emplace_back(streams_.size(), d);
|
|
83
|
+
if (d == Device::gpu) {
|
|
84
|
+
threads_.push_back(nullptr);
|
|
85
|
+
gpu::new_stream(streams_.back());
|
|
86
|
+
} else {
|
|
87
|
+
threads_.push_back(new StreamThread{});
|
|
88
|
+
}
|
|
89
|
+
return streams_.back();
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
template <typename F>
|
|
93
|
+
void enqueue(const Stream& stream, F&& f);
|
|
94
|
+
|
|
95
|
+
Stream get_default_stream(const Device& d) const {
|
|
96
|
+
return default_streams_.at(d.type);
|
|
97
|
+
}
|
|
98
|
+
Stream get_stream(int index) const {
|
|
99
|
+
return streams_.at(index);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
void set_default_stream(const Stream& s) {
|
|
103
|
+
default_streams_.at(s.device.type) = s;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
void notify_new_task(const Stream& stream) {
|
|
107
|
+
{
|
|
108
|
+
std::lock_guard<std::mutex> lk(mtx);
|
|
109
|
+
n_active_tasks_++;
|
|
110
|
+
}
|
|
111
|
+
completion_cv.notify_all();
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
void notify_task_completion(const Stream& stream) {
|
|
115
|
+
{
|
|
116
|
+
std::lock_guard<std::mutex> lk(mtx);
|
|
117
|
+
n_active_tasks_--;
|
|
118
|
+
}
|
|
119
|
+
completion_cv.notify_all();
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
int n_active_tasks() const {
|
|
123
|
+
return n_active_tasks_;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
void wait_for_one() {
|
|
127
|
+
std::unique_lock<std::mutex> lk(mtx);
|
|
128
|
+
int n_tasks_old = n_active_tasks();
|
|
129
|
+
if (n_tasks_old > 1) {
|
|
130
|
+
completion_cv.wait(lk, [this, n_tasks_old] {
|
|
131
|
+
return this->n_active_tasks() < n_tasks_old;
|
|
132
|
+
});
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
~Scheduler() {
|
|
137
|
+
for (auto s : streams_) {
|
|
138
|
+
try {
|
|
139
|
+
synchronize(s);
|
|
140
|
+
} catch (const std::runtime_error&) {
|
|
141
|
+
// ignore errors if synch fails
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
for (auto t : threads_) {
|
|
145
|
+
if (t != nullptr) {
|
|
146
|
+
delete t;
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
private:
|
|
152
|
+
int n_active_tasks_;
|
|
153
|
+
std::vector<StreamThread*> threads_;
|
|
154
|
+
std::vector<Stream> streams_;
|
|
155
|
+
std::unordered_map<Device::DeviceType, Stream> default_streams_;
|
|
156
|
+
std::condition_variable completion_cv;
|
|
157
|
+
std::mutex mtx;
|
|
158
|
+
};
|
|
159
|
+
|
|
160
|
+
template <typename F>
|
|
161
|
+
void Scheduler::enqueue(const Stream& stream, F&& f) {
|
|
162
|
+
threads_[stream.index]->enqueue(std::forward<F>(f));
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
Scheduler& scheduler();
|
|
166
|
+
|
|
167
|
+
template <typename F>
|
|
168
|
+
void enqueue(const Stream& stream, F&& f) {
|
|
169
|
+
scheduler().enqueue(stream, std::forward<F>(f));
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
inline int n_active_tasks() {
|
|
173
|
+
return scheduler().n_active_tasks();
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
inline void notify_new_task(const Stream& stream) {
|
|
177
|
+
scheduler().notify_new_task(stream);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
inline void notify_task_completion(const Stream& stream) {
|
|
181
|
+
scheduler().notify_task_completion(stream);
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
inline void wait_for_one() {
|
|
185
|
+
scheduler().wait_for_one();
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
} // namespace mlx::core::scheduler
|