mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,12 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/distributed/distributed.h"
4
+
5
+ namespace mlx::core::distributed::ring {
6
+
7
+ using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8
+
9
+ bool is_available();
10
+ std::shared_ptr<GroupImpl> init(bool strict = false);
11
+
12
+ } // namespace mlx::core::distributed::ring
@@ -0,0 +1,206 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include <netdb.h>
4
+ #include <unistd.h>
5
+ #include <cstring>
6
+ #include <sstream>
7
+ #include <thread>
8
+
9
+ #include "mlx/distributed/utils.h"
10
+
11
+ namespace mlx::core::distributed::detail {
12
+
13
+ /**
14
+ * Parse a sockaddr from an ip and port provided as strings.
15
+ */
16
+ address_t parse_address(const std::string& ip, const std::string& port) {
17
+ struct addrinfo hints, *res;
18
+ std::memset(&hints, 0, sizeof(hints));
19
+ hints.ai_family = AF_UNSPEC;
20
+ hints.ai_socktype = SOCK_STREAM;
21
+
22
+ int status = getaddrinfo(ip.c_str(), port.c_str(), &hints, &res);
23
+ if (status != 0) {
24
+ std::ostringstream msg;
25
+ msg << "Can't parse address " << ip << ":" << port;
26
+ throw std::runtime_error(msg.str());
27
+ }
28
+
29
+ address_t result;
30
+ memcpy(&result.addr, res->ai_addr, res->ai_addrlen);
31
+ result.len = res->ai_addrlen;
32
+ freeaddrinfo(res);
33
+
34
+ return result;
35
+ }
36
+
37
+ /**
38
+ * Parse a sockaddr provided as an <ip>:<port> string.
39
+ */
40
+ address_t parse_address(const std::string& ip_port) {
41
+ auto colon = ip_port.find(":");
42
+ if (colon == std::string::npos) {
43
+ std::ostringstream msg;
44
+ msg << "Can't parse address " << ip_port;
45
+ throw std::runtime_error(msg.str());
46
+ }
47
+ std::string ip(ip_port.begin(), ip_port.begin() + colon);
48
+ std::string port(ip_port.begin() + colon + 1, ip_port.end());
49
+
50
+ return parse_address(ip, port);
51
+ }
52
+
53
+ TCPSocket::TCPSocket(const char* tag) {
54
+ sock_ = socket(AF_INET, SOCK_STREAM, 0);
55
+ if (sock_ < 0) {
56
+ std::ostringstream msg;
57
+ msg << tag << " Couldn't create socket (error: " << errno << ")";
58
+ throw std::runtime_error(msg.str());
59
+ }
60
+ }
61
+
62
+ TCPSocket::TCPSocket(TCPSocket&& s) {
63
+ sock_ = s.sock_;
64
+ s.sock_ = -1;
65
+ }
66
+
67
+ TCPSocket& TCPSocket::operator=(TCPSocket&& s) {
68
+ if (this != &s) {
69
+ sock_ = s.sock_;
70
+ s.sock_ = -1;
71
+ }
72
+ return *this;
73
+ }
74
+
75
+ TCPSocket::TCPSocket(int s) : sock_(s) {}
76
+
77
+ TCPSocket::~TCPSocket() {
78
+ if (sock_ > 0) {
79
+ shutdown(sock_, 2);
80
+ close(sock_);
81
+ }
82
+ }
83
+
84
+ int TCPSocket::detach() {
85
+ int s = sock_;
86
+ sock_ = -1;
87
+ return s;
88
+ }
89
+
90
+ void TCPSocket::listen(const char* tag, const address_t& addr) {
91
+ int success;
92
+
93
+ // Make sure we can launch immediately after shutdown by setting the
94
+ // reuseaddr option so that we don't get address already in use errors
95
+ int enable = 1;
96
+ success = setsockopt(sock_, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(int));
97
+ if (success < 0) {
98
+ std::ostringstream msg;
99
+ msg << tag << " Couldn't enable reuseaddr (error: " << errno << ")";
100
+ throw std::runtime_error(msg.str());
101
+ }
102
+ success = setsockopt(sock_, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(int));
103
+ if (success < 0) {
104
+ std::ostringstream msg;
105
+ msg << tag << " Couldn't enable reuseport (error: " << errno << ")";
106
+ throw std::runtime_error(msg.str());
107
+ }
108
+
109
+ // Bind the socket to the address and port
110
+ success = bind(sock_, addr.get(), addr.len);
111
+ if (success < 0) {
112
+ std::ostringstream msg;
113
+ msg << tag << " Couldn't bind socket (error: " << errno << ")";
114
+ throw std::runtime_error(msg.str());
115
+ }
116
+
117
+ // Prepare waiting for connections
118
+ success = ::listen(sock_, 0);
119
+ if (success < 0) {
120
+ std::ostringstream msg;
121
+ msg << tag << " Couldn't listen (error: " << errno << ")";
122
+ throw std::runtime_error(msg.str());
123
+ }
124
+ }
125
+
126
+ TCPSocket TCPSocket::accept(const char* tag) {
127
+ int peer = ::accept(sock_, nullptr, nullptr);
128
+ if (peer < 0) {
129
+ std::ostringstream msg;
130
+ msg << tag << " Accept failed (error: " << errno << ")";
131
+ throw std::runtime_error(msg.str());
132
+ }
133
+
134
+ return TCPSocket(peer);
135
+ }
136
+
137
+ void TCPSocket::send(const char* tag, const void* data, size_t len) {
138
+ while (len > 0) {
139
+ auto n = ::send(sock_, data, len, 0);
140
+ if (n <= 0) {
141
+ std::ostringstream msg;
142
+ msg << tag << " Send failed with errno=" << errno;
143
+ throw std::runtime_error(msg.str());
144
+ }
145
+ len -= n;
146
+ data = static_cast<const char*>(data) + n;
147
+ }
148
+ }
149
+
150
+ void TCPSocket::recv(const char* tag, void* data, size_t len) {
151
+ while (len > 0) {
152
+ auto n = ::recv(sock_, data, len, 0);
153
+ if (n <= 0) {
154
+ std::ostringstream msg;
155
+ msg << tag << " Recv failed with errno=" << errno;
156
+ throw std::runtime_error(msg.str());
157
+ }
158
+ len -= n;
159
+ data = static_cast<char*>(data) + n;
160
+ }
161
+ }
162
+
163
+ TCPSocket TCPSocket::connect(
164
+ const char* tag,
165
+ const address_t& addr,
166
+ int num_retries,
167
+ int wait,
168
+ std::function<void(int, int)> cb) {
169
+ int sock, success;
170
+
171
+ // Attempt to connect `num_retries` times with exponential backoff.
172
+ for (int attempt = 0; attempt < num_retries; attempt++) {
173
+ // Create the socket
174
+ sock = socket(AF_INET, SOCK_STREAM, 0);
175
+ if (sock < 0) {
176
+ std::ostringstream msg;
177
+ msg << tag << " Couldn't create socket to connect (error: " << errno
178
+ << ")";
179
+ throw std::runtime_error(msg.str());
180
+ }
181
+
182
+ success = ::connect(sock, addr.get(), addr.len);
183
+ if (success == 0) {
184
+ break;
185
+ }
186
+
187
+ if (cb != nullptr) {
188
+ cb(attempt, wait);
189
+ }
190
+ if (wait > 0) {
191
+ std::this_thread::sleep_for(std::chrono::milliseconds(wait));
192
+ }
193
+
194
+ wait <<= 1;
195
+ }
196
+
197
+ if (success < 0) {
198
+ std::ostringstream msg;
199
+ msg << tag << " Couldn't connect (error: " << errno << ")";
200
+ throw std::runtime_error(msg.str());
201
+ }
202
+
203
+ return TCPSocket(sock);
204
+ }
205
+
206
+ } // namespace mlx::core::distributed::detail
@@ -0,0 +1,67 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <sys/socket.h>
6
+ #include <functional>
7
+ #include <string>
8
+
9
+ namespace mlx::core::distributed::detail {
10
+
11
+ struct address_t {
12
+ sockaddr_storage addr;
13
+ socklen_t len;
14
+
15
+ const sockaddr* get() const {
16
+ return (struct sockaddr*)&addr;
17
+ }
18
+ };
19
+
20
+ /**
21
+ * Parse a sockaddr from an ip and port provided as strings.
22
+ */
23
+ address_t parse_address(const std::string& ip, const std::string& port);
24
+
25
+ /**
26
+ * Parse a sockaddr provided as an <ip>:<port> string.
27
+ */
28
+ address_t parse_address(const std::string& ip_port);
29
+
30
+ /**
31
+ * Small wrapper over a TCP socket to simplify initiating connections.
32
+ */
33
+ class TCPSocket {
34
+ public:
35
+ TCPSocket(const char* tag);
36
+ TCPSocket(const TCPSocket&) = delete;
37
+ TCPSocket& operator=(const TCPSocket&) = delete;
38
+ TCPSocket(TCPSocket&& s);
39
+ TCPSocket& operator=(TCPSocket&&);
40
+ ~TCPSocket();
41
+
42
+ void listen(const char* tag, const address_t& addr);
43
+ TCPSocket accept(const char* tag);
44
+
45
+ void send(const char* tag, const void* data, size_t len);
46
+ void recv(const char* tag, void* data, size_t len);
47
+
48
+ int detach();
49
+
50
+ operator int() const {
51
+ return sock_;
52
+ }
53
+
54
+ static TCPSocket connect(
55
+ const char* tag,
56
+ const address_t& addr,
57
+ int num_retries = 1,
58
+ int wait = 0,
59
+ std::function<void(int, int)> cb = nullptr);
60
+
61
+ private:
62
+ TCPSocket(int sock);
63
+
64
+ int sock_;
65
+ };
66
+
67
+ } // namespace mlx::core::distributed::detail
data/mlx/mlx/dtype.cpp ADDED
@@ -0,0 +1,197 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <cstdint>
4
+
5
+ #include "mlx/dtype.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ namespace {
10
+
11
+ constexpr int num_types = 14;
12
+ constexpr int num_cats = 8;
13
+
14
+ constexpr Dtype::Kind type_kinds[num_types] = {
15
+ Dtype::Kind::b, // bool_,
16
+ Dtype::Kind::u, // uint8,
17
+ Dtype::Kind::u, // uint16,
18
+ Dtype::Kind::u, // uint32,
19
+ Dtype::Kind::u, // uint64,
20
+ Dtype::Kind::i, // int8,
21
+ Dtype::Kind::i, // int16,
22
+ Dtype::Kind::i, // int32,
23
+ Dtype::Kind::i, // int64,
24
+ Dtype::Kind::f, // float16,
25
+ Dtype::Kind::f, // float32,
26
+ Dtype::Kind::f, // float64,
27
+ Dtype::Kind::V, // bfloat16,
28
+ Dtype::Kind::c // complex64,
29
+ };
30
+
31
+ // Following Jax type promotion rules:
32
+ // https://jax.readthedocs.io/en/latest/type_promotion.html
33
+ // clang-format off
34
+ constexpr Dtype type_rules[num_types][num_types] = {
35
+ // bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 float64 bfloat16 complex64
36
+ {bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // bool
37
+ {uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint8
38
+ {uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // uint16
39
+ {uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // uint32
40
+ {uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, float64, bfloat16, complex64}, // uint64
41
+ {int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int8
42
+ {int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, float64, bfloat16, complex64}, // int16
43
+ {int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, float64, bfloat16, complex64}, // int32
44
+ {int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, float64, bfloat16, complex64}, // int64
45
+ {float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float64, float32, complex64}, // float16
46
+ {float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float64, float32, complex64}, // float32
47
+ {float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, float64, complex64}, // float64
48
+ {bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, float64, bfloat16, complex64}, // bfloat16
49
+ {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64,complex64, complex64, complex64}, // complex64
50
+ };
51
+
52
+
53
+ constexpr bool subcategory_to_category[num_cats][num_cats] = {
54
+ // complexfloating floating inexact signedinteger unsignedinteger integer number generic
55
+ {true, false, true, false, false, false, true, true}, // complexfloating
56
+ {false, true, true, false, false, false, true, true}, // floating
57
+ {false, false, true, false, false, false, true, true}, // inexact
58
+ {false, false, false, true, false, true, true, true}, // signedinteger
59
+ {false, false, false, false, true, true, true, true}, // unsignedinteger
60
+ {false, false, false, false, false, true, true, true}, // integer
61
+ {false, false, false, false, false, false, true, true}, // number
62
+ {false, false, false, false, false, false, false, true}, // generic
63
+ };
64
+
65
+ constexpr Dtype::Category type_to_category[num_types] = {
66
+ Dtype::Category::generic, // bool_,
67
+ Dtype::Category::unsignedinteger, // uint8,
68
+ Dtype::Category::unsignedinteger, // uint16,
69
+ Dtype::Category::unsignedinteger, // uint32,
70
+ Dtype::Category::unsignedinteger, // uint64,
71
+ Dtype::Category::signedinteger, // int8,
72
+ Dtype::Category::signedinteger, // int16,
73
+ Dtype::Category::signedinteger, // int32,
74
+ Dtype::Category::signedinteger, // int64,
75
+ Dtype::Category::floating, // float16,
76
+ Dtype::Category::floating, // float32,
77
+ Dtype::Category::floating, // float64,
78
+ Dtype::Category::floating, // bfloat16,
79
+ Dtype::Category::complexfloating, // complex64,
80
+ };
81
+
82
+ // clang-format on
83
+
84
+ } // namespace
85
+
86
+ Dtype promote_types(const Dtype& t1, const Dtype& t2) {
87
+ return Dtype(
88
+ type_rules[static_cast<int>(t1.val())][static_cast<int>(t2.val())]);
89
+ }
90
+
91
+ Dtype::Kind kindof(const Dtype& t) {
92
+ return type_kinds[static_cast<int>(t.val())];
93
+ }
94
+
95
+ template class MLX_API TypeToDtype<bool>;
96
+ template class MLX_API TypeToDtype<uint8_t>;
97
+ template class MLX_API TypeToDtype<uint16_t>;
98
+ template class MLX_API TypeToDtype<uint32_t>;
99
+ template class MLX_API TypeToDtype<uint64_t>;
100
+ template class MLX_API TypeToDtype<int8_t>;
101
+ template class MLX_API TypeToDtype<int16_t>;
102
+ template class MLX_API TypeToDtype<int32_t>;
103
+ template class MLX_API TypeToDtype<int64_t>;
104
+ template class MLX_API TypeToDtype<float16_t>;
105
+ template class MLX_API TypeToDtype<float>;
106
+ template class MLX_API TypeToDtype<double>;
107
+ template class MLX_API TypeToDtype<bfloat16_t>;
108
+ template class MLX_API TypeToDtype<complex64_t>;
109
+
110
+ template <>
111
+ TypeToDtype<bool>::operator Dtype() {
112
+ return bool_;
113
+ }
114
+
115
+ template <>
116
+ TypeToDtype<uint8_t>::operator Dtype() {
117
+ return uint8;
118
+ }
119
+
120
+ template <>
121
+ TypeToDtype<uint16_t>::operator Dtype() {
122
+ return uint16;
123
+ }
124
+
125
+ template <>
126
+ TypeToDtype<uint32_t>::operator Dtype() {
127
+ return uint32;
128
+ }
129
+
130
+ template <>
131
+ TypeToDtype<uint64_t>::operator Dtype() {
132
+ return uint64;
133
+ }
134
+
135
+ template <>
136
+ TypeToDtype<int8_t>::operator Dtype() {
137
+ return int8;
138
+ }
139
+
140
+ template <>
141
+ TypeToDtype<int16_t>::operator Dtype() {
142
+ return int16;
143
+ }
144
+
145
+ template <>
146
+ TypeToDtype<int32_t>::operator Dtype() {
147
+ return int32;
148
+ }
149
+
150
+ template <>
151
+ TypeToDtype<int64_t>::operator Dtype() {
152
+ return int64;
153
+ }
154
+
155
+ template <>
156
+ TypeToDtype<float16_t>::operator Dtype() {
157
+ return float16;
158
+ }
159
+
160
+ template <>
161
+ TypeToDtype<float>::operator Dtype() {
162
+ return float32;
163
+ }
164
+
165
+ template <>
166
+ TypeToDtype<double>::operator Dtype() {
167
+ return float32;
168
+ }
169
+
170
+ template <>
171
+ TypeToDtype<bfloat16_t>::operator Dtype() {
172
+ return bfloat16;
173
+ }
174
+
175
+ template <>
176
+ TypeToDtype<complex64_t>::operator Dtype() {
177
+ return complex64;
178
+ }
179
+
180
+ bool issubdtype(const Dtype& a, const Dtype& b) {
181
+ return a == b;
182
+ }
183
+
184
+ bool issubdtype(const Dtype::Category& cat, const Dtype& type) {
185
+ return false;
186
+ }
187
+
188
+ bool issubdtype(const Dtype& type, const Dtype::Category& cat) {
189
+ return issubdtype(type_to_category[static_cast<uint32_t>(type.val())], cat);
190
+ }
191
+
192
+ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b) {
193
+ return subcategory_to_category[static_cast<uint32_t>(a)]
194
+ [static_cast<uint32_t>(b)];
195
+ }
196
+
197
+ } // namespace mlx::core
data/mlx/mlx/dtype.h ADDED
@@ -0,0 +1,116 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <complex>
6
+ #include <cstdint>
7
+
8
+ #include "mlx/api.h"
9
+ #include "mlx/types/complex.h"
10
+ #include "mlx/types/half_types.h"
11
+
12
+ namespace mlx::core {
13
+
14
+ struct Dtype {
15
+ enum class Val {
16
+ bool_,
17
+ uint8,
18
+ uint16,
19
+ uint32,
20
+ uint64,
21
+ int8,
22
+ int16,
23
+ int32,
24
+ int64,
25
+ float16,
26
+ float32,
27
+ float64,
28
+ bfloat16,
29
+ complex64,
30
+ };
31
+
32
+ enum class Kind {
33
+ b, /* bool */
34
+ u, /* unsigned int */
35
+ i, /* signed int */
36
+ f, /* float */
37
+ c, /* complex */
38
+ V, /* void - used for brain float */
39
+ };
40
+
41
+ enum class Category {
42
+ complexfloating,
43
+ floating,
44
+ inexact,
45
+ signedinteger,
46
+ unsignedinteger,
47
+ integer,
48
+ number,
49
+ generic
50
+ };
51
+
52
+ constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {}
53
+
54
+ constexpr operator Val() const {
55
+ return val_;
56
+ }
57
+ constexpr Val val() const {
58
+ return val_;
59
+ }
60
+ constexpr uint8_t size() const {
61
+ return size_;
62
+ }
63
+
64
+ private:
65
+ Val val_;
66
+ uint8_t size_;
67
+ };
68
+
69
+ inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
70
+
71
+ inline constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
72
+ inline constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
73
+ inline constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
74
+ inline constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
75
+
76
+ inline constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
77
+ inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
78
+ inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
79
+ inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
80
+
81
+ inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
82
+ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
83
+ inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)};
84
+ inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
85
+ inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
86
+
87
+ inline constexpr Dtype::Category complexfloating =
88
+ Dtype::Category::complexfloating;
89
+ inline constexpr Dtype::Category floating = Dtype::Category::floating;
90
+ inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
91
+ inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
92
+ inline constexpr Dtype::Category unsignedinteger =
93
+ Dtype::Category::unsignedinteger;
94
+ inline constexpr Dtype::Category integer = Dtype::Category::integer;
95
+ inline constexpr Dtype::Category number = Dtype::Category::number;
96
+ inline constexpr Dtype::Category generic = Dtype::Category::generic;
97
+
98
+ MLX_API bool issubdtype(const Dtype& a, const Dtype& b);
99
+ MLX_API bool issubdtype(const Dtype::Category& a, const Dtype& b);
100
+ MLX_API bool issubdtype(const Dtype& a, const Dtype::Category& b);
101
+ MLX_API bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
102
+
103
+ MLX_API Dtype promote_types(const Dtype& t1, const Dtype& t2);
104
+
105
+ inline uint8_t size_of(const Dtype& t) {
106
+ return t.size();
107
+ }
108
+
109
+ MLX_API Dtype::Kind kindof(const Dtype& t);
110
+
111
+ template <typename T>
112
+ struct MLX_API TypeToDtype {
113
+ operator Dtype();
114
+ };
115
+
116
+ } // namespace mlx::core
@@ -0,0 +1,42 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/dtype_utils.h"
4
+
5
+ namespace mlx::core {
6
+
7
+ const char* dtype_to_string(Dtype arg) {
8
+ switch (arg) {
9
+ case bool_:
10
+ return "bool";
11
+ case int8:
12
+ return "int8";
13
+ case int16:
14
+ return "int16";
15
+ case int32:
16
+ return "int32";
17
+ case int64:
18
+ return "int64";
19
+ case uint8:
20
+ return "uint8";
21
+ case uint16:
22
+ return "uint16";
23
+ case uint32:
24
+ return "uint32";
25
+ case uint64:
26
+ return "uint64";
27
+ case float16:
28
+ return "float16";
29
+ case bfloat16:
30
+ return "bfloat16";
31
+ case float32:
32
+ return "float32";
33
+ case float64:
34
+ return "float64";
35
+ case complex64:
36
+ return "complex64";
37
+ default:
38
+ return "unknown";
39
+ }
40
+ }
41
+
42
+ } // namespace mlx::core