mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,941 @@
1
+ // Copyright © 2024 Apple Inc.
2
+ #include <numeric>
3
+ #include <sstream>
4
+ #include <unordered_map>
5
+ #include <unordered_set>
6
+
7
+ #include "mlx/einsum.h"
8
+ #include "mlx/ops.h"
9
+
10
+ namespace mlx::core {
11
+
12
+ namespace {
13
+
14
+ // The MLX einsum implementation is based on NumPy (which is based on
15
+ // opt_einsum):
16
+ // https://github.com/numpy/numpy/blob/1d49c7f7ff527c696fc26ab2278ad51632a66660/numpy/_core/einsumfunc.py#L743
17
+ // https://github.com/dgasmith/opt_einsum
18
+
19
+ using CharSet = std::unordered_set<char>;
20
+
21
+ // A helper struct to hold the string and set
22
+ // representation of a subscript to avoid needing
23
+ // to recompute the set
24
+ struct Subscript {
25
+ Subscript(std::string str, CharSet set)
26
+ : str(std::move(str)), set(std::move(set)) {};
27
+ std::string str;
28
+ CharSet set;
29
+ };
30
+
31
+ struct PathInfo {
32
+ size_t naive_cost;
33
+ size_t naive_scaling;
34
+ size_t optimized_cost;
35
+ size_t optimized_scaling;
36
+ size_t largest_term;
37
+ };
38
+
39
+ struct PathNode {
40
+ PathNode(
41
+ std::vector<Subscript> inputs,
42
+ Subscript output,
43
+ std::vector<int> positions)
44
+ : inputs(std::move(inputs)),
45
+ output(std::move(output)),
46
+ positions(std::move(positions)) {};
47
+
48
+ std::vector<Subscript> inputs;
49
+ Subscript output;
50
+
51
+ std::vector<int> positions;
52
+ };
53
+
54
+ // Parse the comma separated subscripts into a vector of strings. If the
55
+ // output subscripts are missing they are inferred.
56
+ //
57
+ // For example:
58
+ // "ij,jk -> ik" becomes {{"ij", "jk"}, "ik"}
59
+ // "ij,jk" becomes {{"ij", "jk"}, "ik"}
60
+ std::pair<std::vector<std::string>, std::string> parse(std::string subscripts) {
61
+ std::string lhs, rhs;
62
+
63
+ // Start by removing all white space
64
+ subscripts.erase(
65
+ std::remove(subscripts.begin(), subscripts.end(), ' '), subscripts.end());
66
+
67
+ if (auto pos = subscripts.find("->"); pos != std::string::npos) {
68
+ // Explicit mode
69
+ lhs = subscripts.substr(0, pos);
70
+ rhs = subscripts.substr(pos + 2);
71
+ } else {
72
+ // Implicit mode:
73
+ // - repeats are summed
74
+ // - ellipses are placed in the beginning of the output
75
+ // - remaining output axes are ordered alphabetically
76
+ lhs = subscripts;
77
+ std::unordered_map<char, int> temp;
78
+ for (auto& c : subscripts) {
79
+ if (c == ',') {
80
+ continue;
81
+ }
82
+ if (c == '.' && rhs.empty()) {
83
+ rhs += "...";
84
+ continue;
85
+ }
86
+
87
+ auto inserted = temp.insert({c, 0});
88
+ inserted.first->second++;
89
+ }
90
+ for (auto& k : temp) {
91
+ if (k.second == 1) {
92
+ rhs += k.first;
93
+ }
94
+ }
95
+ std::sort(rhs.begin(), rhs.end());
96
+ }
97
+ std::vector<std::string> input_list;
98
+ std::stringstream ss(lhs);
99
+ std::string token;
100
+ while (getline(ss, token, ',')) {
101
+ input_list.push_back(token);
102
+ }
103
+ return {input_list, rhs};
104
+ }
105
+
106
+ // Check if two sets are disjoint
107
+ bool disjoint(const CharSet& x, const CharSet& y) {
108
+ for (auto& c : x) {
109
+ if (y.find(c) != y.end()) {
110
+ return false;
111
+ }
112
+ }
113
+ return true;
114
+ }
115
+
116
+ template <typename T>
117
+ size_t term_size(const T& term, std::unordered_map<char, ShapeElem> dict) {
118
+ size_t size = 1;
119
+ for (auto c : term) {
120
+ size *= dict[c];
121
+ }
122
+ return size;
123
+ }
124
+
125
+ size_t flop_count(
126
+ const CharSet& term,
127
+ bool inner,
128
+ int num_terms,
129
+ std::unordered_map<char, ShapeElem> dict) {
130
+ size_t size = term_size(term, dict);
131
+ auto op_factor = 1;
132
+ if ((num_terms - 1) > op_factor) {
133
+ op_factor = num_terms - 1;
134
+ }
135
+ if (inner) {
136
+ op_factor += 1;
137
+ }
138
+ return size * op_factor;
139
+ }
140
+
141
+ std::pair<size_t, int> compute_cost_and_scaling(
142
+ const std::vector<Subscript>& inputs,
143
+ const Subscript& output,
144
+ std::unordered_map<char, ShapeElem> dim_map) {
145
+ CharSet contractions;
146
+ for (auto& in : inputs) {
147
+ contractions.insert(in.set.begin(), in.set.end());
148
+ }
149
+
150
+ bool inner = false;
151
+ for (auto c : contractions) {
152
+ if (output.set.find(c) == output.set.end()) {
153
+ inner = true;
154
+ break;
155
+ }
156
+ }
157
+ auto cost = flop_count(contractions, inner, inputs.size(), dim_map);
158
+ return {cost, contractions.size()};
159
+ }
160
+
161
+ std::tuple<std::vector<PathNode>, size_t, int> greedy_path(
162
+ std::vector<Subscript> inputs,
163
+ const Subscript& output,
164
+ std::unordered_map<char, ShapeElem> dim_map,
165
+ size_t cost_limit,
166
+ size_t memory_limit) {
167
+ // Helper struct for building the greedy path
168
+ struct Contraction {
169
+ Contraction(
170
+ size_t size,
171
+ size_t cost,
172
+ CharSet output,
173
+ int dims,
174
+ int x,
175
+ int y)
176
+ : size(size),
177
+ cost(cost),
178
+ output(std::move(output)),
179
+ dims(dims),
180
+ x(x),
181
+ y(y) {};
182
+
183
+ int64_t size; // Size difference, can be negative
184
+ size_t cost;
185
+ CharSet output;
186
+ int dims; // Number of dimensions in the contraction
187
+ int x;
188
+ int y;
189
+ };
190
+
191
+ // Start by iterating over all possible combinations
192
+ std::vector<std::pair<int, int>> pos_pairs;
193
+ for (int i = 0; i < inputs.size(); ++i) {
194
+ for (int j = i + 1; j < inputs.size(); ++j) {
195
+ pos_pairs.emplace_back(i, j);
196
+ }
197
+ }
198
+
199
+ std::vector<PathNode> path;
200
+ std::vector<Contraction> possible_contractions;
201
+ size_t path_cost = 0;
202
+ int path_scaling = 0;
203
+ auto num_in = inputs.size();
204
+ for (int i = 0; i < num_in - 1; ++i) {
205
+ auto add_contraction = [&](int p1, int p2) {
206
+ CharSet new_term;
207
+ CharSet contractions(inputs[p1].set.begin(), inputs[p1].set.end());
208
+ contractions.insert(inputs[p2].set.begin(), inputs[p2].set.end());
209
+ for (int i = 0; i < inputs.size(); i++) {
210
+ if (i == p1 || i == p2) {
211
+ continue;
212
+ }
213
+ auto& in = inputs[i].set;
214
+ for (auto c : in) {
215
+ if (contractions.find(c) != contractions.end()) {
216
+ new_term.insert(c);
217
+ }
218
+ }
219
+ }
220
+ for (auto c : output.set) {
221
+ if (contractions.find(c) != contractions.end()) {
222
+ new_term.insert(c);
223
+ }
224
+ }
225
+
226
+ // Ignore if:
227
+ // - The size of the new result is greater than the memory limit
228
+ // - The cost is larger than the naive cost
229
+ auto new_size = term_size(new_term, dim_map);
230
+ if (new_size > memory_limit) {
231
+ return;
232
+ }
233
+ int64_t removed_size = term_size(inputs[p1].set, dim_map) +
234
+ term_size(inputs[p2].set, dim_map) - new_size;
235
+
236
+ bool inner = contractions.size() > new_term.size();
237
+ auto cost = flop_count(contractions, inner, 2, dim_map);
238
+ if (path_cost + cost > cost_limit) {
239
+ return;
240
+ }
241
+ possible_contractions.emplace_back(
242
+ removed_size, cost, std::move(new_term), contractions.size(), p1, p2);
243
+ };
244
+
245
+ for (auto& [p1, p2] : pos_pairs) {
246
+ // Ignore outer products
247
+ if (!disjoint(inputs[p1].set, inputs[p2].set)) {
248
+ add_contraction(p1, p2);
249
+ }
250
+ }
251
+
252
+ // If there's nothing in the contraction list,
253
+ // go over the pairs again without ignoring outer products
254
+ if (possible_contractions.empty()) {
255
+ for (auto& [p1, p2] : pos_pairs) {
256
+ add_contraction(p1, p2);
257
+ }
258
+ }
259
+
260
+ if (possible_contractions.empty()) {
261
+ // Default to naive einsum for the remaining inputs
262
+ std::vector<int> positions(inputs.size());
263
+ std::iota(positions.begin(), positions.end(), 0);
264
+ auto [cost, scale] = compute_cost_and_scaling(inputs, output, dim_map);
265
+ path.emplace_back(std::move(inputs), output, std::move(positions));
266
+
267
+ path_cost += cost;
268
+ path_scaling = std::max(scale, path_scaling);
269
+ break;
270
+ }
271
+
272
+ // Find the best contraction
273
+ auto& best = *std::min_element(
274
+ possible_contractions.begin(),
275
+ possible_contractions.end(),
276
+ [](const auto& x, const auto& y) {
277
+ return x.size > y.size || (x.size == y.size && x.cost < y.cost);
278
+ });
279
+ path_scaling = std::max(best.dims, path_scaling);
280
+
281
+ // Construct the output subscripts
282
+ std::string out_str(best.output.begin(), best.output.end());
283
+ // TODO, sorting by dimension size seems suboptimal?
284
+ std::sort(out_str.begin(), out_str.end(), [&dim_map](auto x, auto y) {
285
+ return dim_map[x] < dim_map[y];
286
+ });
287
+ Subscript new_output(std::move(out_str), std::move(best.output));
288
+
289
+ // Add the chosen contraction to the path
290
+ {
291
+ std::vector<Subscript> in_terms;
292
+ in_terms.push_back(std::move(inputs[best.x]));
293
+ in_terms.push_back(std::move(inputs[best.y]));
294
+ path.emplace_back(
295
+ std::move(in_terms), new_output, std::vector<int>{best.x, best.y});
296
+ }
297
+ // Remove used terms
298
+ inputs.erase(inputs.begin() + best.y);
299
+ inputs.erase(inputs.begin() + best.x);
300
+
301
+ // Add the new result
302
+ inputs.push_back(std::move(new_output));
303
+
304
+ // Update the existing contractions based on the selected one
305
+ std::vector<Contraction> updated_contractions;
306
+ for (auto& contraction : possible_contractions) {
307
+ // Drop contractions which contain either selected term
308
+ if (contraction.x == best.x || contraction.x == best.y ||
309
+ contraction.y == best.x || contraction.y == best.y) {
310
+ continue;
311
+ }
312
+
313
+ // Update the positions of other contractions
314
+ int x =
315
+ contraction.x - (contraction.x > best.x) - (contraction.x > best.y);
316
+ int y =
317
+ contraction.y - (contraction.y > best.x) - (contraction.y > best.y);
318
+ contraction.x = x;
319
+ contraction.y = y;
320
+ updated_contractions.push_back(std::move(contraction));
321
+ }
322
+
323
+ pos_pairs.clear();
324
+ for (int i = 0; i < inputs.size() - 1; ++i) {
325
+ pos_pairs.emplace_back(i, inputs.size() - 1);
326
+ }
327
+ path_cost += best.cost;
328
+
329
+ possible_contractions = std::move(updated_contractions);
330
+ }
331
+ return {path, path_cost, path_scaling};
332
+ }
333
+
334
+ // Assumes inputs have already have had repeats and single axis sums collapsed
335
+ bool can_dot(const std::vector<Subscript>& inputs, const Subscript& output) {
336
+ if (inputs.size() != 2) {
337
+ return false;
338
+ }
339
+
340
+ for (auto c : inputs[0].set) {
341
+ // Use batched tensordot if anything is being contracted
342
+ if (output.set.find(c) == output.set.end()) {
343
+ return true;
344
+ }
345
+ }
346
+ return false;
347
+ }
348
+
349
+ array batch_tensordot(
350
+ array a,
351
+ array b,
352
+ std::vector<int> a_contract,
353
+ std::vector<int> a_batch,
354
+ std::vector<int> a_concat,
355
+ std::vector<int> b_contract,
356
+ std::vector<int> b_batch,
357
+ std::vector<int> b_concat,
358
+ StreamOrDevice s) {
359
+ // Broadcast contracting dimensions
360
+ {
361
+ auto a_shape = a.shape();
362
+ auto b_shape = b.shape();
363
+ for (int i = 0; i < a_contract.size(); ++i) {
364
+ auto d = std::max(a.shape(a_contract[i]), b.shape(b_contract[i]));
365
+ a_shape[a_contract[i]] = d;
366
+ b_shape[b_contract[i]] = d;
367
+ }
368
+ a = broadcast_to(a, a_shape, s);
369
+ b = broadcast_to(b, b_shape, s);
370
+ }
371
+ auto transpose_reshape = [&s](
372
+ const array& x,
373
+ const std::vector<int>& i,
374
+ const std::vector<int>& j,
375
+ const std::vector<int>& k) {
376
+ std::vector<int> reorder(i.begin(), i.end());
377
+ reorder.insert(reorder.end(), j.begin(), j.end());
378
+ reorder.insert(reorder.end(), k.begin(), k.end());
379
+
380
+ int size1 = 1;
381
+ for (auto s : j) {
382
+ size1 *= x.shape(s);
383
+ }
384
+
385
+ int size2 = 1;
386
+ for (auto s : k) {
387
+ size2 *= x.shape(s);
388
+ }
389
+
390
+ Shape shape;
391
+ for (auto ax : i) {
392
+ shape.push_back(x.shape(ax));
393
+ }
394
+ shape.push_back(size1);
395
+ shape.push_back(size2);
396
+
397
+ return reshape(transpose(x, reorder, s), std::move(shape), s);
398
+ };
399
+
400
+ Shape out_shape;
401
+ for (auto ax : a_batch) {
402
+ out_shape.push_back(a.shape(ax));
403
+ }
404
+ for (auto ax : a_concat) {
405
+ out_shape.push_back(a.shape(ax));
406
+ }
407
+ for (auto ax : b_concat) {
408
+ out_shape.push_back(b.shape(ax));
409
+ }
410
+
411
+ a = transpose_reshape(a, a_batch, a_concat, a_contract);
412
+ b = transpose_reshape(b, b_batch, b_contract, b_concat);
413
+
414
+ return reshape(matmul(a, b, s), std::move(out_shape), s);
415
+ }
416
+
417
+ // Collapse repeated subscripts and return the resulting array. The subscript
418
+ // is also updated in place. For example:
419
+ // - Given an input with shape (4, 4) and subscript "ii", returns
420
+ // the diagonal of shape (4,) and updates the subscript to "i".
421
+ // - Given an input with shape (4, 2, 4, 2) and subscript "ijij",
422
+ // returns an output with shape (4, 2) and updates the subscript
423
+ // to "ij".
424
+ array collapse_repeats(array in, Subscript& subscript, StreamOrDevice s) {
425
+ // Build a list of (repeat chars, num repeats)
426
+ auto& str = subscript.str;
427
+ std::vector<std::pair<char, int>> repeats;
428
+ std::string new_str;
429
+ {
430
+ std::string repeat_str;
431
+ std::string no_repeat_str;
432
+ std::unordered_map<char, int> counts;
433
+ for (int i = 0; i < str.size(); ++i) {
434
+ auto [it, _] = counts.insert({str[i], 0});
435
+ it->second++;
436
+ }
437
+
438
+ for (auto& v : counts) {
439
+ if (v.second > 1) {
440
+ repeats.emplace_back(v.first, v.second);
441
+ repeat_str += v.first;
442
+ }
443
+ }
444
+ for (auto& c : str) {
445
+ if (counts[c] == 1) {
446
+ no_repeat_str += c;
447
+ }
448
+ }
449
+ new_str = repeat_str + no_repeat_str;
450
+ }
451
+
452
+ // Build the inputs for gather
453
+ auto slice_sizes = in.shape();
454
+ std::vector<int> axes;
455
+ std::vector<array> indices;
456
+ int n_expand = repeats.size();
457
+ for (auto [c, v] : repeats) {
458
+ for (int i = 0; i < str.size(); ++i) {
459
+ if (str[i] == c) {
460
+ slice_sizes[i] = 1;
461
+ axes.push_back(i);
462
+ }
463
+ }
464
+ Shape idx_shape(n_expand--, 1);
465
+ idx_shape[0] = in.shape(axes.back());
466
+ auto idx = reshape(
467
+ arange(static_cast<ShapeElem>(in.shape(axes.back())), s), idx_shape, s);
468
+ for (int i = 0; i < v; ++i) {
469
+ indices.push_back(idx);
470
+ }
471
+ }
472
+
473
+ in = gather(in, indices, axes, slice_sizes, s);
474
+
475
+ // Update subscript string with removed dups
476
+ str = new_str;
477
+
478
+ // Squeeze singleton dimensions left over from the gather
479
+ for (auto& ax : axes) {
480
+ ax += indices[0].ndim();
481
+ }
482
+
483
+ return squeeze(in, axes, s);
484
+ }
485
+
486
+ // Collapse repeat indices and sum single dimensions.
487
+ // For example:
488
+ // - "aa" becomes "a"
489
+ // - "ij,jk->k" becoms "j,jk->k"
490
+ void preprocess_einsum_inputs(
491
+ std::vector<Subscript>& inputs,
492
+ const Subscript& output,
493
+ const std::vector<int>& positions,
494
+ std::vector<array>& operands,
495
+ StreamOrDevice s) {
496
+ // Collapse repeat indices
497
+ for (int i = 0; i < inputs.size(); ++i) {
498
+ auto& in = inputs[i];
499
+ if (in.set.size() < in.str.size()) {
500
+ operands[positions[i]] = collapse_repeats(operands[positions[i]], in, s);
501
+ }
502
+ }
503
+
504
+ // Sum indices that are only in a single input
505
+ {
506
+ std::unordered_map<char, int> counts;
507
+ for (auto& in : inputs) {
508
+ for (auto c : in.set) {
509
+ auto inserted = counts.insert({c, 0});
510
+ inserted.first->second++;
511
+ }
512
+ }
513
+ for (auto c : output.set) {
514
+ auto inserted = counts.insert({c, 0});
515
+ inserted.first->second++;
516
+ }
517
+ for (int i = 0; i < inputs.size(); ++i) {
518
+ auto& in = inputs[i];
519
+ std::vector<int> sum_axes;
520
+ for (int ax = 0; ax < in.str.size(); ++ax) {
521
+ if (counts[in.str[ax]] == 1) {
522
+ sum_axes.push_back(ax);
523
+ }
524
+ }
525
+ if (!sum_axes.empty()) {
526
+ operands[positions[i]] =
527
+ sum(operands[positions[i]], sum_axes, false, s);
528
+ }
529
+ for (auto it = sum_axes.rbegin(); it != sum_axes.rend(); ++it) {
530
+ in.set.erase(in.str[*it]);
531
+ in.str.erase(in.str.begin() + *it);
532
+ }
533
+ }
534
+ }
535
+ }
536
+
537
+ array einsum_naive(
538
+ std::vector<Subscript> inputs,
539
+ const Subscript& output,
540
+ const std::vector<int>& positions,
541
+ std::vector<array> operands,
542
+ StreamOrDevice s) {
543
+ // Map each character to an axis
544
+ std::unordered_map<char, int> char_to_ax;
545
+ for (auto& in : inputs) {
546
+ for (auto c : in.str) {
547
+ char_to_ax.insert({c, char_to_ax.size()});
548
+ }
549
+ }
550
+
551
+ // Expand and transpose inputs as needed
552
+ for (int i = 0; i < inputs.size(); ++i) {
553
+ int pos = positions[i];
554
+ auto& op = operands[pos];
555
+
556
+ // Add missing dimensions at the end
557
+ if (op.ndim() != char_to_ax.size()) {
558
+ auto shape = op.shape();
559
+ shape.insert(shape.end(), char_to_ax.size() - shape.size(), 1);
560
+ op = reshape(op, std::move(shape), s);
561
+ }
562
+
563
+ // Transpose:
564
+ // - Build a vector of (char, ax) pairs for the current input
565
+ // - Sort the vector by the canonical axis in char_to_ax
566
+ // - Extract the sorted axis to get transpose order
567
+ std::vector<std::pair<char, int>> str_ax;
568
+ for (auto c : inputs[i].str) {
569
+ str_ax.emplace_back(c, str_ax.size());
570
+ }
571
+ for (auto [c, ax] : char_to_ax) {
572
+ if (inputs[i].set.find(c) == inputs[i].set.end()) {
573
+ str_ax.emplace_back(c, str_ax.size());
574
+ }
575
+ }
576
+ std::sort(
577
+ str_ax.begin(),
578
+ str_ax.end(),
579
+ [&char_to_ax](const auto& x, const auto& y) {
580
+ return char_to_ax[x.first] < char_to_ax[y.first];
581
+ });
582
+
583
+ // Skip the transpose if not needed
584
+ if (std::is_sorted(
585
+ str_ax.begin(), str_ax.end(), [](const auto& x, const auto& y) {
586
+ return x.second < y.second;
587
+ })) {
588
+ continue;
589
+ }
590
+
591
+ std::vector<int> reorder;
592
+ for (auto [c, ax] : str_ax) {
593
+ reorder.push_back(ax);
594
+ }
595
+ op = transpose(op, reorder, s);
596
+ }
597
+
598
+ // Multiply and sum
599
+ auto out = operands[positions[0]];
600
+ for (int i = 1; i < positions.size(); ++i) {
601
+ out = multiply(out, operands[positions[i]], s);
602
+ }
603
+ std::vector<int> sum_axes;
604
+ for (auto [c, ax] : char_to_ax) {
605
+ if (output.set.find(c) == output.set.end()) {
606
+ sum_axes.push_back(ax);
607
+ }
608
+ }
609
+ if (!sum_axes.empty()) {
610
+ out = sum(out, sum_axes, false, s);
611
+ }
612
+
613
+ // Transpose output if needed
614
+ std::vector<int> reorder;
615
+ for (auto c : output.str) {
616
+ reorder.push_back(char_to_ax[c]);
617
+ }
618
+ for (auto& r : reorder) {
619
+ int offset = 0;
620
+ for (auto s : sum_axes) {
621
+ if (r > s) {
622
+ offset++;
623
+ }
624
+ }
625
+ r -= offset;
626
+ }
627
+ return transpose(out, reorder, s);
628
+ }
629
+
630
+ std::pair<std::vector<PathNode>, PathInfo> einsum_path_helper(
631
+ const std::string& subscripts,
632
+ const std::vector<array>& operands,
633
+ const std::string& fn_name) {
634
+ if (operands.size() == 0) {
635
+ std::ostringstream msg;
636
+ msg << "[" << fn_name << "] At least one operand is required.";
637
+ throw std::invalid_argument(msg.str());
638
+ }
639
+
640
+ auto [in_subscripts, out_subscript] = parse(subscripts);
641
+
642
+ if (operands.size() != in_subscripts.size()) {
643
+ std::ostringstream msg;
644
+ msg << "[" << fn_name << "] Number of operands, " << operands.size()
645
+ << ", does not match number of input subscripts, "
646
+ << in_subscripts.size();
647
+ throw std::invalid_argument(msg.str());
648
+ }
649
+
650
+ // Expand ellipses
651
+ // 1. Collect all the characters we can use for the missing axes.
652
+ // 2. Go over each subscript and check if all the characters are either
653
+ // alphanumeric or an ellipsis.
654
+ // 3. Expand the ellipsis with as many characters from the unused ones as
655
+ // necessary. We use the last N characters effectively prepending with
656
+ // singleton dims for inputs with fewer dimensions.
657
+ // 4. For the output use the maximum size of ellipsis that we encountered in
658
+ // the input.
659
+ CharSet used_chars(subscripts.begin(), subscripts.end());
660
+ std::string remaining_chars;
661
+ remaining_chars.reserve(52 - used_chars.size());
662
+ for (char c = 'a'; c <= 'z'; c++) {
663
+ if (used_chars.find(c) == used_chars.end()) {
664
+ remaining_chars += c;
665
+ }
666
+ }
667
+ for (char c = 'A'; c <= 'Z'; c++) {
668
+ if (used_chars.find(c) == used_chars.end()) {
669
+ remaining_chars += c;
670
+ }
671
+ }
672
+ int max_ellipsis_length = 0;
673
+ auto check_letters_and_expand_ellipsis = [&](auto& subscript,
674
+ const array* operand,
675
+ int operand_idx) {
676
+ bool have_ellipsis = false;
677
+ int cnt_before = 0, cnt_after = 0;
678
+ for (int i = 0; i < subscript.size(); i++) {
679
+ if (!isalpha(subscript[i])) {
680
+ if (i + 2 >= subscript.size() || subscript[i] != '.' ||
681
+ subscript[i + 1] != '.' || subscript[i + 2] != '.') {
682
+ std::ostringstream msg;
683
+ msg << "[" << fn_name << "] Subscripts must be letters, but got '"
684
+ << subscript[i] << "'.";
685
+ throw std::invalid_argument(msg.str());
686
+ }
687
+
688
+ if (have_ellipsis) {
689
+ std::ostringstream msg;
690
+ msg << "[" << fn_name
691
+ << "] Only one ellipsis per subscript is allowed but found more in '"
692
+ << subscript << "'.";
693
+ throw std::invalid_argument(msg.str());
694
+ }
695
+
696
+ have_ellipsis = true;
697
+ i += 2;
698
+ continue;
699
+ }
700
+
701
+ if (have_ellipsis) {
702
+ cnt_after++;
703
+ } else {
704
+ cnt_before++;
705
+ }
706
+ }
707
+
708
+ if (have_ellipsis) {
709
+ int ellipsis_length;
710
+ if (operand != nullptr) {
711
+ ellipsis_length = operand->ndim() - cnt_before - cnt_after;
712
+ if (ellipsis_length < 0) {
713
+ std::ostringstream msg;
714
+ msg << "[" << fn_name << "] Operand " << operand_idx << " with shape "
715
+ << operand->shape()
716
+ << " has insufficient dimensions for subscript '" << subscript
717
+ << "'. The ellipsis requires at least "
718
+ << (cnt_before + cnt_after) << " dimensions but the operand has "
719
+ << operand->ndim() << " dimensions.";
720
+ throw std::invalid_argument(msg.str());
721
+ }
722
+ max_ellipsis_length = std::max(ellipsis_length, max_ellipsis_length);
723
+ } else {
724
+ ellipsis_length = max_ellipsis_length;
725
+ }
726
+
727
+ subscript.replace(
728
+ subscript.begin() + cnt_before,
729
+ subscript.begin() + cnt_before + 3,
730
+ remaining_chars.end() - ellipsis_length,
731
+ remaining_chars.end());
732
+ }
733
+ };
734
+
735
+ for (int i = 0; i < operands.size(); i++) {
736
+ check_letters_and_expand_ellipsis(in_subscripts[i], &operands[i], i);
737
+ }
738
+ check_letters_and_expand_ellipsis(out_subscript, nullptr, -1);
739
+
740
+ CharSet out_set(out_subscript.begin(), out_subscript.end());
741
+ if (out_set.size() != out_subscript.size()) {
742
+ std::ostringstream msg;
743
+ msg << "[" << fn_name << "] Repeat indices not allowed in output.";
744
+ throw std::invalid_argument(msg.str());
745
+ }
746
+ Subscript output(out_subscript, std::move(out_set));
747
+
748
+ std::unordered_map<char, ShapeElem> dim_map;
749
+ std::vector<Subscript> inputs;
750
+ for (int i = 0; i < in_subscripts.size(); ++i) {
751
+ auto& in = in_subscripts[i];
752
+ CharSet in_set(in.begin(), in.end());
753
+ inputs.emplace_back(in, in_set);
754
+
755
+ if (in.size() != operands[i].ndim()) {
756
+ std::ostringstream msg;
757
+ msg << "[" << fn_name << "] Invalid number of subscripts " << in.size()
758
+ << " for input " << i << " with " << operands[i].ndim()
759
+ << " dimensions.";
760
+ throw std::invalid_argument(msg.str());
761
+ }
762
+
763
+ // Check repeat subscripts are valid
764
+ if (in_set.size() < in.size()) {
765
+ std::unordered_map<char, ShapeElem> local_dims;
766
+ for (int j = 0; j < in.size(); ++j) {
767
+ auto dim = operands[i].shape(j);
768
+ auto inserted = local_dims.insert({in[j], dim});
769
+ if (!inserted.second) {
770
+ if (inserted.first->second != dim) {
771
+ std::ostringstream msg;
772
+ msg << "[" << fn_name << "] Dimensions of repeated subscripts "
773
+ << "do not have the same size (" << inserted.first->second
774
+ << " != " << dim << ").";
775
+ throw std::invalid_argument(msg.str());
776
+ }
777
+ }
778
+ }
779
+ }
780
+
781
+ for (int j = 0; j < in.size(); j++) {
782
+ auto c = in[j];
783
+ auto dim = operands[i].shape(j);
784
+ auto inserted = dim_map.insert({c, dim});
785
+ auto& in_dim = inserted.first->second;
786
+ if (dim != 1 && in_dim != 1 && in_dim != dim) {
787
+ std::ostringstream msg;
788
+ msg << "[" << fn_name << "] Cannot broadcast dimension " << j
789
+ << " of input " << i << " with shape " << operands[i].shape()
790
+ << " to size " << in_dim << ".";
791
+ throw std::invalid_argument(msg.str());
792
+ }
793
+ // Ensure the broadcasted size is used
794
+ in_dim = std::max(in_dim, dim);
795
+ }
796
+ }
797
+
798
+ size_t max_size = term_size(out_subscript, dim_map);
799
+ for (auto& in : in_subscripts) {
800
+ max_size = std::max(max_size, term_size(in, dim_map));
801
+ }
802
+
803
+ PathInfo path_info;
804
+
805
+ // Get the full naive cost
806
+ std::tie(path_info.naive_cost, path_info.naive_scaling) =
807
+ compute_cost_and_scaling(inputs, output, dim_map);
808
+
809
+ // Calculate the path
810
+ std::vector<PathNode> path;
811
+ if (inputs.size() <= 2) {
812
+ std::vector<int> positions(in_subscripts.size());
813
+ std::iota(positions.begin(), positions.end(), 0);
814
+ path.emplace_back(
815
+ std::move(inputs), std::move(output), std::move(positions));
816
+ } else {
817
+ std::tie(path, path_info.optimized_cost, path_info.optimized_scaling) =
818
+ greedy_path(inputs, output, dim_map, path_info.naive_cost, max_size);
819
+ // Set the final output subscript to the actual output
820
+ path.back().output = std::move(output);
821
+ }
822
+ return {path, path_info};
823
+ }
824
+
825
+ } // namespace
826
+
827
+ std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
828
+ const std::string& subscripts,
829
+ const std::vector<array>& operands) {
830
+ auto [path, path_info] =
831
+ einsum_path_helper(subscripts, operands, "einsum_path");
832
+
833
+ std::vector<std::vector<int>> pos_path;
834
+ for (auto& p : path) {
835
+ pos_path.push_back(p.positions);
836
+ }
837
+
838
+ std::ostringstream path_print;
839
+ path_print << " Complete contraction: " << subscripts << "\n"
840
+ << " Naive scaling: " << path_info.naive_scaling << "\n"
841
+ << " Optimized scaling: " << path_info.optimized_scaling
842
+ << "\n"
843
+ << " Naive FLOP count: " << path_info.naive_cost << "\n"
844
+ << " Optimized FLOP count: " << path_info.optimized_cost << "\n";
845
+ // TODO add more info here
846
+ return {pos_path, path_print.str()};
847
+ }
848
+
849
+ array einsum(
850
+ const std::string& subscripts,
851
+ const std::vector<array>& operands,
852
+ StreamOrDevice s /* = {} */) {
853
+ auto [path, path_info] = einsum_path_helper(subscripts, operands, "einsum");
854
+ auto inputs = operands;
855
+ for (auto& node : path) {
856
+ preprocess_einsum_inputs(
857
+ node.inputs, node.output, node.positions, inputs, s);
858
+
859
+ if (can_dot(node.inputs, node.output)) {
860
+ auto& in_a = node.inputs[0];
861
+ auto& in_b = node.inputs[1];
862
+ auto& out = node.output;
863
+
864
+ std::vector<int> a_contract;
865
+ std::vector<int> a_batch;
866
+ std::vector<int> a_concat;
867
+ for (int i = 0; i < in_a.str.size(); ++i) {
868
+ auto c = in_a.str[i];
869
+ if (out.set.find(c) == out.set.end()) {
870
+ // Not in the output, contraction
871
+ a_contract.push_back(i);
872
+ } else if (in_b.set.find(c) != in_b.set.end()) {
873
+ // Not a contraction but in both inputs, batch dim
874
+ a_batch.push_back(i);
875
+ } else {
876
+ // Not a batch dim or contract dim, so concat dim
877
+ a_concat.push_back(i);
878
+ }
879
+ }
880
+
881
+ std::vector<int> b_contract;
882
+ std::vector<int> b_batch;
883
+ std::vector<int> b_concat;
884
+ for (auto a_i : a_contract) {
885
+ b_contract.push_back(in_b.str.find(in_a.str[a_i]));
886
+ }
887
+ for (auto a_i : a_batch) {
888
+ b_batch.push_back(in_b.str.find(in_a.str[a_i]));
889
+ }
890
+ for (int i = 0; i < in_b.str.size(); ++i) {
891
+ auto c = in_b.str[i];
892
+ if (out.set.find(c) != out.set.end() &&
893
+ in_a.set.find(c) == in_a.set.end()) {
894
+ b_concat.push_back(i);
895
+ }
896
+ }
897
+
898
+ auto& a = inputs[node.positions[0]];
899
+ auto& b = inputs[node.positions[1]];
900
+
901
+ std::unordered_map<char, int> char_map;
902
+ for (auto i : a_batch) {
903
+ char_map.insert({in_a.str[i], char_map.size()});
904
+ }
905
+ for (auto i : a_concat) {
906
+ char_map.insert({in_a.str[i], char_map.size()});
907
+ }
908
+ for (auto i : b_concat) {
909
+ char_map.insert({in_b.str[i], char_map.size()});
910
+ }
911
+ inputs.emplace_back(batch_tensordot(
912
+ a,
913
+ b,
914
+ std::move(a_contract),
915
+ std::move(a_batch),
916
+ std::move(a_concat),
917
+ std::move(b_contract),
918
+ std::move(b_batch),
919
+ std::move(b_concat),
920
+ s));
921
+
922
+ std::vector<int> reorder;
923
+ for (auto c : node.output.str) {
924
+ reorder.push_back(char_map[c]);
925
+ }
926
+ inputs.back() = transpose(inputs.back(), reorder, s);
927
+
928
+ } else {
929
+ inputs.emplace_back(
930
+ einsum_naive(node.inputs, node.output, node.positions, inputs, s));
931
+ }
932
+
933
+ // Positions are always sorted increasing, so start from the back
934
+ for (auto it = node.positions.rbegin(); it != node.positions.rend(); ++it) {
935
+ inputs.erase(inputs.begin() + *it);
936
+ }
937
+ }
938
+ return inputs.front();
939
+ }
940
+
941
+ } // namespace mlx::core