mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
mlx/include/mlx/ops.h
ADDED
|
@@ -0,0 +1,1627 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <optional>
|
|
6
|
+
|
|
7
|
+
#include "mlx/array.h"
|
|
8
|
+
#include "mlx/device.h"
|
|
9
|
+
#include "mlx/stream.h"
|
|
10
|
+
#include "mlx/utils.h"
|
|
11
|
+
|
|
12
|
+
namespace mlx::core {
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* \defgroup ops Core array operations
|
|
16
|
+
* @{
|
|
17
|
+
*/
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* A 1D array of numbers starting at `start` (optional),
|
|
21
|
+
* stopping at stop, stepping by `step` (optional). */
|
|
22
|
+
array arange(
|
|
23
|
+
double start,
|
|
24
|
+
double stop,
|
|
25
|
+
double step,
|
|
26
|
+
Dtype dtype,
|
|
27
|
+
StreamOrDevice s = {});
|
|
28
|
+
array arange(double start, double stop, double step, StreamOrDevice s = {});
|
|
29
|
+
array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
|
|
30
|
+
array arange(double start, double stop, StreamOrDevice s = {});
|
|
31
|
+
array arange(double stop, Dtype dtype, StreamOrDevice s = {});
|
|
32
|
+
array arange(double stop, StreamOrDevice s = {});
|
|
33
|
+
|
|
34
|
+
array arange(int start, int stop, int step, StreamOrDevice s = {});
|
|
35
|
+
array arange(int start, int stop, StreamOrDevice s = {});
|
|
36
|
+
array arange(int stop, StreamOrDevice s = {});
|
|
37
|
+
|
|
38
|
+
/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */
|
|
39
|
+
array linspace(
|
|
40
|
+
double start,
|
|
41
|
+
double stop,
|
|
42
|
+
int num = 50,
|
|
43
|
+
Dtype dtype = float32,
|
|
44
|
+
StreamOrDevice s = {});
|
|
45
|
+
|
|
46
|
+
/** Convert an array to the given data type. */
|
|
47
|
+
array astype(array a, Dtype dtype, StreamOrDevice s = {});
|
|
48
|
+
|
|
49
|
+
/** Create a view of an array with the given shape and strides. */
|
|
50
|
+
array as_strided(
|
|
51
|
+
array a,
|
|
52
|
+
Shape shape,
|
|
53
|
+
Strides strides,
|
|
54
|
+
size_t offset,
|
|
55
|
+
StreamOrDevice s = {});
|
|
56
|
+
|
|
57
|
+
/** Copy another array. */
|
|
58
|
+
array copy(array a, StreamOrDevice s = {});
|
|
59
|
+
|
|
60
|
+
/** Fill an array of the given shape with the given value(s). */
|
|
61
|
+
array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {});
|
|
62
|
+
array full(Shape shape, array vals, StreamOrDevice s = {});
|
|
63
|
+
template <typename T>
|
|
64
|
+
array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) {
|
|
65
|
+
return full(std::move(shape), array(val, dtype), to_stream(s));
|
|
66
|
+
}
|
|
67
|
+
template <typename T>
|
|
68
|
+
array full(Shape shape, T val, StreamOrDevice s = {}) {
|
|
69
|
+
return full(std::move(shape), array(val), to_stream(s));
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
array full_like(const array& a, array vals, Dtype dtype, StreamOrDevice s = {});
|
|
73
|
+
array full_like(const array& a, array vals, StreamOrDevice s = {});
|
|
74
|
+
template <typename T>
|
|
75
|
+
array full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) {
|
|
76
|
+
return full_like(a, array(val, dtype), dtype, to_stream(s));
|
|
77
|
+
}
|
|
78
|
+
template <typename T>
|
|
79
|
+
array full_like(const array& a, T val, StreamOrDevice s = {}) {
|
|
80
|
+
return full_like(a, array(val, a.dtype()), to_stream(s));
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/** Fill an array of the given shape with zeros. */
|
|
84
|
+
array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
|
85
|
+
inline array zeros(const Shape& shape, StreamOrDevice s = {}) {
|
|
86
|
+
return zeros(shape, float32, s);
|
|
87
|
+
}
|
|
88
|
+
array zeros_like(const array& a, StreamOrDevice s = {});
|
|
89
|
+
|
|
90
|
+
/** Fill an array of the given shape with ones. */
|
|
91
|
+
array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {});
|
|
92
|
+
inline array ones(const Shape& shape, StreamOrDevice s = {}) {
|
|
93
|
+
return ones(shape, float32, s);
|
|
94
|
+
}
|
|
95
|
+
array ones_like(const array& a, StreamOrDevice s = {});
|
|
96
|
+
|
|
97
|
+
/** Fill an array of the given shape (n,m) with ones in the specified diagonal
|
|
98
|
+
* k, and zeros everywhere else. */
|
|
99
|
+
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
|
|
100
|
+
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
|
|
101
|
+
return eye(n, n, 0, dtype, s);
|
|
102
|
+
}
|
|
103
|
+
inline array eye(int n, int m, StreamOrDevice s = {}) {
|
|
104
|
+
return eye(n, m, 0, float32, s);
|
|
105
|
+
}
|
|
106
|
+
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
|
|
107
|
+
return eye(n, m, k, float32, s);
|
|
108
|
+
}
|
|
109
|
+
inline array eye(int n, StreamOrDevice s = {}) {
|
|
110
|
+
return eye(n, n, 0, float32, s);
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
/** Create a square matrix of shape (n,n) of zeros, and ones in the major
|
|
114
|
+
* diagonal. */
|
|
115
|
+
array identity(int n, Dtype dtype, StreamOrDevice s = {});
|
|
116
|
+
inline array identity(int n, StreamOrDevice s = {}) {
|
|
117
|
+
return identity(n, float32, s);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
|
|
121
|
+
inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
|
|
122
|
+
return tri(n, n, 0, type, s);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
array tril(array x, int k = 0, StreamOrDevice s = {});
|
|
126
|
+
array triu(array x, int k = 0, StreamOrDevice s = {});
|
|
127
|
+
|
|
128
|
+
/** Reshape an array to the given shape. */
|
|
129
|
+
array reshape(const array& a, Shape shape, StreamOrDevice s = {});
|
|
130
|
+
|
|
131
|
+
/** Unflatten the axis to the given shape. */
|
|
132
|
+
array unflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {});
|
|
133
|
+
|
|
134
|
+
/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
|
|
135
|
+
array flatten(
|
|
136
|
+
const array& a,
|
|
137
|
+
int start_axis,
|
|
138
|
+
int end_axis = -1,
|
|
139
|
+
StreamOrDevice s = {});
|
|
140
|
+
|
|
141
|
+
/** Flatten the array to 1D. */
|
|
142
|
+
array flatten(const array& a, StreamOrDevice s = {});
|
|
143
|
+
|
|
144
|
+
/** Multiply the array by the Hadamard matrix of corresponding size. */
|
|
145
|
+
array hadamard_transform(
|
|
146
|
+
const array& a,
|
|
147
|
+
std::optional<float> scale = std::nullopt,
|
|
148
|
+
StreamOrDevice s = {});
|
|
149
|
+
|
|
150
|
+
/** Remove singleton dimensions at the given axes. */
|
|
151
|
+
array squeeze(
|
|
152
|
+
const array& a,
|
|
153
|
+
const std::vector<int>& axes,
|
|
154
|
+
StreamOrDevice s = {});
|
|
155
|
+
|
|
156
|
+
/** Remove singleton dimensions at the given axis. */
|
|
157
|
+
array squeeze(const array& a, int axis, StreamOrDevice s = {});
|
|
158
|
+
|
|
159
|
+
/** Remove all singleton dimensions. */
|
|
160
|
+
array squeeze(const array& a, StreamOrDevice s = {});
|
|
161
|
+
|
|
162
|
+
/** Add a singleton dimension at the given axes. */
|
|
163
|
+
array expand_dims(
|
|
164
|
+
const array& a,
|
|
165
|
+
const std::vector<int>& axes,
|
|
166
|
+
StreamOrDevice s = {});
|
|
167
|
+
|
|
168
|
+
/** Add a singleton dimension at the given axis. */
|
|
169
|
+
array expand_dims(const array& a, int axis, StreamOrDevice s = {});
|
|
170
|
+
|
|
171
|
+
/** Slice an array. */
|
|
172
|
+
array slice(
|
|
173
|
+
const array& a,
|
|
174
|
+
Shape start,
|
|
175
|
+
Shape stop,
|
|
176
|
+
Shape strides,
|
|
177
|
+
StreamOrDevice s = {});
|
|
178
|
+
inline array slice(
|
|
179
|
+
const array& a,
|
|
180
|
+
std::initializer_list<int> start,
|
|
181
|
+
Shape stop,
|
|
182
|
+
Shape strides,
|
|
183
|
+
StreamOrDevice s = {}) {
|
|
184
|
+
return slice(a, Shape(start), std::move(stop), std::move(strides), s);
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
/** Slice an array with a stride of 1 in each dimension. */
|
|
188
|
+
array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {});
|
|
189
|
+
|
|
190
|
+
/** Slice an array with dynamic starting indices. */
|
|
191
|
+
array slice(
|
|
192
|
+
const array& a,
|
|
193
|
+
const array& start,
|
|
194
|
+
std::vector<int> axes,
|
|
195
|
+
Shape slice_size,
|
|
196
|
+
StreamOrDevice s = {});
|
|
197
|
+
|
|
198
|
+
/** Update a slice from the source array. */
|
|
199
|
+
array slice_update(
|
|
200
|
+
const array& src,
|
|
201
|
+
const array& update,
|
|
202
|
+
Shape start,
|
|
203
|
+
Shape stop,
|
|
204
|
+
Shape strides,
|
|
205
|
+
StreamOrDevice s = {});
|
|
206
|
+
|
|
207
|
+
/** Update a slice from the source array with stride 1 in each dimension. */
|
|
208
|
+
array slice_update(
|
|
209
|
+
const array& src,
|
|
210
|
+
const array& update,
|
|
211
|
+
Shape start,
|
|
212
|
+
Shape stop,
|
|
213
|
+
StreamOrDevice s = {});
|
|
214
|
+
|
|
215
|
+
/** Update a slice from the source array with dynamic starting indices. */
|
|
216
|
+
array slice_update(
|
|
217
|
+
const array& src,
|
|
218
|
+
const array& update,
|
|
219
|
+
const array& start,
|
|
220
|
+
std::vector<int> axes,
|
|
221
|
+
StreamOrDevice s = {});
|
|
222
|
+
|
|
223
|
+
/** Split an array into sub-arrays along a given axis. */
|
|
224
|
+
std::vector<array>
|
|
225
|
+
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
|
|
226
|
+
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
|
|
227
|
+
std::vector<array>
|
|
228
|
+
split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {});
|
|
229
|
+
std::vector<array>
|
|
230
|
+
split(const array& a, const Shape& indices, StreamOrDevice s = {});
|
|
231
|
+
|
|
232
|
+
/** A vector of coordinate arrays from coordinate vectors. */
|
|
233
|
+
std::vector<array> meshgrid(
|
|
234
|
+
const std::vector<array>& arrays,
|
|
235
|
+
bool sparse = false,
|
|
236
|
+
const std::string& indexing = "xy",
|
|
237
|
+
StreamOrDevice s = {});
|
|
238
|
+
|
|
239
|
+
/**
|
|
240
|
+
* Clip (limit) the values in an array.
|
|
241
|
+
*/
|
|
242
|
+
array clip(
|
|
243
|
+
const array& a,
|
|
244
|
+
const std::optional<array>& a_min = std::nullopt,
|
|
245
|
+
const std::optional<array>& a_max = std::nullopt,
|
|
246
|
+
StreamOrDevice s = {});
|
|
247
|
+
|
|
248
|
+
/** Concatenate arrays along a given axis. */
|
|
249
|
+
array concatenate(std::vector<array> arrays, int axis, StreamOrDevice s = {});
|
|
250
|
+
array concatenate(std::vector<array> arrays, StreamOrDevice s = {});
|
|
251
|
+
|
|
252
|
+
/** Stack arrays along a new axis. */
|
|
253
|
+
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
|
|
254
|
+
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});
|
|
255
|
+
|
|
256
|
+
/** Repeat an array along an axis. */
|
|
257
|
+
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
|
|
258
|
+
array repeat(const array& arr, int repeats, StreamOrDevice s = {});
|
|
259
|
+
|
|
260
|
+
array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});
|
|
261
|
+
|
|
262
|
+
/** Permutes the dimensions according to the given axes. */
|
|
263
|
+
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
|
|
264
|
+
inline array transpose(
|
|
265
|
+
const array& a,
|
|
266
|
+
std::initializer_list<int> axes,
|
|
267
|
+
StreamOrDevice s = {}) {
|
|
268
|
+
return transpose(a, std::vector<int>(axes), s);
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
/** Swap two axes of an array. */
|
|
272
|
+
array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});
|
|
273
|
+
|
|
274
|
+
/** Move an axis of an array. */
|
|
275
|
+
array moveaxis(
|
|
276
|
+
const array& a,
|
|
277
|
+
int source,
|
|
278
|
+
int destination,
|
|
279
|
+
StreamOrDevice s = {});
|
|
280
|
+
|
|
281
|
+
/** Pad an array with a constant value */
|
|
282
|
+
array pad(
|
|
283
|
+
const array& a,
|
|
284
|
+
const std::vector<int>& axes,
|
|
285
|
+
const Shape& low_pad_size,
|
|
286
|
+
const Shape& high_pad_size,
|
|
287
|
+
const array& pad_value = array(0),
|
|
288
|
+
const std::string& mode = "constant",
|
|
289
|
+
StreamOrDevice s = {});
|
|
290
|
+
|
|
291
|
+
/** Pad an array with a constant value along all axes */
|
|
292
|
+
array pad(
|
|
293
|
+
const array& a,
|
|
294
|
+
const std::vector<std::pair<int, int>>& pad_width,
|
|
295
|
+
const array& pad_value = array(0),
|
|
296
|
+
const std::string& mode = "constant",
|
|
297
|
+
StreamOrDevice s = {});
|
|
298
|
+
array pad(
|
|
299
|
+
const array& a,
|
|
300
|
+
const std::pair<int, int>& pad_width,
|
|
301
|
+
const array& pad_value = array(0),
|
|
302
|
+
const std::string& mode = "constant",
|
|
303
|
+
StreamOrDevice s = {});
|
|
304
|
+
array pad(
|
|
305
|
+
const array& a,
|
|
306
|
+
int pad_width,
|
|
307
|
+
const array& pad_value = array(0),
|
|
308
|
+
const std::string& mode = "constant",
|
|
309
|
+
StreamOrDevice s = {});
|
|
310
|
+
|
|
311
|
+
/** Permutes the dimensions in reverse order. */
|
|
312
|
+
array transpose(const array& a, StreamOrDevice s = {});
|
|
313
|
+
|
|
314
|
+
/** Broadcast an array to a given shape. */
|
|
315
|
+
array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {});
|
|
316
|
+
|
|
317
|
+
/** Broadcast a vector of arrays against one another. */
|
|
318
|
+
std::vector<array> broadcast_arrays(
|
|
319
|
+
const std::vector<array>& inputs,
|
|
320
|
+
StreamOrDevice s = {});
|
|
321
|
+
|
|
322
|
+
/** Returns the bool array with (a == b) element-wise. */
|
|
323
|
+
array equal(const array& a, const array& b, StreamOrDevice s = {});
|
|
324
|
+
inline array operator==(const array& a, const array& b) {
|
|
325
|
+
return equal(a, b);
|
|
326
|
+
}
|
|
327
|
+
template <typename T>
|
|
328
|
+
array operator==(T a, const array& b) {
|
|
329
|
+
return equal(array(a), b);
|
|
330
|
+
}
|
|
331
|
+
template <typename T>
|
|
332
|
+
array operator==(const array& a, T b) {
|
|
333
|
+
return equal(a, array(b));
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
/** Returns the bool array with (a != b) element-wise. */
|
|
337
|
+
array not_equal(const array& a, const array& b, StreamOrDevice s = {});
|
|
338
|
+
inline array operator!=(const array& a, const array& b) {
|
|
339
|
+
return not_equal(a, b);
|
|
340
|
+
}
|
|
341
|
+
template <typename T>
|
|
342
|
+
array operator!=(T a, const array& b) {
|
|
343
|
+
return not_equal(array(a), b);
|
|
344
|
+
}
|
|
345
|
+
template <typename T>
|
|
346
|
+
array operator!=(const array& a, T b) {
|
|
347
|
+
return not_equal(a, array(b));
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
/** Returns bool array with (a > b) element-wise. */
|
|
351
|
+
array greater(const array& a, const array& b, StreamOrDevice s = {});
|
|
352
|
+
inline array operator>(const array& a, const array& b) {
|
|
353
|
+
return greater(a, b);
|
|
354
|
+
}
|
|
355
|
+
template <typename T>
|
|
356
|
+
array operator>(T a, const array& b) {
|
|
357
|
+
return greater(array(a), b);
|
|
358
|
+
}
|
|
359
|
+
template <typename T>
|
|
360
|
+
array operator>(const array& a, T b) {
|
|
361
|
+
return greater(a, array(b));
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
/** Returns bool array with (a >= b) element-wise. */
|
|
365
|
+
array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
|
|
366
|
+
inline array operator>=(const array& a, const array& b) {
|
|
367
|
+
return greater_equal(a, b);
|
|
368
|
+
}
|
|
369
|
+
template <typename T>
|
|
370
|
+
array operator>=(T a, const array& b) {
|
|
371
|
+
return greater_equal(array(a), b);
|
|
372
|
+
}
|
|
373
|
+
template <typename T>
|
|
374
|
+
array operator>=(const array& a, T b) {
|
|
375
|
+
return greater_equal(a, array(b));
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
/** Returns bool array with (a < b) element-wise. */
|
|
379
|
+
array less(const array& a, const array& b, StreamOrDevice s = {});
|
|
380
|
+
inline array operator<(const array& a, const array& b) {
|
|
381
|
+
return less(a, b);
|
|
382
|
+
}
|
|
383
|
+
template <typename T>
|
|
384
|
+
array operator<(T a, const array& b) {
|
|
385
|
+
return less(array(a), b);
|
|
386
|
+
}
|
|
387
|
+
template <typename T>
|
|
388
|
+
array operator<(const array& a, T b) {
|
|
389
|
+
return less(a, array(b));
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
/** Returns bool array with (a <= b) element-wise. */
|
|
393
|
+
array less_equal(const array& a, const array& b, StreamOrDevice s = {});
|
|
394
|
+
inline array operator<=(const array& a, const array& b) {
|
|
395
|
+
return less_equal(a, b);
|
|
396
|
+
}
|
|
397
|
+
template <typename T>
|
|
398
|
+
array operator<=(T a, const array& b) {
|
|
399
|
+
return less_equal(array(a), b);
|
|
400
|
+
}
|
|
401
|
+
template <typename T>
|
|
402
|
+
array operator<=(const array& a, T b) {
|
|
403
|
+
return less_equal(a, array(b));
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
/** True if two arrays have the same shape and elements. */
|
|
407
|
+
array array_equal(
|
|
408
|
+
const array& a,
|
|
409
|
+
const array& b,
|
|
410
|
+
bool equal_nan,
|
|
411
|
+
StreamOrDevice s = {});
|
|
412
|
+
inline array
|
|
413
|
+
array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
|
|
414
|
+
return array_equal(a, b, false, s);
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
array isnan(const array& a, StreamOrDevice s = {});
|
|
418
|
+
|
|
419
|
+
array isinf(const array& a, StreamOrDevice s = {});
|
|
420
|
+
|
|
421
|
+
array isfinite(const array& a, StreamOrDevice s = {});
|
|
422
|
+
|
|
423
|
+
array isposinf(const array& a, StreamOrDevice s = {});
|
|
424
|
+
|
|
425
|
+
array isneginf(const array& a, StreamOrDevice s = {});
|
|
426
|
+
|
|
427
|
+
/** Select from x or y depending on condition. */
|
|
428
|
+
array where(
|
|
429
|
+
const array& condition,
|
|
430
|
+
const array& x,
|
|
431
|
+
const array& y,
|
|
432
|
+
StreamOrDevice s = {});
|
|
433
|
+
|
|
434
|
+
/** Replace NaN and infinities with finite numbers. */
|
|
435
|
+
array nan_to_num(
|
|
436
|
+
const array& a,
|
|
437
|
+
float nan = 0.0f,
|
|
438
|
+
const std::optional<float> posinf = std::nullopt,
|
|
439
|
+
const std::optional<float> neginf = std::nullopt,
|
|
440
|
+
StreamOrDevice s = {});
|
|
441
|
+
|
|
442
|
+
/** True if all elements in the array are true (or non-zero). **/
|
|
443
|
+
array all(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
444
|
+
inline array all(const array& a, StreamOrDevice s = {}) {
|
|
445
|
+
return all(a, false, to_stream(s));
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
/** True if the two arrays are equal within the specified tolerance. */
|
|
449
|
+
array allclose(
|
|
450
|
+
const array& a,
|
|
451
|
+
const array& b,
|
|
452
|
+
double rtol = 1e-5,
|
|
453
|
+
double atol = 1e-8,
|
|
454
|
+
bool equal_nan = false,
|
|
455
|
+
StreamOrDevice s = {});
|
|
456
|
+
|
|
457
|
+
/** Returns a boolean array where two arrays are element-wise equal within the
|
|
458
|
+
* specified tolerance. */
|
|
459
|
+
array isclose(
|
|
460
|
+
const array& a,
|
|
461
|
+
const array& b,
|
|
462
|
+
double rtol = 1e-5,
|
|
463
|
+
double atol = 1e-8,
|
|
464
|
+
bool equal_nan = false,
|
|
465
|
+
StreamOrDevice s = {});
|
|
466
|
+
|
|
467
|
+
/**
|
|
468
|
+
* Reduces the input along the given axes. An output value is true
|
|
469
|
+
* if all the corresponding inputs are true.
|
|
470
|
+
**/
|
|
471
|
+
array all(
|
|
472
|
+
const array& a,
|
|
473
|
+
const std::vector<int>& axes,
|
|
474
|
+
bool keepdims = false,
|
|
475
|
+
StreamOrDevice s = {});
|
|
476
|
+
|
|
477
|
+
/**
|
|
478
|
+
* Reduces the input along the given axis. An output value is true
|
|
479
|
+
* if all the corresponding inputs are true.
|
|
480
|
+
**/
|
|
481
|
+
array all(
|
|
482
|
+
const array& a,
|
|
483
|
+
int axis,
|
|
484
|
+
bool keepdims = false,
|
|
485
|
+
StreamOrDevice s = {});
|
|
486
|
+
|
|
487
|
+
/** True if any elements in the array are true (or non-zero). **/
|
|
488
|
+
array any(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
489
|
+
inline array any(const array& a, StreamOrDevice s = {}) {
|
|
490
|
+
return any(a, false, to_stream(s));
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
/**
|
|
494
|
+
* Reduces the input along the given axes. An output value is true
|
|
495
|
+
* if any of the corresponding inputs are true.
|
|
496
|
+
**/
|
|
497
|
+
array any(
|
|
498
|
+
const array& a,
|
|
499
|
+
const std::vector<int>& axes,
|
|
500
|
+
bool keepdims = false,
|
|
501
|
+
StreamOrDevice s = {});
|
|
502
|
+
|
|
503
|
+
/**
|
|
504
|
+
* Reduces the input along the given axis. An output value is true
|
|
505
|
+
* if any of the corresponding inputs are true.
|
|
506
|
+
**/
|
|
507
|
+
array any(
|
|
508
|
+
const array& a,
|
|
509
|
+
int axis,
|
|
510
|
+
bool keepdims = false,
|
|
511
|
+
StreamOrDevice s = {});
|
|
512
|
+
|
|
513
|
+
/** Sums the elements of an array. */
|
|
514
|
+
array sum(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
515
|
+
inline array sum(const array& a, StreamOrDevice s = {}) {
|
|
516
|
+
return sum(a, false, to_stream(s));
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
/** Sums the elements of an array along the given axes. */
|
|
520
|
+
array sum(
|
|
521
|
+
const array& a,
|
|
522
|
+
const std::vector<int>& axes,
|
|
523
|
+
bool keepdims = false,
|
|
524
|
+
StreamOrDevice s = {});
|
|
525
|
+
|
|
526
|
+
/** Sums the elements of an array along the given axis. */
|
|
527
|
+
array sum(
|
|
528
|
+
const array& a,
|
|
529
|
+
int axis,
|
|
530
|
+
bool keepdims = false,
|
|
531
|
+
StreamOrDevice s = {});
|
|
532
|
+
|
|
533
|
+
/** Computes the mean of the elements of an array. */
|
|
534
|
+
array mean(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
535
|
+
inline array mean(const array& a, StreamOrDevice s = {}) {
|
|
536
|
+
return mean(a, false, to_stream(s));
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
/** Computes the mean of the elements of an array along the given axes */
|
|
540
|
+
array mean(
|
|
541
|
+
const array& a,
|
|
542
|
+
const std::vector<int>& axes,
|
|
543
|
+
bool keepdims = false,
|
|
544
|
+
StreamOrDevice s = {});
|
|
545
|
+
|
|
546
|
+
/** Computes the mean of the elements of an array along the given axis */
|
|
547
|
+
array mean(
|
|
548
|
+
const array& a,
|
|
549
|
+
int axis,
|
|
550
|
+
bool keepdims = false,
|
|
551
|
+
StreamOrDevice s = {});
|
|
552
|
+
|
|
553
|
+
/** Computes the median of the elements of an array. */
|
|
554
|
+
array median(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
555
|
+
inline array median(const array& a, StreamOrDevice s = {}) {
|
|
556
|
+
return median(a, false, to_stream(s));
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
/** Computes the median of the elements of an array along the given axes */
|
|
560
|
+
array median(
|
|
561
|
+
const array& a,
|
|
562
|
+
const std::vector<int>& axes,
|
|
563
|
+
bool keepdims = false,
|
|
564
|
+
StreamOrDevice s = {});
|
|
565
|
+
|
|
566
|
+
/** Computes the median of the elements of an array along the given axis */
|
|
567
|
+
array median(
|
|
568
|
+
const array& a,
|
|
569
|
+
int axis,
|
|
570
|
+
bool keepdims = false,
|
|
571
|
+
StreamOrDevice s = {});
|
|
572
|
+
|
|
573
|
+
/** Computes the variance of the elements of an array. */
|
|
574
|
+
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
|
575
|
+
inline array var(const array& a, StreamOrDevice s = {}) {
|
|
576
|
+
return var(a, false, 0, to_stream(s));
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
/** Computes the variance of the elements of an array along the given
|
|
580
|
+
* axes */
|
|
581
|
+
array var(
|
|
582
|
+
const array& a,
|
|
583
|
+
const std::vector<int>& axes,
|
|
584
|
+
bool keepdims = false,
|
|
585
|
+
int ddof = 0,
|
|
586
|
+
StreamOrDevice s = {});
|
|
587
|
+
|
|
588
|
+
/** Computes the variance of the elements of an array along the given
|
|
589
|
+
* axis */
|
|
590
|
+
array var(
|
|
591
|
+
const array& a,
|
|
592
|
+
int axis,
|
|
593
|
+
bool keepdims = false,
|
|
594
|
+
int ddof = 0,
|
|
595
|
+
StreamOrDevice s = {});
|
|
596
|
+
|
|
597
|
+
/** Computes the standard deviation of the elements of an array. */
|
|
598
|
+
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
|
|
599
|
+
inline array std(const array& a, StreamOrDevice s = {}) {
|
|
600
|
+
return std(a, false, 0, to_stream(s));
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
/** Computes the standard deviation of the elements of an array along the given
|
|
604
|
+
* axes */
|
|
605
|
+
array std(
|
|
606
|
+
const array& a,
|
|
607
|
+
const std::vector<int>& axes,
|
|
608
|
+
bool keepdims = false,
|
|
609
|
+
int ddof = 0,
|
|
610
|
+
StreamOrDevice s = {});
|
|
611
|
+
|
|
612
|
+
/** Computes the standard deviation of the elements of an array along the given
|
|
613
|
+
* axis */
|
|
614
|
+
array std(
|
|
615
|
+
const array& a,
|
|
616
|
+
int axis,
|
|
617
|
+
bool keepdims = false,
|
|
618
|
+
int ddof = 0,
|
|
619
|
+
StreamOrDevice s = {});
|
|
620
|
+
|
|
621
|
+
/** The product of all elements of the array. */
|
|
622
|
+
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
623
|
+
inline array prod(const array& a, StreamOrDevice s = {}) {
|
|
624
|
+
return prod(a, false, to_stream(s));
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
/** The product of the elements of an array along the given axes. */
|
|
628
|
+
array prod(
|
|
629
|
+
const array& a,
|
|
630
|
+
const std::vector<int>& axes,
|
|
631
|
+
bool keepdims = false,
|
|
632
|
+
StreamOrDevice s = {});
|
|
633
|
+
|
|
634
|
+
/** The product of the elements of an array along the given axis. */
|
|
635
|
+
array prod(
|
|
636
|
+
const array& a,
|
|
637
|
+
int axis,
|
|
638
|
+
bool keepdims = false,
|
|
639
|
+
StreamOrDevice s = {});
|
|
640
|
+
|
|
641
|
+
/** The maximum of all elements of the array. */
|
|
642
|
+
array max(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
643
|
+
inline array max(const array& a, StreamOrDevice s = {}) {
|
|
644
|
+
return max(a, false, to_stream(s));
|
|
645
|
+
}
|
|
646
|
+
|
|
647
|
+
/** The maximum of the elements of an array along the given axes. */
|
|
648
|
+
array max(
|
|
649
|
+
const array& a,
|
|
650
|
+
const std::vector<int>& axes,
|
|
651
|
+
bool keepdims = false,
|
|
652
|
+
StreamOrDevice s = {});
|
|
653
|
+
|
|
654
|
+
/** The maximum of the elements of an array along the given axis. */
|
|
655
|
+
array max(
|
|
656
|
+
const array& a,
|
|
657
|
+
int axis,
|
|
658
|
+
bool keepdims = false,
|
|
659
|
+
StreamOrDevice s = {});
|
|
660
|
+
|
|
661
|
+
/** The minimum of all elements of the array. */
|
|
662
|
+
array min(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
663
|
+
inline array min(const array& a, StreamOrDevice s = {}) {
|
|
664
|
+
return min(a, false, to_stream(s));
|
|
665
|
+
}
|
|
666
|
+
|
|
667
|
+
/** The minimum of the elements of an array along the given axes. */
|
|
668
|
+
array min(
|
|
669
|
+
const array& a,
|
|
670
|
+
const std::vector<int>& axes,
|
|
671
|
+
bool keepdims = false,
|
|
672
|
+
StreamOrDevice s = {});
|
|
673
|
+
|
|
674
|
+
/** The minimum of the elements of an array along the given axis. */
|
|
675
|
+
array min(
|
|
676
|
+
const array& a,
|
|
677
|
+
int axis,
|
|
678
|
+
bool keepdims = false,
|
|
679
|
+
StreamOrDevice s = {});
|
|
680
|
+
|
|
681
|
+
/** Returns the index of the minimum value in the array. */
|
|
682
|
+
array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
683
|
+
inline array argmin(const array& a, StreamOrDevice s = {}) {
|
|
684
|
+
return argmin(a, false, s);
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
/** Returns the indices of the minimum values along a given axis. */
|
|
688
|
+
array argmin(
|
|
689
|
+
const array& a,
|
|
690
|
+
int axis,
|
|
691
|
+
bool keepdims = false,
|
|
692
|
+
StreamOrDevice s = {});
|
|
693
|
+
|
|
694
|
+
/** Returns the index of the maximum value in the array. */
|
|
695
|
+
array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
696
|
+
inline array argmax(const array& a, StreamOrDevice s = {}) {
|
|
697
|
+
return argmax(a, false, s);
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
/** Returns the indices of the maximum values along a given axis. */
|
|
701
|
+
array argmax(
|
|
702
|
+
const array& a,
|
|
703
|
+
int axis,
|
|
704
|
+
bool keepdims = false,
|
|
705
|
+
StreamOrDevice s = {});
|
|
706
|
+
|
|
707
|
+
/** Returns a sorted copy of the flattened array. */
|
|
708
|
+
array sort(const array& a, StreamOrDevice s = {});
|
|
709
|
+
|
|
710
|
+
/** Returns a sorted copy of the array along a given axis. */
|
|
711
|
+
array sort(const array& a, int axis, StreamOrDevice s = {});
|
|
712
|
+
|
|
713
|
+
/** Returns indices that sort the flattened array. */
|
|
714
|
+
array argsort(const array& a, StreamOrDevice s = {});
|
|
715
|
+
|
|
716
|
+
/** Returns indices that sort the array along a given axis. */
|
|
717
|
+
array argsort(const array& a, int axis, StreamOrDevice s = {});
|
|
718
|
+
|
|
719
|
+
/**
|
|
720
|
+
* Returns a partitioned copy of the flattened array
|
|
721
|
+
* such that the smaller kth elements are first.
|
|
722
|
+
**/
|
|
723
|
+
array partition(const array& a, int kth, StreamOrDevice s = {});
|
|
724
|
+
|
|
725
|
+
/**
|
|
726
|
+
* Returns a partitioned copy of the array along a given axis
|
|
727
|
+
* such that the smaller kth elements are first.
|
|
728
|
+
**/
|
|
729
|
+
array partition(const array& a, int kth, int axis, StreamOrDevice s = {});
|
|
730
|
+
|
|
731
|
+
/**
|
|
732
|
+
* Returns indices that partition the flattened array
|
|
733
|
+
* such that the smaller kth elements are first.
|
|
734
|
+
**/
|
|
735
|
+
array argpartition(const array& a, int kth, StreamOrDevice s = {});
|
|
736
|
+
|
|
737
|
+
/**
|
|
738
|
+
* Returns indices that partition the array along a given axis
|
|
739
|
+
* such that the smaller kth elements are first.
|
|
740
|
+
**/
|
|
741
|
+
array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});
|
|
742
|
+
|
|
743
|
+
/** Returns topk elements of the flattened array. */
|
|
744
|
+
array topk(const array& a, int k, StreamOrDevice s = {});
|
|
745
|
+
|
|
746
|
+
/** Returns topk elements of the array along a given axis. */
|
|
747
|
+
array topk(const array& a, int k, int axis, StreamOrDevice s = {});
|
|
748
|
+
|
|
749
|
+
/** Cumulative logsumexp of an array. */
|
|
750
|
+
array logcumsumexp(
|
|
751
|
+
const array& a,
|
|
752
|
+
bool reverse = false,
|
|
753
|
+
bool inclusive = true,
|
|
754
|
+
StreamOrDevice s = {});
|
|
755
|
+
|
|
756
|
+
/** Cumulative logsumexp of an array along the given axis. */
|
|
757
|
+
array logcumsumexp(
|
|
758
|
+
const array& a,
|
|
759
|
+
int axis,
|
|
760
|
+
bool reverse = false,
|
|
761
|
+
bool inclusive = true,
|
|
762
|
+
StreamOrDevice s = {});
|
|
763
|
+
|
|
764
|
+
/** The logsumexp of all elements of the array. */
|
|
765
|
+
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
|
|
766
|
+
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
|
|
767
|
+
return logsumexp(a, false, to_stream(s));
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
/** The logsumexp of the elements of an array along the given axes. */
|
|
771
|
+
array logsumexp(
|
|
772
|
+
const array& a,
|
|
773
|
+
const std::vector<int>& axes,
|
|
774
|
+
bool keepdims = false,
|
|
775
|
+
StreamOrDevice s = {});
|
|
776
|
+
|
|
777
|
+
/** The logsumexp of the elements of an array along the given axis. */
|
|
778
|
+
array logsumexp(
|
|
779
|
+
const array& a,
|
|
780
|
+
int axis,
|
|
781
|
+
bool keepdims = false,
|
|
782
|
+
StreamOrDevice s = {});
|
|
783
|
+
|
|
784
|
+
/** Absolute value of elements in an array. */
|
|
785
|
+
array abs(const array& a, StreamOrDevice s = {});
|
|
786
|
+
|
|
787
|
+
/** Negate an array. */
|
|
788
|
+
array negative(const array& a, StreamOrDevice s = {});
|
|
789
|
+
array operator-(const array& a);
|
|
790
|
+
|
|
791
|
+
/** The sign of the elements in an array. */
|
|
792
|
+
array sign(const array& a, StreamOrDevice s = {});
|
|
793
|
+
|
|
794
|
+
/** Logical not of an array */
|
|
795
|
+
array logical_not(const array& a, StreamOrDevice s = {});
|
|
796
|
+
|
|
797
|
+
/** Logical and of two arrays */
|
|
798
|
+
array logical_and(const array& a, const array& b, StreamOrDevice s = {});
|
|
799
|
+
array operator&&(const array& a, const array& b);
|
|
800
|
+
|
|
801
|
+
/** Logical or of two arrays */
|
|
802
|
+
array logical_or(const array& a, const array& b, StreamOrDevice s = {});
|
|
803
|
+
array operator||(const array& a, const array& b);
|
|
804
|
+
|
|
805
|
+
/** The reciprocal (1/x) of the elements in an array. */
|
|
806
|
+
array reciprocal(const array& a, StreamOrDevice s = {});
|
|
807
|
+
|
|
808
|
+
/** Add two arrays. */
|
|
809
|
+
array add(const array& a, const array& b, StreamOrDevice s = {});
|
|
810
|
+
array operator+(const array& a, const array& b);
|
|
811
|
+
template <typename T>
|
|
812
|
+
array operator+(T a, const array& b) {
|
|
813
|
+
return add(array(a), b);
|
|
814
|
+
}
|
|
815
|
+
template <typename T>
|
|
816
|
+
array operator+(const array& a, T b) {
|
|
817
|
+
return add(a, array(b));
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
/** Subtract two arrays. */
|
|
821
|
+
array subtract(const array& a, const array& b, StreamOrDevice s = {});
|
|
822
|
+
array operator-(const array& a, const array& b);
|
|
823
|
+
template <typename T>
|
|
824
|
+
array operator-(T a, const array& b) {
|
|
825
|
+
return subtract(array(a), b);
|
|
826
|
+
}
|
|
827
|
+
template <typename T>
|
|
828
|
+
array operator-(const array& a, T b) {
|
|
829
|
+
return subtract(a, array(b));
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
/** Multiply two arrays. */
|
|
833
|
+
array multiply(const array& a, const array& b, StreamOrDevice s = {});
|
|
834
|
+
array operator*(const array& a, const array& b);
|
|
835
|
+
template <typename T>
|
|
836
|
+
array operator*(T a, const array& b) {
|
|
837
|
+
return multiply(array(a), b);
|
|
838
|
+
}
|
|
839
|
+
template <typename T>
|
|
840
|
+
array operator*(const array& a, T b) {
|
|
841
|
+
return multiply(a, array(b));
|
|
842
|
+
}
|
|
843
|
+
|
|
844
|
+
/** Divide two arrays. */
|
|
845
|
+
array divide(const array& a, const array& b, StreamOrDevice s = {});
|
|
846
|
+
array operator/(const array& a, const array& b);
|
|
847
|
+
array operator/(double a, const array& b);
|
|
848
|
+
array operator/(const array& a, double b);
|
|
849
|
+
|
|
850
|
+
/** Compute the element-wise quotient and remainder. */
|
|
851
|
+
std::vector<array>
|
|
852
|
+
divmod(const array& a, const array& b, StreamOrDevice s = {});
|
|
853
|
+
|
|
854
|
+
/** Compute integer division. Equivalent to doing floor(a / x). */
|
|
855
|
+
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});
|
|
856
|
+
|
|
857
|
+
/** Compute the element-wise remainder of division */
|
|
858
|
+
array remainder(const array& a, const array& b, StreamOrDevice s = {});
|
|
859
|
+
array operator%(const array& a, const array& b);
|
|
860
|
+
template <typename T>
|
|
861
|
+
array operator%(T a, const array& b) {
|
|
862
|
+
return remainder(array(a), b);
|
|
863
|
+
}
|
|
864
|
+
template <typename T>
|
|
865
|
+
array operator%(const array& a, T b) {
|
|
866
|
+
return remainder(a, array(b));
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
/** Element-wise maximum between two arrays. */
|
|
870
|
+
array maximum(const array& a, const array& b, StreamOrDevice s = {});
|
|
871
|
+
|
|
872
|
+
/** Element-wise minimum between two arrays. */
|
|
873
|
+
array minimum(const array& a, const array& b, StreamOrDevice s = {});
|
|
874
|
+
|
|
875
|
+
/** Floor the element of an array. **/
|
|
876
|
+
array floor(const array& a, StreamOrDevice s = {});
|
|
877
|
+
|
|
878
|
+
/** Ceil the element of an array. **/
|
|
879
|
+
array ceil(const array& a, StreamOrDevice s = {});
|
|
880
|
+
|
|
881
|
+
/** Square the elements of an array. */
|
|
882
|
+
array square(const array& a, StreamOrDevice s = {});
|
|
883
|
+
|
|
884
|
+
/** Exponential of the elements of an array. */
|
|
885
|
+
array exp(const array& a, StreamOrDevice s = {});
|
|
886
|
+
|
|
887
|
+
/** Sine of the elements of an array */
|
|
888
|
+
array sin(const array& a, StreamOrDevice s = {});
|
|
889
|
+
|
|
890
|
+
/** Cosine of the elements of an array */
|
|
891
|
+
array cos(const array& a, StreamOrDevice s = {});
|
|
892
|
+
|
|
893
|
+
/** Tangent of the elements of an array */
|
|
894
|
+
array tan(const array& a, StreamOrDevice s = {});
|
|
895
|
+
|
|
896
|
+
/** Arc Sine of the elements of an array */
|
|
897
|
+
array arcsin(const array& a, StreamOrDevice s = {});
|
|
898
|
+
|
|
899
|
+
/** Arc Cosine of the elements of an array */
|
|
900
|
+
array arccos(const array& a, StreamOrDevice s = {});
|
|
901
|
+
|
|
902
|
+
/** Arc Tangent of the elements of an array */
|
|
903
|
+
array arctan(const array& a, StreamOrDevice s = {});
|
|
904
|
+
|
|
905
|
+
/** Inverse tangent of the ratio of two arrays */
|
|
906
|
+
array arctan2(const array& a, const array& b, StreamOrDevice s = {});
|
|
907
|
+
|
|
908
|
+
/** Hyperbolic Sine of the elements of an array */
|
|
909
|
+
array sinh(const array& a, StreamOrDevice s = {});
|
|
910
|
+
|
|
911
|
+
/** Hyperbolic Cosine of the elements of an array */
|
|
912
|
+
array cosh(const array& a, StreamOrDevice s = {});
|
|
913
|
+
|
|
914
|
+
/** Hyperbolic Tangent of the elements of an array */
|
|
915
|
+
array tanh(const array& a, StreamOrDevice s = {});
|
|
916
|
+
|
|
917
|
+
/** Inverse Hyperbolic Sine of the elements of an array */
|
|
918
|
+
array arcsinh(const array& a, StreamOrDevice s = {});
|
|
919
|
+
|
|
920
|
+
/** Inverse Hyperbolic Cosine of the elements of an array */
|
|
921
|
+
array arccosh(const array& a, StreamOrDevice s = {});
|
|
922
|
+
|
|
923
|
+
/** Inverse Hyperbolic Tangent of the elements of an array */
|
|
924
|
+
array arctanh(const array& a, StreamOrDevice s = {});
|
|
925
|
+
|
|
926
|
+
/** Convert the elements of an array from Radians to Degrees **/
|
|
927
|
+
array degrees(const array& a, StreamOrDevice s = {});
|
|
928
|
+
|
|
929
|
+
/** Convert the elements of an array from Degrees to Radians **/
|
|
930
|
+
array radians(const array& a, StreamOrDevice s = {});
|
|
931
|
+
|
|
932
|
+
/** Natural logarithm of the elements of an array. */
|
|
933
|
+
array log(const array& a, StreamOrDevice s = {});
|
|
934
|
+
|
|
935
|
+
/** Log base 2 of the elements of an array. */
|
|
936
|
+
array log2(const array& a, StreamOrDevice s = {});
|
|
937
|
+
|
|
938
|
+
/** Log base 10 of the elements of an array. */
|
|
939
|
+
array log10(const array& a, StreamOrDevice s = {});
|
|
940
|
+
|
|
941
|
+
/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */
|
|
942
|
+
array log1p(const array& a, StreamOrDevice s = {});
|
|
943
|
+
|
|
944
|
+
/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */
|
|
945
|
+
array logaddexp(const array& a, const array& b, StreamOrDevice s = {});
|
|
946
|
+
|
|
947
|
+
/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */
|
|
948
|
+
array sigmoid(const array& a, StreamOrDevice s = {});
|
|
949
|
+
|
|
950
|
+
/** Computes the error function of the elements of an array. */
|
|
951
|
+
array erf(const array& a, StreamOrDevice s = {});
|
|
952
|
+
|
|
953
|
+
/** Computes the inverse error function of the elements of an array. */
|
|
954
|
+
array erfinv(const array& a, StreamOrDevice s = {});
|
|
955
|
+
|
|
956
|
+
/** Computes the expm1 function of the elements of an array. */
|
|
957
|
+
array expm1(const array& a, StreamOrDevice s = {});
|
|
958
|
+
|
|
959
|
+
/** Stop the flow of gradients. */
|
|
960
|
+
array stop_gradient(const array& a, StreamOrDevice s = {});
|
|
961
|
+
|
|
962
|
+
/** Round a floating point number */
|
|
963
|
+
array round(const array& a, int decimals, StreamOrDevice s = {});
|
|
964
|
+
inline array round(const array& a, StreamOrDevice s = {}) {
|
|
965
|
+
return round(a, 0, s);
|
|
966
|
+
}
|
|
967
|
+
|
|
968
|
+
/** Matrix-matrix multiplication. */
|
|
969
|
+
array matmul(const array& a, const array& b, StreamOrDevice s = {});
|
|
970
|
+
|
|
971
|
+
/** Gather array entries given indices and slices */
|
|
972
|
+
array gather(
|
|
973
|
+
const array& a,
|
|
974
|
+
const std::vector<array>& indices,
|
|
975
|
+
const std::vector<int>& axes,
|
|
976
|
+
const Shape& slice_sizes,
|
|
977
|
+
StreamOrDevice s = {});
|
|
978
|
+
inline array gather(
|
|
979
|
+
const array& a,
|
|
980
|
+
const array& indices,
|
|
981
|
+
int axis,
|
|
982
|
+
const Shape& slice_sizes,
|
|
983
|
+
StreamOrDevice s = {}) {
|
|
984
|
+
return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
|
|
985
|
+
}
|
|
986
|
+
|
|
987
|
+
/** Compute the Kronecker product of two arrays. */
|
|
988
|
+
array kron(const array& a, const array& b, StreamOrDevice s = {});
|
|
989
|
+
|
|
990
|
+
/** Take array slices at the given indices of the specified axis. */
|
|
991
|
+
array take(
|
|
992
|
+
const array& a,
|
|
993
|
+
const array& indices,
|
|
994
|
+
int axis,
|
|
995
|
+
StreamOrDevice s = {});
|
|
996
|
+
array take(const array& a, int index, int axis, StreamOrDevice s = {});
|
|
997
|
+
|
|
998
|
+
/** Take array entries at the given indices treating the array as flattened. */
|
|
999
|
+
array take(const array& a, const array& indices, StreamOrDevice s = {});
|
|
1000
|
+
array take(const array& a, int index, StreamOrDevice s = {});
|
|
1001
|
+
|
|
1002
|
+
/** Take array entries given indices along the axis */
|
|
1003
|
+
array take_along_axis(
|
|
1004
|
+
const array& a,
|
|
1005
|
+
const array& indices,
|
|
1006
|
+
int axis,
|
|
1007
|
+
StreamOrDevice s = {});
|
|
1008
|
+
|
|
1009
|
+
/** Put the values into the array at the given indices along the axis */
|
|
1010
|
+
array put_along_axis(
|
|
1011
|
+
const array& a,
|
|
1012
|
+
const array& indices,
|
|
1013
|
+
const array& values,
|
|
1014
|
+
int axis,
|
|
1015
|
+
StreamOrDevice s = {});
|
|
1016
|
+
|
|
1017
|
+
/** Add the values into the array at the given indices along the axis */
|
|
1018
|
+
array scatter_add_axis(
|
|
1019
|
+
const array& a,
|
|
1020
|
+
const array& indices,
|
|
1021
|
+
const array& values,
|
|
1022
|
+
int axis,
|
|
1023
|
+
StreamOrDevice s = {});
|
|
1024
|
+
|
|
1025
|
+
/** Scatter updates to the given indices.
|
|
1026
|
+
*
|
|
1027
|
+
* The parameters ``indices`` and ``axes`` determine the locations of ``a``
|
|
1028
|
+
* that are updated with the values in ``updates``. Assuming 1-d ``indices``
|
|
1029
|
+
* for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which
|
|
1030
|
+
* the values in ``updates`` will be applied. Note each array in
|
|
1031
|
+
* ``indices`` is assigned to a corresponding axis and hence ``indices.size() ==
|
|
1032
|
+
* axes.size()``. If an index/axis pair is not provided then indices along that
|
|
1033
|
+
* axis are assumed to be zero.
|
|
1034
|
+
*
|
|
1035
|
+
* Note the rank of ``updates`` must be equal to the sum of the rank of the
|
|
1036
|
+
* broadcasted ``indices`` and the rank of ``a``. In other words, assuming the
|
|
1037
|
+
* arrays in ``indices`` have the same shape, ``updates.ndim() ==
|
|
1038
|
+
* indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates``
|
|
1039
|
+
* correspond to the indices, and the remaining ``a.ndim()`` dimensions are the
|
|
1040
|
+
* values that will be applied to the given location in ``a``.
|
|
1041
|
+
*
|
|
1042
|
+
* For example:
|
|
1043
|
+
*
|
|
1044
|
+
* @code
|
|
1045
|
+
* auto in = zeros({4, 4}, float32);
|
|
1046
|
+
* auto indices = array({2});
|
|
1047
|
+
* auto updates = reshape(arange(1, 3, float32), {1, 1, 2});
|
|
1048
|
+
* std::vector<int> axes{0};
|
|
1049
|
+
*
|
|
1050
|
+
* auto out = scatter(in, {indices}, updates, axes);
|
|
1051
|
+
* @endcode
|
|
1052
|
+
*
|
|
1053
|
+
* will produce:
|
|
1054
|
+
*
|
|
1055
|
+
* @code
|
|
1056
|
+
* array([[0, 0, 0, 0],
|
|
1057
|
+
* [0, 0, 0, 0],
|
|
1058
|
+
* [1, 2, 0, 0],
|
|
1059
|
+
* [0, 0, 0, 0]], dtype=float32)
|
|
1060
|
+
* @endcode
|
|
1061
|
+
*
|
|
1062
|
+
* This scatters the two-element row vector ``[1, 2]`` starting at the ``(2,
|
|
1063
|
+
* 0)`` position of ``a``.
|
|
1064
|
+
*
|
|
1065
|
+
* Adding another element to ``indices`` will scatter into another location of
|
|
1066
|
+
* ``a``. We also have to add an another update for the new index:
|
|
1067
|
+
*
|
|
1068
|
+
* @code
|
|
1069
|
+
* auto in = zeros({4, 4}, float32);
|
|
1070
|
+
* auto indices = array({2, 0});
|
|
1071
|
+
* auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
|
|
1072
|
+
* std::vector<int> axes{0};
|
|
1073
|
+
*
|
|
1074
|
+
* auto out = scatter(in, {indices}, updates, axes):
|
|
1075
|
+
* @endcode
|
|
1076
|
+
*
|
|
1077
|
+
* will produce:
|
|
1078
|
+
*
|
|
1079
|
+
* @code
|
|
1080
|
+
* array([[3, 4, 0, 0],
|
|
1081
|
+
* [0, 0, 0, 0],
|
|
1082
|
+
* [1, 2, 0, 0],
|
|
1083
|
+
* [0, 0, 0, 0]], dtype=float32)
|
|
1084
|
+
* @endcode
|
|
1085
|
+
*
|
|
1086
|
+
* To control the scatter location on an additional axis, add another index
|
|
1087
|
+
* array to ``indices`` and another axis to ``axes``:
|
|
1088
|
+
*
|
|
1089
|
+
* @code
|
|
1090
|
+
* auto in = zeros({4, 4}, float32);
|
|
1091
|
+
* auto indices = std::vector{array({2, 0}), array({1, 2})};
|
|
1092
|
+
* auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
|
|
1093
|
+
* std::vector<int> axes{0, 1};
|
|
1094
|
+
*
|
|
1095
|
+
* auto out = scatter(in, indices, updates, axes);
|
|
1096
|
+
* @endcode
|
|
1097
|
+
*
|
|
1098
|
+
* will produce:
|
|
1099
|
+
*
|
|
1100
|
+
* @code
|
|
1101
|
+
* array([[0, 0, 3, 4],
|
|
1102
|
+
* [0, 0, 0, 0],
|
|
1103
|
+
* [0, 1, 2, 0],
|
|
1104
|
+
* [0, 0, 0, 0]], dtype=float32)
|
|
1105
|
+
* @endcode
|
|
1106
|
+
*
|
|
1107
|
+
* Items in indices are broadcasted together. This means:
|
|
1108
|
+
*
|
|
1109
|
+
* @code
|
|
1110
|
+
* auto indices = std::vector{array({2, 0}), array({1})};
|
|
1111
|
+
* @endcode
|
|
1112
|
+
*
|
|
1113
|
+
* is equivalent to:
|
|
1114
|
+
*
|
|
1115
|
+
* @code
|
|
1116
|
+
* auto indices = std::vector{array({2, 0}), array({1, 1})};
|
|
1117
|
+
* @endcode
|
|
1118
|
+
*
|
|
1119
|
+
* Note, ``scatter`` does not perform bounds checking on the indices and
|
|
1120
|
+
* updates. Out-of-bounds accesses on ``a`` are undefined and typically result
|
|
1121
|
+
* in unintended or invalid memory writes.
|
|
1122
|
+
*/
|
|
1123
|
+
array scatter(
|
|
1124
|
+
const array& a,
|
|
1125
|
+
const std::vector<array>& indices,
|
|
1126
|
+
const array& updates,
|
|
1127
|
+
const std::vector<int>& axes,
|
|
1128
|
+
StreamOrDevice s = {});
|
|
1129
|
+
inline array scatter(
|
|
1130
|
+
const array& a,
|
|
1131
|
+
const array& indices,
|
|
1132
|
+
const array& updates,
|
|
1133
|
+
int axis,
|
|
1134
|
+
StreamOrDevice s = {}) {
|
|
1135
|
+
return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
|
|
1136
|
+
}
|
|
1137
|
+
|
|
1138
|
+
/** Scatter and add updates to given indices */
|
|
1139
|
+
array scatter_add(
|
|
1140
|
+
const array& a,
|
|
1141
|
+
const std::vector<array>& indices,
|
|
1142
|
+
const array& updates,
|
|
1143
|
+
const std::vector<int>& axes,
|
|
1144
|
+
StreamOrDevice s = {});
|
|
1145
|
+
inline array scatter_add(
|
|
1146
|
+
const array& a,
|
|
1147
|
+
const array& indices,
|
|
1148
|
+
const array& updates,
|
|
1149
|
+
int axis,
|
|
1150
|
+
StreamOrDevice s = {}) {
|
|
1151
|
+
return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
|
|
1152
|
+
}
|
|
1153
|
+
|
|
1154
|
+
/** Scatter and prod updates to given indices */
|
|
1155
|
+
array scatter_prod(
|
|
1156
|
+
const array& a,
|
|
1157
|
+
const std::vector<array>& indices,
|
|
1158
|
+
const array& updates,
|
|
1159
|
+
const std::vector<int>& axes,
|
|
1160
|
+
StreamOrDevice s = {});
|
|
1161
|
+
inline array scatter_prod(
|
|
1162
|
+
const array& a,
|
|
1163
|
+
const array& indices,
|
|
1164
|
+
const array& updates,
|
|
1165
|
+
int axis,
|
|
1166
|
+
StreamOrDevice s = {}) {
|
|
1167
|
+
return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
|
|
1168
|
+
}
|
|
1169
|
+
|
|
1170
|
+
/** Scatter and max updates to given linear indices */
|
|
1171
|
+
array scatter_max(
|
|
1172
|
+
const array& a,
|
|
1173
|
+
const std::vector<array>& indices,
|
|
1174
|
+
const array& updates,
|
|
1175
|
+
const std::vector<int>& axes,
|
|
1176
|
+
StreamOrDevice s = {});
|
|
1177
|
+
inline array scatter_max(
|
|
1178
|
+
const array& a,
|
|
1179
|
+
const array& indices,
|
|
1180
|
+
const array& updates,
|
|
1181
|
+
int axis,
|
|
1182
|
+
StreamOrDevice s = {}) {
|
|
1183
|
+
return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
|
|
1184
|
+
}
|
|
1185
|
+
/** Scatter and min updates to given linear indices */
|
|
1186
|
+
array scatter_min(
|
|
1187
|
+
const array& a,
|
|
1188
|
+
const std::vector<array>& indices,
|
|
1189
|
+
const array& updates,
|
|
1190
|
+
const std::vector<int>& axes,
|
|
1191
|
+
StreamOrDevice s = {});
|
|
1192
|
+
inline array scatter_min(
|
|
1193
|
+
const array& a,
|
|
1194
|
+
const array& indices,
|
|
1195
|
+
const array& updates,
|
|
1196
|
+
int axis,
|
|
1197
|
+
StreamOrDevice s = {}) {
|
|
1198
|
+
return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
|
|
1199
|
+
}
|
|
1200
|
+
|
|
1201
|
+
array masked_scatter(
|
|
1202
|
+
const array& a,
|
|
1203
|
+
const array& mask,
|
|
1204
|
+
const array& src,
|
|
1205
|
+
StreamOrDevice s = {});
|
|
1206
|
+
|
|
1207
|
+
/** Square root the elements of an array. */
|
|
1208
|
+
array sqrt(const array& a, StreamOrDevice s = {});
|
|
1209
|
+
|
|
1210
|
+
/** Square root and reciprocal the elements of an array. */
|
|
1211
|
+
array rsqrt(const array& a, StreamOrDevice s = {});
|
|
1212
|
+
|
|
1213
|
+
/** Softmax of an array. */
|
|
1214
|
+
array softmax(
|
|
1215
|
+
const array& a,
|
|
1216
|
+
const std::vector<int>& axes,
|
|
1217
|
+
bool precise = false,
|
|
1218
|
+
StreamOrDevice s = {});
|
|
1219
|
+
|
|
1220
|
+
/** Softmax of an array. */
|
|
1221
|
+
array softmax(const array& a, bool precise = false, StreamOrDevice s = {});
|
|
1222
|
+
|
|
1223
|
+
/** Softmax of an array. */
|
|
1224
|
+
inline array
|
|
1225
|
+
softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
|
|
1226
|
+
return softmax(a, std::vector<int>{axis}, precise, s);
|
|
1227
|
+
}
|
|
1228
|
+
|
|
1229
|
+
/** Raise elements of a to the power of b element-wise */
|
|
1230
|
+
array power(const array& a, const array& b, StreamOrDevice s = {});
|
|
1231
|
+
|
|
1232
|
+
/** Cumulative sum of an array. */
|
|
1233
|
+
array cumsum(
|
|
1234
|
+
const array& a,
|
|
1235
|
+
bool reverse = false,
|
|
1236
|
+
bool inclusive = true,
|
|
1237
|
+
StreamOrDevice s = {});
|
|
1238
|
+
|
|
1239
|
+
/** Cumulative sum of an array along the given axis. */
|
|
1240
|
+
array cumsum(
|
|
1241
|
+
const array& a,
|
|
1242
|
+
int axis,
|
|
1243
|
+
bool reverse = false,
|
|
1244
|
+
bool inclusive = true,
|
|
1245
|
+
StreamOrDevice s = {});
|
|
1246
|
+
|
|
1247
|
+
/** Cumulative product of an array. */
|
|
1248
|
+
array cumprod(
|
|
1249
|
+
const array& a,
|
|
1250
|
+
bool reverse = false,
|
|
1251
|
+
bool inclusive = true,
|
|
1252
|
+
StreamOrDevice s = {});
|
|
1253
|
+
|
|
1254
|
+
/** Cumulative product of an array along the given axis. */
|
|
1255
|
+
array cumprod(
|
|
1256
|
+
const array& a,
|
|
1257
|
+
int axis,
|
|
1258
|
+
bool reverse = false,
|
|
1259
|
+
bool inclusive = true,
|
|
1260
|
+
StreamOrDevice s = {});
|
|
1261
|
+
|
|
1262
|
+
/** Cumulative max of an array. */
|
|
1263
|
+
array cummax(
|
|
1264
|
+
const array& a,
|
|
1265
|
+
bool reverse = false,
|
|
1266
|
+
bool inclusive = true,
|
|
1267
|
+
StreamOrDevice s = {});
|
|
1268
|
+
|
|
1269
|
+
/** Cumulative max of an array along the given axis. */
|
|
1270
|
+
array cummax(
|
|
1271
|
+
const array& a,
|
|
1272
|
+
int axis,
|
|
1273
|
+
bool reverse = false,
|
|
1274
|
+
bool inclusive = true,
|
|
1275
|
+
StreamOrDevice s = {});
|
|
1276
|
+
|
|
1277
|
+
/** Cumulative min of an array. */
|
|
1278
|
+
array cummin(
|
|
1279
|
+
const array& a,
|
|
1280
|
+
bool reverse = false,
|
|
1281
|
+
bool inclusive = true,
|
|
1282
|
+
StreamOrDevice s = {});
|
|
1283
|
+
|
|
1284
|
+
/** Cumulative min of an array along the given axis. */
|
|
1285
|
+
array cummin(
|
|
1286
|
+
const array& a,
|
|
1287
|
+
int axis,
|
|
1288
|
+
bool reverse = false,
|
|
1289
|
+
bool inclusive = true,
|
|
1290
|
+
StreamOrDevice s = {});
|
|
1291
|
+
|
|
1292
|
+
/** General convolution with a filter */
|
|
1293
|
+
array conv_general(
|
|
1294
|
+
array input,
|
|
1295
|
+
array weight,
|
|
1296
|
+
std::vector<int> stride = {},
|
|
1297
|
+
std::vector<int> padding_lo = {},
|
|
1298
|
+
std::vector<int> padding_hi = {},
|
|
1299
|
+
std::vector<int> kernel_dilation = {},
|
|
1300
|
+
std::vector<int> input_dilation = {},
|
|
1301
|
+
int groups = 1,
|
|
1302
|
+
bool flip = false,
|
|
1303
|
+
StreamOrDevice s = {});
|
|
1304
|
+
|
|
1305
|
+
/** General convolution with a filter */
|
|
1306
|
+
inline array conv_general(
|
|
1307
|
+
const array& input,
|
|
1308
|
+
const array& weight,
|
|
1309
|
+
std::vector<int> stride = {},
|
|
1310
|
+
std::vector<int> padding = {},
|
|
1311
|
+
std::vector<int> kernel_dilation = {},
|
|
1312
|
+
std::vector<int> input_dilation = {},
|
|
1313
|
+
int groups = 1,
|
|
1314
|
+
bool flip = false,
|
|
1315
|
+
StreamOrDevice s = {}) {
|
|
1316
|
+
return conv_general(
|
|
1317
|
+
/* const array& input = */ input,
|
|
1318
|
+
/* const array& weight = */ weight,
|
|
1319
|
+
/* std::vector<int> stride = */ stride,
|
|
1320
|
+
/* std::vector<int> padding_lo = */ padding,
|
|
1321
|
+
/* std::vector<int> padding_hi = */ padding,
|
|
1322
|
+
/* std::vector<int> kernel_dilation = */ kernel_dilation,
|
|
1323
|
+
/* std::vector<int> input_dilation = */ input_dilation,
|
|
1324
|
+
/* int groups = */ groups,
|
|
1325
|
+
/* bool flip = */ flip,
|
|
1326
|
+
/* StreamOrDevice s = */ s);
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
/** 1D convolution with a filter */
|
|
1330
|
+
array conv1d(
|
|
1331
|
+
const array& input,
|
|
1332
|
+
const array& weight,
|
|
1333
|
+
int stride = 1,
|
|
1334
|
+
int padding = 0,
|
|
1335
|
+
int dilation = 1,
|
|
1336
|
+
int groups = 1,
|
|
1337
|
+
StreamOrDevice s = {});
|
|
1338
|
+
|
|
1339
|
+
/** 2D convolution with a filter */
|
|
1340
|
+
array conv2d(
|
|
1341
|
+
const array& input,
|
|
1342
|
+
const array& weight,
|
|
1343
|
+
const std::pair<int, int>& stride = {1, 1},
|
|
1344
|
+
const std::pair<int, int>& padding = {0, 0},
|
|
1345
|
+
const std::pair<int, int>& dilation = {1, 1},
|
|
1346
|
+
int groups = 1,
|
|
1347
|
+
StreamOrDevice s = {});
|
|
1348
|
+
|
|
1349
|
+
/** 3D convolution with a filter */
|
|
1350
|
+
array conv3d(
|
|
1351
|
+
const array& input,
|
|
1352
|
+
const array& weight,
|
|
1353
|
+
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
|
1354
|
+
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
|
1355
|
+
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
|
1356
|
+
int groups = 1,
|
|
1357
|
+
StreamOrDevice s = {});
|
|
1358
|
+
|
|
1359
|
+
/** 1D transposed convolution with a filter */
|
|
1360
|
+
array conv_transpose1d(
|
|
1361
|
+
const array& input,
|
|
1362
|
+
const array& weight,
|
|
1363
|
+
int stride = 1,
|
|
1364
|
+
int padding = 0,
|
|
1365
|
+
int dilation = 1,
|
|
1366
|
+
int output_padding = 0,
|
|
1367
|
+
int groups = 1,
|
|
1368
|
+
StreamOrDevice s = {});
|
|
1369
|
+
|
|
1370
|
+
/** 2D transposed convolution with a filter */
|
|
1371
|
+
array conv_transpose2d(
|
|
1372
|
+
const array& input,
|
|
1373
|
+
const array& weight,
|
|
1374
|
+
const std::pair<int, int>& stride = {1, 1},
|
|
1375
|
+
const std::pair<int, int>& padding = {0, 0},
|
|
1376
|
+
const std::pair<int, int>& dilation = {1, 1},
|
|
1377
|
+
const std::pair<int, int>& output_padding = {0, 0},
|
|
1378
|
+
int groups = 1,
|
|
1379
|
+
StreamOrDevice s = {});
|
|
1380
|
+
|
|
1381
|
+
/** 3D transposed convolution with a filter */
|
|
1382
|
+
array conv_transpose3d(
|
|
1383
|
+
const array& input,
|
|
1384
|
+
const array& weight,
|
|
1385
|
+
const std::tuple<int, int, int>& stride = {1, 1, 1},
|
|
1386
|
+
const std::tuple<int, int, int>& padding = {0, 0, 0},
|
|
1387
|
+
const std::tuple<int, int, int>& dilation = {1, 1, 1},
|
|
1388
|
+
const std::tuple<int, int, int>& output_padding = {0, 0, 0},
|
|
1389
|
+
int groups = 1,
|
|
1390
|
+
StreamOrDevice s = {});
|
|
1391
|
+
|
|
1392
|
+
/** Quantized matmul multiplies x with a quantized matrix w*/
|
|
1393
|
+
array quantized_matmul(
|
|
1394
|
+
array x,
|
|
1395
|
+
array w,
|
|
1396
|
+
array scales,
|
|
1397
|
+
std::optional<array> biases = std::nullopt,
|
|
1398
|
+
bool transpose = true,
|
|
1399
|
+
std::optional<int> group_size = std::nullopt,
|
|
1400
|
+
std::optional<int> bits = std::nullopt,
|
|
1401
|
+
const std::string& mode = "affine",
|
|
1402
|
+
StreamOrDevice s = {});
|
|
1403
|
+
|
|
1404
|
+
/** Quantize a matrix along its last axis */
|
|
1405
|
+
std::vector<array> quantize(
|
|
1406
|
+
const array& w,
|
|
1407
|
+
std::optional<int> group_size = std::nullopt,
|
|
1408
|
+
std::optional<int> bits = std::nullopt,
|
|
1409
|
+
const std::string& mode = "affine",
|
|
1410
|
+
StreamOrDevice s = {});
|
|
1411
|
+
|
|
1412
|
+
/** Dequantize a matrix produced by quantize() */
|
|
1413
|
+
array dequantize(
|
|
1414
|
+
const array& w,
|
|
1415
|
+
const array& scales,
|
|
1416
|
+
const std::optional<array>& biases = std::nullopt,
|
|
1417
|
+
std::optional<int> group_size = std::nullopt,
|
|
1418
|
+
std::optional<int> bits = std::nullopt,
|
|
1419
|
+
const std::string& mode = "affine",
|
|
1420
|
+
std::optional<Dtype> dtype = std::nullopt,
|
|
1421
|
+
StreamOrDevice s = {});
|
|
1422
|
+
|
|
1423
|
+
array qqmm(
|
|
1424
|
+
array x, // input activations
|
|
1425
|
+
array w, // maybe quantized weights
|
|
1426
|
+
std::optional<array> w_scales = std::nullopt, // optional scales if w is
|
|
1427
|
+
// quantized
|
|
1428
|
+
std::optional<int> group_size = std::nullopt,
|
|
1429
|
+
std::optional<int> bits = std::nullopt,
|
|
1430
|
+
const std::string& mode = "nvfp4",
|
|
1431
|
+
StreamOrDevice s = {});
|
|
1432
|
+
|
|
1433
|
+
/** Convert an E4M3 float8 to the given floating point dtype. */
|
|
1434
|
+
array from_fp8(array x, Dtype dtype, StreamOrDevice s = {});
|
|
1435
|
+
|
|
1436
|
+
/** Convert a floating point matrix to E4M3 float8. */
|
|
1437
|
+
array to_fp8(array x, StreamOrDevice s = {});
|
|
1438
|
+
|
|
1439
|
+
/** Compute matrix products with matrix-level gather. */
|
|
1440
|
+
array gather_qmm(
|
|
1441
|
+
const array& x,
|
|
1442
|
+
const array& w,
|
|
1443
|
+
const array& scales,
|
|
1444
|
+
const std::optional<array>& biases = std::nullopt,
|
|
1445
|
+
std::optional<array> lhs_indices = std::nullopt,
|
|
1446
|
+
std::optional<array> rhs_indices = std::nullopt,
|
|
1447
|
+
bool transpose = true,
|
|
1448
|
+
std::optional<int> group_size = std::nullopt,
|
|
1449
|
+
std::optional<int> bits = std::nullopt,
|
|
1450
|
+
const std::string& mode = "affine",
|
|
1451
|
+
bool sorted_indices = false,
|
|
1452
|
+
StreamOrDevice s = {});
|
|
1453
|
+
|
|
1454
|
+
/** Returns a contraction of a and b over multiple dimensions. */
|
|
1455
|
+
array tensordot(
|
|
1456
|
+
const array& a,
|
|
1457
|
+
const array& b,
|
|
1458
|
+
const int axis = 2,
|
|
1459
|
+
StreamOrDevice s = {});
|
|
1460
|
+
|
|
1461
|
+
array tensordot(
|
|
1462
|
+
const array& a,
|
|
1463
|
+
const array& b,
|
|
1464
|
+
const std::vector<int>& axes_a,
|
|
1465
|
+
const std::vector<int>& axes_b,
|
|
1466
|
+
StreamOrDevice s = {});
|
|
1467
|
+
|
|
1468
|
+
/** Compute the outer product of two vectors. */
|
|
1469
|
+
array outer(const array& a, const array& b, StreamOrDevice s = {});
|
|
1470
|
+
|
|
1471
|
+
/** Compute the inner product of two vectors. */
|
|
1472
|
+
array inner(const array& a, const array& b, StreamOrDevice s = {});
|
|
1473
|
+
|
|
1474
|
+
/** Compute D = beta * C + alpha * (A @ B) */
|
|
1475
|
+
array addmm(
|
|
1476
|
+
array c,
|
|
1477
|
+
array a,
|
|
1478
|
+
array b,
|
|
1479
|
+
const float& alpha = 1.f,
|
|
1480
|
+
const float& beta = 1.f,
|
|
1481
|
+
StreamOrDevice s = {});
|
|
1482
|
+
|
|
1483
|
+
/** Compute matrix product with block masking */
|
|
1484
|
+
array block_masked_mm(
|
|
1485
|
+
array a,
|
|
1486
|
+
array b,
|
|
1487
|
+
int block_size,
|
|
1488
|
+
std::optional<array> mask_out = std::nullopt,
|
|
1489
|
+
std::optional<array> mask_lhs = std::nullopt,
|
|
1490
|
+
std::optional<array> mask_rhs = std::nullopt,
|
|
1491
|
+
StreamOrDevice s = {});
|
|
1492
|
+
|
|
1493
|
+
/** Compute matrix product with matrix-level gather */
|
|
1494
|
+
array gather_mm(
|
|
1495
|
+
array a,
|
|
1496
|
+
array b,
|
|
1497
|
+
std::optional<array> lhs_indices = std::nullopt,
|
|
1498
|
+
std::optional<array> rhs_indices = std::nullopt,
|
|
1499
|
+
bool sorted_indices = false,
|
|
1500
|
+
StreamOrDevice s = {});
|
|
1501
|
+
|
|
1502
|
+
/**
|
|
1503
|
+
* Compute a matrix product but segment the inner dimension and write the
|
|
1504
|
+
* result separately for each segment.
|
|
1505
|
+
*/
|
|
1506
|
+
array segmented_mm(array a, array b, array segments, StreamOrDevice s = {});
|
|
1507
|
+
|
|
1508
|
+
/** Extract a diagonal or construct a diagonal array */
|
|
1509
|
+
array diagonal(
|
|
1510
|
+
const array& a,
|
|
1511
|
+
int offset = 0,
|
|
1512
|
+
int axis1 = 0,
|
|
1513
|
+
int axis2 = 1,
|
|
1514
|
+
StreamOrDevice s = {});
|
|
1515
|
+
|
|
1516
|
+
/** Extract diagonal from a 2d array or create a diagonal matrix. */
|
|
1517
|
+
array diag(const array& a, int k = 0, StreamOrDevice s = {});
|
|
1518
|
+
|
|
1519
|
+
/** Return the sum along a specified diagonal in the given array. */
|
|
1520
|
+
array trace(
|
|
1521
|
+
const array& a,
|
|
1522
|
+
int offset,
|
|
1523
|
+
int axis1,
|
|
1524
|
+
int axis2,
|
|
1525
|
+
Dtype dtype,
|
|
1526
|
+
StreamOrDevice s = {});
|
|
1527
|
+
array trace(
|
|
1528
|
+
const array& a,
|
|
1529
|
+
int offset,
|
|
1530
|
+
int axis1,
|
|
1531
|
+
int axis2,
|
|
1532
|
+
StreamOrDevice s = {});
|
|
1533
|
+
array trace(const array& a, StreamOrDevice s = {});
|
|
1534
|
+
|
|
1535
|
+
/**
|
|
1536
|
+
* Implements the identity function but allows injecting dependencies to other
|
|
1537
|
+
* arrays. This ensures that these other arrays will have been computed
|
|
1538
|
+
* when the outputs of this function are computed.
|
|
1539
|
+
*/
|
|
1540
|
+
std::vector<array> depends(
|
|
1541
|
+
const std::vector<array>& inputs,
|
|
1542
|
+
const std::vector<array>& dependencies);
|
|
1543
|
+
|
|
1544
|
+
/** convert an array to an atleast ndim array */
|
|
1545
|
+
array atleast_1d(const array& a, StreamOrDevice s = {});
|
|
1546
|
+
std::vector<array> atleast_1d(
|
|
1547
|
+
const std::vector<array>& a,
|
|
1548
|
+
StreamOrDevice s = {});
|
|
1549
|
+
array atleast_2d(const array& a, StreamOrDevice s = {});
|
|
1550
|
+
std::vector<array> atleast_2d(
|
|
1551
|
+
const std::vector<array>& a,
|
|
1552
|
+
StreamOrDevice s = {});
|
|
1553
|
+
array atleast_3d(const array& a, StreamOrDevice s = {});
|
|
1554
|
+
std::vector<array> atleast_3d(
|
|
1555
|
+
const std::vector<array>& a,
|
|
1556
|
+
StreamOrDevice s = {});
|
|
1557
|
+
|
|
1558
|
+
/**
|
|
1559
|
+
* Extract the number of elements along some axes as a scalar array. Used to
|
|
1560
|
+
* allow shape dependent shapeless compilation (pun intended).
|
|
1561
|
+
*/
|
|
1562
|
+
array number_of_elements(
|
|
1563
|
+
const array& a,
|
|
1564
|
+
std::vector<int> axes,
|
|
1565
|
+
bool inverted,
|
|
1566
|
+
Dtype dtype = int32,
|
|
1567
|
+
StreamOrDevice s = {});
|
|
1568
|
+
|
|
1569
|
+
array conjugate(const array& a, StreamOrDevice s = {});
|
|
1570
|
+
|
|
1571
|
+
/** Bitwise and. */
|
|
1572
|
+
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
|
|
1573
|
+
array operator&(const array& a, const array& b);
|
|
1574
|
+
|
|
1575
|
+
/** Bitwise inclusive or. */
|
|
1576
|
+
array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
|
|
1577
|
+
array operator|(const array& a, const array& b);
|
|
1578
|
+
|
|
1579
|
+
/** Bitwise exclusive or. */
|
|
1580
|
+
array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
|
|
1581
|
+
array operator^(const array& a, const array& b);
|
|
1582
|
+
|
|
1583
|
+
/** Shift bits to the left. */
|
|
1584
|
+
array left_shift(const array& a, const array& b, StreamOrDevice s = {});
|
|
1585
|
+
array operator<<(const array& a, const array& b);
|
|
1586
|
+
|
|
1587
|
+
/** Shift bits to the right. */
|
|
1588
|
+
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
|
|
1589
|
+
array operator>>(const array& a, const array& b);
|
|
1590
|
+
|
|
1591
|
+
/** Invert the bits. */
|
|
1592
|
+
array bitwise_invert(const array& a, StreamOrDevice s = {});
|
|
1593
|
+
array operator~(const array& a);
|
|
1594
|
+
|
|
1595
|
+
array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
|
|
1596
|
+
|
|
1597
|
+
/** Roll elements along an axis and introduce them on the other side */
|
|
1598
|
+
array roll(const array& a, int shift, StreamOrDevice s = {});
|
|
1599
|
+
array roll(const array& a, const Shape& shift, StreamOrDevice s = {});
|
|
1600
|
+
array roll(const array& a, int shift, int axis, StreamOrDevice s = {});
|
|
1601
|
+
array roll(
|
|
1602
|
+
const array& a,
|
|
1603
|
+
int shift,
|
|
1604
|
+
const std::vector<int>& axes,
|
|
1605
|
+
StreamOrDevice s = {});
|
|
1606
|
+
array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {});
|
|
1607
|
+
array roll(
|
|
1608
|
+
const array& a,
|
|
1609
|
+
const Shape& shift,
|
|
1610
|
+
const std::vector<int>& axes,
|
|
1611
|
+
StreamOrDevice s = {});
|
|
1612
|
+
|
|
1613
|
+
/* The real part of a complex array. */
|
|
1614
|
+
array real(const array& a, StreamOrDevice s = {});
|
|
1615
|
+
|
|
1616
|
+
/* The imaginary part of a complex array. */
|
|
1617
|
+
array imag(const array& a, StreamOrDevice s = {});
|
|
1618
|
+
|
|
1619
|
+
/* Ensure the array's underlying memory is contiguous. */
|
|
1620
|
+
array contiguous(
|
|
1621
|
+
const array& a,
|
|
1622
|
+
bool allow_col_major = false,
|
|
1623
|
+
StreamOrDevice s = {});
|
|
1624
|
+
|
|
1625
|
+
/** @} */
|
|
1626
|
+
|
|
1627
|
+
} // namespace mlx::core
|