mlx 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlx might be problematic. Click here for more details.

Files changed (914) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/CMakeLists.txt +7 -0
  3. data/ext/mlx/Makefile +273 -0
  4. data/ext/mlx/extconf.rb +94 -0
  5. data/ext/mlx/mkmf.log +44 -0
  6. data/ext/mlx/native.bundle +0 -0
  7. data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
  8. data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
  9. data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
  10. data/ext/mlx/native.cpp +8027 -0
  11. data/ext/mlx/native.o +0 -0
  12. data/lib/mlx/core.rb +1678 -0
  13. data/lib/mlx/distributed_utils/common.rb +116 -0
  14. data/lib/mlx/distributed_utils/config.rb +600 -0
  15. data/lib/mlx/distributed_utils/launch.rb +490 -0
  16. data/lib/mlx/extension.rb +24 -0
  17. data/lib/mlx/nn/base.rb +388 -0
  18. data/lib/mlx/nn/init.rb +140 -0
  19. data/lib/mlx/nn/layers/activations.rb +336 -0
  20. data/lib/mlx/nn/layers/base.rb +6 -0
  21. data/lib/mlx/nn/layers/containers.rb +20 -0
  22. data/lib/mlx/nn/layers/convolution.rb +120 -0
  23. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  24. data/lib/mlx/nn/layers/distributed.rb +309 -0
  25. data/lib/mlx/nn/layers/dropout.rb +75 -0
  26. data/lib/mlx/nn/layers/embedding.rb +28 -0
  27. data/lib/mlx/nn/layers/linear.rb +79 -0
  28. data/lib/mlx/nn/layers/normalization.rb +216 -0
  29. data/lib/mlx/nn/layers/pooling.rb +167 -0
  30. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  31. data/lib/mlx/nn/layers/quantized.rb +215 -0
  32. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  33. data/lib/mlx/nn/layers/transformer.rb +330 -0
  34. data/lib/mlx/nn/layers/upsample.rb +97 -0
  35. data/lib/mlx/nn/layers.rb +18 -0
  36. data/lib/mlx/nn/losses.rb +251 -0
  37. data/lib/mlx/nn/utils.rb +167 -0
  38. data/lib/mlx/nn.rb +12 -0
  39. data/lib/mlx/optimizers/optimizers.rb +808 -0
  40. data/lib/mlx/optimizers/schedulers.rb +62 -0
  41. data/lib/mlx/optimizers.rb +9 -0
  42. data/lib/mlx/utils.rb +171 -0
  43. data/lib/mlx/version +1 -0
  44. data/lib/mlx/version.rb +5 -0
  45. data/lib/mlx.rb +64 -0
  46. data/mlx/.clang-format +87 -0
  47. data/mlx/.git +1 -0
  48. data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  49. data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
  50. data/mlx/.github/actions/build-docs/action.yml +38 -0
  51. data/mlx/.github/actions/build-linux/action.yml +38 -0
  52. data/mlx/.github/actions/build-linux-release/action.yml +42 -0
  53. data/mlx/.github/actions/build-macos/action.yml +80 -0
  54. data/mlx/.github/actions/build-macos-release/action.yml +36 -0
  55. data/mlx/.github/actions/build-windows/action.yml +26 -0
  56. data/mlx/.github/actions/setup-linux/action.yml +93 -0
  57. data/mlx/.github/actions/setup-macos/action.yml +24 -0
  58. data/mlx/.github/actions/setup-windows/action.yml +42 -0
  59. data/mlx/.github/actions/test-linux/action.yml +69 -0
  60. data/mlx/.github/actions/test-windows/action.yml +20 -0
  61. data/mlx/.github/dependabot.yml +6 -0
  62. data/mlx/.github/pull_request_template.md +12 -0
  63. data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
  64. data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
  65. data/mlx/.github/workflows/build_and_test.yml +152 -0
  66. data/mlx/.github/workflows/documentation.yml +28 -0
  67. data/mlx/.github/workflows/nightly.yml +104 -0
  68. data/mlx/.github/workflows/release.yml +256 -0
  69. data/mlx/.gitignore +81 -0
  70. data/mlx/.pre-commit-config.yaml +27 -0
  71. data/mlx/ACKNOWLEDGMENTS.md +268 -0
  72. data/mlx/CITATION.cff +24 -0
  73. data/mlx/CMakeLists.txt +437 -0
  74. data/mlx/CODE_OF_CONDUCT.md +132 -0
  75. data/mlx/CONTRIBUTING.md +38 -0
  76. data/mlx/LICENSE +21 -0
  77. data/mlx/MANIFEST.in +6 -0
  78. data/mlx/README.md +121 -0
  79. data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  80. data/mlx/benchmarks/cpp/autograd.cpp +39 -0
  81. data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  82. data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  83. data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  84. data/mlx/benchmarks/cpp/time_utils.h +39 -0
  85. data/mlx/benchmarks/numpy/single_ops.py +39 -0
  86. data/mlx/benchmarks/numpy/time_utils.py +20 -0
  87. data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  88. data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  89. data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
  90. data/mlx/benchmarks/python/comparative/README.md +15 -0
  91. data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  92. data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  93. data/mlx/benchmarks/python/comparative/compare.py +284 -0
  94. data/mlx/benchmarks/python/compile_bench.py +107 -0
  95. data/mlx/benchmarks/python/conv1d_bench.py +123 -0
  96. data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  97. data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  98. data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  99. data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  100. data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  101. data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  102. data/mlx/benchmarks/python/conv_bench.py +135 -0
  103. data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  104. data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  105. data/mlx/benchmarks/python/distributed_bench.py +66 -0
  106. data/mlx/benchmarks/python/einsum_bench.py +84 -0
  107. data/mlx/benchmarks/python/fft_bench.py +118 -0
  108. data/mlx/benchmarks/python/gather_bench.py +52 -0
  109. data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  110. data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  111. data/mlx/benchmarks/python/hadamard_bench.py +70 -0
  112. data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
  113. data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
  114. data/mlx/benchmarks/python/masked_scatter.py +212 -0
  115. data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
  116. data/mlx/benchmarks/python/rope_bench.py +35 -0
  117. data/mlx/benchmarks/python/scatter_bench.py +96 -0
  118. data/mlx/benchmarks/python/sdpa_bench.py +223 -0
  119. data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
  120. data/mlx/benchmarks/python/single_ops.py +132 -0
  121. data/mlx/benchmarks/python/synchronize_bench.py +55 -0
  122. data/mlx/benchmarks/python/time_utils.py +38 -0
  123. data/mlx/cmake/FindCUDNN.cmake +177 -0
  124. data/mlx/cmake/FindNCCL.cmake +54 -0
  125. data/mlx/cmake/Findnvpl.cmake +3 -0
  126. data/mlx/cmake/extension.cmake +50 -0
  127. data/mlx/docs/.clang-format +2 -0
  128. data/mlx/docs/.gitignore +3 -0
  129. data/mlx/docs/.nojekyll +0 -0
  130. data/mlx/docs/Doxyfile +51 -0
  131. data/mlx/docs/Makefile +18 -0
  132. data/mlx/docs/README.md +54 -0
  133. data/mlx/docs/index.html +1 -0
  134. data/mlx/docs/requirements.txt +5 -0
  135. data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
  136. data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
  137. data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
  138. data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
  139. data/mlx/docs/src/_static/mlx_logo.png +0 -0
  140. data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
  141. data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
  142. data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
  143. data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
  144. data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
  145. data/mlx/docs/src/_templates/module-base-class.rst +33 -0
  146. data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
  147. data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
  148. data/mlx/docs/src/conf.py +99 -0
  149. data/mlx/docs/src/cpp/ops.rst +7 -0
  150. data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
  151. data/mlx/docs/src/dev/extensions.rst +811 -0
  152. data/mlx/docs/src/dev/metal_debugger.rst +68 -0
  153. data/mlx/docs/src/dev/metal_logging.rst +40 -0
  154. data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
  155. data/mlx/docs/src/examples/data_parallelism.rst +91 -0
  156. data/mlx/docs/src/examples/linear_regression.rst +77 -0
  157. data/mlx/docs/src/examples/llama-inference.rst +382 -0
  158. data/mlx/docs/src/examples/mlp.rst +134 -0
  159. data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
  160. data/mlx/docs/src/index.rst +96 -0
  161. data/mlx/docs/src/install.rst +340 -0
  162. data/mlx/docs/src/python/array.rst +65 -0
  163. data/mlx/docs/src/python/cuda.rst +9 -0
  164. data/mlx/docs/src/python/data_types.rst +78 -0
  165. data/mlx/docs/src/python/devices_and_streams.rst +21 -0
  166. data/mlx/docs/src/python/distributed.rst +22 -0
  167. data/mlx/docs/src/python/export.rst +14 -0
  168. data/mlx/docs/src/python/fast.rst +16 -0
  169. data/mlx/docs/src/python/fft.rst +24 -0
  170. data/mlx/docs/src/python/linalg.rst +27 -0
  171. data/mlx/docs/src/python/memory_management.rst +16 -0
  172. data/mlx/docs/src/python/metal.rst +12 -0
  173. data/mlx/docs/src/python/nn/distributed.rst +30 -0
  174. data/mlx/docs/src/python/nn/functions.rst +40 -0
  175. data/mlx/docs/src/python/nn/init.rst +45 -0
  176. data/mlx/docs/src/python/nn/layers.rst +74 -0
  177. data/mlx/docs/src/python/nn/losses.rst +25 -0
  178. data/mlx/docs/src/python/nn/module.rst +38 -0
  179. data/mlx/docs/src/python/nn.rst +186 -0
  180. data/mlx/docs/src/python/ops.rst +184 -0
  181. data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
  182. data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
  183. data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
  184. data/mlx/docs/src/python/optimizers.rst +78 -0
  185. data/mlx/docs/src/python/random.rst +48 -0
  186. data/mlx/docs/src/python/transforms.rst +22 -0
  187. data/mlx/docs/src/python/tree_utils.rst +23 -0
  188. data/mlx/docs/src/usage/compile.rst +516 -0
  189. data/mlx/docs/src/usage/distributed.rst +572 -0
  190. data/mlx/docs/src/usage/export.rst +288 -0
  191. data/mlx/docs/src/usage/function_transforms.rst +191 -0
  192. data/mlx/docs/src/usage/indexing.rst +194 -0
  193. data/mlx/docs/src/usage/launching_distributed.rst +234 -0
  194. data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
  195. data/mlx/docs/src/usage/numpy.rst +124 -0
  196. data/mlx/docs/src/usage/quick_start.rst +67 -0
  197. data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
  198. data/mlx/docs/src/usage/unified_memory.rst +78 -0
  199. data/mlx/docs/src/usage/using_streams.rst +18 -0
  200. data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
  201. data/mlx/examples/cmake_project/README.md +26 -0
  202. data/mlx/examples/cmake_project/example.cpp +14 -0
  203. data/mlx/examples/cpp/CMakeLists.txt +12 -0
  204. data/mlx/examples/cpp/distributed.cpp +22 -0
  205. data/mlx/examples/cpp/linear_regression.cpp +54 -0
  206. data/mlx/examples/cpp/logistic_regression.cpp +54 -0
  207. data/mlx/examples/cpp/metal_capture.cpp +31 -0
  208. data/mlx/examples/cpp/timer.h +20 -0
  209. data/mlx/examples/cpp/tutorial.cpp +99 -0
  210. data/mlx/examples/export/CMakeLists.txt +22 -0
  211. data/mlx/examples/export/README.md +49 -0
  212. data/mlx/examples/export/eval_mlp.cpp +25 -0
  213. data/mlx/examples/export/eval_mlp.py +52 -0
  214. data/mlx/examples/export/train_mlp.cpp +35 -0
  215. data/mlx/examples/export/train_mlp.py +76 -0
  216. data/mlx/examples/extensions/CMakeLists.txt +78 -0
  217. data/mlx/examples/extensions/README.md +24 -0
  218. data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
  219. data/mlx/examples/extensions/axpby/axpby.h +90 -0
  220. data/mlx/examples/extensions/axpby/axpby.metal +47 -0
  221. data/mlx/examples/extensions/bindings.cpp +39 -0
  222. data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
  223. data/mlx/examples/extensions/pyproject.toml +8 -0
  224. data/mlx/examples/extensions/requirements.txt +4 -0
  225. data/mlx/examples/extensions/setup.py +18 -0
  226. data/mlx/examples/extensions/test.py +12 -0
  227. data/mlx/examples/python/linear_regression.py +46 -0
  228. data/mlx/examples/python/logistic_regression.py +49 -0
  229. data/mlx/examples/python/qqmm.py +117 -0
  230. data/mlx/mlx/3rdparty/.clang-format +2 -0
  231. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  232. data/mlx/mlx/CMakeLists.txt +107 -0
  233. data/mlx/mlx/allocator.h +75 -0
  234. data/mlx/mlx/api.h +29 -0
  235. data/mlx/mlx/array.cpp +354 -0
  236. data/mlx/mlx/array.h +647 -0
  237. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  238. data/mlx/mlx/backend/common/binary.h +97 -0
  239. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  240. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  241. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  242. data/mlx/mlx/backend/common/common.cpp +305 -0
  243. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  244. data/mlx/mlx/backend/common/compiled.h +77 -0
  245. data/mlx/mlx/backend/common/copy.h +50 -0
  246. data/mlx/mlx/backend/common/hadamard.h +109 -0
  247. data/mlx/mlx/backend/common/load.cpp +57 -0
  248. data/mlx/mlx/backend/common/matmul.h +67 -0
  249. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  250. data/mlx/mlx/backend/common/reduce.h +59 -0
  251. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  252. data/mlx/mlx/backend/common/slicing.h +20 -0
  253. data/mlx/mlx/backend/common/ternary.h +85 -0
  254. data/mlx/mlx/backend/common/unary.h +29 -0
  255. data/mlx/mlx/backend/common/utils.cpp +231 -0
  256. data/mlx/mlx/backend/common/utils.h +205 -0
  257. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  258. data/mlx/mlx/backend/cpu/arange.h +28 -0
  259. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  260. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  261. data/mlx/mlx/backend/cpu/binary.h +517 -0
  262. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  263. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  264. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  265. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  266. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  267. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  268. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  269. data/mlx/mlx/backend/cpu/copy.h +36 -0
  270. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  271. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  272. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  273. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  274. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  275. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  276. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  277. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  278. data/mlx/mlx/backend/cpu/eval.h +12 -0
  279. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  280. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  281. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  282. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  283. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  284. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  285. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  286. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  287. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  288. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  289. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  290. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  291. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  292. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  293. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  294. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  295. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  296. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  297. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  298. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  299. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  300. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  301. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  302. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  303. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  304. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  305. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  306. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  307. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  308. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  309. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  310. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  311. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  312. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  313. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  314. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  315. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  316. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  317. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  318. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  319. data/mlx/mlx/backend/cpu/unary.h +281 -0
  320. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  321. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  322. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  323. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  324. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  325. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  326. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  327. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  328. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  329. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  330. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  331. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  332. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  333. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  334. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  335. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  336. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  337. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  338. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  339. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  340. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  341. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  342. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  343. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  344. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  345. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  346. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  347. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  348. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  349. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  350. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  351. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  352. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  353. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  354. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  355. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  356. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  357. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  358. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  359. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  360. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  361. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  362. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  363. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  364. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  365. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  366. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  367. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  368. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  369. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  370. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  371. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  372. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  373. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  374. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  375. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  376. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  377. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  378. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  379. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  380. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  381. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  382. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  383. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  384. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  385. data/mlx/mlx/backend/cuda/device.h +195 -0
  386. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  387. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  388. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  389. data/mlx/mlx/backend/cuda/event.cu +415 -0
  390. data/mlx/mlx/backend/cuda/event.h +79 -0
  391. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  392. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  393. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  394. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  395. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  396. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  397. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  398. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  399. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  400. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  401. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  402. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  403. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  404. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  405. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  406. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  407. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  408. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  409. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  410. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  411. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  412. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  413. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  414. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  415. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  416. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  417. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  418. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  419. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  420. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  421. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  422. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  423. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  424. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  425. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  426. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  427. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  428. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  429. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  430. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  431. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  432. data/mlx/mlx/backend/cuda/random.cu +202 -0
  433. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  434. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  435. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  436. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  437. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  438. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  439. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  440. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  441. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  442. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  443. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  444. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  445. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  446. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  447. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  448. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  449. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  450. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  451. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  452. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  453. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  454. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  455. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  456. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  457. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  458. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  459. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  460. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  461. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  462. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  463. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  464. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  465. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  466. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  467. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  468. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  469. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  470. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  471. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  472. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  473. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  474. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  475. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  476. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  477. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  478. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  479. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  480. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  481. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  482. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  483. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  484. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  485. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  486. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  487. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  488. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  489. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  490. data/mlx/mlx/backend/cuda/utils.h +49 -0
  491. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  492. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  493. data/mlx/mlx/backend/cuda/worker.h +55 -0
  494. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  495. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  496. data/mlx/mlx/backend/gpu/copy.h +57 -0
  497. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  498. data/mlx/mlx/backend/gpu/eval.h +18 -0
  499. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  500. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  501. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  502. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  503. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  504. data/mlx/mlx/backend/metal/allocator.h +79 -0
  505. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  506. data/mlx/mlx/backend/metal/binary.h +33 -0
  507. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  508. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  509. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  510. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  511. data/mlx/mlx/backend/metal/device.cpp +816 -0
  512. data/mlx/mlx/backend/metal/device.h +289 -0
  513. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  514. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  515. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  516. data/mlx/mlx/backend/metal/event.cpp +62 -0
  517. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  518. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  519. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  520. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  521. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  522. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  523. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  524. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  525. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  526. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  527. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  528. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  529. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  530. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  531. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  532. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  533. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  534. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  535. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  536. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  537. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  538. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  539. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  540. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  541. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  542. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  543. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  544. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  545. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  546. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  547. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  548. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  549. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  550. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  551. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  552. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  553. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  554. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  555. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  556. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  557. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  558. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  559. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  560. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  561. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  562. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  563. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  564. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  565. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  566. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  567. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  568. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  569. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  570. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  571. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  572. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  573. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  574. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  575. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  576. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  577. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  578. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  579. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  580. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  581. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  582. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  583. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  584. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  585. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  586. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  587. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  588. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  589. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  590. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  591. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  592. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  593. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  594. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  595. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  596. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  597. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  598. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  599. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  600. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  601. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  602. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  603. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  604. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  605. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  606. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  607. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  608. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  609. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  610. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  611. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  612. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  613. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  614. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  615. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  616. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  617. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  618. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  619. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  620. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  621. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  622. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  623. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  624. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  625. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  626. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  627. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  628. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  629. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  630. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  631. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  632. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  633. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  634. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  635. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  636. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  637. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  638. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  639. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  640. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  641. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  642. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  643. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  644. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  645. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  646. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  647. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  648. data/mlx/mlx/backend/metal/kernels.h +375 -0
  649. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  650. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  651. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  652. data/mlx/mlx/backend/metal/matmul.h +144 -0
  653. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  654. data/mlx/mlx/backend/metal/metal.h +25 -0
  655. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  656. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  657. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  658. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  659. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  660. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  661. data/mlx/mlx/backend/metal/reduce.h +41 -0
  662. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  663. data/mlx/mlx/backend/metal/resident.h +32 -0
  664. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  665. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  666. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  667. data/mlx/mlx/backend/metal/scan.h +17 -0
  668. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  669. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  670. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  671. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  672. data/mlx/mlx/backend/metal/ternary.h +21 -0
  673. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  674. data/mlx/mlx/backend/metal/unary.h +21 -0
  675. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  676. data/mlx/mlx/backend/metal/utils.h +99 -0
  677. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  678. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  679. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  680. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  681. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  682. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  683. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  684. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  685. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  686. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  687. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  688. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  689. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  690. data/mlx/mlx/compile.cpp +1243 -0
  691. data/mlx/mlx/compile.h +45 -0
  692. data/mlx/mlx/compile_impl.h +70 -0
  693. data/mlx/mlx/device.cpp +72 -0
  694. data/mlx/mlx/device.h +56 -0
  695. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  696. data/mlx/mlx/distributed/distributed.cpp +197 -0
  697. data/mlx/mlx/distributed/distributed.h +61 -0
  698. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  699. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  700. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  701. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  702. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  703. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  704. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  705. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  706. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  707. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  708. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  709. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  710. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  711. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  712. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  713. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  714. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  715. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  716. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  717. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  718. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  719. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  720. data/mlx/mlx/distributed/ops.cpp +186 -0
  721. data/mlx/mlx/distributed/ops.h +57 -0
  722. data/mlx/mlx/distributed/primitives.cpp +95 -0
  723. data/mlx/mlx/distributed/primitives.h +156 -0
  724. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  725. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  726. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  727. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  728. data/mlx/mlx/distributed/ring/ring.h +12 -0
  729. data/mlx/mlx/distributed/utils.cpp +206 -0
  730. data/mlx/mlx/distributed/utils.h +67 -0
  731. data/mlx/mlx/dtype.cpp +197 -0
  732. data/mlx/mlx/dtype.h +116 -0
  733. data/mlx/mlx/dtype_utils.cpp +42 -0
  734. data/mlx/mlx/dtype_utils.h +119 -0
  735. data/mlx/mlx/einsum.cpp +941 -0
  736. data/mlx/mlx/einsum.h +23 -0
  737. data/mlx/mlx/event.h +58 -0
  738. data/mlx/mlx/export.cpp +1130 -0
  739. data/mlx/mlx/export.h +137 -0
  740. data/mlx/mlx/export_impl.h +99 -0
  741. data/mlx/mlx/fast.cpp +941 -0
  742. data/mlx/mlx/fast.h +103 -0
  743. data/mlx/mlx/fast_primitives.h +427 -0
  744. data/mlx/mlx/fence.h +39 -0
  745. data/mlx/mlx/fft.cpp +262 -0
  746. data/mlx/mlx/fft.h +159 -0
  747. data/mlx/mlx/graph_utils.cpp +175 -0
  748. data/mlx/mlx/graph_utils.h +67 -0
  749. data/mlx/mlx/io/CMakeLists.txt +25 -0
  750. data/mlx/mlx/io/gguf.cpp +470 -0
  751. data/mlx/mlx/io/gguf.h +20 -0
  752. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  753. data/mlx/mlx/io/load.cpp +397 -0
  754. data/mlx/mlx/io/load.h +175 -0
  755. data/mlx/mlx/io/no_gguf.cpp +20 -0
  756. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  757. data/mlx/mlx/io/safetensors.cpp +234 -0
  758. data/mlx/mlx/io.h +61 -0
  759. data/mlx/mlx/linalg.cpp +708 -0
  760. data/mlx/mlx/linalg.h +115 -0
  761. data/mlx/mlx/memory.h +80 -0
  762. data/mlx/mlx/mlx.h +25 -0
  763. data/mlx/mlx/ops.cpp +6094 -0
  764. data/mlx/mlx/ops.h +1610 -0
  765. data/mlx/mlx/primitives.cpp +5850 -0
  766. data/mlx/mlx/primitives.h +2525 -0
  767. data/mlx/mlx/random.cpp +492 -0
  768. data/mlx/mlx/random.h +283 -0
  769. data/mlx/mlx/scheduler.cpp +73 -0
  770. data/mlx/mlx/scheduler.h +189 -0
  771. data/mlx/mlx/small_vector.h +540 -0
  772. data/mlx/mlx/stream.h +42 -0
  773. data/mlx/mlx/threadpool.h +133 -0
  774. data/mlx/mlx/transforms.cpp +1065 -0
  775. data/mlx/mlx/transforms.h +231 -0
  776. data/mlx/mlx/transforms_impl.h +88 -0
  777. data/mlx/mlx/types/bf16.h +187 -0
  778. data/mlx/mlx/types/complex.h +113 -0
  779. data/mlx/mlx/types/fp16.h +234 -0
  780. data/mlx/mlx/types/half_types.h +58 -0
  781. data/mlx/mlx/types/limits.h +70 -0
  782. data/mlx/mlx/utils.cpp +302 -0
  783. data/mlx/mlx/utils.h +174 -0
  784. data/mlx/mlx/version.cpp +11 -0
  785. data/mlx/mlx/version.h +22 -0
  786. data/mlx/mlx.pc.in +52 -0
  787. data/mlx/pyproject.toml +7 -0
  788. data/mlx/python/mlx/__main__.py +27 -0
  789. data/mlx/python/mlx/_distributed_utils/common.py +135 -0
  790. data/mlx/python/mlx/_distributed_utils/config.py +631 -0
  791. data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
  792. data/mlx/python/mlx/_reprlib_fix.py +16 -0
  793. data/mlx/python/mlx/_stub_patterns.txt +36 -0
  794. data/mlx/python/mlx/extension.py +88 -0
  795. data/mlx/python/mlx/nn/__init__.py +5 -0
  796. data/mlx/python/mlx/nn/init.py +441 -0
  797. data/mlx/python/mlx/nn/layers/__init__.py +105 -0
  798. data/mlx/python/mlx/nn/layers/activations.py +661 -0
  799. data/mlx/python/mlx/nn/layers/base.py +675 -0
  800. data/mlx/python/mlx/nn/layers/containers.py +24 -0
  801. data/mlx/python/mlx/nn/layers/convolution.py +232 -0
  802. data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
  803. data/mlx/python/mlx/nn/layers/distributed.py +601 -0
  804. data/mlx/python/mlx/nn/layers/dropout.py +137 -0
  805. data/mlx/python/mlx/nn/layers/embedding.py +53 -0
  806. data/mlx/python/mlx/nn/layers/linear.py +180 -0
  807. data/mlx/python/mlx/nn/layers/normalization.py +363 -0
  808. data/mlx/python/mlx/nn/layers/pooling.py +398 -0
  809. data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
  810. data/mlx/python/mlx/nn/layers/quantized.py +426 -0
  811. data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
  812. data/mlx/python/mlx/nn/layers/transformer.py +354 -0
  813. data/mlx/python/mlx/nn/layers/upsample.py +277 -0
  814. data/mlx/python/mlx/nn/losses.py +610 -0
  815. data/mlx/python/mlx/nn/utils.py +165 -0
  816. data/mlx/python/mlx/optimizers/__init__.py +4 -0
  817. data/mlx/python/mlx/optimizers/optimizers.py +976 -0
  818. data/mlx/python/mlx/optimizers/schedulers.py +158 -0
  819. data/mlx/python/mlx/py.typed +1 -0
  820. data/mlx/python/mlx/utils.py +325 -0
  821. data/mlx/python/src/CMakeLists.txt +96 -0
  822. data/mlx/python/src/array.cpp +1525 -0
  823. data/mlx/python/src/buffer.h +124 -0
  824. data/mlx/python/src/constants.cpp +15 -0
  825. data/mlx/python/src/convert.cpp +504 -0
  826. data/mlx/python/src/convert.h +50 -0
  827. data/mlx/python/src/cuda.cpp +19 -0
  828. data/mlx/python/src/device.cpp +98 -0
  829. data/mlx/python/src/distributed.cpp +352 -0
  830. data/mlx/python/src/export.cpp +356 -0
  831. data/mlx/python/src/fast.cpp +627 -0
  832. data/mlx/python/src/fft.cpp +514 -0
  833. data/mlx/python/src/indexing.cpp +1016 -0
  834. data/mlx/python/src/indexing.h +41 -0
  835. data/mlx/python/src/linalg.cpp +663 -0
  836. data/mlx/python/src/load.cpp +531 -0
  837. data/mlx/python/src/load.h +51 -0
  838. data/mlx/python/src/memory.cpp +125 -0
  839. data/mlx/python/src/metal.cpp +98 -0
  840. data/mlx/python/src/mlx.cpp +51 -0
  841. data/mlx/python/src/mlx_func.cpp +116 -0
  842. data/mlx/python/src/mlx_func.h +31 -0
  843. data/mlx/python/src/ops.cpp +5545 -0
  844. data/mlx/python/src/random.cpp +516 -0
  845. data/mlx/python/src/small_vector.h +76 -0
  846. data/mlx/python/src/stream.cpp +147 -0
  847. data/mlx/python/src/transforms.cpp +1542 -0
  848. data/mlx/python/src/trees.cpp +311 -0
  849. data/mlx/python/src/trees.h +62 -0
  850. data/mlx/python/src/utils.cpp +98 -0
  851. data/mlx/python/src/utils.h +78 -0
  852. data/mlx/python/tests/__main__.py +5 -0
  853. data/mlx/python/tests/cuda_skip.py +62 -0
  854. data/mlx/python/tests/mlx_distributed_tests.py +314 -0
  855. data/mlx/python/tests/mlx_tests.py +116 -0
  856. data/mlx/python/tests/mpi_test_distributed.py +142 -0
  857. data/mlx/python/tests/nccl_test_distributed.py +52 -0
  858. data/mlx/python/tests/ring_test_distributed.py +131 -0
  859. data/mlx/python/tests/test_array.py +2139 -0
  860. data/mlx/python/tests/test_autograd.py +880 -0
  861. data/mlx/python/tests/test_bf16.py +196 -0
  862. data/mlx/python/tests/test_blas.py +1429 -0
  863. data/mlx/python/tests/test_compile.py +1277 -0
  864. data/mlx/python/tests/test_constants.py +41 -0
  865. data/mlx/python/tests/test_conv.py +1198 -0
  866. data/mlx/python/tests/test_conv_transpose.py +810 -0
  867. data/mlx/python/tests/test_device.py +150 -0
  868. data/mlx/python/tests/test_double.py +306 -0
  869. data/mlx/python/tests/test_einsum.py +363 -0
  870. data/mlx/python/tests/test_eval.py +200 -0
  871. data/mlx/python/tests/test_export_import.py +614 -0
  872. data/mlx/python/tests/test_fast.py +923 -0
  873. data/mlx/python/tests/test_fast_sdpa.py +647 -0
  874. data/mlx/python/tests/test_fft.py +323 -0
  875. data/mlx/python/tests/test_graph.py +37 -0
  876. data/mlx/python/tests/test_init.py +139 -0
  877. data/mlx/python/tests/test_linalg.py +621 -0
  878. data/mlx/python/tests/test_load.py +447 -0
  879. data/mlx/python/tests/test_losses.py +427 -0
  880. data/mlx/python/tests/test_memory.py +77 -0
  881. data/mlx/python/tests/test_nn.py +1986 -0
  882. data/mlx/python/tests/test_ops.py +3261 -0
  883. data/mlx/python/tests/test_optimizers.py +584 -0
  884. data/mlx/python/tests/test_quantized.py +1160 -0
  885. data/mlx/python/tests/test_random.py +392 -0
  886. data/mlx/python/tests/test_reduce.py +223 -0
  887. data/mlx/python/tests/test_tree.py +96 -0
  888. data/mlx/python/tests/test_upsample.py +100 -0
  889. data/mlx/python/tests/test_vmap.py +860 -0
  890. data/mlx/setup.py +315 -0
  891. data/mlx/tests/CMakeLists.txt +44 -0
  892. data/mlx/tests/allocator_tests.cpp +41 -0
  893. data/mlx/tests/arg_reduce_tests.cpp +204 -0
  894. data/mlx/tests/array_tests.cpp +663 -0
  895. data/mlx/tests/autograd_tests.cpp +1399 -0
  896. data/mlx/tests/blas_tests.cpp +110 -0
  897. data/mlx/tests/compile_tests.cpp +818 -0
  898. data/mlx/tests/creations_tests.cpp +239 -0
  899. data/mlx/tests/custom_vjp_tests.cpp +55 -0
  900. data/mlx/tests/device_tests.cpp +35 -0
  901. data/mlx/tests/einsum_tests.cpp +85 -0
  902. data/mlx/tests/eval_tests.cpp +93 -0
  903. data/mlx/tests/export_import_tests.cpp +164 -0
  904. data/mlx/tests/fft_tests.cpp +366 -0
  905. data/mlx/tests/gpu_tests.cpp +523 -0
  906. data/mlx/tests/linalg_tests.cpp +639 -0
  907. data/mlx/tests/load_tests.cpp +270 -0
  908. data/mlx/tests/ops_tests.cpp +4159 -0
  909. data/mlx/tests/random_tests.cpp +716 -0
  910. data/mlx/tests/scheduler_tests.cpp +121 -0
  911. data/mlx/tests/tests.cpp +26 -0
  912. data/mlx/tests/utils_tests.cpp +67 -0
  913. data/mlx/tests/vmap_tests.cpp +547 -0
  914. metadata +958 -0
