mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
// Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/common/unary.h"
|
|
6
|
+
#include "mlx/backend/cpu/encoder.h"
|
|
7
|
+
#include "mlx/backend/cpu/simd/simd.h"
|
|
8
|
+
#include "mlx/utils.h"
|
|
9
|
+
|
|
10
|
+
namespace mlx::core {
|
|
11
|
+
|
|
12
|
+
template <typename T, typename U = T, typename Op>
|
|
13
|
+
void unary_op(const T* a, U* out, size_t shape, size_t stride) {
|
|
14
|
+
for (size_t i = 0; i < shape; i += 1) {
|
|
15
|
+
out[i] = Op{}(*a);
|
|
16
|
+
a += stride;
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
template <typename T, typename U = T, typename Op>
|
|
21
|
+
void unary_op(const array& a, array& out, Op) {
|
|
22
|
+
const T* src = a.data<T>();
|
|
23
|
+
U* dst = out.data<U>();
|
|
24
|
+
auto ndim = a.ndim();
|
|
25
|
+
if (a.flags().contiguous) {
|
|
26
|
+
auto size = a.data_size();
|
|
27
|
+
constexpr int N = std::min(simd::max_size<T>, simd::max_size<U>);
|
|
28
|
+
while (size >= N) {
|
|
29
|
+
simd::store(dst, simd::Simd<U, N>(Op{}(simd::load<T, N>(src))));
|
|
30
|
+
size -= N;
|
|
31
|
+
src += N;
|
|
32
|
+
dst += N;
|
|
33
|
+
}
|
|
34
|
+
while (size > 0) {
|
|
35
|
+
*dst = Op{}(*src);
|
|
36
|
+
size--;
|
|
37
|
+
dst++;
|
|
38
|
+
src++;
|
|
39
|
+
}
|
|
40
|
+
} else {
|
|
41
|
+
size_t shape = ndim > 0 ? a.shape().back() : 1;
|
|
42
|
+
size_t stride = ndim > 0 ? a.strides().back() : 1;
|
|
43
|
+
if (ndim <= 1) {
|
|
44
|
+
unary_op<T, U, Op>(src, dst, shape, stride);
|
|
45
|
+
return;
|
|
46
|
+
}
|
|
47
|
+
auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1);
|
|
48
|
+
for (size_t elem = 0; elem < a.size(); elem += shape) {
|
|
49
|
+
unary_op<T, U, Op>(src + it.loc, dst + elem, shape, stride);
|
|
50
|
+
it.step();
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
template <typename Op>
|
|
56
|
+
void unary(const array& a, array& out, Op op, Stream stream) {
|
|
57
|
+
set_unary_output_data(a, out);
|
|
58
|
+
auto& encoder = cpu::get_command_encoder(stream);
|
|
59
|
+
encoder.set_input_array(a);
|
|
60
|
+
encoder.set_output_array(out);
|
|
61
|
+
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
62
|
+
out = array::unsafe_weak_copy(out),
|
|
63
|
+
op = op]() mutable {
|
|
64
|
+
switch (out.dtype()) {
|
|
65
|
+
case bool_:
|
|
66
|
+
unary_op<bool>(a, out, op);
|
|
67
|
+
break;
|
|
68
|
+
case uint8:
|
|
69
|
+
unary_op<uint8_t>(a, out, op);
|
|
70
|
+
break;
|
|
71
|
+
case uint16:
|
|
72
|
+
unary_op<uint16_t>(a, out, op);
|
|
73
|
+
break;
|
|
74
|
+
case uint32:
|
|
75
|
+
unary_op<uint32_t>(a, out, op);
|
|
76
|
+
break;
|
|
77
|
+
case uint64:
|
|
78
|
+
unary_op<uint64_t>(a, out, op);
|
|
79
|
+
break;
|
|
80
|
+
case int8:
|
|
81
|
+
unary_op<int8_t>(a, out, op);
|
|
82
|
+
break;
|
|
83
|
+
case int16:
|
|
84
|
+
unary_op<int16_t>(a, out, op);
|
|
85
|
+
break;
|
|
86
|
+
case int32:
|
|
87
|
+
unary_op<int32_t>(a, out, op);
|
|
88
|
+
break;
|
|
89
|
+
case int64:
|
|
90
|
+
unary_op<int64_t>(a, out, op);
|
|
91
|
+
break;
|
|
92
|
+
case float16:
|
|
93
|
+
unary_op<float16_t>(a, out, op);
|
|
94
|
+
break;
|
|
95
|
+
case float32:
|
|
96
|
+
unary_op<float>(a, out, op);
|
|
97
|
+
break;
|
|
98
|
+
case float64:
|
|
99
|
+
unary_op<double>(a, out, op);
|
|
100
|
+
break;
|
|
101
|
+
case bfloat16:
|
|
102
|
+
unary_op<bfloat16_t>(a, out, op);
|
|
103
|
+
break;
|
|
104
|
+
case complex64:
|
|
105
|
+
unary_op<complex64_t>(a, out, op);
|
|
106
|
+
break;
|
|
107
|
+
}
|
|
108
|
+
});
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
template <typename Op>
|
|
112
|
+
void unary_real_fp(const array& a, array& out, Op op, Stream stream) {
|
|
113
|
+
set_unary_output_data(a, out);
|
|
114
|
+
auto& encoder = cpu::get_command_encoder(stream);
|
|
115
|
+
encoder.set_input_array(a);
|
|
116
|
+
encoder.set_output_array(out);
|
|
117
|
+
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
118
|
+
out = array::unsafe_weak_copy(out),
|
|
119
|
+
op = op]() mutable {
|
|
120
|
+
switch (out.dtype()) {
|
|
121
|
+
case bfloat16:
|
|
122
|
+
unary_op<bfloat16_t>(a, out, op);
|
|
123
|
+
break;
|
|
124
|
+
case float16:
|
|
125
|
+
unary_op<float16_t>(a, out, op);
|
|
126
|
+
break;
|
|
127
|
+
case float32:
|
|
128
|
+
unary_op<float>(a, out, op);
|
|
129
|
+
break;
|
|
130
|
+
case float64:
|
|
131
|
+
unary_op<double>(a, out, op);
|
|
132
|
+
break;
|
|
133
|
+
default:
|
|
134
|
+
std::ostringstream err;
|
|
135
|
+
err << "[unary_real] Does not support " << out.dtype();
|
|
136
|
+
throw std::runtime_error(err.str());
|
|
137
|
+
}
|
|
138
|
+
});
|
|
139
|
+
}
|
|
140
|
+
template <typename Op>
|
|
141
|
+
void unary_fp(const array& a, array& out, Op op, Stream stream) {
|
|
142
|
+
set_unary_output_data(a, out);
|
|
143
|
+
auto& encoder = cpu::get_command_encoder(stream);
|
|
144
|
+
encoder.set_input_array(a);
|
|
145
|
+
encoder.set_output_array(out);
|
|
146
|
+
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
147
|
+
out = array::unsafe_weak_copy(out),
|
|
148
|
+
op = op]() mutable {
|
|
149
|
+
switch (out.dtype()) {
|
|
150
|
+
case bfloat16:
|
|
151
|
+
unary_op<bfloat16_t>(a, out, op);
|
|
152
|
+
break;
|
|
153
|
+
case float16:
|
|
154
|
+
unary_op<float16_t>(a, out, op);
|
|
155
|
+
break;
|
|
156
|
+
case float32:
|
|
157
|
+
unary_op<float>(a, out, op);
|
|
158
|
+
break;
|
|
159
|
+
case float64:
|
|
160
|
+
unary_op<double>(a, out, op);
|
|
161
|
+
break;
|
|
162
|
+
case complex64:
|
|
163
|
+
unary_op<complex64_t>(a, out, op);
|
|
164
|
+
break;
|
|
165
|
+
default:
|
|
166
|
+
std::ostringstream err;
|
|
167
|
+
err << "[unary_fp] Does not support " << out.dtype();
|
|
168
|
+
throw std::runtime_error(err.str());
|
|
169
|
+
}
|
|
170
|
+
});
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
template <typename Op>
|
|
174
|
+
void unary_signed(const array& a, array& out, Op op, Stream stream) {
|
|
175
|
+
set_unary_output_data(a, out);
|
|
176
|
+
auto& encoder = cpu::get_command_encoder(stream);
|
|
177
|
+
encoder.set_input_array(a);
|
|
178
|
+
encoder.set_output_array(out);
|
|
179
|
+
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
180
|
+
out = array::unsafe_weak_copy(out),
|
|
181
|
+
op = op]() mutable {
|
|
182
|
+
switch (out.dtype()) {
|
|
183
|
+
case int8:
|
|
184
|
+
unary_op<int8_t>(a, out, op);
|
|
185
|
+
break;
|
|
186
|
+
case int16:
|
|
187
|
+
unary_op<int16_t>(a, out, op);
|
|
188
|
+
break;
|
|
189
|
+
case int32:
|
|
190
|
+
unary_op<int32_t>(a, out, op);
|
|
191
|
+
break;
|
|
192
|
+
case int64:
|
|
193
|
+
unary_op<int64_t>(a, out, op);
|
|
194
|
+
break;
|
|
195
|
+
case float16:
|
|
196
|
+
unary_op<float16_t>(a, out, op);
|
|
197
|
+
break;
|
|
198
|
+
case float32:
|
|
199
|
+
unary_op<float>(a, out, op);
|
|
200
|
+
break;
|
|
201
|
+
case float64:
|
|
202
|
+
unary_op<double>(a, out, op);
|
|
203
|
+
break;
|
|
204
|
+
case bfloat16:
|
|
205
|
+
unary_op<bfloat16_t>(a, out, op);
|
|
206
|
+
break;
|
|
207
|
+
case complex64:
|
|
208
|
+
unary_op<complex64_t>(a, out, op);
|
|
209
|
+
break;
|
|
210
|
+
default:
|
|
211
|
+
throw std::runtime_error("[Abs] Called on unsigned type");
|
|
212
|
+
}
|
|
213
|
+
});
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
template <typename Op>
|
|
217
|
+
void unary_complex(const array& a, array& out, Op op, Stream stream) {
|
|
218
|
+
set_unary_output_data(a, out);
|
|
219
|
+
auto& encoder = cpu::get_command_encoder(stream);
|
|
220
|
+
encoder.set_input_array(a);
|
|
221
|
+
encoder.set_output_array(out);
|
|
222
|
+
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
223
|
+
out = array::unsafe_weak_copy(out),
|
|
224
|
+
op = op]() mutable { unary_op<complex64_t>(a, out, op); });
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
template <typename Op>
|
|
228
|
+
void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) {
|
|
229
|
+
set_unary_output_data(a, out);
|
|
230
|
+
auto& encoder = cpu::get_command_encoder(stream);
|
|
231
|
+
encoder.set_input_array(a);
|
|
232
|
+
encoder.set_output_array(out);
|
|
233
|
+
encoder.dispatch(
|
|
234
|
+
[a = array::unsafe_weak_copy(a),
|
|
235
|
+
out = array::unsafe_weak_copy(out),
|
|
236
|
+
op = op]() mutable { unary_op<complex64_t, float>(a, out, op); });
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
template <typename Op>
|
|
240
|
+
void unary_int(const array& a, array& out, Op op, Stream stream) {
|
|
241
|
+
set_unary_output_data(a, out);
|
|
242
|
+
auto& encoder = cpu::get_command_encoder(stream);
|
|
243
|
+
encoder.set_input_array(a);
|
|
244
|
+
encoder.set_output_array(out);
|
|
245
|
+
encoder.dispatch([a = array::unsafe_weak_copy(a),
|
|
246
|
+
out = array::unsafe_weak_copy(out),
|
|
247
|
+
op = op]() mutable {
|
|
248
|
+
switch (out.dtype()) {
|
|
249
|
+
case uint8:
|
|
250
|
+
unary_op<uint8_t>(a, out, op);
|
|
251
|
+
break;
|
|
252
|
+
case uint16:
|
|
253
|
+
unary_op<uint16_t>(a, out, op);
|
|
254
|
+
break;
|
|
255
|
+
case uint32:
|
|
256
|
+
unary_op<uint32_t>(a, out, op);
|
|
257
|
+
break;
|
|
258
|
+
case uint64:
|
|
259
|
+
unary_op<uint64_t>(a, out, op);
|
|
260
|
+
break;
|
|
261
|
+
case int8:
|
|
262
|
+
unary_op<int8_t>(a, out, op);
|
|
263
|
+
break;
|
|
264
|
+
case int16:
|
|
265
|
+
unary_op<int16_t>(a, out, op);
|
|
266
|
+
break;
|
|
267
|
+
case int32:
|
|
268
|
+
unary_op<int32_t>(a, out, op);
|
|
269
|
+
break;
|
|
270
|
+
case int64:
|
|
271
|
+
unary_op<int64_t>(a, out, op);
|
|
272
|
+
break;
|
|
273
|
+
default:
|
|
274
|
+
std::ostringstream err;
|
|
275
|
+
err << "[unary_int] Does not support " << out.dtype();
|
|
276
|
+
throw std::runtime_error(err.str());
|
|
277
|
+
}
|
|
278
|
+
});
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
} // namespace mlx::core
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <stdint.h>
|
|
6
|
+
#include <cmath>
|
|
7
|
+
#include <complex>
|
|
8
|
+
|
|
9
|
+
#include "mlx/backend/cpu/simd/simd.h"
|
|
10
|
+
|
|
11
|
+
namespace mlx::core::detail {
|
|
12
|
+
|
|
13
|
+
using namespace mlx::core::simd;
|
|
14
|
+
|
|
15
|
+
#define SINGLE() \
|
|
16
|
+
template <typename T> \
|
|
17
|
+
T operator()(T x) { \
|
|
18
|
+
return (*this)(Simd<T, 1>(x)).value; \
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
#define DEFAULT_OP(Op, op) \
|
|
22
|
+
struct Op { \
|
|
23
|
+
template <int N, typename T> \
|
|
24
|
+
Simd<T, N> operator()(Simd<T, N> x) { \
|
|
25
|
+
return simd::op(x); \
|
|
26
|
+
} \
|
|
27
|
+
SINGLE() \
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
DEFAULT_OP(Abs, abs)
|
|
31
|
+
DEFAULT_OP(ArcCos, acos)
|
|
32
|
+
DEFAULT_OP(ArcCosh, acosh)
|
|
33
|
+
DEFAULT_OP(ArcSin, asin)
|
|
34
|
+
DEFAULT_OP(ArcSinh, asinh)
|
|
35
|
+
DEFAULT_OP(ArcTan, atan)
|
|
36
|
+
DEFAULT_OP(ArcTanh, atanh)
|
|
37
|
+
DEFAULT_OP(BitwiseInvert, operator~)
|
|
38
|
+
DEFAULT_OP(Ceil, ceil)
|
|
39
|
+
DEFAULT_OP(Conjugate, conj)
|
|
40
|
+
DEFAULT_OP(Cos, cos)
|
|
41
|
+
DEFAULT_OP(Cosh, cosh)
|
|
42
|
+
DEFAULT_OP(Erf, erf)
|
|
43
|
+
DEFAULT_OP(ErfInv, erfinv)
|
|
44
|
+
DEFAULT_OP(Exp, exp)
|
|
45
|
+
DEFAULT_OP(Expm1, expm1)
|
|
46
|
+
DEFAULT_OP(Floor, floor);
|
|
47
|
+
DEFAULT_OP(Log, log);
|
|
48
|
+
DEFAULT_OP(Log2, log2);
|
|
49
|
+
DEFAULT_OP(Log10, log10);
|
|
50
|
+
DEFAULT_OP(Log1p, log1p);
|
|
51
|
+
DEFAULT_OP(LogicalNot, operator!)
|
|
52
|
+
DEFAULT_OP(Negative, operator-)
|
|
53
|
+
DEFAULT_OP(Round, rint);
|
|
54
|
+
DEFAULT_OP(Sin, sin)
|
|
55
|
+
DEFAULT_OP(Sinh, sinh)
|
|
56
|
+
DEFAULT_OP(Sqrt, sqrt)
|
|
57
|
+
DEFAULT_OP(Rsqrt, rsqrt)
|
|
58
|
+
DEFAULT_OP(Tan, tan)
|
|
59
|
+
DEFAULT_OP(Tanh, tanh)
|
|
60
|
+
|
|
61
|
+
struct Imag {
|
|
62
|
+
template <int N>
|
|
63
|
+
Simd<float, N> operator()(Simd<complex64_t, N> x) {
|
|
64
|
+
return simd::imag(x);
|
|
65
|
+
}
|
|
66
|
+
SINGLE()
|
|
67
|
+
};
|
|
68
|
+
|
|
69
|
+
struct Real {
|
|
70
|
+
template <int N>
|
|
71
|
+
Simd<float, N> operator()(Simd<complex64_t, N> x) {
|
|
72
|
+
return simd::real(x);
|
|
73
|
+
}
|
|
74
|
+
SINGLE()
|
|
75
|
+
};
|
|
76
|
+
|
|
77
|
+
struct Sigmoid {
|
|
78
|
+
template <int N, typename T>
|
|
79
|
+
Simd<T, N> operator()(Simd<T, N> x) {
|
|
80
|
+
auto y = 1.0f / (1.0f + simd::exp(simd::abs(x)));
|
|
81
|
+
return simd::select(x < Simd<T, N>{0}, y, Simd<T, N>{1} - y);
|
|
82
|
+
}
|
|
83
|
+
SINGLE()
|
|
84
|
+
};
|
|
85
|
+
|
|
86
|
+
struct Sign {
|
|
87
|
+
template <int N, typename T>
|
|
88
|
+
Simd<T, N> operator()(Simd<T, N> x) {
|
|
89
|
+
auto z = Simd<T, N>{0};
|
|
90
|
+
auto o = Simd<T, N>{1};
|
|
91
|
+
auto m = Simd<T, N>{-1};
|
|
92
|
+
if constexpr (std::is_unsigned_v<T>) {
|
|
93
|
+
return simd::select(x == z, z, o);
|
|
94
|
+
} else if constexpr (std::is_same_v<T, complex64_t>) {
|
|
95
|
+
return simd::select(x == z, x, Simd<T, N>(x / simd::abs(x)));
|
|
96
|
+
} else {
|
|
97
|
+
return simd::select(x < z, m, simd::select(x > z, o, z));
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
SINGLE()
|
|
101
|
+
};
|
|
102
|
+
|
|
103
|
+
struct Square {
|
|
104
|
+
template <int N, typename T>
|
|
105
|
+
Simd<T, N> operator()(Simd<T, N> x) {
|
|
106
|
+
return x * x;
|
|
107
|
+
}
|
|
108
|
+
SINGLE()
|
|
109
|
+
};
|
|
110
|
+
|
|
111
|
+
template <int N>
|
|
112
|
+
Simd<float, N> fp32_from_bits(Simd<uint32_t, N> x) {
|
|
113
|
+
return *(Simd<float, N>*)(&x);
|
|
114
|
+
}
|
|
115
|
+
template <int N>
|
|
116
|
+
Simd<uint32_t, N> fp32_to_bits(Simd<float, N> x) {
|
|
117
|
+
return *(Simd<uint32_t, N>*)(&x);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
struct ToFP8 {
|
|
121
|
+
template <typename T, int N>
|
|
122
|
+
Simd<uint8_t, N> operator()(Simd<T, N> f) {
|
|
123
|
+
uint32_t fp8_max = 543 << 21;
|
|
124
|
+
auto denorm_mask = Simd<uint32_t, N>(141 << 23);
|
|
125
|
+
Simd<uint32_t, N> f_bits;
|
|
126
|
+
Simd<float, N> f32 = f;
|
|
127
|
+
f_bits = fp32_to_bits(f32);
|
|
128
|
+
Simd<uint8_t, N> result = 0u;
|
|
129
|
+
auto sign = f_bits & 0x80000000;
|
|
130
|
+
f_bits = f_bits ^ sign;
|
|
131
|
+
|
|
132
|
+
auto f_bits_low =
|
|
133
|
+
fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
|
|
134
|
+
auto result_low = Simd<uint8_t, N>(f_bits_low - denorm_mask);
|
|
135
|
+
|
|
136
|
+
auto mant_odd = Simd<uint8_t, N>((f_bits >> 20) & 1);
|
|
137
|
+
auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF);
|
|
138
|
+
f_bits_high = f_bits_high + Simd<uint32_t, N>(mant_odd);
|
|
139
|
+
|
|
140
|
+
auto result_high = Simd<uint8_t, N>(f_bits_high >> 20);
|
|
141
|
+
result = select(f_bits < (121 << 23), result_low, result_high);
|
|
142
|
+
|
|
143
|
+
auto result_sat = Simd<uint8_t, N>(0x7E);
|
|
144
|
+
result = select(f_bits >= fp8_max, result_sat, result);
|
|
145
|
+
return result | Simd<uint8_t, N>(sign >> 24);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
template <typename T>
|
|
149
|
+
uint8_t operator()(T x) {
|
|
150
|
+
return (*this)(Simd<T, 1>(x)).value;
|
|
151
|
+
}
|
|
152
|
+
};
|
|
153
|
+
|
|
154
|
+
struct FromFP8 {
|
|
155
|
+
template <int N>
|
|
156
|
+
Simd<float, N> operator()(Simd<uint8_t, N> x) {
|
|
157
|
+
auto w = Simd<uint32_t, N>(x) << 24;
|
|
158
|
+
auto sign = w & 0x80000000;
|
|
159
|
+
auto nonsign = w & 0x7FFFFFFF;
|
|
160
|
+
|
|
161
|
+
auto renorm_shift = clz(nonsign);
|
|
162
|
+
renorm_shift = simd::select(
|
|
163
|
+
renorm_shift > Simd<uint32_t, N>{4},
|
|
164
|
+
renorm_shift - Simd<uint32_t, N>{4},
|
|
165
|
+
Simd<uint32_t, N>{0});
|
|
166
|
+
|
|
167
|
+
Simd<int32_t, N> inf_nan_mask =
|
|
168
|
+
(Simd<int32_t, N>(nonsign + 0x01000000) >> 8) & 0x7F800000;
|
|
169
|
+
auto zero_mask = Simd<int32_t, N>(nonsign - 1) >> 31;
|
|
170
|
+
auto result = sign |
|
|
171
|
+
((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
|
|
172
|
+
inf_nan_mask) &
|
|
173
|
+
~zero_mask);
|
|
174
|
+
return fp32_from_bits(result);
|
|
175
|
+
}
|
|
176
|
+
float operator()(uint8_t x) {
|
|
177
|
+
return (*this)(Simd<uint8_t, 1>(x)).value;
|
|
178
|
+
}
|
|
179
|
+
};
|
|
180
|
+
} // namespace mlx::core::detail
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/allocator.h"
|
|
6
|
+
#include "mlx/backend/common/buffer_cache.h"
|
|
7
|
+
#include "mlx/backend/cuda/cuda_utils.h"
|
|
8
|
+
|
|
9
|
+
#include <cuda_runtime.h>
|
|
10
|
+
#include <mutex>
|
|
11
|
+
#include <set>
|
|
12
|
+
#include <utility>
|
|
13
|
+
|
|
14
|
+
namespace mlx::core::cu {
|
|
15
|
+
|
|
16
|
+
class CommandEncoder;
|
|
17
|
+
|
|
18
|
+
using allocator::Buffer;
|
|
19
|
+
|
|
20
|
+
// Stores cuda-managed unified memory.
|
|
21
|
+
struct CudaBuffer {
|
|
22
|
+
void* data;
|
|
23
|
+
size_t size;
|
|
24
|
+
int device; // -1 for managed
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
class SmallSizePool {
|
|
28
|
+
private:
|
|
29
|
+
union Block {
|
|
30
|
+
Block* next;
|
|
31
|
+
CudaBuffer buf;
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
Block* buffer_{nullptr};
|
|
35
|
+
void* data_{nullptr};
|
|
36
|
+
Block* next_free_{nullptr};
|
|
37
|
+
|
|
38
|
+
public:
|
|
39
|
+
SmallSizePool();
|
|
40
|
+
~SmallSizePool();
|
|
41
|
+
|
|
42
|
+
SmallSizePool(const SmallSizePool&) = delete;
|
|
43
|
+
SmallSizePool& operator=(const SmallSizePool&) = delete;
|
|
44
|
+
|
|
45
|
+
CudaBuffer* malloc();
|
|
46
|
+
void free(CudaBuffer* buf);
|
|
47
|
+
bool in_pool(CudaBuffer* buf);
|
|
48
|
+
};
|
|
49
|
+
|
|
50
|
+
class CudaAllocator : public allocator::Allocator {
|
|
51
|
+
public:
|
|
52
|
+
Buffer malloc(size_t size) override;
|
|
53
|
+
Buffer malloc_async(size_t size, int device, cudaStream_t stream);
|
|
54
|
+
void free(Buffer buffer) override;
|
|
55
|
+
size_t size(Buffer buffer) const override;
|
|
56
|
+
|
|
57
|
+
size_t get_active_memory() const;
|
|
58
|
+
size_t get_peak_memory() const;
|
|
59
|
+
void reset_peak_memory();
|
|
60
|
+
size_t get_memory_limit();
|
|
61
|
+
size_t set_memory_limit(size_t limit);
|
|
62
|
+
size_t get_cache_memory() const;
|
|
63
|
+
size_t set_cache_limit(size_t limit);
|
|
64
|
+
void clear_cache();
|
|
65
|
+
|
|
66
|
+
private:
|
|
67
|
+
void cuda_free(CudaBuffer* buf);
|
|
68
|
+
|
|
69
|
+
CudaAllocator();
|
|
70
|
+
friend CudaAllocator& allocator();
|
|
71
|
+
|
|
72
|
+
std::mutex mutex_;
|
|
73
|
+
size_t memory_limit_;
|
|
74
|
+
size_t free_limit_;
|
|
75
|
+
size_t total_memory_;
|
|
76
|
+
size_t max_pool_size_;
|
|
77
|
+
BufferCache<CudaBuffer> buffer_cache_;
|
|
78
|
+
size_t active_memory_{0};
|
|
79
|
+
size_t peak_memory_{0};
|
|
80
|
+
std::vector<cudaStream_t> free_streams_;
|
|
81
|
+
std::vector<cudaMemPool_t> mem_pools_;
|
|
82
|
+
SmallSizePool scalar_pool_;
|
|
83
|
+
};
|
|
84
|
+
|
|
85
|
+
CudaAllocator& allocator();
|
|
86
|
+
|
|
87
|
+
Buffer malloc_async(size_t size, CommandEncoder& encoder);
|
|
88
|
+
|
|
89
|
+
} // namespace mlx::core::cu
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/cuda/device.h"
|
|
6
|
+
#include "mlx/backend/gpu/copy.h"
|
|
7
|
+
|
|
8
|
+
namespace mlx::core {
|
|
9
|
+
|
|
10
|
+
template <int NDIM>
|
|
11
|
+
struct ConvParams {
|
|
12
|
+
int N; // Batch size
|
|
13
|
+
int C; // In channels
|
|
14
|
+
int O; // Out channels
|
|
15
|
+
int strides[NDIM];
|
|
16
|
+
int padding[NDIM];
|
|
17
|
+
int kernel_dilation[NDIM];
|
|
18
|
+
int input_dilation[NDIM];
|
|
19
|
+
int groups;
|
|
20
|
+
bool flip;
|
|
21
|
+
int in_spatial_dims[NDIM];
|
|
22
|
+
int wt_spatial_dims[NDIM];
|
|
23
|
+
int out_spatial_dims[NDIM];
|
|
24
|
+
int64_t in_strides[NDIM + 2];
|
|
25
|
+
|
|
26
|
+
ConvParams(
|
|
27
|
+
const array& in,
|
|
28
|
+
const array& wt,
|
|
29
|
+
const array& out,
|
|
30
|
+
const std::vector<int>& strides,
|
|
31
|
+
const std::vector<int>& padding,
|
|
32
|
+
const std::vector<int>& kernel_dilation,
|
|
33
|
+
const std::vector<int>& input_dilation,
|
|
34
|
+
int groups,
|
|
35
|
+
bool flip)
|
|
36
|
+
: N(in.shape(0)),
|
|
37
|
+
C(in.shape(-1)),
|
|
38
|
+
O(wt.shape(0)),
|
|
39
|
+
groups(groups),
|
|
40
|
+
flip(flip) {
|
|
41
|
+
std::copy_n(strides.begin(), NDIM, this->strides);
|
|
42
|
+
std::copy_n(padding.begin(), NDIM, this->padding);
|
|
43
|
+
std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation);
|
|
44
|
+
std::copy_n(input_dilation.begin(), NDIM, this->input_dilation);
|
|
45
|
+
std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims);
|
|
46
|
+
std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims);
|
|
47
|
+
std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims);
|
|
48
|
+
std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides);
|
|
49
|
+
}
|
|
50
|
+
};
|
|
51
|
+
|
|
52
|
+
void gemm_grouped_conv(
|
|
53
|
+
cu::CommandEncoder& encoder,
|
|
54
|
+
const array& in,
|
|
55
|
+
const array& wt,
|
|
56
|
+
array& out,
|
|
57
|
+
const std::vector<int>& strides,
|
|
58
|
+
const std::vector<int>& padding,
|
|
59
|
+
const std::vector<int>& kernel_dilation,
|
|
60
|
+
const std::vector<int>& input_dilation,
|
|
61
|
+
int groups,
|
|
62
|
+
bool flip,
|
|
63
|
+
Stream s);
|
|
64
|
+
|
|
65
|
+
void gemm_conv(
|
|
66
|
+
cu::CommandEncoder& encoder,
|
|
67
|
+
const array& in,
|
|
68
|
+
const array& wt,
|
|
69
|
+
array& out,
|
|
70
|
+
const std::vector<int>& strides,
|
|
71
|
+
const std::vector<int>& padding,
|
|
72
|
+
const std::vector<int>& kernel_dilation,
|
|
73
|
+
const std::vector<int>& input_dilation,
|
|
74
|
+
bool flip,
|
|
75
|
+
Stream s);
|
|
76
|
+
|
|
77
|
+
inline void gemm_conv(
|
|
78
|
+
cu::CommandEncoder& encoder,
|
|
79
|
+
array in,
|
|
80
|
+
array wt,
|
|
81
|
+
array& out,
|
|
82
|
+
const std::vector<int>& strides,
|
|
83
|
+
const std::vector<int>& padding,
|
|
84
|
+
const std::vector<int>& kernel_dilation,
|
|
85
|
+
const std::vector<int>& input_dilation,
|
|
86
|
+
int groups,
|
|
87
|
+
bool flip,
|
|
88
|
+
Stream s) {
|
|
89
|
+
if (!in.flags().row_contiguous) {
|
|
90
|
+
in = contiguous_copy_gpu(in, s);
|
|
91
|
+
encoder.add_temporary(in);
|
|
92
|
+
}
|
|
93
|
+
if (!wt.flags().row_contiguous) {
|
|
94
|
+
wt = contiguous_copy_gpu(wt, s);
|
|
95
|
+
encoder.add_temporary(wt);
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
if (groups == 1) {
|
|
99
|
+
gemm_conv(
|
|
100
|
+
encoder,
|
|
101
|
+
in,
|
|
102
|
+
wt,
|
|
103
|
+
out,
|
|
104
|
+
strides,
|
|
105
|
+
padding,
|
|
106
|
+
kernel_dilation,
|
|
107
|
+
input_dilation,
|
|
108
|
+
flip,
|
|
109
|
+
s);
|
|
110
|
+
} else {
|
|
111
|
+
gemm_grouped_conv(
|
|
112
|
+
encoder,
|
|
113
|
+
in,
|
|
114
|
+
wt,
|
|
115
|
+
out,
|
|
116
|
+
strides,
|
|
117
|
+
padding,
|
|
118
|
+
kernel_dilation,
|
|
119
|
+
input_dilation,
|
|
120
|
+
groups,
|
|
121
|
+
flip,
|
|
122
|
+
s);
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
} // namespace mlx::core
|