mlx 0.30.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/ext/mlx/extconf.rb +94 -0
- data/ext/mlx/native.cpp +8027 -0
- data/lib/mlx/core.rb +1678 -0
- data/lib/mlx/distributed_utils/common.rb +116 -0
- data/lib/mlx/distributed_utils/config.rb +600 -0
- data/lib/mlx/distributed_utils/launch.rb +490 -0
- data/lib/mlx/extension.rb +24 -0
- data/lib/mlx/nn/base.rb +388 -0
- data/lib/mlx/nn/init.rb +140 -0
- data/lib/mlx/nn/layers/activations.rb +336 -0
- data/lib/mlx/nn/layers/base.rb +6 -0
- data/lib/mlx/nn/layers/containers.rb +20 -0
- data/lib/mlx/nn/layers/convolution.rb +120 -0
- data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
- data/lib/mlx/nn/layers/distributed.rb +309 -0
- data/lib/mlx/nn/layers/dropout.rb +75 -0
- data/lib/mlx/nn/layers/embedding.rb +28 -0
- data/lib/mlx/nn/layers/linear.rb +79 -0
- data/lib/mlx/nn/layers/normalization.rb +216 -0
- data/lib/mlx/nn/layers/pooling.rb +167 -0
- data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
- data/lib/mlx/nn/layers/quantized.rb +215 -0
- data/lib/mlx/nn/layers/recurrent.rb +135 -0
- data/lib/mlx/nn/layers/transformer.rb +330 -0
- data/lib/mlx/nn/layers/upsample.rb +97 -0
- data/lib/mlx/nn/layers.rb +18 -0
- data/lib/mlx/nn/losses.rb +251 -0
- data/lib/mlx/nn/utils.rb +167 -0
- data/lib/mlx/nn.rb +12 -0
- data/lib/mlx/optimizers/optimizers.rb +808 -0
- data/lib/mlx/optimizers/schedulers.rb +62 -0
- data/lib/mlx/optimizers.rb +9 -0
- data/lib/mlx/utils.rb +171 -0
- data/lib/mlx/version.rb +5 -0
- data/lib/mlx.rb +64 -0
- data/mlx/CMakeLists.txt +449 -0
- data/mlx/cmake/FindCUDNN.cmake +177 -0
- data/mlx/cmake/FindNCCL.cmake +54 -0
- data/mlx/cmake/Findnvpl.cmake +3 -0
- data/mlx/cmake/extension.cmake +50 -0
- data/mlx/mlx/3rdparty/.clang-format +2 -0
- data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
- data/mlx/mlx/CMakeLists.txt +107 -0
- data/mlx/mlx/allocator.h +75 -0
- data/mlx/mlx/api.h +29 -0
- data/mlx/mlx/array.cpp +354 -0
- data/mlx/mlx/array.h +647 -0
- data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
- data/mlx/mlx/backend/common/binary.h +97 -0
- data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
- data/mlx/mlx/backend/common/broadcasting.h +11 -0
- data/mlx/mlx/backend/common/buffer_cache.h +158 -0
- data/mlx/mlx/backend/common/common.cpp +305 -0
- data/mlx/mlx/backend/common/compiled.cpp +243 -0
- data/mlx/mlx/backend/common/compiled.h +77 -0
- data/mlx/mlx/backend/common/copy.h +50 -0
- data/mlx/mlx/backend/common/hadamard.h +109 -0
- data/mlx/mlx/backend/common/load.cpp +57 -0
- data/mlx/mlx/backend/common/matmul.h +67 -0
- data/mlx/mlx/backend/common/reduce.cpp +154 -0
- data/mlx/mlx/backend/common/reduce.h +59 -0
- data/mlx/mlx/backend/common/slicing.cpp +71 -0
- data/mlx/mlx/backend/common/slicing.h +20 -0
- data/mlx/mlx/backend/common/ternary.h +85 -0
- data/mlx/mlx/backend/common/unary.h +29 -0
- data/mlx/mlx/backend/common/utils.cpp +231 -0
- data/mlx/mlx/backend/common/utils.h +205 -0
- data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
- data/mlx/mlx/backend/cpu/arange.h +28 -0
- data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
- data/mlx/mlx/backend/cpu/binary.cpp +269 -0
- data/mlx/mlx/backend/cpu/binary.h +517 -0
- data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
- data/mlx/mlx/backend/cpu/binary_two.h +166 -0
- data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
- data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
- data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
- data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
- data/mlx/mlx/backend/cpu/copy.cpp +386 -0
- data/mlx/mlx/backend/cpu/copy.h +36 -0
- data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
- data/mlx/mlx/backend/cpu/device_info.h +28 -0
- data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
- data/mlx/mlx/backend/cpu/eig.cpp +281 -0
- data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
- data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
- data/mlx/mlx/backend/cpu/encoder.h +67 -0
- data/mlx/mlx/backend/cpu/eval.cpp +40 -0
- data/mlx/mlx/backend/cpu/eval.h +12 -0
- data/mlx/mlx/backend/cpu/fft.cpp +120 -0
- data/mlx/mlx/backend/cpu/gemm.h +26 -0
- data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
- data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
- data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
- data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
- data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
- data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
- data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
- data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
- data/mlx/mlx/backend/cpu/lapack.h +80 -0
- data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
- data/mlx/mlx/backend/cpu/luf.cpp +120 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
- data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
- data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
- data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
- data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
- data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
- data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
- data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
- data/mlx/mlx/backend/cpu/scan.cpp +338 -0
- data/mlx/mlx/backend/cpu/select.cpp +95 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
- data/mlx/mlx/backend/cpu/simd/math.h +193 -0
- data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
- data/mlx/mlx/backend/cpu/simd/type.h +11 -0
- data/mlx/mlx/backend/cpu/slicing.h +21 -0
- data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
- data/mlx/mlx/backend/cpu/sort.cpp +481 -0
- data/mlx/mlx/backend/cpu/svd.cpp +289 -0
- data/mlx/mlx/backend/cpu/ternary.h +154 -0
- data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
- data/mlx/mlx/backend/cpu/threefry.h +21 -0
- data/mlx/mlx/backend/cpu/unary.cpp +238 -0
- data/mlx/mlx/backend/cpu/unary.h +281 -0
- data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
- data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
- data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
- data/mlx/mlx/backend/cuda/allocator.h +94 -0
- data/mlx/mlx/backend/cuda/arange.cu +68 -0
- data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
- data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
- data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
- data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
- data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
- data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
- data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
- data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
- data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
- data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
- data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
- data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
- data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
- data/mlx/mlx/backend/cuda/conv.cpp +403 -0
- data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
- data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
- data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
- data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
- data/mlx/mlx/backend/cuda/copy.cu +132 -0
- data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
- data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
- data/mlx/mlx/backend/cuda/cuda.h +21 -0
- data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
- data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
- data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
- data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
- data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
- data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
- data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
- data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
- data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
- data/mlx/mlx/backend/cuda/device/config.h +12 -0
- data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
- data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
- data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
- data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
- data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
- data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
- data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
- data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
- data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
- data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
- data/mlx/mlx/backend/cuda/device.cpp +522 -0
- data/mlx/mlx/backend/cuda/device.h +195 -0
- data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
- data/mlx/mlx/backend/cuda/distributed.cu +121 -0
- data/mlx/mlx/backend/cuda/eval.cpp +66 -0
- data/mlx/mlx/backend/cuda/event.cu +415 -0
- data/mlx/mlx/backend/cuda/event.h +79 -0
- data/mlx/mlx/backend/cuda/fence.cpp +42 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
- data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
- data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
- data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
- data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
- data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
- data/mlx/mlx/backend/cuda/jit_module.h +120 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
- data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
- data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
- data/mlx/mlx/backend/cuda/load.cpp +60 -0
- data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
- data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
- data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
- data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
- data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
- data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
- data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
- data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
- data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
- data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
- data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
- data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
- data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
- data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
- data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
- data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
- data/mlx/mlx/backend/cuda/random.cu +202 -0
- data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
- data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
- data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
- data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
- data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
- data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
- data/mlx/mlx/backend/cuda/reduce.cu +73 -0
- data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
- data/mlx/mlx/backend/cuda/rope.cu +429 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
- data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
- data/mlx/mlx/backend/cuda/scan.cu +468 -0
- data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
- data/mlx/mlx/backend/cuda/softmax.cu +162 -0
- data/mlx/mlx/backend/cuda/sort.cu +1076 -0
- data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
- data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
- data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
- data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
- data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
- data/mlx/mlx/backend/cuda/ternary.cu +271 -0
- data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
- data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
- data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
- data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
- data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
- data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
- data/mlx/mlx/backend/cuda/utils.cpp +116 -0
- data/mlx/mlx/backend/cuda/utils.h +49 -0
- data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
- data/mlx/mlx/backend/cuda/worker.cpp +79 -0
- data/mlx/mlx/backend/cuda/worker.h +55 -0
- data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
- data/mlx/mlx/backend/gpu/copy.cpp +89 -0
- data/mlx/mlx/backend/gpu/copy.h +57 -0
- data/mlx/mlx/backend/gpu/device_info.h +36 -0
- data/mlx/mlx/backend/gpu/eval.h +18 -0
- data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
- data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
- data/mlx/mlx/backend/gpu/slicing.h +36 -0
- data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
- data/mlx/mlx/backend/metal/allocator.cpp +279 -0
- data/mlx/mlx/backend/metal/allocator.h +79 -0
- data/mlx/mlx/backend/metal/binary.cpp +257 -0
- data/mlx/mlx/backend/metal/binary.h +33 -0
- data/mlx/mlx/backend/metal/compiled.cpp +471 -0
- data/mlx/mlx/backend/metal/conv.cpp +1118 -0
- data/mlx/mlx/backend/metal/copy.cpp +235 -0
- data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
- data/mlx/mlx/backend/metal/device.cpp +816 -0
- data/mlx/mlx/backend/metal/device.h +289 -0
- data/mlx/mlx/backend/metal/device_info.cpp +58 -0
- data/mlx/mlx/backend/metal/distributed.cpp +38 -0
- data/mlx/mlx/backend/metal/eval.cpp +97 -0
- data/mlx/mlx/backend/metal/event.cpp +62 -0
- data/mlx/mlx/backend/metal/fence.cpp +162 -0
- data/mlx/mlx/backend/metal/fft.cpp +807 -0
- data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
- data/mlx/mlx/backend/metal/indexing.cpp +727 -0
- data/mlx/mlx/backend/metal/jit/includes.h +58 -0
- data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
- data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
- data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
- data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
- data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
- data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
- data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
- data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
- data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
- data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
- data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
- data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
- data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
- data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
- data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
- data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
- data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
- data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
- data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
- data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
- data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
- data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
- data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
- data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
- data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
- data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
- data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
- data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
- data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
- data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
- data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
- data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
- data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
- data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
- data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
- data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
- data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
- data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
- data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
- data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
- data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
- data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
- data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
- data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
- data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
- data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
- data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
- data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
- data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
- data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
- data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
- data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
- data/mlx/mlx/backend/metal/kernels.h +375 -0
- data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
- data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
- data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
- data/mlx/mlx/backend/metal/matmul.h +144 -0
- data/mlx/mlx/backend/metal/metal.cpp +50 -0
- data/mlx/mlx/backend/metal/metal.h +25 -0
- data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
- data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
- data/mlx/mlx/backend/metal/normalization.cpp +433 -0
- data/mlx/mlx/backend/metal/primitives.cpp +242 -0
- data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
- data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
- data/mlx/mlx/backend/metal/reduce.h +41 -0
- data/mlx/mlx/backend/metal/resident.cpp +100 -0
- data/mlx/mlx/backend/metal/resident.h +32 -0
- data/mlx/mlx/backend/metal/rope.cpp +165 -0
- data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
- data/mlx/mlx/backend/metal/scan.cpp +145 -0
- data/mlx/mlx/backend/metal/scan.h +17 -0
- data/mlx/mlx/backend/metal/slicing.cpp +99 -0
- data/mlx/mlx/backend/metal/softmax.cpp +87 -0
- data/mlx/mlx/backend/metal/sort.cpp +368 -0
- data/mlx/mlx/backend/metal/ternary.cpp +160 -0
- data/mlx/mlx/backend/metal/ternary.h +21 -0
- data/mlx/mlx/backend/metal/unary.cpp +161 -0
- data/mlx/mlx/backend/metal/unary.h +21 -0
- data/mlx/mlx/backend/metal/utils.cpp +77 -0
- data/mlx/mlx/backend/metal/utils.h +99 -0
- data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
- data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
- data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
- data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
- data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
- data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
- data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
- data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
- data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
- data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
- data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
- data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
- data/mlx/mlx/compile.cpp +1243 -0
- data/mlx/mlx/compile.h +45 -0
- data/mlx/mlx/compile_impl.h +70 -0
- data/mlx/mlx/device.cpp +72 -0
- data/mlx/mlx/device.h +56 -0
- data/mlx/mlx/distributed/CMakeLists.txt +14 -0
- data/mlx/mlx/distributed/distributed.cpp +197 -0
- data/mlx/mlx/distributed/distributed.h +61 -0
- data/mlx/mlx/distributed/distributed_impl.h +59 -0
- data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
- data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
- data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
- data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
- data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
- data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
- data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
- data/mlx/mlx/distributed/jaccl/ring.h +178 -0
- data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
- data/mlx/mlx/distributed/jaccl/utils.h +342 -0
- data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
- data/mlx/mlx/distributed/mpi/mpi.h +12 -0
- data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
- data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
- data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
- data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
- data/mlx/mlx/distributed/nccl/nccl.h +12 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
- data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
- data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
- data/mlx/mlx/distributed/ops.cpp +186 -0
- data/mlx/mlx/distributed/ops.h +57 -0
- data/mlx/mlx/distributed/primitives.cpp +95 -0
- data/mlx/mlx/distributed/primitives.h +156 -0
- data/mlx/mlx/distributed/reduction_ops.h +38 -0
- data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
- data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
- data/mlx/mlx/distributed/ring/ring.cpp +870 -0
- data/mlx/mlx/distributed/ring/ring.h +12 -0
- data/mlx/mlx/distributed/utils.cpp +206 -0
- data/mlx/mlx/distributed/utils.h +67 -0
- data/mlx/mlx/dtype.cpp +197 -0
- data/mlx/mlx/dtype.h +116 -0
- data/mlx/mlx/dtype_utils.cpp +42 -0
- data/mlx/mlx/dtype_utils.h +119 -0
- data/mlx/mlx/einsum.cpp +941 -0
- data/mlx/mlx/einsum.h +23 -0
- data/mlx/mlx/event.h +58 -0
- data/mlx/mlx/export.cpp +1130 -0
- data/mlx/mlx/export.h +137 -0
- data/mlx/mlx/export_impl.h +99 -0
- data/mlx/mlx/fast.cpp +941 -0
- data/mlx/mlx/fast.h +103 -0
- data/mlx/mlx/fast_primitives.h +427 -0
- data/mlx/mlx/fence.h +39 -0
- data/mlx/mlx/fft.cpp +262 -0
- data/mlx/mlx/fft.h +159 -0
- data/mlx/mlx/graph_utils.cpp +175 -0
- data/mlx/mlx/graph_utils.h +67 -0
- data/mlx/mlx/io/CMakeLists.txt +25 -0
- data/mlx/mlx/io/gguf.cpp +470 -0
- data/mlx/mlx/io/gguf.h +20 -0
- data/mlx/mlx/io/gguf_quants.cpp +164 -0
- data/mlx/mlx/io/load.cpp +397 -0
- data/mlx/mlx/io/load.h +175 -0
- data/mlx/mlx/io/no_gguf.cpp +20 -0
- data/mlx/mlx/io/no_safetensors.cpp +37 -0
- data/mlx/mlx/io/safetensors.cpp +234 -0
- data/mlx/mlx/io.h +61 -0
- data/mlx/mlx/linalg.cpp +708 -0
- data/mlx/mlx/linalg.h +115 -0
- data/mlx/mlx/memory.h +80 -0
- data/mlx/mlx/mlx.h +25 -0
- data/mlx/mlx/ops.cpp +6094 -0
- data/mlx/mlx/ops.h +1610 -0
- data/mlx/mlx/primitives.cpp +5850 -0
- data/mlx/mlx/primitives.h +2525 -0
- data/mlx/mlx/random.cpp +492 -0
- data/mlx/mlx/random.h +283 -0
- data/mlx/mlx/scheduler.cpp +73 -0
- data/mlx/mlx/scheduler.h +189 -0
- data/mlx/mlx/small_vector.h +540 -0
- data/mlx/mlx/stream.h +42 -0
- data/mlx/mlx/threadpool.h +133 -0
- data/mlx/mlx/transforms.cpp +1065 -0
- data/mlx/mlx/transforms.h +231 -0
- data/mlx/mlx/transforms_impl.h +88 -0
- data/mlx/mlx/types/bf16.h +187 -0
- data/mlx/mlx/types/complex.h +113 -0
- data/mlx/mlx/types/fp16.h +234 -0
- data/mlx/mlx/types/half_types.h +58 -0
- data/mlx/mlx/types/limits.h +70 -0
- data/mlx/mlx/utils.cpp +302 -0
- data/mlx/mlx/utils.h +174 -0
- data/mlx/mlx/version.cpp +11 -0
- data/mlx/mlx/version.h +22 -0
- data/mlx/mlx.pc.in +52 -0
- metadata +643 -0
|
@@ -0,0 +1,1118 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
#include <algorithm>
|
|
3
|
+
#include <cassert>
|
|
4
|
+
#include <numeric>
|
|
5
|
+
|
|
6
|
+
#include "mlx/backend/gpu/copy.h"
|
|
7
|
+
#include "mlx/backend/metal/device.h"
|
|
8
|
+
#include "mlx/backend/metal/kernels.h"
|
|
9
|
+
#include "mlx/backend/metal/kernels/defines.h"
|
|
10
|
+
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
11
|
+
#include "mlx/backend/metal/matmul.h"
|
|
12
|
+
#include "mlx/backend/metal/utils.h"
|
|
13
|
+
#include "mlx/primitives.h"
|
|
14
|
+
#include "mlx/utils.h"
|
|
15
|
+
|
|
16
|
+
using namespace mlx::steel;
|
|
17
|
+
|
|
18
|
+
namespace mlx::core {
|
|
19
|
+
|
|
20
|
+
namespace {
|
|
21
|
+
|
|
22
|
+
template <int N>
|
|
23
|
+
void explicit_gemm_conv_ND_gpu(
|
|
24
|
+
const Stream& s,
|
|
25
|
+
metal::Device& d,
|
|
26
|
+
const array& in,
|
|
27
|
+
const array& wt,
|
|
28
|
+
array out,
|
|
29
|
+
const MLXConvParams<N>& conv_params) {
|
|
30
|
+
// Get gemm shapes
|
|
31
|
+
int implicit_M = out.size() / conv_params.O;
|
|
32
|
+
int implicit_K = wt.size() / conv_params.O;
|
|
33
|
+
int implicit_N = conv_params.O;
|
|
34
|
+
// Prepare unfolding array
|
|
35
|
+
Shape unfolded_shape{implicit_M, implicit_K};
|
|
36
|
+
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
|
37
|
+
|
|
38
|
+
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
|
39
|
+
|
|
40
|
+
// Prepare unfolding kernel
|
|
41
|
+
std::string kname;
|
|
42
|
+
kname.reserve(32);
|
|
43
|
+
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
|
|
44
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
45
|
+
auto kernel = d.get_kernel(kname);
|
|
46
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
47
|
+
|
|
48
|
+
compute_encoder.set_input_array(in, 0);
|
|
49
|
+
compute_encoder.set_output_array(in_unfolded, 1);
|
|
50
|
+
|
|
51
|
+
compute_encoder.set_bytes(conv_params, 2);
|
|
52
|
+
|
|
53
|
+
// Launch unfolding kernel
|
|
54
|
+
size_t tgp_x = std::min(conv_params.C, 64);
|
|
55
|
+
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
|
56
|
+
size_t tgp_y = 256 / tgp_x;
|
|
57
|
+
|
|
58
|
+
MTL::Size grid_dims = MTL::Size(
|
|
59
|
+
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
|
60
|
+
MTL::Size group_dims = MTL::Size(
|
|
61
|
+
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
|
62
|
+
|
|
63
|
+
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
64
|
+
|
|
65
|
+
// Reshape weight
|
|
66
|
+
Shape wt_reshape{implicit_K, implicit_N};
|
|
67
|
+
Strides wt_restride{1, implicit_K};
|
|
68
|
+
array wt_reshaped(wt_reshape, wt.dtype(), nullptr, {});
|
|
69
|
+
auto wt_flags = wt.flags();
|
|
70
|
+
wt_flags.row_contiguous = false;
|
|
71
|
+
wt_flags.col_contiguous = true;
|
|
72
|
+
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());
|
|
73
|
+
|
|
74
|
+
// Perform gemm
|
|
75
|
+
std::vector<array> copies = {in_unfolded};
|
|
76
|
+
return steel_matmul(
|
|
77
|
+
s,
|
|
78
|
+
d,
|
|
79
|
+
/*a = */ in_unfolded,
|
|
80
|
+
/*b = */ wt_reshaped,
|
|
81
|
+
/*c = */ out,
|
|
82
|
+
/*M = */ implicit_M,
|
|
83
|
+
/*N = */ implicit_N,
|
|
84
|
+
/*K = */ implicit_K,
|
|
85
|
+
/*batch_size_out = */ 1,
|
|
86
|
+
/*a_cols = */ implicit_K,
|
|
87
|
+
/*b_cols = */ implicit_K,
|
|
88
|
+
/*a_transposed = */ false,
|
|
89
|
+
/*b_transposed = */ true,
|
|
90
|
+
/*copies = */ copies);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
template <int N>
|
|
94
|
+
void explicit_gemm_conv_group_ND_gpu(
|
|
95
|
+
const Stream& s,
|
|
96
|
+
metal::Device& d,
|
|
97
|
+
const array& in,
|
|
98
|
+
const array& wt,
|
|
99
|
+
array out,
|
|
100
|
+
const MLXConvParams<N>& conv_params) {
|
|
101
|
+
const int groups = conv_params.groups;
|
|
102
|
+
const int C_per_group = conv_params.C / conv_params.groups;
|
|
103
|
+
const int O_per_group = conv_params.O / conv_params.groups;
|
|
104
|
+
// Get gemm shapes
|
|
105
|
+
const int implicit_M = out.size() / conv_params.O;
|
|
106
|
+
const int implicit_K = wt.size() / conv_params.O;
|
|
107
|
+
const int implicit_N = O_per_group;
|
|
108
|
+
|
|
109
|
+
int kernel_size = 1;
|
|
110
|
+
for (int i = 0; i < N; ++i) {
|
|
111
|
+
kernel_size *= conv_params.wS[i];
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
// Prepare unfolding array
|
|
115
|
+
Shape unfolded_shape{implicit_M, implicit_K * groups};
|
|
116
|
+
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
|
|
117
|
+
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));
|
|
118
|
+
|
|
119
|
+
// Prepare unfolding kernel
|
|
120
|
+
std::string kname;
|
|
121
|
+
kname.reserve(32);
|
|
122
|
+
concatenate(
|
|
123
|
+
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
|
|
124
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
125
|
+
auto kernel = d.get_kernel(kname);
|
|
126
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
127
|
+
|
|
128
|
+
compute_encoder.set_input_array(in, 0);
|
|
129
|
+
compute_encoder.set_output_array(in_unfolded, 1);
|
|
130
|
+
|
|
131
|
+
compute_encoder.set_bytes(conv_params, 2);
|
|
132
|
+
|
|
133
|
+
// Launch unfolding kernel
|
|
134
|
+
size_t tgp_x = std::min(conv_params.C, 64);
|
|
135
|
+
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
|
|
136
|
+
size_t tgp_y = 256 / tgp_x;
|
|
137
|
+
|
|
138
|
+
MTL::Size grid_dims = MTL::Size(
|
|
139
|
+
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
|
|
140
|
+
MTL::Size group_dims = MTL::Size(
|
|
141
|
+
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);
|
|
142
|
+
|
|
143
|
+
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
144
|
+
|
|
145
|
+
// Transpose kernel weights so that we can slice them by contiguous chunks
|
|
146
|
+
// of channel groups.
|
|
147
|
+
array wt_view(
|
|
148
|
+
{wt.shape(0), C_per_group, kernel_size}, wt.dtype(), nullptr, {});
|
|
149
|
+
wt_view.copy_shared_buffer(
|
|
150
|
+
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
|
151
|
+
|
|
152
|
+
// Materialize
|
|
153
|
+
array wt_transpose = contiguous_copy_gpu(wt_view, s);
|
|
154
|
+
|
|
155
|
+
// Perform gemm
|
|
156
|
+
std::vector<array> copies = {in_unfolded, wt_transpose};
|
|
157
|
+
return steel_matmul_regular(
|
|
158
|
+
/* const Stream& s = */ s,
|
|
159
|
+
/* Device& d = */ d,
|
|
160
|
+
/* const array& a = */ in_unfolded,
|
|
161
|
+
/* const array& b = */ wt_transpose,
|
|
162
|
+
/* array& c = */ out,
|
|
163
|
+
/* int M = */ implicit_M,
|
|
164
|
+
/* int N = */ implicit_N,
|
|
165
|
+
/* int K = */ implicit_K,
|
|
166
|
+
/* int batch_size_out = */ groups,
|
|
167
|
+
/* int lda = */ implicit_K * groups,
|
|
168
|
+
/* int ldb = */ implicit_K,
|
|
169
|
+
/* int ldd = */ implicit_N * groups,
|
|
170
|
+
/* bool transpose_a = */ false,
|
|
171
|
+
/* bool transpose_b = */ true,
|
|
172
|
+
/* std::vector<array>& copies = */ copies,
|
|
173
|
+
/* Shape batch_shape = */ {1},
|
|
174
|
+
/* Strides batch_strides = */ {0},
|
|
175
|
+
/* int64_t A_batch_strides = */ int64_t(implicit_K),
|
|
176
|
+
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
|
|
177
|
+
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
void implicit_gemm_conv_2D_gpu(
|
|
181
|
+
const Stream& s,
|
|
182
|
+
metal::Device& d,
|
|
183
|
+
const array& in,
|
|
184
|
+
const array& wt,
|
|
185
|
+
array out,
|
|
186
|
+
const MLXConvParams<2>& conv_params) {
|
|
187
|
+
const int groups = conv_params.groups;
|
|
188
|
+
const int C_per_group = conv_params.C / conv_params.groups;
|
|
189
|
+
const int O_per_group = conv_params.O / conv_params.groups;
|
|
190
|
+
|
|
191
|
+
// Deduce implicit gemm size
|
|
192
|
+
const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
|
193
|
+
const int implicit_N = O_per_group;
|
|
194
|
+
const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group;
|
|
195
|
+
|
|
196
|
+
// Determine block and warp tiles
|
|
197
|
+
int wm = 2, wn = 2;
|
|
198
|
+
|
|
199
|
+
int bm = implicit_M >= 8192 && C_per_group >= 64 ? 64 : 32;
|
|
200
|
+
int bn = (bm == 64 || implicit_N >= 64) ? 64 : 32;
|
|
201
|
+
int bk = 16;
|
|
202
|
+
|
|
203
|
+
if (implicit_N <= 16) {
|
|
204
|
+
bn = 8;
|
|
205
|
+
wm = 4;
|
|
206
|
+
wn = 1;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
int tn = (implicit_N + bn - 1) / bn;
|
|
210
|
+
int tm = (implicit_M + bm - 1) / bm;
|
|
211
|
+
int swizzle_log = 0;
|
|
212
|
+
|
|
213
|
+
// Fix small channel specialization
|
|
214
|
+
int n_channel_specialization = 0;
|
|
215
|
+
int channel_k_iters = ((C_per_group + bk - 1) / bk);
|
|
216
|
+
int gemm_k_iters = conv_params.wS[0] * conv_params.wS[1] * channel_k_iters;
|
|
217
|
+
|
|
218
|
+
if (C_per_group <= 2) {
|
|
219
|
+
gemm_k_iters = (implicit_K + bk - 1) / bk;
|
|
220
|
+
n_channel_specialization = C_per_group;
|
|
221
|
+
} else if (C_per_group <= 4) {
|
|
222
|
+
gemm_k_iters = ((conv_params.wS[0] * conv_params.wS[1] * 4) + bk - 1) / bk;
|
|
223
|
+
n_channel_specialization = C_per_group;
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
bool small_filter = (!n_channel_specialization) &&
|
|
227
|
+
(conv_params.wS[0] <= 16 && conv_params.wS[1] <= 16);
|
|
228
|
+
|
|
229
|
+
// Fix host side helper params
|
|
230
|
+
int sign = (conv_params.flip ? -1 : 1);
|
|
231
|
+
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
|
|
232
|
+
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
|
|
233
|
+
|
|
234
|
+
int inp_jump_w = sign * ijw;
|
|
235
|
+
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
|
|
236
|
+
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
|
|
237
|
+
sign * (conv_params.wS[1] - 1) * ijw;
|
|
238
|
+
|
|
239
|
+
// Build implicit gemm params
|
|
240
|
+
ImplicitGemmConv2DParams gemm_params{
|
|
241
|
+
/* const int M = */ implicit_M,
|
|
242
|
+
/* const int N = */ implicit_N,
|
|
243
|
+
/* const int K = */ implicit_K,
|
|
244
|
+
|
|
245
|
+
/* const int gemm_k_iterations = */ gemm_k_iters,
|
|
246
|
+
|
|
247
|
+
/* const int inp_jump_w = */ inp_jump_w,
|
|
248
|
+
/* const int inp_jump_h = */ inp_jump_h,
|
|
249
|
+
/* const int inp_jump_c = */ inp_jump_c,
|
|
250
|
+
|
|
251
|
+
/* const int tiles_n = */ tn,
|
|
252
|
+
/* const int tiles_m = */ tm,
|
|
253
|
+
/* const int swizzle_log = */ swizzle_log};
|
|
254
|
+
|
|
255
|
+
// Determine kernel
|
|
256
|
+
std::string kname;
|
|
257
|
+
kname.reserve(64);
|
|
258
|
+
concatenate(
|
|
259
|
+
kname,
|
|
260
|
+
"implicit_gemm_conv_2d_",
|
|
261
|
+
type_to_name(out),
|
|
262
|
+
"_bm",
|
|
263
|
+
bm,
|
|
264
|
+
"_bn",
|
|
265
|
+
bn,
|
|
266
|
+
"_bk",
|
|
267
|
+
bk,
|
|
268
|
+
"_wm",
|
|
269
|
+
wm,
|
|
270
|
+
"_wn",
|
|
271
|
+
wn,
|
|
272
|
+
"_channel_",
|
|
273
|
+
n_channel_specialization ? std::to_string(n_channel_specialization) : "l",
|
|
274
|
+
"_filter_",
|
|
275
|
+
small_filter ? 's' : 'l');
|
|
276
|
+
|
|
277
|
+
// Encode and dispatch kernel
|
|
278
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
279
|
+
auto kernel = get_steel_conv_kernel(
|
|
280
|
+
d,
|
|
281
|
+
kname,
|
|
282
|
+
out,
|
|
283
|
+
bm,
|
|
284
|
+
bn,
|
|
285
|
+
bk,
|
|
286
|
+
wm,
|
|
287
|
+
wn,
|
|
288
|
+
n_channel_specialization,
|
|
289
|
+
small_filter);
|
|
290
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
291
|
+
|
|
292
|
+
// Deduce grid launch dimensions
|
|
293
|
+
int tile = 1 << swizzle_log;
|
|
294
|
+
size_t grid_dim_y = (tm + tile - 1) / tile;
|
|
295
|
+
size_t grid_dim_x = tn * tile;
|
|
296
|
+
|
|
297
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
298
|
+
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, groups);
|
|
299
|
+
|
|
300
|
+
// Encode arrays
|
|
301
|
+
compute_encoder.set_input_array(in, 0);
|
|
302
|
+
compute_encoder.set_input_array(wt, 1);
|
|
303
|
+
compute_encoder.set_output_array(out, 2);
|
|
304
|
+
|
|
305
|
+
// Encode params
|
|
306
|
+
compute_encoder.set_bytes(conv_params, 3);
|
|
307
|
+
compute_encoder.set_bytes(gemm_params, 4);
|
|
308
|
+
|
|
309
|
+
// Launch kernel
|
|
310
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
void implicit_gemm_conv_2D_general_gpu(
|
|
314
|
+
const Stream& s,
|
|
315
|
+
metal::Device& d,
|
|
316
|
+
const array& in,
|
|
317
|
+
const array& wt,
|
|
318
|
+
array out,
|
|
319
|
+
const MLXConvParams<2>& conv_params) {
|
|
320
|
+
// Deduce implicit gemm size
|
|
321
|
+
int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1];
|
|
322
|
+
int implicit_N = conv_params.O;
|
|
323
|
+
int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C;
|
|
324
|
+
|
|
325
|
+
// Determine block and warp tiles
|
|
326
|
+
int wm = 2, wn = 2;
|
|
327
|
+
|
|
328
|
+
// Make jump params
|
|
329
|
+
int f_wgt_jump_h =
|
|
330
|
+
std::lcm(conv_params.idil[0], conv_params.kdil[0]) / conv_params.kdil[0];
|
|
331
|
+
int f_wgt_jump_w =
|
|
332
|
+
std::lcm(conv_params.idil[1], conv_params.kdil[1]) / conv_params.kdil[1];
|
|
333
|
+
|
|
334
|
+
int f_out_jump_h =
|
|
335
|
+
std::lcm(conv_params.idil[0], conv_params.str[0]) / conv_params.str[0];
|
|
336
|
+
int f_out_jump_w =
|
|
337
|
+
std::lcm(conv_params.idil[1], conv_params.str[1]) / conv_params.str[1];
|
|
338
|
+
|
|
339
|
+
int adj_out_h = (conv_params.oS[0] + f_out_jump_h - 1) / f_out_jump_h;
|
|
340
|
+
int adj_out_w = (conv_params.oS[1] + f_out_jump_w - 1) / f_out_jump_w;
|
|
341
|
+
int adj_out_hw = adj_out_h * adj_out_w;
|
|
342
|
+
int adj_implicit_m = conv_params.N * adj_out_hw;
|
|
343
|
+
|
|
344
|
+
Conv2DGeneralJumpParams jump_params{
|
|
345
|
+
/* const int f_wgt_jump_h = */ f_wgt_jump_h,
|
|
346
|
+
/* const int f_wgt_jump_w = */ f_wgt_jump_w,
|
|
347
|
+
|
|
348
|
+
/* const int f_out_jump_h = */ f_out_jump_h,
|
|
349
|
+
/* const int f_out_jump_w = */ f_out_jump_w,
|
|
350
|
+
|
|
351
|
+
/* const int adj_out_h = */ adj_out_h,
|
|
352
|
+
/* const int adj_out_w = */ adj_out_w,
|
|
353
|
+
/* const int adj_out_hw = */ adj_out_hw,
|
|
354
|
+
/* const int adj_implicit_m = */ adj_implicit_m};
|
|
355
|
+
|
|
356
|
+
// Make base info
|
|
357
|
+
std::vector<Conv2DGeneralBaseInfo> base_h(f_out_jump_h);
|
|
358
|
+
std::vector<Conv2DGeneralBaseInfo> base_w(f_out_jump_w);
|
|
359
|
+
|
|
360
|
+
int jump_h = conv_params.flip ? -conv_params.kdil[0] : conv_params.kdil[0];
|
|
361
|
+
int jump_w = conv_params.flip ? -conv_params.kdil[1] : conv_params.kdil[1];
|
|
362
|
+
|
|
363
|
+
int init_h =
|
|
364
|
+
(conv_params.flip ? (conv_params.wS[0] - 1) * conv_params.kdil[0] : 0);
|
|
365
|
+
int init_w =
|
|
366
|
+
(conv_params.flip ? (conv_params.wS[1] - 1) * conv_params.kdil[1] : 0);
|
|
367
|
+
|
|
368
|
+
for (int i = 0; i < f_out_jump_h; ++i) {
|
|
369
|
+
int ih_loop = i * conv_params.str[0] - conv_params.pad[0] + init_h;
|
|
370
|
+
|
|
371
|
+
int wh_base = 0;
|
|
372
|
+
while (wh_base < conv_params.wS[0] && ih_loop % conv_params.idil[0] != 0) {
|
|
373
|
+
wh_base++;
|
|
374
|
+
ih_loop += jump_h;
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
int wh_size =
|
|
378
|
+
((conv_params.wS[0] - wh_base) + f_wgt_jump_h - 1) / f_wgt_jump_h;
|
|
379
|
+
base_h[i] = {wh_base, wh_size};
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
for (int j = 0; j < f_out_jump_w; ++j) {
|
|
383
|
+
int iw_loop = j * conv_params.str[1] - conv_params.pad[1] + init_w;
|
|
384
|
+
|
|
385
|
+
int ww_base = 0;
|
|
386
|
+
while (ww_base < conv_params.wS[1] && iw_loop % conv_params.idil[1] != 0) {
|
|
387
|
+
ww_base++;
|
|
388
|
+
iw_loop += jump_w;
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
int ww_size =
|
|
392
|
+
((conv_params.wS[1] - ww_base) + f_wgt_jump_w - 1) / f_wgt_jump_w;
|
|
393
|
+
base_w[j] = {ww_base, ww_size};
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
// Collect block sizes
|
|
397
|
+
int bm = adj_implicit_m >= 8192 && conv_params.C >= 64 ? 64 : 32;
|
|
398
|
+
int bn = (bm == 64 && implicit_N >= 64) ? 64 : 32;
|
|
399
|
+
int bk = 16;
|
|
400
|
+
|
|
401
|
+
int tn = (implicit_N + bn - 1) / bn;
|
|
402
|
+
int tm = (adj_implicit_m + bm - 1) / bm;
|
|
403
|
+
int swizzle_log = 0;
|
|
404
|
+
|
|
405
|
+
// Get channel iteration info
|
|
406
|
+
int channel_k_iters = ((conv_params.C + bk - 1) / bk);
|
|
407
|
+
int gemm_k_iters = channel_k_iters;
|
|
408
|
+
bool align_C = conv_params.C % bk == 0;
|
|
409
|
+
|
|
410
|
+
// Fix host side helper params
|
|
411
|
+
int sign = (conv_params.flip ? -1 : 1);
|
|
412
|
+
int ijw = conv_params.in_strides[2] * conv_params.kdil[1];
|
|
413
|
+
int ijh = conv_params.in_strides[1] * conv_params.kdil[0];
|
|
414
|
+
|
|
415
|
+
int inp_jump_w = sign * ijw;
|
|
416
|
+
int inp_jump_h = sign * (ijh - (conv_params.wS[1] - 1) * ijw);
|
|
417
|
+
int inp_jump_c = bk - sign * (conv_params.wS[0] - 1) * ijh -
|
|
418
|
+
sign * (conv_params.wS[1] - 1) * ijw;
|
|
419
|
+
|
|
420
|
+
// Build implicit gemm params
|
|
421
|
+
ImplicitGemmConv2DParams gemm_params{
|
|
422
|
+
/* const int M = */ implicit_M,
|
|
423
|
+
/* const int N = */ implicit_N,
|
|
424
|
+
/* const int K = */ implicit_K,
|
|
425
|
+
|
|
426
|
+
/* const int gemm_k_iterations = */ gemm_k_iters,
|
|
427
|
+
|
|
428
|
+
/* const int inp_jump_w = */ inp_jump_w,
|
|
429
|
+
/* const int inp_jump_h = */ inp_jump_h,
|
|
430
|
+
/* const int inp_jump_c = */ inp_jump_c,
|
|
431
|
+
|
|
432
|
+
/* const int tiles_n = */ tn,
|
|
433
|
+
/* const int tiles_m = */ tm,
|
|
434
|
+
/* const int swizzle_log = */ swizzle_log};
|
|
435
|
+
|
|
436
|
+
// Determine kernel
|
|
437
|
+
std::string kname;
|
|
438
|
+
kname.reserve(64);
|
|
439
|
+
concatenate(
|
|
440
|
+
kname,
|
|
441
|
+
"implicit_gemm_conv_2d_general_",
|
|
442
|
+
type_to_name(out),
|
|
443
|
+
"_bm",
|
|
444
|
+
bm,
|
|
445
|
+
"_bn",
|
|
446
|
+
bn,
|
|
447
|
+
"_bk",
|
|
448
|
+
bk,
|
|
449
|
+
"_wm",
|
|
450
|
+
wm,
|
|
451
|
+
"_wn",
|
|
452
|
+
wn);
|
|
453
|
+
std::string hash_name;
|
|
454
|
+
hash_name.reserve(64);
|
|
455
|
+
concatenate(hash_name, kname, "_alC_", align_C);
|
|
456
|
+
metal::MTLFCList func_consts = {
|
|
457
|
+
{&align_C, MTL::DataType::DataTypeBool, 200},
|
|
458
|
+
};
|
|
459
|
+
|
|
460
|
+
// Encode and dispatch kernel
|
|
461
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
462
|
+
auto kernel = get_steel_conv_general_kernel(
|
|
463
|
+
d, kname, hash_name, func_consts, out, bm, bn, bk, wm, wn);
|
|
464
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
465
|
+
|
|
466
|
+
// Deduce grid launch dimensions
|
|
467
|
+
int tile = 1 << swizzle_log;
|
|
468
|
+
size_t grid_dim_y = (tm + tile - 1) / tile;
|
|
469
|
+
size_t grid_dim_x = tn * tile;
|
|
470
|
+
size_t grid_dim_z = f_out_jump_h * f_out_jump_w;
|
|
471
|
+
|
|
472
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
473
|
+
MTL::Size grid_dims = MTL::Size(grid_dim_x, grid_dim_y, grid_dim_z);
|
|
474
|
+
|
|
475
|
+
// Encode arrays
|
|
476
|
+
compute_encoder.set_input_array(in, 0);
|
|
477
|
+
compute_encoder.set_input_array(wt, 1);
|
|
478
|
+
compute_encoder.set_output_array(out, 2);
|
|
479
|
+
|
|
480
|
+
// Encode params
|
|
481
|
+
compute_encoder.set_bytes(conv_params, 3);
|
|
482
|
+
compute_encoder.set_bytes(gemm_params, 4);
|
|
483
|
+
compute_encoder.set_bytes(jump_params, 5);
|
|
484
|
+
|
|
485
|
+
compute_encoder.set_vector_bytes(base_h, 6);
|
|
486
|
+
compute_encoder.set_vector_bytes(base_w, 7);
|
|
487
|
+
|
|
488
|
+
// Launch kernel
|
|
489
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
void winograd_conv_2D_gpu(
|
|
493
|
+
const Stream& s,
|
|
494
|
+
metal::Device& d,
|
|
495
|
+
const array& in,
|
|
496
|
+
const array& wt,
|
|
497
|
+
array out,
|
|
498
|
+
const MLXConvParams<2>& conv_params,
|
|
499
|
+
std::vector<array>& copies_w) {
|
|
500
|
+
Shape padded_shape = {
|
|
501
|
+
conv_params.N,
|
|
502
|
+
conv_params.iS[0] + 2 * conv_params.pad[0],
|
|
503
|
+
conv_params.iS[1] + 2 * conv_params.pad[1],
|
|
504
|
+
conv_params.C};
|
|
505
|
+
|
|
506
|
+
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
|
|
507
|
+
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
|
|
508
|
+
|
|
509
|
+
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
|
|
510
|
+
|
|
511
|
+
// Fill with zeros
|
|
512
|
+
array zero_arr = array(0, in.dtype());
|
|
513
|
+
fill_gpu(zero_arr, in_padded, s);
|
|
514
|
+
copies_w.push_back(zero_arr);
|
|
515
|
+
|
|
516
|
+
// Pick input slice from padded
|
|
517
|
+
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
|
|
518
|
+
conv_params.pad[1] * in_padded.strides()[2];
|
|
519
|
+
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
|
|
520
|
+
in_padded_slice.copy_shared_buffer(
|
|
521
|
+
in_padded,
|
|
522
|
+
in_padded.strides(),
|
|
523
|
+
in_padded.flags(),
|
|
524
|
+
in_padded_slice.size(),
|
|
525
|
+
data_offset);
|
|
526
|
+
|
|
527
|
+
// Copy input values into the slice
|
|
528
|
+
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
|
|
529
|
+
|
|
530
|
+
copies_w.push_back(in_padded_slice);
|
|
531
|
+
copies_w.push_back(in_padded);
|
|
532
|
+
|
|
533
|
+
MLXConvParams<2> conv_params_updated{
|
|
534
|
+
/* const int N = */ static_cast<int>(in_padded.shape(0)),
|
|
535
|
+
/* const int C = */ static_cast<int>(in_padded.shape(3)),
|
|
536
|
+
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
537
|
+
/* const int iS[NDIM] = */
|
|
538
|
+
{static_cast<int>(in_padded.shape(1)),
|
|
539
|
+
static_cast<int>(in_padded.shape(2))},
|
|
540
|
+
/* const int wS[NDIM] = */
|
|
541
|
+
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
|
542
|
+
/* const int oS[NDIM] = */
|
|
543
|
+
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
|
544
|
+
/* const int str[NDIM] = */ {1, 1},
|
|
545
|
+
/* const int pad[NDIM] = */ {0, 0},
|
|
546
|
+
/* const int kdil[NDIM] = */ {1, 1},
|
|
547
|
+
/* const int idil[NDIM] = */ {1, 1},
|
|
548
|
+
/* const size_t in_strides[NDIM + 2] = */
|
|
549
|
+
{in_padded.strides()[0],
|
|
550
|
+
in_padded.strides()[1],
|
|
551
|
+
in_padded.strides()[2],
|
|
552
|
+
in_padded.strides()[3]},
|
|
553
|
+
/* const size_t wt_strides[NDIM + 2] = */
|
|
554
|
+
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
|
|
555
|
+
/* const size_t out_strides[NDIM + 2] = */
|
|
556
|
+
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
|
|
557
|
+
/* const int groups = */ 1,
|
|
558
|
+
/* const bool flip = */ false,
|
|
559
|
+
};
|
|
560
|
+
|
|
561
|
+
int O_c = conv_params.O;
|
|
562
|
+
int C_c = conv_params.C;
|
|
563
|
+
|
|
564
|
+
int N_tiles_n = conv_params.N;
|
|
565
|
+
int N_tiles_h = (conv_params.oS[0] + 5) / 6;
|
|
566
|
+
int N_tiles_w = (conv_params.oS[1] + 5) / 6;
|
|
567
|
+
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
|
|
568
|
+
|
|
569
|
+
// Do filter transform
|
|
570
|
+
Shape filt_wg_shape = {8 * 8, conv_params.C, conv_params.O};
|
|
571
|
+
array filt_wg(std::move(filt_wg_shape), wt.dtype(), nullptr, {});
|
|
572
|
+
filt_wg.set_data(allocator::malloc(filt_wg.nbytes()));
|
|
573
|
+
copies_w.push_back(filt_wg);
|
|
574
|
+
{
|
|
575
|
+
int bc = 32;
|
|
576
|
+
int bo = 4;
|
|
577
|
+
std::string kname;
|
|
578
|
+
kname.reserve(32);
|
|
579
|
+
concatenate(
|
|
580
|
+
kname,
|
|
581
|
+
"winograd_conv_2d_weight_transform_",
|
|
582
|
+
type_to_name(out),
|
|
583
|
+
"_bc",
|
|
584
|
+
bc);
|
|
585
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
586
|
+
auto kernel = d.get_kernel(kname);
|
|
587
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
588
|
+
|
|
589
|
+
compute_encoder.set_input_array(wt, 0);
|
|
590
|
+
compute_encoder.set_output_array(filt_wg, 1);
|
|
591
|
+
|
|
592
|
+
compute_encoder.set_bytes(C_c, 2);
|
|
593
|
+
compute_encoder.set_bytes(O_c, 3);
|
|
594
|
+
|
|
595
|
+
MTL::Size group_dims = MTL::Size(32, bo, 1);
|
|
596
|
+
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
|
|
597
|
+
|
|
598
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
// Do input transform
|
|
602
|
+
Shape inp_wg_shape = {8 * 8, N_tiles, conv_params.C};
|
|
603
|
+
array inp_wg(std::move(inp_wg_shape), in.dtype(), nullptr, {});
|
|
604
|
+
inp_wg.set_data(allocator::malloc(inp_wg.nbytes()));
|
|
605
|
+
copies_w.push_back(inp_wg);
|
|
606
|
+
{
|
|
607
|
+
int bc = 32;
|
|
608
|
+
int wm = 2;
|
|
609
|
+
int wn = 2;
|
|
610
|
+
std::string kname;
|
|
611
|
+
kname.reserve(32);
|
|
612
|
+
concatenate(
|
|
613
|
+
kname,
|
|
614
|
+
"winograd_conv_2d_input_transform_",
|
|
615
|
+
type_to_name(out),
|
|
616
|
+
"_bc",
|
|
617
|
+
bc);
|
|
618
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
619
|
+
auto kernel = d.get_kernel(kname);
|
|
620
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
621
|
+
|
|
622
|
+
compute_encoder.set_input_array(in_padded, 0);
|
|
623
|
+
compute_encoder.set_output_array(inp_wg, 1);
|
|
624
|
+
|
|
625
|
+
compute_encoder.set_bytes(conv_params_updated, 2);
|
|
626
|
+
|
|
627
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
628
|
+
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
|
629
|
+
|
|
630
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
// Do batched gemm
|
|
634
|
+
Shape out_wg_shape = {8 * 8, N_tiles, conv_params.O};
|
|
635
|
+
array out_wg(std::move(out_wg_shape), in.dtype(), nullptr, {});
|
|
636
|
+
out_wg.set_data(allocator::malloc(out_wg.nbytes()));
|
|
637
|
+
copies_w.push_back(out_wg);
|
|
638
|
+
{
|
|
639
|
+
std::vector<array> empty_copies;
|
|
640
|
+
steel_matmul(
|
|
641
|
+
s,
|
|
642
|
+
d,
|
|
643
|
+
/*a = */ inp_wg,
|
|
644
|
+
/*b = */ filt_wg,
|
|
645
|
+
/*c = */ out_wg,
|
|
646
|
+
/*M = */ N_tiles,
|
|
647
|
+
/*N = */ conv_params.O,
|
|
648
|
+
/*K = */ conv_params.C,
|
|
649
|
+
/*batch_size_out = */ 8 * 8,
|
|
650
|
+
/*a_cols = */ conv_params.C,
|
|
651
|
+
/*b_cols = */ conv_params.O,
|
|
652
|
+
/*a_transposed = */ false,
|
|
653
|
+
/*b_transposed = */ false,
|
|
654
|
+
/*copies = */ empty_copies);
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
// Do output transform
|
|
658
|
+
{
|
|
659
|
+
int bc = 32;
|
|
660
|
+
int wm = 2;
|
|
661
|
+
int wn = 2;
|
|
662
|
+
std::string kname;
|
|
663
|
+
kname.reserve(32);
|
|
664
|
+
concatenate(
|
|
665
|
+
kname,
|
|
666
|
+
"winograd_conv_2d_output_transform_",
|
|
667
|
+
type_to_name(out),
|
|
668
|
+
"_bo",
|
|
669
|
+
bc);
|
|
670
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
671
|
+
auto kernel = d.get_kernel(kname);
|
|
672
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
673
|
+
|
|
674
|
+
compute_encoder.set_input_array(out_wg, 0);
|
|
675
|
+
compute_encoder.set_output_array(out, 1);
|
|
676
|
+
|
|
677
|
+
compute_encoder.set_bytes(conv_params_updated, 2);
|
|
678
|
+
|
|
679
|
+
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
|
680
|
+
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
|
|
681
|
+
|
|
682
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
683
|
+
}
|
|
684
|
+
}
|
|
685
|
+
|
|
686
|
+
void depthwise_conv_2D_gpu(
|
|
687
|
+
const Stream& s,
|
|
688
|
+
metal::Device& d,
|
|
689
|
+
const array& in,
|
|
690
|
+
const array& wt,
|
|
691
|
+
array out,
|
|
692
|
+
const MLXConvParams<2>& conv_params) {
|
|
693
|
+
std::string base_name;
|
|
694
|
+
base_name.reserve(32);
|
|
695
|
+
concatenate(base_name, "depthwise_conv_2d_", type_to_name(out));
|
|
696
|
+
|
|
697
|
+
const int N = conv_params.N;
|
|
698
|
+
const int ker_h = conv_params.wS[0];
|
|
699
|
+
const int ker_w = conv_params.wS[1];
|
|
700
|
+
const int str_h = conv_params.str[0];
|
|
701
|
+
const int str_w = conv_params.str[1];
|
|
702
|
+
const int tc = 8;
|
|
703
|
+
const int tw = 8;
|
|
704
|
+
const int th = 4;
|
|
705
|
+
const bool do_flip = conv_params.flip;
|
|
706
|
+
|
|
707
|
+
metal::MTLFCList func_consts = {
|
|
708
|
+
{&ker_h, MTL::DataType::DataTypeInt, 00},
|
|
709
|
+
{&ker_w, MTL::DataType::DataTypeInt, 01},
|
|
710
|
+
{&str_h, MTL::DataType::DataTypeInt, 10},
|
|
711
|
+
{&str_w, MTL::DataType::DataTypeInt, 11},
|
|
712
|
+
{&th, MTL::DataType::DataTypeInt, 100},
|
|
713
|
+
{&tw, MTL::DataType::DataTypeInt, 101},
|
|
714
|
+
{&do_flip, MTL::DataType::DataTypeBool, 200},
|
|
715
|
+
};
|
|
716
|
+
|
|
717
|
+
// clang-format off
|
|
718
|
+
std::string hash_name;
|
|
719
|
+
hash_name.reserve(64);
|
|
720
|
+
concatenate(
|
|
721
|
+
hash_name,
|
|
722
|
+
base_name,
|
|
723
|
+
"_ker_h_", ker_h,
|
|
724
|
+
"_ker_w_", ker_w,
|
|
725
|
+
"_str_h_", str_h,
|
|
726
|
+
"_str_w_", str_w,
|
|
727
|
+
"_tgp_h_", th,
|
|
728
|
+
"_tgp_w_", tw,
|
|
729
|
+
"_do_flip_", do_flip ? 't' : 'n'); // clang-format on
|
|
730
|
+
|
|
731
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
732
|
+
auto kernel = d.get_kernel(base_name, hash_name, func_consts);
|
|
733
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
734
|
+
|
|
735
|
+
compute_encoder.set_input_array(in, 0);
|
|
736
|
+
compute_encoder.set_input_array(wt, 1);
|
|
737
|
+
compute_encoder.set_output_array(out, 2);
|
|
738
|
+
|
|
739
|
+
compute_encoder.set_bytes(conv_params, 3);
|
|
740
|
+
|
|
741
|
+
MTL::Size group_dims = MTL::Size(tc, tw, th);
|
|
742
|
+
MTL::Size grid_dims = MTL::Size(
|
|
743
|
+
conv_params.C / tc, conv_params.oS[1] / tw, (conv_params.oS[0] / th) * N);
|
|
744
|
+
|
|
745
|
+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
void dispatch_conv_2D_gpu(
|
|
749
|
+
const Stream& s,
|
|
750
|
+
metal::Device& d,
|
|
751
|
+
const array& in,
|
|
752
|
+
const array& wt,
|
|
753
|
+
array out,
|
|
754
|
+
const MLXConvParams<2>& conv_params,
|
|
755
|
+
std::vector<array>& copies) {
|
|
756
|
+
bool is_stride_one = conv_params.str[0] == 1 && conv_params.str[1] == 1;
|
|
757
|
+
bool is_kdil_one = conv_params.kdil[0] == 1 && conv_params.kdil[1] == 1;
|
|
758
|
+
bool is_idil_one = conv_params.idil[0] == 1 && conv_params.idil[1] == 1;
|
|
759
|
+
|
|
760
|
+
if (is_idil_one && conv_params.groups > 1) {
|
|
761
|
+
const int C_per_group = conv_params.C / conv_params.groups;
|
|
762
|
+
const int O_per_group = conv_params.O / conv_params.groups;
|
|
763
|
+
|
|
764
|
+
if (C_per_group == 1 && O_per_group == 1 && is_kdil_one &&
|
|
765
|
+
conv_params.wS[0] <= 7 && conv_params.wS[1] <= 7 &&
|
|
766
|
+
conv_params.str[0] <= 2 && conv_params.str[1] <= 2 &&
|
|
767
|
+
conv_params.oS[0] % 8 == 0 && conv_params.oS[1] % 8 == 0 &&
|
|
768
|
+
conv_params.wt_strides[1] == conv_params.wS[1] &&
|
|
769
|
+
conv_params.C % 16 == 0 && conv_params.C == conv_params.O) {
|
|
770
|
+
return depthwise_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
if ((C_per_group <= 4 || C_per_group % 16 == 0) &&
|
|
774
|
+
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
|
775
|
+
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
776
|
+
} else {
|
|
777
|
+
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
|
778
|
+
}
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
// Direct to winograd conv
|
|
782
|
+
bool inp_large =
|
|
783
|
+
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096;
|
|
784
|
+
bool channels_large = (conv_params.C + conv_params.O) >= 256;
|
|
785
|
+
bool out_large =
|
|
786
|
+
(conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256;
|
|
787
|
+
if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one &&
|
|
788
|
+
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
|
|
789
|
+
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
|
|
790
|
+
channels_large) {
|
|
791
|
+
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
// Direct to implicit gemm conv
|
|
795
|
+
if (is_idil_one && (conv_params.C <= 4 || conv_params.C % 16 == 0) &&
|
|
796
|
+
(conv_params.O <= 16 || conv_params.O % 16 == 0)) {
|
|
797
|
+
return implicit_gemm_conv_2D_gpu(s, d, in, wt, out, conv_params);
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
else if ((conv_params.C % 16 == 0 && conv_params.O % 16 == 0) || out_large) {
|
|
801
|
+
return implicit_gemm_conv_2D_general_gpu(s, d, in, wt, out, conv_params);
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
// Direct to explicit gemm conv
|
|
805
|
+
else {
|
|
806
|
+
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
|
807
|
+
}
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
void depthwise_conv_1D_gpu(
|
|
811
|
+
const Stream& s,
|
|
812
|
+
metal::Device& d,
|
|
813
|
+
const array& in,
|
|
814
|
+
array wt,
|
|
815
|
+
array out) {
|
|
816
|
+
bool large = in.size() > INT32_MAX || in.data_size() > INT32_MAX;
|
|
817
|
+
std::string base_name;
|
|
818
|
+
base_name.reserve(32);
|
|
819
|
+
concatenate(
|
|
820
|
+
base_name,
|
|
821
|
+
"depthwise_conv_1d_",
|
|
822
|
+
large ? "_large" : "",
|
|
823
|
+
type_to_name(out));
|
|
824
|
+
|
|
825
|
+
if (!wt.flags().row_contiguous) {
|
|
826
|
+
wt = contiguous_copy_gpu(wt, s);
|
|
827
|
+
d.add_temporary(wt, s.index);
|
|
828
|
+
}
|
|
829
|
+
auto& compute_encoder = d.get_command_encoder(s.index);
|
|
830
|
+
auto kernel = d.get_kernel(base_name);
|
|
831
|
+
compute_encoder.set_compute_pipeline_state(kernel);
|
|
832
|
+
|
|
833
|
+
auto B = in.shape(0);
|
|
834
|
+
auto Tout = out.shape(1);
|
|
835
|
+
auto D = in.shape(2);
|
|
836
|
+
auto K = wt.shape(1);
|
|
837
|
+
|
|
838
|
+
compute_encoder.set_input_array(in, 0);
|
|
839
|
+
compute_encoder.set_input_array(wt, 1);
|
|
840
|
+
compute_encoder.set_output_array(out, 2);
|
|
841
|
+
if (large) {
|
|
842
|
+
int64_t strides[3] = {in.strides(0), in.strides(1), in.strides(2)};
|
|
843
|
+
compute_encoder.set_bytes(strides, 3, 3);
|
|
844
|
+
|
|
845
|
+
} else {
|
|
846
|
+
int strides[3] = {
|
|
847
|
+
static_cast<int>(in.strides(0)),
|
|
848
|
+
static_cast<int>(in.strides(1)),
|
|
849
|
+
static_cast<int>(in.strides(2))};
|
|
850
|
+
compute_encoder.set_bytes(strides, 3, 3);
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
compute_encoder.set_bytes(K, 4);
|
|
854
|
+
auto group_dims = get_block_dims(D, Tout, B);
|
|
855
|
+
MTL::Size grid_dims = MTL::Size(D, Tout, B);
|
|
856
|
+
|
|
857
|
+
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
|
858
|
+
}
|
|
859
|
+
|
|
860
|
+
void conv_1D_gpu(
|
|
861
|
+
const Stream& s,
|
|
862
|
+
metal::Device& d,
|
|
863
|
+
const array& in,
|
|
864
|
+
const array& wt,
|
|
865
|
+
array out,
|
|
866
|
+
const std::vector<int>& padding,
|
|
867
|
+
const std::vector<int>& wt_strides,
|
|
868
|
+
const std::vector<int>& wt_dilation,
|
|
869
|
+
const std::vector<int>& in_dilation,
|
|
870
|
+
int groups,
|
|
871
|
+
bool flip,
|
|
872
|
+
std::vector<array>& copies) {
|
|
873
|
+
bool is_idil_one = in_dilation[0] == 1;
|
|
874
|
+
int C = in.shape(2);
|
|
875
|
+
int O = wt.shape(0);
|
|
876
|
+
// Fast path for fully separable 1D convolution
|
|
877
|
+
if (is_idil_one && (groups == C) && groups == O && wt_strides[0] == 1 &&
|
|
878
|
+
wt_dilation[0] == 1 && padding[0] == 0 && !flip) {
|
|
879
|
+
depthwise_conv_1D_gpu(s, d, in, wt, out);
|
|
880
|
+
return;
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
const int C_per_group = C / groups;
|
|
884
|
+
const int O_per_group = O / groups;
|
|
885
|
+
|
|
886
|
+
// Direct to implicit gemm conv
|
|
887
|
+
if (is_idil_one && (C_per_group <= 4 || C_per_group % 16 == 0) &&
|
|
888
|
+
(O_per_group <= 16 || O_per_group % 16 == 0)) {
|
|
889
|
+
MLXConvParams<2> conv_params{
|
|
890
|
+
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
891
|
+
/* const int C = */ C,
|
|
892
|
+
/* const int O = */ O,
|
|
893
|
+
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1)), 1},
|
|
894
|
+
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1)), 1},
|
|
895
|
+
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1)), 1},
|
|
896
|
+
/* const int str[NDIM] = */ {wt_strides[0], 1},
|
|
897
|
+
/* const int pad[NDIM] = */ {padding[0], 0},
|
|
898
|
+
/* const int kdil[NDIM] = */ {wt_dilation[0], 1},
|
|
899
|
+
/* const int idil[NDIM] = */ {in_dilation[0], 1},
|
|
900
|
+
/* const size_t in_strides[NDIM + 2] = */
|
|
901
|
+
{in.strides()[0], in.strides()[1], 0, in.strides()[2]},
|
|
902
|
+
/* const size_t wt_strides[NDIM + 2] = */
|
|
903
|
+
{wt.strides()[0], wt.strides()[1], 0, wt.strides()[2]},
|
|
904
|
+
/* const size_t out_strides[NDIM + 2] = */
|
|
905
|
+
{out.strides()[0], out.strides()[1], 0, out.strides()[2]},
|
|
906
|
+
/* const int groups = */ groups,
|
|
907
|
+
/* const bool flip = */ flip};
|
|
908
|
+
|
|
909
|
+
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
|
910
|
+
return;
|
|
911
|
+
}
|
|
912
|
+
|
|
913
|
+
// Make conv params
|
|
914
|
+
MLXConvParams<1> conv_params{
|
|
915
|
+
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
916
|
+
/* const int C = */ static_cast<int>(in.shape(2)),
|
|
917
|
+
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
918
|
+
/* const int iS[NDIM] = */ {static_cast<int>(in.shape(1))},
|
|
919
|
+
/* const int wS[NDIM] = */ {static_cast<int>(wt.shape(1))},
|
|
920
|
+
/* const int oS[NDIM] = */ {static_cast<int>(out.shape(1))},
|
|
921
|
+
/* const int str[NDIM] = */ {wt_strides[0]},
|
|
922
|
+
/* const int pad[NDIM] = */ {padding[0]},
|
|
923
|
+
/* const int kdil[NDIM] = */ {wt_dilation[0]},
|
|
924
|
+
/* const int idil[NDIM] = */ {in_dilation[0]},
|
|
925
|
+
/* const size_t in_strides[NDIM + 2] = */
|
|
926
|
+
{in.strides()[0], in.strides()[1], in.strides()[2]},
|
|
927
|
+
/* const size_t wt_strides[NDIM + 2] = */
|
|
928
|
+
{wt.strides()[0], wt.strides()[1], wt.strides()[2]},
|
|
929
|
+
/* const size_t out_strides[NDIM + 2] = */
|
|
930
|
+
{out.strides()[0], out.strides()[1], out.strides()[2]},
|
|
931
|
+
/* const int groups = */ groups,
|
|
932
|
+
/* const bool flip = */ flip};
|
|
933
|
+
|
|
934
|
+
// Direct to explicit gemm conv
|
|
935
|
+
if (groups > 1) {
|
|
936
|
+
return explicit_gemm_conv_group_ND_gpu(s, d, in, wt, out, conv_params);
|
|
937
|
+
} else {
|
|
938
|
+
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
|
939
|
+
}
|
|
940
|
+
}
|
|
941
|
+
|
|
942
|
+
void conv_2D_gpu(
|
|
943
|
+
const Stream& s,
|
|
944
|
+
metal::Device& d,
|
|
945
|
+
const array& in,
|
|
946
|
+
const array& wt,
|
|
947
|
+
array out,
|
|
948
|
+
const std::vector<int>& padding,
|
|
949
|
+
const std::vector<int>& wt_strides,
|
|
950
|
+
const std::vector<int>& wt_dilation,
|
|
951
|
+
const std::vector<int>& in_dilation,
|
|
952
|
+
const int groups,
|
|
953
|
+
bool flip,
|
|
954
|
+
std::vector<array>& copies) {
|
|
955
|
+
// Make conv params
|
|
956
|
+
MLXConvParams<2> conv_params{
|
|
957
|
+
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
958
|
+
/* const int C = */ static_cast<int>(in.shape(3)),
|
|
959
|
+
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
960
|
+
/* const int iS[NDIM] = */
|
|
961
|
+
{static_cast<int>(in.shape(1)), static_cast<int>(in.shape(2))},
|
|
962
|
+
/* const int wS[NDIM] = */
|
|
963
|
+
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
|
|
964
|
+
/* const int oS[NDIM] = */
|
|
965
|
+
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
|
|
966
|
+
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1]},
|
|
967
|
+
/* const int pad[NDIM] = */ {padding[0], padding[1]},
|
|
968
|
+
/* const int kdil[NDIM] = */ {wt_dilation[0], wt_dilation[1]},
|
|
969
|
+
/* const int idil[NDIM] = */ {in_dilation[0], in_dilation[1]},
|
|
970
|
+
/* const size_t in_strides[NDIM + 2] = */
|
|
971
|
+
{in.strides(0), in.strides(1), in.strides(2), in.strides(3)},
|
|
972
|
+
/* const size_t wt_strides[NDIM + 2] = */
|
|
973
|
+
{wt.strides(0), wt.strides(1), wt.strides(2), wt.strides(3)},
|
|
974
|
+
/* const size_t out_strides[NDIM + 2] = */
|
|
975
|
+
{out.strides(0), out.strides(1), out.strides(2), out.strides(3)},
|
|
976
|
+
/* const int groups = */ groups,
|
|
977
|
+
/* const bool flip = */ flip,
|
|
978
|
+
};
|
|
979
|
+
dispatch_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
void conv_3D_gpu(
|
|
983
|
+
const Stream& s,
|
|
984
|
+
metal::Device& d,
|
|
985
|
+
const array& in,
|
|
986
|
+
const array& wt,
|
|
987
|
+
array out,
|
|
988
|
+
const std::vector<int>& padding,
|
|
989
|
+
const std::vector<int>& wt_strides,
|
|
990
|
+
const std::vector<int>& wt_dilation,
|
|
991
|
+
const std::vector<int>& in_dilation,
|
|
992
|
+
bool flip,
|
|
993
|
+
std::vector<array>& copies) {
|
|
994
|
+
// Make conv params
|
|
995
|
+
MLXConvParams<3> conv_params{
|
|
996
|
+
/* const int N = */ static_cast<int>(in.shape(0)),
|
|
997
|
+
/* const int C = */ static_cast<int>(in.shape(4)),
|
|
998
|
+
/* const int O = */ static_cast<int>(wt.shape(0)),
|
|
999
|
+
/* const int iS[NDIM] = */
|
|
1000
|
+
{static_cast<int>(in.shape(1)),
|
|
1001
|
+
static_cast<int>(in.shape(2)),
|
|
1002
|
+
static_cast<int>(in.shape(3))},
|
|
1003
|
+
/* const int wS[NDIM] = */
|
|
1004
|
+
{static_cast<int>(wt.shape(1)),
|
|
1005
|
+
static_cast<int>(wt.shape(2)),
|
|
1006
|
+
static_cast<int>(wt.shape(3))},
|
|
1007
|
+
/* const int oS[NDIM] = */
|
|
1008
|
+
{static_cast<int>(out.shape(1)),
|
|
1009
|
+
static_cast<int>(out.shape(2)),
|
|
1010
|
+
static_cast<int>(out.shape(3))},
|
|
1011
|
+
/* const int str[NDIM] = */ {wt_strides[0], wt_strides[1], wt_strides[2]},
|
|
1012
|
+
/* const int pad[NDIM] = */ {padding[0], padding[1], padding[2]},
|
|
1013
|
+
/* const int kdil[NDIM] = */
|
|
1014
|
+
{wt_dilation[0], wt_dilation[1], wt_dilation[2]},
|
|
1015
|
+
/* const int idil[NDIM] = */
|
|
1016
|
+
{in_dilation[0], in_dilation[1], in_dilation[2]},
|
|
1017
|
+
/* const size_t in_strides[NDIM + 2] = */
|
|
1018
|
+
{in.strides()[0],
|
|
1019
|
+
in.strides()[1],
|
|
1020
|
+
in.strides()[2],
|
|
1021
|
+
in.strides()[3],
|
|
1022
|
+
in.strides()[4]},
|
|
1023
|
+
/* const size_t wt_strides[NDIM + 2] = */
|
|
1024
|
+
{wt.strides()[0],
|
|
1025
|
+
wt.strides()[1],
|
|
1026
|
+
wt.strides()[2],
|
|
1027
|
+
wt.strides()[3],
|
|
1028
|
+
wt.strides()[4]},
|
|
1029
|
+
/* const size_t out_strides[NDIM + 2] = */
|
|
1030
|
+
{out.strides()[0],
|
|
1031
|
+
out.strides()[1],
|
|
1032
|
+
out.strides()[2],
|
|
1033
|
+
out.strides()[3],
|
|
1034
|
+
out.strides()[4]},
|
|
1035
|
+
/* const int groups = */ 1,
|
|
1036
|
+
/* const bool flip = */ flip,
|
|
1037
|
+
};
|
|
1038
|
+
return explicit_gemm_conv_ND_gpu(s, d, in, wt, out, conv_params);
|
|
1039
|
+
}
|
|
1040
|
+
|
|
1041
|
+
} // namespace
|
|
1042
|
+
|
|
1043
|
+
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|
1044
|
+
out.set_data(allocator::malloc(out.nbytes()));
|
|
1045
|
+
auto& s = stream();
|
|
1046
|
+
auto& d = metal::device(s.device);
|
|
1047
|
+
|
|
1048
|
+
// Ensure contiguity
|
|
1049
|
+
std::vector<array> copies;
|
|
1050
|
+
auto in = inputs[0];
|
|
1051
|
+
auto wt = inputs[1];
|
|
1052
|
+
if (!in.flags().row_contiguous) {
|
|
1053
|
+
in = contiguous_copy_gpu(in, s);
|
|
1054
|
+
copies.push_back(in);
|
|
1055
|
+
}
|
|
1056
|
+
if (!wt.flags().row_contiguous) {
|
|
1057
|
+
wt = contiguous_copy_gpu(wt, s);
|
|
1058
|
+
copies.push_back(wt);
|
|
1059
|
+
}
|
|
1060
|
+
|
|
1061
|
+
// 3D conv
|
|
1062
|
+
if (out.ndim() == 5) {
|
|
1063
|
+
conv_3D_gpu(
|
|
1064
|
+
s,
|
|
1065
|
+
d,
|
|
1066
|
+
in,
|
|
1067
|
+
wt,
|
|
1068
|
+
out,
|
|
1069
|
+
padding_lo_,
|
|
1070
|
+
kernel_strides_,
|
|
1071
|
+
kernel_dilation_,
|
|
1072
|
+
input_dilation_,
|
|
1073
|
+
flip_,
|
|
1074
|
+
copies);
|
|
1075
|
+
}
|
|
1076
|
+
// 2D conv
|
|
1077
|
+
else if (out.ndim() == 4) {
|
|
1078
|
+
conv_2D_gpu(
|
|
1079
|
+
s,
|
|
1080
|
+
d,
|
|
1081
|
+
in,
|
|
1082
|
+
wt,
|
|
1083
|
+
out,
|
|
1084
|
+
padding_lo_,
|
|
1085
|
+
kernel_strides_,
|
|
1086
|
+
kernel_dilation_,
|
|
1087
|
+
input_dilation_,
|
|
1088
|
+
groups_,
|
|
1089
|
+
flip_,
|
|
1090
|
+
copies);
|
|
1091
|
+
}
|
|
1092
|
+
// 1D conv
|
|
1093
|
+
else if (out.ndim() == 3) {
|
|
1094
|
+
conv_1D_gpu(
|
|
1095
|
+
s,
|
|
1096
|
+
d,
|
|
1097
|
+
in,
|
|
1098
|
+
wt,
|
|
1099
|
+
out,
|
|
1100
|
+
padding_lo_,
|
|
1101
|
+
kernel_strides_,
|
|
1102
|
+
kernel_dilation_,
|
|
1103
|
+
input_dilation_,
|
|
1104
|
+
groups_,
|
|
1105
|
+
flip_,
|
|
1106
|
+
copies);
|
|
1107
|
+
}
|
|
1108
|
+
// Throw error
|
|
1109
|
+
else {
|
|
1110
|
+
throw std::invalid_argument(
|
|
1111
|
+
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
|
|
1112
|
+
}
|
|
1113
|
+
|
|
1114
|
+
// Record copies
|
|
1115
|
+
d.add_temporaries(std::move(copies), s.index);
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
} // namespace mlx::core
|