mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
#include "mlx/allocator.h"
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
#include "mlx/backend/common/utils.h"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core {
|
|
9
|
+
|
|
10
|
+
// TODO: Add support for more combinations of input types.
|
|
11
|
+
enum class TernaryOpType {
|
|
12
|
+
ScalarScalarScalar,
|
|
13
|
+
VectorVectorVector,
|
|
14
|
+
VectorVectorScalar,
|
|
15
|
+
VectorScalarVector,
|
|
16
|
+
General,
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
inline TernaryOpType
|
|
20
|
+
get_ternary_op_type(const array& a, const array& b, const array& c) {
|
|
21
|
+
TernaryOpType topt;
|
|
22
|
+
if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
|
|
23
|
+
topt = TernaryOpType::ScalarScalarScalar;
|
|
24
|
+
} else if (
|
|
25
|
+
(a.flags().row_contiguous && b.flags().row_contiguous &&
|
|
26
|
+
c.flags().row_contiguous) ||
|
|
27
|
+
(a.flags().col_contiguous && b.flags().col_contiguous &&
|
|
28
|
+
c.flags().col_contiguous)) {
|
|
29
|
+
topt = TernaryOpType::VectorVectorVector;
|
|
30
|
+
} else if (
|
|
31
|
+
b.data_size() == 1 && a.flags().row_contiguous &&
|
|
32
|
+
c.flags().row_contiguous) {
|
|
33
|
+
topt = TernaryOpType::VectorScalarVector;
|
|
34
|
+
} else if (
|
|
35
|
+
c.data_size() == 1 && a.flags().row_contiguous &&
|
|
36
|
+
b.flags().row_contiguous) {
|
|
37
|
+
topt = TernaryOpType::VectorVectorScalar;
|
|
38
|
+
} else {
|
|
39
|
+
topt = TernaryOpType::General;
|
|
40
|
+
}
|
|
41
|
+
return topt;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
inline void set_ternary_op_output_data(
|
|
45
|
+
const array& a,
|
|
46
|
+
const array& b,
|
|
47
|
+
const array& c,
|
|
48
|
+
array& out,
|
|
49
|
+
TernaryOpType topt,
|
|
50
|
+
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
|
51
|
+
auto maybe_donate = [&out](const array& x) {
|
|
52
|
+
if (is_donatable(x, out)) {
|
|
53
|
+
out.copy_shared_buffer(x);
|
|
54
|
+
return true;
|
|
55
|
+
}
|
|
56
|
+
return false;
|
|
57
|
+
};
|
|
58
|
+
|
|
59
|
+
switch (topt) {
|
|
60
|
+
case TernaryOpType::ScalarScalarScalar:
|
|
61
|
+
out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
|
|
62
|
+
break;
|
|
63
|
+
case TernaryOpType::VectorVectorVector:
|
|
64
|
+
if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
|
|
65
|
+
out.set_data(
|
|
66
|
+
mallocfn(out.itemsize() * b.data_size()),
|
|
67
|
+
b.data_size(),
|
|
68
|
+
b.strides(),
|
|
69
|
+
b.flags());
|
|
70
|
+
}
|
|
71
|
+
break;
|
|
72
|
+
case TernaryOpType::VectorVectorScalar:
|
|
73
|
+
case TernaryOpType::VectorScalarVector:
|
|
74
|
+
case TernaryOpType::General:
|
|
75
|
+
// Try to donate an input which is row_contiguous
|
|
76
|
+
if (!((a.flags().row_contiguous && maybe_donate(a)) ||
|
|
77
|
+
(b.flags().row_contiguous && maybe_donate(b)) ||
|
|
78
|
+
(c.flags().row_contiguous && maybe_donate(c)))) {
|
|
79
|
+
out.set_data(mallocfn(out.nbytes()));
|
|
80
|
+
}
|
|
81
|
+
break;
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/allocator.h"
|
|
6
|
+
#include "mlx/backend/common/utils.h"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core {
|
|
9
|
+
|
|
10
|
+
inline void set_unary_output_data(
|
|
11
|
+
const array& in,
|
|
12
|
+
array& out,
|
|
13
|
+
std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
|
|
14
|
+
if (in.flags().contiguous) {
|
|
15
|
+
if (is_donatable(in, out)) {
|
|
16
|
+
out.copy_shared_buffer(in);
|
|
17
|
+
} else {
|
|
18
|
+
out.set_data(
|
|
19
|
+
mallocfn(in.data_size() * out.itemsize()),
|
|
20
|
+
in.data_size(),
|
|
21
|
+
in.strides(),
|
|
22
|
+
in.flags());
|
|
23
|
+
}
|
|
24
|
+
} else {
|
|
25
|
+
out.set_data(mallocfn(out.nbytes()));
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <filesystem>
|
|
6
|
+
#include <tuple>
|
|
7
|
+
#include <vector>
|
|
8
|
+
|
|
9
|
+
#include "mlx/array.h"
|
|
10
|
+
|
|
11
|
+
namespace mlx::core {
|
|
12
|
+
|
|
13
|
+
// Return the directory that contains current shared library.
|
|
14
|
+
std::filesystem::path current_binary_dir();
|
|
15
|
+
|
|
16
|
+
inline int64_t
|
|
17
|
+
elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
|
|
18
|
+
int64_t loc = 0;
|
|
19
|
+
for (int i = shape.size() - 1; i >= 0; --i) {
|
|
20
|
+
auto q_and_r = ldiv(elem, shape[i]);
|
|
21
|
+
loc += q_and_r.rem * strides[i];
|
|
22
|
+
elem = q_and_r.quot;
|
|
23
|
+
}
|
|
24
|
+
return loc;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
inline int64_t elem_to_loc(int elem, const array& a) {
|
|
28
|
+
if (a.flags().row_contiguous) {
|
|
29
|
+
return elem;
|
|
30
|
+
}
|
|
31
|
+
return elem_to_loc(elem, a.shape(), a.strides());
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
inline Strides make_contiguous_strides(const Shape& shape) {
|
|
35
|
+
Strides strides(shape.size(), 1);
|
|
36
|
+
for (int i = shape.size() - 1; i > 0; i--) {
|
|
37
|
+
strides[i - 1] = strides[i] * shape[i];
|
|
38
|
+
}
|
|
39
|
+
return strides;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
// Collapse dims that are contiguous to possibly route to a better kernel
|
|
43
|
+
// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
|
|
44
|
+
// should return {{2, 4}, {{1, 2}}}.
|
|
45
|
+
//
|
|
46
|
+
// When multiple arrays are passed they should all have the same shape. The
|
|
47
|
+
// collapsed axes are also the same so one shape is returned.
|
|
48
|
+
std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
|
49
|
+
const Shape& shape,
|
|
50
|
+
const std::vector<Strides>& strides,
|
|
51
|
+
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
|
52
|
+
|
|
53
|
+
inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
|
|
54
|
+
const std::vector<array>& xs,
|
|
55
|
+
size_t size_cap = std::numeric_limits<int32_t>::max()) {
|
|
56
|
+
std::vector<Strides> strides;
|
|
57
|
+
for (auto& x : xs) {
|
|
58
|
+
strides.emplace_back(x.strides());
|
|
59
|
+
}
|
|
60
|
+
return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
|
|
64
|
+
inline auto collapse_contiguous_dims(Arrays&&... xs) {
|
|
65
|
+
return collapse_contiguous_dims(
|
|
66
|
+
std::vector<array>{std::forward<Arrays>(xs)...});
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// The single array version of the above.
|
|
70
|
+
std::pair<Shape, Strides> collapse_contiguous_dims(
|
|
71
|
+
const Shape& shape,
|
|
72
|
+
const Strides& strides,
|
|
73
|
+
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
|
74
|
+
std::pair<Shape, Strides> collapse_contiguous_dims(
|
|
75
|
+
const array& a,
|
|
76
|
+
int64_t size_cap = std::numeric_limits<int32_t>::max());
|
|
77
|
+
|
|
78
|
+
// Compute the thread block dimensions which fit the given
|
|
79
|
+
// input dimensions.
|
|
80
|
+
// - The thread block dimensions will be powers of two
|
|
81
|
+
// - The thread block size will be less than 2^pow2
|
|
82
|
+
using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
|
|
83
|
+
Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
|
|
84
|
+
|
|
85
|
+
// Computes a 2D grid where each element is < UINT_MAX
|
|
86
|
+
// Assumes:
|
|
87
|
+
// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
|
|
88
|
+
// - shape and strides correspond to a contiguous (no holes) but
|
|
89
|
+
// possibly broadcasted array
|
|
90
|
+
Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
|
|
91
|
+
|
|
92
|
+
// Same as above but we do an implicit division with divisor.
|
|
93
|
+
// Basically, equivalent to factorizing
|
|
94
|
+
// Prod(s \forall s in shape if strides[s] > 0) / divisor.
|
|
95
|
+
Dims get_2d_grid_dims_common(
|
|
96
|
+
const Shape& shape,
|
|
97
|
+
const Strides& strides,
|
|
98
|
+
size_t divisor);
|
|
99
|
+
|
|
100
|
+
// Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
|
|
101
|
+
std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
|
|
102
|
+
|
|
103
|
+
struct ContiguousIterator {
|
|
104
|
+
inline void step() {
|
|
105
|
+
int dims = shape_.size();
|
|
106
|
+
if (dims == 0) {
|
|
107
|
+
return;
|
|
108
|
+
}
|
|
109
|
+
int i = dims - 1;
|
|
110
|
+
while (pos_[i] == (shape_[i] - 1) && i > 0) {
|
|
111
|
+
pos_[i] = 0;
|
|
112
|
+
loc -= (shape_[i] - 1) * strides_[i];
|
|
113
|
+
i--;
|
|
114
|
+
}
|
|
115
|
+
pos_[i]++;
|
|
116
|
+
loc += strides_[i];
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
void seek(int64_t n) {
|
|
120
|
+
loc = 0;
|
|
121
|
+
for (int i = shape_.size() - 1; i >= 0; --i) {
|
|
122
|
+
auto q_and_r = ldiv(n, shape_[i]);
|
|
123
|
+
loc += q_and_r.rem * strides_[i];
|
|
124
|
+
pos_[i] = q_and_r.rem;
|
|
125
|
+
n = q_and_r.quot;
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
void reset() {
|
|
130
|
+
loc = 0;
|
|
131
|
+
std::fill(pos_.begin(), pos_.end(), 0);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
ContiguousIterator() {};
|
|
135
|
+
|
|
136
|
+
explicit ContiguousIterator(const array& a)
|
|
137
|
+
: shape_(a.shape()), strides_(a.strides()) {
|
|
138
|
+
if (!shape_.empty()) {
|
|
139
|
+
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
|
140
|
+
pos_ = Shape(shape_.size(), 0);
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
explicit ContiguousIterator(
|
|
145
|
+
const Shape& shape,
|
|
146
|
+
const Strides& strides,
|
|
147
|
+
int dims)
|
|
148
|
+
: shape_(shape.begin(), shape.begin() + dims),
|
|
149
|
+
strides_(strides.begin(), strides.begin() + dims) {
|
|
150
|
+
if (!shape_.empty()) {
|
|
151
|
+
std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
|
|
152
|
+
pos_ = Shape(shape_.size(), 0);
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
int64_t loc{0};
|
|
157
|
+
|
|
158
|
+
private:
|
|
159
|
+
Shape shape_;
|
|
160
|
+
Strides strides_;
|
|
161
|
+
Shape pos_;
|
|
162
|
+
};
|
|
163
|
+
|
|
164
|
+
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
|
165
|
+
size_t no_broadcast_data_size = 1;
|
|
166
|
+
int64_t f_stride = 1;
|
|
167
|
+
int64_t b_stride = 1;
|
|
168
|
+
bool is_row_contiguous = true;
|
|
169
|
+
bool is_col_contiguous = true;
|
|
170
|
+
|
|
171
|
+
for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
|
|
172
|
+
is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
|
|
173
|
+
is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
|
|
174
|
+
f_stride *= shape[i];
|
|
175
|
+
b_stride *= shape[ri];
|
|
176
|
+
if (strides[i] > 0) {
|
|
177
|
+
no_broadcast_data_size *= shape[i];
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
return std::make_tuple(
|
|
182
|
+
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
inline bool is_donatable(const array& in, const array& out) {
|
|
186
|
+
constexpr size_t donation_extra = 16384;
|
|
187
|
+
|
|
188
|
+
return in.is_donatable() && in.itemsize() == out.itemsize() &&
|
|
189
|
+
in.buffer_size() <= out.nbytes() + donation_extra;
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
|
|
193
|
+
|
|
194
|
+
void shared_buffer_reshape(
|
|
195
|
+
const array& in,
|
|
196
|
+
const Strides& out_strides,
|
|
197
|
+
array& out);
|
|
198
|
+
|
|
199
|
+
template <typename T>
|
|
200
|
+
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
|
201
|
+
vec.erase(std::next(vec.begin(), index));
|
|
202
|
+
return vec;
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/array.h"
|
|
6
|
+
#include "mlx/backend/cpu/encoder.h"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core {
|
|
9
|
+
|
|
10
|
+
namespace {
|
|
11
|
+
|
|
12
|
+
template <typename T>
|
|
13
|
+
void arange(T start, T next, array& out, size_t size, Stream stream) {
|
|
14
|
+
auto ptr = out.data<T>();
|
|
15
|
+
auto step_size = next - start;
|
|
16
|
+
auto& encoder = cpu::get_command_encoder(stream);
|
|
17
|
+
encoder.set_output_array(out);
|
|
18
|
+
encoder.dispatch([ptr, start, step_size, size]() mutable {
|
|
19
|
+
for (int i = 0; i < size; ++i) {
|
|
20
|
+
ptr[i] = start;
|
|
21
|
+
start += step_size;
|
|
22
|
+
}
|
|
23
|
+
});
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
} // namespace
|
|
27
|
+
|
|
28
|
+
} // namespace mlx::core
|