mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,2524 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <unordered_set>
6
+
7
+ #include "mlx/array.h"
8
+ #include "mlx/device.h"
9
+ #include "mlx/io/load.h"
10
+ #include "mlx/stream.h"
11
+
12
+ #define DEFINE_VMAP() \
13
+ virtual std::pair<std::vector<array>, std::vector<int>> vmap( \
14
+ const std::vector<array>& inputs, const std::vector<int>& axes) \
15
+ override;
16
+
17
+ #define DEFINE_GRADS() \
18
+ std::vector<array> jvp( \
19
+ const std::vector<array>& primals, \
20
+ const std::vector<array>& tangents, \
21
+ const std::vector<int>& argnums) override; \
22
+ \
23
+ std::vector<array> vjp( \
24
+ const std::vector<array>& primals, \
25
+ const std::vector<array>& cotangents, \
26
+ const std::vector<int>& argnums, \
27
+ const std::vector<array>& outputs) override;
28
+
29
+ #define DEFINE_NAME(PRIMITIVE) \
30
+ const char* name() const override { \
31
+ return #PRIMITIVE; \
32
+ }
33
+
34
+ #define DEFINE_DEFAULT_IS_EQUIVALENT() \
35
+ bool is_equivalent(const Primitive& other) const override { \
36
+ return true; \
37
+ }
38
+
39
+ #define DEFINE_INPUT_OUTPUT_SHAPE() \
40
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) \
41
+ override { \
42
+ return {inputs[0].shape()}; \
43
+ }
44
+
45
+ namespace mlx::core {
46
+
47
+ // Abstract base class
48
+ class Primitive {
49
+ public:
50
+ explicit Primitive(Stream stream) : stream_(stream) {}
51
+
52
+ /** The device the primitive will run on. */
53
+ const Device& device() {
54
+ return stream().device;
55
+ }
56
+
57
+ /** The stream the primitive will run on. */
58
+ const Stream& stream() {
59
+ return stream_;
60
+ }
61
+
62
+ /**
63
+ * A primitive must know how to evaluate itself on
64
+ * the CPU/GPU for the given inputs and populate the output arrays.
65
+ *
66
+ * To avoid unnecessary allocations, the evaluation function
67
+ * is responsible for allocating space for the array.
68
+ */
69
+ virtual void eval_cpu(
70
+ const std::vector<array>& inputs,
71
+ std::vector<array>& outputs) = 0;
72
+ virtual void eval_gpu(
73
+ const std::vector<array>& inputs,
74
+ std::vector<array>& outputs) = 0;
75
+
76
+ /**
77
+ * The Jacobian-vector product.
78
+ */
79
+ virtual std::vector<array> jvp(
80
+ const std::vector<array>& primals,
81
+ const std::vector<array>& tangents,
82
+ const std::vector<int>& argnums);
83
+
84
+ /**
85
+ * The vector-Jacobian product.
86
+ */
87
+ virtual std::vector<array> vjp(
88
+ const std::vector<array>& primals,
89
+ const std::vector<array>& cotangents,
90
+ const std::vector<int>& argnums,
91
+ const std::vector<array>& outputs);
92
+
93
+ /**
94
+ * The primitive must know how to vectorize itself across
95
+ * the given axes. The output is a pair containing the output arrays
96
+ * representing the vectorized computation and the axes which
97
+ * corresponds to the vectorized dimensions of each output.
98
+ */
99
+ virtual std::pair<std::vector<array>, std::vector<int>> vmap(
100
+ const std::vector<array>& inputs,
101
+ const std::vector<int>& axes);
102
+
103
+ /** Get the name of primitive. */
104
+ virtual const char* name() const = 0;
105
+
106
+ /** Equivalence check defaults to false unless overridden by the primitive */
107
+ virtual bool is_equivalent(const Primitive& other) const {
108
+ return false;
109
+ }
110
+
111
+ /** Get the output shapes of the primitive. This is not required to be
112
+ * implemented by derived classes, in which case it will throw. */
113
+ virtual std::vector<Shape> output_shapes(const std::vector<array>& inputs);
114
+
115
+ virtual ~Primitive() = default;
116
+ Primitive(const Primitive& other) = delete;
117
+ Primitive(Primitive&& other) = delete;
118
+ Primitive& operator=(const Primitive& other) = delete;
119
+ Primitive& operator=(Primitive&& other) = delete;
120
+
121
+ private:
122
+ // Every primitive stores the stream it should run in
123
+ Stream stream_;
124
+ };
125
+
126
+ class UnaryPrimitive : public Primitive {
127
+ /**
128
+ * An abstract base class for a primitive with a single output.
129
+ */
130
+ public:
131
+ explicit UnaryPrimitive(Stream stream) : Primitive(stream) {}
132
+
133
+ virtual void eval_cpu(const std::vector<array>& inputs, array& output) = 0;
134
+ virtual void eval_gpu(const std::vector<array>& inputs, array& output) = 0;
135
+
136
+ inline void eval_cpu(
137
+ const std::vector<array>& inputs,
138
+ std::vector<array>& outputs) override {
139
+ eval_cpu(inputs, outputs[0]);
140
+ }
141
+ inline void eval_gpu(
142
+ const std::vector<array>& inputs,
143
+ std::vector<array>& outputs) override {
144
+ eval_gpu(inputs, outputs[0]);
145
+ }
146
+
147
+ virtual ~UnaryPrimitive() = default;
148
+ UnaryPrimitive(const UnaryPrimitive& other) = delete;
149
+ UnaryPrimitive(UnaryPrimitive&& other) = delete;
150
+ UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete;
151
+ UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete;
152
+ };
153
+
154
+ enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 };
155
+
156
+ std::string quantization_mode_to_string(QuantizationMode mode);
157
+ QuantizationMode string_to_quantization_mode(
158
+ const std::string& mode,
159
+ std::string_view error_tag = "");
160
+
161
+ class Abs : public UnaryPrimitive {
162
+ public:
163
+ explicit Abs(Stream stream) : UnaryPrimitive(stream) {}
164
+
165
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
166
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
167
+
168
+ DEFINE_VMAP()
169
+ DEFINE_GRADS()
170
+ DEFINE_NAME(Abs)
171
+ DEFINE_DEFAULT_IS_EQUIVALENT()
172
+ DEFINE_INPUT_OUTPUT_SHAPE()
173
+ };
174
+
175
+ class Add : public UnaryPrimitive {
176
+ public:
177
+ explicit Add(Stream stream) : UnaryPrimitive(stream) {}
178
+
179
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
180
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
181
+
182
+ DEFINE_VMAP()
183
+ DEFINE_GRADS()
184
+ DEFINE_NAME(Add)
185
+ DEFINE_DEFAULT_IS_EQUIVALENT()
186
+ DEFINE_INPUT_OUTPUT_SHAPE()
187
+ };
188
+
189
+ class AddMM : public UnaryPrimitive {
190
+ public:
191
+ explicit AddMM(Stream stream, float alpha, float beta)
192
+ : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {}
193
+
194
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
195
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
196
+
197
+ DEFINE_GRADS()
198
+ DEFINE_VMAP()
199
+ DEFINE_NAME(AddMM)
200
+
201
+ bool is_equivalent(const Primitive& other) const override;
202
+ std::pair<float, float> state() const {
203
+ return {alpha_, beta_};
204
+ };
205
+
206
+ private:
207
+ const float alpha_;
208
+ const float beta_;
209
+ };
210
+
211
+ class Arange : public UnaryPrimitive {
212
+ public:
213
+ explicit Arange(Stream stream, double start, double stop, double step)
214
+ : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {}
215
+
216
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
217
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
218
+
219
+ DEFINE_NAME(Arange)
220
+ bool is_equivalent(const Primitive& other) const override;
221
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
222
+ std::tuple<double, double, double> state() const {
223
+ return {start_, stop_, step_};
224
+ };
225
+
226
+ private:
227
+ double start_;
228
+ double stop_;
229
+ double step_;
230
+ };
231
+
232
+ class ArcCos : public UnaryPrimitive {
233
+ public:
234
+ explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {}
235
+
236
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
237
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
238
+
239
+ DEFINE_VMAP()
240
+ DEFINE_GRADS()
241
+ DEFINE_NAME(ArcCos)
242
+ DEFINE_DEFAULT_IS_EQUIVALENT()
243
+ DEFINE_INPUT_OUTPUT_SHAPE()
244
+ };
245
+
246
+ class ArcCosh : public UnaryPrimitive {
247
+ public:
248
+ explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {}
249
+
250
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
251
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
252
+
253
+ DEFINE_VMAP()
254
+ DEFINE_GRADS()
255
+ DEFINE_NAME(ArcCosh)
256
+ DEFINE_DEFAULT_IS_EQUIVALENT()
257
+ DEFINE_INPUT_OUTPUT_SHAPE()
258
+ };
259
+
260
+ class ArcSin : public UnaryPrimitive {
261
+ public:
262
+ explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {}
263
+
264
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
265
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
266
+
267
+ DEFINE_VMAP()
268
+ DEFINE_GRADS()
269
+ DEFINE_NAME(ArcSin)
270
+ DEFINE_DEFAULT_IS_EQUIVALENT()
271
+ DEFINE_INPUT_OUTPUT_SHAPE()
272
+ };
273
+
274
+ class ArcSinh : public UnaryPrimitive {
275
+ public:
276
+ explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {}
277
+
278
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
279
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
280
+
281
+ DEFINE_VMAP()
282
+ DEFINE_GRADS()
283
+ DEFINE_NAME(ArcSinh)
284
+ DEFINE_DEFAULT_IS_EQUIVALENT()
285
+ DEFINE_INPUT_OUTPUT_SHAPE()
286
+ };
287
+
288
+ class ArcTan : public UnaryPrimitive {
289
+ public:
290
+ explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {}
291
+
292
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
293
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
294
+
295
+ DEFINE_VMAP()
296
+ DEFINE_GRADS()
297
+ DEFINE_NAME(ArcTan)
298
+ DEFINE_DEFAULT_IS_EQUIVALENT()
299
+ DEFINE_INPUT_OUTPUT_SHAPE()
300
+ };
301
+
302
+ class ArcTan2 : public UnaryPrimitive {
303
+ public:
304
+ explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {}
305
+
306
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
307
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
308
+
309
+ DEFINE_VMAP()
310
+ DEFINE_GRADS()
311
+ DEFINE_NAME(ArcTan2)
312
+ DEFINE_DEFAULT_IS_EQUIVALENT()
313
+ DEFINE_INPUT_OUTPUT_SHAPE()
314
+ };
315
+
316
+ class ArcTanh : public UnaryPrimitive {
317
+ public:
318
+ explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {}
319
+
320
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
321
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
322
+
323
+ DEFINE_VMAP()
324
+ DEFINE_GRADS()
325
+ DEFINE_NAME(ArcTanh)
326
+ DEFINE_DEFAULT_IS_EQUIVALENT()
327
+ DEFINE_INPUT_OUTPUT_SHAPE()
328
+ };
329
+
330
+ class ArgPartition : public UnaryPrimitive {
331
+ public:
332
+ explicit ArgPartition(Stream stream, int kth, int axis)
333
+ : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
334
+
335
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
336
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
337
+
338
+ DEFINE_VMAP()
339
+ DEFINE_GRADS()
340
+ DEFINE_NAME(ArgPartition)
341
+ DEFINE_INPUT_OUTPUT_SHAPE()
342
+ bool is_equivalent(const Primitive& other) const override;
343
+ std::pair<int, int> state() const {
344
+ return {kth_, axis_};
345
+ };
346
+
347
+ private:
348
+ int kth_;
349
+ int axis_;
350
+ };
351
+
352
+ class ArgReduce : public UnaryPrimitive {
353
+ public:
354
+ enum ReduceType {
355
+ ArgMin,
356
+ ArgMax,
357
+ };
358
+
359
+ explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis)
360
+ : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
361
+
362
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
363
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
364
+
365
+ DEFINE_VMAP()
366
+ DEFINE_GRADS()
367
+ DEFINE_NAME(ArgReduce)
368
+ bool is_equivalent(const Primitive& other) const override;
369
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
370
+ std::pair<ReduceType, int> state() const {
371
+ return {reduce_type_, axis_};
372
+ };
373
+
374
+ private:
375
+ ReduceType reduce_type_;
376
+ int axis_;
377
+ };
378
+
379
+ class ArgSort : public UnaryPrimitive {
380
+ public:
381
+ explicit ArgSort(Stream stream, int axis)
382
+ : UnaryPrimitive(stream), axis_(axis) {}
383
+
384
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
385
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
386
+
387
+ DEFINE_VMAP()
388
+ DEFINE_GRADS()
389
+ DEFINE_NAME(ArgSort)
390
+ DEFINE_INPUT_OUTPUT_SHAPE()
391
+ bool is_equivalent(const Primitive& other) const override;
392
+ int state() const {
393
+ return axis_;
394
+ };
395
+
396
+ private:
397
+ int axis_;
398
+ };
399
+
400
+ class AsType : public UnaryPrimitive {
401
+ public:
402
+ explicit AsType(Stream stream, Dtype dtype)
403
+ : UnaryPrimitive(stream), dtype_(dtype) {}
404
+
405
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
406
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
407
+
408
+ DEFINE_VMAP()
409
+ DEFINE_GRADS()
410
+ DEFINE_NAME(AsType)
411
+ DEFINE_INPUT_OUTPUT_SHAPE()
412
+ bool is_equivalent(const Primitive& other) const override;
413
+ Dtype state() const {
414
+ return dtype_;
415
+ };
416
+
417
+ private:
418
+ Dtype dtype_;
419
+ };
420
+
421
+ class AsStrided : public UnaryPrimitive {
422
+ public:
423
+ explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset)
424
+ : UnaryPrimitive(stream),
425
+ shape_(std::move(shape)),
426
+ strides_(std::move(strides)),
427
+ offset_(offset) {}
428
+
429
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
430
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
431
+
432
+ DEFINE_GRADS()
433
+ DEFINE_NAME(AsStrided)
434
+ bool is_equivalent(const Primitive& other) const override;
435
+ auto state() const {
436
+ return std::make_tuple(shape_, strides_, offset_);
437
+ }
438
+
439
+ private:
440
+ Shape shape_;
441
+ Strides strides_;
442
+ size_t offset_;
443
+
444
+ void eval(const std::vector<array>& inputs, array& out);
445
+ };
446
+
447
+ class BitwiseBinary : public UnaryPrimitive {
448
+ public:
449
+ enum Op { And, Or, Xor, LeftShift, RightShift };
450
+
451
+ explicit BitwiseBinary(Stream stream, Op op)
452
+ : UnaryPrimitive(stream), op_(op) {}
453
+
454
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
455
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
456
+
457
+ DEFINE_VMAP()
458
+ DEFINE_GRADS()
459
+
460
+ const char* name() const override {
461
+ switch (op_) {
462
+ case BitwiseBinary::And:
463
+ return "BitwiseAnd";
464
+ case BitwiseBinary::Or:
465
+ return "BitwiseOr";
466
+ case BitwiseBinary::Xor:
467
+ return "BitwiseXor";
468
+ case BitwiseBinary::LeftShift:
469
+ return "LeftShift";
470
+ case BitwiseBinary::RightShift:
471
+ return "RightShift";
472
+ }
473
+ return "<unknwon BitwiseBinary>";
474
+ }
475
+
476
+ bool is_equivalent(const Primitive& other) const override;
477
+ DEFINE_INPUT_OUTPUT_SHAPE()
478
+ auto state() const {
479
+ return op_;
480
+ }
481
+
482
+ private:
483
+ Op op_;
484
+ };
485
+
486
+ class BitwiseInvert : public UnaryPrimitive {
487
+ public:
488
+ explicit BitwiseInvert(Stream stream) : UnaryPrimitive(stream) {}
489
+
490
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
491
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
492
+
493
+ DEFINE_VMAP()
494
+ DEFINE_NAME(BitwiseInvert)
495
+ DEFINE_DEFAULT_IS_EQUIVALENT()
496
+ DEFINE_INPUT_OUTPUT_SHAPE()
497
+ };
498
+
499
+ class BlockMaskedMM : public UnaryPrimitive {
500
+ public:
501
+ explicit BlockMaskedMM(Stream stream, int block_size)
502
+ : UnaryPrimitive(stream), block_size_(block_size) {}
503
+
504
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
505
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
506
+
507
+ std::vector<array> vjp(
508
+ const std::vector<array>& primals,
509
+ const std::vector<array>& cotangents,
510
+ const std::vector<int>& argnums,
511
+ const std::vector<array>& outputs) override;
512
+
513
+ DEFINE_NAME(BlockMaskedMM)
514
+ bool is_equivalent(const Primitive& other) const override;
515
+ auto state() const {
516
+ return block_size_;
517
+ }
518
+
519
+ private:
520
+ int block_size_;
521
+ };
522
+
523
+ class GatherMM : public UnaryPrimitive {
524
+ public:
525
+ explicit GatherMM(
526
+ Stream stream,
527
+ bool left_sorted = false,
528
+ bool right_sorted = false)
529
+ : UnaryPrimitive(stream),
530
+ left_sorted_(left_sorted),
531
+ right_sorted_(right_sorted) {}
532
+
533
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
534
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
535
+
536
+ std::vector<array> vjp(
537
+ const std::vector<array>& primals,
538
+ const std::vector<array>& cotangents,
539
+ const std::vector<int>& argnums,
540
+ const std::vector<array>& outputs) override;
541
+
542
+ DEFINE_NAME(GatherMM)
543
+ bool is_equivalent(const Primitive& other) const override;
544
+ auto state() const {
545
+ return std::make_pair(left_sorted_, right_sorted_);
546
+ }
547
+
548
+ private:
549
+ bool left_sorted_;
550
+ bool right_sorted_;
551
+ };
552
+
553
+ class SegmentedMM : public UnaryPrimitive {
554
+ public:
555
+ explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {}
556
+
557
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
558
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
559
+
560
+ DEFINE_NAME(SegmentedMM)
561
+ };
562
+
563
+ class BroadcastAxes : public UnaryPrimitive {
564
+ public:
565
+ explicit BroadcastAxes(Stream stream, std::vector<int> ignore_axes = {})
566
+ : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {}
567
+
568
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
569
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
570
+
571
+ DEFINE_VMAP()
572
+ DEFINE_GRADS()
573
+ DEFINE_NAME(BroadcastAxes)
574
+ bool is_equivalent(const Primitive& other) const override;
575
+ static Shape output_shape(
576
+ const std::vector<array>& inputs,
577
+ const std::vector<int>& ignore_axes);
578
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
579
+ auto state() const {
580
+ return ignore_axes_;
581
+ }
582
+
583
+ private:
584
+ void eval(const std::vector<array>& inputs, array& out);
585
+ std::vector<int> ignore_axes_;
586
+ };
587
+
588
+ class Broadcast : public UnaryPrimitive {
589
+ public:
590
+ explicit Broadcast(Stream stream, const Shape& shape)
591
+ : UnaryPrimitive(stream), shape_(shape) {}
592
+
593
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
594
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
595
+
596
+ DEFINE_VMAP()
597
+ DEFINE_GRADS()
598
+ DEFINE_NAME(Broadcast)
599
+ static Shape output_shape(const std::vector<array>& inputs);
600
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
601
+ bool is_equivalent(const Primitive& other) const override;
602
+ Shape state() const {
603
+ return shape_;
604
+ };
605
+
606
+ private:
607
+ Shape shape_;
608
+
609
+ void eval(const std::vector<array>& inputs, array& out);
610
+ };
611
+
612
+ class Ceil : public UnaryPrimitive {
613
+ public:
614
+ explicit Ceil(Stream stream) : UnaryPrimitive(stream) {}
615
+
616
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
617
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
618
+
619
+ DEFINE_VMAP()
620
+ DEFINE_GRADS()
621
+ DEFINE_NAME(Ceil)
622
+ DEFINE_DEFAULT_IS_EQUIVALENT()
623
+ DEFINE_INPUT_OUTPUT_SHAPE()
624
+ };
625
+
626
+ class Compiled : public Primitive {
627
+ public:
628
+ /*
629
+ * The inputs, outputs and tape are either tracers or constants.
630
+ * - The tape should not contain the inputs, but it should contain the
631
+ * outputs.
632
+ * - The tape should also have only one array per primitive for multi-output
633
+ * primitives.
634
+ * - The constant_ids contains ids of arrays in the input list that are safe
635
+ * to treat as scalar constants.
636
+ */
637
+ explicit Compiled(
638
+ Stream stream,
639
+ std::vector<array> inputs,
640
+ std::vector<array> outputs,
641
+ std::vector<array> tape,
642
+ std::unordered_set<uintptr_t> constant_ids);
643
+
644
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
645
+ override;
646
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
647
+ override;
648
+
649
+ DEFINE_VMAP()
650
+ DEFINE_GRADS()
651
+ const char* name() const override;
652
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
653
+ bool is_equivalent(const Primitive& other) const override;
654
+
655
+ std::string lib_name() const {
656
+ return kernel_lib_;
657
+ }
658
+
659
+ private:
660
+ const std::vector<array> inputs_;
661
+ const std::vector<array> outputs_;
662
+ const std::vector<array> tape_;
663
+ const std::unordered_set<uintptr_t> constant_ids_;
664
+ const std::function<bool(size_t)> is_constant_;
665
+
666
+ mutable std::string name_;
667
+ std::string kernel_lib_;
668
+ };
669
+
670
+ class Concatenate : public UnaryPrimitive {
671
+ public:
672
+ explicit Concatenate(Stream stream, int axis)
673
+ : UnaryPrimitive(stream), axis_(axis) {}
674
+
675
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
676
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
677
+
678
+ DEFINE_VMAP()
679
+ DEFINE_GRADS()
680
+ DEFINE_NAME(Concatenate)
681
+ bool is_equivalent(const Primitive& other) const override;
682
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
683
+ auto state() const {
684
+ return axis_;
685
+ }
686
+
687
+ private:
688
+ int axis_;
689
+ };
690
+
691
+ class Conjugate : public UnaryPrimitive {
692
+ public:
693
+ explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {}
694
+
695
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
696
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
697
+
698
+ DEFINE_VMAP()
699
+ DEFINE_NAME(Conjugate)
700
+ DEFINE_DEFAULT_IS_EQUIVALENT()
701
+ DEFINE_INPUT_OUTPUT_SHAPE()
702
+ };
703
+
704
+ class Contiguous : public UnaryPrimitive {
705
+ public:
706
+ explicit Contiguous(Stream stream, bool allow_col_major)
707
+ : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {}
708
+
709
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
710
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
711
+
712
+ DEFINE_VMAP()
713
+ DEFINE_GRADS()
714
+ DEFINE_NAME(Contiguous)
715
+ DEFINE_INPUT_OUTPUT_SHAPE()
716
+
717
+ bool is_equivalent(const Primitive& other) const override;
718
+
719
+ private:
720
+ bool allow_col_major_;
721
+ };
722
+
723
+ class Convolution : public UnaryPrimitive {
724
+ public:
725
+ explicit Convolution(
726
+ Stream stream,
727
+ const std::vector<int>& kernel_strides,
728
+ const std::vector<int>& padding_lo,
729
+ const std::vector<int>& padding_hi,
730
+ const std::vector<int>& kernel_dilation,
731
+ const std::vector<int>& input_dilation,
732
+ const int groups = 1,
733
+ const bool flip = false)
734
+ : UnaryPrimitive(stream),
735
+ padding_lo_(padding_lo),
736
+ padding_hi_(padding_hi),
737
+ kernel_strides_(kernel_strides),
738
+ kernel_dilation_(kernel_dilation),
739
+ input_dilation_(input_dilation),
740
+ groups_(groups),
741
+ flip_(flip) {}
742
+
743
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
744
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
745
+
746
+ std::vector<array> vjp(
747
+ const std::vector<array>& primals,
748
+ const std::vector<array>& cotangents,
749
+ const std::vector<int>& argnums,
750
+ const std::vector<array>& outputs) override;
751
+
752
+ DEFINE_VMAP()
753
+ DEFINE_NAME(Convolution)
754
+ bool is_equivalent(const Primitive& other) const override;
755
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
756
+ auto state() const {
757
+ return std::make_tuple(
758
+ kernel_strides_,
759
+ padding_lo_,
760
+ padding_hi_,
761
+ kernel_dilation_,
762
+ input_dilation_,
763
+ groups_,
764
+ flip_);
765
+ }
766
+
767
+ static Shape conv_out_shape(
768
+ const Shape& in_shape,
769
+ const Shape& wt_shape,
770
+ const std::vector<int>& strides,
771
+ const std::vector<int>& pads_lo,
772
+ const std::vector<int>& pads_hi,
773
+ const std::vector<int>& kernel_dilation,
774
+ const std::vector<int>& input_dilation);
775
+
776
+ private:
777
+ std::vector<int> padding_lo_;
778
+ std::vector<int> padding_hi_;
779
+ std::vector<int> kernel_strides_;
780
+ std::vector<int> kernel_dilation_;
781
+ std::vector<int> input_dilation_;
782
+ int groups_;
783
+ bool flip_;
784
+ };
785
+
786
+ class Copy : public UnaryPrimitive {
787
+ public:
788
+ explicit Copy(Stream stream) : UnaryPrimitive(stream) {}
789
+
790
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
791
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
792
+
793
+ DEFINE_VMAP()
794
+ DEFINE_GRADS()
795
+ DEFINE_NAME(Copy)
796
+ DEFINE_DEFAULT_IS_EQUIVALENT()
797
+ DEFINE_INPUT_OUTPUT_SHAPE()
798
+
799
+ private:
800
+ void eval(const std::vector<array>& inputs, array& out);
801
+ };
802
+
803
+ class Cos : public UnaryPrimitive {
804
+ public:
805
+ explicit Cos(Stream stream) : UnaryPrimitive(stream) {}
806
+
807
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
808
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
809
+
810
+ DEFINE_VMAP()
811
+ DEFINE_GRADS()
812
+ DEFINE_NAME(Cos)
813
+ DEFINE_DEFAULT_IS_EQUIVALENT()
814
+ DEFINE_INPUT_OUTPUT_SHAPE()
815
+ };
816
+
817
+ class Cosh : public UnaryPrimitive {
818
+ public:
819
+ explicit Cosh(Stream stream) : UnaryPrimitive(stream) {}
820
+
821
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
822
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
823
+
824
+ DEFINE_VMAP()
825
+ DEFINE_GRADS()
826
+ DEFINE_NAME(Cosh)
827
+ DEFINE_DEFAULT_IS_EQUIVALENT()
828
+ DEFINE_INPUT_OUTPUT_SHAPE()
829
+ };
830
+
831
+ class CustomTransforms : public Primitive {
832
+ public:
833
+ explicit CustomTransforms(
834
+ Stream stream,
835
+ int num_outputs,
836
+ std::function<std::vector<array>(
837
+ const std::vector<array>&,
838
+ const std::vector<array>&,
839
+ const std::vector<array>&)> vjp,
840
+ std::function<std::vector<array>(
841
+ const std::vector<array>&,
842
+ const std::vector<array>&,
843
+ const std::vector<int>&)> jvp,
844
+ std::function<std::pair<std::vector<array>, std::vector<int>>(
845
+ const std::vector<array>&,
846
+ const std::vector<int>&)> vmap)
847
+ : Primitive(stream),
848
+ num_outputs_(num_outputs),
849
+ vjp_fun_(std::move(vjp)),
850
+ jvp_fun_(std::move(jvp)),
851
+ vmap_fun_(std::move(vmap)) {}
852
+
853
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
854
+ override;
855
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
856
+ override;
857
+
858
+ DEFINE_GRADS();
859
+ DEFINE_VMAP();
860
+ DEFINE_NAME(CustomTransforms);
861
+
862
+ private:
863
+ void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
864
+
865
+ int num_outputs_;
866
+
867
+ std::function<std::vector<array>(
868
+ const std::vector<array>&,
869
+ const std::vector<array>&,
870
+ const std::vector<array>&)>
871
+ vjp_fun_;
872
+ std::function<std::vector<array>(
873
+ const std::vector<array>&,
874
+ const std::vector<array>&,
875
+ const std::vector<int>&)>
876
+ jvp_fun_;
877
+ std::function<std::pair<std::vector<array>, std::vector<int>>(
878
+ const std::vector<array>&,
879
+ const std::vector<int>&)>
880
+ vmap_fun_;
881
+ };
882
+
883
+ class Depends : public Primitive {
884
+ public:
885
+ explicit Depends(Stream stream) : Primitive(stream) {}
886
+
887
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
888
+ override;
889
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
890
+ override;
891
+
892
+ std::vector<array> vjp(
893
+ const std::vector<array>& primals,
894
+ const std::vector<array>& cotan,
895
+ const std::vector<int>& argnums,
896
+ const std::vector<array>& outputs) override;
897
+
898
+ DEFINE_NAME(Depends);
899
+
900
+ private:
901
+ void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
902
+ };
903
+
904
+ class Divide : public UnaryPrimitive {
905
+ public:
906
+ explicit Divide(Stream stream) : UnaryPrimitive(stream) {}
907
+
908
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
909
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
910
+
911
+ DEFINE_VMAP()
912
+ DEFINE_GRADS()
913
+ DEFINE_NAME(Divide)
914
+ DEFINE_DEFAULT_IS_EQUIVALENT()
915
+ DEFINE_INPUT_OUTPUT_SHAPE()
916
+ };
917
+
918
+ class DivMod : public Primitive {
919
+ public:
920
+ explicit DivMod(Stream stream) : Primitive(stream) {}
921
+
922
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
923
+ override;
924
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
925
+ override;
926
+
927
+ DEFINE_VMAP()
928
+ DEFINE_GRADS()
929
+ DEFINE_NAME(DivMod)
930
+ DEFINE_DEFAULT_IS_EQUIVALENT()
931
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
932
+ return std::vector{inputs[0].shape(), inputs[0].shape()};
933
+ }
934
+ };
935
+
936
+ class Select : public UnaryPrimitive {
937
+ public:
938
+ explicit Select(Stream stream) : UnaryPrimitive(stream) {}
939
+
940
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
941
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
942
+
943
+ DEFINE_VMAP()
944
+ DEFINE_GRADS()
945
+ DEFINE_NAME(Select)
946
+ DEFINE_DEFAULT_IS_EQUIVALENT()
947
+ DEFINE_INPUT_OUTPUT_SHAPE()
948
+ };
949
+
950
+ class Remainder : public UnaryPrimitive {
951
+ public:
952
+ explicit Remainder(Stream stream) : UnaryPrimitive(stream) {}
953
+
954
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
955
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
956
+
957
+ DEFINE_VMAP()
958
+ DEFINE_GRADS()
959
+ DEFINE_NAME(Remainder)
960
+ DEFINE_DEFAULT_IS_EQUIVALENT()
961
+ DEFINE_INPUT_OUTPUT_SHAPE()
962
+ };
963
+
964
+ class Equal : public UnaryPrimitive {
965
+ public:
966
+ explicit Equal(Stream stream, bool equal_nan = false)
967
+ : UnaryPrimitive(stream), equal_nan_(equal_nan) {}
968
+
969
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
970
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
971
+
972
+ DEFINE_VMAP()
973
+ DEFINE_GRADS()
974
+ DEFINE_DEFAULT_IS_EQUIVALENT()
975
+ DEFINE_INPUT_OUTPUT_SHAPE()
976
+
977
+ const char* name() const override {
978
+ if (equal_nan_) {
979
+ return "NaNEqual";
980
+ } else {
981
+ return "Equal";
982
+ }
983
+ }
984
+ auto state() const {
985
+ return equal_nan_;
986
+ };
987
+
988
+ private:
989
+ bool equal_nan_;
990
+ };
991
+
992
+ class Erf : public UnaryPrimitive {
993
+ public:
994
+ explicit Erf(Stream stream) : UnaryPrimitive(stream) {}
995
+
996
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
997
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
998
+
999
+ DEFINE_VMAP()
1000
+ DEFINE_GRADS()
1001
+ DEFINE_NAME(Erf)
1002
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1003
+ DEFINE_INPUT_OUTPUT_SHAPE()
1004
+ };
1005
+
1006
+ class ErfInv : public UnaryPrimitive {
1007
+ public:
1008
+ explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {}
1009
+
1010
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1011
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1012
+
1013
+ DEFINE_VMAP()
1014
+ DEFINE_GRADS()
1015
+ DEFINE_NAME(ErfInv)
1016
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1017
+ DEFINE_INPUT_OUTPUT_SHAPE()
1018
+ };
1019
+
1020
+ class Exp : public UnaryPrimitive {
1021
+ public:
1022
+ explicit Exp(Stream stream) : UnaryPrimitive(stream) {}
1023
+
1024
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1025
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1026
+
1027
+ DEFINE_VMAP()
1028
+ DEFINE_GRADS()
1029
+ DEFINE_NAME(Exp)
1030
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1031
+ DEFINE_INPUT_OUTPUT_SHAPE()
1032
+ };
1033
+
1034
+ class Expm1 : public UnaryPrimitive {
1035
+ public:
1036
+ explicit Expm1(Stream stream) : UnaryPrimitive(stream) {}
1037
+
1038
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1039
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1040
+
1041
+ DEFINE_VMAP()
1042
+ DEFINE_GRADS()
1043
+ DEFINE_NAME(Expm1)
1044
+ DEFINE_INPUT_OUTPUT_SHAPE()
1045
+ };
1046
+
1047
+ class ExpandDims : public UnaryPrimitive {
1048
+ public:
1049
+ explicit ExpandDims(Stream stream, std::vector<int> axes)
1050
+ : UnaryPrimitive(stream), axes_(std::move(axes)) {}
1051
+
1052
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1053
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1054
+
1055
+ DEFINE_VMAP()
1056
+ DEFINE_GRADS()
1057
+ DEFINE_NAME(ExpandDims)
1058
+
1059
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1060
+ bool is_equivalent(const Primitive& other) const override;
1061
+
1062
+ static Shape output_shape(const array& input, const std::vector<int>& axes);
1063
+ auto state() const {
1064
+ return axes_;
1065
+ }
1066
+
1067
+ private:
1068
+ void eval(const std::vector<array>& inputs, array& out);
1069
+ std::vector<int> axes_;
1070
+ };
1071
+
1072
+ class FFT : public UnaryPrimitive {
1073
+ public:
1074
+ explicit FFT(
1075
+ Stream stream,
1076
+ const std::vector<size_t>& axes,
1077
+ bool inverse,
1078
+ bool real)
1079
+ : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {}
1080
+
1081
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1082
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1083
+
1084
+ DEFINE_VMAP()
1085
+ DEFINE_GRADS()
1086
+ DEFINE_NAME(FFT)
1087
+
1088
+ bool is_equivalent(const Primitive& other) const override;
1089
+ auto state() const {
1090
+ return std::make_tuple(axes_, inverse_, real_);
1091
+ }
1092
+
1093
+ private:
1094
+ std::vector<size_t> axes_;
1095
+ bool inverse_;
1096
+ bool real_;
1097
+ };
1098
+
1099
+ class Flatten : public UnaryPrimitive {
1100
+ public:
1101
+ explicit Flatten(Stream stream, int start_axis, int end_axis)
1102
+ : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {}
1103
+
1104
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1105
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1106
+
1107
+ DEFINE_VMAP()
1108
+ DEFINE_GRADS()
1109
+ DEFINE_NAME(Flatten)
1110
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1111
+ bool is_equivalent(const Primitive& other) const override;
1112
+
1113
+ static Shape output_shape(const array& input, int start_axis, int end_axis);
1114
+ auto state() const {
1115
+ return std::make_pair(start_axis_, end_axis_);
1116
+ }
1117
+
1118
+ private:
1119
+ int start_axis_;
1120
+ int end_axis_;
1121
+ void eval(const std::vector<array>& inputs, array& out);
1122
+ };
1123
+
1124
+ class Floor : public UnaryPrimitive {
1125
+ public:
1126
+ explicit Floor(Stream stream) : UnaryPrimitive(stream) {}
1127
+
1128
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1129
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1130
+
1131
+ DEFINE_VMAP()
1132
+ DEFINE_GRADS()
1133
+ DEFINE_NAME(Floor)
1134
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1135
+ DEFINE_INPUT_OUTPUT_SHAPE()
1136
+ };
1137
+
1138
+ class Full : public UnaryPrimitive {
1139
+ public:
1140
+ explicit Full(Stream stream) : UnaryPrimitive(stream) {}
1141
+
1142
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1143
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1144
+
1145
+ DEFINE_VMAP()
1146
+ DEFINE_GRADS()
1147
+ DEFINE_NAME(Full)
1148
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1149
+ DEFINE_INPUT_OUTPUT_SHAPE()
1150
+ };
1151
+
1152
+ class Gather : public UnaryPrimitive {
1153
+ public:
1154
+ explicit Gather(Stream stream, std::vector<int> axes, Shape slice_sizes)
1155
+ : UnaryPrimitive(stream),
1156
+ axes_(std::move(axes)),
1157
+ slice_sizes_(std::move(slice_sizes)) {}
1158
+
1159
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1160
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1161
+
1162
+ DEFINE_VMAP()
1163
+ DEFINE_GRADS()
1164
+ DEFINE_NAME(Gather)
1165
+ bool is_equivalent(const Primitive& other) const override;
1166
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1167
+ std::pair<std::vector<int>, Shape> state() const {
1168
+ return {axes_, slice_sizes_};
1169
+ }
1170
+
1171
+ private:
1172
+ std::vector<int> axes_;
1173
+ Shape slice_sizes_;
1174
+ };
1175
+
1176
+ class GatherAxis : public UnaryPrimitive {
1177
+ public:
1178
+ explicit GatherAxis(Stream stream, int axis)
1179
+ : UnaryPrimitive(stream), axis_(axis) {}
1180
+
1181
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1182
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1183
+
1184
+ DEFINE_VMAP()
1185
+ DEFINE_GRADS()
1186
+ DEFINE_NAME(GatherAxis)
1187
+ bool is_equivalent(const Primitive& other) const override;
1188
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1189
+ auto state() const {
1190
+ return axis_;
1191
+ }
1192
+
1193
+ private:
1194
+ int axis_;
1195
+ };
1196
+
1197
+ class Greater : public UnaryPrimitive {
1198
+ public:
1199
+ explicit Greater(Stream stream) : UnaryPrimitive(stream) {}
1200
+
1201
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1202
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1203
+
1204
+ DEFINE_VMAP()
1205
+ DEFINE_GRADS()
1206
+ DEFINE_NAME(Greater)
1207
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1208
+ DEFINE_INPUT_OUTPUT_SHAPE()
1209
+ };
1210
+
1211
+ class GreaterEqual : public UnaryPrimitive {
1212
+ public:
1213
+ explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {}
1214
+
1215
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1216
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1217
+
1218
+ DEFINE_VMAP()
1219
+ DEFINE_GRADS()
1220
+ DEFINE_NAME(GreaterEqual)
1221
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1222
+ DEFINE_INPUT_OUTPUT_SHAPE()
1223
+ };
1224
+
1225
+ class Hadamard : public UnaryPrimitive {
1226
+ public:
1227
+ explicit Hadamard(Stream stream, float scale)
1228
+ : UnaryPrimitive(stream), scale_(scale) {}
1229
+
1230
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1231
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1232
+
1233
+ DEFINE_VMAP()
1234
+ DEFINE_GRADS()
1235
+ DEFINE_NAME(Hadamard)
1236
+ DEFINE_INPUT_OUTPUT_SHAPE()
1237
+
1238
+ bool is_equivalent(const Primitive& other) const override;
1239
+ auto state() const {
1240
+ return scale_;
1241
+ }
1242
+
1243
+ private:
1244
+ float scale_;
1245
+ };
1246
+
1247
+ class Imag : public UnaryPrimitive {
1248
+ public:
1249
+ explicit Imag(Stream stream) : UnaryPrimitive(stream) {}
1250
+
1251
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1252
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1253
+
1254
+ DEFINE_VMAP()
1255
+ DEFINE_GRADS()
1256
+ DEFINE_NAME(Imag)
1257
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1258
+ DEFINE_INPUT_OUTPUT_SHAPE()
1259
+ };
1260
+
1261
+ class Less : public UnaryPrimitive {
1262
+ public:
1263
+ explicit Less(Stream stream) : UnaryPrimitive(stream) {}
1264
+
1265
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1266
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1267
+
1268
+ DEFINE_VMAP()
1269
+ DEFINE_GRADS()
1270
+ DEFINE_NAME(Less)
1271
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1272
+ DEFINE_INPUT_OUTPUT_SHAPE()
1273
+ };
1274
+
1275
+ class LessEqual : public UnaryPrimitive {
1276
+ public:
1277
+ explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {}
1278
+
1279
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1280
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1281
+
1282
+ DEFINE_VMAP()
1283
+ DEFINE_GRADS()
1284
+ DEFINE_NAME(LessEqual)
1285
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1286
+ DEFINE_INPUT_OUTPUT_SHAPE()
1287
+ };
1288
+
1289
+ class Load : public UnaryPrimitive {
1290
+ public:
1291
+ explicit Load(
1292
+ Stream stream,
1293
+ std::shared_ptr<io::Reader> reader,
1294
+ size_t offset,
1295
+ bool swap_endianness = false)
1296
+ : UnaryPrimitive(stream),
1297
+ reader_(std::move(reader)),
1298
+ offset_(offset),
1299
+ swap_endianness_(swap_endianness) {}
1300
+
1301
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1302
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1303
+
1304
+ DEFINE_NAME(Load)
1305
+
1306
+ private:
1307
+ std::shared_ptr<io::Reader> reader_;
1308
+ size_t offset_;
1309
+ bool swap_endianness_;
1310
+ };
1311
+
1312
+ class Log : public UnaryPrimitive {
1313
+ public:
1314
+ enum Base { two, ten, e };
1315
+
1316
+ explicit Log(Stream stream, Base base)
1317
+ : UnaryPrimitive(stream), base_(base) {}
1318
+
1319
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1320
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1321
+
1322
+ DEFINE_VMAP()
1323
+ DEFINE_GRADS()
1324
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1325
+ DEFINE_INPUT_OUTPUT_SHAPE()
1326
+
1327
+ Base state() const {
1328
+ return base_;
1329
+ };
1330
+
1331
+ const char* name() const override {
1332
+ switch (base_) {
1333
+ case e:
1334
+ return "Log";
1335
+ case two:
1336
+ return "Log2";
1337
+ case ten:
1338
+ return "Log10";
1339
+ }
1340
+ return "<unknwon Log>";
1341
+ }
1342
+
1343
+ private:
1344
+ Base base_;
1345
+ };
1346
+
1347
+ class Log1p : public UnaryPrimitive {
1348
+ public:
1349
+ explicit Log1p(Stream stream) : UnaryPrimitive(stream) {}
1350
+
1351
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1352
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1353
+
1354
+ DEFINE_VMAP()
1355
+ DEFINE_GRADS()
1356
+ DEFINE_NAME(Log1p)
1357
+ DEFINE_INPUT_OUTPUT_SHAPE()
1358
+ };
1359
+
1360
+ class LogicalNot : public UnaryPrimitive {
1361
+ public:
1362
+ explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {}
1363
+
1364
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1365
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1366
+
1367
+ DEFINE_VMAP()
1368
+ DEFINE_GRADS()
1369
+ DEFINE_NAME(LogicalNot)
1370
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1371
+ DEFINE_INPUT_OUTPUT_SHAPE()
1372
+ };
1373
+
1374
+ class LogicalAnd : public UnaryPrimitive {
1375
+ public:
1376
+ explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {}
1377
+
1378
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1379
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1380
+
1381
+ DEFINE_VMAP()
1382
+ DEFINE_GRADS()
1383
+ DEFINE_NAME(LogicalAnd)
1384
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1385
+ DEFINE_INPUT_OUTPUT_SHAPE()
1386
+ };
1387
+
1388
+ class LogicalOr : public UnaryPrimitive {
1389
+ public:
1390
+ explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {}
1391
+
1392
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1393
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1394
+
1395
+ DEFINE_VMAP()
1396
+ DEFINE_GRADS()
1397
+ DEFINE_NAME(LogicalOr)
1398
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1399
+ DEFINE_INPUT_OUTPUT_SHAPE()
1400
+ };
1401
+
1402
+ class LogAddExp : public UnaryPrimitive {
1403
+ public:
1404
+ explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {}
1405
+
1406
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1407
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1408
+
1409
+ DEFINE_VMAP()
1410
+ DEFINE_GRADS()
1411
+ DEFINE_NAME(LogAddExp)
1412
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1413
+ DEFINE_INPUT_OUTPUT_SHAPE()
1414
+ };
1415
+
1416
+ class LogSumExp : public UnaryPrimitive {
1417
+ public:
1418
+ explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {}
1419
+
1420
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1421
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1422
+
1423
+ DEFINE_VMAP()
1424
+ DEFINE_GRADS()
1425
+ DEFINE_NAME(LogSumExp)
1426
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1427
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1428
+ };
1429
+
1430
+ class Matmul : public UnaryPrimitive {
1431
+ public:
1432
+ explicit Matmul(Stream stream) : UnaryPrimitive(stream) {}
1433
+
1434
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1435
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1436
+
1437
+ DEFINE_GRADS()
1438
+ DEFINE_VMAP()
1439
+ DEFINE_NAME(Matmul)
1440
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1441
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1442
+ };
1443
+
1444
+ class Maximum : public UnaryPrimitive {
1445
+ public:
1446
+ explicit Maximum(Stream stream) : UnaryPrimitive(stream) {}
1447
+
1448
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1449
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1450
+
1451
+ DEFINE_VMAP()
1452
+ DEFINE_GRADS()
1453
+ DEFINE_NAME(Maximum)
1454
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1455
+ DEFINE_INPUT_OUTPUT_SHAPE()
1456
+ };
1457
+
1458
+ class Minimum : public UnaryPrimitive {
1459
+ public:
1460
+ explicit Minimum(Stream stream) : UnaryPrimitive(stream) {}
1461
+
1462
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1463
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1464
+
1465
+ DEFINE_VMAP()
1466
+ DEFINE_GRADS()
1467
+ DEFINE_NAME(Minimum)
1468
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1469
+ DEFINE_INPUT_OUTPUT_SHAPE()
1470
+ };
1471
+
1472
+ class Multiply : public UnaryPrimitive {
1473
+ public:
1474
+ explicit Multiply(Stream stream) : UnaryPrimitive(stream) {}
1475
+
1476
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1477
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1478
+
1479
+ DEFINE_VMAP()
1480
+ DEFINE_GRADS()
1481
+ DEFINE_NAME(Multiply)
1482
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1483
+ DEFINE_INPUT_OUTPUT_SHAPE()
1484
+ };
1485
+
1486
+ class Negative : public UnaryPrimitive {
1487
+ public:
1488
+ explicit Negative(Stream stream) : UnaryPrimitive(stream) {}
1489
+
1490
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1491
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1492
+
1493
+ DEFINE_VMAP()
1494
+ DEFINE_GRADS()
1495
+ DEFINE_NAME(Negative)
1496
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1497
+ DEFINE_INPUT_OUTPUT_SHAPE()
1498
+ };
1499
+
1500
+ class NotEqual : public UnaryPrimitive {
1501
+ public:
1502
+ explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {}
1503
+
1504
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1505
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1506
+
1507
+ DEFINE_VMAP()
1508
+ DEFINE_GRADS()
1509
+ DEFINE_NAME(NotEqual)
1510
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1511
+ DEFINE_INPUT_OUTPUT_SHAPE()
1512
+ };
1513
+
1514
+ class NumberOfElements : public UnaryPrimitive {
1515
+ public:
1516
+ explicit NumberOfElements(
1517
+ Stream stream,
1518
+ std::vector<int> axes,
1519
+ bool inverted,
1520
+ Dtype dtype)
1521
+ : UnaryPrimitive(stream),
1522
+ axes_(std::move(axes)),
1523
+ inverted_(inverted),
1524
+ dtype_(dtype) {}
1525
+
1526
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1527
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1528
+
1529
+ DEFINE_VMAP()
1530
+ DEFINE_NAME(NumberOfElements)
1531
+ bool is_equivalent(const Primitive& other) const override;
1532
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
1533
+ return {{}};
1534
+ }
1535
+ std::tuple<std::vector<int>, bool, Dtype> state() const {
1536
+ return {axes_, inverted_, dtype_};
1537
+ }
1538
+
1539
+ private:
1540
+ std::vector<int> axes_;
1541
+ bool inverted_;
1542
+ Dtype dtype_;
1543
+
1544
+ void eval(const std::vector<array>& inputs, array& out);
1545
+ };
1546
+
1547
+ class Pad : public UnaryPrimitive {
1548
+ public:
1549
+ explicit Pad(
1550
+ Stream stream,
1551
+ const std::vector<int>& axes,
1552
+ const Shape& low_pad_size,
1553
+ const Shape& high_pad_size)
1554
+ : UnaryPrimitive(stream),
1555
+ axes_(axes),
1556
+ low_pad_size_(low_pad_size),
1557
+ high_pad_size_(high_pad_size) {}
1558
+
1559
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1560
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1561
+
1562
+ DEFINE_VMAP()
1563
+ DEFINE_GRADS()
1564
+ DEFINE_NAME(Pad)
1565
+ bool is_equivalent(const Primitive& other) const override;
1566
+ auto state() const {
1567
+ return std::make_tuple(axes_, low_pad_size_, high_pad_size_);
1568
+ }
1569
+
1570
+ private:
1571
+ std::vector<int> axes_;
1572
+ Shape low_pad_size_;
1573
+ Shape high_pad_size_;
1574
+ };
1575
+
1576
+ class Partition : public UnaryPrimitive {
1577
+ public:
1578
+ explicit Partition(Stream stream, int kth, int axis)
1579
+ : UnaryPrimitive(stream), kth_(kth), axis_(axis) {}
1580
+
1581
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1582
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1583
+
1584
+ DEFINE_VMAP()
1585
+ DEFINE_GRADS()
1586
+ DEFINE_NAME(Partition)
1587
+ DEFINE_INPUT_OUTPUT_SHAPE()
1588
+ bool is_equivalent(const Primitive& other) const override;
1589
+ auto state() const {
1590
+ return std::make_pair(kth_, axis_);
1591
+ };
1592
+
1593
+ private:
1594
+ int kth_;
1595
+ int axis_;
1596
+ };
1597
+
1598
+ class Power : public UnaryPrimitive {
1599
+ public:
1600
+ explicit Power(Stream stream) : UnaryPrimitive(stream) {}
1601
+
1602
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1603
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1604
+
1605
+ DEFINE_VMAP()
1606
+ DEFINE_GRADS()
1607
+ DEFINE_NAME(Power)
1608
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1609
+ DEFINE_INPUT_OUTPUT_SHAPE()
1610
+ };
1611
+
1612
+ class QuantizedMatmul : public UnaryPrimitive {
1613
+ public:
1614
+ explicit QuantizedMatmul(
1615
+ Stream stream,
1616
+ int group_size,
1617
+ int bits,
1618
+ QuantizationMode mode,
1619
+ bool transpose)
1620
+ : UnaryPrimitive(stream),
1621
+ group_size_(group_size),
1622
+ bits_(bits),
1623
+ mode_(mode),
1624
+ transpose_(transpose) {}
1625
+
1626
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1627
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1628
+
1629
+ DEFINE_VMAP()
1630
+ DEFINE_GRADS()
1631
+ DEFINE_NAME(QuantizedMatmul)
1632
+ bool is_equivalent(const Primitive& other) const override;
1633
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1634
+ auto state() const {
1635
+ return std::make_tuple(group_size_, bits_, mode_, transpose_);
1636
+ }
1637
+
1638
+ private:
1639
+ int group_size_;
1640
+ int bits_;
1641
+ QuantizationMode mode_;
1642
+ bool transpose_;
1643
+ };
1644
+
1645
+ class QQMatmul : public UnaryPrimitive {
1646
+ public:
1647
+ explicit QQMatmul(
1648
+ Stream stream,
1649
+ int group_size,
1650
+ int bits,
1651
+ QuantizationMode mode)
1652
+ : UnaryPrimitive(stream),
1653
+ group_size_(group_size),
1654
+ bits_(bits),
1655
+ mode_(mode) {}
1656
+
1657
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1658
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1659
+
1660
+ // DEFINE_VMAP()
1661
+ DEFINE_GRADS()
1662
+ DEFINE_NAME(QQMatmul)
1663
+ bool is_equivalent(const Primitive& other) const override;
1664
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1665
+ auto state() const {
1666
+ return std::make_tuple(group_size_, bits_, mode_);
1667
+ }
1668
+
1669
+ private:
1670
+ int group_size_;
1671
+ int bits_;
1672
+ QuantizationMode mode_;
1673
+ };
1674
+
1675
+ class GatherQMM : public UnaryPrimitive {
1676
+ public:
1677
+ explicit GatherQMM(
1678
+ Stream stream,
1679
+ int group_size,
1680
+ int bits,
1681
+ QuantizationMode mode,
1682
+ bool transpose,
1683
+ bool left_sorted = false,
1684
+ bool right_sorted = false)
1685
+ : UnaryPrimitive(stream),
1686
+ group_size_(group_size),
1687
+ bits_(bits),
1688
+ mode_(mode),
1689
+ transpose_(transpose),
1690
+ left_sorted_(left_sorted),
1691
+ right_sorted_(right_sorted) {}
1692
+
1693
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1694
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1695
+
1696
+ DEFINE_VMAP()
1697
+ DEFINE_GRADS()
1698
+ DEFINE_NAME(GatherQMM)
1699
+ bool is_equivalent(const Primitive& other) const override;
1700
+ auto state() const {
1701
+ return std::make_tuple(
1702
+ group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_);
1703
+ }
1704
+
1705
+ private:
1706
+ int group_size_;
1707
+ int bits_;
1708
+ QuantizationMode mode_;
1709
+ bool transpose_;
1710
+ bool left_sorted_;
1711
+ bool right_sorted_;
1712
+ };
1713
+
1714
+ class RandomBits : public UnaryPrimitive {
1715
+ public:
1716
+ explicit RandomBits(Stream stream, const Shape& shape, int width)
1717
+ : UnaryPrimitive(stream), shape_(shape), width_(width) {}
1718
+
1719
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1720
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1721
+
1722
+ DEFINE_VMAP()
1723
+ DEFINE_NAME(RandomBits)
1724
+ bool is_equivalent(const Primitive& other) const override;
1725
+ std::pair<Shape, int> state() const {
1726
+ return {shape_, width_};
1727
+ };
1728
+
1729
+ private:
1730
+ Shape shape_;
1731
+ int width_;
1732
+ };
1733
+
1734
+ class Real : public UnaryPrimitive {
1735
+ public:
1736
+ explicit Real(Stream stream) : UnaryPrimitive(stream) {}
1737
+
1738
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1739
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1740
+
1741
+ DEFINE_VMAP()
1742
+ DEFINE_GRADS()
1743
+ DEFINE_NAME(Real)
1744
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1745
+ DEFINE_INPUT_OUTPUT_SHAPE()
1746
+ };
1747
+
1748
+ class Reshape : public UnaryPrimitive {
1749
+ public:
1750
+ explicit Reshape(Stream stream, const Shape& shape)
1751
+ : UnaryPrimitive(stream), shape_(shape) {}
1752
+
1753
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1754
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1755
+
1756
+ DEFINE_VMAP()
1757
+ DEFINE_GRADS()
1758
+ DEFINE_NAME(Reshape)
1759
+ bool is_equivalent(const Primitive& other) const override;
1760
+ Shape state() const {
1761
+ return shape_;
1762
+ };
1763
+ static Shape output_shape(const array& input, Shape shape);
1764
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1765
+
1766
+ private:
1767
+ Shape shape_;
1768
+ };
1769
+
1770
+ class Reduce : public UnaryPrimitive {
1771
+ public:
1772
+ enum ReduceType { And, Or, Sum, Prod, Min, Max };
1773
+
1774
+ explicit Reduce(
1775
+ Stream stream,
1776
+ ReduceType reduce_type,
1777
+ const std::vector<int>& axes)
1778
+ : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1779
+
1780
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1781
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1782
+
1783
+ DEFINE_VMAP()
1784
+ DEFINE_GRADS();
1785
+
1786
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1787
+
1788
+ const char* name() const override {
1789
+ switch (reduce_type_) {
1790
+ case And:
1791
+ return "And";
1792
+ case Or:
1793
+ return "Or";
1794
+ case Sum:
1795
+ return "Sum";
1796
+ case Prod:
1797
+ return "Prod";
1798
+ case Min:
1799
+ return "Min";
1800
+ case Max:
1801
+ return "Max";
1802
+ }
1803
+ return "<unknwon Reduce>";
1804
+ }
1805
+
1806
+ bool is_equivalent(const Primitive& other) const override;
1807
+ std::pair<ReduceType, std::vector<int>> state() const {
1808
+ return {reduce_type_, axes_};
1809
+ };
1810
+
1811
+ private:
1812
+ ReduceType reduce_type_;
1813
+ std::vector<int> axes_;
1814
+ };
1815
+
1816
+ class Round : public UnaryPrimitive {
1817
+ public:
1818
+ explicit Round(Stream stream) : UnaryPrimitive(stream) {}
1819
+
1820
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1821
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1822
+
1823
+ DEFINE_VMAP()
1824
+ DEFINE_GRADS()
1825
+ DEFINE_NAME(Round)
1826
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1827
+ DEFINE_INPUT_OUTPUT_SHAPE()
1828
+ };
1829
+
1830
+ class Scan : public UnaryPrimitive {
1831
+ public:
1832
+ enum ReduceType { Max, Min, Sum, Prod, LogAddExp };
1833
+
1834
+ explicit Scan(
1835
+ Stream stream,
1836
+ ReduceType reduce_type,
1837
+ int axis,
1838
+ bool reverse,
1839
+ bool inclusive)
1840
+ : UnaryPrimitive(stream),
1841
+ reduce_type_(reduce_type),
1842
+ axis_(axis),
1843
+ reverse_(reverse),
1844
+ inclusive_(inclusive) {}
1845
+
1846
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1847
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1848
+
1849
+ DEFINE_VMAP()
1850
+ DEFINE_GRADS();
1851
+
1852
+ const char* name() const override {
1853
+ switch (reduce_type_) {
1854
+ case Sum:
1855
+ return "CumSum";
1856
+ case Prod:
1857
+ return "CumProd";
1858
+ case Min:
1859
+ return "CumMin";
1860
+ case Max:
1861
+ return "CumMax";
1862
+ case LogAddExp:
1863
+ return "CumLogAddExp";
1864
+ }
1865
+ return "<unknwon Scan>";
1866
+ }
1867
+
1868
+ bool is_equivalent(const Primitive& other) const override;
1869
+ auto state() const {
1870
+ return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_);
1871
+ }
1872
+
1873
+ private:
1874
+ ReduceType reduce_type_;
1875
+ int axis_;
1876
+ bool reverse_;
1877
+ bool inclusive_;
1878
+ };
1879
+
1880
+ class Scatter : public UnaryPrimitive {
1881
+ public:
1882
+ enum ReduceType { Max, Min, Sum, Prod, None };
1883
+
1884
+ explicit Scatter(
1885
+ Stream stream,
1886
+ ReduceType reduce_type,
1887
+ const std::vector<int>& axes)
1888
+ : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {}
1889
+
1890
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1891
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1892
+
1893
+ DEFINE_VMAP();
1894
+ DEFINE_GRADS();
1895
+
1896
+ const char* name() const override {
1897
+ switch (reduce_type_) {
1898
+ case Sum:
1899
+ return "Scatter Sum";
1900
+ case Prod:
1901
+ return "Scatter Prod";
1902
+ case Min:
1903
+ return "Scatter Min";
1904
+ case Max:
1905
+ return "Scatter Max";
1906
+ case None:
1907
+ return "Scatter";
1908
+ }
1909
+ return "<unknwon Scatter>";
1910
+ }
1911
+
1912
+ bool is_equivalent(const Primitive& other) const override;
1913
+ std::pair<ReduceType, std::vector<int>> state() const {
1914
+ return {reduce_type_, axes_};
1915
+ };
1916
+
1917
+ private:
1918
+ ReduceType reduce_type_;
1919
+ std::vector<int> axes_;
1920
+ };
1921
+
1922
+ class ScatterAxis : public UnaryPrimitive {
1923
+ public:
1924
+ enum ReduceType { Sum, None };
1925
+
1926
+ explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis)
1927
+ : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {}
1928
+
1929
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1930
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1931
+
1932
+ DEFINE_VMAP()
1933
+ DEFINE_GRADS()
1934
+
1935
+ const char* name() const override {
1936
+ switch (reduce_type_) {
1937
+ case Sum:
1938
+ return "ScatterAxis Sum";
1939
+ case None:
1940
+ return "ScatterAxis";
1941
+ }
1942
+ return "<unknwon ScatterAxis>";
1943
+ }
1944
+
1945
+ bool is_equivalent(const Primitive& other) const override;
1946
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
1947
+ std::pair<ReduceType, int> state() const {
1948
+ return {reduce_type_, axis_};
1949
+ }
1950
+
1951
+ private:
1952
+ ReduceType reduce_type_;
1953
+ int axis_;
1954
+ };
1955
+
1956
+ class MaskedScatter : public UnaryPrimitive {
1957
+ public:
1958
+ explicit MaskedScatter(Stream stream) : UnaryPrimitive(stream) {}
1959
+
1960
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1961
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1962
+
1963
+ DEFINE_VMAP();
1964
+ DEFINE_GRADS();
1965
+ DEFINE_NAME(MaskedScatter);
1966
+ DEFINE_DEFAULT_IS_EQUIVALENT();
1967
+ DEFINE_INPUT_OUTPUT_SHAPE();
1968
+ };
1969
+
1970
+ class Sigmoid : public UnaryPrimitive {
1971
+ public:
1972
+ explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {}
1973
+
1974
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1975
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1976
+
1977
+ DEFINE_VMAP()
1978
+ DEFINE_GRADS()
1979
+ DEFINE_NAME(Sigmoid)
1980
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1981
+ DEFINE_INPUT_OUTPUT_SHAPE()
1982
+ };
1983
+
1984
+ class Sign : public UnaryPrimitive {
1985
+ public:
1986
+ explicit Sign(Stream stream) : UnaryPrimitive(stream) {}
1987
+
1988
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
1989
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
1990
+
1991
+ DEFINE_VMAP()
1992
+ DEFINE_GRADS()
1993
+ DEFINE_NAME(Sign)
1994
+ DEFINE_DEFAULT_IS_EQUIVALENT()
1995
+ DEFINE_INPUT_OUTPUT_SHAPE()
1996
+ };
1997
+
1998
+ class Sin : public UnaryPrimitive {
1999
+ public:
2000
+ explicit Sin(Stream stream) : UnaryPrimitive(stream) {}
2001
+
2002
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2003
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2004
+
2005
+ DEFINE_VMAP()
2006
+ DEFINE_GRADS()
2007
+ DEFINE_NAME(Sin)
2008
+ DEFINE_DEFAULT_IS_EQUIVALENT()
2009
+ DEFINE_INPUT_OUTPUT_SHAPE()
2010
+ };
2011
+
2012
+ class Sinh : public UnaryPrimitive {
2013
+ public:
2014
+ explicit Sinh(Stream stream) : UnaryPrimitive(stream) {}
2015
+
2016
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2017
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2018
+
2019
+ DEFINE_VMAP()
2020
+ DEFINE_GRADS()
2021
+ DEFINE_NAME(Sinh)
2022
+ DEFINE_DEFAULT_IS_EQUIVALENT()
2023
+ DEFINE_INPUT_OUTPUT_SHAPE()
2024
+ };
2025
+
2026
+ class Slice : public UnaryPrimitive {
2027
+ public:
2028
+ explicit Slice(
2029
+ Stream stream,
2030
+ const Shape& start_indices,
2031
+ const Shape& end_indices,
2032
+ const Shape& strides)
2033
+ : UnaryPrimitive(stream),
2034
+ start_indices_(start_indices),
2035
+ end_indices_(end_indices),
2036
+ strides_(strides) {}
2037
+
2038
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2039
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2040
+
2041
+ DEFINE_VMAP()
2042
+ DEFINE_GRADS()
2043
+ DEFINE_NAME(Slice)
2044
+ bool is_equivalent(const Primitive& other) const override;
2045
+ auto state() const {
2046
+ return std::make_tuple(start_indices_, end_indices_, strides_);
2047
+ }
2048
+
2049
+ private:
2050
+ Shape start_indices_;
2051
+ Shape end_indices_;
2052
+ Shape strides_;
2053
+ };
2054
+
2055
+ class SliceUpdate : public UnaryPrimitive {
2056
+ public:
2057
+ explicit SliceUpdate(
2058
+ Stream stream,
2059
+ const Shape& start_indices,
2060
+ const Shape& end_indices,
2061
+ const Shape& strides)
2062
+ : UnaryPrimitive(stream),
2063
+ start_indices_(start_indices),
2064
+ end_indices_(end_indices),
2065
+ strides_(strides) {}
2066
+
2067
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2068
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2069
+
2070
+ DEFINE_VMAP()
2071
+ DEFINE_GRADS()
2072
+ DEFINE_NAME(SliceUpdate)
2073
+ bool is_equivalent(const Primitive& other) const override;
2074
+ DEFINE_INPUT_OUTPUT_SHAPE()
2075
+ auto state() const {
2076
+ return std::make_tuple(start_indices_, end_indices_, strides_);
2077
+ }
2078
+
2079
+ private:
2080
+ Shape start_indices_;
2081
+ Shape end_indices_;
2082
+ Shape strides_;
2083
+ };
2084
+
2085
+ class DynamicSlice : public UnaryPrimitive {
2086
+ public:
2087
+ explicit DynamicSlice(Stream stream, std::vector<int> axes, Shape slice_size)
2088
+ : UnaryPrimitive(stream),
2089
+ axes_(std::move(axes)),
2090
+ slice_size_(std::move(slice_size)) {}
2091
+
2092
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2093
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2094
+
2095
+ DEFINE_VMAP()
2096
+ DEFINE_GRADS()
2097
+ DEFINE_NAME(DynamicSlice)
2098
+ bool is_equivalent(const Primitive& other) const override;
2099
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2100
+ auto state() const {
2101
+ return std::make_pair(axes_, slice_size_);
2102
+ }
2103
+
2104
+ private:
2105
+ std::vector<int> axes_;
2106
+ Shape slice_size_;
2107
+ };
2108
+
2109
+ class DynamicSliceUpdate : public UnaryPrimitive {
2110
+ public:
2111
+ explicit DynamicSliceUpdate(Stream stream, std::vector<int> axes)
2112
+ : UnaryPrimitive(stream), axes_(std::move(axes)) {}
2113
+
2114
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2115
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2116
+
2117
+ DEFINE_VMAP()
2118
+ DEFINE_GRADS()
2119
+ DEFINE_NAME(DynamicSliceUpdate)
2120
+ bool is_equivalent(const Primitive& other) const override;
2121
+ DEFINE_INPUT_OUTPUT_SHAPE()
2122
+ auto state() const {
2123
+ return axes_;
2124
+ }
2125
+
2126
+ private:
2127
+ std::vector<int> axes_;
2128
+ };
2129
+
2130
+ class Softmax : public UnaryPrimitive {
2131
+ public:
2132
+ explicit Softmax(Stream stream, bool precise)
2133
+ : UnaryPrimitive(stream), precise_(precise) {}
2134
+
2135
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2136
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2137
+
2138
+ DEFINE_VMAP()
2139
+ DEFINE_GRADS()
2140
+ DEFINE_NAME(Softmax)
2141
+ DEFINE_INPUT_OUTPUT_SHAPE()
2142
+
2143
+ bool is_equivalent(const Primitive& other) const override;
2144
+ auto state() const {
2145
+ return precise_;
2146
+ };
2147
+
2148
+ private:
2149
+ bool precise_;
2150
+ };
2151
+
2152
+ class Sort : public UnaryPrimitive {
2153
+ public:
2154
+ explicit Sort(Stream stream, int axis)
2155
+ : UnaryPrimitive(stream), axis_(axis) {}
2156
+
2157
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2158
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2159
+
2160
+ DEFINE_VMAP()
2161
+ DEFINE_GRADS()
2162
+ DEFINE_NAME(Sort)
2163
+ DEFINE_INPUT_OUTPUT_SHAPE()
2164
+ bool is_equivalent(const Primitive& other) const override;
2165
+ auto state() const {
2166
+ return axis_;
2167
+ }
2168
+
2169
+ private:
2170
+ int axis_;
2171
+ };
2172
+
2173
+ class Split : public Primitive {
2174
+ public:
2175
+ explicit Split(Stream stream, const Shape& indices, int axis)
2176
+ : Primitive(stream), indices_(indices), axis_(axis) {}
2177
+
2178
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2179
+ override;
2180
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2181
+ override;
2182
+
2183
+ DEFINE_VMAP()
2184
+ DEFINE_GRADS()
2185
+ DEFINE_NAME(Split)
2186
+ bool is_equivalent(const Primitive& other) const override;
2187
+ std::pair<Shape, int> state() const {
2188
+ return {indices_, axis_};
2189
+ };
2190
+
2191
+ private:
2192
+ void eval(const std::vector<array>& inputs, std::vector<array>& outputs);
2193
+
2194
+ Shape indices_;
2195
+ int axis_;
2196
+ };
2197
+
2198
+ class Square : public UnaryPrimitive {
2199
+ public:
2200
+ explicit Square(Stream stream) : UnaryPrimitive(stream) {}
2201
+
2202
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2203
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2204
+
2205
+ DEFINE_VMAP()
2206
+ DEFINE_GRADS()
2207
+ DEFINE_NAME(Square)
2208
+ DEFINE_DEFAULT_IS_EQUIVALENT()
2209
+ DEFINE_INPUT_OUTPUT_SHAPE()
2210
+ };
2211
+
2212
+ class Sqrt : public UnaryPrimitive {
2213
+ public:
2214
+ explicit Sqrt(Stream stream, bool recip = false)
2215
+ : UnaryPrimitive(stream), recip_(recip) {}
2216
+
2217
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2218
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2219
+
2220
+ DEFINE_VMAP()
2221
+ DEFINE_GRADS()
2222
+ DEFINE_INPUT_OUTPUT_SHAPE()
2223
+ bool is_equivalent(const Primitive& other) const override;
2224
+ auto state() const {
2225
+ return recip_;
2226
+ }
2227
+
2228
+ const char* name() const override {
2229
+ if (recip_) {
2230
+ return "Rsqrt";
2231
+ } else {
2232
+ return "Sqrt";
2233
+ }
2234
+ }
2235
+
2236
+ private:
2237
+ bool recip_;
2238
+ };
2239
+
2240
+ class StopGradient : public UnaryPrimitive {
2241
+ public:
2242
+ explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {}
2243
+
2244
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2245
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2246
+
2247
+ DEFINE_VMAP()
2248
+ DEFINE_NAME(StopGradient)
2249
+ DEFINE_DEFAULT_IS_EQUIVALENT()
2250
+ DEFINE_INPUT_OUTPUT_SHAPE()
2251
+
2252
+ private:
2253
+ void eval(const std::vector<array>& inputs, array& out);
2254
+ };
2255
+
2256
+ class Subtract : public UnaryPrimitive {
2257
+ public:
2258
+ explicit Subtract(Stream stream) : UnaryPrimitive(stream) {}
2259
+
2260
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2261
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2262
+
2263
+ DEFINE_VMAP()
2264
+ DEFINE_GRADS()
2265
+ DEFINE_NAME(Subtract)
2266
+ DEFINE_DEFAULT_IS_EQUIVALENT()
2267
+ DEFINE_INPUT_OUTPUT_SHAPE()
2268
+ };
2269
+
2270
+ class Squeeze : public UnaryPrimitive {
2271
+ public:
2272
+ explicit Squeeze(Stream stream, std::vector<int> axes)
2273
+ : UnaryPrimitive(stream), axes_(std::move(axes)) {}
2274
+
2275
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2276
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2277
+
2278
+ DEFINE_VMAP()
2279
+ DEFINE_GRADS()
2280
+ DEFINE_NAME(Squeeze)
2281
+
2282
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2283
+ bool is_equivalent(const Primitive& other) const override;
2284
+
2285
+ static Shape output_shape(const array& input, const std::vector<int>& axes);
2286
+ auto state() const {
2287
+ return axes_;
2288
+ };
2289
+
2290
+ private:
2291
+ void eval(const std::vector<array>& inputs, array& out);
2292
+ std::vector<int> axes_;
2293
+ };
2294
+
2295
+ class Tan : public UnaryPrimitive {
2296
+ public:
2297
+ explicit Tan(Stream stream) : UnaryPrimitive(stream) {}
2298
+
2299
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2300
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2301
+
2302
+ DEFINE_VMAP()
2303
+ DEFINE_GRADS()
2304
+ DEFINE_NAME(Tan)
2305
+ DEFINE_DEFAULT_IS_EQUIVALENT()
2306
+ DEFINE_INPUT_OUTPUT_SHAPE()
2307
+ };
2308
+
2309
+ class Tanh : public UnaryPrimitive {
2310
+ public:
2311
+ explicit Tanh(Stream stream) : UnaryPrimitive(stream) {}
2312
+
2313
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2314
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2315
+
2316
+ DEFINE_VMAP()
2317
+ DEFINE_GRADS()
2318
+ DEFINE_NAME(Tanh)
2319
+ DEFINE_DEFAULT_IS_EQUIVALENT()
2320
+ DEFINE_INPUT_OUTPUT_SHAPE()
2321
+ };
2322
+
2323
+ class Unflatten : public UnaryPrimitive {
2324
+ public:
2325
+ explicit Unflatten(Stream stream, int axis, Shape shape)
2326
+ : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {}
2327
+
2328
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2329
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2330
+
2331
+ DEFINE_VMAP()
2332
+ DEFINE_GRADS()
2333
+ DEFINE_NAME(Unflatten)
2334
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2335
+ bool is_equivalent(const Primitive& other) const override;
2336
+
2337
+ static Shape output_shape(const array& input, int axis, const Shape& shape);
2338
+ auto state() const {
2339
+ return std::make_pair(axis_, shape_);
2340
+ }
2341
+
2342
+ private:
2343
+ int axis_;
2344
+ Shape shape_;
2345
+ void eval(const std::vector<array>& inputs, array& out);
2346
+ };
2347
+
2348
+ class View : public UnaryPrimitive {
2349
+ public:
2350
+ explicit View(Stream stream, Dtype dtype)
2351
+ : UnaryPrimitive(stream), dtype_(dtype) {}
2352
+
2353
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2354
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2355
+
2356
+ DEFINE_VMAP()
2357
+ const char* name() const override;
2358
+ bool is_equivalent(const Primitive& other) const override;
2359
+ auto state() const {
2360
+ return dtype_;
2361
+ }
2362
+
2363
+ private:
2364
+ Dtype dtype_;
2365
+ mutable std::string name_;
2366
+ };
2367
+
2368
+ class Transpose : public UnaryPrimitive {
2369
+ public:
2370
+ explicit Transpose(Stream stream, const std::vector<int>& axes)
2371
+ : UnaryPrimitive(stream), axes_(axes) {}
2372
+
2373
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2374
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2375
+
2376
+ DEFINE_VMAP()
2377
+ DEFINE_GRADS()
2378
+ DEFINE_NAME(Transpose)
2379
+ bool is_equivalent(const Primitive& other) const override;
2380
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2381
+ std::vector<int> state() const {
2382
+ return axes_;
2383
+ };
2384
+
2385
+ private:
2386
+ std::vector<int> axes_;
2387
+
2388
+ void eval(const std::vector<array>& inputs, array& out);
2389
+ };
2390
+
2391
+ /* QR Factorization primitive. */
2392
+ class QRF : public Primitive {
2393
+ public:
2394
+ explicit QRF(Stream stream) : Primitive(stream) {}
2395
+
2396
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2397
+ override;
2398
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2399
+ override;
2400
+
2401
+ DEFINE_NAME(QRF)
2402
+ };
2403
+
2404
+ /* SVD primitive. */
2405
+ class SVD : public Primitive {
2406
+ public:
2407
+ explicit SVD(Stream stream, bool compute_uv)
2408
+ : Primitive(stream), compute_uv_(compute_uv) {}
2409
+
2410
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2411
+ override;
2412
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2413
+ override;
2414
+
2415
+ DEFINE_VMAP()
2416
+ DEFINE_NAME(SVD)
2417
+ auto state() const {
2418
+ return compute_uv_;
2419
+ }
2420
+
2421
+ private:
2422
+ bool compute_uv_;
2423
+ };
2424
+
2425
+ /* Matrix inversion primitive. */
2426
+ class Inverse : public UnaryPrimitive {
2427
+ public:
2428
+ explicit Inverse(Stream stream, bool tri, bool upper)
2429
+ : UnaryPrimitive(stream), tri_(tri), upper_(upper) {}
2430
+
2431
+ void eval_cpu(const std::vector<array>& inputs, array& output) override;
2432
+ void eval_gpu(const std::vector<array>& inputs, array& output) override;
2433
+
2434
+ DEFINE_VMAP()
2435
+ DEFINE_NAME(Inverse)
2436
+ auto state() const {
2437
+ return std::make_pair(tri_, upper_);
2438
+ }
2439
+
2440
+ private:
2441
+ bool tri_;
2442
+ bool upper_;
2443
+ };
2444
+
2445
+ class Cholesky : public UnaryPrimitive {
2446
+ public:
2447
+ explicit Cholesky(Stream stream, bool upper)
2448
+ : UnaryPrimitive(stream), upper_(upper) {}
2449
+
2450
+ void eval_cpu(const std::vector<array>& inputs, array& out) override;
2451
+ void eval_gpu(const std::vector<array>& inputs, array& out) override;
2452
+ auto state() const {
2453
+ return upper_;
2454
+ }
2455
+
2456
+ DEFINE_VMAP()
2457
+ DEFINE_NAME(Cholesky)
2458
+
2459
+ private:
2460
+ bool upper_;
2461
+ };
2462
+
2463
+ class Eig : public Primitive {
2464
+ public:
2465
+ explicit Eig(Stream stream, bool compute_eigenvectors)
2466
+ : Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {}
2467
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2468
+ override;
2469
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2470
+ override;
2471
+
2472
+ DEFINE_VMAP()
2473
+ DEFINE_NAME(Eig)
2474
+
2475
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2476
+
2477
+ bool is_equivalent(const Primitive& other) const override;
2478
+ auto state() const {
2479
+ return compute_eigenvectors_;
2480
+ }
2481
+
2482
+ private:
2483
+ bool compute_eigenvectors_;
2484
+ };
2485
+
2486
+ class Eigh : public Primitive {
2487
+ public:
2488
+ explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors)
2489
+ : Primitive(stream),
2490
+ uplo_(std::move(uplo)),
2491
+ compute_eigenvectors_(compute_eigenvectors) {}
2492
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2493
+ override;
2494
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2495
+ override;
2496
+
2497
+ DEFINE_VMAP()
2498
+ DEFINE_NAME(Eigh)
2499
+
2500
+ std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
2501
+
2502
+ bool is_equivalent(const Primitive& other) const override;
2503
+ auto state() const {
2504
+ return std::make_pair(uplo_, compute_eigenvectors_);
2505
+ }
2506
+
2507
+ private:
2508
+ std::string uplo_;
2509
+ bool compute_eigenvectors_;
2510
+ };
2511
+
2512
+ /* LU Factorization primitive. */
2513
+ class LUF : public Primitive {
2514
+ public:
2515
+ explicit LUF(Stream stream) : Primitive(stream) {}
2516
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2517
+ override;
2518
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
2519
+ override;
2520
+
2521
+ DEFINE_NAME(LUF)
2522
+ };
2523
+
2524
+ } // namespace mlx::core