mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <algorithm>
|
|
6
|
+
#include <cmath>
|
|
7
|
+
#include <cstdint>
|
|
8
|
+
#include <vector>
|
|
9
|
+
|
|
10
|
+
#define __MLX_HALF_NAN__ 0x7D00
|
|
11
|
+
|
|
12
|
+
namespace mlx::core {
|
|
13
|
+
|
|
14
|
+
namespace {
|
|
15
|
+
union float_bits_fp16 {
|
|
16
|
+
float f;
|
|
17
|
+
uint32_t u;
|
|
18
|
+
};
|
|
19
|
+
} // namespace
|
|
20
|
+
|
|
21
|
+
struct _MLX_Float16 {
|
|
22
|
+
uint16_t bits_;
|
|
23
|
+
|
|
24
|
+
// Default constructor
|
|
25
|
+
_MLX_Float16() = default;
|
|
26
|
+
|
|
27
|
+
// Default copy constructor
|
|
28
|
+
_MLX_Float16(_MLX_Float16 const&) = default;
|
|
29
|
+
|
|
30
|
+
// Appease std::vector<bool> for being special
|
|
31
|
+
_MLX_Float16& operator=(std::vector<bool>::reference x) {
|
|
32
|
+
bits_ = x;
|
|
33
|
+
return *this;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
_MLX_Float16& operator=(const float& x) {
|
|
37
|
+
return (*this = _MLX_Float16(x));
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// From float32
|
|
41
|
+
_MLX_Float16(const float& x) : bits_(0) {
|
|
42
|
+
// Conversion following
|
|
43
|
+
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
|
44
|
+
|
|
45
|
+
// Union
|
|
46
|
+
float_bits_fp16 in;
|
|
47
|
+
|
|
48
|
+
// Take fp32 bits
|
|
49
|
+
in.f = x;
|
|
50
|
+
|
|
51
|
+
// Find and take sign bit
|
|
52
|
+
uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
|
|
53
|
+
uint16_t x_sign_16 = (x_sign_32 >> 16);
|
|
54
|
+
|
|
55
|
+
if (std::isnan(x)) {
|
|
56
|
+
bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
|
|
57
|
+
} else {
|
|
58
|
+
// Union
|
|
59
|
+
float_bits_fp16 inf_scale, zero_scale, magic_bits;
|
|
60
|
+
|
|
61
|
+
// Find exponent bits and take the max supported by half
|
|
62
|
+
uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
|
|
63
|
+
uint32_t max_expo_32 = uint32_t(0x38800000);
|
|
64
|
+
x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
|
|
65
|
+
x_expo_32 += uint32_t(15) << 23;
|
|
66
|
+
|
|
67
|
+
// Handle scaling to inf as needed
|
|
68
|
+
inf_scale.u = uint32_t(0x77800000);
|
|
69
|
+
zero_scale.u = uint32_t(0x08800000);
|
|
70
|
+
|
|
71
|
+
// Combine with magic and let addition do rounding
|
|
72
|
+
magic_bits.u = x_expo_32;
|
|
73
|
+
magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
|
|
74
|
+
|
|
75
|
+
// Take the lower 5 bits of the exponent
|
|
76
|
+
uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
|
|
77
|
+
|
|
78
|
+
// Collect the lower 12 bits which have the mantissa
|
|
79
|
+
uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
|
|
80
|
+
|
|
81
|
+
// Combine sign, exp and mantissa
|
|
82
|
+
bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// To float32
|
|
87
|
+
operator float() const {
|
|
88
|
+
// Conversion following
|
|
89
|
+
// https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
|
|
90
|
+
|
|
91
|
+
// Union
|
|
92
|
+
float_bits_fp16 out;
|
|
93
|
+
|
|
94
|
+
uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
|
|
95
|
+
uint32_t base = (bits_ << 16);
|
|
96
|
+
uint32_t two_base = base + base;
|
|
97
|
+
|
|
98
|
+
uint32_t denorm_max = 1u << 27;
|
|
99
|
+
if (two_base < denorm_max) {
|
|
100
|
+
out.u = uint32_t(126) << 23; // magic mask
|
|
101
|
+
out.u |= (two_base >> 17); // Bits from fp16
|
|
102
|
+
out.f -= 0.5f; // magic bias
|
|
103
|
+
} else {
|
|
104
|
+
out.u = uint32_t(0xE0) << 23; // exponent offset
|
|
105
|
+
out.u += (two_base >> 4); // Bits from fp16
|
|
106
|
+
float out_unscaled = out.f; // Store value
|
|
107
|
+
out.u = uint32_t(0x7800000); // exponent scale
|
|
108
|
+
out.f *= out_unscaled;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
// Add sign
|
|
112
|
+
out.u |= x_sign_32;
|
|
113
|
+
|
|
114
|
+
return out.f;
|
|
115
|
+
}
|
|
116
|
+
};
|
|
117
|
+
|
|
118
|
+
#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
|
119
|
+
inline otype __operator__(atype lhs, btype rhs) { \
|
|
120
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
|
124
|
+
inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
|
|
125
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
126
|
+
} \
|
|
127
|
+
inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
|
|
128
|
+
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
// Operators
|
|
132
|
+
#define half_binop(__op__, __operator__) \
|
|
133
|
+
half_binop_base( \
|
|
134
|
+
__op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
|
|
135
|
+
half_binop_helper(__op__, __operator__, float, float, float); \
|
|
136
|
+
half_binop_helper(__op__, __operator__, double, double, double); \
|
|
137
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
|
|
138
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
|
|
139
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
|
|
140
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
|
|
141
|
+
half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
|
|
142
|
+
|
|
143
|
+
half_binop(+, operator+);
|
|
144
|
+
half_binop(-, operator-);
|
|
145
|
+
half_binop(*, operator*);
|
|
146
|
+
half_binop(/, operator/);
|
|
147
|
+
|
|
148
|
+
#undef half_binop
|
|
149
|
+
|
|
150
|
+
// Comparison ops
|
|
151
|
+
#define half_compop(__op__, __operator__) \
|
|
152
|
+
half_binop_base( \
|
|
153
|
+
__op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
|
|
154
|
+
half_binop_helper(__op__, __operator__, bool, float, float); \
|
|
155
|
+
half_binop_helper(__op__, __operator__, bool, double, double); \
|
|
156
|
+
half_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
|
157
|
+
half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
|
158
|
+
half_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
|
159
|
+
half_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
|
160
|
+
|
|
161
|
+
half_compop(>, operator>);
|
|
162
|
+
half_compop(<, operator<);
|
|
163
|
+
half_compop(>=, operator>=);
|
|
164
|
+
half_compop(<=, operator<=);
|
|
165
|
+
half_compop(==, operator==);
|
|
166
|
+
half_compop(!=, operator!=);
|
|
167
|
+
|
|
168
|
+
#undef half_compop
|
|
169
|
+
|
|
170
|
+
// Negative
|
|
171
|
+
inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
|
|
172
|
+
return -static_cast<float>(lhs);
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// Inplace ops
|
|
176
|
+
#define half_inplace_op(__op__, __operator__) \
|
|
177
|
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
|
|
178
|
+
lhs = lhs __op__ rhs; \
|
|
179
|
+
return lhs; \
|
|
180
|
+
} \
|
|
181
|
+
inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
|
|
182
|
+
lhs = lhs __op__ rhs; \
|
|
183
|
+
return lhs; \
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
half_inplace_op(+, operator+=);
|
|
187
|
+
half_inplace_op(-, operator-=);
|
|
188
|
+
half_inplace_op(*, operator*=);
|
|
189
|
+
half_inplace_op(/, operator/=);
|
|
190
|
+
|
|
191
|
+
#undef half_inplace_op
|
|
192
|
+
|
|
193
|
+
// Bitwise ops
|
|
194
|
+
|
|
195
|
+
#define half_bitop(__op__, __operator__) \
|
|
196
|
+
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
|
|
197
|
+
_MLX_Float16 out; \
|
|
198
|
+
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
|
199
|
+
return out; \
|
|
200
|
+
} \
|
|
201
|
+
inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
|
|
202
|
+
_MLX_Float16 out; \
|
|
203
|
+
out.bits_ = lhs.bits_ __op__ rhs; \
|
|
204
|
+
return out; \
|
|
205
|
+
} \
|
|
206
|
+
inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
|
|
207
|
+
_MLX_Float16 out; \
|
|
208
|
+
out.bits_ = lhs __op__ rhs.bits_; \
|
|
209
|
+
return out; \
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
half_bitop(|, operator|);
|
|
213
|
+
half_bitop(&, operator&);
|
|
214
|
+
half_bitop(^, operator^);
|
|
215
|
+
|
|
216
|
+
#undef half_bitop
|
|
217
|
+
|
|
218
|
+
#define half_inplace_bitop(__op__, __operator__) \
|
|
219
|
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
|
|
220
|
+
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
|
|
221
|
+
return lhs; \
|
|
222
|
+
} \
|
|
223
|
+
inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
|
|
224
|
+
lhs.bits_ = lhs.bits_ __op__ rhs; \
|
|
225
|
+
return lhs; \
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
half_inplace_bitop(|, operator|=);
|
|
229
|
+
half_inplace_bitop(&, operator&=);
|
|
230
|
+
half_inplace_bitop(^, operator^=);
|
|
231
|
+
|
|
232
|
+
#undef half_inplace_bitop
|
|
233
|
+
|
|
234
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
|
6
|
+
|
|
7
|
+
#include <arm_fp16.h>
|
|
8
|
+
namespace mlx::core {
|
|
9
|
+
using ::float16_t;
|
|
10
|
+
} // namespace mlx::core
|
|
11
|
+
|
|
12
|
+
#else
|
|
13
|
+
|
|
14
|
+
#define ADD_HALF_BINOPS
|
|
15
|
+
#include "mlx/types/fp16.h"
|
|
16
|
+
namespace mlx::core {
|
|
17
|
+
typedef struct _MLX_Float16 float16_t;
|
|
18
|
+
} // namespace mlx::core
|
|
19
|
+
|
|
20
|
+
#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
|
|
21
|
+
|
|
22
|
+
#ifdef __ARM_FEATURE_BF16
|
|
23
|
+
|
|
24
|
+
#include <arm_bf16.h>
|
|
25
|
+
namespace mlx::core {
|
|
26
|
+
using ::bfloat16_t;
|
|
27
|
+
} // namespace mlx::core
|
|
28
|
+
|
|
29
|
+
#else
|
|
30
|
+
|
|
31
|
+
#define ADD_HALF_BINOPS
|
|
32
|
+
#include "mlx/types/bf16.h"
|
|
33
|
+
namespace mlx::core {
|
|
34
|
+
typedef struct _MLX_BFloat16 bfloat16_t;
|
|
35
|
+
} // namespace mlx::core
|
|
36
|
+
|
|
37
|
+
#endif // __ARM_FEATURE_BF16
|
|
38
|
+
|
|
39
|
+
#ifdef ADD_HALF_BINOPS
|
|
40
|
+
namespace mlx::core {
|
|
41
|
+
|
|
42
|
+
// clang-format off
|
|
43
|
+
#define fp16_bf16_binop_helper(__op__, __operator__) \
|
|
44
|
+
inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
|
|
45
|
+
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
46
|
+
} \
|
|
47
|
+
inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
|
|
48
|
+
return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
fp16_bf16_binop_helper(+, operator+)
|
|
52
|
+
fp16_bf16_binop_helper(-, operator-)
|
|
53
|
+
fp16_bf16_binop_helper(*, operator*)
|
|
54
|
+
fp16_bf16_binop_helper(/, operator/)
|
|
55
|
+
// clang-format on
|
|
56
|
+
|
|
57
|
+
} // namespace mlx::core
|
|
58
|
+
#endif
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
#pragma once
|
|
3
|
+
|
|
4
|
+
#include <limits>
|
|
5
|
+
#include "mlx/types/half_types.h"
|
|
6
|
+
|
|
7
|
+
namespace mlx::core {
|
|
8
|
+
|
|
9
|
+
template <typename T>
|
|
10
|
+
struct numeric_limits;
|
|
11
|
+
|
|
12
|
+
template <>
|
|
13
|
+
struct numeric_limits<float> : public std::numeric_limits<float> {};
|
|
14
|
+
|
|
15
|
+
template <>
|
|
16
|
+
struct numeric_limits<double> : public std::numeric_limits<double> {};
|
|
17
|
+
|
|
18
|
+
template <>
|
|
19
|
+
struct numeric_limits<float16_t> {
|
|
20
|
+
private:
|
|
21
|
+
union half_or_bits {
|
|
22
|
+
uint16_t bits;
|
|
23
|
+
float16_t value;
|
|
24
|
+
};
|
|
25
|
+
constexpr static float16_t bits_to_half(uint16_t v) {
|
|
26
|
+
return half_or_bits{v}.value;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
public:
|
|
30
|
+
constexpr static float16_t lowest() {
|
|
31
|
+
return bits_to_half(0xFBFF);
|
|
32
|
+
}
|
|
33
|
+
static constexpr float16_t max() {
|
|
34
|
+
return bits_to_half(0x7BFF);
|
|
35
|
+
}
|
|
36
|
+
static constexpr float16_t epsilon() {
|
|
37
|
+
return bits_to_half(0x1400);
|
|
38
|
+
}
|
|
39
|
+
static constexpr float16_t infinity() {
|
|
40
|
+
return bits_to_half(0x7C00);
|
|
41
|
+
}
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
template <>
|
|
45
|
+
struct numeric_limits<bfloat16_t> {
|
|
46
|
+
private:
|
|
47
|
+
union bfloat_or_bits {
|
|
48
|
+
uint16_t bits;
|
|
49
|
+
bfloat16_t value;
|
|
50
|
+
};
|
|
51
|
+
constexpr static bfloat16_t bits_to_bfloat(uint16_t v) {
|
|
52
|
+
return bfloat_or_bits{v}.value;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
public:
|
|
56
|
+
constexpr static bfloat16_t lowest() {
|
|
57
|
+
return bits_to_bfloat(0xFF7F);
|
|
58
|
+
}
|
|
59
|
+
static constexpr bfloat16_t max() {
|
|
60
|
+
return bits_to_bfloat(0x7F7F);
|
|
61
|
+
}
|
|
62
|
+
static constexpr bfloat16_t epsilon() {
|
|
63
|
+
return bits_to_bfloat(0x3C00);
|
|
64
|
+
}
|
|
65
|
+
static constexpr bfloat16_t infinity() {
|
|
66
|
+
return bits_to_bfloat(0x7F80);
|
|
67
|
+
}
|
|
68
|
+
};
|
|
69
|
+
|
|
70
|
+
} // namespace mlx::core
|
mlx/include/mlx/utils.h
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <exception>
|
|
6
|
+
#include <variant>
|
|
7
|
+
|
|
8
|
+
#include "mlx/array.h"
|
|
9
|
+
#include "mlx/device.h"
|
|
10
|
+
#include "mlx/dtype.h"
|
|
11
|
+
#include "mlx/stream.h"
|
|
12
|
+
|
|
13
|
+
namespace mlx::core {
|
|
14
|
+
|
|
15
|
+
using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
|
|
16
|
+
Stream to_stream(StreamOrDevice s);
|
|
17
|
+
Stream to_stream(StreamOrDevice s, Device default_);
|
|
18
|
+
|
|
19
|
+
struct StreamContext {
|
|
20
|
+
public:
|
|
21
|
+
StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
|
|
22
|
+
if (std::holds_alternative<std::monostate>(s)) {
|
|
23
|
+
throw std::runtime_error(
|
|
24
|
+
"[StreamContext] Invalid argument, please specify a stream or device.");
|
|
25
|
+
}
|
|
26
|
+
auto _s = to_stream(s);
|
|
27
|
+
set_default_device(_s.device);
|
|
28
|
+
set_default_stream(_s);
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
~StreamContext() {
|
|
32
|
+
set_default_device(_stream.device);
|
|
33
|
+
set_default_stream(_stream);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
private:
|
|
37
|
+
Stream _stream;
|
|
38
|
+
};
|
|
39
|
+
|
|
40
|
+
struct PrintFormatter {
|
|
41
|
+
inline void print(std::ostream& os, bool val);
|
|
42
|
+
inline void print(std::ostream& os, int16_t val);
|
|
43
|
+
inline void print(std::ostream& os, uint16_t val);
|
|
44
|
+
inline void print(std::ostream& os, int32_t val);
|
|
45
|
+
inline void print(std::ostream& os, uint32_t val);
|
|
46
|
+
inline void print(std::ostream& os, int64_t val);
|
|
47
|
+
inline void print(std::ostream& os, uint64_t val);
|
|
48
|
+
inline void print(std::ostream& os, float16_t val);
|
|
49
|
+
inline void print(std::ostream& os, bfloat16_t val);
|
|
50
|
+
inline void print(std::ostream& os, float val);
|
|
51
|
+
inline void print(std::ostream& os, double val);
|
|
52
|
+
inline void print(std::ostream& os, complex64_t val);
|
|
53
|
+
|
|
54
|
+
bool capitalize_bool{false};
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
PrintFormatter& get_global_formatter();
|
|
58
|
+
|
|
59
|
+
/** Print the exception and then abort. */
|
|
60
|
+
void abort_with_exception(const std::exception& error);
|
|
61
|
+
|
|
62
|
+
/** Holds information about floating-point types. */
|
|
63
|
+
struct finfo {
|
|
64
|
+
explicit finfo(Dtype dtype);
|
|
65
|
+
Dtype dtype;
|
|
66
|
+
double min;
|
|
67
|
+
double max;
|
|
68
|
+
double eps;
|
|
69
|
+
};
|
|
70
|
+
|
|
71
|
+
/** Holds information about integral types. */
|
|
72
|
+
struct iinfo {
|
|
73
|
+
explicit iinfo(Dtype dtype);
|
|
74
|
+
Dtype dtype;
|
|
75
|
+
int64_t min;
|
|
76
|
+
uint64_t max;
|
|
77
|
+
};
|
|
78
|
+
|
|
79
|
+
/** The type from promoting the arrays' types with one another. */
|
|
80
|
+
inline Dtype result_type(const array& a, const array& b) {
|
|
81
|
+
return promote_types(a.dtype(), b.dtype());
|
|
82
|
+
}
|
|
83
|
+
inline Dtype result_type(const array& a, const array& b, const array& c) {
|
|
84
|
+
return promote_types(result_type(a, b), c.dtype());
|
|
85
|
+
}
|
|
86
|
+
Dtype result_type(const std::vector<array>& arrays);
|
|
87
|
+
|
|
88
|
+
Shape broadcast_shapes(const Shape& s1, const Shape& s2);
|
|
89
|
+
|
|
90
|
+
/**
|
|
91
|
+
* Returns the axis normalized to be in the range [0, ndim).
|
|
92
|
+
*/
|
|
93
|
+
int normalize_axis_index(
|
|
94
|
+
int axis,
|
|
95
|
+
int ndim,
|
|
96
|
+
const std::string& msg_prefix = "");
|
|
97
|
+
|
|
98
|
+
std::ostream& operator<<(std::ostream& os, const Device& d);
|
|
99
|
+
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
|
100
|
+
std::ostream& operator<<(std::ostream& os, const Dtype& d);
|
|
101
|
+
std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
|
|
102
|
+
std::ostream& operator<<(std::ostream& os, array a);
|
|
103
|
+
inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
|
|
104
|
+
return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
|
|
105
|
+
}
|
|
106
|
+
inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
|
107
|
+
return os << static_cast<float>(v);
|
|
108
|
+
}
|
|
109
|
+
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
|
110
|
+
return os << static_cast<float>(v);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
|
|
114
|
+
inline std::ostream& operator<<(std::ostream& os, const Vec& v) {
|
|
115
|
+
os << "(";
|
|
116
|
+
for (auto it = v.begin(); it != v.end(); ++it) {
|
|
117
|
+
os << *it;
|
|
118
|
+
if (it != std::prev(v.end())) {
|
|
119
|
+
os << ",";
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
os << ")";
|
|
123
|
+
return os;
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
inline bool is_power_of_2(int n) {
|
|
127
|
+
return ((n & (n - 1)) == 0) && n != 0;
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
inline int next_power_of_2(int n) {
|
|
131
|
+
if (is_power_of_2(n)) {
|
|
132
|
+
return n;
|
|
133
|
+
}
|
|
134
|
+
return pow(2, std::ceil(std::log2(n)));
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
namespace env {
|
|
138
|
+
|
|
139
|
+
int get_var(const char* name, int default_value);
|
|
140
|
+
|
|
141
|
+
inline int bfs_max_width() {
|
|
142
|
+
static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
|
|
143
|
+
return bfs_max_width_;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
inline int max_ops_per_buffer(int default_value) {
|
|
147
|
+
static int max_ops_per_buffer_ =
|
|
148
|
+
get_var("MLX_MAX_OPS_PER_BUFFER", default_value);
|
|
149
|
+
return max_ops_per_buffer_;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
inline int max_mb_per_buffer(int default_value) {
|
|
153
|
+
static int max_mb_per_buffer_ =
|
|
154
|
+
get_var("MLX_MAX_MB_PER_BUFFER", default_value);
|
|
155
|
+
return max_mb_per_buffer_;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
inline bool metal_fast_synch() {
|
|
159
|
+
static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
|
|
160
|
+
return metal_fast_synch;
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
inline bool enable_tf32() {
|
|
164
|
+
static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
|
|
165
|
+
return enable_tf32_;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
inline int nccl_timeout(int default_value) {
|
|
169
|
+
static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
|
|
170
|
+
return nccl_timeout;
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
} // namespace env
|
|
174
|
+
|
|
175
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#define MLX_VERSION_MAJOR 0
|
|
6
|
+
#define MLX_VERSION_MINOR 30
|
|
7
|
+
#define MLX_VERSION_PATCH 1
|
|
8
|
+
#define MLX_VERSION_NUMERIC \
|
|
9
|
+
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
|
|
10
|
+
|
|
11
|
+
namespace mlx::core {
|
|
12
|
+
|
|
13
|
+
/* A string representation of the MLX version in the format
|
|
14
|
+
* "major.minor.patch".
|
|
15
|
+
*
|
|
16
|
+
* For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
|
|
17
|
+
*/
|
|
18
|
+
const char* version();
|
|
19
|
+
|
|
20
|
+
} // namespace mlx::core
|
mlx/lib/libmlx.so
ADDED
|
Binary file
|
mlx/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
|
|
2
|
+
# directories.
|
|
3
|
+
|
|
4
|
+
set(NCCL_ROOT_DIR
|
|
5
|
+
$ENV{NCCL_ROOT_DIR}
|
|
6
|
+
CACHE PATH "Folder contains NVIDIA NCCL")
|
|
7
|
+
|
|
8
|
+
find_path(
|
|
9
|
+
NCCL_INCLUDE_DIRS
|
|
10
|
+
NAMES nccl.h
|
|
11
|
+
HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
|
|
12
|
+
${CUDA_TOOLKIT_ROOT_DIR}/include)
|
|
13
|
+
|
|
14
|
+
if($ENV{USE_STATIC_NCCL})
|
|
15
|
+
message(
|
|
16
|
+
STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
|
|
17
|
+
set(NCCL_LIBNAME "libnccl_static.a")
|
|
18
|
+
else()
|
|
19
|
+
set(NCCL_LIBNAME "nccl")
|
|
20
|
+
endif()
|
|
21
|
+
|
|
22
|
+
find_library(
|
|
23
|
+
NCCL_LIBRARIES
|
|
24
|
+
NAMES ${NCCL_LIBNAME}
|
|
25
|
+
HINTS ${NCCL_LIB_DIR}
|
|
26
|
+
${NCCL_ROOT_DIR}
|
|
27
|
+
${NCCL_ROOT_DIR}/lib
|
|
28
|
+
${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
|
|
29
|
+
${NCCL_ROOT_DIR}/lib64
|
|
30
|
+
${CUDA_TOOLKIT_ROOT_DIR}/lib
|
|
31
|
+
${CUDA_TOOLKIT_ROOT_DIR}/lib64)
|
|
32
|
+
|
|
33
|
+
include(FindPackageHandleStandardArgs)
|
|
34
|
+
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
|
|
35
|
+
NCCL_LIBRARIES)
|
|
36
|
+
|
|
37
|
+
if(NCCL_FOUND)
|
|
38
|
+
set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
|
|
39
|
+
message(
|
|
40
|
+
STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
|
|
41
|
+
file(
|
|
42
|
+
STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
|
|
43
|
+
REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
|
|
44
|
+
LIMIT_COUNT 1)
|
|
45
|
+
if(NCCL_MAJOR_VERSION_DEFINED)
|
|
46
|
+
string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
|
|
47
|
+
NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
|
|
48
|
+
message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
|
|
49
|
+
endif()
|
|
50
|
+
message(
|
|
51
|
+
STATUS
|
|
52
|
+
"Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
|
|
53
|
+
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
|
|
54
|
+
endif()
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Find MLX
|
|
2
|
+
#
|
|
3
|
+
# Defines the following variables:
|
|
4
|
+
#
|
|
5
|
+
# MLX_FOUND : True if MLX is found
|
|
6
|
+
# MLX_INCLUDE_DIRS : Include directory
|
|
7
|
+
# MLX_LIBRARIES : Libraries to link against
|
|
8
|
+
# MLX_CXX_FLAGS : Additional compiler flags
|
|
9
|
+
# MLX_BUILD_ACCELERATE : True if MLX was built with accelerate
|
|
10
|
+
# MLX_BUILD_METAL : True if MLX was built with metal
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() #######
|
|
14
|
+
####### Any changes to this file will be overwritten by the next CMake run ####
|
|
15
|
+
####### The input file was mlx.pc.in ########
|
|
16
|
+
|
|
17
|
+
get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
|
|
18
|
+
|
|
19
|
+
macro(set_and_check _var _file)
|
|
20
|
+
set(${_var} "${_file}")
|
|
21
|
+
if(NOT EXISTS "${_file}")
|
|
22
|
+
message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !")
|
|
23
|
+
endif()
|
|
24
|
+
endmacro()
|
|
25
|
+
|
|
26
|
+
####################################################################################
|
|
27
|
+
|
|
28
|
+
include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/MLXTargets.cmake)
|
|
29
|
+
include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/extension.cmake)
|
|
30
|
+
|
|
31
|
+
set_and_check(MLX_LIBRARY_DIRS ${PACKAGE_PREFIX_DIR}/lib)
|
|
32
|
+
set_and_check(MLX_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
|
|
33
|
+
set(MLX_LIBRARIES mlx)
|
|
34
|
+
|
|
35
|
+
find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS})
|
|
36
|
+
|
|
37
|
+
if (OFF)
|
|
38
|
+
set(MLX_BUILD_ACCELERATE OFF)
|
|
39
|
+
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK)
|
|
40
|
+
endif()
|
|
41
|
+
|
|
42
|
+
if (OFF)
|
|
43
|
+
set(MLX_BUILD_METAL OFF)
|
|
44
|
+
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
|
45
|
+
set(MLX_INCLUDE_DIRS
|
|
46
|
+
"${MLX_INCLUDE_DIRS};"
|
|
47
|
+
${PACKAGE_PREFIX_DIR}/include/metal_cpp
|
|
48
|
+
)
|
|
49
|
+
if( GREATER_EQUAL 310)
|
|
50
|
+
set(MLX_INCLUDE_DIRS
|
|
51
|
+
"${MLX_INCLUDE_DIRS};"
|
|
52
|
+
${PACKAGE_PREFIX_DIR}/include/mlx/backend/metal/kernels/metal_3_1)
|
|
53
|
+
else()
|
|
54
|
+
set(MLX_INCLUDE_DIRS
|
|
55
|
+
"${MLX_INCLUDE_DIRS};"
|
|
56
|
+
${PACKAGE_PREFIX_DIR}/include/mlx/backend/metal/kernels/metal_3_0)
|
|
57
|
+
endif()
|
|
58
|
+
endif()
|
|
59
|
+
|
|
60
|
+
set_target_properties(mlx PROPERTIES
|
|
61
|
+
CXX_STANDARD 17
|
|
62
|
+
INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
include(FindPackageHandleStandardArgs)
|
|
66
|
+
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|