mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,20 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+ #pragma once
3
+
4
+ #include "mlx/io.h"
5
+ #include "mlx/primitives.h"
6
+ #include "mlx/transforms.h"
7
+ #include "mlx/utils.h"
8
+
9
+ extern "C" {
10
+ #include <gguflib.h>
11
+ }
12
+
13
+ namespace mlx::core {
14
+
15
+ Shape get_shape(const gguf_tensor& tensor);
16
+ void gguf_load_quantized(
17
+ std::unordered_map<std::string, array>& a,
18
+ const gguf_tensor& tensor);
19
+
20
+ } // namespace mlx::core
@@ -0,0 +1,175 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <memory>
6
+ #include <sstream>
7
+
8
+ #include <fcntl.h>
9
+ #ifdef _MSC_VER
10
+ #include <io.h>
11
+ #else
12
+ #include <sys/stat.h>
13
+ #include <unistd.h>
14
+ #endif
15
+
16
+ #include "mlx/threadpool.h"
17
+
18
+ // Strictly we need to operate on files in binary mode (to avoid \r getting
19
+ // automatically inserted), but every modern system except for Windows no
20
+ // longer differentiates between binary and text files and for them define
21
+ // the flag as no-op.
22
+ #ifndef O_BINARY
23
+ #define O_BINARY 0
24
+ #endif
25
+
26
+ namespace mlx::core {
27
+
28
+ namespace io {
29
+
30
+ ThreadPool& thread_pool();
31
+
32
+ class Reader {
33
+ public:
34
+ virtual bool is_open() const = 0;
35
+ virtual bool good() const = 0;
36
+ virtual size_t tell() = 0; // tellp is non-const in iostream
37
+ virtual void seek(
38
+ int64_t off,
39
+ std::ios_base::seekdir way = std::ios_base::beg) = 0;
40
+ virtual void read(char* data, size_t n) = 0;
41
+ virtual void read(char* data, size_t n, size_t offset) = 0;
42
+ virtual std::string label() const = 0;
43
+ virtual ~Reader() = default;
44
+ };
45
+
46
+ class Writer {
47
+ public:
48
+ virtual bool is_open() const = 0;
49
+ virtual bool good() const = 0;
50
+ virtual size_t tell() = 0;
51
+ virtual void seek(
52
+ int64_t off,
53
+ std::ios_base::seekdir way = std::ios_base::beg) = 0;
54
+ virtual void write(const char* data, size_t n) = 0;
55
+ virtual std::string label() const = 0;
56
+ virtual ~Writer() = default;
57
+ };
58
+
59
+ class ParallelFileReader : public Reader {
60
+ public:
61
+ explicit ParallelFileReader(std::string file_path)
62
+ : fd_(open(file_path.c_str(), O_RDONLY | O_BINARY)),
63
+ label_(std::move(file_path)) {}
64
+
65
+ ~ParallelFileReader() override {
66
+ close(fd_);
67
+ }
68
+
69
+ bool is_open() const override {
70
+ return fd_ > 0;
71
+ }
72
+
73
+ bool good() const override {
74
+ return is_open();
75
+ }
76
+
77
+ size_t tell() override {
78
+ return lseek(fd_, 0, SEEK_CUR);
79
+ }
80
+
81
+ // Warning: do not use this function from multiple threads as
82
+ // it advances the file descriptor
83
+ void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
84
+ override {
85
+ if (way == std::ios_base::beg) {
86
+ lseek(fd_, off, 0);
87
+ } else {
88
+ lseek(fd_, off, SEEK_CUR);
89
+ }
90
+ }
91
+
92
+ // Warning: do not use this function from multiple threads as
93
+ // it advances the file descriptor
94
+ void read(char* data, size_t n) override;
95
+
96
+ void read(char* data, size_t n, size_t offset) override;
97
+
98
+ std::string label() const override {
99
+ return "file " + label_;
100
+ }
101
+
102
+ private:
103
+ static constexpr size_t batch_size_ = 1 << 25;
104
+ static ThreadPool& thread_pool();
105
+ int fd_;
106
+ std::string label_;
107
+ };
108
+
109
+ class FileWriter : public Writer {
110
+ public:
111
+ explicit FileWriter() {}
112
+ explicit FileWriter(std::string file_path)
113
+ : fd_(open(
114
+ file_path.c_str(),
115
+ O_CREAT | O_WRONLY | O_TRUNC | O_BINARY,
116
+ 0644)),
117
+ label_(std::move(file_path)) {}
118
+
119
+ FileWriter(const FileWriter&) = delete;
120
+ FileWriter& operator=(const FileWriter&) = delete;
121
+ FileWriter(FileWriter&& other) {
122
+ std::swap(fd_, other.fd_);
123
+ }
124
+
125
+ ~FileWriter() override {
126
+ if (fd_ != 0) {
127
+ close(fd_);
128
+ }
129
+ }
130
+
131
+ bool is_open() const override {
132
+ return fd_ >= 0;
133
+ }
134
+
135
+ bool good() const override {
136
+ return is_open();
137
+ }
138
+
139
+ size_t tell() override {
140
+ return lseek(fd_, 0, SEEK_CUR);
141
+ }
142
+
143
+ void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
144
+ override {
145
+ if (way == std::ios_base::beg) {
146
+ lseek(fd_, off, 0);
147
+ } else {
148
+ lseek(fd_, off, SEEK_CUR);
149
+ }
150
+ }
151
+
152
+ void write(const char* data, size_t n) override {
153
+ while (n != 0) {
154
+ auto m = ::write(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
155
+ if (m <= 0) {
156
+ std::ostringstream msg;
157
+ msg << "[write] Unable to write " << n << " bytes to file.";
158
+ throw std::runtime_error(msg.str());
159
+ }
160
+ data += m;
161
+ n -= m;
162
+ }
163
+ }
164
+
165
+ std::string label() const override {
166
+ return "file " + label_;
167
+ }
168
+
169
+ private:
170
+ int fd_{0};
171
+ std::string label_;
172
+ };
173
+
174
+ } // namespace io
175
+ } // namespace mlx::core
mlx/include/mlx/io.h ADDED
@@ -0,0 +1,61 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <unordered_map>
6
+ #include <variant>
7
+
8
+ #include "mlx/array.h"
9
+ #include "mlx/io/load.h"
10
+ #include "mlx/stream.h"
11
+ #include "mlx/utils.h"
12
+
13
+ namespace mlx::core {
14
+ using GGUFMetaData =
15
+ std::variant<std::monostate, array, std::string, std::vector<std::string>>;
16
+ using GGUFLoad = std::pair<
17
+ std::unordered_map<std::string, array>,
18
+ std::unordered_map<std::string, GGUFMetaData>>;
19
+ using SafetensorsLoad = std::pair<
20
+ std::unordered_map<std::string, array>,
21
+ std::unordered_map<std::string, std::string>>;
22
+
23
+ /** Save array to out stream in .npy format */
24
+ void save(std::shared_ptr<io::Writer> out_stream, array a);
25
+
26
+ /** Save array to file in .npy format */
27
+ void save(std::string file, array a);
28
+
29
+ /** Load array from reader in .npy format */
30
+ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s = {});
31
+
32
+ /** Load array from file in .npy format */
33
+ array load(std::string file, StreamOrDevice s = {});
34
+
35
+ /** Load array map from .safetensors file format */
36
+ SafetensorsLoad load_safetensors(
37
+ std::shared_ptr<io::Reader> in_stream,
38
+ StreamOrDevice s = {});
39
+ SafetensorsLoad load_safetensors(
40
+ const std::string& file,
41
+ StreamOrDevice s = {});
42
+
43
+ void save_safetensors(
44
+ std::shared_ptr<io::Writer> in_stream,
45
+ std::unordered_map<std::string, array>,
46
+ std::unordered_map<std::string, std::string> metadata = {});
47
+ void save_safetensors(
48
+ std::string file,
49
+ std::unordered_map<std::string, array>,
50
+ std::unordered_map<std::string, std::string> metadata = {});
51
+
52
+ /** Load array map and metadata from .gguf file format */
53
+
54
+ GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {});
55
+
56
+ void save_gguf(
57
+ std::string file,
58
+ std::unordered_map<std::string, array> array_map,
59
+ std::unordered_map<std::string, GGUFMetaData> meta_data = {});
60
+
61
+ } // namespace mlx::core
@@ -0,0 +1,111 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+
7
+ #include "mlx/array.h"
8
+ #include "mlx/device.h"
9
+ #include "mlx/ops.h"
10
+ #include "mlx/stream.h"
11
+
12
+ namespace mlx::core::linalg {
13
+
14
+ /**
15
+ * Compute vector or matrix norms.
16
+ *
17
+ * - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
18
+ * - If axis is not provided but ord is, then x must be either 1D or 2D.
19
+ * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
20
+ * for matrices) is computed along the given axes. At most 2 axes can be
21
+ * specified.
22
+ * - If both axis and ord are provided, then the corresponding matrix or vector
23
+ * norm is computed. At most 2 axes can be specified.
24
+ */
25
+ array norm(
26
+ const array& a,
27
+ const double ord,
28
+ const std::optional<std::vector<int>>& axis = std::nullopt,
29
+ bool keepdims = false,
30
+ StreamOrDevice s = {});
31
+ inline array norm(
32
+ const array& a,
33
+ const double ord,
34
+ int axis,
35
+ bool keepdims = false,
36
+ StreamOrDevice s = {}) {
37
+ return norm(a, ord, std::vector<int>{axis}, keepdims, s);
38
+ }
39
+ array norm(
40
+ const array& a,
41
+ const std::string& ord,
42
+ const std::optional<std::vector<int>>& axis = std::nullopt,
43
+ bool keepdims = false,
44
+ StreamOrDevice s = {});
45
+ inline array norm(
46
+ const array& a,
47
+ const std::string& ord,
48
+ int axis,
49
+ bool keepdims = false,
50
+ StreamOrDevice s = {}) {
51
+ return norm(a, ord, std::vector<int>{axis}, keepdims, s);
52
+ }
53
+ array norm(
54
+ const array& a,
55
+ const std::optional<std::vector<int>>& axis = std::nullopt,
56
+ bool keepdims = false,
57
+ StreamOrDevice s = {});
58
+ inline array
59
+ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
60
+ return norm(a, std::vector<int>{axis}, keepdims, s);
61
+ }
62
+
63
+ std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
64
+
65
+ std::vector<array>
66
+ svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */);
67
+ inline std::vector<array> svd(const array& a, StreamOrDevice s = {}) {
68
+ return svd(a, true, s);
69
+ }
70
+
71
+ array inv(const array& a, StreamOrDevice s = {});
72
+
73
+ array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {});
74
+
75
+ array cholesky(const array& a, bool upper = false, StreamOrDevice s = {});
76
+
77
+ array pinv(const array& a, StreamOrDevice s = {});
78
+
79
+ array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
80
+
81
+ std::vector<array> lu(const array& a, StreamOrDevice s = {});
82
+
83
+ std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});
84
+
85
+ array solve(const array& a, const array& b, StreamOrDevice s = {});
86
+
87
+ array solve_triangular(
88
+ const array& a,
89
+ const array& b,
90
+ bool upper = false,
91
+ StreamOrDevice s = {});
92
+
93
+ /**
94
+ * Compute the cross product of two arrays along the given axis.
95
+ */
96
+ array cross(
97
+ const array& a,
98
+ const array& b,
99
+ int axis = -1,
100
+ StreamOrDevice s = {});
101
+
102
+ std::pair<array, array> eig(const array& a, StreamOrDevice s = {});
103
+
104
+ array eigvals(const array& a, StreamOrDevice s = {});
105
+
106
+ array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
107
+
108
+ std::pair<array, array>
109
+ eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
110
+
111
+ } // namespace mlx::core::linalg
@@ -0,0 +1,78 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <cstdlib>
6
+
7
+ namespace mlx::core {
8
+
9
+ /* Get the actively used memory in bytes.
10
+ *
11
+ * Note, this will not always match memory use reported by the system because
12
+ * it does not include cached memory buffers.
13
+ * */
14
+ size_t get_active_memory();
15
+
16
+ /* Get the peak amount of used memory in bytes.
17
+ *
18
+ * The maximum memory used recorded from the beginning of the program
19
+ * execution or since the last call to reset_peak_memory.
20
+ * */
21
+ size_t get_peak_memory();
22
+
23
+ /* Reset the peak memory to zero.
24
+ * */
25
+ void reset_peak_memory();
26
+
27
+ /* Get the cache size in bytes.
28
+ *
29
+ * The cache includes memory not currently used that has not been returned
30
+ * to the system allocator.
31
+ * */
32
+ size_t get_cache_memory();
33
+
34
+ /* Set the memory limit.
35
+ * The memory limit is a guideline for the maximum amount of memory to use
36
+ * during graph evaluation. If the memory limit is exceeded and there is no
37
+ * more RAM (including swap when available) allocations will result in an
38
+ * exception.
39
+ *
40
+ * When Metal is available the memory limit defaults to 1.5 times the maximum
41
+ * recommended working set size reported by the device.
42
+ *
43
+ * Returns the previous memory limit.
44
+ * */
45
+ size_t set_memory_limit(size_t limit);
46
+
47
+ /* Get the current memory limit. */
48
+ size_t get_memory_limit();
49
+
50
+ /* Set the cache limit.
51
+ * If using more than the given limit, free memory will be reclaimed
52
+ * from the cache on the next allocation. To disable the cache,
53
+ * set the limit to 0.
54
+ *
55
+ * The cache limit defaults to the memory limit.
56
+ *
57
+ * Returns the previous cache limit.
58
+ * */
59
+ size_t set_cache_limit(size_t limit);
60
+
61
+ /* Clear the memory cache. */
62
+ void clear_cache();
63
+
64
+ /* Set the wired size limit.
65
+ *
66
+ * Note, this function is only useful when using the Metal backend with
67
+ * macOS 15.0 or higher.
68
+ *
69
+ * The wired limit is the total size in bytes of memory that will be kept
70
+ * resident. The default value is ``0``.
71
+ *
72
+ * Setting a wired limit larger than system wired limit is an error.
73
+ *
74
+ * Returns the previous wired limit.
75
+ * */
76
+ size_t set_wired_limit(size_t limit);
77
+
78
+ } // namespace mlx::core
mlx/include/mlx/mlx.h ADDED
@@ -0,0 +1,25 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/cuda/cuda.h"
7
+ #include "mlx/backend/gpu/available.h"
8
+ #include "mlx/backend/metal/metal.h"
9
+ #include "mlx/compile.h"
10
+ #include "mlx/device.h"
11
+ #include "mlx/distributed/distributed.h"
12
+ #include "mlx/distributed/ops.h"
13
+ #include "mlx/einsum.h"
14
+ #include "mlx/export.h"
15
+ #include "mlx/fast.h"
16
+ #include "mlx/fft.h"
17
+ #include "mlx/io.h"
18
+ #include "mlx/linalg.h"
19
+ #include "mlx/memory.h"
20
+ #include "mlx/ops.h"
21
+ #include "mlx/random.h"
22
+ #include "mlx/stream.h"
23
+ #include "mlx/transforms.h"
24
+ #include "mlx/utils.h"
25
+ #include "mlx/version.h"