mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,65 @@
1
+ # This is a basic version file for the Config-mode of find_package().
2
+ # It is used by write_basic_package_version_file() as input file for configure_file()
3
+ # to create a version-file which can be installed along a config.cmake file.
4
+ #
5
+ # The created file sets PACKAGE_VERSION_EXACT if the current version string and
6
+ # the requested version string are exactly the same and it sets
7
+ # PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
8
+ # but only if the requested major version is the same as the current one.
9
+ # The variable CVF_VERSION must be set before calling configure_file().
10
+
11
+
12
+ set(PACKAGE_VERSION "0.30.1")
13
+
14
+ if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION)
15
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
16
+ else()
17
+
18
+ if("0.30.1" MATCHES "^([0-9]+)\\.")
19
+ set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}")
20
+ if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0)
21
+ string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}")
22
+ endif()
23
+ else()
24
+ set(CVF_VERSION_MAJOR "0.30.1")
25
+ endif()
26
+
27
+ if(PACKAGE_FIND_VERSION_RANGE)
28
+ # both endpoints of the range must have the expected major version
29
+ math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1")
30
+ if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
31
+ OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR)
32
+ OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT)))
33
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
34
+ elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
35
+ AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX)
36
+ OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX)))
37
+ set(PACKAGE_VERSION_COMPATIBLE TRUE)
38
+ else()
39
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
40
+ endif()
41
+ else()
42
+ if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR)
43
+ set(PACKAGE_VERSION_COMPATIBLE TRUE)
44
+ else()
45
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
46
+ endif()
47
+
48
+ if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION)
49
+ set(PACKAGE_VERSION_EXACT TRUE)
50
+ endif()
51
+ endif()
52
+ endif()
53
+
54
+
55
+ # if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
56
+ if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "")
57
+ return()
58
+ endif()
59
+
60
+ # check that the installed version has the same 32/64bit-ness as the one which is currently searching:
61
+ if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8")
62
+ math(EXPR installedBits "8 * 8")
63
+ set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)")
64
+ set(PACKAGE_VERSION_UNSUITABLE TRUE)
65
+ endif()
@@ -0,0 +1,19 @@
1
+ #----------------------------------------------------------------
2
+ # Generated CMake target import file for configuration "Release".
3
+ #----------------------------------------------------------------
4
+
5
+ # Commands may need to know the format version.
6
+ set(CMAKE_IMPORT_FILE_VERSION 1)
7
+
8
+ # Import target "mlx" for configuration "Release"
9
+ set_property(TARGET mlx APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
10
+ set_target_properties(mlx PROPERTIES
11
+ IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libmlx.so"
12
+ IMPORTED_SONAME_RELEASE "libmlx.so"
13
+ )
14
+
15
+ list(APPEND _cmake_import_check_targets mlx )
16
+ list(APPEND _cmake_import_check_files_for_mlx "${_IMPORT_PREFIX}/lib/libmlx.so" )
17
+
18
+ # Commands beyond this point should not need to know the version.
19
+ set(CMAKE_IMPORT_FILE_VERSION)
@@ -0,0 +1,106 @@
1
+ # Generated by CMake
2
+
3
+ if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
4
+ message(FATAL_ERROR "CMake >= 2.8.3 required")
5
+ endif()
6
+ if(CMAKE_VERSION VERSION_LESS "2.8.3")
7
+ message(FATAL_ERROR "CMake >= 2.8.3 required")
8
+ endif()
9
+ cmake_policy(PUSH)
10
+ cmake_policy(VERSION 2.8.3...4.0)
11
+ #----------------------------------------------------------------
12
+ # Generated CMake target import file.
13
+ #----------------------------------------------------------------
14
+
15
+ # Commands may need to know the format version.
16
+ set(CMAKE_IMPORT_FILE_VERSION 1)
17
+
18
+ # Protect against multiple inclusion, which would fail when already imported targets are added once more.
19
+ set(_cmake_targets_defined "")
20
+ set(_cmake_targets_not_defined "")
21
+ set(_cmake_expected_targets "")
22
+ foreach(_cmake_expected_target IN ITEMS mlx)
23
+ list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
24
+ if(TARGET "${_cmake_expected_target}")
25
+ list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
26
+ else()
27
+ list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
28
+ endif()
29
+ endforeach()
30
+ unset(_cmake_expected_target)
31
+ if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
32
+ unset(_cmake_targets_defined)
33
+ unset(_cmake_targets_not_defined)
34
+ unset(_cmake_expected_targets)
35
+ unset(CMAKE_IMPORT_FILE_VERSION)
36
+ cmake_policy(POP)
37
+ return()
38
+ endif()
39
+ if(NOT _cmake_targets_defined STREQUAL "")
40
+ string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
41
+ string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
42
+ message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
43
+ endif()
44
+ unset(_cmake_targets_defined)
45
+ unset(_cmake_targets_not_defined)
46
+ unset(_cmake_expected_targets)
47
+
48
+
49
+ # Compute the installation prefix relative to this file.
50
+ get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
51
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
52
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
53
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
54
+ if(_IMPORT_PREFIX STREQUAL "/")
55
+ set(_IMPORT_PREFIX "")
56
+ endif()
57
+
58
+ # Create imported target mlx
59
+ add_library(mlx SHARED IMPORTED)
60
+
61
+ set_target_properties(mlx PROPERTIES
62
+ INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
63
+ )
64
+
65
+ # Load information for each installed configuration.
66
+ file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/MLXTargets-*.cmake")
67
+ foreach(_cmake_config_file IN LISTS _cmake_config_files)
68
+ include("${_cmake_config_file}")
69
+ endforeach()
70
+ unset(_cmake_config_file)
71
+ unset(_cmake_config_files)
72
+
73
+ # Cleanup temporary variables.
74
+ set(_IMPORT_PREFIX)
75
+
76
+ # Loop over all imported files and verify that they actually exist
77
+ foreach(_cmake_target IN LISTS _cmake_import_check_targets)
78
+ if(CMAKE_VERSION VERSION_LESS "3.28"
79
+ OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
80
+ OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
81
+ foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
82
+ if(NOT EXISTS "${_cmake_file}")
83
+ message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
84
+ \"${_cmake_file}\"
85
+ but this file does not exist. Possible reasons include:
86
+ * The file was deleted, renamed, or moved to another location.
87
+ * An install or uninstall procedure did not complete successfully.
88
+ * The installation package was faulty and contained
89
+ \"${CMAKE_CURRENT_LIST_FILE}\"
90
+ but not all the files it references.
91
+ ")
92
+ endif()
93
+ endforeach()
94
+ endif()
95
+ unset(_cmake_file)
96
+ unset("_cmake_import_check_files_for_${_cmake_target}")
97
+ endforeach()
98
+ unset(_cmake_target)
99
+ unset(_cmake_import_check_targets)
100
+
101
+ # This file does not depend on other imported targets which have
102
+ # been exported from the same project but in a separate export set.
103
+
104
+ # Commands beyond this point should not need to know the version.
105
+ set(CMAKE_IMPORT_FILE_VERSION)
106
+ cmake_policy(POP)
@@ -0,0 +1,50 @@
1
+ include(CMakeParseArguments)
2
+
3
+ # clang format off
4
+ #
5
+ # ##############################################################################
6
+ # Build metal library
7
+ #
8
+ # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
9
+ # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
10
+ #
11
+ # Args: TARGET: Custom target to be added for the metal library TITLE: Name of
12
+ # the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
13
+ # of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
14
+ # files (like headers) DEBUG: Boolean, if true, enables debug compile options
15
+ # for this specific library. If not provided, uses global MLX_METAL_DEBUG.
16
+ #
17
+ # clang format on
18
+
19
+ macro(mlx_build_metallib)
20
+ # Parse args
21
+ set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
22
+ set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
23
+ cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
24
+
25
+ # Set output
26
+ set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
27
+
28
+ # Collect compile options
29
+ set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
30
+ if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
31
+ set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
32
+ -frecord-sources)
33
+ endif()
34
+
35
+ # Prepare metallib build command
36
+ add_custom_command(
37
+ OUTPUT ${MTLLIB_BUILD_TARGET}
38
+ COMMAND
39
+ xcrun -sdk macosx metal
40
+ "$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
41
+ ${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
42
+ DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
43
+ COMMAND_EXPAND_LISTS
44
+ COMMENT "Building ${MTLLIB_TITLE}.metallib"
45
+ VERBATIM)
46
+
47
+ # Add metallib custom target
48
+ add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
49
+
50
+ endmacro(mlx_build_metallib)
mlx/utils.py ADDED
@@ -0,0 +1,325 @@
1
+ # Copyright © 2023 Apple Inc.
2
+ from collections import defaultdict
3
+ from itertools import zip_longest
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+
7
+ def tree_map(
8
+ fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
9
+ ) -> Any:
10
+ """Applies ``fn`` to the leaves of the Python tree ``tree`` and
11
+ returns a new collection with the results.
12
+
13
+ If ``rest`` is provided, every item is assumed to be a superset of ``tree``
14
+ and the corresponding leaves are provided as extra positional arguments to
15
+ ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
16
+ than to :func:`map`.
17
+
18
+ The keyword argument ``is_leaf`` decides what constitutes a leaf from
19
+ ``tree`` similar to :func:`tree_flatten`.
20
+
21
+ .. code-block:: python
22
+
23
+ import mlx.nn as nn
24
+ from mlx.utils import tree_map
25
+
26
+ model = nn.Linear(10, 10)
27
+ print(model.parameters().keys())
28
+ # dict_keys(['weight', 'bias'])
29
+
30
+ # square the parameters
31
+ model.update(tree_map(lambda x: x*x, model.parameters()))
32
+
33
+ Args:
34
+ fn (callable): The function that processes the leaves of the tree.
35
+ tree (Any): The main Python tree that will be iterated upon.
36
+ rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
37
+ is_leaf (callable, optional): An optional callable that returns ``True``
38
+ if the passed object is considered a leaf or ``False`` otherwise.
39
+
40
+ Returns:
41
+ A Python tree with the new values returned by ``fn``.
42
+ """
43
+ if is_leaf is not None and is_leaf(tree):
44
+ return fn(tree, *rest)
45
+ elif isinstance(tree, (list, tuple)):
46
+ TreeType = type(tree)
47
+ subtrees = (
48
+ tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
49
+ for i, child in enumerate(tree)
50
+ )
51
+ return TreeType(*subtrees) if hasattr(tree, "_fields") else TreeType(subtrees)
52
+ elif isinstance(tree, dict):
53
+ return {
54
+ k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
55
+ for k, child in tree.items()
56
+ }
57
+ else:
58
+ return fn(tree, *rest)
59
+
60
+
61
+ def tree_map_with_path(
62
+ fn: Callable,
63
+ tree: Any,
64
+ *rest: Any,
65
+ is_leaf: Optional[Callable] = None,
66
+ path: Optional[Any] = None,
67
+ ) -> Any:
68
+ """Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
69
+ returns a new collection with the results.
70
+
71
+ This function is the same :func:`tree_map` but the ``fn`` takes the path as
72
+ the first argument followed by the remaining tree nodes.
73
+
74
+ Args:
75
+ fn (callable): The function that processes the leaves of the tree.
76
+ tree (Any): The main Python tree that will be iterated upon.
77
+ rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
78
+ is_leaf (Optional[Callable]): An optional callable that returns ``True``
79
+ if the passed object is considered a leaf or ``False`` otherwise.
80
+ path (Optional[Any]): Prefix will be added to the result.
81
+
82
+ Returns:
83
+ A Python tree with the new values returned by ``fn``.
84
+
85
+ Example:
86
+ >>> from mlx.utils import tree_map_with_path
87
+ >>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
88
+ >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
89
+ model.0.w
90
+ model.0.b
91
+ model.1.w
92
+ model.1.b
93
+ """
94
+ if is_leaf is not None and is_leaf(tree):
95
+ return fn(path, tree, *rest)
96
+ elif isinstance(tree, (list, tuple)):
97
+ prefix = f"{path}." if path else ""
98
+ TreeType = type(tree)
99
+ return TreeType(
100
+ tree_map_with_path(
101
+ fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}"
102
+ )
103
+ for i, child in enumerate(tree)
104
+ )
105
+ elif isinstance(tree, dict):
106
+ prefix = f"{path}." if path else ""
107
+ return {
108
+ k: tree_map_with_path(
109
+ fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}"
110
+ )
111
+ for k, child in tree.items()
112
+ }
113
+ else:
114
+ return fn(path, tree, *rest)
115
+
116
+
117
+ def tree_flatten(
118
+ tree: Any,
119
+ prefix: str = "",
120
+ is_leaf: Optional[Callable] = None,
121
+ destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
122
+ ) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
123
+ """Flattens a Python tree to a list of key, value tuples.
124
+
125
+ The keys are using the dot notation to define trees of arbitrary depth and
126
+ complexity.
127
+
128
+ .. code-block:: python
129
+
130
+ from mlx.utils import tree_flatten
131
+
132
+ print(tree_flatten([[[0]]]))
133
+ # [("0.0.0", 0)]
134
+
135
+ print(tree_flatten([[[0]]], prefix=".hello"))
136
+ # [("hello.0.0.0", 0)]
137
+
138
+ tree_flatten({"a": {"b": 1}}, destination={})
139
+ {"a.b": 1}
140
+
141
+ .. note::
142
+ Dictionaries should have keys that are valid Python identifiers.
143
+
144
+ Args:
145
+ tree (Any): The Python tree to be flattened.
146
+ prefix (str): A prefix to use for the keys. The first character is
147
+ always discarded.
148
+ is_leaf (callable): An optional callable that returns True if the
149
+ passed object is considered a leaf or False otherwise.
150
+ destination (list or dict, optional): A list or dictionary to store the
151
+ flattened tree. If None an empty list will be used. Default: ``None``.
152
+
153
+ Returns:
154
+ Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
155
+ the Python tree.
156
+ """
157
+ if destination is None:
158
+ destination = []
159
+
160
+ # Create the function to update the destination. We are taking advantage of
161
+ # the fact that list.extend and dict.update have the same API to simplify
162
+ # the code a bit.
163
+ if isinstance(destination, list):
164
+ _add_to_destination = destination.extend
165
+ elif isinstance(destination, dict):
166
+ _add_to_destination = destination.update
167
+ else:
168
+ raise ValueError("Destination should be either a list or a dictionary or None")
169
+
170
+ # Leaf identified by is_leaf so add it and return
171
+ if is_leaf is not None and is_leaf(tree):
172
+ _add_to_destination([(prefix[1:], tree)])
173
+ return destination
174
+
175
+ # List or tuple so recursively add each subtree
176
+ if isinstance(tree, (list, tuple)):
177
+ for i, item in enumerate(tree):
178
+ tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
179
+ return destination
180
+
181
+ # Dictionary so recursively add each subtree
182
+ if isinstance(tree, dict):
183
+ for key, value in tree.items():
184
+ tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
185
+ return destination
186
+
187
+ # Leaf so add it and return
188
+ _add_to_destination([(prefix[1:], tree)])
189
+
190
+ return destination
191
+
192
+
193
+ def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
194
+ """Recreate a Python tree from its flat representation.
195
+
196
+ .. code-block:: python
197
+
198
+ from mlx.utils import tree_unflatten
199
+
200
+ d = tree_unflatten([("hello.world", 42)])
201
+ print(d)
202
+ # {"hello": {"world": 42}}
203
+
204
+ d = tree_unflatten({"hello.world": 42})
205
+ print(d)
206
+ # {"hello": {"world": 42}}
207
+
208
+ Args:
209
+ tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
210
+ For instance as returned by :meth:`tree_flatten`.
211
+
212
+ Returns:
213
+ A Python tree.
214
+ """
215
+ items = tree.items() if isinstance(tree, dict) else tree
216
+
217
+ # Special case when we have just one element in the tree ie not a tree
218
+ if len(items) == 1:
219
+ key, value = next(iter(items))
220
+ if key == "":
221
+ return value
222
+
223
+ # collect children
224
+ children = defaultdict(list)
225
+ for key, value in items:
226
+ current_idx, *next_idx = key.split(".", maxsplit=1)
227
+ next_idx = "" if not next_idx else next_idx[0]
228
+ children[current_idx].append((next_idx, value))
229
+
230
+ # Assume they are a list and fail to dict if the keys are not all integers
231
+ try:
232
+ keys = sorted((int(idx), idx) for idx in children.keys())
233
+ l = []
234
+ for i, k in keys:
235
+ # if i <= len(l), no {} will be appended.
236
+ l.extend([{} for _ in range(i - len(l))])
237
+ l.append(tree_unflatten(children[k]))
238
+ return l
239
+ except ValueError:
240
+ return {k: tree_unflatten(v) for k, v in children.items()}
241
+
242
+
243
+ def tree_reduce(fn, tree, initializer=None, is_leaf=None):
244
+ """Applies a reduction to the leaves of a Python tree.
245
+
246
+ This function reduces Python trees into an accumulated result by applying
247
+ the provided function ``fn`` to the leaves of the tree.
248
+
249
+ Example:
250
+ >>> from mlx.utils import tree_reduce
251
+ >>> tree = {"a": [1, 2, 3], "b": [4, 5]}
252
+ >>> tree_reduce(lambda acc, x: acc + x, tree, 0)
253
+ 15
254
+
255
+ Args:
256
+ fn (callable): The reducer function that takes two arguments (accumulator,
257
+ current value) and returns the updated accumulator.
258
+ tree (Any): The Python tree to reduce. It can be any nested combination of
259
+ lists, tuples, or dictionaries.
260
+ initializer (Any, optional): The initial value to start the reduction. If
261
+ not provided, the first leaf value is used.
262
+ is_leaf (callable, optional): A function to determine if an object is a
263
+ leaf, returning ``True`` for leaf nodes and ``False`` otherwise.
264
+
265
+ Returns:
266
+ Any: The accumulated value.
267
+ """
268
+ if is_leaf is not None and is_leaf(tree):
269
+ return tree if initializer is None else fn(initializer, tree)
270
+
271
+ accumulator = initializer
272
+
273
+ if isinstance(tree, (list, tuple)):
274
+ for item in tree:
275
+ accumulator = tree_reduce(fn, item, accumulator, is_leaf)
276
+ elif isinstance(tree, dict):
277
+ for item in tree.values():
278
+ accumulator = tree_reduce(fn, item, accumulator, is_leaf)
279
+ else:
280
+ return tree if accumulator is None else fn(accumulator, tree)
281
+
282
+ return accumulator
283
+
284
+
285
+ def tree_merge(tree_a, tree_b, merge_fn=None):
286
+ """Merge two Python trees in one containing the values of both. It can be
287
+ thought of as a deep dict.update method.
288
+
289
+ Args:
290
+ tree_a (Any): The first Python tree.
291
+ tree_b (Any): The second Python tree.
292
+ merge_fn (callable, optional): A function to merge leaves.
293
+
294
+ Returns:
295
+ The Python tree containing the values of both ``tree_a`` and
296
+ ``tree_b``.
297
+ """
298
+ if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
299
+ tree_a = None
300
+ if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
301
+ tree_b = None
302
+ if tree_a is None and tree_b is not None:
303
+ return tree_b
304
+ if tree_a is not None and tree_b is None:
305
+ return tree_a
306
+
307
+ if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
308
+ TreeType = type(tree_a)
309
+ return TreeType(
310
+ tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
311
+ )
312
+ elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
313
+ return {
314
+ k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
315
+ for k in set(tree_a.keys()) | set(tree_b.keys())
316
+ }
317
+ else:
318
+ if merge_fn is None:
319
+ raise ValueError(
320
+ (
321
+ "Trees contain elements at the same locations but no merge "
322
+ "function was provided"
323
+ )
324
+ )
325
+ return merge_fn(tree_a, tree_b)