mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,870 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include <fcntl.h>
4
+ #include <netinet/tcp.h>
5
+ #include <sys/socket.h>
6
+ #include <unistd.h>
7
+
8
+ #include <chrono>
9
+ #include <fstream>
10
+ #include <future>
11
+ #include <iostream>
12
+ #include <list>
13
+ #include <sstream>
14
+ #include <thread>
15
+ #include <unordered_map>
16
+
17
+ #include <json.hpp>
18
+
19
+ #include "mlx/backend/cpu/encoder.h"
20
+ #include "mlx/distributed/distributed.h"
21
+ #include "mlx/distributed/distributed_impl.h"
22
+ #include "mlx/distributed/reduction_ops.h"
23
+ #include "mlx/distributed/utils.h"
24
+ #include "mlx/threadpool.h"
25
+
26
+ #ifndef SOL_TCP
27
+ #define SOL_TCP IPPROTO_TCP
28
+ #endif
29
+
30
+ #define SWITCH_TYPE(x, ...) \
31
+ switch ((x).dtype()) { \
32
+ case bool_: { \
33
+ using T = bool; \
34
+ __VA_ARGS__; \
35
+ } break; \
36
+ case int8: { \
37
+ using T = int8_t; \
38
+ __VA_ARGS__; \
39
+ } break; \
40
+ case int16: { \
41
+ using T = int16_t; \
42
+ __VA_ARGS__; \
43
+ } break; \
44
+ case int32: { \
45
+ using T = int32_t; \
46
+ __VA_ARGS__; \
47
+ } break; \
48
+ case int64: { \
49
+ using T = int64_t; \
50
+ __VA_ARGS__; \
51
+ } break; \
52
+ case uint8: { \
53
+ using T = uint8_t; \
54
+ __VA_ARGS__; \
55
+ } break; \
56
+ case uint16: { \
57
+ using T = uint16_t; \
58
+ __VA_ARGS__; \
59
+ } break; \
60
+ case uint32: { \
61
+ using T = uint32_t; \
62
+ __VA_ARGS__; \
63
+ } break; \
64
+ case uint64: { \
65
+ using T = uint64_t; \
66
+ __VA_ARGS__; \
67
+ } break; \
68
+ case bfloat16: { \
69
+ using T = bfloat16_t; \
70
+ __VA_ARGS__; \
71
+ } break; \
72
+ case float16: { \
73
+ using T = float16_t; \
74
+ __VA_ARGS__; \
75
+ } break; \
76
+ case float32: { \
77
+ using T = float; \
78
+ __VA_ARGS__; \
79
+ } break; \
80
+ case float64: { \
81
+ using T = double; \
82
+ __VA_ARGS__; \
83
+ } break; \
84
+ case complex64: { \
85
+ using T = complex64_t; \
86
+ __VA_ARGS__; \
87
+ } break; \
88
+ }
89
+
90
+ namespace mlx::core::distributed::ring {
91
+
92
+ constexpr const size_t ALL_SUM_SIZE = 8 * 1024 * 1024;
93
+ constexpr const size_t ALL_SUM_BUFFERS = 2;
94
+ constexpr const int CONN_ATTEMPTS = 5;
95
+ constexpr const int CONN_WAIT = 1000;
96
+ constexpr const char* RING_TAG = "[ring]";
97
+
98
+ using GroupImpl = mlx::core::distributed::detail::GroupImpl;
99
+ using json = nlohmann::json;
100
+ using namespace std::chrono_literals;
101
+
102
+ namespace {
103
+
104
+ template <typename T>
105
+ void log(std::ostream& os, T first) {
106
+ os << first << std::endl;
107
+ }
108
+
109
+ template <typename T, typename... Args>
110
+ void log(std::ostream& os, T first, Args... args) {
111
+ log(os << first << " ", args...);
112
+ }
113
+
114
+ template <typename... Args>
115
+ void log_info(bool verbose, Args... args) {
116
+ if (!verbose) {
117
+ return;
118
+ }
119
+
120
+ log(std::cerr, "[ring]", args...);
121
+ }
122
+
123
+ template <typename T, typename U>
124
+ decltype(T() * U()) ceildiv(T a, U b) {
125
+ return (a + b - 1) / b;
126
+ }
127
+
128
+ class SocketThread {
129
+ public:
130
+ SocketThread(int fd) : fd_(fd), stop_(false) {
131
+ worker_ = std::thread(&SocketThread::worker, this);
132
+ int flags = fcntl(fd, F_GETFL, 0);
133
+ fcntl(fd, F_SETFL, flags | O_NONBLOCK);
134
+ }
135
+ ~SocketThread() {
136
+ stop_ = true;
137
+ condition_.notify_all();
138
+ worker_.join();
139
+ int flags = fcntl(fd_, F_GETFL, 0);
140
+ fcntl(fd_, F_SETFL, flags & ~O_NONBLOCK);
141
+ }
142
+
143
+ template <typename T>
144
+ std::future<void> send(const T* buffer, size_t size) {
145
+ return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T));
146
+ }
147
+
148
+ template <typename T>
149
+ std::future<void> recv(T* buffer, size_t size) {
150
+ return recv_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
151
+ }
152
+
153
+ private:
154
+ struct SocketTask {
155
+ SocketTask(void* b, size_t s, std::promise<void>&& p)
156
+ : buffer(b), size(s), promise(std::move(p)) {}
157
+ SocketTask(SocketTask&& t)
158
+ : buffer(t.buffer), size(t.size), promise(std::move(t.promise)) {}
159
+ void* buffer;
160
+ size_t size;
161
+ std::promise<void> promise;
162
+ };
163
+
164
+ std::future<void> send_impl(const char* buffer, size_t size) {
165
+ std::promise<void> send_completed_promise;
166
+ auto send_completed_future = send_completed_promise.get_future();
167
+ if (size == 0) {
168
+ send_completed_promise.set_value();
169
+ return send_completed_future;
170
+ }
171
+
172
+ {
173
+ std::unique_lock lock(queue_mutex_);
174
+ sends_.emplace_back(SocketTask(
175
+ const_cast<char*>(buffer), size, std::move(send_completed_promise)));
176
+ }
177
+ condition_.notify_one();
178
+ return send_completed_future;
179
+ }
180
+
181
+ std::future<void> recv_impl(char* buffer, size_t size) {
182
+ std::promise<void> recv_completed_promise;
183
+ auto recv_completed_future = recv_completed_promise.get_future();
184
+ if (size == 0) {
185
+ recv_completed_promise.set_value();
186
+ return recv_completed_future;
187
+ }
188
+
189
+ {
190
+ std::unique_lock lock(queue_mutex_);
191
+ recvs_.emplace_back(
192
+ SocketTask(buffer, size, std::move(recv_completed_promise)));
193
+ }
194
+ condition_.notify_one();
195
+ return recv_completed_future;
196
+ }
197
+
198
+ bool have_tasks() {
199
+ return !(sends_.empty() && recvs_.empty());
200
+ }
201
+
202
+ void worker() {
203
+ int error_count = 0;
204
+ bool delete_recv = false;
205
+ bool delete_send = false;
206
+ while (true) {
207
+ {
208
+ std::unique_lock lock(queue_mutex_);
209
+
210
+ if (delete_recv) {
211
+ recvs_.front().promise.set_value();
212
+ recvs_.pop_front();
213
+ delete_recv = false;
214
+ }
215
+ if (delete_send) {
216
+ sends_.front().promise.set_value();
217
+ sends_.pop_front();
218
+ delete_send = false;
219
+ }
220
+
221
+ if (stop_) {
222
+ return;
223
+ }
224
+
225
+ if (!have_tasks()) {
226
+ condition_.wait(lock, [this] { return stop_ || have_tasks(); });
227
+ if (stop_) {
228
+ return;
229
+ }
230
+ }
231
+ }
232
+
233
+ if (!recvs_.empty()) {
234
+ auto& task = recvs_.front();
235
+ ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
236
+ if (r > 0) {
237
+ task.buffer = static_cast<char*>(task.buffer) + r;
238
+ task.size -= r;
239
+ delete_recv = task.size == 0;
240
+ error_count = 0;
241
+ } else if (errno != EAGAIN) {
242
+ error_count++;
243
+ log_info(
244
+ true, "Receiving from socket", fd_, "failed with errno", errno);
245
+ }
246
+ }
247
+ if (!sends_.empty()) {
248
+ auto& task = sends_.front();
249
+ ssize_t r = ::send(fd_, task.buffer, task.size, 0);
250
+ if (r > 0) {
251
+ task.buffer = static_cast<char*>(task.buffer) + r;
252
+ task.size -= r;
253
+ delete_send = task.size == 0;
254
+ error_count = 0;
255
+ } else if (errno != EAGAIN) {
256
+ error_count++;
257
+ log_info(true, "Sending to socket", fd_, "failed with errno", errno);
258
+ }
259
+ }
260
+
261
+ if (error_count >= 10) {
262
+ log_info(true, "Too many send/recv errors. Aborting...");
263
+ return;
264
+ }
265
+ }
266
+ }
267
+
268
+ int fd_;
269
+ bool stop_;
270
+ std::thread worker_;
271
+ std::mutex queue_mutex_;
272
+ std::condition_variable condition_;
273
+ std::list<SocketTask> sends_;
274
+ std::list<SocketTask> recvs_;
275
+ };
276
+
277
+ class CommunicationThreads {
278
+ public:
279
+ void add(const std::vector<int>& sockets) {
280
+ for (int sock : sockets) {
281
+ threads_.emplace(sock, sock);
282
+ }
283
+ }
284
+
285
+ template <typename T>
286
+ std::future<void> send(int socket, T* buffer, size_t size) {
287
+ return threads_.at(socket).send<T>(buffer, size);
288
+ }
289
+
290
+ template <typename T>
291
+ std::future<void> recv(int socket, T* buffer, size_t size) {
292
+ return threads_.at(socket).recv<T>(buffer, size);
293
+ }
294
+
295
+ private:
296
+ std::unordered_map<int, SocketThread> threads_;
297
+ };
298
+
299
+ /**
300
+ * Load all addresses from the json hostfile. The hostfile is a list of
301
+ * addresses in order of rank. For each rank there can be many addresses so
302
+ * that we can have multiple connections between peers.
303
+ *
304
+ * For example:
305
+ * [
306
+ * ["ip1:5000", "ip1:5001"],
307
+ * ["ip2:5000", "ip2:5001"],
308
+ * ["ip3:5000", "ip3:5001"],
309
+ * ]
310
+ */
311
+ std::vector<std::vector<detail::address_t>> load_nodes(const char* hostfile) {
312
+ std::vector<std::vector<detail::address_t>> nodes;
313
+ std::ifstream f(hostfile);
314
+
315
+ json hosts = json::parse(f);
316
+ for (auto& h : hosts) {
317
+ std::vector<detail::address_t> host;
318
+ for (auto& ips : h) {
319
+ host.push_back(std::move(detail::parse_address(ips.get<std::string>())));
320
+ }
321
+ nodes.push_back(std::move(host));
322
+ }
323
+
324
+ return nodes;
325
+ }
326
+
327
+ /**
328
+ * Create a socket and accept one connection for each of the provided
329
+ * addresses.
330
+ */
331
+ std::vector<int> accept_connections(
332
+ const std::vector<detail::address_t>& addresses) {
333
+ std::vector<int> sockets;
334
+ int success;
335
+
336
+ for (auto& address : addresses) {
337
+ detail::TCPSocket socket(RING_TAG);
338
+ socket.listen(RING_TAG, address);
339
+ sockets.push_back(socket.accept(RING_TAG).detach());
340
+ }
341
+
342
+ return sockets;
343
+ }
344
+
345
+ /**
346
+ * The counterpoint of `accept_connections`. Basically connect to each of the
347
+ * provided addresses.
348
+ */
349
+ std::vector<int> make_connections(
350
+ const std::vector<detail::address_t>& addresses,
351
+ bool verbose) {
352
+ std::vector<int> sockets;
353
+ int success;
354
+
355
+ for (auto& address : addresses) {
356
+ sockets.push_back(
357
+ detail::TCPSocket::connect(
358
+ RING_TAG,
359
+ address,
360
+ CONN_ATTEMPTS,
361
+ CONN_WAIT,
362
+ [verbose](int attempt, int wait) {
363
+ log_info(
364
+ verbose,
365
+ "Attempt",
366
+ attempt,
367
+ "waiting",
368
+ wait,
369
+ "ms (error:",
370
+ errno,
371
+ ")");
372
+ })
373
+ .detach());
374
+ }
375
+
376
+ return sockets;
377
+ }
378
+
379
+ } // namespace
380
+
381
+ class RingGroup : public GroupImpl {
382
+ public:
383
+ RingGroup(
384
+ int rank,
385
+ std::vector<std::vector<detail::address_t>> nodes,
386
+ bool verbose)
387
+ : rank_(rank), verbose_(verbose), pool_(0) {
388
+ if (rank_ > 0 && rank_ >= nodes.size()) {
389
+ throw std::runtime_error(
390
+ "[ring] Rank cannot be larger than the size of the group");
391
+ }
392
+
393
+ size_ = nodes.size();
394
+ int connect_to = (rank_ + 1) % size_;
395
+
396
+ // We define the connection order by having the rank_ == size_ - 1 connect
397
+ // first and accept after.
398
+ if (rank_ < connect_to) {
399
+ log_info(verbose_, "Rank", rank_, "accepting");
400
+ sockets_left_ = accept_connections(nodes[rank_]);
401
+ log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
402
+ sockets_right_ = make_connections(nodes[connect_to], verbose);
403
+ } else {
404
+ log_info(verbose_, "Rank", rank_, "connecting to", connect_to);
405
+ sockets_right_ = make_connections(nodes[connect_to], verbose);
406
+ log_info(verbose_, "Rank", rank_, "accepting");
407
+ sockets_left_ = accept_connections(nodes[rank_]);
408
+ }
409
+
410
+ // Failure if we couldn't make right or left sockets
411
+ if (sockets_right_.empty()) {
412
+ std::ostringstream msg;
413
+ msg << "[ring] Rank " << rank_ << " has no sockets to the right.";
414
+ throw std::invalid_argument(msg.str());
415
+ }
416
+ if (sockets_left_.empty()) {
417
+ std::ostringstream msg;
418
+ msg << "[ring] Rank " << rank_ << " has no sockets to the left.";
419
+ throw std::invalid_argument(msg.str());
420
+ }
421
+
422
+ // The following could be relaxed since we can define non-homogeneous rings
423
+ // but it makes things a bit simpler for now.
424
+ if (sockets_right_.size() != sockets_left_.size()) {
425
+ std::ostringstream msg;
426
+ msg << "[ring] It is required to have as many connections to the left as "
427
+ << "to the right but rank " << rank_ << " has "
428
+ << sockets_right_.size() << " connections to the right and "
429
+ << sockets_left_.size() << " to the left.";
430
+ throw std::invalid_argument(msg.str());
431
+ }
432
+
433
+ // Configure all sockets to use TCP no delay.
434
+ int one = 1;
435
+ for (int i = 0; i < sockets_right_.size(); i++) {
436
+ setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
437
+ setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
438
+ }
439
+
440
+ // Start the all reduce threads. One all reduce per direction per ring.
441
+ pool_.resize(sockets_right_.size() + sockets_left_.size());
442
+
443
+ // Create a communication thread per socket. This also converts them to
444
+ // non-blocking.
445
+ comm_.add(sockets_right_);
446
+ comm_.add(sockets_left_);
447
+
448
+ // Allocate buffers for the all sum
449
+ buffers_.resize(
450
+ (sockets_right_.size() + sockets_left_.size()) * ALL_SUM_BUFFERS *
451
+ ALL_SUM_SIZE);
452
+ }
453
+
454
+ ~RingGroup() {
455
+ for (auto s : sockets_right_) {
456
+ shutdown(s, 2);
457
+ close(s);
458
+ }
459
+ for (auto s : sockets_left_) {
460
+ shutdown(s, 2);
461
+ close(s);
462
+ }
463
+ }
464
+
465
+ Stream communication_stream(StreamOrDevice s) override {
466
+ return to_stream(s, Device::cpu);
467
+ }
468
+
469
+ int rank() override {
470
+ return rank_;
471
+ }
472
+
473
+ int size() override {
474
+ return size_;
475
+ }
476
+
477
+ void all_sum(const array& input, array& output, Stream stream) override {
478
+ SWITCH_TYPE(
479
+ output, all_reduce<T>(input, output, stream, detail::SumOp<T>()));
480
+ }
481
+
482
+ void all_max(const array& input, array& output, Stream stream) override {
483
+ SWITCH_TYPE(
484
+ output, all_reduce<T>(input, output, stream, detail::MaxOp<T>()));
485
+ }
486
+
487
+ void all_min(const array& input, array& output, Stream stream) override {
488
+ SWITCH_TYPE(
489
+ output, all_reduce<T>(input, output, stream, detail::MinOp<T>()));
490
+ }
491
+
492
+ std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
493
+ throw std::runtime_error("[ring] Group split not supported.");
494
+ }
495
+
496
+ void all_gather(const array& input, array& output, Stream stream) override {
497
+ auto& encoder = cpu::get_command_encoder(stream);
498
+ encoder.set_input_array(input);
499
+ encoder.set_output_array(output);
500
+ encoder.dispatch([input_ptr = input.data<char>(),
501
+ nbytes = input.nbytes(),
502
+ output_ptr = output.data<char>(),
503
+ this]() {
504
+ constexpr size_t min_send_size = 262144;
505
+ size_t n_gathers = std::max(
506
+ std::min(
507
+ sockets_right_.size() + sockets_left_.size(),
508
+ nbytes / min_send_size),
509
+ size_t(1));
510
+ size_t bytes_per_gather = ceildiv(nbytes, n_gathers);
511
+ std::vector<std::future<void>> all_gathers;
512
+ for (int i = 0; i < n_gathers; i++) {
513
+ auto offset = i * bytes_per_gather;
514
+ all_gathers.emplace_back(pool_.enqueue(
515
+ std::bind(
516
+ &RingGroup::all_gather_impl,
517
+ this,
518
+ input_ptr + offset,
519
+ output_ptr + offset,
520
+ nbytes,
521
+ offset + bytes_per_gather > nbytes ? nbytes - offset
522
+ : bytes_per_gather,
523
+ sockets_right_[i / 2],
524
+ sockets_left_[i / 2],
525
+ (i % 2) ? -1 : 1)));
526
+ }
527
+ for (auto& f : all_gathers) {
528
+ f.wait();
529
+ }
530
+ });
531
+ }
532
+
533
+ void send(const array& input, int dst, Stream stream) override {
534
+ auto& encoder = cpu::get_command_encoder(stream);
535
+ encoder.set_input_array(input);
536
+ encoder.dispatch(
537
+ [input_ptr = input.data<char>(), nbytes = input.nbytes(), dst, this]() {
538
+ int right = (rank_ + 1) % size_;
539
+ int left = (rank_ + size_ - 1) % size_;
540
+ if (dst == right) {
541
+ send(sockets_right_, input_ptr, nbytes);
542
+ } else if (dst == left) {
543
+ send(sockets_left_, input_ptr, nbytes);
544
+ } else {
545
+ std::ostringstream msg;
546
+ msg << "[ring] Send only supported to direct neighbors "
547
+ << "but tried to send to " << dst << " from " << rank_
548
+ << std::endl;
549
+ throw std::runtime_error(msg.str());
550
+ }
551
+ });
552
+ }
553
+
554
+ void recv(array& out, int src, Stream stream) override {
555
+ auto& encoder = cpu::get_command_encoder(stream);
556
+ encoder.set_output_array(out);
557
+ encoder.dispatch(
558
+ [out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() {
559
+ // NOTE: We 'll check the sockets with the opposite order of send so
560
+ // that they work even with 2 nodes where left and right is the same
561
+ // neighbor.
562
+ int right = (rank_ + 1) % size_;
563
+ int left = (rank_ + size_ - 1) % size_;
564
+ if (src == left) {
565
+ recv(sockets_left_, out_ptr, nbytes);
566
+ } else if (src == right) {
567
+ recv(sockets_right_, out_ptr, nbytes);
568
+ } else {
569
+ std::ostringstream msg;
570
+ msg << "[ring] Recv only supported from direct neighbors "
571
+ << "but tried to recv from " << src << " to " << rank_
572
+ << std::endl;
573
+ throw std::runtime_error(msg.str());
574
+ }
575
+ });
576
+ }
577
+
578
+ void sum_scatter(const array& input, array& output, Stream stream) override {
579
+ throw std::runtime_error("[ring] sum_scatter not supported.");
580
+ }
581
+
582
+ private:
583
+ template <typename T, typename ReduceOp>
584
+ void all_reduce(
585
+ const array& input,
586
+ array& output,
587
+ Stream stream,
588
+ ReduceOp reduce_op) {
589
+ auto in_ptr = input.data<char>();
590
+ auto out_ptr = output.data<char>();
591
+ auto& encoder = cpu::get_command_encoder(stream);
592
+ encoder.set_output_array(output);
593
+ encoder.dispatch([in_ptr, out_ptr, size = input.size(), this, reduce_op]() {
594
+ // If the input data cannot be split into size_ segments then copy it and
595
+ // all reduce a local buffer prefilled with 0s.
596
+ size_t nbytes = size * sizeof(T);
597
+ if (size < size_) {
598
+ // TODO: Maybe allocate dynamically so we don't have the constraint
599
+ // below?
600
+ if (sizeof(T) * size_ > 1024) {
601
+ std::ostringstream msg;
602
+ msg << "Can't perform the ring all reduce of " << size
603
+ << " elements with a ring of size " << size_;
604
+ throw std::runtime_error(msg.str());
605
+ }
606
+
607
+ char buffer[1024];
608
+ std::memset(buffer, 0, size_ * sizeof(T));
609
+ std::memcpy(buffer, in_ptr, nbytes);
610
+ all_reduce_impl<T, ReduceOp>(
611
+ reinterpret_cast<T*>(buffers_.data()),
612
+ reinterpret_cast<T*>(buffer),
613
+ size_,
614
+ sockets_right_[0],
615
+ sockets_left_[0],
616
+ -1,
617
+ reduce_op);
618
+ std::memcpy(out_ptr, buffer, nbytes);
619
+ return;
620
+ }
621
+
622
+ // If not inplace all reduce then copy the input to the output first
623
+ if (in_ptr != out_ptr) {
624
+ std::memcpy(out_ptr, in_ptr, nbytes);
625
+ }
626
+
627
+ // Split the all reduces so that each member has at least 1 buffer to
628
+ // send/recv per segment.
629
+ constexpr size_t min_send_size = 262144;
630
+ size_t n_reduces = std::max(
631
+ std::min(
632
+ sockets_right_.size() + sockets_left_.size(),
633
+ nbytes / (size_ * min_send_size)),
634
+ size_t(1));
635
+ size_t step = ceildiv(size, n_reduces);
636
+ std::vector<std::future<void>> all_sums;
637
+
638
+ for (int i = 0; i < n_reduces; i++) {
639
+ all_sums.emplace_back(pool_.enqueue(
640
+ std::bind(
641
+ &RingGroup::all_reduce_impl<T, ReduceOp>,
642
+ this,
643
+ reinterpret_cast<T*>(
644
+ buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
645
+ reinterpret_cast<T*>(out_ptr) + i * step,
646
+ std::min(size, (i + 1) * step) - i * step,
647
+ sockets_right_[i / 2],
648
+ sockets_left_[i / 2],
649
+ (i % 2) ? -1 : 1,
650
+ reduce_op)));
651
+ }
652
+ for (auto& f : all_sums) {
653
+ f.wait();
654
+ }
655
+ });
656
+ }
657
+
658
+ template <typename T, typename ReduceOp>
659
+ void all_reduce_impl(
660
+ T* buffer,
661
+ T* data,
662
+ size_t data_size,
663
+ int socket_right,
664
+ int socket_left,
665
+ int direction,
666
+ ReduceOp reduce_op) {
667
+ // Choose which socket we send to and recv from
668
+ int socket_send = (direction < 0) ? socket_right : socket_left;
669
+ int socket_recv = (direction < 0) ? socket_left : socket_right;
670
+
671
+ // We split the data into `size_` segments of size `segment_size` and each
672
+ // of these in smaller segments of ALL_SUM_SIZE which we 'll call packets.
673
+ size_t segment_size = ceildiv(data_size, size_);
674
+ size_t BUFFER_SIZE = std::max(
675
+ size_t(32768), std::min(ALL_SUM_SIZE / sizeof(T), segment_size / 2));
676
+ size_t n_packets = ceildiv(segment_size, BUFFER_SIZE);
677
+
678
+ // Initial segments
679
+ int send_segment = rank_;
680
+ int recv_segment = (rank_ + direction + size_) % size_;
681
+
682
+ // Plan the whole reduce in terms of sends and recvs as indices in data.
683
+ // It makes the actual async send and recv a bit simpler to follow when
684
+ // there are less offset calculations around.
685
+ std::vector<std::pair<size_t, size_t>> send_plan;
686
+ std::vector<std::pair<size_t, size_t>> recv_plan;
687
+
688
+ // Two times the same send/recv operations, first scatter reduce and then
689
+ // gather.
690
+ for (int k = 0; k < 2; k++) {
691
+ for (int i = 0; i < size_ - 1; i++) {
692
+ size_t send_start = send_segment * segment_size;
693
+ size_t send_stop =
694
+ std::min((send_segment + 1) * segment_size, data_size);
695
+ size_t recv_start = recv_segment * segment_size;
696
+ size_t recv_stop =
697
+ std::min((recv_segment + 1) * segment_size, data_size);
698
+
699
+ for (size_t j = 0; j < n_packets; j++) {
700
+ send_plan.emplace_back(
701
+ std::min(send_start + j * BUFFER_SIZE, send_stop),
702
+ std::min(send_start + (j + 1) * BUFFER_SIZE, send_stop));
703
+ recv_plan.emplace_back(
704
+ std::min(recv_start + j * BUFFER_SIZE, recv_stop),
705
+ std::min(recv_start + (j + 1) * BUFFER_SIZE, recv_stop));
706
+ }
707
+
708
+ send_segment = (send_segment + size_ + direction) % size_;
709
+ recv_segment = (recv_segment + size_ + direction) % size_;
710
+ }
711
+ }
712
+
713
+ // Running the plan is fairly simple, we keep a send and a recv in flight
714
+ // while doing the summation.
715
+ T* recv_buffers[ALL_SUM_BUFFERS];
716
+ for (int i = 0; i < ALL_SUM_BUFFERS; i++) {
717
+ recv_buffers[i] = buffer + i * BUFFER_SIZE;
718
+ }
719
+ std::future<void> sends[2], recvs[2];
720
+ int a = 0;
721
+ int b = (n_packets > 1) ? 1 : 0;
722
+ for (int i = 0, j = -b; i < send_plan.size(); j++, i++) {
723
+ sends[a] = comm_.send(
724
+ socket_send,
725
+ data + send_plan[i].first,
726
+ send_plan[i].second - send_plan[i].first);
727
+ if (2 * i < send_plan.size()) {
728
+ recvs[a] = comm_.recv(
729
+ socket_recv,
730
+ recv_buffers[i % ALL_SUM_BUFFERS],
731
+ recv_plan[i].second - recv_plan[i].first);
732
+ } else {
733
+ recvs[a] = comm_.recv(
734
+ socket_recv,
735
+ data + recv_plan[i].first,
736
+ recv_plan[i].second - recv_plan[i].first);
737
+ }
738
+
739
+ if (j >= 0) {
740
+ sends[b].wait();
741
+ recvs[b].wait();
742
+ if (2 * j < send_plan.size()) {
743
+ reduce_op(
744
+ recv_buffers[j % ALL_SUM_BUFFERS],
745
+ data + recv_plan[j].first,
746
+ recv_plan[j].second - recv_plan[j].first);
747
+ }
748
+ }
749
+
750
+ std::swap(a, b);
751
+ }
752
+ sends[b].wait();
753
+ recvs[b].wait();
754
+ }
755
+
756
+ void all_gather_impl(
757
+ const char* input,
758
+ char* output,
759
+ size_t input_size,
760
+ size_t data_size,
761
+ int socket_right,
762
+ int socket_left,
763
+ int direction) {
764
+ // Choose which socket we send to and recv from
765
+ int socket_send = (direction < 0) ? socket_right : socket_left;
766
+ int socket_recv = (direction < 0) ? socket_left : socket_right;
767
+
768
+ // Initial segments
769
+ int send_segment = rank_;
770
+ int recv_segment = (rank_ + direction + size_) % size_;
771
+
772
+ // Copy our own segment in the output
773
+ std::memcpy(output + rank_ * input_size, input, data_size);
774
+
775
+ // Simple send/recv all gather. Possible performance improvement by
776
+ // splitting to multiple chunks and allowing send/recv to run a bit ahead.
777
+ // See all_sum_impl for an example.
778
+ for (int i = 0; i < size_ - 1; i++) {
779
+ auto sent = comm_.send(
780
+ socket_send, output + send_segment * input_size, data_size);
781
+ auto recvd = comm_.recv(
782
+ socket_recv, output + recv_segment * input_size, data_size);
783
+
784
+ send_segment = (send_segment + size_ + direction) % size_;
785
+ recv_segment = (recv_segment + size_ + direction) % size_;
786
+
787
+ sent.wait();
788
+ recvd.wait();
789
+ }
790
+ }
791
+
792
+ void
793
+ send(const std::vector<int>& sockets, const char* data, size_t data_size) {
794
+ size_t segment_size =
795
+ std::max(size_t(1024), ceildiv(data_size, sockets.size()));
796
+ std::vector<std::future<void>> sends;
797
+ for (int i = 0; i < sockets.size(); i++) {
798
+ if (i * segment_size >= data_size) {
799
+ break;
800
+ }
801
+ sends.emplace_back(comm_.send(
802
+ sockets[i],
803
+ data + i * segment_size,
804
+ std::min(data_size, (i + 1) * segment_size) - i * segment_size));
805
+ }
806
+ for (auto& f : sends) {
807
+ f.wait();
808
+ }
809
+ }
810
+
811
+ void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
812
+ size_t segment_size =
813
+ std::max(size_t(1024), ceildiv(data_size, sockets.size()));
814
+ std::vector<std::future<void>> recvs;
815
+ for (int i = 0; i < sockets.size(); i++) {
816
+ if (i * segment_size >= data_size) {
817
+ break;
818
+ }
819
+ recvs.emplace_back(comm_.recv(
820
+ sockets[i],
821
+ data + i * segment_size,
822
+ std::min(data_size, (i + 1) * segment_size) - i * segment_size));
823
+ }
824
+ for (auto& f : recvs) {
825
+ f.wait();
826
+ }
827
+ }
828
+
829
+ int rank_;
830
+ int size_;
831
+
832
+ bool verbose_;
833
+
834
+ ThreadPool pool_;
835
+ CommunicationThreads comm_;
836
+
837
+ std::vector<int> sockets_right_;
838
+ std::vector<int> sockets_left_;
839
+
840
+ std::vector<char> buffers_;
841
+ };
842
+
843
+ bool is_available() {
844
+ return true;
845
+ }
846
+
847
+ std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
848
+ const char* hostfile = std::getenv("MLX_HOSTFILE");
849
+ const char* rank_str = std::getenv("MLX_RANK");
850
+ const char* ring_verbose = std::getenv("MLX_RING_VERBOSE");
851
+
852
+ if (!hostfile || !rank_str) {
853
+ if (strict) {
854
+ std::ostringstream msg;
855
+ msg << "[ring] You need to provide via environment variables both a rank (MLX_RANK) "
856
+ << "and a hostfile (MLX_HOSTFILE) but provided MLX_RANK=\""
857
+ << ((rank_str) ? rank_str : "") << "\" and MLX_HOSTFILE=\""
858
+ << ((hostfile) ? hostfile : "") << "\"";
859
+ throw std::runtime_error(msg.str());
860
+ }
861
+ return nullptr;
862
+ }
863
+
864
+ auto nodes = load_nodes(hostfile);
865
+ int rank = std::atoi(rank_str);
866
+
867
+ return std::make_shared<RingGroup>(rank, nodes, ring_verbose != nullptr);
868
+ }
869
+
870
+ } // namespace mlx::core::distributed::ring