mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# This is a basic version file for the Config-mode of find_package().
|
|
2
|
+
# It is used by write_basic_package_version_file() as input file for configure_file()
|
|
3
|
+
# to create a version-file which can be installed along a config.cmake file.
|
|
4
|
+
#
|
|
5
|
+
# The created file sets PACKAGE_VERSION_EXACT if the current version string and
|
|
6
|
+
# the requested version string are exactly the same and it sets
|
|
7
|
+
# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
|
|
8
|
+
# but only if the requested major version is the same as the current one.
|
|
9
|
+
# The variable CVF_VERSION must be set before calling configure_file().
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
set(PACKAGE_VERSION "0.30.1")
|
|
13
|
+
|
|
14
|
+
if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION)
|
|
15
|
+
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
|
16
|
+
else()
|
|
17
|
+
|
|
18
|
+
if("0.30.1" MATCHES "^([0-9]+)\\.")
|
|
19
|
+
set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}")
|
|
20
|
+
if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0)
|
|
21
|
+
string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}")
|
|
22
|
+
endif()
|
|
23
|
+
else()
|
|
24
|
+
set(CVF_VERSION_MAJOR "0.30.1")
|
|
25
|
+
endif()
|
|
26
|
+
|
|
27
|
+
if(PACKAGE_FIND_VERSION_RANGE)
|
|
28
|
+
# both endpoints of the range must have the expected major version
|
|
29
|
+
math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1")
|
|
30
|
+
if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
|
|
31
|
+
OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR)
|
|
32
|
+
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT)))
|
|
33
|
+
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
|
34
|
+
elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
|
|
35
|
+
AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX)
|
|
36
|
+
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX)))
|
|
37
|
+
set(PACKAGE_VERSION_COMPATIBLE TRUE)
|
|
38
|
+
else()
|
|
39
|
+
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
|
40
|
+
endif()
|
|
41
|
+
else()
|
|
42
|
+
if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR)
|
|
43
|
+
set(PACKAGE_VERSION_COMPATIBLE TRUE)
|
|
44
|
+
else()
|
|
45
|
+
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
|
46
|
+
endif()
|
|
47
|
+
|
|
48
|
+
if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION)
|
|
49
|
+
set(PACKAGE_VERSION_EXACT TRUE)
|
|
50
|
+
endif()
|
|
51
|
+
endif()
|
|
52
|
+
endif()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
|
|
56
|
+
if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "")
|
|
57
|
+
return()
|
|
58
|
+
endif()
|
|
59
|
+
|
|
60
|
+
# check that the installed version has the same 32/64bit-ness as the one which is currently searching:
|
|
61
|
+
if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8")
|
|
62
|
+
math(EXPR installedBits "8 * 8")
|
|
63
|
+
set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)")
|
|
64
|
+
set(PACKAGE_VERSION_UNSUITABLE TRUE)
|
|
65
|
+
endif()
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
#----------------------------------------------------------------
|
|
2
|
+
# Generated CMake target import file for configuration "Release".
|
|
3
|
+
#----------------------------------------------------------------
|
|
4
|
+
|
|
5
|
+
# Commands may need to know the format version.
|
|
6
|
+
set(CMAKE_IMPORT_FILE_VERSION 1)
|
|
7
|
+
|
|
8
|
+
# Import target "mlx" for configuration "Release"
|
|
9
|
+
set_property(TARGET mlx APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
|
|
10
|
+
set_target_properties(mlx PROPERTIES
|
|
11
|
+
IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libmlx.so"
|
|
12
|
+
IMPORTED_SONAME_RELEASE "libmlx.so"
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
list(APPEND _cmake_import_check_targets mlx )
|
|
16
|
+
list(APPEND _cmake_import_check_files_for_mlx "${_IMPORT_PREFIX}/lib/libmlx.so" )
|
|
17
|
+
|
|
18
|
+
# Commands beyond this point should not need to know the version.
|
|
19
|
+
set(CMAKE_IMPORT_FILE_VERSION)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# Generated by CMake
|
|
2
|
+
|
|
3
|
+
if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
|
|
4
|
+
message(FATAL_ERROR "CMake >= 2.8.3 required")
|
|
5
|
+
endif()
|
|
6
|
+
if(CMAKE_VERSION VERSION_LESS "2.8.3")
|
|
7
|
+
message(FATAL_ERROR "CMake >= 2.8.3 required")
|
|
8
|
+
endif()
|
|
9
|
+
cmake_policy(PUSH)
|
|
10
|
+
cmake_policy(VERSION 2.8.3...4.0)
|
|
11
|
+
#----------------------------------------------------------------
|
|
12
|
+
# Generated CMake target import file.
|
|
13
|
+
#----------------------------------------------------------------
|
|
14
|
+
|
|
15
|
+
# Commands may need to know the format version.
|
|
16
|
+
set(CMAKE_IMPORT_FILE_VERSION 1)
|
|
17
|
+
|
|
18
|
+
# Protect against multiple inclusion, which would fail when already imported targets are added once more.
|
|
19
|
+
set(_cmake_targets_defined "")
|
|
20
|
+
set(_cmake_targets_not_defined "")
|
|
21
|
+
set(_cmake_expected_targets "")
|
|
22
|
+
foreach(_cmake_expected_target IN ITEMS mlx)
|
|
23
|
+
list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
|
|
24
|
+
if(TARGET "${_cmake_expected_target}")
|
|
25
|
+
list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
|
|
26
|
+
else()
|
|
27
|
+
list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
|
|
28
|
+
endif()
|
|
29
|
+
endforeach()
|
|
30
|
+
unset(_cmake_expected_target)
|
|
31
|
+
if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
|
|
32
|
+
unset(_cmake_targets_defined)
|
|
33
|
+
unset(_cmake_targets_not_defined)
|
|
34
|
+
unset(_cmake_expected_targets)
|
|
35
|
+
unset(CMAKE_IMPORT_FILE_VERSION)
|
|
36
|
+
cmake_policy(POP)
|
|
37
|
+
return()
|
|
38
|
+
endif()
|
|
39
|
+
if(NOT _cmake_targets_defined STREQUAL "")
|
|
40
|
+
string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
|
|
41
|
+
string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
|
|
42
|
+
message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
|
|
43
|
+
endif()
|
|
44
|
+
unset(_cmake_targets_defined)
|
|
45
|
+
unset(_cmake_targets_not_defined)
|
|
46
|
+
unset(_cmake_expected_targets)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Compute the installation prefix relative to this file.
|
|
50
|
+
get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
|
51
|
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
|
52
|
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
|
53
|
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
|
54
|
+
if(_IMPORT_PREFIX STREQUAL "/")
|
|
55
|
+
set(_IMPORT_PREFIX "")
|
|
56
|
+
endif()
|
|
57
|
+
|
|
58
|
+
# Create imported target mlx
|
|
59
|
+
add_library(mlx SHARED IMPORTED)
|
|
60
|
+
|
|
61
|
+
set_target_properties(mlx PROPERTIES
|
|
62
|
+
INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Load information for each installed configuration.
|
|
66
|
+
file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/MLXTargets-*.cmake")
|
|
67
|
+
foreach(_cmake_config_file IN LISTS _cmake_config_files)
|
|
68
|
+
include("${_cmake_config_file}")
|
|
69
|
+
endforeach()
|
|
70
|
+
unset(_cmake_config_file)
|
|
71
|
+
unset(_cmake_config_files)
|
|
72
|
+
|
|
73
|
+
# Cleanup temporary variables.
|
|
74
|
+
set(_IMPORT_PREFIX)
|
|
75
|
+
|
|
76
|
+
# Loop over all imported files and verify that they actually exist
|
|
77
|
+
foreach(_cmake_target IN LISTS _cmake_import_check_targets)
|
|
78
|
+
if(CMAKE_VERSION VERSION_LESS "3.28"
|
|
79
|
+
OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
|
|
80
|
+
OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
|
|
81
|
+
foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
|
|
82
|
+
if(NOT EXISTS "${_cmake_file}")
|
|
83
|
+
message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
|
|
84
|
+
\"${_cmake_file}\"
|
|
85
|
+
but this file does not exist. Possible reasons include:
|
|
86
|
+
* The file was deleted, renamed, or moved to another location.
|
|
87
|
+
* An install or uninstall procedure did not complete successfully.
|
|
88
|
+
* The installation package was faulty and contained
|
|
89
|
+
\"${CMAKE_CURRENT_LIST_FILE}\"
|
|
90
|
+
but not all the files it references.
|
|
91
|
+
")
|
|
92
|
+
endif()
|
|
93
|
+
endforeach()
|
|
94
|
+
endif()
|
|
95
|
+
unset(_cmake_file)
|
|
96
|
+
unset("_cmake_import_check_files_for_${_cmake_target}")
|
|
97
|
+
endforeach()
|
|
98
|
+
unset(_cmake_target)
|
|
99
|
+
unset(_cmake_import_check_targets)
|
|
100
|
+
|
|
101
|
+
# This file does not depend on other imported targets which have
|
|
102
|
+
# been exported from the same project but in a separate export set.
|
|
103
|
+
|
|
104
|
+
# Commands beyond this point should not need to know the version.
|
|
105
|
+
set(CMAKE_IMPORT_FILE_VERSION)
|
|
106
|
+
cmake_policy(POP)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
include(CMakeParseArguments)
|
|
2
|
+
|
|
3
|
+
# clang format off
|
|
4
|
+
#
|
|
5
|
+
# ##############################################################################
|
|
6
|
+
# Build metal library
|
|
7
|
+
#
|
|
8
|
+
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
|
9
|
+
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
|
10
|
+
#
|
|
11
|
+
# Args: TARGET: Custom target to be added for the metal library TITLE: Name of
|
|
12
|
+
# the .metallib OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib SOURCES: List
|
|
13
|
+
# of source files INCLUDE_DIRS: List of include dirs DEPS: List of dependency
|
|
14
|
+
# files (like headers) DEBUG: Boolean, if true, enables debug compile options
|
|
15
|
+
# for this specific library. If not provided, uses global MLX_METAL_DEBUG.
|
|
16
|
+
#
|
|
17
|
+
# clang format on
|
|
18
|
+
|
|
19
|
+
macro(mlx_build_metallib)
|
|
20
|
+
# Parse args
|
|
21
|
+
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY DEBUG)
|
|
22
|
+
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
|
23
|
+
cmake_parse_arguments(MTLLIB "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
|
24
|
+
|
|
25
|
+
# Set output
|
|
26
|
+
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
|
27
|
+
|
|
28
|
+
# Collect compile options
|
|
29
|
+
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
|
|
30
|
+
if(MLX_METAL_DEBUG OR MTLLIB_DEBUG)
|
|
31
|
+
set(MTLLIB_COMPILE_OPTIONS ${MTLLIB_COMPILE_OPTIONS} -gline-tables-only
|
|
32
|
+
-frecord-sources)
|
|
33
|
+
endif()
|
|
34
|
+
|
|
35
|
+
# Prepare metallib build command
|
|
36
|
+
add_custom_command(
|
|
37
|
+
OUTPUT ${MTLLIB_BUILD_TARGET}
|
|
38
|
+
COMMAND
|
|
39
|
+
xcrun -sdk macosx metal
|
|
40
|
+
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
|
41
|
+
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
|
42
|
+
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
|
43
|
+
COMMAND_EXPAND_LISTS
|
|
44
|
+
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
|
45
|
+
VERBATIM)
|
|
46
|
+
|
|
47
|
+
# Add metallib custom target
|
|
48
|
+
add_custom_target(${MTLLIB_TARGET} DEPENDS ${MTLLIB_BUILD_TARGET})
|
|
49
|
+
|
|
50
|
+
endmacro(mlx_build_metallib)
|
mlx/utils.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
# Copyright © 2023 Apple Inc.
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from itertools import zip_longest
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def tree_map(
|
|
8
|
+
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
|
|
9
|
+
) -> Any:
|
|
10
|
+
"""Applies ``fn`` to the leaves of the Python tree ``tree`` and
|
|
11
|
+
returns a new collection with the results.
|
|
12
|
+
|
|
13
|
+
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
|
14
|
+
and the corresponding leaves are provided as extra positional arguments to
|
|
15
|
+
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
|
16
|
+
than to :func:`map`.
|
|
17
|
+
|
|
18
|
+
The keyword argument ``is_leaf`` decides what constitutes a leaf from
|
|
19
|
+
``tree`` similar to :func:`tree_flatten`.
|
|
20
|
+
|
|
21
|
+
.. code-block:: python
|
|
22
|
+
|
|
23
|
+
import mlx.nn as nn
|
|
24
|
+
from mlx.utils import tree_map
|
|
25
|
+
|
|
26
|
+
model = nn.Linear(10, 10)
|
|
27
|
+
print(model.parameters().keys())
|
|
28
|
+
# dict_keys(['weight', 'bias'])
|
|
29
|
+
|
|
30
|
+
# square the parameters
|
|
31
|
+
model.update(tree_map(lambda x: x*x, model.parameters()))
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
fn (callable): The function that processes the leaves of the tree.
|
|
35
|
+
tree (Any): The main Python tree that will be iterated upon.
|
|
36
|
+
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
|
|
37
|
+
is_leaf (callable, optional): An optional callable that returns ``True``
|
|
38
|
+
if the passed object is considered a leaf or ``False`` otherwise.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
A Python tree with the new values returned by ``fn``.
|
|
42
|
+
"""
|
|
43
|
+
if is_leaf is not None and is_leaf(tree):
|
|
44
|
+
return fn(tree, *rest)
|
|
45
|
+
elif isinstance(tree, (list, tuple)):
|
|
46
|
+
TreeType = type(tree)
|
|
47
|
+
subtrees = (
|
|
48
|
+
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
|
49
|
+
for i, child in enumerate(tree)
|
|
50
|
+
)
|
|
51
|
+
return TreeType(*subtrees) if hasattr(tree, "_fields") else TreeType(subtrees)
|
|
52
|
+
elif isinstance(tree, dict):
|
|
53
|
+
return {
|
|
54
|
+
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
|
|
55
|
+
for k, child in tree.items()
|
|
56
|
+
}
|
|
57
|
+
else:
|
|
58
|
+
return fn(tree, *rest)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def tree_map_with_path(
|
|
62
|
+
fn: Callable,
|
|
63
|
+
tree: Any,
|
|
64
|
+
*rest: Any,
|
|
65
|
+
is_leaf: Optional[Callable] = None,
|
|
66
|
+
path: Optional[Any] = None,
|
|
67
|
+
) -> Any:
|
|
68
|
+
"""Applies ``fn`` to the path and leaves of the Python tree ``tree`` and
|
|
69
|
+
returns a new collection with the results.
|
|
70
|
+
|
|
71
|
+
This function is the same :func:`tree_map` but the ``fn`` takes the path as
|
|
72
|
+
the first argument followed by the remaining tree nodes.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
fn (callable): The function that processes the leaves of the tree.
|
|
76
|
+
tree (Any): The main Python tree that will be iterated upon.
|
|
77
|
+
rest (tuple[Any]): Extra trees to be iterated together with ``tree``.
|
|
78
|
+
is_leaf (Optional[Callable]): An optional callable that returns ``True``
|
|
79
|
+
if the passed object is considered a leaf or ``False`` otherwise.
|
|
80
|
+
path (Optional[Any]): Prefix will be added to the result.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
A Python tree with the new values returned by ``fn``.
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
>>> from mlx.utils import tree_map_with_path
|
|
87
|
+
>>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
|
|
88
|
+
>>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
|
|
89
|
+
model.0.w
|
|
90
|
+
model.0.b
|
|
91
|
+
model.1.w
|
|
92
|
+
model.1.b
|
|
93
|
+
"""
|
|
94
|
+
if is_leaf is not None and is_leaf(tree):
|
|
95
|
+
return fn(path, tree, *rest)
|
|
96
|
+
elif isinstance(tree, (list, tuple)):
|
|
97
|
+
prefix = f"{path}." if path else ""
|
|
98
|
+
TreeType = type(tree)
|
|
99
|
+
return TreeType(
|
|
100
|
+
tree_map_with_path(
|
|
101
|
+
fn, child, *(r[i] for r in rest), is_leaf=is_leaf, path=f"{prefix}{i}"
|
|
102
|
+
)
|
|
103
|
+
for i, child in enumerate(tree)
|
|
104
|
+
)
|
|
105
|
+
elif isinstance(tree, dict):
|
|
106
|
+
prefix = f"{path}." if path else ""
|
|
107
|
+
return {
|
|
108
|
+
k: tree_map_with_path(
|
|
109
|
+
fn, child, *(r[k] for r in rest), is_leaf=is_leaf, path=f"{prefix}{k}"
|
|
110
|
+
)
|
|
111
|
+
for k, child in tree.items()
|
|
112
|
+
}
|
|
113
|
+
else:
|
|
114
|
+
return fn(path, tree, *rest)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def tree_flatten(
|
|
118
|
+
tree: Any,
|
|
119
|
+
prefix: str = "",
|
|
120
|
+
is_leaf: Optional[Callable] = None,
|
|
121
|
+
destination: Optional[Union[List[Tuple[str, Any]], Dict[str, Any]]] = None,
|
|
122
|
+
) -> Union[List[Tuple[str, Any]], Dict[str, Any]]:
|
|
123
|
+
"""Flattens a Python tree to a list of key, value tuples.
|
|
124
|
+
|
|
125
|
+
The keys are using the dot notation to define trees of arbitrary depth and
|
|
126
|
+
complexity.
|
|
127
|
+
|
|
128
|
+
.. code-block:: python
|
|
129
|
+
|
|
130
|
+
from mlx.utils import tree_flatten
|
|
131
|
+
|
|
132
|
+
print(tree_flatten([[[0]]]))
|
|
133
|
+
# [("0.0.0", 0)]
|
|
134
|
+
|
|
135
|
+
print(tree_flatten([[[0]]], prefix=".hello"))
|
|
136
|
+
# [("hello.0.0.0", 0)]
|
|
137
|
+
|
|
138
|
+
tree_flatten({"a": {"b": 1}}, destination={})
|
|
139
|
+
{"a.b": 1}
|
|
140
|
+
|
|
141
|
+
.. note::
|
|
142
|
+
Dictionaries should have keys that are valid Python identifiers.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
tree (Any): The Python tree to be flattened.
|
|
146
|
+
prefix (str): A prefix to use for the keys. The first character is
|
|
147
|
+
always discarded.
|
|
148
|
+
is_leaf (callable): An optional callable that returns True if the
|
|
149
|
+
passed object is considered a leaf or False otherwise.
|
|
150
|
+
destination (list or dict, optional): A list or dictionary to store the
|
|
151
|
+
flattened tree. If None an empty list will be used. Default: ``None``.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Union[List[Tuple[str, Any]], Dict[str, Any]]: The flat representation of
|
|
155
|
+
the Python tree.
|
|
156
|
+
"""
|
|
157
|
+
if destination is None:
|
|
158
|
+
destination = []
|
|
159
|
+
|
|
160
|
+
# Create the function to update the destination. We are taking advantage of
|
|
161
|
+
# the fact that list.extend and dict.update have the same API to simplify
|
|
162
|
+
# the code a bit.
|
|
163
|
+
if isinstance(destination, list):
|
|
164
|
+
_add_to_destination = destination.extend
|
|
165
|
+
elif isinstance(destination, dict):
|
|
166
|
+
_add_to_destination = destination.update
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError("Destination should be either a list or a dictionary or None")
|
|
169
|
+
|
|
170
|
+
# Leaf identified by is_leaf so add it and return
|
|
171
|
+
if is_leaf is not None and is_leaf(tree):
|
|
172
|
+
_add_to_destination([(prefix[1:], tree)])
|
|
173
|
+
return destination
|
|
174
|
+
|
|
175
|
+
# List or tuple so recursively add each subtree
|
|
176
|
+
if isinstance(tree, (list, tuple)):
|
|
177
|
+
for i, item in enumerate(tree):
|
|
178
|
+
tree_flatten(item, f"{prefix}.{i}", is_leaf, destination)
|
|
179
|
+
return destination
|
|
180
|
+
|
|
181
|
+
# Dictionary so recursively add each subtree
|
|
182
|
+
if isinstance(tree, dict):
|
|
183
|
+
for key, value in tree.items():
|
|
184
|
+
tree_flatten(value, f"{prefix}.{key}", is_leaf, destination)
|
|
185
|
+
return destination
|
|
186
|
+
|
|
187
|
+
# Leaf so add it and return
|
|
188
|
+
_add_to_destination([(prefix[1:], tree)])
|
|
189
|
+
|
|
190
|
+
return destination
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def tree_unflatten(tree: Union[List[Tuple[str, Any]], Dict[str, Any]]) -> Any:
|
|
194
|
+
"""Recreate a Python tree from its flat representation.
|
|
195
|
+
|
|
196
|
+
.. code-block:: python
|
|
197
|
+
|
|
198
|
+
from mlx.utils import tree_unflatten
|
|
199
|
+
|
|
200
|
+
d = tree_unflatten([("hello.world", 42)])
|
|
201
|
+
print(d)
|
|
202
|
+
# {"hello": {"world": 42}}
|
|
203
|
+
|
|
204
|
+
d = tree_unflatten({"hello.world": 42})
|
|
205
|
+
print(d)
|
|
206
|
+
# {"hello": {"world": 42}}
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
tree (list[tuple[str, Any]] or dict[str, Any]): The flat representation of a Python tree.
|
|
210
|
+
For instance as returned by :meth:`tree_flatten`.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
A Python tree.
|
|
214
|
+
"""
|
|
215
|
+
items = tree.items() if isinstance(tree, dict) else tree
|
|
216
|
+
|
|
217
|
+
# Special case when we have just one element in the tree ie not a tree
|
|
218
|
+
if len(items) == 1:
|
|
219
|
+
key, value = next(iter(items))
|
|
220
|
+
if key == "":
|
|
221
|
+
return value
|
|
222
|
+
|
|
223
|
+
# collect children
|
|
224
|
+
children = defaultdict(list)
|
|
225
|
+
for key, value in items:
|
|
226
|
+
current_idx, *next_idx = key.split(".", maxsplit=1)
|
|
227
|
+
next_idx = "" if not next_idx else next_idx[0]
|
|
228
|
+
children[current_idx].append((next_idx, value))
|
|
229
|
+
|
|
230
|
+
# Assume they are a list and fail to dict if the keys are not all integers
|
|
231
|
+
try:
|
|
232
|
+
keys = sorted((int(idx), idx) for idx in children.keys())
|
|
233
|
+
l = []
|
|
234
|
+
for i, k in keys:
|
|
235
|
+
# if i <= len(l), no {} will be appended.
|
|
236
|
+
l.extend([{} for _ in range(i - len(l))])
|
|
237
|
+
l.append(tree_unflatten(children[k]))
|
|
238
|
+
return l
|
|
239
|
+
except ValueError:
|
|
240
|
+
return {k: tree_unflatten(v) for k, v in children.items()}
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def tree_reduce(fn, tree, initializer=None, is_leaf=None):
|
|
244
|
+
"""Applies a reduction to the leaves of a Python tree.
|
|
245
|
+
|
|
246
|
+
This function reduces Python trees into an accumulated result by applying
|
|
247
|
+
the provided function ``fn`` to the leaves of the tree.
|
|
248
|
+
|
|
249
|
+
Example:
|
|
250
|
+
>>> from mlx.utils import tree_reduce
|
|
251
|
+
>>> tree = {"a": [1, 2, 3], "b": [4, 5]}
|
|
252
|
+
>>> tree_reduce(lambda acc, x: acc + x, tree, 0)
|
|
253
|
+
15
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
fn (callable): The reducer function that takes two arguments (accumulator,
|
|
257
|
+
current value) and returns the updated accumulator.
|
|
258
|
+
tree (Any): The Python tree to reduce. It can be any nested combination of
|
|
259
|
+
lists, tuples, or dictionaries.
|
|
260
|
+
initializer (Any, optional): The initial value to start the reduction. If
|
|
261
|
+
not provided, the first leaf value is used.
|
|
262
|
+
is_leaf (callable, optional): A function to determine if an object is a
|
|
263
|
+
leaf, returning ``True`` for leaf nodes and ``False`` otherwise.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Any: The accumulated value.
|
|
267
|
+
"""
|
|
268
|
+
if is_leaf is not None and is_leaf(tree):
|
|
269
|
+
return tree if initializer is None else fn(initializer, tree)
|
|
270
|
+
|
|
271
|
+
accumulator = initializer
|
|
272
|
+
|
|
273
|
+
if isinstance(tree, (list, tuple)):
|
|
274
|
+
for item in tree:
|
|
275
|
+
accumulator = tree_reduce(fn, item, accumulator, is_leaf)
|
|
276
|
+
elif isinstance(tree, dict):
|
|
277
|
+
for item in tree.values():
|
|
278
|
+
accumulator = tree_reduce(fn, item, accumulator, is_leaf)
|
|
279
|
+
else:
|
|
280
|
+
return tree if accumulator is None else fn(accumulator, tree)
|
|
281
|
+
|
|
282
|
+
return accumulator
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def tree_merge(tree_a, tree_b, merge_fn=None):
|
|
286
|
+
"""Merge two Python trees in one containing the values of both. It can be
|
|
287
|
+
thought of as a deep dict.update method.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
tree_a (Any): The first Python tree.
|
|
291
|
+
tree_b (Any): The second Python tree.
|
|
292
|
+
merge_fn (callable, optional): A function to merge leaves.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
The Python tree containing the values of both ``tree_a`` and
|
|
296
|
+
``tree_b``.
|
|
297
|
+
"""
|
|
298
|
+
if isinstance(tree_a, (dict, list, tuple)) and len(tree_a) == 0:
|
|
299
|
+
tree_a = None
|
|
300
|
+
if isinstance(tree_b, (dict, list, tuple)) and len(tree_b) == 0:
|
|
301
|
+
tree_b = None
|
|
302
|
+
if tree_a is None and tree_b is not None:
|
|
303
|
+
return tree_b
|
|
304
|
+
if tree_a is not None and tree_b is None:
|
|
305
|
+
return tree_a
|
|
306
|
+
|
|
307
|
+
if isinstance(tree_a, (list, tuple)) and isinstance(tree_b, (list, tuple)):
|
|
308
|
+
TreeType = type(tree_a)
|
|
309
|
+
return TreeType(
|
|
310
|
+
tree_merge(a, b, merge_fn) for a, b in zip_longest(tree_a, tree_b)
|
|
311
|
+
)
|
|
312
|
+
elif isinstance(tree_a, dict) and isinstance(tree_b, dict):
|
|
313
|
+
return {
|
|
314
|
+
k: tree_merge(tree_a.get(k, None), tree_b.get(k, None), merge_fn)
|
|
315
|
+
for k in set(tree_a.keys()) | set(tree_b.keys())
|
|
316
|
+
}
|
|
317
|
+
else:
|
|
318
|
+
if merge_fn is None:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
(
|
|
321
|
+
"Trees contain elements at the same locations but no merge "
|
|
322
|
+
"function was provided"
|
|
323
|
+
)
|
|
324
|
+
)
|
|
325
|
+
return merge_fn(tree_a, tree_b)
|