mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,234 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <algorithm>
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <vector>
9
+
10
+ #define __MLX_HALF_NAN__ 0x7D00
11
+
12
+ namespace mlx::core {
13
+
14
+ namespace {
15
+ union float_bits_fp16 {
16
+ float f;
17
+ uint32_t u;
18
+ };
19
+ } // namespace
20
+
21
+ struct _MLX_Float16 {
22
+ uint16_t bits_;
23
+
24
+ // Default constructor
25
+ _MLX_Float16() = default;
26
+
27
+ // Default copy constructor
28
+ _MLX_Float16(_MLX_Float16 const&) = default;
29
+
30
+ // Appease std::vector<bool> for being special
31
+ _MLX_Float16& operator=(std::vector<bool>::reference x) {
32
+ bits_ = x;
33
+ return *this;
34
+ }
35
+
36
+ _MLX_Float16& operator=(const float& x) {
37
+ return (*this = _MLX_Float16(x));
38
+ }
39
+
40
+ // From float32
41
+ _MLX_Float16(const float& x) : bits_(0) {
42
+ // Conversion following
43
+ // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
44
+
45
+ // Union
46
+ float_bits_fp16 in;
47
+
48
+ // Take fp32 bits
49
+ in.f = x;
50
+
51
+ // Find and take sign bit
52
+ uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
53
+ uint16_t x_sign_16 = (x_sign_32 >> 16);
54
+
55
+ if (std::isnan(x)) {
56
+ bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
57
+ } else {
58
+ // Union
59
+ float_bits_fp16 inf_scale, zero_scale, magic_bits;
60
+
61
+ // Find exponent bits and take the max supported by half
62
+ uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
63
+ uint32_t max_expo_32 = uint32_t(0x38800000);
64
+ x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
65
+ x_expo_32 += uint32_t(15) << 23;
66
+
67
+ // Handle scaling to inf as needed
68
+ inf_scale.u = uint32_t(0x77800000);
69
+ zero_scale.u = uint32_t(0x08800000);
70
+
71
+ // Combine with magic and let addition do rounding
72
+ magic_bits.u = x_expo_32;
73
+ magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
74
+
75
+ // Take the lower 5 bits of the exponent
76
+ uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
77
+
78
+ // Collect the lower 12 bits which have the mantissa
79
+ uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
80
+
81
+ // Combine sign, exp and mantissa
82
+ bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
83
+ }
84
+ }
85
+
86
+ // To float32
87
+ operator float() const {
88
+ // Conversion following
89
+ // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
90
+
91
+ // Union
92
+ float_bits_fp16 out;
93
+
94
+ uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
95
+ uint32_t base = (bits_ << 16);
96
+ uint32_t two_base = base + base;
97
+
98
+ uint32_t denorm_max = 1u << 27;
99
+ if (two_base < denorm_max) {
100
+ out.u = uint32_t(126) << 23; // magic mask
101
+ out.u |= (two_base >> 17); // Bits from fp16
102
+ out.f -= 0.5f; // magic bias
103
+ } else {
104
+ out.u = uint32_t(0xE0) << 23; // exponent offset
105
+ out.u += (two_base >> 4); // Bits from fp16
106
+ float out_unscaled = out.f; // Store value
107
+ out.u = uint32_t(0x7800000); // exponent scale
108
+ out.f *= out_unscaled;
109
+ }
110
+
111
+ // Add sign
112
+ out.u |= x_sign_32;
113
+
114
+ return out.f;
115
+ }
116
+ };
117
+
118
+ #define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
119
+ inline otype __operator__(atype lhs, btype rhs) { \
120
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
121
+ }
122
+
123
+ #define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
124
+ inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
125
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
126
+ } \
127
+ inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
128
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
129
+ }
130
+
131
+ // Operators
132
+ #define half_binop(__op__, __operator__) \
133
+ half_binop_base( \
134
+ __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
135
+ half_binop_helper(__op__, __operator__, float, float, float); \
136
+ half_binop_helper(__op__, __operator__, double, double, double); \
137
+ half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
138
+ half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
139
+ half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
140
+ half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
141
+ half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
142
+
143
+ half_binop(+, operator+);
144
+ half_binop(-, operator-);
145
+ half_binop(*, operator*);
146
+ half_binop(/, operator/);
147
+
148
+ #undef half_binop
149
+
150
+ // Comparison ops
151
+ #define half_compop(__op__, __operator__) \
152
+ half_binop_base( \
153
+ __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
154
+ half_binop_helper(__op__, __operator__, bool, float, float); \
155
+ half_binop_helper(__op__, __operator__, bool, double, double); \
156
+ half_binop_helper(__op__, __operator__, bool, int32_t, float); \
157
+ half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
158
+ half_binop_helper(__op__, __operator__, bool, int64_t, float); \
159
+ half_binop_helper(__op__, __operator__, bool, uint64_t, float);
160
+
161
+ half_compop(>, operator>);
162
+ half_compop(<, operator<);
163
+ half_compop(>=, operator>=);
164
+ half_compop(<=, operator<=);
165
+ half_compop(==, operator==);
166
+ half_compop(!=, operator!=);
167
+
168
+ #undef half_compop
169
+
170
+ // Negative
171
+ inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
172
+ return -static_cast<float>(lhs);
173
+ }
174
+
175
+ // Inplace ops
176
+ #define half_inplace_op(__op__, __operator__) \
177
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
178
+ lhs = lhs __op__ rhs; \
179
+ return lhs; \
180
+ } \
181
+ inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
182
+ lhs = lhs __op__ rhs; \
183
+ return lhs; \
184
+ }
185
+
186
+ half_inplace_op(+, operator+=);
187
+ half_inplace_op(-, operator-=);
188
+ half_inplace_op(*, operator*=);
189
+ half_inplace_op(/, operator/=);
190
+
191
+ #undef half_inplace_op
192
+
193
+ // Bitwise ops
194
+
195
+ #define half_bitop(__op__, __operator__) \
196
+ inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
197
+ _MLX_Float16 out; \
198
+ out.bits_ = lhs.bits_ __op__ rhs.bits_; \
199
+ return out; \
200
+ } \
201
+ inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
202
+ _MLX_Float16 out; \
203
+ out.bits_ = lhs.bits_ __op__ rhs; \
204
+ return out; \
205
+ } \
206
+ inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
207
+ _MLX_Float16 out; \
208
+ out.bits_ = lhs __op__ rhs.bits_; \
209
+ return out; \
210
+ }
211
+
212
+ half_bitop(|, operator|);
213
+ half_bitop(&, operator&);
214
+ half_bitop(^, operator^);
215
+
216
+ #undef half_bitop
217
+
218
+ #define half_inplace_bitop(__op__, __operator__) \
219
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
220
+ lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
221
+ return lhs; \
222
+ } \
223
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
224
+ lhs.bits_ = lhs.bits_ __op__ rhs; \
225
+ return lhs; \
226
+ }
227
+
228
+ half_inplace_bitop(|, operator|=);
229
+ half_inplace_bitop(&, operator&=);
230
+ half_inplace_bitop(^, operator^=);
231
+
232
+ #undef half_inplace_bitop
233
+
234
+ } // namespace mlx::core
@@ -0,0 +1,58 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
6
+
7
+ #include <arm_fp16.h>
8
+ namespace mlx::core {
9
+ using ::float16_t;
10
+ } // namespace mlx::core
11
+
12
+ #else
13
+
14
+ #define ADD_HALF_BINOPS
15
+ #include "mlx/types/fp16.h"
16
+ namespace mlx::core {
17
+ typedef struct _MLX_Float16 float16_t;
18
+ } // namespace mlx::core
19
+
20
+ #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
21
+
22
+ #ifdef __ARM_FEATURE_BF16
23
+
24
+ #include <arm_bf16.h>
25
+ namespace mlx::core {
26
+ using ::bfloat16_t;
27
+ } // namespace mlx::core
28
+
29
+ #else
30
+
31
+ #define ADD_HALF_BINOPS
32
+ #include "mlx/types/bf16.h"
33
+ namespace mlx::core {
34
+ typedef struct _MLX_BFloat16 bfloat16_t;
35
+ } // namespace mlx::core
36
+
37
+ #endif // __ARM_FEATURE_BF16
38
+
39
+ #ifdef ADD_HALF_BINOPS
40
+ namespace mlx::core {
41
+
42
+ // clang-format off
43
+ #define fp16_bf16_binop_helper(__op__, __operator__) \
44
+ inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
45
+ return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
46
+ } \
47
+ inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
48
+ return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
49
+ }
50
+
51
+ fp16_bf16_binop_helper(+, operator+)
52
+ fp16_bf16_binop_helper(-, operator-)
53
+ fp16_bf16_binop_helper(*, operator*)
54
+ fp16_bf16_binop_helper(/, operator/)
55
+ // clang-format on
56
+
57
+ } // namespace mlx::core
58
+ #endif
@@ -0,0 +1,70 @@
1
+ // Copyright © 2024 Apple Inc.
2
+ #pragma once
3
+
4
+ #include <limits>
5
+ #include "mlx/types/half_types.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ template <typename T>
10
+ struct numeric_limits;
11
+
12
+ template <>
13
+ struct numeric_limits<float> : public std::numeric_limits<float> {};
14
+
15
+ template <>
16
+ struct numeric_limits<double> : public std::numeric_limits<double> {};
17
+
18
+ template <>
19
+ struct numeric_limits<float16_t> {
20
+ private:
21
+ union half_or_bits {
22
+ uint16_t bits;
23
+ float16_t value;
24
+ };
25
+ constexpr static float16_t bits_to_half(uint16_t v) {
26
+ return half_or_bits{v}.value;
27
+ }
28
+
29
+ public:
30
+ constexpr static float16_t lowest() {
31
+ return bits_to_half(0xFBFF);
32
+ }
33
+ static constexpr float16_t max() {
34
+ return bits_to_half(0x7BFF);
35
+ }
36
+ static constexpr float16_t epsilon() {
37
+ return bits_to_half(0x1400);
38
+ }
39
+ static constexpr float16_t infinity() {
40
+ return bits_to_half(0x7C00);
41
+ }
42
+ };
43
+
44
+ template <>
45
+ struct numeric_limits<bfloat16_t> {
46
+ private:
47
+ union bfloat_or_bits {
48
+ uint16_t bits;
49
+ bfloat16_t value;
50
+ };
51
+ constexpr static bfloat16_t bits_to_bfloat(uint16_t v) {
52
+ return bfloat_or_bits{v}.value;
53
+ }
54
+
55
+ public:
56
+ constexpr static bfloat16_t lowest() {
57
+ return bits_to_bfloat(0xFF7F);
58
+ }
59
+ static constexpr bfloat16_t max() {
60
+ return bits_to_bfloat(0x7F7F);
61
+ }
62
+ static constexpr bfloat16_t epsilon() {
63
+ return bits_to_bfloat(0x3C00);
64
+ }
65
+ static constexpr bfloat16_t infinity() {
66
+ return bits_to_bfloat(0x7F80);
67
+ }
68
+ };
69
+
70
+ } // namespace mlx::core
data/mlx/mlx/utils.cpp ADDED
@@ -0,0 +1,302 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #include <cstdlib>
4
+ #include <iostream>
5
+ #include <sstream>
6
+ #include <vector>
7
+
8
+ #include "mlx/dtype_utils.h"
9
+ #include "mlx/types/limits.h"
10
+ #include "mlx/utils.h"
11
+
12
+ namespace mlx::core {
13
+
14
+ Stream to_stream(StreamOrDevice s) {
15
+ if (std::holds_alternative<std::monostate>(s)) {
16
+ return default_stream(default_device());
17
+ } else if (std::holds_alternative<Device>(s)) {
18
+ return default_stream(std::get<Device>(s));
19
+ } else {
20
+ return std::get<Stream>(s);
21
+ }
22
+ }
23
+
24
+ Stream to_stream(StreamOrDevice s, Device default_) {
25
+ if (std::holds_alternative<std::monostate>(s)) {
26
+ return default_stream(default_);
27
+ } else if (std::holds_alternative<Device>(s)) {
28
+ return default_stream(std::get<Device>(s));
29
+ } else {
30
+ return std::get<Stream>(s);
31
+ }
32
+ }
33
+
34
+ void PrintFormatter::print(std::ostream& os, bool val) {
35
+ if (capitalize_bool) {
36
+ os << (val ? "True" : "False");
37
+ } else {
38
+ os << val;
39
+ }
40
+ }
41
+ inline void PrintFormatter::print(std::ostream& os, int16_t val) {
42
+ os << val;
43
+ }
44
+ inline void PrintFormatter::print(std::ostream& os, uint16_t val) {
45
+ os << val;
46
+ }
47
+ inline void PrintFormatter::print(std::ostream& os, int32_t val) {
48
+ os << val;
49
+ }
50
+ inline void PrintFormatter::print(std::ostream& os, uint32_t val) {
51
+ os << val;
52
+ }
53
+ inline void PrintFormatter::print(std::ostream& os, int64_t val) {
54
+ os << val;
55
+ }
56
+ inline void PrintFormatter::print(std::ostream& os, uint64_t val) {
57
+ os << val;
58
+ }
59
+ inline void PrintFormatter::print(std::ostream& os, float16_t val) {
60
+ os << val;
61
+ }
62
+ inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {
63
+ os << val;
64
+ }
65
+ inline void PrintFormatter::print(std::ostream& os, float val) {
66
+ os << val;
67
+ }
68
+ inline void PrintFormatter::print(std::ostream& os, double val) {
69
+ os << val;
70
+ }
71
+ inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
72
+ os << val.real();
73
+ if (val.imag() >= 0 || std::isnan(val.imag())) {
74
+ os << "+" << val.imag() << "j";
75
+ } else {
76
+ os << "-" << -val.imag() << "j";
77
+ }
78
+ }
79
+
80
+ PrintFormatter& get_global_formatter() {
81
+ static PrintFormatter formatter;
82
+ return formatter;
83
+ }
84
+
85
+ void abort_with_exception(const std::exception& error) {
86
+ std::ostringstream msg;
87
+ msg << "Terminating due to uncaught exception: " << error.what();
88
+ std::cerr << msg.str() << std::endl;
89
+ std::abort();
90
+ }
91
+
92
+ Dtype result_type(const std::vector<array>& arrays) {
93
+ Dtype t = bool_;
94
+ for (auto& arr : arrays) {
95
+ t = promote_types(t, arr.dtype());
96
+ }
97
+ return t;
98
+ }
99
+
100
+ Shape broadcast_shapes(const Shape& s1, const Shape& s2) {
101
+ // Use the same broadcasting rules as numpy
102
+ // https://numpy.org/doc/1.20/user/theory.broadcasting.html
103
+ // "The size of the trailing axes for both arrays in an operation must
104
+ // either be the same size or one of them must be one."
105
+ int ndim1 = s1.size();
106
+ int ndim2 = s2.size();
107
+ int ndim = std::max(ndim1, ndim2);
108
+ int diff = std::abs(ndim1 - ndim2);
109
+ const auto& big = ndim1 > ndim2 ? s1 : s2;
110
+ const auto& small = ndim1 > ndim2 ? s2 : s1;
111
+ Shape out_shape(ndim);
112
+ for (int i = ndim - 1; i >= diff; --i) {
113
+ auto a = big[i];
114
+ auto b = small[i - diff];
115
+ if (b == a) {
116
+ out_shape[i] = a;
117
+ } else if (a == 1 || b == 1) {
118
+ // 0 if a or b is 0 otherwise max(a, b)
119
+ out_shape[i] = a * b;
120
+ } else {
121
+ std::ostringstream msg;
122
+ msg << "[broadcast_shapes] Shapes " << s1 << " and " << s2
123
+ << " cannot be broadcast.";
124
+ throw std::invalid_argument(msg.str());
125
+ }
126
+ }
127
+ for (int i = diff - 1; i >= 0; --i) {
128
+ out_shape[i] = big[i];
129
+ }
130
+ return out_shape;
131
+ }
132
+
133
+ int normalize_axis_index(
134
+ int axis,
135
+ int ndim,
136
+ const std::string& msg_prefix /* = "" */) {
137
+ if (axis < -ndim || axis >= ndim) {
138
+ std::ostringstream msg;
139
+ msg << msg_prefix << "Axis " << axis << " is out of bounds for array with "
140
+ << ndim << " dimensions.";
141
+ throw std::invalid_argument(msg.str());
142
+ }
143
+ return axis < 0 ? axis + ndim : axis;
144
+ }
145
+
146
+ std::ostream& operator<<(std::ostream& os, const Device& d) {
147
+ os << "Device(";
148
+ switch (d.type) {
149
+ case Device::cpu:
150
+ os << "cpu";
151
+ break;
152
+ case Device::gpu:
153
+ os << "gpu";
154
+ break;
155
+ }
156
+ os << ", " << d.index << ")";
157
+ return os;
158
+ }
159
+
160
+ std::ostream& operator<<(std::ostream& os, const Stream& s) {
161
+ os << "Stream(";
162
+ os << s.device;
163
+ os << ", " << s.index << ")";
164
+ return os;
165
+ }
166
+
167
+ std::ostream& operator<<(std::ostream& os, int8_t x) {
168
+ os << static_cast<int>(x);
169
+ return os;
170
+ }
171
+
172
+ std::ostream& operator<<(std::ostream& os, uint8_t x) {
173
+ os << static_cast<unsigned int>(x);
174
+ return os;
175
+ }
176
+
177
+ namespace {
178
+
179
+ template <typename T>
180
+ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) {
181
+ int num_print = 3;
182
+ int n = a.shape(dim);
183
+ size_t s = a.strides()[dim];
184
+ bool is_last = dim == a.ndim() - 1;
185
+ auto prefix = is_last ? "" : std::string(7 + dim, ' ');
186
+ auto postfix = is_last ? ", " : ",\n";
187
+ os << "[";
188
+ for (int i = 0; i < n; ++i) {
189
+ os << (i == 0 ? "" : prefix);
190
+ if (i == num_print && n > 2 * num_print) {
191
+ os << "...";
192
+ i = n - num_print - 1;
193
+ index += s * (n - 2 * num_print - 1);
194
+ } else if (is_last) {
195
+ get_global_formatter().print(os, a.data<T>()[index]);
196
+ } else {
197
+ print_subarray<T>(os, a, index, dim + 1);
198
+ }
199
+ os << (i == n - 1 ? "" : postfix);
200
+ index += s;
201
+ }
202
+ os << "]";
203
+ }
204
+
205
+ template <typename T>
206
+ void print_array(std::ostream& os, const array& a) {
207
+ os << std::boolalpha;
208
+ os << "array(";
209
+ if (a.ndim() == 0) {
210
+ auto data = a.data<T>();
211
+ get_global_formatter().print(os, data[0]);
212
+ } else {
213
+ print_subarray<T>(os, a, 0, 0);
214
+ }
215
+ os << ", dtype=" << a.dtype() << ")";
216
+ os << std::noboolalpha;
217
+ }
218
+
219
+ } // namespace
220
+
221
+ std::ostream& operator<<(std::ostream& os, const Dtype& dtype) {
222
+ return os << dtype_to_string(dtype);
223
+ }
224
+
225
+ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
226
+ switch (k) {
227
+ case Dtype::Kind::b:
228
+ return os << "b";
229
+ case Dtype::Kind::i:
230
+ return os << "i";
231
+ case Dtype::Kind::u:
232
+ return os << "u";
233
+ case Dtype::Kind::f:
234
+ return os << "f";
235
+ case Dtype::Kind::c:
236
+ return os << "c";
237
+ case Dtype::Kind::V:
238
+ return os << "V";
239
+ }
240
+ return os;
241
+ }
242
+
243
+ std::ostream& operator<<(std::ostream& os, array a) {
244
+ a.eval();
245
+ dispatch_all_types(a.dtype(), [&](auto type_tag) {
246
+ print_array<MLX_GET_TYPE(type_tag)>(os, a);
247
+ });
248
+ return os;
249
+ }
250
+
251
+ namespace env {
252
+
253
+ int get_var(const char* name, int default_value) {
254
+ if (const char* buff_str = std::getenv(name)) {
255
+ return atoi(buff_str);
256
+ } else {
257
+ return default_value;
258
+ }
259
+ }
260
+
261
+ } // namespace env
262
+
263
+ template <typename T>
264
+ void set_finfo_limits(double& min, double& max, double& eps) {
265
+ min = numeric_limits<T>::lowest();
266
+ max = numeric_limits<T>::max();
267
+ eps = numeric_limits<T>::epsilon();
268
+ }
269
+
270
+ finfo::finfo(Dtype dtype) : dtype(dtype) {
271
+ if (!issubdtype(dtype, inexact)) {
272
+ std::ostringstream msg;
273
+ msg << "[finfo] dtype " << dtype << " is not inexact.";
274
+ throw std::invalid_argument(msg.str());
275
+ }
276
+ if (dtype == float32) {
277
+ set_finfo_limits<float>(min, max, eps);
278
+ } else if (dtype == float16) {
279
+ set_finfo_limits<float16_t>(min, max, eps);
280
+ } else if (dtype == bfloat16) {
281
+ set_finfo_limits<bfloat16_t>(min, max, eps);
282
+ } else if (dtype == float64) {
283
+ set_finfo_limits<double>(min, max, eps);
284
+ } else if (dtype == complex64) {
285
+ this->dtype = float32;
286
+ set_finfo_limits<float>(min, max, eps);
287
+ }
288
+ }
289
+
290
+ template <typename T>
291
+ void set_iinfo_limits(int64_t& min, uint64_t& max) {
292
+ min = std::numeric_limits<T>::min();
293
+ max = std::numeric_limits<T>::max();
294
+ }
295
+
296
+ iinfo::iinfo(Dtype dtype) : dtype(dtype) {
297
+ dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) {
298
+ set_iinfo_limits<MLX_GET_TYPE(type_tag)>(min, max);
299
+ });
300
+ }
301
+
302
+ } // namespace mlx::core