mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
6
|
+
// GEMM param classes
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
|
|
9
|
+
namespace mlx {
|
|
10
|
+
namespace steel {
|
|
11
|
+
|
|
12
|
+
struct GEMMParams {
|
|
13
|
+
const int M;
|
|
14
|
+
const int N;
|
|
15
|
+
const int K;
|
|
16
|
+
|
|
17
|
+
const int lda;
|
|
18
|
+
const int ldb;
|
|
19
|
+
const int ldd;
|
|
20
|
+
|
|
21
|
+
const int tiles_n;
|
|
22
|
+
const int tiles_m;
|
|
23
|
+
|
|
24
|
+
const int64_t batch_stride_a;
|
|
25
|
+
const int64_t batch_stride_b;
|
|
26
|
+
const int64_t batch_stride_d;
|
|
27
|
+
|
|
28
|
+
const int swizzle_log;
|
|
29
|
+
const int gemm_k_iterations_aligned;
|
|
30
|
+
|
|
31
|
+
const int batch_ndim;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
struct GEMMSpiltKParams {
|
|
35
|
+
const int M;
|
|
36
|
+
const int N;
|
|
37
|
+
const int K;
|
|
38
|
+
|
|
39
|
+
const int lda;
|
|
40
|
+
const int ldb;
|
|
41
|
+
const int ldc;
|
|
42
|
+
|
|
43
|
+
const int tiles_n;
|
|
44
|
+
const int tiles_m;
|
|
45
|
+
|
|
46
|
+
const int split_k_partitions;
|
|
47
|
+
const int split_k_partition_stride;
|
|
48
|
+
const int split_k_partition_size;
|
|
49
|
+
|
|
50
|
+
const int gemm_k_iterations_aligned;
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
struct GEMMAddMMParams {
|
|
54
|
+
const int ldc;
|
|
55
|
+
const int fdc;
|
|
56
|
+
|
|
57
|
+
const int64_t batch_stride_c;
|
|
58
|
+
|
|
59
|
+
const float alpha;
|
|
60
|
+
const float beta;
|
|
61
|
+
};
|
|
62
|
+
|
|
63
|
+
} // namespace steel
|
|
64
|
+
} // namespace mlx
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
6
|
+
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
// Transforms and Epilogues
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
|
|
11
|
+
namespace mlx {
|
|
12
|
+
namespace steel {
|
|
13
|
+
|
|
14
|
+
template <typename OutT, typename InT>
|
|
15
|
+
struct TransformNone {
|
|
16
|
+
static METAL_FUNC OutT apply(InT x) {
|
|
17
|
+
return static_cast<OutT>(x);
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
static METAL_FUNC OutT apply(InT x, OutT) {
|
|
21
|
+
return static_cast<OutT>(x);
|
|
22
|
+
}
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
template <typename OutT, typename InT>
|
|
26
|
+
struct TransformAdd {
|
|
27
|
+
TransformAdd(const float, const float) {}
|
|
28
|
+
|
|
29
|
+
static METAL_FUNC OutT apply(InT x) {
|
|
30
|
+
return static_cast<OutT>(x);
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
static METAL_FUNC OutT apply(InT x, OutT c) {
|
|
34
|
+
return static_cast<OutT>(x) + c;
|
|
35
|
+
}
|
|
36
|
+
};
|
|
37
|
+
|
|
38
|
+
template <typename OutT, typename InT>
|
|
39
|
+
struct TransformAxpby {
|
|
40
|
+
const float alpha;
|
|
41
|
+
const float beta;
|
|
42
|
+
|
|
43
|
+
TransformAxpby(const float alpha_, const float beta_)
|
|
44
|
+
: alpha(alpha_), beta(beta_) {}
|
|
45
|
+
|
|
46
|
+
static METAL_FUNC OutT apply(InT x) {
|
|
47
|
+
return static_cast<OutT>(x);
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
METAL_FUNC OutT apply(InT x, OutT c) const {
|
|
51
|
+
return static_cast<OutT>(
|
|
52
|
+
x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c));
|
|
53
|
+
}
|
|
54
|
+
};
|
|
55
|
+
|
|
56
|
+
template <typename T>
|
|
57
|
+
struct AccumHelper {
|
|
58
|
+
typedef float accum_type;
|
|
59
|
+
};
|
|
60
|
+
|
|
61
|
+
struct BlockSwizzle {
|
|
62
|
+
static METAL_FUNC int2
|
|
63
|
+
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
|
64
|
+
const int tid_x = (tid.x) >> swizzle_log;
|
|
65
|
+
const int tid_y =
|
|
66
|
+
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
|
67
|
+
return int2(tid_x, tid_y);
|
|
68
|
+
}
|
|
69
|
+
};
|
|
70
|
+
|
|
71
|
+
} // namespace steel
|
|
72
|
+
} // namespace mlx
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_stdlib>
|
|
6
|
+
#include "mlx/backend/metal/kernels/steel/utils/type_traits.h"
|
|
7
|
+
|
|
8
|
+
#pragma METAL internals : enable
|
|
9
|
+
|
|
10
|
+
namespace mlx {
|
|
11
|
+
namespace steel {
|
|
12
|
+
|
|
13
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
14
|
+
// Integral constant with casting
|
|
15
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
16
|
+
|
|
17
|
+
template <typename T, T v>
|
|
18
|
+
struct integral_constant {
|
|
19
|
+
static constexpr constant T value = v;
|
|
20
|
+
using value_type = T;
|
|
21
|
+
using type = integral_constant;
|
|
22
|
+
|
|
23
|
+
METAL_FUNC constexpr operator value_type() const noexcept {
|
|
24
|
+
return value;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
// METAL_FUNC constexpr value_type operator()() const noexcept {
|
|
28
|
+
// return value;
|
|
29
|
+
// }
|
|
30
|
+
};
|
|
31
|
+
|
|
32
|
+
template <bool B>
|
|
33
|
+
using bool_constant = integral_constant<bool, B>;
|
|
34
|
+
using true_type = bool_constant<true>;
|
|
35
|
+
using false_type = bool_constant<false>;
|
|
36
|
+
|
|
37
|
+
template <class T>
|
|
38
|
+
struct is_integral : bool_constant<metal::is_integral<T>::value> {};
|
|
39
|
+
|
|
40
|
+
template <class T, T v>
|
|
41
|
+
struct is_integral<integral_constant<T, v>>
|
|
42
|
+
: bool_constant<metal::is_integral<T>::value> {};
|
|
43
|
+
|
|
44
|
+
template <typename T>
|
|
45
|
+
constexpr constant bool is_integral_v = is_integral<T>::value;
|
|
46
|
+
|
|
47
|
+
template <int val>
|
|
48
|
+
using Int = integral_constant<int, val>;
|
|
49
|
+
|
|
50
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
51
|
+
// Binary Operators on Integral constants
|
|
52
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
53
|
+
|
|
54
|
+
#define integral_const_binop(__op__, __operator__) \
|
|
55
|
+
template <typename T, T tv, typename U, U uv> \
|
|
56
|
+
METAL_FUNC constexpr auto __operator__( \
|
|
57
|
+
integral_constant<T, tv>, integral_constant<U, uv>) { \
|
|
58
|
+
constexpr auto res = tv __op__ uv; \
|
|
59
|
+
return integral_constant<decltype(res), res>{}; \
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
integral_const_binop(+, operator+);
|
|
63
|
+
integral_const_binop(-, operator-);
|
|
64
|
+
integral_const_binop(*, operator*);
|
|
65
|
+
integral_const_binop(/, operator/);
|
|
66
|
+
|
|
67
|
+
integral_const_binop(==, operator==);
|
|
68
|
+
integral_const_binop(!=, operator!=);
|
|
69
|
+
integral_const_binop(<, operator<);
|
|
70
|
+
integral_const_binop(>, operator>);
|
|
71
|
+
integral_const_binop(<=, operator<=);
|
|
72
|
+
integral_const_binop(>=, operator>=);
|
|
73
|
+
|
|
74
|
+
integral_const_binop(&&, operator&&);
|
|
75
|
+
integral_const_binop(||, operator||);
|
|
76
|
+
|
|
77
|
+
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
|
78
|
+
METAL_FUNC constexpr auto operator||(true_type, T) {
|
|
79
|
+
return true_type{};
|
|
80
|
+
}
|
|
81
|
+
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
|
82
|
+
METAL_FUNC constexpr auto operator||(T, true_type) {
|
|
83
|
+
return true_type{};
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
|
87
|
+
METAL_FUNC constexpr auto operator&&(false_type, T) {
|
|
88
|
+
return false_type{};
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
template <typename T, typename = metal::enable_if_t<!is_integral_v<T>>>
|
|
92
|
+
METAL_FUNC constexpr auto operator&&(T, false_type) {
|
|
93
|
+
return false_type{};
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Dispatch utilities
|
|
97
|
+
template <typename F>
|
|
98
|
+
void dispatch_bool(bool v, F f) {
|
|
99
|
+
if (v) {
|
|
100
|
+
f(true_type{});
|
|
101
|
+
} else {
|
|
102
|
+
f(false_type{});
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
template <int start, int stop, int step, typename F>
|
|
107
|
+
constexpr void const_for_loop(F f) {
|
|
108
|
+
if constexpr (start < stop) {
|
|
109
|
+
constexpr auto idx = Int<start>{};
|
|
110
|
+
f(idx);
|
|
111
|
+
const_for_loop<start + step, stop, step, F>(f);
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
#undef integral_const_binop
|
|
116
|
+
|
|
117
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
118
|
+
// Reduction operators
|
|
119
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
120
|
+
|
|
121
|
+
template <typename T>
|
|
122
|
+
METAL_FUNC constexpr T sum(T x) {
|
|
123
|
+
return x;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
template <typename T, typename... Us>
|
|
127
|
+
METAL_FUNC constexpr auto sum(T x, Us... us) {
|
|
128
|
+
return x + sum(us...);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
} // namespace steel
|
|
132
|
+
} // namespace mlx
|
|
133
|
+
|
|
134
|
+
#pragma METAL internals : disable
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_stdlib>
|
|
6
|
+
|
|
7
|
+
#pragma METAL internals : enable
|
|
8
|
+
|
|
9
|
+
namespace metal {
|
|
10
|
+
|
|
11
|
+
template <typename T>
|
|
12
|
+
struct is_empty : metal::bool_constant<__is_empty(T)> {};
|
|
13
|
+
|
|
14
|
+
#ifdef __cpp_variable_templates
|
|
15
|
+
template <typename T>
|
|
16
|
+
constexpr constant bool is_empty_v = is_empty<T>::value;
|
|
17
|
+
#endif
|
|
18
|
+
|
|
19
|
+
template <typename... Ts>
|
|
20
|
+
struct make_void {
|
|
21
|
+
typedef void type;
|
|
22
|
+
};
|
|
23
|
+
|
|
24
|
+
template <typename... Ts>
|
|
25
|
+
using void_t = typename make_void<Ts...>::type;
|
|
26
|
+
|
|
27
|
+
template <class T>
|
|
28
|
+
struct is_static : metal::bool_constant<is_empty<remove_cv_t<T>>::value> {};
|
|
29
|
+
|
|
30
|
+
template <typename T>
|
|
31
|
+
struct pointer_element {};
|
|
32
|
+
|
|
33
|
+
template <typename T>
|
|
34
|
+
struct pointer_element<thread T*> {
|
|
35
|
+
using type = remove_cv_t<T>;
|
|
36
|
+
};
|
|
37
|
+
template <typename T>
|
|
38
|
+
struct pointer_element<device T*> {
|
|
39
|
+
using type = remove_cv_t<T>;
|
|
40
|
+
};
|
|
41
|
+
template <typename T>
|
|
42
|
+
struct pointer_element<constant T*> {
|
|
43
|
+
using type = remove_cv_t<T>;
|
|
44
|
+
};
|
|
45
|
+
template <typename T>
|
|
46
|
+
struct pointer_element<threadgroup T*> {
|
|
47
|
+
using type = remove_cv_t<T>;
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
template <typename T>
|
|
51
|
+
using pointer_element_t = typename pointer_element<remove_cv_t<T>>::type;
|
|
52
|
+
|
|
53
|
+
} // namespace metal
|
|
54
|
+
|
|
55
|
+
#pragma METAL internals : disable
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_stdlib>
|
|
6
|
+
|
|
7
|
+
METAL_FUNC ulong2 elem_to_loc_broadcast(
|
|
8
|
+
uint elem,
|
|
9
|
+
constant const int* shape,
|
|
10
|
+
constant const int64_t* a_strides,
|
|
11
|
+
constant const int64_t* b_strides,
|
|
12
|
+
int ndim) {
|
|
13
|
+
ulong loc_a{0};
|
|
14
|
+
ulong loc_b{0};
|
|
15
|
+
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
|
16
|
+
int pos_in_dim = (elem % shape[i]);
|
|
17
|
+
elem /= shape[i];
|
|
18
|
+
loc_a += pos_in_dim * a_strides[i];
|
|
19
|
+
loc_b += pos_in_dim * b_strides[i];
|
|
20
|
+
}
|
|
21
|
+
return ulong2(loc_a, loc_b);
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
METAL_FUNC ulong3 elem_to_loc_broadcast(
|
|
25
|
+
uint elem,
|
|
26
|
+
constant const int* shape,
|
|
27
|
+
constant const int64_t* a_strides,
|
|
28
|
+
constant const int64_t* b_strides,
|
|
29
|
+
constant const int64_t* c_strides,
|
|
30
|
+
int ndim) {
|
|
31
|
+
ulong loc_a{0};
|
|
32
|
+
ulong loc_b{0};
|
|
33
|
+
ulong loc_c{0};
|
|
34
|
+
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
|
|
35
|
+
int pos_in_dim = (elem % shape[i]);
|
|
36
|
+
elem /= shape[i];
|
|
37
|
+
loc_a += pos_in_dim * a_strides[i];
|
|
38
|
+
loc_b += pos_in_dim * b_strides[i];
|
|
39
|
+
loc_c += pos_in_dim * c_strides[i];
|
|
40
|
+
}
|
|
41
|
+
return ulong3(loc_a, loc_b, loc_c);
|
|
42
|
+
}
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
template <
|
|
4
|
+
typename T,
|
|
5
|
+
typename Op,
|
|
6
|
+
bool BSCALAR,
|
|
7
|
+
bool CSCALAR,
|
|
8
|
+
int N = WorkPerThread<T>::n>
|
|
9
|
+
[[kernel]] void ternary_v(
|
|
10
|
+
device const bool* a,
|
|
11
|
+
device const T* b,
|
|
12
|
+
device const T* c,
|
|
13
|
+
device T* d,
|
|
14
|
+
constant uint& size,
|
|
15
|
+
uint index [[thread_position_in_grid]]) {
|
|
16
|
+
index *= N;
|
|
17
|
+
if (N > 1 && index + N > size) {
|
|
18
|
+
for (int i = 0; index + i < size; ++i) {
|
|
19
|
+
auto bidx = BSCALAR ? 0 : index + i;
|
|
20
|
+
auto cidx = CSCALAR ? 0 : index + i;
|
|
21
|
+
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
|
|
22
|
+
}
|
|
23
|
+
} else {
|
|
24
|
+
for (int i = 0; i < N; ++i) {
|
|
25
|
+
auto bidx = BSCALAR ? 0 : index + i;
|
|
26
|
+
auto cidx = CSCALAR ? 0 : index + i;
|
|
27
|
+
d[index + i] = Op()(a[index + i], b[bidx], c[cidx]);
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
template <
|
|
33
|
+
typename T,
|
|
34
|
+
typename Op,
|
|
35
|
+
bool BSCALAR,
|
|
36
|
+
bool CSCALAR,
|
|
37
|
+
int N = WorkPerThread<T>::n>
|
|
38
|
+
[[kernel]] void ternary_v2(
|
|
39
|
+
device const bool* a,
|
|
40
|
+
device const T* b,
|
|
41
|
+
device const T* c,
|
|
42
|
+
device T* d,
|
|
43
|
+
constant int64_t& size,
|
|
44
|
+
uint2 index [[thread_position_in_grid]],
|
|
45
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
46
|
+
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
|
47
|
+
if (N > 1 && offset + N > size) {
|
|
48
|
+
for (int i = 0; offset + i < size; ++i) {
|
|
49
|
+
auto bidx = BSCALAR ? 0 : offset + i;
|
|
50
|
+
auto cidx = CSCALAR ? 0 : offset + i;
|
|
51
|
+
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
|
|
52
|
+
}
|
|
53
|
+
} else {
|
|
54
|
+
for (int i = 0; i < N; ++i) {
|
|
55
|
+
auto bidx = BSCALAR ? 0 : offset + i;
|
|
56
|
+
auto cidx = CSCALAR ? 0 : offset + i;
|
|
57
|
+
d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]);
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
template <typename T, typename Op, typename IdxT = int64_t>
|
|
63
|
+
[[kernel]] void ternary_g_nd1(
|
|
64
|
+
device const bool* a,
|
|
65
|
+
device const T* b,
|
|
66
|
+
device const T* c,
|
|
67
|
+
device T* d,
|
|
68
|
+
constant const int64_t& a_strides,
|
|
69
|
+
constant const int64_t& b_strides,
|
|
70
|
+
constant const int64_t& c_strides,
|
|
71
|
+
uint index [[thread_position_in_grid]]) {
|
|
72
|
+
auto a_idx = elem_to_loc_1<IdxT>(index, a_strides);
|
|
73
|
+
auto b_idx = elem_to_loc_1<IdxT>(index, b_strides);
|
|
74
|
+
auto c_idx = elem_to_loc_1<IdxT>(index, c_strides);
|
|
75
|
+
d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
template <typename T, typename Op, typename IdxT = int64_t>
|
|
79
|
+
[[kernel]] void ternary_g_nd2(
|
|
80
|
+
device const bool* a,
|
|
81
|
+
device const T* b,
|
|
82
|
+
device const T* c,
|
|
83
|
+
device T* d,
|
|
84
|
+
constant const int64_t a_strides[2],
|
|
85
|
+
constant const int64_t b_strides[2],
|
|
86
|
+
constant const int64_t c_strides[2],
|
|
87
|
+
uint2 index [[thread_position_in_grid]],
|
|
88
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
89
|
+
auto a_idx = elem_to_loc_2<IdxT>(index, a_strides);
|
|
90
|
+
auto b_idx = elem_to_loc_2<IdxT>(index, b_strides);
|
|
91
|
+
auto c_idx = elem_to_loc_2<IdxT>(index, c_strides);
|
|
92
|
+
IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y;
|
|
93
|
+
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
template <typename T, typename Op, typename IdxT = int64_t>
|
|
97
|
+
[[kernel]] void ternary_g_nd3(
|
|
98
|
+
device const bool* a,
|
|
99
|
+
device const T* b,
|
|
100
|
+
device const T* c,
|
|
101
|
+
device T* d,
|
|
102
|
+
constant const int64_t a_strides[3],
|
|
103
|
+
constant const int64_t b_strides[3],
|
|
104
|
+
constant const int64_t c_strides[3],
|
|
105
|
+
uint3 index [[thread_position_in_grid]],
|
|
106
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
107
|
+
auto a_idx = elem_to_loc_3<IdxT>(index, a_strides);
|
|
108
|
+
auto b_idx = elem_to_loc_3<IdxT>(index, b_strides);
|
|
109
|
+
auto c_idx = elem_to_loc_3<IdxT>(index, c_strides);
|
|
110
|
+
IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z);
|
|
111
|
+
d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]);
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
template <typename T, typename Op, int N = 1, typename IdxT = int64_t>
|
|
115
|
+
[[kernel]] void ternary_g(
|
|
116
|
+
device const bool* a,
|
|
117
|
+
device const T* b,
|
|
118
|
+
device const T* c,
|
|
119
|
+
device T* d,
|
|
120
|
+
constant const int* shape,
|
|
121
|
+
constant const int64_t* a_strides,
|
|
122
|
+
constant const int64_t* b_strides,
|
|
123
|
+
constant const int64_t* c_strides,
|
|
124
|
+
constant const int& ndim,
|
|
125
|
+
uint3 index [[thread_position_in_grid]],
|
|
126
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
127
|
+
auto idx = elem_to_loc_3_nd<IdxT>(
|
|
128
|
+
{N * index.x, index.y, index.z},
|
|
129
|
+
shape,
|
|
130
|
+
a_strides,
|
|
131
|
+
b_strides,
|
|
132
|
+
c_strides,
|
|
133
|
+
ndim);
|
|
134
|
+
auto xshape = shape[ndim - 1];
|
|
135
|
+
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
|
136
|
+
IdxT a_xstride = a_strides[ndim - 1];
|
|
137
|
+
IdxT b_xstride = b_strides[ndim - 1];
|
|
138
|
+
IdxT c_xstride = c_strides[ndim - 1];
|
|
139
|
+
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
|
140
|
+
d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]);
|
|
141
|
+
idx.x += a_xstride;
|
|
142
|
+
idx.y += b_xstride;
|
|
143
|
+
idx.z += c_xstride;
|
|
144
|
+
}
|
|
145
|
+
}
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|
4
|
+
[[kernel]] void unary_v(
|
|
5
|
+
device const T* in,
|
|
6
|
+
device U* out,
|
|
7
|
+
constant uint& size,
|
|
8
|
+
uint index [[thread_position_in_grid]]) {
|
|
9
|
+
index *= N;
|
|
10
|
+
if (N > 1 && index + N > size) {
|
|
11
|
+
for (int i = 0; index + i < size; ++i) {
|
|
12
|
+
out[index + i] = static_cast<U>(Op()(in[index + i]));
|
|
13
|
+
}
|
|
14
|
+
} else {
|
|
15
|
+
for (int i = 0; i < N; ++i) {
|
|
16
|
+
out[index + i] = static_cast<U>(Op()(in[index + i]));
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
|
|
22
|
+
[[kernel]] void unary_v2(
|
|
23
|
+
device const T* in,
|
|
24
|
+
device U* out,
|
|
25
|
+
constant int64_t& size,
|
|
26
|
+
uint2 index [[thread_position_in_grid]],
|
|
27
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
28
|
+
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
|
|
29
|
+
if (N > 1 && offset + N > size) {
|
|
30
|
+
for (int i = 0; offset + i < size; ++i) {
|
|
31
|
+
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
|
|
32
|
+
}
|
|
33
|
+
} else {
|
|
34
|
+
for (int i = 0; i < N; ++i) {
|
|
35
|
+
out[offset + i] = static_cast<U>(Op()(in[offset + i]));
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
template <
|
|
41
|
+
typename T,
|
|
42
|
+
typename U,
|
|
43
|
+
typename Op,
|
|
44
|
+
int N = 1,
|
|
45
|
+
typename IdxT = int64_t>
|
|
46
|
+
[[kernel]] void unary_g(
|
|
47
|
+
device const T* in,
|
|
48
|
+
device U* out,
|
|
49
|
+
constant const int* in_shape,
|
|
50
|
+
constant const int64_t* in_strides,
|
|
51
|
+
device const int& ndim,
|
|
52
|
+
uint3 index [[thread_position_in_grid]],
|
|
53
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
54
|
+
auto idx = elem_to_loc<IdxT>(
|
|
55
|
+
{N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
|
|
56
|
+
auto xshape = in_shape[ndim - 1];
|
|
57
|
+
IdxT xstride = in_strides[ndim - 1];
|
|
58
|
+
IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z);
|
|
59
|
+
for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
|
|
60
|
+
out[out_idx++] = static_cast<U>(Op()(in[idx]));
|
|
61
|
+
idx += xstride;
|
|
62
|
+
}
|
|
63
|
+
}
|