@@ -0,0 +1,976 @@
1
+ # Copyright © 2023-2024 Apple Inc.
2
+
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+
5
+ import mlx.core as mx
6
+ from mlx.nn import Module
7
+ from mlx.utils import tree_flatten, tree_map, tree_merge, tree_reduce, tree_unflatten
8
+
9
+
10
+ class Optimizer:
11
+ """The base class for all optimizers. It allows us to implement an
12
+ optimizer on a per-parameter basis and apply it to a parameter tree.
13
+ """
14
+
15
+ def __init__(self, schedulers=None):
16
+ self._initialized = False
17
+ self._state = {"step": mx.array(0, mx.uint64)}
18
+ self._schedulers = {k: v for k, v in (schedulers or {}).items()}
19
+
20
+ def update(self, model: Module, gradients: dict):
21
+ """Apply the gradients to the parameters of the model and update the
22
+ model with the new parameters.
23
+
24
+ Args:
25
+ model (mlx.nn.Module): An mlx module to be updated.
26
+ gradients (dict): A Python tree of gradients, most likely computed
27
+ via :func:`mlx.nn.value_and_grad`.
28
+ """
29
+ model.update(self.apply_gradients(gradients, model))
30
+
31
+ def init(self, parameters: dict):
32
+ """Initialize the optimizer's state
33
+
34
+ This function can be used to initialize optimizers which have state
35
+ (like momentum in :class:`SGD`). Using this method is optional as the
36
+ optimizer will initialize itself if the state is not yet set. However,
37
+ there are some cases where explicit initialization is useful in order
38
+ to have access to the :attr:`Optimizer.state` before the first call to
39
+ :meth:`Optimizer.update`.
40
+
41
+ Args:
42
+ model (dict): A Python tree of parameters.
43
+
44
+ Example:
45
+ >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
46
+ >>> model = nn.Linear(2, 2)
47
+ >>> optimizer.init(model.trainable_parameters())
48
+ >>> optimizer.state.keys()
49
+ dict_keys(['step', 'learning_rate', 'weight', 'bias'])
50
+ """
51
+
52
+ # Initialize the optimizer state to match the parameter state
53
+ def update_state(params, state):
54
+ if isinstance(params, (list, tuple)):
55
+ state = list(state)
56
+ for i in range(len(state)):
57
+ state[i] = update_state(params[i], state[i])
58
+ if len(state) != len(params):
59
+ state.extend(tree_map(lambda _: {}, params[len(state) :]))
60
+ return type(params)(state)
61
+ elif isinstance(params, dict):
62
+ for k, v in params.items():
63
+ if k not in state:
64
+ state[k] = tree_map(lambda _: {}, v)
65
+ else:
66
+ state[k] = update_state(v, state[k])
67
+ return state
68
+ else:
69
+ return state
70
+
71
+ update_state(parameters, self._state)
72
+ tree_map(lambda p, s: s or self.init_single(p, s), parameters, self._state)
73
+ self._initialized = True
74
+
75
+ def init_single(self, parameter: mx.array, state: dict):
76
+ """To be extended by the children classes to implement each optimizer's
77
+ state initialization.
78
+
79
+ Args:
80
+ parameter (mx.array): A single parameter that will be optimized.
81
+ state (dict): The optimizer's state.
82
+ """
83
+ raise NotImplementedError()
84
+
85
+ def apply_gradients(self, gradients: dict, parameters: dict):
86
+ """Apply the gradients to the parameters and return the updated parameters.
87
+
88
+ Can be used to update a model via
89
+ ``model.update(opt.apply_gradients(grads, model))`` which is precisely
90
+ how :meth:`Optimizer.update` is implemented.
91
+
92
+ Args:
93
+ gradients (dict): A Python tree of gradients.
94
+ parameters (dict): A Python tree of parameters. It can be a
95
+ superset of the gradients. In that case the returned python
96
+ tree will be of the same structure as the gradients.
97
+ """
98
+ if not self._initialized:
99
+ self.init(gradients)
100
+
101
+ # Update any scheduled variables
102
+ for param, scheduler in self._schedulers.items():
103
+ self.state[param] = scheduler(self.step)
104
+
105
+ # Increment the step
106
+ self.state["step"] = self.step + 1
107
+
108
+ # Apply the update
109
+ return tree_map(self.apply_single, gradients, parameters, self.state)
110
+
111
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
112
+ """To be extended by derived classes to implement the optimizer's update.
113
+
114
+ Args:
115
+ gradient (mx.array): The ``parameter`` gradient.
116
+ parameter (mx.array): The ``parameter`` to update.
117
+ state (dict): The optimizer's state.
118
+ """
119
+ raise NotImplementedError()
120
+
121
+ @property
122
+ def state(self):
123
+ """The optimizer's state dictionary."""
124
+ return self._state
125
+
126
+ @state.setter
127
+ def state(self, state: dict):
128
+ self._initialized = False
129
+ self._state = state
130
+
131
+ @property
132
+ def step(self):
133
+ return self.state["step"]
134
+
135
+ @property
136
+ def learning_rate(self):
137
+ return self.state["learning_rate"]
138
+
139
+ @learning_rate.setter
140
+ def learning_rate(self, learning_rate: Union[float, mx.array]):
141
+ self.state["learning_rate"] = mx.array(learning_rate)
142
+
143
+ def _maybe_schedule(
144
+ self, name: str, param: Union[float, Callable[[mx.array], mx.array]]
145
+ ):
146
+ """
147
+ To be used by derived classes to optionally put a parameter on a schedule.
148
+ """
149
+ if isinstance(param, Callable):
150
+ self._schedulers[name] = param
151
+ parameter = param(self.step)
152
+ else:
153
+ parameter = mx.array(param)
154
+ self.state[name] = parameter
155
+
156
+
157
+ class MultiOptimizer(Optimizer):
158
+ """Wraps a list of optimizers with corresponding weight predicates/filters
159
+ to make it easy to use different optimizers for different weights.
160
+
161
+ The predicates take the full "path" of the weight and the weight itself and
162
+ return True if it should be considered for this optimizer. The last
163
+ optimizer in the list is a fallback optimizer and no predicate should be
164
+ given for it.
165
+
166
+ Args:
167
+ optimizers (list[Optimizer]): A list of optimizers to delegate to
168
+ filters (list[Callable[[str, array], bool]): A list of predicates that
169
+ should be one less than the provided optimizers.
170
+ """
171
+
172
+ def __init__(self, optimizers, filters: list = []):
173
+ super().__init__()
174
+ self._state = {}
175
+
176
+ if len(filters) != len(optimizers) - 1:
177
+ raise ValueError(
178
+ f"Given {len(filters)} filters but {len(optimizers)-1} needed."
179
+ )
180
+
181
+ self.optimizers = optimizers
182
+ self.filters = filters + [lambda *args, **kwargs: True]
183
+
184
+ def _split_dictionary(self, gradients: dict):
185
+ if len(self.optimizers) == 1:
186
+ return [gradients]
187
+
188
+ parts = [[] for _ in range(len(self.optimizers))]
189
+ flat_gradients = tree_flatten(gradients)
190
+ for k, g in flat_gradients:
191
+ for i, fn in enumerate(self.filters):
192
+ if fn(k, g):
193
+ parts[i].append((k, g))
194
+ break
195
+
196
+ return [tree_unflatten(p) for p in parts]
197
+
198
+ def init(self, parameters: dict):
199
+ for o, p in zip(self.optimizers, self._split_dictionary(parameters)):
200
+ o.init(p)
201
+
202
+ def apply_gradients(self, gradients: dict, parameters: dict):
203
+ tree = {}
204
+ for o, g in zip(self.optimizers, self._split_dictionary(gradients)):
205
+ tree = tree_merge(tree, o.apply_gradients(g, parameters))
206
+ return tree
207
+
208
+ @property
209
+ def state(self):
210
+ return {"states": [o.state for o in self.optimizers]}
211
+
212
+ @state.setter
213
+ def state(self, state: dict):
214
+ if "states" not in state or len(state["states"]) != len(self.optimizers):
215
+ raise ValueError("Invalid state provided")
216
+
217
+ for o, s in zip(self.optimizers, state["states"]):
218
+ o.state = s
219
+
220
+ @property
221
+ def learning_rate(self):
222
+ return self.optimizers[0].learning_rate
223
+
224
+ @learning_rate.setter
225
+ def learning_rate(self, learning_rate: Union[float, mx.array]):
226
+ for o in self.optimizers:
227
+ o.learning_rate = learning_rate
228
+
229
+
230
+ class SGD(Optimizer):
231
+ r"""The stochastic gradient descent optimizer.
232
+
233
+ Updates a parameter :math:`w` with a gradient :math:`g` as follows
234
+
235
+ .. math::
236
+
237
+ v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
238
+ w_{t+1} &= w_t - \lambda v_{t+1}
239
+
240
+ Args:
241
+ learning_rate (float or callable): The learning rate :math:`\lambda`.
242
+ momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
243
+ weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
244
+ dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
245
+ nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
246
+ """
247
+
248
+ def __init__(
249
+ self,
250
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
251
+ momentum: float = 0.0,
252
+ weight_decay: float = 0.0,
253
+ dampening: float = 0.0,
254
+ nesterov: bool = False,
255
+ ):
256
+ if nesterov and (momentum <= 0 or dampening != 0):
257
+ raise ValueError(
258
+ "Nesterov momentum requires a momentum and zero dampening."
259
+ )
260
+ super().__init__()
261
+
262
+ self._maybe_schedule("learning_rate", learning_rate)
263
+ self.momentum = momentum
264
+ self.weight_decay = weight_decay
265
+ self.dampening = dampening
266
+ self.nesterov = nesterov
267
+
268
+ def init_single(self, parameter: mx.array, state: dict):
269
+ """Initialize optimizer state"""
270
+ state["v"] = mx.zeros_like(parameter)
271
+
272
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
273
+ """Performs the SGD parameter update and stores :math:`v` in the
274
+ optimizer state."""
275
+
276
+ if self.weight_decay != 0:
277
+ gradient += self.weight_decay * parameter
278
+
279
+ if self.momentum <= 0:
280
+ return parameter - self.learning_rate.astype(gradient.dtype) * gradient
281
+
282
+ v = self.momentum * state.get("v")
283
+ if self.dampening > 0:
284
+ v += (1 - self.dampening) * gradient
285
+ else:
286
+ v += gradient
287
+
288
+ if self.nesterov:
289
+ update = gradient + self.momentum * v
290
+ else:
291
+ update = v
292
+
293
+ state["v"] = v
294
+ return parameter - self.learning_rate.astype(gradient.dtype) * update
295
+
296
+
297
+ class RMSprop(Optimizer):
298
+ r"""The RMSprop optimizer [1].
299
+
300
+ [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
301
+
302
+ .. math::
303
+
304
+ v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
305
+ w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
306
+
307
+ Args:
308
+ learning_rate (float or callable): The learning rate :math:`\lambda`.
309
+ alpha (float, optional): The smoothing constant :math:`\alpha`.
310
+ Default: ``0.99``
311
+ eps (float, optional): The term :math:`\epsilon` added to the denominator
312
+ to improve numerical stability. Default: ``1e-8``
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
318
+ alpha: float = 0.99,
319
+ eps: float = 1e-8,
320
+ ):
321
+ super().__init__()
322
+
323
+ self._maybe_schedule("learning_rate", learning_rate)
324
+ self.alpha = alpha
325
+ self.eps = eps
326
+
327
+ if self.alpha < 0.0:
328
+ raise ValueError(
329
+ f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
330
+ )
331
+ if self.eps < 0.0:
332
+ raise ValueError(
333
+ f"RMSprop epsilon should be >0, {self.eps} was provided instead"
334
+ )
335
+
336
+ def init_single(self, parameter: mx.array, state: dict):
337
+ """Initialize optimizer state"""
338
+ state["v"] = mx.zeros_like(parameter)
339
+
340
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
341
+ """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
342
+ lr = self.learning_rate.astype(gradient.dtype)
343
+ alpha = self.alpha
344
+ eps = self.eps
345
+
346
+ v = state["v"]
347
+ v = alpha * v + (1 - alpha) * mx.square(gradient)
348
+ state["v"] = v
349
+
350
+ return parameter - lr * gradient / (mx.sqrt(v) + eps)
351
+
352
+
353
+ class Adagrad(Optimizer):
354
+ r"""The Adagrad optimizer [1].
355
+
356
+ Our Adagrad implementation follows the original paper. In detail,
357
+
358
+ [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
359
+ for online learning and stochastic optimization. JMLR 2011.
360
+
361
+ .. math::
362
+
363
+ v_{t+1} &= v_t + g_t^2 \\
364
+ w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
365
+
366
+ Args:
367
+ learning_rate (float or callable): The learning rate :math:`\lambda`.
368
+ eps (float, optional): The term :math:`\epsilon` added to the
369
+ denominator to improve numerical stability. Default: ``1e-8``
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
375
+ eps: float = 1e-8,
376
+ ):
377
+ super().__init__()
378
+
379
+ self._maybe_schedule("learning_rate", learning_rate)
380
+ self.eps = eps
381
+
382
+ if self.eps < 0.0:
383
+ raise ValueError(
384
+ f"Adagrad epsilon should be >0, {self.eps} was provided instead"
385
+ )
386
+
387
+ def init_single(self, parameter: mx.array, state: dict):
388
+ """Initialize optimizer state"""
389
+ state["v"] = mx.zeros_like(parameter)
390
+
391
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
392
+ """Performs the Adagrad parameter update and stores :math:`v` in the
393
+ optimizer state."""
394
+ lr = self.learning_rate.astype(gradient.dtype)
395
+ eps = self.eps
396
+
397
+ v = state["v"] + mx.square(gradient)
398
+ state["v"] = v
399
+
400
+ return parameter - lr * gradient / (mx.sqrt(v) + eps)
401
+
402
+
403
+ class AdaDelta(Optimizer):
404
+ r"""The AdaDelta optimizer with a learning rate [1].
405
+
406
+ Our AdaDelta implementation follows the original paper. In detail,
407
+
408
+ [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
409
+
410
+ .. math::
411
+
412
+ v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
413
+ \Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
414
+ u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
415
+ w_{t+1} &= w_t - \lambda \Delta w_{t+1}
416
+
417
+ Args:
418
+ learning_rate (float or callable): The learning rate :math:`\lambda`.
419
+ rho (float, optional): The coefficient :math:`\rho` used for computing a
420
+ running average of squared gradients. Default: ``0.9``
421
+ eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
422
+ numerical stability. Default: `1e-8`
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
428
+ rho: float = 0.9,
429
+ eps: float = 1e-6,
430
+ ):
431
+ super().__init__()
432
+
433
+ self._maybe_schedule("learning_rate", learning_rate)
434
+ self.rho = rho
435
+ self.eps = eps
436
+ if self.rho < 0.0:
437
+ raise ValueError(
438
+ f"AdaDelta rho should be >=0, {self.rho} was provided instead"
439
+ )
440
+ if self.eps < 0.0:
441
+ raise ValueError(
442
+ f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
443
+ )
444
+
445
+ def init_single(self, parameter: mx.array, state: dict):
446
+ """Initialize optimizer state"""
447
+ state["v"] = mx.zeros_like(parameter)
448
+ state["u"] = mx.zeros_like(parameter)
449
+
450
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
451
+ """Performs the AdaDelta parameter update and stores :math:`v` and
452
+ :math:`u` in the optimizer state."""
453
+ lr = self.learning_rate.astype(gradient.dtype)
454
+ rho = self.rho
455
+ eps = self.eps
456
+
457
+ v = state["v"]
458
+ u = state["u"]
459
+
460
+ v = rho * v + (1 - rho) * mx.square(gradient)
461
+ d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
462
+ u = rho * u + (1 - rho) * mx.square(d)
463
+
464
+ state["v"] = v
465
+ state["u"] = u
466
+
467
+ return parameter - lr * d
468
+
469
+
470
+ class Adam(Optimizer):
471
+ r"""The Adam optimizer [1]. In detail,
472
+
473
+ [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
474
+ optimization. ICLR 2015.
475
+
476
+ .. math::
477
+
478
+ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
479
+ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
480
+ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
481
+
482
+ Args:
483
+ learning_rate (float or callable): The learning rate :math:`\lambda`.
484
+ betas (Tuple[float, float], optional): The coefficients
485
+ :math:`(\beta_1, \beta_2)` used for computing running averages of the
486
+ gradient and its square. Default: ``(0.9, 0.999)``
487
+ eps (float, optional): The term :math:`\epsilon` added to the
488
+ denominator to improve numerical stability. Default: ``1e-8``
489
+ bias_correction (bool, optional): If set to ``True``, bias correction
490
+ is applied. Default: ``False``
491
+ """
492
+
493
+ def __init__(
494
+ self,
495
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
496
+ betas: List[float] = [0.9, 0.999],
497
+ eps: float = 1e-8,
498
+ bias_correction: bool = False,
499
+ ):
500
+ super().__init__()
501
+
502
+ self._maybe_schedule("learning_rate", learning_rate)
503
+ self.betas = betas
504
+ self.eps = eps
505
+ self.bias_correction = bias_correction
506
+
507
+ def init_single(self, parameter: mx.array, state: dict):
508
+ """Initialize optimizer state"""
509
+ state["m"] = mx.zeros_like(parameter)
510
+ state["v"] = mx.zeros_like(parameter)
511
+
512
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
513
+ """Performs the Adam parameter update and stores :math:`v` and
514
+ :math:`m` in the optimizer state."""
515
+ lr = self.learning_rate.astype(gradient.dtype)
516
+ b1, b2 = self.betas
517
+ eps = self.eps
518
+ bias_correction = self.bias_correction
519
+ step = self.step
520
+
521
+ m = state["m"]
522
+ v = state["v"]
523
+ m = b1 * m + (1 - b1) * gradient
524
+ v = b2 * v + (1 - b2) * mx.square(gradient)
525
+ state["m"] = m
526
+ state["v"] = v
527
+
528
+ if bias_correction:
529
+ c1 = (lr / (1 - b1**step)).astype(gradient.dtype)
530
+ c2 = mx.rsqrt(1 - b2**step).astype(gradient.dtype)
531
+ numerator = c1 * m
532
+ denominator = mx.sqrt(v) * c2 + eps
533
+ return parameter - numerator / denominator
534
+ else:
535
+ return parameter - lr * m / (mx.sqrt(v) + eps)
536
+
537
+
538
+ class AdamW(Adam):
539
+ r"""The AdamW optimizer [1]. We update the weights with a weight_decay
540
+ (:math:`\lambda`) value:
541
+
542
+ [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
543
+ regularization. ICLR 2019.
544
+
545
+ .. math::
546
+
547
+ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
548
+ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
549
+ w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} + \lambda w_t)
550
+
551
+ Args:
552
+ learning_rate (float or callable): The learning rate :math:`\alpha`.
553
+ betas (Tuple[float, float], optional): The coefficients
554
+ :math:`(\beta_1, \beta_2)` used for computing running averages of the
555
+ gradient and its square. Default: ``(0.9, 0.999)``
556
+ eps (float, optional): The term :math:`\epsilon` added to the
557
+ denominator to improve numerical stability. Default: ``1e-8``
558
+ weight_decay (float, optional): The weight decay :math:`\lambda`.
559
+ Default: ``0.01``.
560
+ bias_correction (bool, optional): If set to ``True``, bias correction
561
+ is applied. Default: ``False``
562
+ """
563
+
564
+ def __init__(
565
+ self,
566
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
567
+ betas: List[float] = [0.9, 0.999],
568
+ eps: float = 1e-8,
569
+ weight_decay: float = 0.01,
570
+ bias_correction: bool = False,
571
+ ):
572
+ super().__init__(
573
+ learning_rate=learning_rate,
574
+ betas=betas,
575
+ eps=eps,
576
+ bias_correction=bias_correction,
577
+ )
578
+ self.weight_decay = weight_decay
579
+
580
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
581
+ """Performs the AdamW parameter update by modifying the parameters
582
+ passed into Adam.
583
+ """
584
+
585
+ lr = self.learning_rate.astype(gradient.dtype)
586
+ return super().apply_single(
587
+ gradient, parameter * (1 - lr * self.weight_decay), state
588
+ )
589
+
590
+
591
+ class Adamax(Adam):
592
+ r"""The Adamax optimizer, a variant of Adam based on the infinity norm [1].
593
+
594
+ Our Adam implementation follows the original paper and omits the bias
595
+ correction in the first and second moment estimates. In detail,
596
+
597
+ [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
598
+ optimization. ICLR 2015.
599
+
600
+ .. math::
601
+
602
+ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
603
+ v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
604
+ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
605
+
606
+ Args:
607
+ learning_rate (float or callable): The learning rate :math:`\lambda`.
608
+ betas (Tuple[float, float], optional): The coefficients
609
+ :math:`(\beta_1, \beta_2)` used for computing running averages of the
610
+ gradient and its square. Default: ``(0.9, 0.999)``
611
+ eps (float, optional): The term :math:`\epsilon` added to the
612
+ denominator to improve numerical stability. Default: ``1e-8``
613
+ """
614
+
615
+ def __init__(
616
+ self,
617
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
618
+ betas: List[float] = [0.9, 0.999],
619
+ eps: float = 1e-8,
620
+ ):
621
+ super().__init__(learning_rate, betas, eps)
622
+ if not 0.0 <= eps:
623
+ raise ValueError(
624
+ f"Epsilon value should be >=0, {self.eps} was provided instead"
625
+ )
626
+
627
+ def init_single(self, parameter: mx.array, state: dict):
628
+ """Initialize optimizer state"""
629
+ state["m"] = mx.zeros_like(parameter)
630
+ state["v"] = mx.zeros_like(parameter)
631
+
632
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
633
+ """Performs the Adamax parameter update and stores :math:`v` and
634
+ :math:`m` in the optimizer state."""
635
+ lr = self.learning_rate.astype(gradient.dtype)
636
+ b1, b2 = self.betas
637
+ eps = self.eps
638
+
639
+ m = state["m"]
640
+ v = state["v"]
641
+
642
+ m = b1 * m + (1 - b1) * gradient
643
+ v = mx.maximum(b2 * v, mx.abs(gradient))
644
+ state["m"] = m
645
+ state["v"] = v
646
+
647
+ return parameter - lr * m / (v + eps)
648
+
649
+
650
+ class Lion(Optimizer):
651
+ r"""The Lion optimizer [1].
652
+
653
+ Since updates are computed through the sign operation, they tend to
654
+ have larger norm than for other optimizers such as SGD and Adam.
655
+ We recommend a learning rate that is 3-10x smaller than AdamW and a
656
+ weight decay 3-10x larger than AdamW to maintain the strength
657
+ (lr * wd). Our Lion implementation follows the original paper. In
658
+ detail,
659
+
660
+ [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
661
+ preprint arXiv:2302.06675.
662
+
663
+ .. math::
664
+
665
+ c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
666
+ m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t \\
667
+ w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
668
+
669
+ Args:
670
+ learning_rate (float or callable): The learning rate :math:`\eta`.
671
+ betas (Tuple[float, float], optional): The coefficients
672
+ :math:`(\beta_1, \beta_2)` used for computing the gradient
673
+ momentum and update direction. Default: ``(0.9, 0.99)``
674
+ weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
675
+ """
676
+
677
+ def __init__(
678
+ self,
679
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
680
+ betas: List[float] = [0.9, 0.99],
681
+ weight_decay: float = 0.0,
682
+ ):
683
+ super().__init__()
684
+
685
+ self._maybe_schedule("learning_rate", learning_rate)
686
+ self.betas = betas
687
+ self.weight_decay = weight_decay
688
+
689
+ def init_single(self, parameter: mx.array, state: dict):
690
+ """Initialize optimizer state"""
691
+ state["m"] = mx.zeros_like(parameter)
692
+
693
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
694
+ """Performs the Lion parameter update and stores :math:`m`
695
+ in the optimizer state."""
696
+ lr = self.learning_rate.astype(gradient.dtype)
697
+ b1, b2 = self.betas
698
+ weight_decay = self.weight_decay
699
+
700
+ m = state["m"]
701
+ c = b1 * m + (1 - b1) * gradient
702
+ state["m"] = b2 * m + (1 - b2) * gradient
703
+ if weight_decay > 0:
704
+ parameter = (1 - lr * weight_decay) * parameter
705
+ return parameter - lr * mx.sign(c)
706
+
707
+
708
+ class Adafactor(Optimizer):
709
+ r"""The Adafactor optimizer.
710
+
711
+ Our Adafactor implementation follows the original paper: `Adafactor:
712
+ Adaptive Learning Rates with Sublinear Memory Cost
713
+ <https://arxiv.org/abs/1804.04235>`_
714
+
715
+ Args:
716
+ learning_rate (float or callable, optional): The learning rate.
717
+ Default: ``None``.
718
+ eps (tuple(float, float), optional): The first term :math:`\epsilon_1`
719
+ added to the square of the gradients to improve numerical
720
+ stability and the second term :math:`\epsilon_2` is used for
721
+ parameter scaling if ``parameter_scale`` is set to ``True``.
722
+ Default: ``(1e-30, 1e-3)``.
723
+ clip_threshold (float, optional): Clips the unscaled update at
724
+ ``clip_threshold``. Default: ``1.0``.
725
+ decay_rate (float, optional): Coefficient for the running average
726
+ of the squared gradient. Default: ``-0.8``.
727
+ beta_1 (float, optional): If set to a value bigger than zero
728
+ then first moment will be used. Default: ``None``.
729
+ weight_decay (float, optional): The weight decay :math:`\lambda`.
730
+ Default: ``0.0``.
731
+ scale_parameter (bool, optional): If set to ``True`` the learning rate
732
+ will be scaled by :math:`\max(\epsilon_1, \text{RMS}(w_{t-1}))`.
733
+ Default: ``True``.
734
+ relative_step (bool, optional): If set to ``True`` the ``learning_rate``
735
+ will be ignored and relative step size will be computed.
736
+ Default: ``True``.
737
+ warmup_init (bool, optional): If set to ``True`` then the relative
738
+ step size will be calculated by the current step. Default:
739
+ ``False``.
740
+ """
741
+
742
+ def __init__(
743
+ self,
744
+ learning_rate: Union[float, Callable[[mx.array], mx.array], None] = None,
745
+ eps: Tuple[float, float] = (1e-30, 1e-3),
746
+ clip_threshold: float = 1.0,
747
+ decay_rate: float = -0.8,
748
+ beta_1: Optional[float] = None,
749
+ weight_decay: float = 0.0,
750
+ scale_parameter: bool = True,
751
+ relative_step: bool = True,
752
+ warmup_init: bool = False,
753
+ ):
754
+ super().__init__()
755
+ if learning_rate is not None:
756
+ self._maybe_schedule("learning_rate", learning_rate)
757
+ self.eps = eps
758
+ self.clip_threshold = clip_threshold
759
+ self.decay_rate = decay_rate
760
+ self.beta_1 = beta_1
761
+ self.weight_decay = weight_decay
762
+ self.scale_parameter = scale_parameter
763
+ self.relative_step = relative_step
764
+ self.warmup_init = warmup_init
765
+
766
+ def init_single(self, parameter: mx.array, state: dict):
767
+ """Initialize optimizer state"""
768
+ if parameter.ndim >= 2:
769
+ shape = parameter.shape
770
+ dtype = parameter.dtype
771
+ state["exp_avg_sq_row"] = mx.zeros(shape[:-1], dtype=dtype)
772
+ state["exp_avg_sq_col"] = mx.zeros(shape[:-2] + shape[-1:], dtype=dtype)
773
+ else:
774
+ state["exp_avg_sq"] = mx.zeros_like(parameter)
775
+
776
+ if self.beta_1 is not None:
777
+ state["exp_avg"] = mx.zeros_like(parameter)
778
+
779
+ def _compute_rms(self, inputs):
780
+ return mx.sqrt(mx.mean(mx.square(inputs)))
781
+
782
+ def _compute_learning_rate(self, step, parameter_rms):
783
+ if self.relative_step:
784
+ min_step = 1e-6 * step if self.warmup_init else 1e-2
785
+ relative_step_size = mx.minimum(min_step, mx.rsqrt(step))
786
+ else:
787
+ relative_step_size = self.learning_rate
788
+
789
+ relative_step_size = relative_step_size.astype(parameter_rms.dtype)
790
+ parameter_scale = 1.0
791
+ if self.scale_parameter:
792
+ parameter_scale = mx.maximum(self.eps[1], parameter_rms)
793
+ return parameter_scale * relative_step_size
794
+
795
+ def _approximate_exp_moving_avg(self, exp_avg_sq_row, exp_avg_sq_col):
796
+ r_factor = mx.rsqrt(
797
+ exp_avg_sq_row / mx.mean(exp_avg_sq_row, axis=-1, keepdims=True)
798
+ )
799
+ c_factor = mx.rsqrt(exp_avg_sq_col)
800
+ return mx.matmul(
801
+ mx.expand_dims(r_factor, axis=-1), mx.expand_dims(c_factor, axis=0)
802
+ )
803
+
804
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
805
+ """Performs the Adafactor parameter and state update."""
806
+ factored = gradient.ndim >= 2
807
+
808
+ step = self.step
809
+ use_first_moment = self.beta_1 is not None
810
+
811
+ parameter_rms = self._compute_rms(parameter)
812
+ learning_rate = self._compute_learning_rate(step, parameter_rms)
813
+ beta_2 = 1.0 - (step**self.decay_rate).astype(parameter_rms.dtype)
814
+ update = mx.square(gradient) + self.eps[0]
815
+
816
+ if factored:
817
+ exp_avg_sq_row = state["exp_avg_sq_row"]
818
+ exp_avg_sq_col = state["exp_avg_sq_col"]
819
+ exp_avg_sq_row = (beta_2 * exp_avg_sq_row) + (
820
+ (1 - beta_2) * mx.mean(update, axis=-1)
821
+ )
822
+ exp_avg_sq_col = (beta_2 * exp_avg_sq_col) + (
823
+ (1 - beta_2) * mx.mean(update, axis=-2)
824
+ )
825
+ state["exp_avg_sq_row"] = exp_avg_sq_row
826
+ state["exp_avg_sq_col"] = exp_avg_sq_col
827
+ update = self._approximate_exp_moving_avg(exp_avg_sq_row, exp_avg_sq_col)
828
+ update = update * gradient
829
+ else:
830
+ exp_avg_sq = state["exp_avg_sq"]
831
+ exp_avg_sq = (beta_2 * exp_avg_sq) + ((1 - beta_2) * update)
832
+ state["exp_avg_sq"] = exp_avg_sq
833
+ update = mx.rsqrt(exp_avg_sq) * gradient
834
+
835
+ update = update / mx.maximum(
836
+ 1.0, self._compute_rms(update) / self.clip_threshold
837
+ )
838
+ update = learning_rate * update
839
+
840
+ if use_first_moment:
841
+ exp_avg = state["exp_avg"]
842
+ exp_avg = (self.beta_1 * exp_avg) + ((1 - self.beta_1) * update)
843
+ state["exp_avg"] = exp_avg
844
+ update = exp_avg
845
+
846
+ if self.weight_decay != 0:
847
+ parameter += parameter * (-self.weight_decay * learning_rate)
848
+ return parameter - update
849
+
850
+
851
+ class Muon(Optimizer):
852
+ r"""The Muon optimizer.
853
+
854
+ Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
855
+ original implementation: `Muon: An optimizer for hidden layers in neural
856
+ networks <https://kellerjordan.github.io/posts/muon/>`_
857
+
858
+ Note:
859
+ - Muon may be sub-optimal for the embedding layer, the final fully
860
+ connected layer, or any 0D/1D parameters. Those should be optimized
861
+ by a different method (e.g., :class:`AdamW`).
862
+ - For 4D convolutional filters, it works by flattening their last
863
+ dimensions.
864
+
865
+ Args:
866
+ learning_rate (float or callable): The learning rate.
867
+ momentum (float, optional): The momentum strength. Default: ``0.95``
868
+ weight_decay (float, optional): The weight decay (L2 penalty).
869
+ Default: ``0.01``
870
+ nesterov (bool, optional): Enables Nesterov momentum. Recommended for
871
+ better performance. Default: ``True``
872
+ ns_steps (int, optional): Number of Newton-Schulz iteration steps for
873
+ orthogonalization. Default: ``5``
874
+ """
875
+
876
+ def __init__(
877
+ self,
878
+ learning_rate: Union[float, Callable[[mx.array], mx.array]],
879
+ momentum: float = 0.95,
880
+ weight_decay: float = 0.01,
881
+ nesterov: bool = True,
882
+ ns_steps: int = 5,
883
+ ):
884
+ super().__init__()
885
+
886
+ self._maybe_schedule("learning_rate", learning_rate)
887
+ self.momentum = momentum
888
+ self.weight_decay = weight_decay
889
+ self.nesterov = nesterov
890
+ self.ns_steps = ns_steps
891
+
892
+ def init_single(self, parameter: mx.array, state: dict):
893
+ """Initialize optimizer state"""
894
+ state["v"] = mx.zeros_like(parameter)
895
+
896
+ def _zeropower_via_newtonschulz5(self, X, steps: int):
897
+ assert (
898
+ X.ndim == 2
899
+ ), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
900
+ a, b, c = (3.4445, -4.7750, 2.0315)
901
+ transpose_needed = X.shape[-2] > X.shape[-1]
902
+
903
+ if transpose_needed:
904
+ X = X.T
905
+
906
+ X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
907
+
908
+ for _ in range(steps):
909
+ A = X @ X.T
910
+ B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
911
+ X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
912
+
913
+ if transpose_needed:
914
+ X = X.T
915
+ return X
916
+
917
+ def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
918
+ """Performs the Muon parameter update"""
919
+
920
+ if self.weight_decay != 0:
921
+ gradient = gradient + self.weight_decay * parameter
922
+
923
+ v = self.momentum * state["v"]
924
+ v = v + (1 - self.momentum) * gradient
925
+ state["v"] = v
926
+
927
+ if self.nesterov:
928
+ update = gradient * (1 - self.momentum) + v * self.momentum
929
+ else:
930
+ update = v
931
+
932
+ lr = self.learning_rate.astype(gradient.dtype)
933
+
934
+ if update.ndim >= 2:
935
+ original_shape = update.shape
936
+ reshape_needed = update.ndim > 2
937
+
938
+ if reshape_needed:
939
+ update = mx.reshape(update, (update.shape[0], -1))
940
+
941
+ update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
942
+
943
+ if reshape_needed:
944
+ update = mx.reshape(update, original_shape)
945
+
946
+ lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
947
+
948
+ return parameter - lr * update
949
+
950
+
951
+ def clip_grad_norm(grads, max_norm):
952
+ """Clips the global norm of the gradients.
953
+
954
+ This function ensures that the global norm of the gradients does not exceed
955
+ ``max_norm``. It scales down the gradients proportionally if their norm is
956
+ greater than ``max_norm``.
957
+
958
+ Example:
959
+ >>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
960
+ >>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
961
+ >>> print(clipped_grads)
962
+ {"w1": mx.array([...]), "w2": mx.array([...])}
963
+
964
+ Args:
965
+ grads (dict): A dictionary containing the gradient arrays.
966
+ max_norm (float): The maximum allowed global norm of the gradients.
967
+
968
+ Returns:
969
+ (dict, float): The possibly rescaled gradients and the original
970
+ gradient norm.
971
+ """
972
+ norm_squared = tree_reduce(lambda acc, g: acc + g.square().sum(), grads, 0.0)
973
+ total_norm = mx.sqrt(norm_squared)
974
+ normalizer = mx.minimum(max_norm / (total_norm + 1e-6), 1.0)
975
+ clipped_grads = tree_map(lambda g: g * normalizer, grads)
976
+ return clipped_grads, total_norm