mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,397 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+ #include <algorithm>
3
+ #include <cstring>
4
+ #include <fstream>
5
+ #include <limits>
6
+ #include <sstream>
7
+
8
+ // Used by pread implementation.
9
+ #ifdef _WIN32
10
+ #include <windows.h>
11
+ #endif // _WIN32
12
+
13
+ #include "mlx/backend/cuda/cuda.h"
14
+ #include "mlx/io.h"
15
+ #include "mlx/io/load.h"
16
+ #include "mlx/ops.h"
17
+ #include "mlx/primitives.h"
18
+ #include "mlx/utils.h"
19
+
20
+ // Adapted from
21
+ // https://github.com/angeloskath/supervised-lda/blob/master/include/ldaplusplus/NumpyFormat.hpp
22
+
23
+ namespace mlx::core {
24
+
25
+ namespace {
26
+
27
+ constexpr uint8_t MAGIC[] = {
28
+ 0x93,
29
+ 0x4e,
30
+ 0x55,
31
+ 0x4d,
32
+ 0x50,
33
+ 0x59,
34
+ };
35
+
36
+ inline bool is_big_endian() {
37
+ union ByteOrder {
38
+ int32_t i;
39
+ uint8_t c[4];
40
+ };
41
+ ByteOrder b = {0x01234567};
42
+
43
+ return b.c[0] == 0x01;
44
+ }
45
+
46
+ // Array protocol typestring for Dtype
47
+ std::string dtype_to_array_protocol(const Dtype& t) {
48
+ std::ostringstream r;
49
+ if (size_of(t) > 1) {
50
+ r << (is_big_endian() ? ">" : "<");
51
+ } else {
52
+ r << "|";
53
+ }
54
+ r << kindof(t) << (int)size_of(t);
55
+ return r.str();
56
+ }
57
+
58
+ // Dtype from array protocol type string
59
+ Dtype dtype_from_array_protocol(std::string_view t) {
60
+ if (t.length() == 2 || t.length() == 3) {
61
+ std::string_view r = t.length() == 3 ? t.substr(1, 2) : t;
62
+
63
+ if (r == "V2") {
64
+ return bfloat16;
65
+ }
66
+
67
+ uint8_t size = r[1] - '0';
68
+
69
+ switch (r[0]) {
70
+ case 'b': {
71
+ if (size == 1)
72
+ return bool_;
73
+ break;
74
+ }
75
+ case 'i': {
76
+ if (size == 1)
77
+ return int8;
78
+ else if (size == 2)
79
+ return int16;
80
+ else if (size == 4)
81
+ return int32;
82
+ else if (size == 8)
83
+ return int64;
84
+ break;
85
+ }
86
+ case 'u': {
87
+ if (size == 1)
88
+ return uint8;
89
+ else if (size == 2)
90
+ return uint16;
91
+ else if (size == 4)
92
+ return uint32;
93
+ else if (size == 8)
94
+ return uint64;
95
+ break;
96
+ }
97
+ case 'f': {
98
+ if (size == 2)
99
+ return float16;
100
+ else if (size == 4)
101
+ return float32;
102
+ else if (size == 8)
103
+ return float64;
104
+ break;
105
+ }
106
+ case 'c': {
107
+ if (size == 8)
108
+ return complex64;
109
+ break;
110
+ }
111
+ }
112
+ }
113
+
114
+ throw std::invalid_argument(
115
+ "[from_str] Unsupported array protocol type-string: " + std::string(t));
116
+ }
117
+
118
+ #ifdef _WIN32
119
+ // There is no pread on Windows, emulate it with ReadFile.
120
+ int64_t pread(int fd, void* buf, uint64_t size, uint64_t offset) {
121
+ HANDLE file = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
122
+ if (file == INVALID_HANDLE_VALUE) {
123
+ return -1;
124
+ }
125
+
126
+ OVERLAPPED overlapped = {0};
127
+ overlapped.Offset = offset & 0xFFFFFFFF;
128
+ overlapped.OffsetHigh = (offset >> 32) & 0xFFFFFFFF;
129
+
130
+ DWORD bytes_read;
131
+ if (!ReadFile(file, buf, size, &bytes_read, &overlapped)) {
132
+ if (GetLastError() != ERROR_HANDLE_EOF) {
133
+ return -1;
134
+ }
135
+ }
136
+
137
+ return bytes_read;
138
+ }
139
+ #endif
140
+
141
+ } // namespace
142
+
143
+ /** Save array to out stream in .npy format */
144
+ void save(std::shared_ptr<io::Writer> out_stream, array a) {
145
+ ////////////////////////////////////////////////////////
146
+ // Check array
147
+
148
+ a = contiguous(a, true);
149
+ a.eval();
150
+
151
+ if (a.nbytes() == 0) {
152
+ throw std::invalid_argument("[save] cannot serialize an empty array");
153
+ }
154
+
155
+ ////////////////////////////////////////////////////////
156
+ // Check file
157
+ if (!out_stream->good() || !out_stream->is_open()) {
158
+ throw std::runtime_error("[save] Failed to open " + out_stream->label());
159
+ }
160
+
161
+ ////////////////////////////////////////////////////////
162
+ // Prepare header
163
+ std::ostringstream magic_ver_len;
164
+ magic_ver_len.write(reinterpret_cast<const char*>(MAGIC), 6);
165
+
166
+ std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
167
+ std::ostringstream header;
168
+ header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
169
+ << " 'fortran_order': " << fortran_order << "," << " 'shape': (";
170
+ for (auto i : a.shape()) {
171
+ header << i << ", ";
172
+ }
173
+ header << ")}";
174
+
175
+ size_t header_len = static_cast<size_t>(header.tellp());
176
+ bool is_v1 = header_len + 15 < std::numeric_limits<uint16_t>::max();
177
+
178
+ // Pad out magic + version + header_len + header + \n to be divisible by 16
179
+ size_t padding = (6 + 2 + (2 + 2 * is_v1) + header_len + 1) % 16;
180
+
181
+ header << std::string(padding, ' ') << '\n';
182
+
183
+ if (is_v1) {
184
+ magic_ver_len << (char)0x01 << (char)0x00;
185
+
186
+ uint16_t v1_header_len = header.tellp();
187
+ const char* len_bytes = reinterpret_cast<const char*>(&v1_header_len);
188
+
189
+ if (!is_big_endian()) {
190
+ magic_ver_len.write(len_bytes, 2);
191
+ } else {
192
+ magic_ver_len.write(len_bytes + 1, 1);
193
+ magic_ver_len.write(len_bytes, 1);
194
+ }
195
+ } else {
196
+ magic_ver_len << (char)0x02 << (char)0x00;
197
+
198
+ uint32_t v2_header_len = header.tellp();
199
+ const char* len_bytes = reinterpret_cast<const char*>(&v2_header_len);
200
+
201
+ if (!is_big_endian()) {
202
+ magic_ver_len.write(len_bytes, 4);
203
+ } else {
204
+ magic_ver_len.write(len_bytes + 3, 1);
205
+ magic_ver_len.write(len_bytes + 2, 1);
206
+ magic_ver_len.write(len_bytes + 1, 1);
207
+ magic_ver_len.write(len_bytes, 1);
208
+ }
209
+ }
210
+ ////////////////////////////////////////////////////////
211
+ // Serialize array
212
+
213
+ out_stream->write(magic_ver_len.str().c_str(), magic_ver_len.str().length());
214
+ out_stream->write(header.str().c_str(), header.str().length());
215
+ out_stream->write(a.data<char>(), a.nbytes());
216
+ }
217
+
218
+ /** Save array to file in .npy format */
219
+ void save(std::string file, array a) {
220
+ // Add .npy to file name if it is not there
221
+ if (file.length() < 4 || file.substr(file.length() - 4, 4) != ".npy")
222
+ file += ".npy";
223
+
224
+ // Serialize array
225
+ save(std::make_shared<io::FileWriter>(std::move(file)), a);
226
+ }
227
+
228
+ /** Load array from reader in .npy format */
229
+ array load(std::shared_ptr<io::Reader> in_stream, StreamOrDevice s) {
230
+ ////////////////////////////////////////////////////////
231
+ // Open and check file
232
+ if (!in_stream->good() || !in_stream->is_open()) {
233
+ throw std::runtime_error("[load] Failed to open " + in_stream->label());
234
+ }
235
+
236
+ auto stream = cu::is_available() ? to_stream(s) : to_stream(s, Device::cpu);
237
+
238
+ ////////////////////////////////////////////////////////
239
+ // Read header and prepare array details
240
+
241
+ // Read and check magic
242
+ char read_magic_and_ver[8];
243
+ in_stream->read(read_magic_and_ver, 8);
244
+ if (std::memcmp(read_magic_and_ver, MAGIC, 6) != 0) {
245
+ throw std::runtime_error("[load] Invalid header in " + in_stream->label());
246
+ }
247
+
248
+ // Read and check version
249
+ if (read_magic_and_ver[6] != 1 && read_magic_and_ver[6] != 2) {
250
+ throw std::runtime_error(
251
+ "[load] Unsupported npy format version in " + in_stream->label());
252
+ }
253
+
254
+ // Read header len and header
255
+ int header_len_size = read_magic_and_ver[6] == 1 ? 2 : 4;
256
+ size_t header_len;
257
+
258
+ if (header_len_size == 2) {
259
+ uint16_t v1_header_len;
260
+ in_stream->read(reinterpret_cast<char*>(&v1_header_len), header_len_size);
261
+ header_len = v1_header_len;
262
+ } else {
263
+ uint32_t v2_header_len;
264
+ in_stream->read(reinterpret_cast<char*>(&v2_header_len), header_len_size);
265
+ header_len = v2_header_len;
266
+ }
267
+
268
+ // Read the header
269
+ std::vector<char> buffer(header_len + 1);
270
+ in_stream->read(&buffer[0], header_len);
271
+ buffer[header_len] = 0;
272
+ std::string header(buffer.data(), header_len);
273
+
274
+ // Read data type from header
275
+ std::string dtype_str = header.substr(11, 3);
276
+ bool read_is_big_endian = dtype_str[0] == '>';
277
+ Dtype dtype = dtype_from_array_protocol(dtype_str);
278
+
279
+ // Read contiguity order
280
+ bool col_contiguous = header.at(34) == 'T';
281
+
282
+ // Read array shape from header
283
+ Shape shape;
284
+
285
+ size_t st = header.find_last_of('(') + 1;
286
+ size_t ed = header.find_last_of(')');
287
+ std::string shape_str = header.substr(st, ed - st);
288
+
289
+ while (!shape_str.empty()) {
290
+ // Read current number and get position of comma
291
+ size_t pos;
292
+ int dim = std::stoi(shape_str, &pos);
293
+ shape.push_back(dim);
294
+
295
+ // Skip the comma and space and read the next number
296
+ if (pos + 2 <= shape_str.length())
297
+ shape_str = shape_str.substr(pos + 2);
298
+ else {
299
+ shape_str = shape_str.substr(pos);
300
+ if (!shape_str.empty() && shape_str != " " && shape_str != ",") {
301
+ throw std::runtime_error(
302
+ "[load] Unknown error while parsing header in " +
303
+ in_stream->label());
304
+ }
305
+ shape_str = "";
306
+ }
307
+ }
308
+
309
+ ////////////////////////////////////////////////////////
310
+ // Build primitive
311
+
312
+ size_t offset = 8 + header_len_size + header.length();
313
+ bool swap_endianness = read_is_big_endian != is_big_endian();
314
+
315
+ if (col_contiguous) {
316
+ std::reverse(shape.begin(), shape.end());
317
+ }
318
+ auto loaded_array = array(
319
+ shape,
320
+ dtype,
321
+ std::make_shared<Load>(stream, in_stream, offset, swap_endianness),
322
+ std::vector<array>{});
323
+ if (col_contiguous) {
324
+ loaded_array = transpose(loaded_array, s);
325
+ }
326
+
327
+ return loaded_array;
328
+ }
329
+
330
+ /** Load array from file in .npy format */
331
+ array load(std::string file, StreamOrDevice s) {
332
+ return load(std::make_shared<io::ParallelFileReader>(std::move(file)), s);
333
+ }
334
+
335
+ namespace io {
336
+
337
+ ThreadPool& thread_pool() {
338
+ static ThreadPool pool_{4};
339
+ return pool_;
340
+ }
341
+
342
+ ThreadPool& ParallelFileReader::thread_pool() {
343
+ static ThreadPool thread_pool{4};
344
+ return thread_pool;
345
+ }
346
+
347
+ void ParallelFileReader::read(char* data, size_t n) {
348
+ while (n != 0) {
349
+ auto m = ::read(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
350
+ if (m <= 0) {
351
+ std::ostringstream msg;
352
+ msg << "[read] Unable to read " << n << " bytes from file.";
353
+ throw std::runtime_error(msg.str());
354
+ }
355
+ data += m;
356
+ n -= m;
357
+ }
358
+ }
359
+
360
+ void ParallelFileReader::read(char* data, size_t n, size_t offset) {
361
+ auto readfn = [fd = fd_](size_t offset, size_t size, char* buffer) -> bool {
362
+ while (size != 0) {
363
+ auto m = pread(fd, buffer, size, offset);
364
+ if (m <= 0) {
365
+ return false;
366
+ }
367
+ buffer += m;
368
+ size -= m;
369
+ }
370
+ return true;
371
+ };
372
+ std::vector<std::future<bool>> futs;
373
+ while (n != 0) {
374
+ if (n < batch_size_) {
375
+ if (!readfn(offset, n, data)) {
376
+ throw std::runtime_error("[read] Unable to read from file.");
377
+ }
378
+ break;
379
+ } else {
380
+ size_t m = batch_size_;
381
+ futs.emplace_back(
382
+ ParallelFileReader::thread_pool().enqueue(readfn, offset, m, data));
383
+ data += m;
384
+ n -= m;
385
+ offset += m;
386
+ }
387
+ }
388
+ for (auto& f : futs) {
389
+ if (!f.get()) {
390
+ throw std::runtime_error("[read] Unable to read from file.");
391
+ }
392
+ }
393
+ }
394
+
395
+ } // namespace io
396
+
397
+ } // namespace mlx::core
data/mlx/mlx/io/load.h ADDED
@@ -0,0 +1,175 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <memory>
6
+ #include <sstream>
7
+
8
+ #include <fcntl.h>
9
+ #ifdef _MSC_VER
10
+ #include <io.h>
11
+ #else
12
+ #include <sys/stat.h>
13
+ #include <unistd.h>
14
+ #endif
15
+
16
+ #include "mlx/threadpool.h"
17
+
18
+ // Strictly we need to operate on files in binary mode (to avoid \r getting
19
+ // automatically inserted), but every modern system except for Windows no
20
+ // longer differentiates between binary and text files and for them define
21
+ // the flag as no-op.
22
+ #ifndef O_BINARY
23
+ #define O_BINARY 0
24
+ #endif
25
+
26
+ namespace mlx::core {
27
+
28
+ namespace io {
29
+
30
+ ThreadPool& thread_pool();
31
+
32
+ class Reader {
33
+ public:
34
+ virtual bool is_open() const = 0;
35
+ virtual bool good() const = 0;
36
+ virtual size_t tell() = 0; // tellp is non-const in iostream
37
+ virtual void seek(
38
+ int64_t off,
39
+ std::ios_base::seekdir way = std::ios_base::beg) = 0;
40
+ virtual void read(char* data, size_t n) = 0;
41
+ virtual void read(char* data, size_t n, size_t offset) = 0;
42
+ virtual std::string label() const = 0;
43
+ virtual ~Reader() = default;
44
+ };
45
+
46
+ class Writer {
47
+ public:
48
+ virtual bool is_open() const = 0;
49
+ virtual bool good() const = 0;
50
+ virtual size_t tell() = 0;
51
+ virtual void seek(
52
+ int64_t off,
53
+ std::ios_base::seekdir way = std::ios_base::beg) = 0;
54
+ virtual void write(const char* data, size_t n) = 0;
55
+ virtual std::string label() const = 0;
56
+ virtual ~Writer() = default;
57
+ };
58
+
59
+ class ParallelFileReader : public Reader {
60
+ public:
61
+ explicit ParallelFileReader(std::string file_path)
62
+ : fd_(open(file_path.c_str(), O_RDONLY | O_BINARY)),
63
+ label_(std::move(file_path)) {}
64
+
65
+ ~ParallelFileReader() override {
66
+ close(fd_);
67
+ }
68
+
69
+ bool is_open() const override {
70
+ return fd_ > 0;
71
+ }
72
+
73
+ bool good() const override {
74
+ return is_open();
75
+ }
76
+
77
+ size_t tell() override {
78
+ return lseek(fd_, 0, SEEK_CUR);
79
+ }
80
+
81
+ // Warning: do not use this function from multiple threads as
82
+ // it advances the file descriptor
83
+ void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
84
+ override {
85
+ if (way == std::ios_base::beg) {
86
+ lseek(fd_, off, 0);
87
+ } else {
88
+ lseek(fd_, off, SEEK_CUR);
89
+ }
90
+ }
91
+
92
+ // Warning: do not use this function from multiple threads as
93
+ // it advances the file descriptor
94
+ void read(char* data, size_t n) override;
95
+
96
+ void read(char* data, size_t n, size_t offset) override;
97
+
98
+ std::string label() const override {
99
+ return "file " + label_;
100
+ }
101
+
102
+ private:
103
+ static constexpr size_t batch_size_ = 1 << 25;
104
+ static ThreadPool& thread_pool();
105
+ int fd_;
106
+ std::string label_;
107
+ };
108
+
109
+ class FileWriter : public Writer {
110
+ public:
111
+ explicit FileWriter() {}
112
+ explicit FileWriter(std::string file_path)
113
+ : fd_(open(
114
+ file_path.c_str(),
115
+ O_CREAT | O_WRONLY | O_TRUNC | O_BINARY,
116
+ 0644)),
117
+ label_(std::move(file_path)) {}
118
+
119
+ FileWriter(const FileWriter&) = delete;
120
+ FileWriter& operator=(const FileWriter&) = delete;
121
+ FileWriter(FileWriter&& other) {
122
+ std::swap(fd_, other.fd_);
123
+ }
124
+
125
+ ~FileWriter() override {
126
+ if (fd_ != 0) {
127
+ close(fd_);
128
+ }
129
+ }
130
+
131
+ bool is_open() const override {
132
+ return fd_ >= 0;
133
+ }
134
+
135
+ bool good() const override {
136
+ return is_open();
137
+ }
138
+
139
+ size_t tell() override {
140
+ return lseek(fd_, 0, SEEK_CUR);
141
+ }
142
+
143
+ void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg)
144
+ override {
145
+ if (way == std::ios_base::beg) {
146
+ lseek(fd_, off, 0);
147
+ } else {
148
+ lseek(fd_, off, SEEK_CUR);
149
+ }
150
+ }
151
+
152
+ void write(const char* data, size_t n) override {
153
+ while (n != 0) {
154
+ auto m = ::write(fd_, data, std::min(n, static_cast<size_t>(INT32_MAX)));
155
+ if (m <= 0) {
156
+ std::ostringstream msg;
157
+ msg << "[write] Unable to write " << n << " bytes to file.";
158
+ throw std::runtime_error(msg.str());
159
+ }
160
+ data += m;
161
+ n -= m;
162
+ }
163
+ }
164
+
165
+ std::string label() const override {
166
+ return "file " + label_;
167
+ }
168
+
169
+ private:
170
+ int fd_{0};
171
+ std::string label_;
172
+ };
173
+
174
+ } // namespace io
175
+ } // namespace mlx::core
@@ -0,0 +1,20 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include "mlx/io.h"
4
+
5
+ namespace mlx::core {
6
+
7
+ GGUFLoad load_gguf(const std::string&, StreamOrDevice s) {
8
+ throw std::runtime_error(
9
+ "[load_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.");
10
+ }
11
+
12
+ void save_gguf(
13
+ std::string,
14
+ std::unordered_map<std::string, array>,
15
+ std::unordered_map<std::string, GGUFMetaData>) {
16
+ throw std::runtime_error(
17
+ "[save_gguf] Compile with MLX_BUILD_GGUF=ON to enable GGUF support.");
18
+ }
19
+
20
+ } // namespace mlx::core
@@ -0,0 +1,37 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include "mlx/io.h"
4
+
5
+ namespace mlx::core {
6
+
7
+ SafetensorsLoad load_safetensors(std::shared_ptr<io::Reader>, StreamOrDevice) {
8
+ throw std::runtime_error(
9
+ "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
10
+ "to enable safetensors support.");
11
+ }
12
+
13
+ SafetensorsLoad load_safetensors(const std::string&, StreamOrDevice) {
14
+ throw std::runtime_error(
15
+ "[load_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
16
+ "to enable safetensors support.");
17
+ }
18
+
19
+ void save_safetensors(
20
+ std::shared_ptr<io::Writer>,
21
+ std::unordered_map<std::string, array>,
22
+ std::unordered_map<std::string, std::string>) {
23
+ throw std::runtime_error(
24
+ "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
25
+ "to enable safetensors support.");
26
+ }
27
+
28
+ void save_safetensors(
29
+ std::string file,
30
+ std::unordered_map<std::string, array>,
31
+ std::unordered_map<std::string, std::string>) {
32
+ throw std::runtime_error(
33
+ "[save_safetensors] Compile with MLX_BUILD_SAFETENSORS=ON "
34
+ "to enable safetensors support.");
35
+ }
36
+
37
+ } // namespace mlx::core