mlx 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlx might be problematic. Click here for more details.

Files changed (914) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/CMakeLists.txt +7 -0
  3. data/ext/mlx/Makefile +273 -0
  4. data/ext/mlx/extconf.rb +94 -0
  5. data/ext/mlx/mkmf.log +44 -0
  6. data/ext/mlx/native.bundle +0 -0
  7. data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
  8. data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
  9. data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
  10. data/ext/mlx/native.cpp +8027 -0
  11. data/ext/mlx/native.o +0 -0
  12. data/lib/mlx/core.rb +1678 -0
  13. data/lib/mlx/distributed_utils/common.rb +116 -0
  14. data/lib/mlx/distributed_utils/config.rb +600 -0
  15. data/lib/mlx/distributed_utils/launch.rb +490 -0
  16. data/lib/mlx/extension.rb +24 -0
  17. data/lib/mlx/nn/base.rb +388 -0
  18. data/lib/mlx/nn/init.rb +140 -0
  19. data/lib/mlx/nn/layers/activations.rb +336 -0
  20. data/lib/mlx/nn/layers/base.rb +6 -0
  21. data/lib/mlx/nn/layers/containers.rb +20 -0
  22. data/lib/mlx/nn/layers/convolution.rb +120 -0
  23. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  24. data/lib/mlx/nn/layers/distributed.rb +309 -0
  25. data/lib/mlx/nn/layers/dropout.rb +75 -0
  26. data/lib/mlx/nn/layers/embedding.rb +28 -0
  27. data/lib/mlx/nn/layers/linear.rb +79 -0
  28. data/lib/mlx/nn/layers/normalization.rb +216 -0
  29. data/lib/mlx/nn/layers/pooling.rb +167 -0
  30. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  31. data/lib/mlx/nn/layers/quantized.rb +215 -0
  32. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  33. data/lib/mlx/nn/layers/transformer.rb +330 -0
  34. data/lib/mlx/nn/layers/upsample.rb +97 -0
  35. data/lib/mlx/nn/layers.rb +18 -0
  36. data/lib/mlx/nn/losses.rb +251 -0
  37. data/lib/mlx/nn/utils.rb +167 -0
  38. data/lib/mlx/nn.rb +12 -0
  39. data/lib/mlx/optimizers/optimizers.rb +808 -0
  40. data/lib/mlx/optimizers/schedulers.rb +62 -0
  41. data/lib/mlx/optimizers.rb +9 -0
  42. data/lib/mlx/utils.rb +171 -0
  43. data/lib/mlx/version +1 -0
  44. data/lib/mlx/version.rb +5 -0
  45. data/lib/mlx.rb +64 -0
  46. data/mlx/.clang-format +87 -0
  47. data/mlx/.git +1 -0
  48. data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  49. data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
  50. data/mlx/.github/actions/build-docs/action.yml +38 -0
  51. data/mlx/.github/actions/build-linux/action.yml +38 -0
  52. data/mlx/.github/actions/build-linux-release/action.yml +42 -0
  53. data/mlx/.github/actions/build-macos/action.yml +80 -0
  54. data/mlx/.github/actions/build-macos-release/action.yml +36 -0
  55. data/mlx/.github/actions/build-windows/action.yml +26 -0
  56. data/mlx/.github/actions/setup-linux/action.yml +93 -0
  57. data/mlx/.github/actions/setup-macos/action.yml +24 -0
  58. data/mlx/.github/actions/setup-windows/action.yml +42 -0
  59. data/mlx/.github/actions/test-linux/action.yml +69 -0
  60. data/mlx/.github/actions/test-windows/action.yml +20 -0
  61. data/mlx/.github/dependabot.yml +6 -0
  62. data/mlx/.github/pull_request_template.md +12 -0
  63. data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
  64. data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
  65. data/mlx/.github/workflows/build_and_test.yml +152 -0
  66. data/mlx/.github/workflows/documentation.yml +28 -0
  67. data/mlx/.github/workflows/nightly.yml +104 -0
  68. data/mlx/.github/workflows/release.yml +256 -0
  69. data/mlx/.gitignore +81 -0
  70. data/mlx/.pre-commit-config.yaml +27 -0
  71. data/mlx/ACKNOWLEDGMENTS.md +268 -0
  72. data/mlx/CITATION.cff +24 -0
  73. data/mlx/CMakeLists.txt +437 -0
  74. data/mlx/CODE_OF_CONDUCT.md +132 -0
  75. data/mlx/CONTRIBUTING.md +38 -0
  76. data/mlx/LICENSE +21 -0
  77. data/mlx/MANIFEST.in +6 -0
  78. data/mlx/README.md +121 -0
  79. data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  80. data/mlx/benchmarks/cpp/autograd.cpp +39 -0
  81. data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  82. data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  83. data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  84. data/mlx/benchmarks/cpp/time_utils.h +39 -0
  85. data/mlx/benchmarks/numpy/single_ops.py +39 -0
  86. data/mlx/benchmarks/numpy/time_utils.py +20 -0
  87. data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  88. data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  89. data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
  90. data/mlx/benchmarks/python/comparative/README.md +15 -0
  91. data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  92. data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  93. data/mlx/benchmarks/python/comparative/compare.py +284 -0
  94. data/mlx/benchmarks/python/compile_bench.py +107 -0
  95. data/mlx/benchmarks/python/conv1d_bench.py +123 -0
  96. data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  97. data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  98. data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  99. data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  100. data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  101. data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  102. data/mlx/benchmarks/python/conv_bench.py +135 -0
  103. data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  104. data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  105. data/mlx/benchmarks/python/distributed_bench.py +66 -0
  106. data/mlx/benchmarks/python/einsum_bench.py +84 -0
  107. data/mlx/benchmarks/python/fft_bench.py +118 -0
  108. data/mlx/benchmarks/python/gather_bench.py +52 -0
  109. data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  110. data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  111. data/mlx/benchmarks/python/hadamard_bench.py +70 -0
  112. data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
  113. data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
  114. data/mlx/benchmarks/python/masked_scatter.py +212 -0
  115. data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
  116. data/mlx/benchmarks/python/rope_bench.py +35 -0
  117. data/mlx/benchmarks/python/scatter_bench.py +96 -0
  118. data/mlx/benchmarks/python/sdpa_bench.py +223 -0
  119. data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
  120. data/mlx/benchmarks/python/single_ops.py +132 -0
  121. data/mlx/benchmarks/python/synchronize_bench.py +55 -0
  122. data/mlx/benchmarks/python/time_utils.py +38 -0
  123. data/mlx/cmake/FindCUDNN.cmake +177 -0
  124. data/mlx/cmake/FindNCCL.cmake +54 -0
  125. data/mlx/cmake/Findnvpl.cmake +3 -0
  126. data/mlx/cmake/extension.cmake +50 -0
  127. data/mlx/docs/.clang-format +2 -0
  128. data/mlx/docs/.gitignore +3 -0
  129. data/mlx/docs/.nojekyll +0 -0
  130. data/mlx/docs/Doxyfile +51 -0
  131. data/mlx/docs/Makefile +18 -0
  132. data/mlx/docs/README.md +54 -0
  133. data/mlx/docs/index.html +1 -0
  134. data/mlx/docs/requirements.txt +5 -0
  135. data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
  136. data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
  137. data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
  138. data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
  139. data/mlx/docs/src/_static/mlx_logo.png +0 -0
  140. data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
  141. data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
  142. data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
  143. data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
  144. data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
  145. data/mlx/docs/src/_templates/module-base-class.rst +33 -0
  146. data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
  147. data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
  148. data/mlx/docs/src/conf.py +99 -0
  149. data/mlx/docs/src/cpp/ops.rst +7 -0
  150. data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
  151. data/mlx/docs/src/dev/extensions.rst +811 -0
  152. data/mlx/docs/src/dev/metal_debugger.rst +68 -0
  153. data/mlx/docs/src/dev/metal_logging.rst +40 -0
  154. data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
  155. data/mlx/docs/src/examples/data_parallelism.rst +91 -0
  156. data/mlx/docs/src/examples/linear_regression.rst +77 -0
  157. data/mlx/docs/src/examples/llama-inference.rst +382 -0
  158. data/mlx/docs/src/examples/mlp.rst +134 -0
  159. data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
  160. data/mlx/docs/src/index.rst +96 -0
  161. data/mlx/docs/src/install.rst +340 -0
  162. data/mlx/docs/src/python/array.rst +65 -0
  163. data/mlx/docs/src/python/cuda.rst +9 -0
  164. data/mlx/docs/src/python/data_types.rst +78 -0
  165. data/mlx/docs/src/python/devices_and_streams.rst +21 -0
  166. data/mlx/docs/src/python/distributed.rst +22 -0
  167. data/mlx/docs/src/python/export.rst +14 -0
  168. data/mlx/docs/src/python/fast.rst +16 -0
  169. data/mlx/docs/src/python/fft.rst +24 -0
  170. data/mlx/docs/src/python/linalg.rst +27 -0
  171. data/mlx/docs/src/python/memory_management.rst +16 -0
  172. data/mlx/docs/src/python/metal.rst +12 -0
  173. data/mlx/docs/src/python/nn/distributed.rst +30 -0
  174. data/mlx/docs/src/python/nn/functions.rst +40 -0
  175. data/mlx/docs/src/python/nn/init.rst +45 -0
  176. data/mlx/docs/src/python/nn/layers.rst +74 -0
  177. data/mlx/docs/src/python/nn/losses.rst +25 -0
  178. data/mlx/docs/src/python/nn/module.rst +38 -0
  179. data/mlx/docs/src/python/nn.rst +186 -0
  180. data/mlx/docs/src/python/ops.rst +184 -0
  181. data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
  182. data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
  183. data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
  184. data/mlx/docs/src/python/optimizers.rst +78 -0
  185. data/mlx/docs/src/python/random.rst +48 -0
  186. data/mlx/docs/src/python/transforms.rst +22 -0
  187. data/mlx/docs/src/python/tree_utils.rst +23 -0
  188. data/mlx/docs/src/usage/compile.rst +516 -0
  189. data/mlx/docs/src/usage/distributed.rst +572 -0
  190. data/mlx/docs/src/usage/export.rst +288 -0
  191. data/mlx/docs/src/usage/function_transforms.rst +191 -0
  192. data/mlx/docs/src/usage/indexing.rst +194 -0
  193. data/mlx/docs/src/usage/launching_distributed.rst +234 -0
  194. data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
  195. data/mlx/docs/src/usage/numpy.rst +124 -0
  196. data/mlx/docs/src/usage/quick_start.rst +67 -0
  197. data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
  198. data/mlx/docs/src/usage/unified_memory.rst +78 -0
  199. data/mlx/docs/src/usage/using_streams.rst +18 -0
  200. data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
  201. data/mlx/examples/cmake_project/README.md +26 -0
  202. data/mlx/examples/cmake_project/example.cpp +14 -0
  203. data/mlx/examples/cpp/CMakeLists.txt +12 -0
  204. data/mlx/examples/cpp/distributed.cpp +22 -0
  205. data/mlx/examples/cpp/linear_regression.cpp +54 -0
  206. data/mlx/examples/cpp/logistic_regression.cpp +54 -0
  207. data/mlx/examples/cpp/metal_capture.cpp +31 -0
  208. data/mlx/examples/cpp/timer.h +20 -0
  209. data/mlx/examples/cpp/tutorial.cpp +99 -0
  210. data/mlx/examples/export/CMakeLists.txt +22 -0
  211. data/mlx/examples/export/README.md +49 -0
  212. data/mlx/examples/export/eval_mlp.cpp +25 -0
  213. data/mlx/examples/export/eval_mlp.py +52 -0
  214. data/mlx/examples/export/train_mlp.cpp +35 -0
  215. data/mlx/examples/export/train_mlp.py +76 -0
  216. data/mlx/examples/extensions/CMakeLists.txt +78 -0
  217. data/mlx/examples/extensions/README.md +24 -0
  218. data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
  219. data/mlx/examples/extensions/axpby/axpby.h +90 -0
  220. data/mlx/examples/extensions/axpby/axpby.metal +47 -0
  221. data/mlx/examples/extensions/bindings.cpp +39 -0
  222. data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
  223. data/mlx/examples/extensions/pyproject.toml +8 -0
  224. data/mlx/examples/extensions/requirements.txt +4 -0
  225. data/mlx/examples/extensions/setup.py +18 -0
  226. data/mlx/examples/extensions/test.py +12 -0
  227. data/mlx/examples/python/linear_regression.py +46 -0
  228. data/mlx/examples/python/logistic_regression.py +49 -0
  229. data/mlx/examples/python/qqmm.py +117 -0
  230. data/mlx/mlx/3rdparty/.clang-format +2 -0
  231. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  232. data/mlx/mlx/CMakeLists.txt +107 -0
  233. data/mlx/mlx/allocator.h +75 -0
  234. data/mlx/mlx/api.h +29 -0
  235. data/mlx/mlx/array.cpp +354 -0
  236. data/mlx/mlx/array.h +647 -0
  237. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  238. data/mlx/mlx/backend/common/binary.h +97 -0
  239. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  240. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  241. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  242. data/mlx/mlx/backend/common/common.cpp +305 -0
  243. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  244. data/mlx/mlx/backend/common/compiled.h +77 -0
  245. data/mlx/mlx/backend/common/copy.h +50 -0
  246. data/mlx/mlx/backend/common/hadamard.h +109 -0
  247. data/mlx/mlx/backend/common/load.cpp +57 -0
  248. data/mlx/mlx/backend/common/matmul.h +67 -0
  249. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  250. data/mlx/mlx/backend/common/reduce.h +59 -0
  251. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  252. data/mlx/mlx/backend/common/slicing.h +20 -0
  253. data/mlx/mlx/backend/common/ternary.h +85 -0
  254. data/mlx/mlx/backend/common/unary.h +29 -0
  255. data/mlx/mlx/backend/common/utils.cpp +231 -0
  256. data/mlx/mlx/backend/common/utils.h +205 -0
  257. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  258. data/mlx/mlx/backend/cpu/arange.h +28 -0
  259. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  260. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  261. data/mlx/mlx/backend/cpu/binary.h +517 -0
  262. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  263. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  264. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  265. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  266. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  267. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  268. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  269. data/mlx/mlx/backend/cpu/copy.h +36 -0
  270. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  271. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  272. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  273. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  274. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  275. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  276. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  277. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  278. data/mlx/mlx/backend/cpu/eval.h +12 -0
  279. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  280. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  281. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  282. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  283. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  284. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  285. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  286. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  287. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  288. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  289. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  290. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  291. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  292. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  293. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  294. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  295. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  296. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  297. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  298. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  299. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  300. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  301. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  302. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  303. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  304. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  305. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  306. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  307. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  308. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  309. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  310. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  311. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  312. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  313. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  314. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  315. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  316. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  317. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  318. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  319. data/mlx/mlx/backend/cpu/unary.h +281 -0
  320. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  321. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  322. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  323. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  324. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  325. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  326. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  327. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  328. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  329. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  330. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  331. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  332. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  333. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  334. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  335. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  336. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  337. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  338. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  339. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  340. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  341. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  342. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  343. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  344. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  345. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  346. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  347. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  348. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  349. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  350. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  351. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  352. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  353. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  354. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  355. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  356. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  357. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  358. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  359. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  360. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  361. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  362. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  363. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  364. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  365. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  366. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  367. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  368. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  369. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  370. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  371. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  372. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  373. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  374. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  375. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  376. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  377. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  378. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  379. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  380. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  381. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  382. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  383. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  384. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  385. data/mlx/mlx/backend/cuda/device.h +195 -0
  386. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  387. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  388. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  389. data/mlx/mlx/backend/cuda/event.cu +415 -0
  390. data/mlx/mlx/backend/cuda/event.h +79 -0
  391. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  392. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  393. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  394. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  395. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  396. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  397. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  398. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  399. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  400. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  401. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  402. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  403. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  404. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  405. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  406. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  407. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  408. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  409. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  410. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  411. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  412. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  413. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  414. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  415. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  416. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  417. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  418. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  419. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  420. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  421. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  422. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  423. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  424. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  425. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  426. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  427. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  428. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  429. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  430. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  431. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  432. data/mlx/mlx/backend/cuda/random.cu +202 -0
  433. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  434. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  435. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  436. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  437. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  438. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  439. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  440. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  441. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  442. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  443. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  444. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  445. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  446. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  447. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  448. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  449. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  450. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  451. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  452. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  453. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  454. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  455. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  456. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  457. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  458. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  459. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  460. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  461. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  462. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  463. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  464. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  465. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  466. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  467. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  468. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  469. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  470. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  471. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  472. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  473. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  474. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  475. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  476. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  477. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  478. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  479. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  480. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  481. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  482. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  483. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  484. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  485. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  486. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  487. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  488. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  489. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  490. data/mlx/mlx/backend/cuda/utils.h +49 -0
  491. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  492. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  493. data/mlx/mlx/backend/cuda/worker.h +55 -0
  494. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  495. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  496. data/mlx/mlx/backend/gpu/copy.h +57 -0
  497. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  498. data/mlx/mlx/backend/gpu/eval.h +18 -0
  499. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  500. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  501. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  502. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  503. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  504. data/mlx/mlx/backend/metal/allocator.h +79 -0
  505. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  506. data/mlx/mlx/backend/metal/binary.h +33 -0
  507. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  508. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  509. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  510. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  511. data/mlx/mlx/backend/metal/device.cpp +816 -0
  512. data/mlx/mlx/backend/metal/device.h +289 -0
  513. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  514. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  515. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  516. data/mlx/mlx/backend/metal/event.cpp +62 -0
  517. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  518. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  519. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  520. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  521. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  522. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  523. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  524. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  525. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  526. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  527. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  528. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  529. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  530. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  531. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  532. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  533. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  534. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  535. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  536. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  537. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  538. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  539. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  540. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  541. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  542. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  543. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  544. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  545. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  546. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  547. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  548. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  549. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  550. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  551. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  552. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  553. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  554. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  555. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  556. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  557. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  558. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  559. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  560. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  561. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  562. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  563. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  564. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  565. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  566. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  567. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  568. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  569. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  570. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  571. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  572. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  573. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  574. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  575. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  576. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  577. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  578. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  579. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  580. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  581. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  582. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  583. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  584. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  585. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  586. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  587. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  588. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  589. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  590. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  591. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  592. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  593. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  594. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  595. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  596. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  597. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  598. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  599. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  600. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  601. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  602. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  603. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  604. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  605. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  606. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  607. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  608. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  609. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  610. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  611. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  612. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  613. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  614. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  615. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  616. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  617. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  618. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  619. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  620. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  621. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  622. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  623. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  624. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  625. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  626. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  627. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  628. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  629. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  630. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  631. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  632. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  633. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  634. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  635. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  636. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  637. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  638. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  639. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  640. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  641. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  642. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  643. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  644. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  645. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  646. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  647. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  648. data/mlx/mlx/backend/metal/kernels.h +375 -0
  649. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  650. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  651. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  652. data/mlx/mlx/backend/metal/matmul.h +144 -0
  653. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  654. data/mlx/mlx/backend/metal/metal.h +25 -0
  655. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  656. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  657. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  658. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  659. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  660. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  661. data/mlx/mlx/backend/metal/reduce.h +41 -0
  662. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  663. data/mlx/mlx/backend/metal/resident.h +32 -0
  664. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  665. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  666. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  667. data/mlx/mlx/backend/metal/scan.h +17 -0
  668. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  669. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  670. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  671. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  672. data/mlx/mlx/backend/metal/ternary.h +21 -0
  673. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  674. data/mlx/mlx/backend/metal/unary.h +21 -0
  675. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  676. data/mlx/mlx/backend/metal/utils.h +99 -0
  677. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  678. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  679. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  680. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  681. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  682. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  683. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  684. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  685. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  686. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  687. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  688. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  689. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  690. data/mlx/mlx/compile.cpp +1243 -0
  691. data/mlx/mlx/compile.h +45 -0
  692. data/mlx/mlx/compile_impl.h +70 -0
  693. data/mlx/mlx/device.cpp +72 -0
  694. data/mlx/mlx/device.h +56 -0
  695. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  696. data/mlx/mlx/distributed/distributed.cpp +197 -0
  697. data/mlx/mlx/distributed/distributed.h +61 -0
  698. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  699. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  700. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  701. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  702. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  703. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  704. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  705. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  706. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  707. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  708. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  709. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  710. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  711. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  712. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  713. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  714. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  715. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  716. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  717. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  718. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  719. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  720. data/mlx/mlx/distributed/ops.cpp +186 -0
  721. data/mlx/mlx/distributed/ops.h +57 -0
  722. data/mlx/mlx/distributed/primitives.cpp +95 -0
  723. data/mlx/mlx/distributed/primitives.h +156 -0
  724. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  725. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  726. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  727. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  728. data/mlx/mlx/distributed/ring/ring.h +12 -0
  729. data/mlx/mlx/distributed/utils.cpp +206 -0
  730. data/mlx/mlx/distributed/utils.h +67 -0
  731. data/mlx/mlx/dtype.cpp +197 -0
  732. data/mlx/mlx/dtype.h +116 -0
  733. data/mlx/mlx/dtype_utils.cpp +42 -0
  734. data/mlx/mlx/dtype_utils.h +119 -0
  735. data/mlx/mlx/einsum.cpp +941 -0
  736. data/mlx/mlx/einsum.h +23 -0
  737. data/mlx/mlx/event.h +58 -0
  738. data/mlx/mlx/export.cpp +1130 -0
  739. data/mlx/mlx/export.h +137 -0
  740. data/mlx/mlx/export_impl.h +99 -0
  741. data/mlx/mlx/fast.cpp +941 -0
  742. data/mlx/mlx/fast.h +103 -0
  743. data/mlx/mlx/fast_primitives.h +427 -0
  744. data/mlx/mlx/fence.h +39 -0
  745. data/mlx/mlx/fft.cpp +262 -0
  746. data/mlx/mlx/fft.h +159 -0
  747. data/mlx/mlx/graph_utils.cpp +175 -0
  748. data/mlx/mlx/graph_utils.h +67 -0
  749. data/mlx/mlx/io/CMakeLists.txt +25 -0
  750. data/mlx/mlx/io/gguf.cpp +470 -0
  751. data/mlx/mlx/io/gguf.h +20 -0
  752. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  753. data/mlx/mlx/io/load.cpp +397 -0
  754. data/mlx/mlx/io/load.h +175 -0
  755. data/mlx/mlx/io/no_gguf.cpp +20 -0
  756. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  757. data/mlx/mlx/io/safetensors.cpp +234 -0
  758. data/mlx/mlx/io.h +61 -0
  759. data/mlx/mlx/linalg.cpp +708 -0
  760. data/mlx/mlx/linalg.h +115 -0
  761. data/mlx/mlx/memory.h +80 -0
  762. data/mlx/mlx/mlx.h +25 -0
  763. data/mlx/mlx/ops.cpp +6094 -0
  764. data/mlx/mlx/ops.h +1610 -0
  765. data/mlx/mlx/primitives.cpp +5850 -0
  766. data/mlx/mlx/primitives.h +2525 -0
  767. data/mlx/mlx/random.cpp +492 -0
  768. data/mlx/mlx/random.h +283 -0
  769. data/mlx/mlx/scheduler.cpp +73 -0
  770. data/mlx/mlx/scheduler.h +189 -0
  771. data/mlx/mlx/small_vector.h +540 -0
  772. data/mlx/mlx/stream.h +42 -0
  773. data/mlx/mlx/threadpool.h +133 -0
  774. data/mlx/mlx/transforms.cpp +1065 -0
  775. data/mlx/mlx/transforms.h +231 -0
  776. data/mlx/mlx/transforms_impl.h +88 -0
  777. data/mlx/mlx/types/bf16.h +187 -0
  778. data/mlx/mlx/types/complex.h +113 -0
  779. data/mlx/mlx/types/fp16.h +234 -0
  780. data/mlx/mlx/types/half_types.h +58 -0
  781. data/mlx/mlx/types/limits.h +70 -0
  782. data/mlx/mlx/utils.cpp +302 -0
  783. data/mlx/mlx/utils.h +174 -0
  784. data/mlx/mlx/version.cpp +11 -0
  785. data/mlx/mlx/version.h +22 -0
  786. data/mlx/mlx.pc.in +52 -0
  787. data/mlx/pyproject.toml +7 -0
  788. data/mlx/python/mlx/__main__.py +27 -0
  789. data/mlx/python/mlx/_distributed_utils/common.py +135 -0
  790. data/mlx/python/mlx/_distributed_utils/config.py +631 -0
  791. data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
  792. data/mlx/python/mlx/_reprlib_fix.py +16 -0
  793. data/mlx/python/mlx/_stub_patterns.txt +36 -0
  794. data/mlx/python/mlx/extension.py +88 -0
  795. data/mlx/python/mlx/nn/__init__.py +5 -0
  796. data/mlx/python/mlx/nn/init.py +441 -0
  797. data/mlx/python/mlx/nn/layers/__init__.py +105 -0
  798. data/mlx/python/mlx/nn/layers/activations.py +661 -0
  799. data/mlx/python/mlx/nn/layers/base.py +675 -0
  800. data/mlx/python/mlx/nn/layers/containers.py +24 -0
  801. data/mlx/python/mlx/nn/layers/convolution.py +232 -0
  802. data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
  803. data/mlx/python/mlx/nn/layers/distributed.py +601 -0
  804. data/mlx/python/mlx/nn/layers/dropout.py +137 -0
  805. data/mlx/python/mlx/nn/layers/embedding.py +53 -0
  806. data/mlx/python/mlx/nn/layers/linear.py +180 -0
  807. data/mlx/python/mlx/nn/layers/normalization.py +363 -0
  808. data/mlx/python/mlx/nn/layers/pooling.py +398 -0
  809. data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
  810. data/mlx/python/mlx/nn/layers/quantized.py +426 -0
  811. data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
  812. data/mlx/python/mlx/nn/layers/transformer.py +354 -0
  813. data/mlx/python/mlx/nn/layers/upsample.py +277 -0
  814. data/mlx/python/mlx/nn/losses.py +610 -0
  815. data/mlx/python/mlx/nn/utils.py +165 -0
  816. data/mlx/python/mlx/optimizers/__init__.py +4 -0
  817. data/mlx/python/mlx/optimizers/optimizers.py +976 -0
  818. data/mlx/python/mlx/optimizers/schedulers.py +158 -0
  819. data/mlx/python/mlx/py.typed +1 -0
  820. data/mlx/python/mlx/utils.py +325 -0
  821. data/mlx/python/src/CMakeLists.txt +96 -0
  822. data/mlx/python/src/array.cpp +1525 -0
  823. data/mlx/python/src/buffer.h +124 -0
  824. data/mlx/python/src/constants.cpp +15 -0
  825. data/mlx/python/src/convert.cpp +504 -0
  826. data/mlx/python/src/convert.h +50 -0
  827. data/mlx/python/src/cuda.cpp +19 -0
  828. data/mlx/python/src/device.cpp +98 -0
  829. data/mlx/python/src/distributed.cpp +352 -0
  830. data/mlx/python/src/export.cpp +356 -0
  831. data/mlx/python/src/fast.cpp +627 -0
  832. data/mlx/python/src/fft.cpp +514 -0
  833. data/mlx/python/src/indexing.cpp +1016 -0
  834. data/mlx/python/src/indexing.h +41 -0
  835. data/mlx/python/src/linalg.cpp +663 -0
  836. data/mlx/python/src/load.cpp +531 -0
  837. data/mlx/python/src/load.h +51 -0
  838. data/mlx/python/src/memory.cpp +125 -0
  839. data/mlx/python/src/metal.cpp +98 -0
  840. data/mlx/python/src/mlx.cpp +51 -0
  841. data/mlx/python/src/mlx_func.cpp +116 -0
  842. data/mlx/python/src/mlx_func.h +31 -0
  843. data/mlx/python/src/ops.cpp +5545 -0
  844. data/mlx/python/src/random.cpp +516 -0
  845. data/mlx/python/src/small_vector.h +76 -0
  846. data/mlx/python/src/stream.cpp +147 -0
  847. data/mlx/python/src/transforms.cpp +1542 -0
  848. data/mlx/python/src/trees.cpp +311 -0
  849. data/mlx/python/src/trees.h +62 -0
  850. data/mlx/python/src/utils.cpp +98 -0
  851. data/mlx/python/src/utils.h +78 -0
  852. data/mlx/python/tests/__main__.py +5 -0
  853. data/mlx/python/tests/cuda_skip.py +62 -0
  854. data/mlx/python/tests/mlx_distributed_tests.py +314 -0
  855. data/mlx/python/tests/mlx_tests.py +116 -0
  856. data/mlx/python/tests/mpi_test_distributed.py +142 -0
  857. data/mlx/python/tests/nccl_test_distributed.py +52 -0
  858. data/mlx/python/tests/ring_test_distributed.py +131 -0
  859. data/mlx/python/tests/test_array.py +2139 -0
  860. data/mlx/python/tests/test_autograd.py +880 -0
  861. data/mlx/python/tests/test_bf16.py +196 -0
  862. data/mlx/python/tests/test_blas.py +1429 -0
  863. data/mlx/python/tests/test_compile.py +1277 -0
  864. data/mlx/python/tests/test_constants.py +41 -0
  865. data/mlx/python/tests/test_conv.py +1198 -0
  866. data/mlx/python/tests/test_conv_transpose.py +810 -0
  867. data/mlx/python/tests/test_device.py +150 -0
  868. data/mlx/python/tests/test_double.py +306 -0
  869. data/mlx/python/tests/test_einsum.py +363 -0
  870. data/mlx/python/tests/test_eval.py +200 -0
  871. data/mlx/python/tests/test_export_import.py +614 -0
  872. data/mlx/python/tests/test_fast.py +923 -0
  873. data/mlx/python/tests/test_fast_sdpa.py +647 -0
  874. data/mlx/python/tests/test_fft.py +323 -0
  875. data/mlx/python/tests/test_graph.py +37 -0
  876. data/mlx/python/tests/test_init.py +139 -0
  877. data/mlx/python/tests/test_linalg.py +621 -0
  878. data/mlx/python/tests/test_load.py +447 -0
  879. data/mlx/python/tests/test_losses.py +427 -0
  880. data/mlx/python/tests/test_memory.py +77 -0
  881. data/mlx/python/tests/test_nn.py +1986 -0
  882. data/mlx/python/tests/test_ops.py +3261 -0
  883. data/mlx/python/tests/test_optimizers.py +584 -0
  884. data/mlx/python/tests/test_quantized.py +1160 -0
  885. data/mlx/python/tests/test_random.py +392 -0
  886. data/mlx/python/tests/test_reduce.py +223 -0
  887. data/mlx/python/tests/test_tree.py +96 -0
  888. data/mlx/python/tests/test_upsample.py +100 -0
  889. data/mlx/python/tests/test_vmap.py +860 -0
  890. data/mlx/setup.py +315 -0
  891. data/mlx/tests/CMakeLists.txt +44 -0
  892. data/mlx/tests/allocator_tests.cpp +41 -0
  893. data/mlx/tests/arg_reduce_tests.cpp +204 -0
  894. data/mlx/tests/array_tests.cpp +663 -0
  895. data/mlx/tests/autograd_tests.cpp +1399 -0
  896. data/mlx/tests/blas_tests.cpp +110 -0
  897. data/mlx/tests/compile_tests.cpp +818 -0
  898. data/mlx/tests/creations_tests.cpp +239 -0
  899. data/mlx/tests/custom_vjp_tests.cpp +55 -0
  900. data/mlx/tests/device_tests.cpp +35 -0
  901. data/mlx/tests/einsum_tests.cpp +85 -0
  902. data/mlx/tests/eval_tests.cpp +93 -0
  903. data/mlx/tests/export_import_tests.cpp +164 -0
  904. data/mlx/tests/fft_tests.cpp +366 -0
  905. data/mlx/tests/gpu_tests.cpp +523 -0
  906. data/mlx/tests/linalg_tests.cpp +639 -0
  907. data/mlx/tests/load_tests.cpp +270 -0
  908. data/mlx/tests/ops_tests.cpp +4159 -0
  909. data/mlx/tests/random_tests.cpp +716 -0
  910. data/mlx/tests/scheduler_tests.cpp +121 -0
  911. data/mlx/tests/tests.cpp +26 -0
  912. data/mlx/tests/utils_tests.cpp +67 -0
  913. data/mlx/tests/vmap_tests.cpp +547 -0
  914. metadata +958 -0
@@ -0,0 +1,4159 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ // Required for using M_PI_2 in MSVC.
4
+ #define _USE_MATH_DEFINES
5
+ #include <cmath>
6
+ #include <numeric>
7
+
8
+ #include "doctest/doctest.h"
9
+
10
+ #include "mlx/backend/cuda/cuda.h"
11
+ #include "mlx/mlx.h"
12
+
13
+ using namespace mlx::core;
14
+
15
+ TEST_CASE("test copy") {
16
+ array x(1.0);
17
+ auto y = copy(x);
18
+ CHECK_EQ(y.shape(), Shape{});
19
+ CHECK_NE(y.id(), x.id());
20
+ CHECK_EQ(y.item<float>(), 1.0f);
21
+
22
+ x = array({1, 2}, {2, 1});
23
+ y = copy(x);
24
+ CHECK_EQ(y.shape(), Shape{2, 1});
25
+ CHECK_EQ(y.dtype(), int32);
26
+ CHECK_NE(y.id(), x.id());
27
+ CHECK(array_equal(y, x).item<bool>());
28
+ }
29
+
30
+ TEST_CASE("test reshape") {
31
+ array x(1.0);
32
+ CHECK_EQ(reshape(x, {}).shape(), Shape{});
33
+ CHECK_THROWS_AS(reshape(x, {2}), std::invalid_argument);
34
+ auto y = reshape(x, {1, 1, 1});
35
+ CHECK_EQ(y.shape(), Shape{1, 1, 1});
36
+ y = reshape(x, {-1, 1, 1});
37
+ CHECK_EQ(y.shape(), Shape{1, 1, 1});
38
+ y = reshape(x, {1, 1, -1});
39
+ CHECK_EQ(y.shape(), Shape{1, 1, 1});
40
+ CHECK_THROWS_AS(reshape(x, {1, -1, -1}), std::invalid_argument);
41
+ CHECK_THROWS_AS(reshape(x, {2, -1}), std::invalid_argument);
42
+
43
+ x = zeros({2, 2, 2});
44
+ y = reshape(x, {8});
45
+ CHECK_EQ(y.shape(), Shape{8});
46
+ CHECK_THROWS_AS(reshape(x, {7}), std::invalid_argument);
47
+ y = reshape(x, {-1});
48
+ CHECK_EQ(y.shape(), Shape{8});
49
+ y = reshape(x, {-1, 2});
50
+ CHECK_EQ(y.shape(), Shape{4, 2});
51
+ CHECK_THROWS_AS(reshape(x, {-1, 7}), std::invalid_argument);
52
+
53
+ // Works with empty array
54
+ x = array({});
55
+ y = reshape(x, {0, 0, 0});
56
+ CHECK_EQ(y.shape(), Shape{0, 0, 0});
57
+ y.eval();
58
+ CHECK_EQ(y.size(), 0);
59
+ CHECK_THROWS_AS(reshape(x, {}), std::invalid_argument);
60
+ CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument);
61
+ y = reshape(x, {1, 5, 0});
62
+ CHECK_EQ(y.shape(), Shape{1, 5, 0});
63
+
64
+ // Check that reshaping a transposed array doesn't result in a copy
65
+ x = reshape(arange(64), {2, 4, 8});
66
+ x.eval();
67
+ CHECK_EQ(x.strides()[0], 32);
68
+ CHECK_EQ(x.strides()[1], 8);
69
+ CHECK_EQ(x.strides()[2], 1);
70
+ y = reshape(transpose(x, {0, 2, 1}), {2, 4, 2, 4});
71
+ y.eval();
72
+ CHECK_EQ(y.strides()[0], 32);
73
+ CHECK_EQ(y.strides()[1], 2);
74
+ CHECK_EQ(y.strides()[2], 1);
75
+ CHECK_EQ(y.strides()[3], 8);
76
+ CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
77
+
78
+ // Split transposed (2, 8, 4) -> (2, 8, 2, 2)
79
+ y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 2});
80
+ y.eval();
81
+ CHECK_EQ(y.strides()[0], 32);
82
+ CHECK_EQ(y.strides()[1], 1);
83
+ CHECK_EQ(y.strides()[2], 16);
84
+ CHECK_EQ(y.strides()[3], 8);
85
+ CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
86
+
87
+ // Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2)
88
+ y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2});
89
+ y.eval();
90
+ CHECK_EQ(y.strides()[0], 32);
91
+ CHECK_EQ(y.strides()[1], 1);
92
+ CHECK_EQ(y.strides()[2], 16);
93
+ // y.strides()[3] can be anything since y.shape()[3] == 1
94
+ CHECK_EQ(y.strides()[4], 8);
95
+ CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
96
+
97
+ // Split transposed (2, 8, 4) -> (2, 8, 2, 1, 2, 1)
98
+ y = reshape(transpose(x, {0, 2, 1}), {2, 8, 2, 1, 2, 1});
99
+ y.eval();
100
+ CHECK_EQ(y.strides()[0], 32);
101
+ CHECK_EQ(y.strides()[1], 1);
102
+ CHECK_EQ(y.strides()[2], 16);
103
+ // y.strides()[3] can be anything since y.shape()[3] == 1
104
+ CHECK_EQ(y.strides()[4], 8);
105
+ // y.strides()[5] can be anything since y.shape()[5] == 1
106
+ CHECK_EQ(x.data<int32_t>(), y.data<int32_t>());
107
+
108
+ // Check contiguity preservation
109
+ x = ones({10, 10});
110
+ eval(x);
111
+ CHECK(x.flags().row_contiguous);
112
+ CHECK(!x.flags().col_contiguous);
113
+ y = reshape(x, {2, 5, 10});
114
+ eval(y);
115
+ CHECK(y.flags().row_contiguous);
116
+ CHECK(!y.flags().col_contiguous);
117
+ y = reshape(x, {10, 1, 10, 1});
118
+ eval(y);
119
+ CHECK(y.flags().row_contiguous);
120
+ CHECK(!y.flags().col_contiguous);
121
+ x = transpose(x, {1, 0});
122
+ eval(x);
123
+ CHECK(!x.flags().row_contiguous);
124
+ CHECK(x.flags().col_contiguous);
125
+ y = reshape(x, {2, 5, 10});
126
+ eval(y);
127
+ CHECK(!y.flags().row_contiguous);
128
+ CHECK(y.flags().col_contiguous);
129
+ y = reshape(x, {2, 50});
130
+ eval(y);
131
+ CHECK(y.flags().row_contiguous);
132
+ CHECK(!y.flags().col_contiguous);
133
+ y = reshape(x, {10, 1, 10, 1});
134
+ eval(y);
135
+ CHECK(!y.flags().row_contiguous);
136
+ CHECK(y.flags().col_contiguous);
137
+ }
138
+
139
+ TEST_CASE("test flatten") {
140
+ array x = zeros({2, 3, 4});
141
+ CHECK_EQ(flatten(x).shape(), Shape({2 * 3 * 4}));
142
+
143
+ CHECK_EQ(flatten(x, 1, 1).shape(), Shape({2, 3, 4}));
144
+ CHECK_EQ(flatten(x, 1, 2).shape(), Shape({2, 3 * 4}));
145
+ CHECK_EQ(flatten(x, 1, 3).shape(), Shape({2, 3 * 4}));
146
+ CHECK_EQ(flatten(x, 1, -1).shape(), Shape({2, 3 * 4}));
147
+ CHECK_EQ(flatten(x, -2, -1).shape(), Shape({2, 3 * 4}));
148
+ CHECK_EQ(flatten(x, -3, -1).shape(), Shape({2 * 3 * 4}));
149
+ CHECK_EQ(flatten(x, -4, -1).shape(), Shape({2 * 3 * 4}));
150
+
151
+ // Check start > end throws
152
+ CHECK_THROWS(flatten(x, 2, 1));
153
+
154
+ // Check start >= ndim throws
155
+ CHECK_THROWS(flatten(x, 5, 6));
156
+
157
+ // Check end < 0 throws
158
+ CHECK_THROWS(flatten(x, -5, -4));
159
+
160
+ // Check scalar flattens to 1D
161
+ x = array(1);
162
+ CHECK_EQ(flatten(x, -3, -1).shape(), Shape({1}));
163
+ CHECK_EQ(flatten(x, 0, 0).shape(), Shape({1}));
164
+ }
165
+
166
+ TEST_CASE("test unflatten") {
167
+ array x = array(1);
168
+ CHECK_THROWS(unflatten(x, 0, {1, 1}));
169
+
170
+ x = array({1});
171
+ auto out = unflatten(x, 0, {1, 1});
172
+ CHECK_EQ(out.shape(), Shape({1, 1}));
173
+ CHECK_THROWS(unflatten(x, 1, {1, 1}));
174
+ CHECK_THROWS(unflatten(x, 0, {-1, -1}));
175
+ CHECK_THROWS(unflatten(x, 0, {-1, 2}));
176
+ CHECK_THROWS(unflatten(x, 0, {}));
177
+
178
+ x = zeros({4, 8});
179
+ out = unflatten(x, 1, {2, 2, 2});
180
+ CHECK_EQ(out.shape(), Shape({4, 2, 2, 2}));
181
+ }
182
+
183
+ TEST_CASE("test squeeze and expand") {
184
+ array x = zeros({2, 1, 2, 1, 2, 1});
185
+ CHECK_EQ(squeeze(x).shape(), Shape{2, 2, 2});
186
+ CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), Shape{2, 2, 2});
187
+ CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), Shape{2, 2, 2});
188
+ CHECK_EQ(squeeze(x, 1).shape(), Shape{2, 2, 1, 2, 1});
189
+ CHECK_EQ(squeeze(x, -1).shape(), Shape{2, 1, 2, 1, 2});
190
+
191
+ CHECK_THROWS(squeeze(x, 0));
192
+ CHECK_THROWS(squeeze(x, 2));
193
+ CHECK_THROWS(squeeze(x, {1, 3, 1}));
194
+ CHECK_THROWS(squeeze(x, {1, 3, -3}));
195
+
196
+ x = zeros({2, 2});
197
+ CHECK_EQ(expand_dims(x, 0).shape(), Shape{1, 2, 2});
198
+ CHECK_EQ(expand_dims(x, -1).shape(), Shape{2, 2, 1});
199
+ CHECK_EQ(expand_dims(x, 1).shape(), Shape{2, 1, 2});
200
+ CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), Shape{1, 1, 1, 2, 2});
201
+ CHECK_EQ(
202
+ expand_dims(x, {0, 1, 2, 5, 6, 7}).shape(),
203
+ Shape{1, 1, 1, 2, 2, 1, 1, 1});
204
+
205
+ CHECK_THROWS(expand_dims(x, 3));
206
+ CHECK_THROWS(expand_dims(x, -4));
207
+ CHECK_THROWS(expand_dims(x, {0, 1, 0}));
208
+ CHECK_THROWS(expand_dims(x, {0, 1, -4}));
209
+ }
210
+
211
+ TEST_CASE("test slice") {
212
+ array x = array(3);
213
+ auto out = slice(x, {}, {});
214
+ CHECK_EQ(out.item<int>(), 3);
215
+ CHECK_THROWS_AS(slice(x, {1}, {2}), std::invalid_argument);
216
+ CHECK_THROWS_AS(slice(x, {}, {2}), std::invalid_argument);
217
+ CHECK_THROWS_AS(slice(x, {0}, {}), std::invalid_argument);
218
+
219
+ x = array({3});
220
+ out = slice(x, {0}, {1});
221
+ CHECK_EQ(out.item<int>(), 3);
222
+ out = slice(x, {-1}, {1});
223
+ CHECK_EQ(out.item<int>(), 3);
224
+
225
+ out = slice(x, {-3}, {10});
226
+ CHECK_EQ(out.item<int>(), 3);
227
+
228
+ out = slice(x, {1}, {0});
229
+ eval(out);
230
+ CHECK_EQ(out.shape(), Shape{0});
231
+
232
+ out = slice(x, {0}, {1}, {1});
233
+ CHECK_EQ(out.item<int>(), 3);
234
+
235
+ out = slice(x, {0}, {1}, {10});
236
+ CHECK_EQ(out.item<int>(), 3);
237
+
238
+ x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 4});
239
+ out = slice(x, {0, 0}, {2, 2});
240
+ CHECK(array_equal(out, array({0, 1, 4, 5}, {2, 2})).item<bool>());
241
+
242
+ out = slice(x, {0, 0}, {0, 2});
243
+ CHECK(array_equal(out, reshape(array({}), {0, 2})).item<bool>());
244
+
245
+ out = slice(x, {0, 2}, {2, 3});
246
+ CHECK(array_equal(out, array({2, 6}, {2, 1})).item<bool>());
247
+
248
+ out = slice(x, {0, 0}, {2, 4}, {1, 2});
249
+ CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item<bool>());
250
+
251
+ // Check contiguity preservation
252
+ x = ones({10, 10});
253
+ eval(x);
254
+ CHECK(x.flags().row_contiguous);
255
+ CHECK(!x.flags().col_contiguous);
256
+ out = slice(x, {0, 0}, {10, 5});
257
+ eval(out);
258
+ CHECK(!out.flags().row_contiguous);
259
+ CHECK(!out.flags().col_contiguous);
260
+ out = slice(x, {0, 0}, {5, 10});
261
+ eval(out);
262
+ CHECK(out.flags().row_contiguous);
263
+ CHECK(!out.flags().col_contiguous);
264
+ x = transpose(x, {1, 0});
265
+ eval(x);
266
+ CHECK(!x.flags().row_contiguous);
267
+ CHECK(x.flags().col_contiguous);
268
+ out = slice(x, {0, 0}, {10, 5});
269
+ eval(out);
270
+ CHECK(!out.flags().row_contiguous);
271
+ CHECK(out.flags().col_contiguous);
272
+ out = slice(x, {0, 0}, {5, 10});
273
+ eval(out);
274
+ CHECK(!out.flags().row_contiguous);
275
+ CHECK(!out.flags().col_contiguous);
276
+
277
+ x = ones({6, 4, 10});
278
+ out = slice(x, {0, 0, 0}, {6, 4, 10}, {2, 1, 2});
279
+ eval(out);
280
+ CHECK(!out.flags().contiguous);
281
+ CHECK(!out.flags().row_contiguous);
282
+ CHECK(!out.flags().col_contiguous);
283
+
284
+ // Check data size correctness
285
+ x = ones({4});
286
+ out = slice(x, {0}, {2});
287
+ eval(out);
288
+ CHECK_EQ(out.data_size(), 2);
289
+
290
+ out = slice(x, {2}, {4});
291
+ eval(out);
292
+ CHECK_EQ(out.data_size(), 2);
293
+
294
+ out = slice(x, {0}, {4}, {2});
295
+ eval(out);
296
+ CHECK_EQ(out.data_size(), 3);
297
+
298
+ x = ones({4, 4});
299
+ out = slice(x, {0, 0}, {2, 4});
300
+ eval(out);
301
+ CHECK_EQ(out.data_size(), 8);
302
+
303
+ out = slice(x, {0, 0}, {1, 2});
304
+ eval(out);
305
+ CHECK_EQ(out.data_size(), 2);
306
+
307
+ out = slice(x, {0, 1}, {4, 4});
308
+ eval(out);
309
+ CHECK_EQ(out.data_size(), 15);
310
+
311
+ out = slice(x, {1, 2}, {3, 4});
312
+ eval(out);
313
+ CHECK_EQ(out.data_size(), 6);
314
+
315
+ x = ones({4, 4, 4});
316
+ out = slice(x, {0, 0, 0}, {4, 2, 2});
317
+ eval(out);
318
+ CHECK_EQ(out.data_size(), 54);
319
+
320
+ x = ones({4, 4, 4});
321
+ out = slice(x, {2, 2, 2}, {3, 3, 3});
322
+ eval(out);
323
+ CHECK_EQ(out.data_size(), 1);
324
+
325
+ x = ones({4, 4, 4});
326
+ out = slice(x, {2, 2, 2}, {3, 4, 3});
327
+ eval(out);
328
+ CHECK_EQ(out.data_size(), 5);
329
+
330
+ x = ones({8});
331
+ out = slice(x, {7}, {-9}, {-1});
332
+ eval(out);
333
+ CHECK_EQ(out.data_size(), 8);
334
+
335
+ out = slice(x, {7}, {-9}, {-1});
336
+ eval(out);
337
+ CHECK_EQ(out.data_size(), 8);
338
+
339
+ x = ones({4, 2});
340
+ out = slice(x, {3, 0}, {-5, 2}, {-1, 1});
341
+ eval(out);
342
+ CHECK_EQ(out.data_size(), 8);
343
+ }
344
+
345
+ TEST_CASE("test slice update") {
346
+ array x = array({0., 0., 0., 0., 0., 0., 0., 0.}, {8}, float32);
347
+ array y = array(
348
+ {
349
+ 1.,
350
+ 2.,
351
+ 3.,
352
+ 4.,
353
+ },
354
+ {4},
355
+ float32);
356
+
357
+ auto out = slice_update(x, y, {2}, {6}, {1});
358
+ CHECK(array_equal(slice(out, {2}, {6}, {1}), y).item<bool>());
359
+
360
+ out = slice_update(x, y, {5}, {1}, {-1});
361
+ CHECK(array_equal(slice(out, {5}, {1}, {-1}), y).item<bool>());
362
+
363
+ x = reshape(x, {2, 4});
364
+ out = slice_update(x, y, {0, 0}, {2, 4}, {1, 1});
365
+ out = reshape(out, {8});
366
+ CHECK(array_equal(slice(out, {0}, {4}, {1}), y).item<bool>());
367
+ CHECK(array_equal(slice(out, {4}, {8}, {1}), y).item<bool>());
368
+ }
369
+
370
+ TEST_CASE("test dynamic slice") {
371
+ auto src = reshape(arange(6), {2, 3});
372
+ CHECK_THROWS(slice(src, array({1, 0, 0}), {0, 0, 0}, {1, 1}));
373
+ CHECK_THROWS(slice(src, array({1, 0}), {0}, {1, 1}));
374
+ CHECK_THROWS(slice(src, array({1}), {3}, {1, 1}));
375
+ CHECK_THROWS(slice(src, array({1, 0}), {0, 0}, {1, 1}));
376
+
377
+ CHECK_THROWS(slice(src, array({1}), {0}, {2, 4}));
378
+ CHECK_THROWS(slice(src, array({1.0f}, float32), {0}, {1, 1}));
379
+
380
+ auto out = slice(src, array({1}), {0}, {1, 2});
381
+ auto expected = array({3, 4}, {1, 2});
382
+ CHECK(array_equal(out, expected).item<bool>());
383
+
384
+ out = slice(src, array({1, 1}), {0, 1}, {1, 2});
385
+ expected = array({4, 5}, {1, 2});
386
+ CHECK(array_equal(out, expected).item<bool>());
387
+ }
388
+
389
+ TEST_CASE("test dynamic slice update") {
390
+ auto src = zeros({2, 3}, int32);
391
+ auto upd = ones({1, 2}, int32);
392
+ CHECK_THROWS(slice_update(src, upd, array({1, 0, 0}), {0, 0, 0}));
393
+ CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0}));
394
+ CHECK_THROWS(slice_update(src, upd, array({1}), {3}));
395
+ CHECK_THROWS(slice_update(src, upd, array({1, 0}), {0, 0}));
396
+
397
+ upd = ones({4}, int32);
398
+ CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
399
+ upd = ones({1, 4}, int32);
400
+ CHECK_THROWS(slice_update(src, upd, array({1}), {0}));
401
+ CHECK_THROWS(slice_update(src, upd, array({1.0f}, float32), {0}));
402
+
403
+ upd = ones({1, 2}, int32);
404
+ auto out = slice_update(src, upd, array({1}), {0});
405
+ auto expected = reshape(array({0, 0, 0, 1, 1, 0}), {2, 3});
406
+ CHECK(array_equal(out, expected).item<bool>());
407
+
408
+ upd = ones({1, 2}, int32);
409
+ out = slice_update(src, upd, array({1, 1}), {0, 1});
410
+ expected = reshape(array({0, 0, 0, 0, 1, 1}), {2, 3});
411
+ CHECK(array_equal(out, expected).item<bool>());
412
+ }
413
+
414
+ TEST_CASE("test split") {
415
+ array x = array(1);
416
+ CHECK_THROWS(split(x, 0));
417
+
418
+ x = array({3});
419
+ CHECK_EQ(split(x, 1)[0].item<int>(), 3);
420
+
421
+ x = array({0, 1, 2});
422
+ CHECK_THROWS(split(x, 3, 1));
423
+ CHECK_THROWS(split(x, 3, -2));
424
+
425
+ auto out = split(x, 3, 0);
426
+ CHECK_EQ(out.size(), 3);
427
+
428
+ out = split(x, 3, -1);
429
+ CHECK_EQ(out.size(), 3);
430
+ for (auto i = 0; i < 3; ++i) {
431
+ CHECK_EQ(out[i].shape(), Shape{1});
432
+ CHECK_EQ(out[i].dtype(), int32);
433
+ CHECK_EQ(out[i].item<int>(), i);
434
+ }
435
+
436
+ x = array({0, 1, 2, 3, 4, 5}, {2, 3});
437
+ out = split(x, 2);
438
+ CHECK(array_equal(out[0], array({0, 1, 2}, {1, 3})).item<bool>());
439
+ CHECK(array_equal(out[1], array({3, 4, 5}, {1, 3})).item<bool>());
440
+ out = split(x, 3, 1);
441
+ CHECK(array_equal(out[0], array({0, 3}, {2, 1})).item<bool>());
442
+ CHECK(array_equal(out[1], array({1, 4}, {2, 1})).item<bool>());
443
+ CHECK(array_equal(out[2], array({2, 5}, {2, 1})).item<bool>());
444
+
445
+ x = zeros({8, 12});
446
+ out = split(x, 2);
447
+ CHECK_EQ(out.size(), 2);
448
+ CHECK_EQ(out[0].shape(), Shape{4, 12});
449
+ CHECK_EQ(out[1].shape(), Shape{4, 12});
450
+ out = split(x, 3, 1);
451
+ CHECK_EQ(out.size(), 3);
452
+ CHECK_EQ(out[0].shape(), Shape{8, 4});
453
+ CHECK_EQ(out[1].shape(), Shape{8, 4});
454
+ CHECK_EQ(out[2].shape(), Shape{8, 4});
455
+
456
+ out = split(x, Shape{});
457
+ CHECK_EQ(out.size(), 1);
458
+ CHECK_EQ(out[0].shape(), x.shape());
459
+
460
+ out = split(x, {3, 7});
461
+ CHECK_EQ(out.size(), 3);
462
+ CHECK_EQ(out[0].shape(), Shape{3, 12});
463
+ CHECK_EQ(out[1].shape(), Shape{4, 12});
464
+ CHECK_EQ(out[2].shape(), Shape{1, 12});
465
+
466
+ out = split(x, Shape{20});
467
+ CHECK_EQ(out.size(), 2);
468
+ CHECK_EQ(out[0].shape(), Shape{8, 12});
469
+ CHECK_EQ(out[1].shape(), Shape{0, 12});
470
+
471
+ // Negative indices
472
+ out = split(x, Shape{-5});
473
+ CHECK_EQ(out[0].shape(), Shape{3, 12});
474
+ CHECK_EQ(out[1].shape(), Shape{5, 12});
475
+
476
+ // Different axis
477
+ out = split(x, {2, 8}, 1);
478
+ CHECK_EQ(out[0].shape(), Shape{8, 2});
479
+ CHECK_EQ(out[1].shape(), Shape{8, 6});
480
+ CHECK_EQ(out[2].shape(), Shape{8, 4});
481
+
482
+ // Out of order indices
483
+ x = arange(5);
484
+ out = split(x, {2, 1, 2});
485
+ CHECK(array_equal(out[0], array({0, 1})).item<bool>());
486
+ CHECK(array_equal(out[1], array({})).item<bool>());
487
+ CHECK(array_equal(out[2], array({1})).item<bool>());
488
+ CHECK(array_equal(out[3], array({2, 3, 4})).item<bool>());
489
+ }
490
+
491
+ TEST_CASE("test swap and move axes") {
492
+ // Test swapaxes
493
+ array a(0.0);
494
+ CHECK_THROWS(swapaxes(a, 0, 0));
495
+
496
+ a = zeros({2});
497
+ CHECK_THROWS(swapaxes(a, 0, 1));
498
+ CHECK_EQ(swapaxes(a, 0, 0).shape(), Shape{2});
499
+ CHECK_EQ(swapaxes(a, -1, -1).shape(), Shape{2});
500
+
501
+ a = zeros({2, 3, 4});
502
+ CHECK_THROWS(swapaxes(a, 0, -4));
503
+ CHECK_THROWS(swapaxes(a, 0, 3));
504
+ CHECK_THROWS(swapaxes(a, 3, 0));
505
+ CHECK_THROWS(swapaxes(a, -4, 0));
506
+ CHECK_EQ(swapaxes(a, 0, 2).shape(), Shape{4, 3, 2});
507
+ CHECK_EQ(swapaxes(a, 0, 1).shape(), Shape{3, 2, 4});
508
+ CHECK_EQ(swapaxes(a, 0, -1).shape(), Shape{4, 3, 2});
509
+ CHECK_EQ(swapaxes(a, -2, 2).shape(), Shape{2, 4, 3});
510
+
511
+ // Test moveaxis
512
+ a = array(0.0);
513
+ CHECK_THROWS(moveaxis(a, 0, 0));
514
+
515
+ a = zeros({2});
516
+ CHECK_THROWS(moveaxis(a, 0, 1));
517
+ CHECK_EQ(moveaxis(a, 0, 0).shape(), Shape{2});
518
+ CHECK_EQ(moveaxis(a, -1, -1).shape(), Shape{2});
519
+
520
+ a = zeros({2, 3, 4});
521
+ CHECK_THROWS(moveaxis(a, 0, -4));
522
+ CHECK_THROWS(moveaxis(a, 0, 3));
523
+ CHECK_THROWS(moveaxis(a, 3, 0));
524
+ CHECK_THROWS(moveaxis(a, -4, 0));
525
+ CHECK_EQ(moveaxis(a, 0, 2).shape(), Shape{3, 4, 2});
526
+ CHECK_EQ(moveaxis(a, 0, 1).shape(), Shape{3, 2, 4});
527
+ CHECK_EQ(moveaxis(a, 0, -1).shape(), Shape{3, 4, 2});
528
+ CHECK_EQ(moveaxis(a, -2, 2).shape(), Shape{2, 4, 3});
529
+ }
530
+
531
+ TEST_CASE("test transpose") {
532
+ array x(1);
533
+ auto y = transpose(x);
534
+ CHECK_EQ(y.shape(), Shape{});
535
+ CHECK_EQ(y.item<int>(), 1);
536
+ CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument);
537
+ CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument);
538
+
539
+ x = array({1}, {1});
540
+ y = transpose(x);
541
+ CHECK_EQ(y.shape(), Shape{1});
542
+ CHECK_EQ(y.item<int>(), 1);
543
+
544
+ // Negative indices
545
+ y = transpose(x, {-1});
546
+ CHECK_EQ(y.shape(), Shape{1});
547
+ CHECK_EQ(y.item<int>(), 1);
548
+
549
+ CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument);
550
+ CHECK_THROWS_AS(transpose(x, {0, 0}), std::invalid_argument);
551
+
552
+ // Works with empty array
553
+ x = array({});
554
+ y = transpose(x);
555
+ CHECK_EQ(y.shape(), Shape{0});
556
+ y.eval();
557
+ CHECK_EQ(y.size(), 0);
558
+
559
+ x = array({1, 2, 3, 4, 5, 6}, {2, 3});
560
+ y = transpose(x);
561
+ CHECK_EQ(y.shape(), Shape{3, 2});
562
+ y = transpose(x, {-1, 0});
563
+ CHECK_EQ(y.shape(), Shape{3, 2});
564
+ y = transpose(x, {-1, -2});
565
+ CHECK_EQ(y.shape(), Shape{3, 2});
566
+ y.eval();
567
+ CHECK(array_equal(y, array({1, 4, 2, 5, 3, 6}, {3, 2})).item<bool>());
568
+ y = transpose(x, {0, 1});
569
+ CHECK_EQ(y.shape(), Shape{2, 3});
570
+ CHECK(array_equal(y, x).item<bool>());
571
+ y = transpose(x, {0, -1});
572
+ CHECK_EQ(y.shape(), Shape{2, 3});
573
+ CHECK(array_equal(y, x).item<bool>());
574
+
575
+ CHECK_THROWS_AS(transpose(x, {}), std::invalid_argument);
576
+ CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument);
577
+ CHECK_THROWS_AS(transpose(x, {0, 0}), std::invalid_argument);
578
+ CHECK_THROWS_AS(transpose(x, {0, 0, 0}), std::invalid_argument);
579
+ CHECK_THROWS_AS(transpose(x, {0, 1, 1}), std::invalid_argument);
580
+
581
+ x = array({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 3, 2});
582
+ y = transpose(x);
583
+ CHECK_EQ(y.shape(), Shape{2, 3, 2});
584
+ auto expected = array({1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12}, {2, 3, 2});
585
+ CHECK(array_equal(y, expected).item<bool>());
586
+
587
+ y = transpose(x, {0, 1, 2});
588
+ CHECK_EQ(y.shape(), Shape{2, 3, 2});
589
+ CHECK(array_equal(y, x).item<bool>());
590
+ y = transpose(x, {1, 0, 2});
591
+ CHECK_EQ(y.shape(), Shape{3, 2, 2});
592
+ expected = array({1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}, {3, 2, 2});
593
+ CHECK(array_equal(y, expected).item<bool>());
594
+ y = transpose(x, {0, 2, 1});
595
+ CHECK_EQ(y.shape(), Shape{2, 2, 3});
596
+ expected = array({1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}, {2, 2, 3});
597
+ CHECK(array_equal(y, expected).item<bool>());
598
+
599
+ // Check reshaping a transposed array
600
+ x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2});
601
+ x = reshape(transpose(x), {2, 2, 2});
602
+ expected = array({0, 2, 4, 6, 1, 3, 5, 7}, {2, 2, 2});
603
+ CHECK(array_equal(x, expected).item<bool>());
604
+
605
+ // Check maintaining contiguous status
606
+ x = array({0, 1, 2, 3, 4, 5, 6, 7}, {1, 4, 1, 2});
607
+ CHECK(x.flags().row_contiguous);
608
+ x = transpose(x, {2, 1, 0, 3});
609
+ eval(x);
610
+ CHECK(x.flags().row_contiguous);
611
+ }
612
+
613
+ TEST_CASE("test comparison ops") {
614
+ // Empty array
615
+ {
616
+ array x({});
617
+ array y({});
618
+ auto z = x == y;
619
+ CHECK_EQ(z.dtype(), bool_);
620
+ CHECK_EQ(z.shape(), Shape{0});
621
+ }
622
+
623
+ // Basic cases
624
+ {
625
+ array x(1.0);
626
+ array y(1.0);
627
+ CHECK(equal(x, y).item<bool>());
628
+ CHECK((x == y).item<bool>());
629
+ CHECK((x == 1.0f).item<bool>());
630
+ CHECK((1.0f == y).item<bool>());
631
+
632
+ CHECK(!(x != y).item<bool>());
633
+ CHECK(!not_equal(x, y).item<bool>());
634
+ CHECK(!(1.0f != y).item<bool>());
635
+ CHECK(!(x != 1.0f).item<bool>());
636
+
637
+ CHECK(array_equal(x, y).item<bool>());
638
+
639
+ x = array(0.0);
640
+ CHECK(!equal(x, y).item<bool>());
641
+ CHECK(!array_equal(x, y).item<bool>());
642
+ CHECK(not_equal(x, y).item<bool>());
643
+ }
644
+
645
+ // Greater and less
646
+ {
647
+ array x(1.0);
648
+ array y(0.0);
649
+ CHECK(greater(x, y).item<bool>());
650
+ CHECK((x > 0.0f).item<bool>());
651
+ CHECK((1.0f > y).item<bool>());
652
+ CHECK(greater_equal(x, y).item<bool>());
653
+ CHECK((1.0f >= y).item<bool>());
654
+ CHECK(!(x > 1.0f).item<bool>());
655
+ CHECK((x >= 1.0f).item<bool>());
656
+
657
+ CHECK(less(y, x).item<bool>());
658
+ CHECK((y < 1.0).item<bool>());
659
+ CHECK((y <= 1.0f).item<bool>());
660
+ CHECK(!(x < 1.0).item<bool>());
661
+ CHECK((x <= 1.0f).item<bool>());
662
+ }
663
+
664
+ // Check array_equal works
665
+ {
666
+ auto x = zeros({5, 5});
667
+ auto y = zeros({5, 5});
668
+ CHECK(array_equal(x, y).item<bool>());
669
+
670
+ x = zeros({1, 1});
671
+ CHECK(!array_equal(x, y).item<bool>());
672
+
673
+ x = ones({5, 5});
674
+ CHECK(!array_equal(x, y).item<bool>());
675
+
676
+ x = array({0.0f, 1.0f, NAN});
677
+ y = array({0.0f, 1.0f, NAN});
678
+ CHECK(!array_equal(x, y).item<bool>());
679
+ CHECK(array_equal(x, y, true).item<bool>());
680
+ }
681
+
682
+ // Check other types
683
+ {
684
+ auto x = zeros({5, 5}, int32);
685
+ auto y = zeros({5, 5}, int32);
686
+ CHECK(array_equal(x, y).item<bool>());
687
+
688
+ x = ones({5, 5}, bool_);
689
+ y = ones({5, 5}, bool_);
690
+ CHECK(array_equal(x, y).item<bool>());
691
+ }
692
+
693
+ // Check type promotion
694
+ {
695
+ array x(1.0f);
696
+ array y(1);
697
+ CHECK_EQ(equal(x, y).item<bool>(), true);
698
+
699
+ x = array(true, bool_);
700
+ CHECK_EQ(equal(x, y).item<bool>(), true);
701
+ }
702
+
703
+ // Broadcasting works
704
+ {
705
+ auto x = zeros({1, 2});
706
+ auto y = zeros({2, 1});
707
+ auto z = equal(x, y);
708
+ CHECK_EQ(z.dtype(), bool_);
709
+ CHECK_EQ(z.shape(), Shape{2, 2});
710
+ auto expected = array({true, true, true, true}, {2, 2});
711
+ CHECK(array_equal(z, expected).item<bool>());
712
+
713
+ x = array({1.0, 2.0}, {1, 2});
714
+ y = array({1.0, 2.0}, {2, 1});
715
+ z = equal(x, y);
716
+ CHECK_EQ(z.dtype(), bool_);
717
+ CHECK_EQ(z.shape(), Shape{2, 2});
718
+ expected = array({true, false, false, true}, {2, 2});
719
+ CHECK(array_equal(z, expected).item<bool>());
720
+
721
+ expected = array({false, true, false, false}, {2, 2});
722
+ z = greater(x, y);
723
+ CHECK(array_equal(z, expected).item<bool>());
724
+
725
+ expected = array({true, true, false, true}, {2, 2});
726
+ z = greater_equal(x, y);
727
+ CHECK(array_equal(z, expected).item<bool>());
728
+
729
+ expected = array({false, false, true, false}, {2, 2});
730
+ z = less(x, y);
731
+ CHECK(array_equal(z, expected).item<bool>());
732
+
733
+ expected = array({true, false, true, true}, {2, 2});
734
+ z = less_equal(x, y);
735
+ CHECK(array_equal(z, expected).item<bool>());
736
+ }
737
+ }
738
+
739
+ TEST_CASE("test is nan") {
740
+ array x(1.0f);
741
+ CHECK_FALSE(isnan(x).item<bool>());
742
+
743
+ array y(NAN);
744
+ CHECK(isnan(y).item<bool>());
745
+
746
+ array z = identity(7);
747
+ CHECK_FALSE(all(isnan(z)).item<bool>());
748
+
749
+ array w = array({1.0f, NAN, 2.0f});
750
+ CHECK_FALSE(all(isnan(w)).item<bool>());
751
+
752
+ array a(1.0f, bfloat16);
753
+ CHECK_FALSE(isnan(a).item<bool>());
754
+
755
+ array b(1.0f, float16);
756
+ CHECK_FALSE(isnan(b).item<bool>());
757
+
758
+ array c(NAN, bfloat16);
759
+ CHECK(isnan(c).item<bool>());
760
+
761
+ array d(NAN, float16);
762
+ CHECK(isnan(d).item<bool>());
763
+ }
764
+
765
+ TEST_CASE("test is inf") {
766
+ array x(1.0f);
767
+ CHECK_FALSE(isinf(x).item<bool>());
768
+
769
+ auto inf = std::numeric_limits<float>::infinity();
770
+
771
+ array y(inf);
772
+ CHECK(isinf(y).item<bool>());
773
+
774
+ auto neginf = -std::numeric_limits<float>::infinity();
775
+ CHECK(isinf(array(neginf)).item<bool>());
776
+
777
+ array z = identity(7);
778
+ CHECK_FALSE(any(isinf(z)).item<bool>());
779
+
780
+ array w = array({1.0f, inf, 2.0f});
781
+ CHECK(array_equal(array({false, true, false}), isinf(w)).item<bool>());
782
+
783
+ array a(1.0f, bfloat16);
784
+ CHECK_FALSE(isinf(a).item<bool>());
785
+
786
+ array b(1.0f, float16);
787
+ CHECK_FALSE(isinf(b).item<bool>());
788
+
789
+ array c(inf, bfloat16);
790
+ CHECK(isinf(c).item<bool>());
791
+
792
+ array d(inf, float16);
793
+ CHECK(isinf(d).item<bool>());
794
+ }
795
+
796
+ TEST_CASE("test all close") {
797
+ array x(1.0f);
798
+ array y(1.0f);
799
+ CHECK(allclose(x, y).item<bool>());
800
+
801
+ y = array(1.1f);
802
+ CHECK_FALSE(allclose(x, y).item<bool>());
803
+ CHECK(allclose(x, y, 0.1).item<bool>());
804
+ CHECK_FALSE(allclose(x, y, 0.01).item<bool>());
805
+ CHECK(allclose(x, y, 0.01, 0.1).item<bool>());
806
+ }
807
+
808
+ TEST_CASE("test is close") {
809
+ {
810
+ array a({1.0, std::numeric_limits<float>::infinity()});
811
+ array b({1.0, std::numeric_limits<float>::infinity()});
812
+ CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());
813
+ }
814
+ {
815
+ array a({1.0, -std::numeric_limits<float>::infinity()});
816
+ array b({1.0, -std::numeric_limits<float>::infinity()});
817
+ CHECK(array_equal(isclose(a, b), array({true, true})).item<bool>());
818
+ }
819
+ {
820
+ array a({1.0, std::numeric_limits<float>::infinity()});
821
+ array b({1.0, -std::numeric_limits<float>::infinity()});
822
+ CHECK(array_equal(isclose(a, b), array({true, false})).item<bool>());
823
+ }
824
+ {
825
+ array a({1.0, std::nan("1"), std::nan("1")});
826
+ array b({1.0, std::nan("1"), 2.0});
827
+ CHECK(array_equal(isclose(a, b), array({true, false, false})).item<bool>());
828
+ }
829
+ {
830
+ array a({1.0, std::nan("1"), std::nan("1")});
831
+ array b({1.0, std::nan("1"), 2.0});
832
+ CHECK(
833
+ array_equal(isclose(a, b, 1e-5, 1e-8, true), array({true, true, false}))
834
+ .item<bool>());
835
+ }
836
+ }
837
+
838
+ TEST_CASE("test reduction ops") {
839
+ // Check shapes and throws correctly
840
+ {
841
+ auto x = array(1);
842
+ auto out = sum(x);
843
+ CHECK_EQ(out.ndim(), 0);
844
+ CHECK_THROWS_AS(sum(x, 0), std::out_of_range);
845
+ CHECK_THROWS_AS(sum(x, -1), std::out_of_range);
846
+ out = sum(x, std::vector<int>{});
847
+ CHECK_EQ(out.shape(), Shape{});
848
+ CHECK_EQ(out.size(), 1);
849
+
850
+ x = array({});
851
+ out = sum(x);
852
+ CHECK_EQ(out.shape(), Shape{});
853
+ CHECK_EQ(out.size(), 1);
854
+ out = sum(x, true);
855
+ CHECK_EQ(out.shape(), Shape{1});
856
+ out = sum(x, std::vector<int>{});
857
+ CHECK_EQ(out.shape(), x.shape());
858
+
859
+ x = zeros({2});
860
+ out = sum(x);
861
+ CHECK_EQ(out.ndim(), 0);
862
+ out = sum(x, -1);
863
+ CHECK_EQ(out.ndim(), 0);
864
+ out = sum(x, -1, true);
865
+ CHECK_EQ(out.ndim(), 1);
866
+ CHECK_EQ(out.shape(), Shape{1});
867
+
868
+ CHECK_THROWS_AS(sum(x, 1), std::out_of_range);
869
+ CHECK_THROWS_AS(sum(x, -2), std::out_of_range);
870
+ CHECK_THROWS_AS(sum(x, {0, 0}), std::invalid_argument);
871
+ CHECK_THROWS_AS(sum(x, {-1, 0}), std::invalid_argument);
872
+
873
+ x = zeros({2, 3, 4});
874
+ out = sum(x, {0, 2});
875
+ CHECK_EQ(out.shape(), Shape{3});
876
+ out = sum(x, std::vector<int>{});
877
+ CHECK_EQ(out.shape(), x.shape());
878
+
879
+ out = sum(x, {0, -1});
880
+ CHECK_EQ(out.shape(), Shape{3});
881
+
882
+ out = sum(x, {0, -1}, true);
883
+ CHECK_EQ(out.shape(), Shape{1, 3, 1});
884
+
885
+ out = sum(x, true);
886
+ CHECK_EQ(out.shape(), Shape{1, 1, 1});
887
+
888
+ out = sum(x);
889
+ CHECK_EQ(out.shape(), Shape{});
890
+
891
+ CHECK_THROWS_AS(sum(x, 3), std::out_of_range);
892
+ CHECK_THROWS_AS(sum(x, -4), std::out_of_range);
893
+ CHECK_THROWS_AS(sum(x, {0, 1, -2}), std::invalid_argument);
894
+ }
895
+
896
+ // Test sum
897
+ {
898
+ auto x = array({});
899
+ CHECK_EQ(sum(x).item<float>(), 0.0f);
900
+
901
+ x = array({1, 2, 3});
902
+ CHECK_EQ(sum(x).item<int>(), 6);
903
+ CHECK(array_equal(sum(x, std::vector<int>{}), x).item<bool>());
904
+
905
+ x = ones({2, 3});
906
+ CHECK(array_equal(sum(x, 1), full({2}, 3.0f)).item<bool>());
907
+ CHECK(array_equal(sum(x, 0), full({3}, 2.0f)).item<bool>());
908
+ CHECK_EQ(sum(x, {0, 1}).item<float>(), 6.0f);
909
+
910
+ x = ones({2, 3, 4});
911
+ CHECK(array_equal(sum(x, 0), full({3, 4}, 2.0f)).item<bool>());
912
+ CHECK(array_equal(sum(x, 1), full({2, 4}, 3.0f)).item<bool>());
913
+ CHECK(array_equal(sum(x, 2), full({2, 3}, 4.0f)).item<bool>());
914
+ CHECK(array_equal(sum(x, {0, 1}), full({4}, 6.0f)).item<bool>());
915
+ CHECK(array_equal(sum(x, {0, 2}), full({3}, 8.0f)).item<bool>());
916
+ CHECK(array_equal(sum(x, {1, 2}), full({2}, 12.0f)).item<bool>());
917
+
918
+ // Output for bool gets higher precision
919
+ x = array({true, true, true});
920
+ CHECK_EQ(sum(x).item<int32_t>(), 3);
921
+
922
+ x = array(2.0f);
923
+ x = broadcast_to(x, {2, 2, 2});
924
+ CHECK_EQ(sum(x).item<float>(), 16.0f);
925
+
926
+ // Tests with non-uniform results after reduction
927
+ x = array({1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}, {2, 3});
928
+ CHECK(array_equal(sum(x, 0), full({3}, 3.0f)).item<bool>());
929
+ CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item<bool>());
930
+ }
931
+
932
+ // Test unsigned sum
933
+ {
934
+ const int num_elems = 1000;
935
+
936
+ auto x = astype(full({num_elems}, 255), uint8);
937
+ CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 255 * num_elems);
938
+
939
+ x = astype(full({num_elems}, 65535), uint16);
940
+ CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 65535 * num_elems);
941
+
942
+ x = full({3, 3, 3}, 10000, uint32);
943
+ CHECK_EQ(sum(x, Device::cpu).item<uint32_t>(), 270000);
944
+
945
+ x = full({3, 3, 3}, 10000, uint64);
946
+ CHECK_EQ(sum(x, Device::cpu).item<uint64_t>(), 270000);
947
+ }
948
+
949
+ // Test prod
950
+ {
951
+ auto x = array({});
952
+ CHECK_EQ(prod(x).item<float>(), 1.0f);
953
+
954
+ x = array({2, 2, 2});
955
+ CHECK_EQ(prod(x).item<int>(), 8);
956
+ CHECK(array_equal(prod(x, std::vector<int>{}), x).item<bool>());
957
+
958
+ x = full({2, 3}, 2.0f);
959
+ CHECK(array_equal(prod(x, 1), full({2}, 8.0f)).item<bool>());
960
+ CHECK(array_equal(prod(x, 0), full({3}, 4.0f)).item<bool>());
961
+ CHECK_EQ(prod(x, {0, 1}).item<float>(), 64.0f);
962
+
963
+ x = full({2, 3, 4}, 2.0f);
964
+ CHECK(array_equal(prod(x, 0), full({3, 4}, 4.0f)).item<bool>());
965
+ CHECK(array_equal(prod(x, 1), full({2, 4}, 8.0f)).item<bool>());
966
+ CHECK(array_equal(prod(x, 2), full({2, 3}, 16.0f)).item<bool>());
967
+ CHECK(array_equal(prod(x, {0, 1}), full({4}, 64.0f)).item<bool>());
968
+ CHECK(array_equal(prod(x, {0, 2}), full({3}, 256.0f)).item<bool>());
969
+ CHECK(array_equal(prod(x, {1, 2}), full({2}, 4096.0f)).item<bool>());
970
+
971
+ // Tests with non-uniform results after reduction
972
+ x = array({1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}, {2, 3});
973
+ CHECK(array_equal(prod(x, 0), full({3}, 2.0f)).item<bool>());
974
+ CHECK(array_equal(prod(x, 1), array({1.0f, 8.0f}, {2})).item<bool>());
975
+
976
+ x = array({true, true, true, false, true, false}, {2, 3});
977
+ CHECK(array_equal(prod(x, 0), array({false, true, false})).item<bool>());
978
+ CHECK(array_equal(prod(x, 1), array({true, false})).item<bool>());
979
+ }
980
+
981
+ // Test unsigned prod
982
+ {
983
+ auto x = array({255, 255}, {2}, uint8);
984
+ CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 65025);
985
+
986
+ x = array({65535, 2}, {2}, uint16);
987
+ CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 131070);
988
+
989
+ x = array({100000, 2}, {2}, uint32);
990
+ CHECK_EQ(prod(x, Device::cpu).item<uint32_t>(), 200000);
991
+
992
+ x = array({100000, 2}, {2}, uint64);
993
+ CHECK_EQ(prod(x, Device::cpu).item<uint64_t>(), 200000);
994
+ }
995
+
996
+ // Test all
997
+ {
998
+ auto x = array({});
999
+ CHECK_EQ(all(x).item<bool>(), true);
1000
+
1001
+ x = array({2, 2, 2});
1002
+ CHECK_EQ(all(x).item<bool>(), true);
1003
+ auto out = all(x, std::vector<int>{});
1004
+ CHECK(array_equal(out, array({true, true, true})).item<bool>());
1005
+
1006
+ x = array({0, 2, 2});
1007
+ CHECK_EQ(all(x).item<bool>(), false);
1008
+
1009
+ x = array({true, true, true, false, true, false}, {2, 3});
1010
+ CHECK(array_equal(all(x, 1), array({true, false})).item<bool>());
1011
+ CHECK(array_equal(all(x, 0), array({false, true, false})).item<bool>());
1012
+ }
1013
+
1014
+ // Test any
1015
+ {
1016
+ auto x = array({});
1017
+ CHECK_EQ(any(x).item<bool>(), false);
1018
+
1019
+ x = array({0, 0, 0});
1020
+ CHECK_EQ(any(x).item<bool>(), false);
1021
+
1022
+ x = array({0, 2, 0});
1023
+ CHECK_EQ(any(x).item<bool>(), true);
1024
+ auto out = any(x, std::vector<int>{});
1025
+ CHECK(array_equal(out, array({false, true, false})).item<bool>());
1026
+
1027
+ x = array({true, false, true, false, false, false}, {2, 3});
1028
+ CHECK(array_equal(any(x, 1), array({true, false})).item<bool>());
1029
+ CHECK(array_equal(any(x, 0), array({true, false, true})).item<bool>());
1030
+ }
1031
+
1032
+ // Test max and min
1033
+ {
1034
+ auto x = array({});
1035
+ CHECK_THROWS(max(x));
1036
+ CHECK_THROWS(min(x));
1037
+
1038
+ x = array({1.0f, 2.0f, 3.0f});
1039
+ CHECK_EQ(max(x).item<float>(), 3.0f);
1040
+ CHECK_EQ(min(x).item<float>(), 1.0f);
1041
+
1042
+ x = array({-2.0f, -1.0f});
1043
+ CHECK_EQ(max(x).item<float>(), -1.0f);
1044
+ CHECK_EQ(min(x).item<float>(), -2.0f);
1045
+
1046
+ constexpr float inf = std::numeric_limits<float>::infinity();
1047
+ x = array({inf});
1048
+ CHECK_EQ(min(x).item<float>(), inf);
1049
+
1050
+ x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});
1051
+ CHECK(array_equal(max(x, 0), array({4.0f, 5.0f, 6.0f})).item<bool>());
1052
+ CHECK(array_equal(max(x, 1), array({3.0f, 6.0f})).item<bool>());
1053
+ CHECK(array_equal(min(x, 0), array({1.0f, 2.0f, 3.0f})).item<bool>());
1054
+ CHECK(array_equal(min(x, 1), array({1.0f, 4.0f})).item<bool>());
1055
+
1056
+ x = array({1u, 2u, 3u});
1057
+ CHECK_EQ(max(x).item<uint32_t>(), 3u);
1058
+ CHECK_EQ(min(x).item<uint32_t>(), 1u);
1059
+
1060
+ x = array({1u, 2u, 3u, 4u, 5u, 6u}, {2, 3});
1061
+ CHECK(array_equal(max(x, 0), array({4u, 5u, 6u})).item<bool>());
1062
+ CHECK(array_equal(max(x, 1), array({3u, 6u})).item<bool>());
1063
+ CHECK(array_equal(min(x, 0), array({1u, 2u, 3u})).item<bool>());
1064
+ CHECK(array_equal(min(x, 1), array({1u, 4u})).item<bool>());
1065
+
1066
+ x = array({true, false, true, false, false, false}, {2, 3});
1067
+ CHECK(array_equal(max(x, 1), array({true, false})).item<bool>());
1068
+ CHECK(array_equal(max(x, 0), array({true, false, true})).item<bool>());
1069
+
1070
+ x = array({true, true, true, false, true, false}, {2, 3});
1071
+ CHECK(array_equal(min(x, 1), array({true, false})).item<bool>());
1072
+ CHECK(array_equal(min(x, 0), array({false, true, false})).item<bool>());
1073
+
1074
+ x = array({1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});
1075
+ CHECK(array_equal(max(x, 0), array({4.0f, NAN, 6.0f}), true).item<bool>());
1076
+ CHECK(array_equal(max(x, 1), array({NAN, 6.0f}), true).item<bool>());
1077
+ }
1078
+
1079
+ // Test logsumexp
1080
+ {
1081
+ auto x = array({});
1082
+ CHECK_THROWS(logsumexp(x));
1083
+
1084
+ constexpr float inf = std::numeric_limits<float>::infinity();
1085
+
1086
+ x = array({-inf, -inf});
1087
+ CHECK_EQ(logsumexp(x).item<float>(), -inf);
1088
+
1089
+ x = repeat(array(-inf), 5000);
1090
+ CHECK_EQ(logsumexp(x).item<float>(), -inf);
1091
+
1092
+ x = array({0.0f, -inf});
1093
+ CHECK_EQ(logsumexp(x).item<float>(), 0.0f);
1094
+
1095
+ x = array({0.0f, inf});
1096
+ CHECK_EQ(logsumexp(x).item<float>(), inf);
1097
+
1098
+ x = reshape(arange(6, float32), {2, 3});
1099
+
1100
+ std::vector<float> nums = {0.0f, 1.0f, 2.0f, 3.0f};
1101
+ x = array(nums.data(), {2, 2});
1102
+ auto y = logsumexp(x, {0, 1}, true);
1103
+ CHECK_EQ(y.shape(), Shape{1, 1});
1104
+ auto result = std::log(
1105
+ std::exp(nums[0]) + std::exp(nums[1]) + std::exp(nums[2]) +
1106
+ std::exp(nums[3]));
1107
+ CHECK(y.item<float>() == doctest::Approx(result));
1108
+ auto expected = array(
1109
+ {std::log(std::exp(nums[0]) + std::exp(nums[2])),
1110
+ std::log(std::exp(nums[1]) + std::exp(nums[3]))});
1111
+ CHECK(allclose(logsumexp(x, 0), expected).item<bool>());
1112
+
1113
+ expected = array(
1114
+ {std::log(std::exp(nums[0]) + std::exp(nums[1])),
1115
+ std::log(std::exp(nums[2]) + std::exp(nums[3]))});
1116
+ CHECK(allclose(logsumexp(x, 1), expected).item<bool>());
1117
+ }
1118
+
1119
+ // Test softmax
1120
+ {
1121
+ for (auto t : {float16, bfloat16, float32}) {
1122
+ const auto rtol = t == float32 ? 1e-5 : 1e-2;
1123
+ auto x = array({}, t);
1124
+ CHECK(array_equal(x, softmax(x)).item<bool>());
1125
+
1126
+ // all zeros
1127
+ x = array({0., 0., 0., 0.}, t);
1128
+ auto y = array({0.25, 0.25, 0.25, 0.25}, t);
1129
+ CHECK(array_equal(y, softmax(x)).item<bool>());
1130
+ CHECK(array_equal(y, softmax(x, -1)).item<bool>());
1131
+ CHECK(array_equal(y, softmax(x, std::vector<int>{-1})).item<bool>());
1132
+ CHECK(array_equal(y, softmax(x, std::vector<int>{0})).item<bool>());
1133
+
1134
+ auto ones = array(1.0f, t);
1135
+ CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
1136
+
1137
+ // all ones
1138
+ x = array({1., 1., 1., 1.}, t);
1139
+ CHECK(array_equal(y, softmax(x)).item<bool>());
1140
+ CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
1141
+
1142
+ // negative values
1143
+ x = array({-1., -2., -3., -4.}, t);
1144
+ y = array({0.643914, 0.236883, 0.0871443, 0.0320586}, t);
1145
+ CHECK(allclose(y, softmax(x), rtol).item<bool>());
1146
+ CHECK(allclose(ones, sum(softmax(x)), rtol).item<bool>());
1147
+
1148
+ // positive and negative values
1149
+ x = array({1., 0., -1., 0.}, t);
1150
+ y = array({0.534447, 0.196612, 0.0723295, 0.196612}, t);
1151
+ CHECK(allclose(y, softmax(x), rtol).item<bool>());
1152
+ CHECK(allclose(ones, sum(softmax(x)), rtol).item<bool>());
1153
+
1154
+ // large positive values
1155
+ x = array({1000., 1000., 1000.}, t);
1156
+ y = array({0.333333, 0.333333, 0.333333}, t);
1157
+ CHECK(allclose(y, softmax(x)).item<bool>());
1158
+ CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
1159
+
1160
+ // large negative values
1161
+ x = negative(x);
1162
+ CHECK(allclose(y, softmax(x)).item<bool>());
1163
+ CHECK(array_equal(ones, sum(softmax(x))).item<bool>());
1164
+ }
1165
+ }
1166
+ }
1167
+
1168
+ TEST_CASE("test irregular binary ops") {
1169
+ // 1D strided
1170
+ {
1171
+ auto x = full({128}, 1.0f);
1172
+ auto y = full({64}, 1.0f);
1173
+ x = slice(x, {0}, {128}, {4});
1174
+ y = slice(y, {0}, {64}, {2});
1175
+ CHECK(array_equal(add(x, y), full({32}, 2.0f)).item<bool>());
1176
+ }
1177
+
1178
+ // 2D broadcasts
1179
+ {
1180
+ auto x = full({32, 32}, 4.0f);
1181
+ auto y = full({32}, 4.0f);
1182
+ CHECK(array_equal(add(x, y), full({32, 32}, 8.0f)).item<bool>());
1183
+ y = reshape(y, {32, 1});
1184
+ CHECK(array_equal(add(x, y), full({32, 32}, 8.0f)).item<bool>());
1185
+ CHECK(array_equal(subtract(y, x), zeros({32, 32})).item<bool>());
1186
+ }
1187
+ }
1188
+
1189
+ TEST_CASE("test arithmetic unary ops") {
1190
+ // Test negative
1191
+ {
1192
+ array x(1.0f);
1193
+ CHECK_EQ(negative(x).item<float>(), -1.0f);
1194
+ CHECK_EQ((-x).item<float>(), -1.0f);
1195
+
1196
+ // works on empty array
1197
+ CHECK(array_equal(-array({}), array({})).item<bool>());
1198
+
1199
+ // Throws on bool
1200
+ CHECK_THROWS(negative(array(true)));
1201
+ }
1202
+
1203
+ // Test logical not
1204
+ {
1205
+ array x(false);
1206
+ CHECK_EQ(logical_not(x).item<bool>(), true);
1207
+
1208
+ x = array(1.0f);
1209
+ auto y = logical_not(x);
1210
+ CHECK_EQ(y.dtype(), bool_);
1211
+ CHECK_EQ(y.item<bool>(), false);
1212
+
1213
+ x = array(0);
1214
+ y = logical_not(x);
1215
+ CHECK_EQ(y.dtype(), bool_);
1216
+ CHECK_EQ(y.item<bool>(), true);
1217
+ }
1218
+
1219
+ // Test logical and
1220
+ {
1221
+ array x(true);
1222
+ array y(true);
1223
+ CHECK_EQ(logical_and(x, y).item<bool>(), true);
1224
+
1225
+ x = array(1.0f);
1226
+ y = array(1.0f);
1227
+ auto z = logical_and(x, y);
1228
+ CHECK_EQ(z.dtype(), bool_);
1229
+ CHECK_EQ(z.item<bool>(), true);
1230
+
1231
+ x = array(0);
1232
+ y = array(1.0f);
1233
+ z = logical_and(x, y);
1234
+ CHECK_EQ(z.dtype(), bool_);
1235
+ CHECK_EQ(z.item<bool>(), false);
1236
+ }
1237
+
1238
+ // Test logical or
1239
+ {
1240
+ array x(false);
1241
+ array y(false);
1242
+ CHECK_EQ(logical_or(x, y).item<bool>(), false);
1243
+
1244
+ x = array(1.0f);
1245
+ y = array(1.0f);
1246
+ auto z = logical_or(x, y);
1247
+ CHECK_EQ(z.dtype(), bool_);
1248
+ CHECK_EQ(z.item<bool>(), true);
1249
+
1250
+ x = array(0);
1251
+ y = array(1.0f);
1252
+ z = logical_or(x, y);
1253
+ CHECK_EQ(z.dtype(), bool_);
1254
+ CHECK_EQ(z.item<bool>(), true);
1255
+ }
1256
+
1257
+ // Test abs
1258
+ {
1259
+ array x({-1.0f, 0.0f, 1.0f});
1260
+ CHECK(array_equal(abs(x), array({1.0f, 0.0f, 1.0f})).item<bool>());
1261
+
1262
+ // works on empty array
1263
+ CHECK(array_equal(abs(array({})), array({})).item<bool>());
1264
+
1265
+ // int32
1266
+ x = array({-1, 0, 1});
1267
+ CHECK(array_equal(abs(x), array({1, 0, 1})).item<bool>());
1268
+
1269
+ // uint32
1270
+ x = array({1u, 0u, 1u});
1271
+ CHECK(array_equal(abs(x), array({1u, 0u, 1u})).item<bool>());
1272
+
1273
+ // bool
1274
+ x = array({false, true});
1275
+ CHECK(array_equal(abs(x), array({false, true})).item<bool>());
1276
+ }
1277
+
1278
+ // Test sign
1279
+ {
1280
+ array x({-1.0f, 0.0f, 1.0f});
1281
+ CHECK(array_equal(sign(x), x).item<bool>());
1282
+
1283
+ // works on empty array
1284
+ CHECK(array_equal(sign(array({})), array({})).item<bool>());
1285
+
1286
+ // int32
1287
+ x = array({-1, 0, 1});
1288
+ CHECK(array_equal(sign(x), x).item<bool>());
1289
+
1290
+ // uint32
1291
+ x = array({1u, 0u, 1u});
1292
+ CHECK(array_equal(sign(x), x).item<bool>());
1293
+
1294
+ // bool
1295
+ x = array({false, true});
1296
+ CHECK(array_equal(sign(x), x).item<bool>());
1297
+
1298
+ // uint64
1299
+ array x_uint64(
1300
+ {uint64_t(0xa11cc311cb6acd70),
1301
+ uint64_t(0x7a375ac3ebb533f3),
1302
+ uint64_t(0x734969adf9d7190c),
1303
+ uint64_t(0xb400515a4f673424)});
1304
+ array expected(
1305
+ {uint64_t(0x0000000000000001),
1306
+ uint64_t(0x0000000000000001),
1307
+ uint64_t(0x0000000000000001),
1308
+ uint64_t(0x0000000000000001)});
1309
+ CHECK(array_equal(sign(x_uint64), expected).item<bool>());
1310
+
1311
+ x_uint64 = array(
1312
+ {uint64_t(0xa11cc311cb6acd70),
1313
+ uint64_t(0x7a375ac3ebb533f3),
1314
+ uint64_t(0x734969adf9d7190c)});
1315
+ expected = array(
1316
+ {uint64_t(0x0000000000000001),
1317
+ uint64_t(0x0000000000000001),
1318
+ uint64_t(0x0000000000000001)});
1319
+ CHECK(array_equal(sign(x_uint64), expected).item<bool>());
1320
+
1321
+ x_uint64 =
1322
+ array({uint64_t(0xa11cc311cb6acd70), uint64_t(0x7a375ac3ebb533f3)});
1323
+ expected =
1324
+ array({uint64_t(0x0000000000000001), uint64_t(0x0000000000000001)});
1325
+ CHECK(array_equal(sign(x_uint64), expected).item<bool>());
1326
+
1327
+ x_uint64 = array({uint64_t(0xa11cc311cb6acd70)});
1328
+ expected = array({uint64_t(0x0000000000000001)});
1329
+ CHECK(array_equal(sign(x_uint64), expected).item<bool>());
1330
+
1331
+ x_uint64 = array({uint64_t(0xffffffffffffffff)});
1332
+ expected = array({uint64_t(0x0000000000000001)});
1333
+ CHECK(array_equal(sign(x_uint64), expected).item<bool>());
1334
+
1335
+ x_uint64 = array({uint64_t(0x0000000000000001)});
1336
+ expected = array({uint64_t(0x0000000000000001)});
1337
+ CHECK(array_equal(sign(x_uint64), expected).item<bool>());
1338
+ }
1339
+
1340
+ constexpr float neginf = -std::numeric_limits<float>::infinity();
1341
+
1342
+ // Test floor and ceil
1343
+ {
1344
+ array x(1.0f);
1345
+ CHECK_EQ(floor(x).item<float>(), 1.0f);
1346
+ CHECK_EQ(ceil(x).item<float>(), 1.0f);
1347
+
1348
+ x = array(1.5f);
1349
+ CHECK_EQ(floor(x).item<float>(), 1.0f);
1350
+ CHECK_EQ(ceil(x).item<float>(), 2.0f);
1351
+
1352
+ x = array(-1.5f);
1353
+ CHECK_EQ(floor(x).item<float>(), -2.0f);
1354
+ CHECK_EQ(ceil(x).item<float>(), -1.0f);
1355
+
1356
+ x = array(neginf);
1357
+ CHECK_EQ(floor(x).item<float>(), neginf);
1358
+ CHECK_EQ(ceil(x).item<float>(), neginf);
1359
+
1360
+ x = array(std::complex<float>(1.0f, 1.0f));
1361
+ CHECK_THROWS_AS(floor(x), std::invalid_argument);
1362
+ CHECK_THROWS_AS(ceil(x), std::invalid_argument);
1363
+ }
1364
+
1365
+ // Test round
1366
+ {
1367
+ array x({0.5, -0.5, 1.5, -1.5, 2.3, 2.6});
1368
+ CHECK(array_equal(round(x), array({0, -0, 2, -2, 2, 3})).item<bool>());
1369
+
1370
+ x = array({11, 222, 32});
1371
+ CHECK(array_equal(round(x, -1), array({10, 220, 30})).item<bool>());
1372
+ }
1373
+
1374
+ // Test exponential
1375
+ {
1376
+ array x(0.0);
1377
+ CHECK_EQ(exp(x).item<float>(), 1.0);
1378
+
1379
+ x = array(2.0);
1380
+ CHECK_EQ(exp(x).item<float>(), doctest::Approx(std::exp(2.0f)));
1381
+
1382
+ CHECK(array_equal(exp(array({})), array({})).item<bool>());
1383
+
1384
+ x = array(neginf);
1385
+ CHECK_EQ(exp(x).item<float>(), doctest::Approx(0.0f));
1386
+
1387
+ // Integer input type
1388
+ x = array(2);
1389
+ CHECK_EQ(x.dtype(), int32);
1390
+ CHECK_EQ(exp(x).item<float>(), doctest::Approx(std::exp(2.0f)));
1391
+
1392
+ // Input is irregularly strided
1393
+ x = broadcast_to(array(1.0f), {2, 2, 2});
1394
+ CHECK(allclose(exp(x), full({2, 2, 2}, std::exp(1.0f))).item<bool>());
1395
+
1396
+ x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
1397
+ auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
1398
+ CHECK(allclose(exp(x), expected).item<bool>());
1399
+
1400
+ // Complex of -inf
1401
+ constexpr float inf = std::numeric_limits<float>::infinity();
1402
+ x = array(complex64_t{-inf, -inf});
1403
+ CHECK_EQ(exp(x).item<complex64_t>(), complex64_t{0, 0});
1404
+ }
1405
+
1406
+ // Test expm1
1407
+ {
1408
+ array x(-1.0f);
1409
+ CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(-1.0f)));
1410
+
1411
+ x = array(1.0f);
1412
+ CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));
1413
+
1414
+ // Integer input type
1415
+ x = array(1);
1416
+ CHECK_EQ(expm1(x).dtype(), float32);
1417
+ CHECK_EQ(expm1(x).item<float>(), doctest::Approx(std::expm1(1.0f)));
1418
+ }
1419
+
1420
+ // Test sine
1421
+ {
1422
+ array x(0.0);
1423
+ CHECK_EQ(sin(x).item<float>(), 0.0);
1424
+
1425
+ x = array(M_PI_2);
1426
+ CHECK(sin(x).item<float>() == doctest::Approx(std::sin(M_PI_2)));
1427
+
1428
+ CHECK(array_equal(sin(array({})), array({})).item<bool>());
1429
+
1430
+ // Integer input type
1431
+ x = array(0);
1432
+ CHECK_EQ(x.dtype(), int32);
1433
+ CHECK_EQ(sin(x).item<float>(), std::sin(0.0f));
1434
+
1435
+ // Input is irregularly strided
1436
+ x = broadcast_to(array(1.0f), {2, 2, 2});
1437
+ CHECK(allclose(sin(x), full({2, 2, 2}, std::sin(1.0f))).item<bool>());
1438
+
1439
+ x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
1440
+ auto expected = array({std::sin(0.0f), std::sin(2.0f)}, {2, 1});
1441
+ CHECK(allclose(sin(x), expected).item<bool>());
1442
+ }
1443
+
1444
+ // Test cos
1445
+ {
1446
+ array x(0.0);
1447
+ CHECK_EQ(cos(x).item<float>(), doctest::Approx(1.0));
1448
+
1449
+ x = array(M_PI_2);
1450
+ CHECK(cos(x).item<float>() == doctest::Approx(std::cos(M_PI_2)));
1451
+
1452
+ CHECK(array_equal(cos(array({})), array({})).item<bool>());
1453
+
1454
+ // Integer input type
1455
+ x = array(0);
1456
+ CHECK_EQ(x.dtype(), int32);
1457
+ CHECK(cos(x).item<float>() == doctest::Approx(std::cos(0.0f)));
1458
+
1459
+ // Input is irregularly strided
1460
+ x = broadcast_to(array(1.0f), {2, 2, 2});
1461
+ CHECK(allclose(cos(x), full({2, 2, 2}, std::cos(1.0f))).item<bool>());
1462
+
1463
+ x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
1464
+ auto expected = array({std::cos(0.0f), std::cos(2.0f)}, {2, 1});
1465
+ CHECK(allclose(cos(x), expected).item<bool>());
1466
+ }
1467
+
1468
+ // Test degrees
1469
+ {
1470
+ array x(0.0);
1471
+ CHECK_EQ(degrees(x).item<float>(), 0.0);
1472
+
1473
+ x = array(M_PI_2);
1474
+ CHECK(degrees(x).item<float>() == doctest::Approx(90.0));
1475
+
1476
+ CHECK(array_equal(degrees(array({})), array({})).item<bool>());
1477
+
1478
+ // Integer input type
1479
+ x = array(0);
1480
+ CHECK_EQ(x.dtype(), int32);
1481
+ CHECK_EQ(degrees(x).item<float>(), 0.0);
1482
+
1483
+ // Input is irregularly strided
1484
+ x = broadcast_to(array(M_PI_2), {2, 2, 2});
1485
+ CHECK(allclose(degrees(x), full({2, 2, 2}, 90.0)).item<bool>());
1486
+
1487
+ float angles[] = {0.0f, M_PI_2, M_PI, 3.0f * M_PI_2};
1488
+ x = split(array(angles, {2, 2}), 2, 1)[0];
1489
+ auto expected = array({0.0f, 180.0f}, {2, 1});
1490
+ CHECK(allclose(degrees(x), expected).item<bool>());
1491
+ }
1492
+
1493
+ // Test radians
1494
+ {
1495
+ array x(0.0);
1496
+ CHECK_EQ(radians(x).item<float>(), 0.0);
1497
+
1498
+ x = array(90.0);
1499
+ CHECK(radians(x).item<float>() == doctest::Approx(M_PI_2));
1500
+
1501
+ CHECK(array_equal(radians(array({})), array({})).item<bool>());
1502
+
1503
+ // Integer input type
1504
+ x = array(90);
1505
+ CHECK_EQ(x.dtype(), int32);
1506
+ CHECK(radians(x).item<float>() == doctest::Approx(M_PI_2));
1507
+
1508
+ // Input is irregularly strided
1509
+ x = broadcast_to(array(90.0f), {2, 2, 2});
1510
+ CHECK(allclose(radians(x), full({2, 2, 2}, M_PI_2)).item<bool>());
1511
+
1512
+ x = split(array({0.0f, 90.0f, 180.0f, 270.0f}, {2, 2}), 2, 1)[0];
1513
+ float angles[] = {0.0f, M_PI};
1514
+ auto expected = array(angles, {2, 1});
1515
+ CHECK(allclose(radians(x), expected).item<bool>());
1516
+ }
1517
+
1518
+ // Test log
1519
+ {
1520
+ array x(0.0);
1521
+ CHECK_EQ(log(x).item<float>(), neginf);
1522
+
1523
+ x = array(1.0);
1524
+ CHECK_EQ(log(x).item<float>(), log(1.0f));
1525
+
1526
+ // Integer input type
1527
+ x = array(1);
1528
+ CHECK_EQ(log(x).dtype(), float32);
1529
+ CHECK_EQ(log(x).item<float>(), log(1.0f));
1530
+
1531
+ // Input is irregularly strided
1532
+ x = broadcast_to(array(1.0f), {2, 2, 2});
1533
+ CHECK(array_equal(log(x), full({2, 2, 2}, std::log(1.0f))).item<bool>());
1534
+
1535
+ x = split(array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}), 2, 1)[0];
1536
+ auto expected = array({std::log(1.0f), std::log(3.0f)}, {2, 1});
1537
+ CHECK(array_equal(log(x), expected).item<bool>());
1538
+ }
1539
+
1540
+ // Test log2
1541
+ {
1542
+ array x(0.0);
1543
+ CHECK_EQ(log2(x).item<float>(), neginf);
1544
+
1545
+ x = array(1.0);
1546
+ CHECK_EQ(log2(x).item<float>(), 0.0f);
1547
+
1548
+ x = array(1024.0f);
1549
+ CHECK_EQ(log2(x).item<float>(), 10.0f);
1550
+ }
1551
+
1552
+ // Test log10
1553
+ {
1554
+ array x(0.0);
1555
+ CHECK_EQ(log10(x).item<float>(), neginf);
1556
+
1557
+ x = array(1.0);
1558
+ CHECK_EQ(log10(x).item<float>(), 0.0f);
1559
+
1560
+ x = array(1000.0f);
1561
+ CHECK_EQ(log10(x).item<float>(), 3.0f);
1562
+ }
1563
+
1564
+ // Test log1p
1565
+ {
1566
+ array x(-1.0f);
1567
+ CHECK_EQ(log1p(x).item<float>(), neginf);
1568
+
1569
+ x = array(1.0f);
1570
+ CHECK_EQ(log1p(x).item<float>(), std::log1pf(1.0f));
1571
+
1572
+ // Integer input type
1573
+ x = array(1);
1574
+ CHECK_EQ(log1p(x).dtype(), float32);
1575
+ CHECK_EQ(log1p(x).item<float>(), std::log1pf(1.0f));
1576
+
1577
+ // Input is irregularly strided
1578
+ x = broadcast_to(array(1.0f), {2, 2, 2});
1579
+ CHECK(
1580
+ array_equal(log1p(x), full({2, 2, 2}, std::log1pf(1.0f))).item<bool>());
1581
+
1582
+ x = split(array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}), 2, 1)[0];
1583
+ auto expected = array({std::log1pf(1.0f), std::log1pf(3.0f)}, {2, 1});
1584
+ CHECK(array_equal(log1p(x), expected).item<bool>());
1585
+ }
1586
+
1587
+ // Test sigmoid
1588
+ {
1589
+ array x(0.0);
1590
+ CHECK_EQ(sigmoid(x).item<float>(), 0.5f);
1591
+
1592
+ // Integer input type
1593
+ x = array(0);
1594
+ CHECK_EQ(sigmoid(x).dtype(), float32);
1595
+ CHECK_EQ(sigmoid(x).item<float>(), 0.5f);
1596
+
1597
+ constexpr auto inf = std::numeric_limits<float>::infinity();
1598
+ x = array(inf);
1599
+ CHECK_EQ(sigmoid(x).item<float>(), 1.0f);
1600
+ x = array(-inf);
1601
+ CHECK_EQ(sigmoid(x).item<float>(), 0.0f);
1602
+ }
1603
+
1604
+ // Test square
1605
+ {
1606
+ array x(3.0);
1607
+ CHECK_EQ(square(x).item<float>(), 9.0);
1608
+
1609
+ x = array(2);
1610
+ CHECK_EQ(square(x).item<int>(), 4);
1611
+
1612
+ x = full({3, 3}, 2.0f);
1613
+ CHECK(array_equal(square(x), full({3, 3}, 4.0f)).item<bool>());
1614
+ }
1615
+
1616
+ // Test sqrt and rsqrt
1617
+ {
1618
+ array x(4.0);
1619
+ CHECK_EQ(sqrt(x).item<float>(), 2.0);
1620
+ CHECK_EQ(rsqrt(x).item<float>(), 0.5);
1621
+
1622
+ x = full({3, 3}, 9.0f);
1623
+ CHECK(array_equal(sqrt(x), full({3, 3}, 3.0f)).item<bool>());
1624
+
1625
+ x = array(4, int32);
1626
+ CHECK_EQ(sqrt(x).item<float>(), 2.0f);
1627
+ CHECK_EQ(rsqrt(x).item<float>(), 0.5f);
1628
+ }
1629
+
1630
+ // Test reciprocal
1631
+ {
1632
+ array x(8.0);
1633
+ CHECK_EQ(reciprocal(x).item<float>(), 0.125f);
1634
+
1635
+ x = array(2);
1636
+ auto out = reciprocal(x);
1637
+ CHECK_EQ(out.dtype(), float32);
1638
+ CHECK_EQ(out.item<float>(), 0.5f);
1639
+
1640
+ x = full({3, 3}, 2.0f);
1641
+ CHECK(array_equal(reciprocal(x), full({3, 3}, 0.5f)).item<bool>());
1642
+ }
1643
+ }
1644
+
1645
+ TEST_CASE("test error functions") {
1646
+ constexpr float inf = std::numeric_limits<float>::infinity();
1647
+ array x(0.0f);
1648
+ CHECK_EQ(erf(x).item<float>(), 0.0f);
1649
+ x = array(inf);
1650
+ CHECK_EQ(erf(x).item<float>(), 1.0f);
1651
+ x = array(-inf);
1652
+ CHECK_EQ(erf(x).item<float>(), -1.0f);
1653
+
1654
+ x = array(1, int32);
1655
+ CHECK_EQ(erf(x).dtype(), float32);
1656
+
1657
+ x = array(0.0f);
1658
+ CHECK_EQ(erfinv(x).item<float>(), 0.0f);
1659
+ x = array(1.0f);
1660
+ CHECK_EQ(erfinv(x).item<float>(), inf);
1661
+ x = array(-1.0f);
1662
+ CHECK_EQ(erfinv(x).item<float>(), -inf);
1663
+
1664
+ x = array(1, int32);
1665
+ CHECK_EQ(erfinv(x).dtype(), float32);
1666
+
1667
+ x = array(2.0f);
1668
+ CHECK(std::isnan(erfinv(x).item<float>()));
1669
+ x = array(-2.0f);
1670
+ CHECK(std::isnan(erfinv(x).item<float>()));
1671
+
1672
+ auto vals = {0.9f, 0.5f, 0.1f, -0.1f, -0.5f, -0.9f};
1673
+ // Expected values are generated from scipy's error function:
1674
+ // python -c "import scipy.special as ss;
1675
+ // vals = [0.9, 0.5, 0.1, -0.1, -0.5, -0.9];
1676
+ // print([ss.erf(x) for x in vals])"
1677
+ {
1678
+ auto expected = {
1679
+ 0.7969082124228322,
1680
+ 0.5204998778130465,
1681
+ 0.1124629160182849,
1682
+ -0.1124629160182849,
1683
+ -0.5204998778130465,
1684
+ -0.7969082124228322};
1685
+ for (int i = 0; i < vals.size(); ++i) {
1686
+ x = array(vals.begin()[i]);
1687
+ CHECK_EQ(erf(x).item<float>(), doctest::Approx(expected.begin()[i]));
1688
+ }
1689
+ }
1690
+
1691
+ // Expected values are generated from scipy's inverse error function:
1692
+ // python -c "import scipy.special as ss;
1693
+ // vals = [0.9, 0.5, 0.1, -0.1, -0.5, -0.9];
1694
+ // print([ss.erfinv(x) for x in vals])"
1695
+ {
1696
+ auto expected = {
1697
+ 1.1630871536766738,
1698
+ 0.4769362762044699,
1699
+ 0.08885599049425778,
1700
+ -0.08885599049425769,
1701
+ -0.4769362762044699,
1702
+ -1.1630871536766743};
1703
+ for (int i = 0; i < vals.size(); ++i) {
1704
+ x = array(vals.begin()[i]);
1705
+ CHECK_EQ(erfinv(x).item<float>(), doctest::Approx(expected.begin()[i]));
1706
+ }
1707
+ }
1708
+
1709
+ // float16_t
1710
+ {
1711
+ array x(0.0f, float16);
1712
+ auto out = erf(x);
1713
+ CHECK_EQ(out.dtype(), float16);
1714
+ CHECK_EQ(out.item<float16_t>(), 0.0f);
1715
+
1716
+ out = erfinv(x);
1717
+ CHECK_EQ(out.dtype(), float16);
1718
+ CHECK_EQ(out.item<float16_t>(), 0.0f);
1719
+ }
1720
+
1721
+ // bfloat
1722
+ {
1723
+ array x(0.0f, bfloat16);
1724
+ auto out = erf(x);
1725
+ CHECK_EQ(out.dtype(), bfloat16);
1726
+ CHECK_EQ(out.item<bfloat16_t>(), 0.0f);
1727
+
1728
+ out = erfinv(x);
1729
+ CHECK_EQ(out.dtype(), bfloat16);
1730
+ CHECK_EQ(out.item<float16_t>(), 0.0f);
1731
+ }
1732
+ }
1733
+
1734
+ TEST_CASE("test arithmetic binary ops") {
1735
+ array x(1.0);
1736
+ array y(1.0);
1737
+ auto z = add(x, y);
1738
+ CHECK_EQ(z.item<float>(), 2.0);
1739
+ z = x + y;
1740
+ CHECK_EQ(z.item<float>(), 2.0);
1741
+ z = add(z, x);
1742
+ CHECK_EQ(z.item<float>(), 3.0);
1743
+ z.eval(); // No-op
1744
+ CHECK_EQ(z.item<float>(), 3.0);
1745
+
1746
+ // Chain a few adds:
1747
+ auto out = x;
1748
+ for (int i = 0; i < 10; ++i) {
1749
+ out = add(out, x);
1750
+ }
1751
+ CHECK_EQ(out.item<float>(), 11.0);
1752
+
1753
+ // Works for different shapes
1754
+ x = array({1.0, 2.0, 3.0}, {1, 3});
1755
+ y = array({1.0, 2.0, 3.0}, {1, 3});
1756
+ z = add(x, y);
1757
+ CHECK_EQ(z.shape(), Shape{1, 3});
1758
+ auto eq = array_equal(z, array({2.0, 4.0, 6.0}, {1, 3}));
1759
+ CHECK(eq.item<bool>());
1760
+
1761
+ // Works with scalars
1762
+ x = array({1.0, 2.0, 3.0}, {1, 3});
1763
+ y = x + 2.0;
1764
+ CHECK_EQ(y.dtype(), float32);
1765
+ eq = array_equal(y, array({3.0, 4.0, 5.0}, {1, 3}));
1766
+ CHECK(eq.item<bool>());
1767
+ y = 2.0 + x;
1768
+ CHECK_EQ(y.dtype(), float32);
1769
+ eq = array_equal(y, array({3.0, 4.0, 5.0}, {1, 3}));
1770
+ CHECK(eq.item<bool>());
1771
+
1772
+ // Check type promotion
1773
+ y = 2 + x;
1774
+ CHECK_EQ(y.dtype(), float32);
1775
+
1776
+ y = 2.0 + array({1, 2, 3});
1777
+ CHECK_EQ(y.dtype(), float32);
1778
+ CHECK(array_equal(y, array({3.0, 4.0, 5.0})).item<bool>());
1779
+
1780
+ // Broadcasting works
1781
+ x = broadcast_to(array({1.0}), {10});
1782
+ y = broadcast_to(array({2.0}), {10});
1783
+ z = add(x, y);
1784
+ CHECK(array_equal(z, full({10}, 3.0)).item<bool>());
1785
+
1786
+ x = array({1.0, 2.0}, {1, 2});
1787
+ y = array({1.0, 2.0}, {2, 1});
1788
+ z = add(x, y);
1789
+ CHECK_EQ(z.shape(), Shape{2, 2});
1790
+ eq = array_equal(z, array({2.0, 3.0, 3.0, 4.0}, {2, 2}));
1791
+ CHECK(eq.item<bool>());
1792
+
1793
+ x = ones({3, 2, 1});
1794
+ z = x + 2.0;
1795
+ CHECK_EQ(z.shape(), Shape{3, 2, 1});
1796
+ eq = array_equal(z, array({3.0, 3.0, 3.0, 3.0, 3.0, 3.0}, {3, 2, 1}));
1797
+ CHECK(eq.item<bool>());
1798
+
1799
+ // Works for empty arrays
1800
+ x = array({});
1801
+ y = array({});
1802
+ z = x + y;
1803
+ z.eval();
1804
+ CHECK_EQ(z.size(), 0);
1805
+ CHECK_EQ(z.shape(), Shape{0});
1806
+
1807
+ // Check subtraction
1808
+ x = array({3, 2, 1});
1809
+ y = array({1, 1, 1});
1810
+ CHECK(array_equal(x - y, array({2, 1, 0})).item<bool>());
1811
+
1812
+ // Check multiplication
1813
+ x = array({1, 2, 3});
1814
+ y = array({2, 2, 2});
1815
+ CHECK(array_equal(x * y, array({2, 4, 6})).item<bool>());
1816
+
1817
+ // Check division
1818
+ x = array(1);
1819
+ y = array(1);
1820
+ CHECK_EQ(divide(x, y).item<float>(), 1.0f);
1821
+
1822
+ x = array(1);
1823
+ y = array(0.5);
1824
+ CHECK_EQ(divide(x, y).item<float>(), 2.0f);
1825
+
1826
+ x = array(1);
1827
+ y = array(4);
1828
+ CHECK_EQ(divide(x, y).item<float>(), 0.25f);
1829
+
1830
+ x = array(true);
1831
+ y = array(true);
1832
+ CHECK_EQ(divide(x, y).item<float>(), 1.0f);
1833
+
1834
+ x = array(false);
1835
+ y = array(true);
1836
+ CHECK_EQ(divide(x, y).item<float>(), 0.0f);
1837
+
1838
+ x = array(true);
1839
+ y = array(false);
1840
+ CHECK(std::isinf(divide(x, y).item<float>()));
1841
+
1842
+ x = array(false);
1843
+ y = array(false);
1844
+ CHECK(std::isnan(divide(x, y).item<float>()));
1845
+
1846
+ // Check maximum and minimum
1847
+ x = array(1.0f);
1848
+ y = array(0.0f);
1849
+ CHECK_EQ(maximum(x, y).item<float>(), 1.0f);
1850
+ CHECK_EQ(minimum(x, y).item<float>(), 0.0f);
1851
+ y = array(2.0f);
1852
+ CHECK_EQ(maximum(x, y).item<float>(), 2.0f);
1853
+ CHECK_EQ(minimum(x, y).item<float>(), 1.0f);
1854
+
1855
+ // Check logaddexp
1856
+ x = array(0.0f);
1857
+ y = array(0.0f);
1858
+ CHECK_EQ(logaddexp(x, y).item<float>(), std::log(2.0f));
1859
+
1860
+ x = array(0u);
1861
+ y = array(10000u);
1862
+ CHECK_EQ(logaddexp(x, y).item<float>(), 10000.0f);
1863
+
1864
+ constexpr float inf = std::numeric_limits<float>::infinity();
1865
+ x = array(inf);
1866
+ y = array(3.0f);
1867
+ CHECK_EQ(logaddexp(x, y).item<float>(), inf);
1868
+
1869
+ x = array(-inf);
1870
+ y = array(3.0f);
1871
+ CHECK_EQ(logaddexp(x, y).item<float>(), 3.0f);
1872
+
1873
+ x = array(-inf);
1874
+ y = array(-inf);
1875
+ CHECK_EQ(logaddexp(x, y).item<float>(), -inf);
1876
+
1877
+ x = array(inf);
1878
+ y = array(inf);
1879
+ CHECK_EQ(logaddexp(x, y).item<float>(), inf);
1880
+
1881
+ x = array(-inf);
1882
+ y = array(inf);
1883
+ CHECK_EQ(logaddexp(x, y).item<float>(), inf);
1884
+
1885
+ x = array(complex64_t{1, 1});
1886
+ y = array(complex64_t{-inf, -inf});
1887
+ CHECK_EQ(logaddexp(x, y).item<complex64_t>(), complex64_t{1, 1});
1888
+ }
1889
+
1890
+ TEST_CASE("test broadcast") {
1891
+ auto s = broadcast_shapes({1}, {1, 2});
1892
+ CHECK_EQ(s, Shape{1, 2});
1893
+
1894
+ s = broadcast_shapes({1, 2}, {1});
1895
+ CHECK_EQ(s, Shape{1, 2});
1896
+
1897
+ s = broadcast_shapes({2, 2}, {});
1898
+ CHECK_EQ(s, Shape{2, 2});
1899
+
1900
+ s = broadcast_shapes({}, {1, 1});
1901
+ CHECK_EQ(s, Shape{1, 1});
1902
+
1903
+ s = broadcast_shapes({1, 2, 1}, {2});
1904
+ CHECK_EQ(s, Shape{1, 2, 2});
1905
+
1906
+ s = broadcast_shapes({2}, {1, 2, 1});
1907
+ CHECK_EQ(s, Shape{1, 2, 2});
1908
+
1909
+ s = broadcast_shapes({2, 2, 2}, {1, 2, 1});
1910
+ CHECK_EQ(s, Shape{2, 2, 2});
1911
+
1912
+ s = broadcast_shapes({2, 2, 2, 1}, {1, 2, 1});
1913
+ CHECK_EQ(s, Shape{2, 2, 2, 1});
1914
+
1915
+ s = broadcast_shapes({0}, {0, 0});
1916
+ CHECK_EQ(s, Shape{0, 0});
1917
+
1918
+ CHECK_EQ(broadcast_shapes({}, {0}), Shape{0});
1919
+
1920
+ s = broadcast_shapes({5, 0}, {0, 5, 0});
1921
+ CHECK_EQ(s, Shape{0, 5, 0});
1922
+
1923
+ CHECK_EQ(broadcast_shapes({}, {0}), Shape{0});
1924
+ CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0});
1925
+ CHECK_EQ(broadcast_shapes({1}, {0}), Shape{0});
1926
+ CHECK_EQ(broadcast_shapes({1}, {0, 0}), Shape{0, 0});
1927
+ CHECK_EQ(broadcast_shapes({1, 1}, {0}), Shape{1, 0});
1928
+ CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), Shape{0, 0});
1929
+ CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), Shape{2, 0});
1930
+ CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), Shape{2, 0});
1931
+ CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), Shape{1, 2, 0});
1932
+ CHECK_THROWS_AS(broadcast_shapes({2}, {0}), std::invalid_argument);
1933
+ CHECK_THROWS_AS(broadcast_shapes({2, 1}, {0, 0}), std::invalid_argument);
1934
+
1935
+ CHECK_THROWS_AS(broadcast_shapes({3}, {2}), std::invalid_argument);
1936
+ CHECK_THROWS_AS(broadcast_shapes({1, 3}, {2}), std::invalid_argument);
1937
+ CHECK_THROWS_AS(broadcast_shapes({3}, {1, 2}), std::invalid_argument);
1938
+ CHECK_THROWS_AS(
1939
+ broadcast_shapes({1, 3, 2}, {1, 2, 2}), std::invalid_argument);
1940
+
1941
+ auto x = full({1, 1}, 2.3f);
1942
+ CHECK_EQ(broadcast_to(x, {1, 1}).item<float>(), 2.3f);
1943
+
1944
+ x = broadcast_to(x, {5, 1});
1945
+ CHECK_EQ(x.shape(), Shape{5, 1});
1946
+ x.eval();
1947
+ CHECK_EQ(x.strides(), Strides{0, 0});
1948
+
1949
+ CHECK_THROWS_AS(broadcast_to(x, {1, 5}), std::invalid_argument);
1950
+ x = broadcast_to(x, {5, 5});
1951
+ CHECK_EQ(x.shape(), Shape{5, 5});
1952
+
1953
+ x = zeros({2, 1, 2});
1954
+ x = broadcast_to(x, {4, 2, 1, 2});
1955
+ CHECK_EQ(x.shape(), Shape{4, 2, 1, 2});
1956
+ x.eval();
1957
+ CHECK_EQ(x.strides(), Strides{0, 2, 0, 1});
1958
+
1959
+ // Broadcast on empty arrays works as expected
1960
+ x = array({});
1961
+ CHECK_THROWS_AS(broadcast_to(x, {1}), std::invalid_argument);
1962
+
1963
+ // Broadcast to empty array works as expected
1964
+ x = array({1});
1965
+ auto y = broadcast_to(x, {0});
1966
+ eval(y);
1967
+ CHECK_EQ(y.size(), 0);
1968
+ CHECK_EQ(y.shape(), Shape{0});
1969
+
1970
+ x = array({1, 2}, {2, 1});
1971
+ y = broadcast_to(x, {2, 0});
1972
+ eval(y);
1973
+ CHECK_EQ(y.size(), 0);
1974
+ CHECK_EQ(y.shape(), Shape{2, 0});
1975
+
1976
+ // Check repeat application works
1977
+ x = zeros({2});
1978
+ x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2});
1979
+ CHECK_EQ(x.shape(), Shape{2, 2});
1980
+ x.eval();
1981
+ CHECK_EQ(x.strides(), Strides{0, 1});
1982
+ x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2, 2});
1983
+ CHECK_EQ(x.shape(), Shape{2, 2, 2});
1984
+ x.eval();
1985
+ CHECK_EQ(x.strides(), Strides{0, 0, 1});
1986
+
1987
+ // Broadcast on transposed array works
1988
+ x = array({0, 1, 2, 3, 4, 5}, {2, 3});
1989
+ x = broadcast_to(transpose(x), {2, 3, 2});
1990
+ CHECK_EQ(x.shape(), Shape{2, 3, 2});
1991
+ y = broadcast_to(array({0, 3, 1, 4, 2, 5}, {3, 2}), {2, 3, 2});
1992
+ CHECK(array_equal(x, y).item<bool>());
1993
+
1994
+ // Reshape on broadcasted array works
1995
+ x = array(1.0);
1996
+ x = broadcast_to(x, {2});
1997
+ x = reshape(x, {1, 2});
1998
+ CHECK(array_equal(x, ones({1, 2})).item<bool>());
1999
+ }
2000
+
2001
+ TEST_CASE("test gather") {
2002
+ // Empty input, non-empty indices/slice
2003
+ CHECK_THROWS(gather(array({}), array({1}), 0, {1}));
2004
+
2005
+ // More indices than dimensions
2006
+ CHECK_THROWS(gather(array(0), array({1}), 0, {1}));
2007
+
2008
+ // Mismatch dimensions and indices
2009
+ CHECK_THROWS(gather(array({0}), {array({0})}, {0, 1}, {1}));
2010
+ CHECK_THROWS(gather(array({0}), array({0}), -1, {1}));
2011
+
2012
+ // Repeat dimensions
2013
+ CHECK_THROWS(
2014
+ gather(array({0}, {1, 1}), {array({0}), array({0})}, {0, 0}, {1, 1}));
2015
+
2016
+ // Slice sizes incorrect
2017
+ CHECK_THROWS(gather(array({0}), array({0}), 0, {2}));
2018
+ CHECK_THROWS(gather(array({0}), array({0}), 0, {0, 0}));
2019
+ CHECK_THROWS(gather(array({0}), array({0}), 0, {-1}));
2020
+
2021
+ // Wrong index type
2022
+ CHECK_THROWS(gather(array({0}), array({0.0f}), 0, {0}));
2023
+ CHECK_THROWS(
2024
+ gather(array({0}, {1, 1}), {array({0}), array({0.0f})}, {0, 1}, {1, 1}));
2025
+
2026
+ // Index arrays must be broadcastable
2027
+ CHECK_THROWS(gather(
2028
+ array({0}, {1, 1}),
2029
+ {array({0, 0, 0}, {3}), array({0, 0}, {2})},
2030
+ {0, 1},
2031
+ {1, 1}));
2032
+
2033
+ // Basic test of correctness with 1D input
2034
+ auto x = arange(20);
2035
+ auto y = arange(10);
2036
+ auto out = gather(x, y, 0, {1});
2037
+ CHECK_EQ(out.shape(), Shape{10, 1});
2038
+ CHECK(array_equal(reshape(out, {-1}), y).item<bool>());
2039
+
2040
+ out = gather(x, array({15}, uint32), 0, {1});
2041
+ CHECK_EQ(out.shape(), Shape{1, 1});
2042
+ CHECK_EQ(out.item<int32_t>(), 15);
2043
+
2044
+ // No index gather works
2045
+ out = gather(x, {}, std::vector<int>{}, {10});
2046
+ CHECK_EQ(out.shape(), Shape{10});
2047
+ CHECK(array_equal(out, arange(10)).item<bool>());
2048
+
2049
+ // Basic test of correctness with 2D input
2050
+ x = arange(128);
2051
+ x = reshape(x, {4, 32});
2052
+ y = array({0, 1}, uint32);
2053
+ out = gather(x, y, 0, {1, 32});
2054
+ CHECK_EQ(out.shape(), Shape{2, 1, 32});
2055
+ CHECK(array_equal(reshape(out, {64}), arange(64)).item<bool>());
2056
+
2057
+ x = reshape(x, {64, 2});
2058
+ y = array({0}, uint32);
2059
+ out = gather(x, y, 0, {64, 1});
2060
+ CHECK_EQ(out.shape(), Shape{1, 64, 1});
2061
+ CHECK(array_equal(out, reshape(arange(0, 128, 2), {1, 64, 1})).item<bool>());
2062
+
2063
+ // Basic test of correctness with 3D input
2064
+ x = arange(256);
2065
+ x = reshape(x, {8, 4, 8});
2066
+ y = array({0}, uint32);
2067
+ out = gather(x, y, 0, {8, 1, 1});
2068
+ CHECK_EQ(out.shape(), Shape{1, 8, 1, 1});
2069
+ CHECK(
2070
+ array_equal(out, reshape(arange(0, 256, 32), {1, 8, 1, 1})).item<bool>());
2071
+
2072
+ x = broadcast_to(array({1, 2}), {20, 2});
2073
+ out = gather(x, array({5}), 0, {1, 1});
2074
+ CHECK_EQ(out.item<int>(), 1);
2075
+ out = gather(x, {array({5}), array({1})}, {0, 1}, {1, 1});
2076
+ CHECK_EQ(out.item<int>(), 2);
2077
+ }
2078
+
2079
+ TEST_CASE("test take") {
2080
+ // Empty takes
2081
+ auto empty = astype(array({}), int32);
2082
+ auto z = take(array({1}), empty);
2083
+ CHECK_EQ(z.shape(), Shape{0});
2084
+ empty = reshape(empty, {1, 0, 1});
2085
+ z = take(array({1}), empty);
2086
+ CHECK_EQ(z.shape(), Shape{1, 0, 1});
2087
+
2088
+ CHECK_THROWS(take(array({}), array(1)));
2089
+
2090
+ z = take(array({}), empty);
2091
+ CHECK_EQ(z.size(), 0);
2092
+
2093
+ // Take a single row
2094
+ auto x = reshape(arange(256), {8, 4, 8});
2095
+ z = take(x, array({0}, uint32), 0);
2096
+ CHECK_EQ(z.shape(), Shape{1, 4, 8});
2097
+ z = reshape(z, {32});
2098
+ CHECK(array_equal(z, arange(32)).item<bool>());
2099
+
2100
+ z = take(x, array({1}, uint32), 0);
2101
+ z = reshape(z, {32});
2102
+ CHECK(array_equal(z, arange(32, 64)).item<bool>());
2103
+
2104
+ // Take multiple rows
2105
+ x = arange(256);
2106
+ x = reshape(x, {8, 4, 8});
2107
+ z = take(x, array({0, 1}, uint32), 0);
2108
+ z = reshape(z, {64});
2109
+ CHECK(array_equal(z, arange(64)).item<bool>());
2110
+
2111
+ // Take along middle axis
2112
+ x = reshape(arange(8), {2, 2, 2});
2113
+ z = take(x, array({0}), 1);
2114
+ CHECK(array_equal(z, array({0, 1, 4, 5}, {2, 1, 2})).item<bool>());
2115
+
2116
+ // Irregular strides test
2117
+ auto a = array({1, 2, 3}, float32);
2118
+ auto indices = broadcast_to(array(0), {10});
2119
+ auto b = take(a, indices);
2120
+ CHECK(array_equal(b, ones({10})).item<bool>());
2121
+
2122
+ // Take with 0 dim index
2123
+ z = take(array({0, 1, 2}), array(0));
2124
+ CHECK_EQ(z.item<int>(), 0);
2125
+ CHECK_EQ(z.ndim(), 0);
2126
+
2127
+ // Check take with float indices crashes
2128
+ CHECK_THROWS(take(array({}), array({})));
2129
+ CHECK_THROWS(take(a, array({1.0, 2.0, 3.0})));
2130
+
2131
+ // Check axis
2132
+ a = array({1, 2, 3, 4}, {2, 2});
2133
+ CHECK_THROWS(take(a, array({1}), -3));
2134
+ CHECK_THROWS(take(a, array({1}), 2));
2135
+
2136
+ // Check negative indices
2137
+ a = array({1, 2, 3, 4}, {2, 2});
2138
+ CHECK_EQ(take(a, array({-1})).item<int>(), 4);
2139
+ CHECK(array_equal(take(a, array({1, -1})), array({2, 4})).item<bool>());
2140
+ CHECK(array_equal(take(a, array(-1), 0), array({3, 4})).item<bool>());
2141
+
2142
+ // Check shapes
2143
+ a = zeros({2, 1, 1});
2144
+ auto out = take(a, array({1}), 0);
2145
+ CHECK(array_equal(out, zeros({1, 1, 1})).item<bool>());
2146
+ out = take(a, array({0}), 1);
2147
+ CHECK(array_equal(out, zeros({2, 1, 1})).item<bool>());
2148
+ out = take(a, array({0}), 1);
2149
+ CHECK(array_equal(out, zeros({2, 1, 1})).item<bool>());
2150
+ a = zeros({1, 2, 1});
2151
+ out = take(a, array({0}), 0);
2152
+ CHECK(array_equal(out, zeros({1, 2, 1})).item<bool>());
2153
+ out = take(a, array({0}), 1);
2154
+ CHECK(array_equal(out, zeros({1, 1, 1})).item<bool>());
2155
+ out = take(a, array({0, 1}), 1);
2156
+ CHECK(array_equal(out, zeros({1, 2, 1})).item<bool>());
2157
+
2158
+ // Indices have wrong shape
2159
+ a = zeros({2, 3, 4});
2160
+ CHECK_THROWS(take(a, zeros({1, 3, 4}), 1));
2161
+ CHECK_THROWS(take(a, zeros({2, 3, 7}), 1));
2162
+ CHECK_THROWS(take(a, zeros({2, 3, 2}), 0));
2163
+ }
2164
+
2165
+ TEST_CASE("test take along axis") {
2166
+ // No zero dim arrays
2167
+ auto a = array(1);
2168
+ CHECK_THROWS(take_along_axis(a, array(0), 0));
2169
+
2170
+ // Index and array size mismatches
2171
+ a = arange(5);
2172
+ CHECK_THROWS(take_along_axis(a, array({1}), 1));
2173
+ CHECK_THROWS(take_along_axis(a, array({1}, {1, 1}), 0));
2174
+ CHECK_THROWS(take_along_axis(a, array(1), -1));
2175
+
2176
+ auto out = take_along_axis(a, array({1}), 0);
2177
+ CHECK_EQ(out.item<int>(), 1);
2178
+ out = take_along_axis(a, array({1}), -1);
2179
+ CHECK_EQ(out.item<int>(), 1);
2180
+
2181
+ // Empty arrays
2182
+ a = reshape(array({}), {1, 0});
2183
+ CHECK_THROWS(take_along_axis(a, array({1}), 0));
2184
+
2185
+ out = take_along_axis(a, reshape(array({1}), {1, 1}), 0);
2186
+ eval(out); // Make sure it runs
2187
+ CHECK_EQ(out.shape(), Shape{1, 0});
2188
+
2189
+ auto inds = reshape(astype(array({}), int32), {1, 0});
2190
+ out = take_along_axis(a, inds, 0);
2191
+ eval(out); // Make sure it runs
2192
+ CHECK_EQ(out.shape(), Shape{1, 0});
2193
+
2194
+ a = array({1, 2, 3, 4}, {2, 2});
2195
+ inds = array({0, 1}, {1, 2});
2196
+ out = take_along_axis(a, inds, 0);
2197
+ CHECK(array_equal(out, array({1, 4}, {1, 2})).item<bool>());
2198
+
2199
+ inds = array({0, 1, 0, 1, 0, 0, 1, 0}, {4, 2}, int32);
2200
+ out = take_along_axis(a, inds, 0);
2201
+ CHECK(array_equal(out, array({1, 4, 1, 4, 1, 2, 3, 2}, {4, 2})).item<bool>());
2202
+
2203
+ inds = array({0, 1}, {2, 1});
2204
+ out = take_along_axis(a, inds, 1);
2205
+ CHECK(array_equal(out, array({1, 4}, {2, 1})).item<bool>());
2206
+
2207
+ // Broadcasting works
2208
+ inds = array({0}, {1, 1});
2209
+ out = take_along_axis(a, inds, 0);
2210
+ CHECK(array_equal(out, array({1, 2}, {1, 2})).item<bool>());
2211
+ out = take_along_axis(a, inds, 1);
2212
+ CHECK(array_equal(out, array({1, 3}, {2, 1})).item<bool>());
2213
+
2214
+ inds = array({0, 1, 1, 0, 0, 1}, {2, 3}, int32);
2215
+ out = take_along_axis(a, inds, 1);
2216
+ CHECK(array_equal(out, array({1, 2, 2, 3, 3, 4}, {2, 3})).item<bool>());
2217
+
2218
+ a = reshape(arange(8), {2, 2, 2});
2219
+ inds = array({0, 1, 0, 0, 1, 0, 0, 1}, {2, 2, 2});
2220
+ out = take_along_axis(a, inds, 0);
2221
+ CHECK(array_equal(out, array({0, 5, 2, 3, 4, 1, 2, 7}, {2, 2, 2}))
2222
+ .item<bool>());
2223
+ out = take_along_axis(a, inds, 1);
2224
+ CHECK(array_equal(out, array({0, 3, 0, 1, 6, 5, 4, 7}, {2, 2, 2}))
2225
+ .item<bool>());
2226
+ out = take_along_axis(a, inds, 2);
2227
+ CHECK(array_equal(out, array({0, 1, 2, 2, 5, 4, 6, 7}, {2, 2, 2}))
2228
+ .item<bool>());
2229
+ }
2230
+
2231
+ TEST_CASE("test put along axis") {
2232
+ // No zero dim arrays
2233
+ auto a = array(1);
2234
+ auto v = array(1);
2235
+ CHECK_THROWS(put_along_axis(a, array(0), v, 0));
2236
+
2237
+ // Index and array size mismatches
2238
+ a = arange(5);
2239
+ CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 1));
2240
+ CHECK_THROWS(put_along_axis(a, array({1}, {1, 1}), array({0}), 0));
2241
+ CHECK_THROWS(put_along_axis(a, array(1), array(0), -1));
2242
+
2243
+ auto expected = array({0, 0, 2, 3, 4});
2244
+ auto out = put_along_axis(a, array({1}), array({0}), 0);
2245
+ CHECK(array_equal(out, expected).item<bool>());
2246
+
2247
+ // Empty arrays
2248
+ a = reshape(array({}), {1, 0});
2249
+ CHECK_THROWS(put_along_axis(a, array({1}), array({0}), 0));
2250
+
2251
+ auto inds = reshape(astype(array({}), int32), {1, 0});
2252
+ out = take_along_axis(a, inds, 0);
2253
+ eval(out); // Make sure it runs
2254
+ CHECK_EQ(out.shape(), Shape{1, 0});
2255
+
2256
+ a = array({1, 2, 3, 4}, {2, 2});
2257
+ inds = array({0, 1}, {1, 2});
2258
+ out = put_along_axis(a, inds, array({0}), 0);
2259
+ expected = array({0, 2, 3, 0}, {2, 2});
2260
+ CHECK(array_equal(out, expected).item<bool>());
2261
+
2262
+ inds = array({0, 0, 1, 1}, {2, 2}, int32);
2263
+ auto values = array({2, 3, 4, 5}, {2, 2}, int32);
2264
+ out = put_along_axis(a, inds, values, 0);
2265
+ CHECK(array_equal(out, array({2, 3, 4, 5}, {2, 2})).item<bool>());
2266
+
2267
+ inds = array({0, 1}, {2, 1});
2268
+ out = put_along_axis(a, inds, array({0}), 1);
2269
+ expected = array({0, 2, 3, 0}, {2, 2});
2270
+ CHECK(array_equal(out, expected).item<bool>());
2271
+ }
2272
+
2273
+ TEST_CASE("test scatter") {
2274
+ // More indices than dimensions
2275
+ CHECK_THROWS(scatter(array(0), array({1}), array(1), 0));
2276
+
2277
+ // Mismatch dimensions and indices
2278
+ CHECK_THROWS(scatter(array({0}), {array({0})}, array({1}, {1, 1}), {0, 1}));
2279
+ CHECK_THROWS(scatter(array({0}), array({0}), array({1}, {1, 1}), -1));
2280
+
2281
+ // Repeat dimensions
2282
+ CHECK_THROWS(scatter(
2283
+ array({0}, {1, 1}), {array({0}), array({0})}, array({1}), {0, 0}));
2284
+
2285
+ // Update sizes incorrect
2286
+ CHECK_THROWS(scatter(array({0}), array({0}), array({0, 1}), 0));
2287
+ CHECK_THROWS(scatter(array({0}), array({0}), array({0, 1}, {2, 1}), 0));
2288
+ CHECK_THROWS(scatter(array({0}, {1}), array({0}), array({0, 1}, {1, 2}), 0));
2289
+
2290
+ // Wrong index type
2291
+ CHECK_THROWS(scatter(array({0}), array({0.0f}), array({0}, {1, 1}), 0));
2292
+ CHECK_THROWS(scatter(
2293
+ array({0}, {1, 1}),
2294
+ {array({0}), array({0.0f})},
2295
+ array({1}, {1, 1, 1}),
2296
+ {0, 1}));
2297
+
2298
+ // Index arrays must be broadcastable
2299
+ CHECK_THROWS(scatter(
2300
+ array({0}, {1, 1}),
2301
+ {array({0, 0, 0}, {3}), array({0, 0}, {2})},
2302
+ ones({3, 2, 1, 1}),
2303
+ {0, 1}));
2304
+
2305
+ // Single element scatter
2306
+ auto in = zeros({4}, float32);
2307
+ auto inds = arange(2);
2308
+ auto updates = ones({2, 1}, float32);
2309
+ auto out = scatter(in, inds, updates, 0);
2310
+ CHECK(array_equal(out, array({1.0f, 1.0f, 0.0f, 0.0f})).item<bool>());
2311
+
2312
+ // Single element scatter add
2313
+ in = ones({4}, float32);
2314
+ inds = array({0, 0, 3});
2315
+ updates = ones({3, 1}, float32);
2316
+ out = scatter_add(in, inds, updates, 0);
2317
+ CHECK(array_equal(out, array({3.0f, 1.0f, 1.0f, 2.0f})).item<bool>());
2318
+
2319
+ // Single element scatter prod
2320
+ in = ones({4}, float32);
2321
+ inds = array({0, 0, 3});
2322
+ updates = full({3, 1}, 2.0f, float32);
2323
+ out = scatter_prod(in, inds, updates, 0);
2324
+ CHECK(array_equal(out, array({4.0f, 1.0f, 1.0f, 2.0f})).item<bool>());
2325
+
2326
+ // Single element scatter max
2327
+ in = ones({4}, float32);
2328
+ inds = array({0, 0, 3});
2329
+ updates = array({1.0f, 6.0f, -2.0f}, {3, 1});
2330
+ out = scatter_max(in, inds, updates, 0);
2331
+ CHECK(array_equal(out, array({6.0f, 1.0f, 1.0f, 1.0f})).item<bool>());
2332
+
2333
+ // Single element scatter min
2334
+ in = ones({4}, float32);
2335
+ inds = array({0, 0, 3});
2336
+ updates = array({1.0f, -6.0f, 2.0f}, {3, 1});
2337
+ out = scatter_min(in, inds, updates, 0);
2338
+ CHECK(array_equal(out, array({-6.0f, 1.0f, 1.0f, 1.0f})).item<bool>());
2339
+
2340
+ // Empty scatter
2341
+ in = arange(4, float32);
2342
+ inds = astype(array({}), uint32);
2343
+ updates = reshape(array({}), {0, 1});
2344
+ out = scatter(in, inds, updates, 0);
2345
+ CHECK(array_equal(out, in).item<bool>());
2346
+
2347
+ // Array scatters
2348
+ in = zeros({4, 4}, float32);
2349
+ inds = array({0, 1, 2, 3});
2350
+ updates = reshape(arange(16, float32), {4, 1, 4});
2351
+ out = scatter(in, inds, updates, 0);
2352
+ CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item<bool>());
2353
+
2354
+ // Array scatters with col contiguous updates
2355
+ in = zeros({4, 4}, float32);
2356
+ inds = array({0, 1, 2, 3});
2357
+ updates = transpose(reshape(arange(16, float32), {4, 1, 4}));
2358
+ out = scatter(in, inds, updates, 0);
2359
+ CHECK(array_equal(out, transpose(reshape(arange(16, float32), {4, 4})))
2360
+ .item<bool>());
2361
+
2362
+ // Irregular strided index and reduce collision test
2363
+ in = zeros({10}, float32);
2364
+ inds = broadcast_to(array(3), {10});
2365
+ updates = ones({10, 1}, float32);
2366
+ out = scatter_add(in, inds, updates, 0);
2367
+ CHECK_EQ(take(out, array(3)).item<float>(), 10);
2368
+
2369
+ // 1 element array with 0 dim index
2370
+ in = array({1}, int32);
2371
+ updates = array({2}, int32);
2372
+ out = scatter_max(in, array(0), updates, 0);
2373
+ CHECK_EQ(out.item<int>(), 2);
2374
+
2375
+ // No index arrays or axes
2376
+ out = scatter_max(array(1), {}, array(2), std::vector<int>{});
2377
+ CHECK_EQ(out.item<int>(), 2);
2378
+
2379
+ // Irregularly strided updates test
2380
+ in = ones({3, 3});
2381
+ updates = broadcast_to(array({2, 2, 2}), {1, 3, 3});
2382
+ inds = array({0});
2383
+ out = scatter(in, inds, updates, 0);
2384
+ CHECK(array_equal(out, ones({3, 3}) * 2).item<bool>());
2385
+
2386
+ // Along different axis
2387
+ in = zeros({2, 3});
2388
+ updates = array({1, 2, 3, 4}, {2, 2, 1});
2389
+ inds = array({0, 2});
2390
+ out = scatter(in, inds, updates, 1);
2391
+ auto expected = array({1, 0, 3, 2, 0, 4}, {2, 3});
2392
+ CHECK(array_equal(out, expected).item<bool>());
2393
+
2394
+ // Multiple index arrays
2395
+ in = zeros({2, 2});
2396
+ updates = array({1, 2}, {2, 1, 1});
2397
+ inds = array({0, 1});
2398
+ out = scatter(in, {inds, inds}, updates, {0, 1});
2399
+ CHECK(array_equal(out, array({1, 0, 0, 2}, {2, 2})).item<bool>());
2400
+
2401
+ // Broadcasted indices
2402
+ in = zeros({2, 2});
2403
+ updates = array({5, 2, 9, 1}, {2, 2, 1, 1});
2404
+ auto inds0 = array({0, 1}, {2, 1});
2405
+ auto inds1 = array({0, 1}, {1, 2});
2406
+ out = scatter(in, {inds0, inds1}, updates, {0, 1});
2407
+ CHECK(array_equal(out, array({5, 2, 9, 1}, {2, 2})).item<bool>());
2408
+
2409
+ // Brodacasted operand
2410
+ in = broadcast_to(array({0, 0}), {2, 2});
2411
+ updates = array({1, 1}, {2, 1, 1});
2412
+ inds = array({0, 1});
2413
+ out = scatter_add(in, inds, updates, 0);
2414
+ CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item<bool>());
2415
+
2416
+ // 1D scatter
2417
+ {
2418
+ auto dst = zeros({2, 4}, int32);
2419
+ auto src = reshape(array({1, 2, 3, 4}), {1, 1, 4});
2420
+ auto idx = array({1});
2421
+ auto expected = reshape(array({0, 0, 0, 0, 1, 2, 3, 4}), {2, 4});
2422
+ auto out = scatter(dst, idx, src, 0);
2423
+ CHECK(array_equal(out, expected).item<bool>());
2424
+ }
2425
+
2426
+ // 1D indices with 2D update
2427
+ {
2428
+ auto dst = zeros({3, 4}, int32);
2429
+ auto indices = {array({1}), array({2})};
2430
+ auto axes = {0, 1};
2431
+ auto updates = reshape(array({1, 2, 3, 4}, int32), {1, 2, 2});
2432
+ auto out = scatter(dst, indices, updates, axes);
2433
+ auto expected =
2434
+ reshape(array({0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4}), {3, 4});
2435
+ CHECK(array_equal(out, expected).item<bool>());
2436
+ }
2437
+ }
2438
+
2439
+ TEST_CASE("test masked_scatter") {
2440
+ if (cu::is_available()) {
2441
+ INFO("Skipping masked_scatter cuda ops tests");
2442
+ return;
2443
+ }
2444
+
2445
+ // Wrong mask dtype
2446
+ CHECK_THROWS(masked_scatter(array({1, 2}), array({1, 2}), array({1, 2})));
2447
+
2448
+ // Mask must be broadcastable to self array
2449
+ CHECK_THROWS(masked_scatter(
2450
+ array({1, 2, 3, 4}, {2, 2}),
2451
+ array({false, true, true, false}, {4, 1}),
2452
+ array({1, 2})));
2453
+
2454
+ // 1D mask
2455
+ {
2456
+ auto self = zeros({4}, int32);
2457
+ auto mask = array({true, true, false, true});
2458
+ auto source = array({1, 2, 4});
2459
+ auto out = masked_scatter(self, mask, source);
2460
+ CHECK(array_equal(out, array({1, 2, 0, 4})).item<bool>());
2461
+ }
2462
+
2463
+ // Empty mask
2464
+ {
2465
+ auto self = zeros({4}, int32);
2466
+ auto mask = array({false, false, false, false});
2467
+ auto source = array({1, 2, 4});
2468
+ auto out = masked_scatter(self, mask, source);
2469
+ CHECK(array_equal(out, self).item<bool>());
2470
+ }
2471
+
2472
+ // Broadcasted mask
2473
+ {
2474
+ auto self = zeros({2, 2}, int32);
2475
+ auto mask = array({true, false});
2476
+ auto source = array({5, 6, 7, 8}, {2, 2});
2477
+ auto out = masked_scatter(self, mask, source);
2478
+ CHECK(array_equal(out, array({5, 6, 0, 0}, {2, 2})).item<bool>());
2479
+ }
2480
+ }
2481
+
2482
+ TEST_CASE("test is positive infinity") {
2483
+ array x(1.0f);
2484
+ CHECK_FALSE(isposinf(x).item<bool>());
2485
+
2486
+ array y(std::numeric_limits<float>::infinity());
2487
+ CHECK(isposinf(y).item<bool>());
2488
+
2489
+ array z = identity(7);
2490
+ CHECK_FALSE(all(isposinf(z)).item<bool>());
2491
+
2492
+ array w = array({1.0f, std::numeric_limits<float>::infinity(), 2.0f});
2493
+ CHECK_FALSE(all(isposinf(w)).item<bool>());
2494
+
2495
+ array a(1.0f, bfloat16);
2496
+ CHECK_FALSE(isposinf(a).item<bool>());
2497
+
2498
+ array b(std::numeric_limits<float>::infinity(), float16);
2499
+ CHECK(isposinf(b).item<bool>());
2500
+
2501
+ array c(std::numeric_limits<float>::infinity(), bfloat16);
2502
+ CHECK(isposinf(c).item<bool>());
2503
+ }
2504
+
2505
+ TEST_CASE("test is negative infinity") {
2506
+ array x(1.0f);
2507
+ CHECK_FALSE(isneginf(x).item<bool>());
2508
+
2509
+ array y(-std::numeric_limits<float>::infinity());
2510
+ CHECK(isneginf(y).item<bool>());
2511
+
2512
+ array z = identity(7);
2513
+ CHECK_FALSE(all(isneginf(z)).item<bool>());
2514
+
2515
+ array w = array({1.0f, -std::numeric_limits<float>::infinity(), 2.0f});
2516
+ CHECK_FALSE(all(isneginf(w)).item<bool>());
2517
+
2518
+ array a(1.0f, bfloat16);
2519
+ CHECK_FALSE(isneginf(a).item<bool>());
2520
+
2521
+ array b(-std::numeric_limits<float>::infinity(), float16);
2522
+ CHECK(isneginf(b).item<bool>());
2523
+
2524
+ array c(-std::numeric_limits<float>::infinity(), bfloat16);
2525
+ CHECK(isneginf(c).item<bool>());
2526
+ }
2527
+
2528
+ TEST_CASE("test scatter types") {
2529
+ for (auto t : {bool_, uint8, uint16, int8, int16}) {
2530
+ auto in = zeros({4, 4}, t);
2531
+ auto inds = {arange(4), arange(4)};
2532
+ auto updates = ones({4, 1, 1}, t);
2533
+ auto out = scatter(in, inds, updates, {0, 1});
2534
+ auto expected =
2535
+ array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t);
2536
+ CHECK(array_equal(out, expected).item<bool>());
2537
+ }
2538
+
2539
+ for (auto t : {float16, bfloat16}) {
2540
+ auto in = zeros({4, 4}, t);
2541
+ auto inds = {arange(4), arange(4)};
2542
+ auto updates = ones({4, 1, 1}, t);
2543
+ auto out = scatter(in, inds, updates, {0, 1});
2544
+ auto expected =
2545
+ array({1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1}, {4, 4}, t);
2546
+ CHECK(allclose(out, expected).item<bool>());
2547
+ }
2548
+ }
2549
+
2550
+ TEST_CASE("test complex ops") {
2551
+ // Creation ops
2552
+ {
2553
+ auto x = full({2, 2}, complex64_t{1, 1});
2554
+ CHECK_EQ(x.dtype(), complex64);
2555
+ std::initializer_list<complex64_t> expected = {
2556
+ {1, 1}, {1, 1}, {1, 1}, {1, 1}};
2557
+ CHECK(array_equal(x, array(expected, {2, 2})).item<bool>());
2558
+ }
2559
+
2560
+ // Unary ops
2561
+ {
2562
+ std::initializer_list<complex64_t> vals = {{0, 1}, {1, 0}, {1, 1}};
2563
+ auto x = array(vals);
2564
+
2565
+ auto y = abs(x);
2566
+ CHECK_EQ(y.dtype(), float32);
2567
+ CHECK(array_equal(y, array({1.0f, 1.0f, std::sqrt(2.0f)})).item<bool>());
2568
+
2569
+ y = negative(x);
2570
+ std::initializer_list<complex64_t> expected = {{0, -1}, {-1, 0}, {-1, -1}};
2571
+ CHECK(array_equal(y, array(expected)).item<bool>());
2572
+
2573
+ y = exp(x);
2574
+ {
2575
+ std::initializer_list<complex64_t> expected = {
2576
+ {0.54030231, 0.84147098}, {2.71828183, 0.}, {1.46869394, 2.28735529}};
2577
+ CHECK(allclose(y, array(expected)).item<bool>());
2578
+ }
2579
+
2580
+ y = sin(x);
2581
+ {
2582
+ std::initializer_list<complex64_t> expected = {
2583
+ {0., 1.17520119}, {0.84147098, 0.}, {1.29845758, 0.63496391}};
2584
+ CHECK(allclose(y, array(expected)).item<bool>());
2585
+ }
2586
+
2587
+ y = cos(x);
2588
+ {
2589
+ std::initializer_list<complex64_t> expected = {
2590
+ {1.54308063, -0.}, {0.54030231, -0.}, {0.83373003, -0.98889771}};
2591
+ CHECK(allclose(y, array(expected)).item<bool>());
2592
+ }
2593
+ }
2594
+
2595
+ // Binary ops
2596
+ {
2597
+ std::initializer_list<complex64_t> vals_x = {{0, 1}, {1, 0}, {1, 1}};
2598
+ auto x = array(vals_x);
2599
+
2600
+ std::initializer_list<complex64_t> vals_y = {{2, 0}, {1, 1}, {0, 1}};
2601
+ auto y = array(vals_y);
2602
+
2603
+ auto z = add(x, y);
2604
+ {
2605
+ std::initializer_list<complex64_t> expected = {{2, 1}, {2, 1}, {1, 2}};
2606
+ CHECK(array_equal(z, array(expected)).item<bool>());
2607
+ }
2608
+
2609
+ z = subtract(x, y);
2610
+ {
2611
+ std::initializer_list<complex64_t> expected = {{-2, 1}, {0, -1}, {1, 0}};
2612
+ CHECK(array_equal(z, array(expected)).item<bool>());
2613
+ }
2614
+
2615
+ z = multiply(x, y);
2616
+ {
2617
+ std::initializer_list<complex64_t> expected = {{0, 2}, {1, 1}, {-1, 1}};
2618
+ CHECK(array_equal(z, array(expected)).item<bool>());
2619
+ }
2620
+
2621
+ z = maximum(x, y);
2622
+ {
2623
+ std::initializer_list<complex64_t> expected = {{2, 0}, {1, 1}, {1, 1}};
2624
+ CHECK(array_equal(z, array(expected)).item<bool>());
2625
+ }
2626
+ }
2627
+
2628
+ // Reductions
2629
+ if (default_device() == Device::cpu) {
2630
+ std::initializer_list<complex64_t> vals = {{0, 0}, {1, 0}, {0, 1}};
2631
+ auto x = array(vals);
2632
+ CHECK_EQ(max(x).item<complex64_t>(), complex64_t{1, 0});
2633
+ CHECK_EQ(min(x).item<complex64_t>(), complex64_t{0, 0});
2634
+ CHECK_EQ(sum(x).item<complex64_t>(), complex64_t{1, 1});
2635
+ CHECK_EQ(prod(x).item<complex64_t>(), complex64_t{0, 0});
2636
+ }
2637
+ }
2638
+
2639
+ TEST_CASE("test as_strided op") {
2640
+ auto x = arange(10);
2641
+ auto y = as_strided(x, {3, 3}, {1, 1}, 0);
2642
+ auto expected = array({0, 1, 2, 1, 2, 3, 2, 3, 4}, {3, 3});
2643
+ CHECK(array_equal(y, expected).item<bool>());
2644
+
2645
+ y = as_strided(x, {3, 3}, {0, 3}, 0);
2646
+ expected = array({0, 3, 6, 0, 3, 6, 0, 3, 6}, {3, 3});
2647
+ CHECK(array_equal(y, expected).item<bool>());
2648
+
2649
+ x = reshape(x, {2, 5}); // 0 1 2 3 ...
2650
+ x = transpose(x, {1, 0}); // 0 5 1 6 2 7 ...
2651
+ y = as_strided(x, {3, 3}, {2, 1}, 1);
2652
+ expected = array({5, 1, 6, 6, 2, 7, 7, 3, 8}, {3, 3});
2653
+ CHECK(array_equal(y, expected).item<bool>());
2654
+ }
2655
+
2656
+ TEST_CASE("test scan op") {
2657
+ auto x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});
2658
+ auto y = cumsum(x, 1, false, true);
2659
+ auto expected = array({1.0f, 3.0f, 6.0f, 4.0f, 9.0f, 15.0f}, {2, 3});
2660
+ CHECK(array_equal(y, expected).item<bool>());
2661
+
2662
+ y = cumsum(x, 1, false, false);
2663
+ expected = array({0.0f, 1.0f, 3.0f, 0.0f, 4.0f, 9.0f}, {2, 3});
2664
+ CHECK(array_equal(y, expected).item<bool>());
2665
+
2666
+ y = cumsum(x, 1, true, true);
2667
+ expected = array({6.0f, 5.0f, 3.0f, 15.0f, 11.0f, 6.0f}, {2, 3});
2668
+ CHECK(array_equal(y, expected).item<bool>());
2669
+
2670
+ y = cumsum(x, 1, true, false);
2671
+ expected = array({5.0f, 3.0f, 0.0f, 11.0f, 6.0f, 0.0f}, {2, 3});
2672
+ CHECK(array_equal(y, expected).item<bool>());
2673
+
2674
+ x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2});
2675
+ y = cumsum(x, 0, false, true);
2676
+ expected =
2677
+ array({1.0f, 2.0f, 3.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}, {2, 2, 2});
2678
+ CHECK(array_equal(y, expected).item<bool>());
2679
+
2680
+ y = cumsum(x, 1, false, true);
2681
+ expected =
2682
+ array({1.0f, 2.0f, 4.0f, 6.0f, 5.0f, 6.0f, 12.0f, 14.0f}, {2, 2, 2});
2683
+ CHECK(array_equal(y, expected).item<bool>());
2684
+
2685
+ x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2});
2686
+ y = cumsum(x, 0, true, true);
2687
+ expected =
2688
+ array({6.0f, 8.0f, 10.0f, 12.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2});
2689
+ CHECK(array_equal(y, expected).item<bool>());
2690
+
2691
+ y = cumsum(x, 1, true, true);
2692
+ expected =
2693
+ array({4.0f, 6.0f, 3.0f, 4.0f, 12.0f, 14.0f, 7.0f, 8.0f}, {2, 2, 2});
2694
+ CHECK(array_equal(y, expected).item<bool>());
2695
+
2696
+ x = reshape(x, {4, 2});
2697
+ y = cumsum(x, 0, false, false);
2698
+ expected = array({0.0f, 0.0f, 1.0f, 2.0f, 4.0f, 6.0f, 9.0f, 12.0f}, {4, 2});
2699
+ CHECK(array_equal(y, expected).item<bool>());
2700
+
2701
+ y = cumsum(x, 0, true, false);
2702
+ expected =
2703
+ array({15.0f, 18.0f, 12.0f, 14.0f, 7.0f, 8.0f, 0.0f, 0.0f}, {4, 2});
2704
+ CHECK(array_equal(y, expected).item<bool>());
2705
+
2706
+ // Check the vmap implementation
2707
+ auto fun = [](array x) { return cumsum(x, 0, false, true); };
2708
+ y = vmap(fun, 0, 0)(x);
2709
+ expected = array({1.0f, 3.0f, 3.0f, 7.0f, 5.0f, 11.0f, 7.0f, 15.0f}, {4, 2});
2710
+ CHECK(array_equal(y, expected).item<bool>());
2711
+
2712
+ y = vmap(fun, 1, 1)(x);
2713
+ expected = array({1.0f, 2.0f, 4.0f, 6.0f, 9.0f, 12.0f, 16.0f, 20.0f}, {4, 2});
2714
+ CHECK(array_equal(y, expected).item<bool>());
2715
+ }
2716
+
2717
+ TEST_CASE("test pad") {
2718
+ auto x = zeros({1, 2, 3});
2719
+ CHECK_EQ(pad(x, 1).shape(), Shape{3, 4, 5});
2720
+ CHECK_EQ(pad(x, {0, 1}).shape(), Shape{2, 3, 4});
2721
+ CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), Shape{3, 5, 7});
2722
+
2723
+ x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
2724
+ auto padded_x = pad(x, 1);
2725
+ auto expected = array(
2726
+ {0.0f,
2727
+ 0.0f,
2728
+ 0.0f,
2729
+ 0.0f,
2730
+ 0.0f,
2731
+ 1.0f,
2732
+ 2.0f,
2733
+ 0.0f,
2734
+ 0.0f,
2735
+ 3.0f,
2736
+ 4.0f,
2737
+ 0.0f,
2738
+ 0.0f,
2739
+ 0.0f,
2740
+ 0.0f,
2741
+ 0.0f},
2742
+ {4, 4});
2743
+ CHECK(array_equal(padded_x, expected).item<bool>());
2744
+ }
2745
+
2746
+ TEST_CASE("test power") {
2747
+ CHECK_EQ(power(array(1), array(2)).item<int>(), 1);
2748
+ CHECK_EQ((power(array(-1), array(2))).item<int>(), 1);
2749
+ CHECK_EQ((power(array(-1), array(3))).item<int>(), -1);
2750
+
2751
+ CHECK_EQ((power(array(true), array(false))).item<bool>(), true);
2752
+ CHECK_EQ((power(array(false), array(false))).item<bool>(), true);
2753
+ CHECK_EQ((power(array(true), array(true))).item<bool>(), true);
2754
+ CHECK_EQ((power(array(false), array(true))).item<bool>(), false);
2755
+
2756
+ auto x = array(2.0f);
2757
+ CHECK_EQ(
2758
+ (power(x, array(0.5))).item<float>(),
2759
+ doctest::Approx(std::pow(2.0f, 0.5f)));
2760
+ CHECK_EQ(power(x, array(2.0f)).item<float>(), 4.0f);
2761
+
2762
+ CHECK(std::isnan((power(array(-1.0f), array(0.5))).item<float>()));
2763
+
2764
+ auto a = complex64_t{0.5, 0.5};
2765
+ auto b = complex64_t{0.5, 0.5};
2766
+ auto expected = std::pow(a, b);
2767
+ auto out = (power(array(a), array(b))).item<complex64_t>();
2768
+ CHECK(abs(out.real() - expected.real()) < 1e-7);
2769
+ CHECK(abs(out.imag() - expected.imag()) < 1e-7);
2770
+
2771
+ a = complex64_t{-1.2, 0.1};
2772
+ b = complex64_t{2.2, 0.0};
2773
+ expected = std::pow(a, b);
2774
+ out = (power(array(a), array(b))).item<complex64_t>();
2775
+ CHECK(abs(out.real() - expected.real()) < 1e-6);
2776
+ CHECK(abs(out.imag() - expected.imag()) < 1e-6);
2777
+ }
2778
+
2779
+ TEST_CASE("test where") {
2780
+ const float inf = std::numeric_limits<float>::infinity();
2781
+
2782
+ array condition(true);
2783
+ array x(1.0f);
2784
+ array y(0.0f);
2785
+ auto out = where(condition, x, y);
2786
+ CHECK_EQ(out.dtype(), float32);
2787
+ CHECK_EQ(out.item<float>(), 1.0f);
2788
+
2789
+ x = array({1, 2}, {2, 1});
2790
+ y = array({3, 4}, {1, 2});
2791
+ CHECK(array_equal(where(condition, x, y), broadcast_to(x, {2, 2}))
2792
+ .item<bool>());
2793
+
2794
+ condition = array(false);
2795
+ CHECK(array_equal(where(condition, x, y), broadcast_to(y, {2, 2}))
2796
+ .item<bool>());
2797
+
2798
+ condition = array({true, false});
2799
+ out = where(condition, x, y);
2800
+ auto expected = array({1, 4, 2, 4}, {2, 2});
2801
+ CHECK(array_equal(where(condition, x, y), expected).item<bool>());
2802
+
2803
+ condition = array({true, false, false, true}, {2, 2});
2804
+ out = where(condition, x, y);
2805
+ expected = array({1, 4, 3, 2}, {2, 2});
2806
+ CHECK(array_equal(where(condition, x, y), expected).item<bool>());
2807
+
2808
+ x = array(1);
2809
+ y = array(2);
2810
+ out = where(condition, x, y);
2811
+ expected = array({1, 2, 2, 1}, {2, 2});
2812
+ CHECK(array_equal(where(condition, x, y), expected).item<bool>());
2813
+
2814
+ condition = array(true);
2815
+ x = array({1, 2, 3});
2816
+ y = array({3, 6, 13});
2817
+ CHECK(array_equal(where(condition, x, y), array({1, 2, 3})).item<bool>());
2818
+
2819
+ condition = array(false);
2820
+ x = array({1, 2, 3});
2821
+ y = array({3, 6, 13});
2822
+ CHECK(array_equal(where(condition, x, y), array({3, 6, 13})).item<bool>());
2823
+
2824
+ condition = array({1, 1, 0});
2825
+ x = array({1, 2, 3});
2826
+ y = array({11, 12, 13});
2827
+ CHECK(array_equal(where(condition, x, y), array({1, 2, 13})).item<bool>());
2828
+
2829
+ condition = array({true, false}, {2, 1, 1});
2830
+ x = array({1, 2, 3, 4}, {2, 1, 2});
2831
+ y = array({11, 22, 33, 44}, {2, 2, 1});
2832
+ expected = array({1, 2, 1, 2, 33, 33, 44, 44}, {2, 2, 2});
2833
+ CHECK(array_equal(where(condition, x, y), expected).item<bool>());
2834
+
2835
+ condition = array({true, false, false});
2836
+ x = array({inf, 2.0, 3.0});
2837
+ y = array({10.0, 20.0, -inf});
2838
+ CHECK(array_equal(where(condition, x, y), array({inf, 20.0, -inf}))
2839
+ .item<bool>());
2840
+
2841
+ // 4-dim optimized case.
2842
+ condition = array({false});
2843
+ x = array({1, 2}, {2, 1, 1, 1});
2844
+ y = array({3, 4}, {1, 1, 2, 1});
2845
+ CHECK(array_equal(where(condition, x, y), array({3, 4, 3, 4}, {2, 1, 2, 1}))
2846
+ .item<bool>());
2847
+
2848
+ // 5-dim optimized case.
2849
+ condition = array({true, false}, {2, 1, 1, 1, 1});
2850
+ x = array({1, 2, 3, 4}, {2, 1, 1, 1, 2});
2851
+ y = array({11, 22}, {1, 1, 2, 1, 1});
2852
+ CHECK(array_equal(
2853
+ where(condition, x, y),
2854
+ array({1, 2, 1, 2, 11, 11, 22, 22}, {2, 1, 2, 1, 2}))
2855
+ .item<bool>());
2856
+ }
2857
+
2858
+ TEST_CASE("test stack") {
2859
+ auto x = array({});
2860
+ CHECK_EQ(stack({x}, 0).shape(), Shape{1, 0});
2861
+ CHECK_EQ(stack({x}, 1).shape(), Shape{0, 1});
2862
+
2863
+ x = array({1, 2, 3}, {3});
2864
+ CHECK_EQ(stack({x}, 0).shape(), Shape{1, 3});
2865
+ CHECK_EQ(stack({x}, 1).shape(), Shape{3, 1});
2866
+
2867
+ auto y = array({4, 5, 6}, {3});
2868
+ auto z = std::vector<array>{x, y};
2869
+ CHECK_EQ(stack(z).shape(), Shape{2, 3});
2870
+ CHECK_EQ(stack(z, 0).shape(), Shape{2, 3});
2871
+ CHECK_EQ(stack(z, 1).shape(), Shape{3, 2});
2872
+ CHECK_EQ(stack(z, -1).shape(), Shape{3, 2});
2873
+ CHECK_EQ(stack(z, -2).shape(), Shape{2, 3});
2874
+
2875
+ CHECK_THROWS_MESSAGE(stack({}, 0), "No arrays provided for stacking");
2876
+
2877
+ x = array({1, 2, 3}, {3}, float16);
2878
+ y = array({4, 5, 6}, {3}, int32);
2879
+ CHECK_EQ(stack({x, y}, 0).dtype(), float16);
2880
+
2881
+ x = array({1, 2, 3}, {3}, int32);
2882
+ y = array({4, 5, 6, 7}, {4}, int32);
2883
+ CHECK_THROWS_MESSAGE(
2884
+ stack({x, y}, 0), "All arrays must have the same shape and dtype");
2885
+ }
2886
+
2887
+ TEST_CASE("test full_like") {
2888
+ auto base_int = array({1, 2, 3}, {3}, int16);
2889
+
2890
+ auto from_array_with_dtype = full_like(base_int, array(7.5f), float16);
2891
+ auto expected_float16 = array({7.5, 7.5, 7.5}, {3}, float16);
2892
+ CHECK_EQ(from_array_with_dtype.dtype(), float16);
2893
+ CHECK(array_equal(from_array_with_dtype, expected_float16).item<bool>());
2894
+
2895
+ auto from_array_default_dtype = full_like(base_int, array(4.0f));
2896
+ auto expected_int16 = array({4, 4, 4}, {3}, int16);
2897
+ CHECK_EQ(from_array_default_dtype.dtype(), int16);
2898
+ CHECK(array_equal(from_array_default_dtype, expected_int16).item<bool>());
2899
+
2900
+ auto from_scalar_with_dtype = full_like(base_int, 3.25f, float32);
2901
+ auto expected_float32 = array({3.25f, 3.25f, 3.25f}, {3}, float32);
2902
+ CHECK_EQ(from_scalar_with_dtype.dtype(), float32);
2903
+ CHECK(array_equal(from_scalar_with_dtype, expected_float32).item<bool>());
2904
+
2905
+ auto base_float = array({1.0f, 2.0f}, {2}, float32);
2906
+ auto from_scalar_default_dtype = full_like(base_float, 2);
2907
+ auto expected_base_float = array({2.0f, 2.0f}, {2}, float32);
2908
+ CHECK_EQ(from_scalar_default_dtype.dtype(), float32);
2909
+ CHECK(
2910
+ array_equal(from_scalar_default_dtype, expected_base_float).item<bool>());
2911
+ }
2912
+
2913
+ TEST_CASE("test eye") {
2914
+ auto eye_3 = eye(3);
2915
+ CHECK_EQ(eye_3.shape(), Shape{3, 3});
2916
+ auto expected_eye_3 =
2917
+ array({1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, {3, 3});
2918
+ CHECK(array_equal(eye_3, expected_eye_3).item<bool>());
2919
+
2920
+ auto eye_3x2 = eye(3, 2);
2921
+ CHECK_EQ(eye_3x2.shape(), Shape{3, 2});
2922
+ auto expected_eye_3x2 = array({1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}, {3, 2});
2923
+ CHECK(array_equal(eye_3x2, expected_eye_3x2).item<bool>());
2924
+ }
2925
+
2926
+ TEST_CASE("test tri") {
2927
+ auto _tri = tri(4, 4, 0, float32);
2928
+ CHECK_EQ(_tri.shape(), Shape{4, 4});
2929
+ auto expected_tri = array(
2930
+ {1.0f,
2931
+ 0.0f,
2932
+ 0.0f,
2933
+ 0.0f,
2934
+ 1.0f,
2935
+ 1.0f,
2936
+ 0.0f,
2937
+ 0.0f,
2938
+ 1.0f,
2939
+ 1.0f,
2940
+ 1.0f,
2941
+ 0.0f,
2942
+ 1.0f,
2943
+ 1.0f,
2944
+ 1.0f,
2945
+ 1.0f},
2946
+ {4, 4});
2947
+ CHECK(array_equal(_tri, expected_tri).item<bool>());
2948
+ }
2949
+
2950
+ TEST_CASE("test tril") {
2951
+ auto _tril = tril(full({4, 4}, 2.0f, float32), 0);
2952
+ CHECK_EQ(_tril.shape(), Shape{4, 4});
2953
+ auto expected_tri = array(
2954
+ {2.0f,
2955
+ 0.0f,
2956
+ 0.0f,
2957
+ 0.0f,
2958
+ 2.0f,
2959
+ 2.0f,
2960
+ 0.0f,
2961
+ 0.0f,
2962
+ 2.0f,
2963
+ 2.0f,
2964
+ 2.0f,
2965
+ 0.0f,
2966
+ 2.0f,
2967
+ 2.0f,
2968
+ 2.0f,
2969
+ 2.0f},
2970
+ {4, 4});
2971
+ CHECK(array_equal(_tril, expected_tri).item<bool>());
2972
+ }
2973
+
2974
+ TEST_CASE("test triu") {
2975
+ auto _triu = triu(full({4, 4}, 2.0f, float32), 0);
2976
+ CHECK_EQ(_triu.shape(), Shape{4, 4});
2977
+ auto expected_tri = array(
2978
+ {2.0f,
2979
+ 2.0f,
2980
+ 2.0f,
2981
+ 2.0f,
2982
+ 0.0f,
2983
+ 2.0f,
2984
+ 2.0f,
2985
+ 2.0f,
2986
+ 0.0f,
2987
+ 0.0f,
2988
+ 2.0f,
2989
+ 2.0f,
2990
+ 0.0f,
2991
+ 0.0f,
2992
+ 0.0f,
2993
+ 2.0f},
2994
+ {4, 4});
2995
+ CHECK(array_equal(_triu, expected_tri).item<bool>());
2996
+ }
2997
+
2998
+ TEST_CASE("test identity") {
2999
+ auto id_4 = identity(4);
3000
+ CHECK_EQ(id_4.shape(), Shape{4, 4});
3001
+ auto expected_id_4 = array(
3002
+ {1.0f,
3003
+ 0.0f,
3004
+ 0.0f,
3005
+ 0.0f,
3006
+ 0.0f,
3007
+ 1.0f,
3008
+ 0.0f,
3009
+ 0.0f,
3010
+ 0.0f,
3011
+ 0.0f,
3012
+ 1.0f,
3013
+ 0.0f,
3014
+ 0.0f,
3015
+ 0.0f,
3016
+ 0.0f,
3017
+ 1.0f},
3018
+ {4, 4});
3019
+ CHECK(array_equal(id_4, expected_id_4).item<bool>());
3020
+ }
3021
+
3022
+ TEST_CASE("test eye with positive k offset") {
3023
+ auto eye_3_k1 = eye(3, 4, 1);
3024
+ CHECK_EQ(eye_3_k1.shape(), Shape{3, 4});
3025
+ auto expected_eye_3_k1 = array(
3026
+ {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f},
3027
+ {3, 4});
3028
+ CHECK(array_equal(eye_3_k1, expected_eye_3_k1).item<bool>());
3029
+ }
3030
+
3031
+ TEST_CASE("test eye with negative k offset") {
3032
+ auto eye_4_k_minus1 = eye(4, 3, -1);
3033
+ CHECK_EQ(eye_4_k_minus1.shape(), Shape{4, 3});
3034
+ auto expected_eye_4_k_minus1 = array(
3035
+ {0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f},
3036
+ {4, 3});
3037
+ CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
3038
+ }
3039
+
3040
+ TEST_CASE("test basic clipping") {
3041
+ array a({1.0f, 4.0f, 3.0f, 8.0f, 5.0f}, {5});
3042
+ array expected({2.0f, 4.0f, 3.0f, 6.0f, 5.0f}, {5});
3043
+ auto clipped = clip(a, array(2.0f), array(6.0f));
3044
+ CHECK(array_equal(clipped, expected).item<bool>());
3045
+ }
3046
+
3047
+ TEST_CASE("test clipping with only min") {
3048
+ array a({-1.0f, 1.0f, 0.0f, 5.0f}, {4});
3049
+ array expected({0.0f, 1.0f, 0.0f, 5.0f}, {4});
3050
+ auto clipped = clip(a, array(0.0f), std::nullopt);
3051
+ CHECK(array_equal(clipped, expected).item<bool>());
3052
+ }
3053
+
3054
+ TEST_CASE("test clipping with only max") {
3055
+ array a({2.0f, 3.0f, 4.0f, 5.0f}, {4});
3056
+ array expected({2.0f, 3.0f, 4.0f, 4.0f}, {4});
3057
+ auto clipped = clip(a, std::nullopt, array(4.0f));
3058
+ CHECK(array_equal(clipped, expected).item<bool>());
3059
+ }
3060
+
3061
+ TEST_CASE("test linspace") {
3062
+ auto x = linspace(0, 10, 5);
3063
+ auto expected = array({0.0f, 2.5f, 5.0f, 7.5f, 10.0f}, {5});
3064
+ CHECK(array_equal(x, expected).item<bool>());
3065
+
3066
+ x = linspace(0, 10, 5, int32);
3067
+ expected = array({0, 2, 5, 7, 10}, {5});
3068
+ CHECK(array_equal(x, expected).item<bool>());
3069
+
3070
+ x = linspace(0, 1, 0);
3071
+ expected = array(std::initializer_list<float>{}, {0});
3072
+ CHECK(array_equal(x, expected).item<bool>());
3073
+ }
3074
+
3075
+ TEST_CASE("test quantize dequantize") {
3076
+ auto x1 = ones({128, 1});
3077
+ auto x2 = expand_dims(arange(0, 512, float32), 0);
3078
+ auto x = x1 * x2;
3079
+
3080
+ for (int i = 2; i <= 8; i *= 2) {
3081
+ int el_per_int = 32 / i;
3082
+ auto res = quantize(x, 128, i);
3083
+ auto x_q = res[0];
3084
+ auto scales = res[1];
3085
+ auto biases = res[2];
3086
+ CHECK_EQ(x_q.shape(), Shape{128, 512 / el_per_int});
3087
+ CHECK_EQ(scales.shape(), Shape{128, 4});
3088
+ CHECK_EQ(biases.shape(), Shape{128, 4});
3089
+
3090
+ auto x_hat = dequantize(x_q, scales, biases, 128, i);
3091
+ auto max_diff = max(abs(x - x_hat)).item<float>();
3092
+ CHECK(max_diff <= 127.0 / (1 << i));
3093
+ }
3094
+ }
3095
+
3096
+ TEST_CASE("test repeat") {
3097
+ auto data = array({13, 3, 16, 6, 14, 4, 15, 5, 11, 1, 12, 2}, {3, 2, 2});
3098
+ auto repeat_axis_0 = repeat(data, 2, 0);
3099
+ auto expected_axis_0 = array(
3100
+ {13, 3, 16, 6, 13, 3, 16, 6, 14, 4, 15, 5,
3101
+ 14, 4, 15, 5, 11, 1, 12, 2, 11, 1, 12, 2},
3102
+ {6, 2, 2});
3103
+
3104
+ auto repeat_axis_1 = repeat(data, 2, 1);
3105
+ auto expected_axis_1 = array(
3106
+ {13, 3, 13, 3, 16, 6, 16, 6, 14, 4, 14, 4,
3107
+ 15, 5, 15, 5, 11, 1, 11, 1, 12, 2, 12, 2},
3108
+ {3, 4, 2});
3109
+
3110
+ auto repeat_axis_2 = repeat(data, 2); // default axis == ndim - 1 == 2
3111
+ auto expected_axis_2 = array(
3112
+ {13, 13, 3, 3, 16, 16, 6, 6, 14, 14, 4, 4,
3113
+ 15, 15, 5, 5, 11, 11, 1, 1, 12, 12, 2, 2},
3114
+ {24});
3115
+
3116
+ // check output
3117
+ CHECK(array_equal(repeat_axis_0, expected_axis_0).item<bool>());
3118
+ CHECK(array_equal(repeat_axis_1, expected_axis_1).item<bool>());
3119
+ CHECK(array_equal(repeat_axis_2, expected_axis_2).item<bool>());
3120
+
3121
+ auto data_2 = array({1, 3, 2}, {3});
3122
+ auto repeat_2 = repeat(data_2, 2, 0);
3123
+ auto expected_2 = array({1, 1, 3, 3, 2, 2}, {6});
3124
+ CHECK(array_equal(repeat_2, expected_2).item<bool>());
3125
+
3126
+ auto data_3 = array({1, 2, 3, 4, 5, 4, 0, 1, 2}, {3, 3});
3127
+ auto repeat_3 = repeat(data_3, 2, 0);
3128
+ auto expected_3 =
3129
+ array({1, 2, 3, 1, 2, 3, 4, 5, 4, 4, 5, 4, 0, 1, 2, 0, 1, 2}, {6, 3});
3130
+ CHECK(array_equal(repeat_3, expected_3).item<bool>());
3131
+
3132
+ // 0 repeats
3133
+ auto repeat_4 = repeat(data_3, 0, 0);
3134
+ auto expected_4 = array({});
3135
+ CHECK(array_equal(repeat_2, expected_2).item<bool>());
3136
+
3137
+ // negative repeats
3138
+ CHECK_THROWS_AS(repeat(data_3, -3, 0), std::invalid_argument);
3139
+ }
3140
+
3141
+ TEST_CASE("tile") {
3142
+ auto x = array({1, 2, 3}, {3});
3143
+ auto y = tile(x, {2});
3144
+ auto expected = array({1, 2, 3, 1, 2, 3}, {6});
3145
+ CHECK(array_equal(y, expected).item<bool>());
3146
+ x = array({1, 2, 3, 4}, {2, 2});
3147
+ y = tile(x, {2});
3148
+ expected = array({1, 2, 1, 2, 3, 4, 3, 4}, {2, 4});
3149
+ CHECK(array_equal(y, expected).item<bool>());
3150
+ x = array({1, 2, 3, 4}, {2, 2});
3151
+ y = tile(x, {4, 1});
3152
+ expected = array({1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}, {8, 2});
3153
+ CHECK(array_equal(y, expected).item<bool>());
3154
+
3155
+ x = array({1, 2, 3, 4}, {2, 2});
3156
+ y = tile(x, {2, 2});
3157
+ expected = array({1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4}, {4, 4});
3158
+ CHECK(array_equal(y, expected).item<bool>());
3159
+ x = array({1, 2, 3}, {3});
3160
+ y = tile(x, {2, 2, 2});
3161
+ expected = array(
3162
+ {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3},
3163
+ {2, 2, 6});
3164
+ CHECK(array_equal(y, expected).item<bool>());
3165
+ }
3166
+
3167
+ TEST_CASE("tensordot") {
3168
+ auto x = reshape(arange(60.), {3, 4, 5});
3169
+ auto y = reshape(arange(24.), {4, 3, 2});
3170
+ auto z = tensordot(x, y, {1, 0}, {0, 1});
3171
+ auto expected = array(
3172
+ {4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306}, {5, 2});
3173
+ CHECK(array_equal(z, expected).item<bool>());
3174
+ x = reshape(arange(360.), {3, 4, 5, 6});
3175
+ y = reshape(arange(360.), {6, 4, 5, 3});
3176
+ CHECK_THROWS_AS(tensordot(x, y, {2, 1, 3}, {1, 2, 0}), std::invalid_argument);
3177
+ x = reshape(arange(60.), {3, 4, 5});
3178
+ y = reshape(arange(120.), {4, 5, 6});
3179
+ z = tensordot(x, y, 2);
3180
+ expected = array(
3181
+ {14820.,
3182
+ 15010.,
3183
+ 15200.,
3184
+ 15390.,
3185
+ 15580.,
3186
+ 15770.,
3187
+ 37620.,
3188
+ 38210.,
3189
+ 38800.,
3190
+ 39390.,
3191
+ 39980.,
3192
+ 40570.,
3193
+ 60420.,
3194
+ 61410.,
3195
+ 62400.,
3196
+ 63390.,
3197
+ 64380.,
3198
+ 65370.},
3199
+ {3, 6});
3200
+ CHECK(array_equal(z, expected).item<bool>());
3201
+ }
3202
+
3203
+ TEST_CASE("outer") {
3204
+ auto x = arange(1.0, 5.0);
3205
+ auto y = arange(1.0, 4.0);
3206
+ auto z = outer(x, y);
3207
+ auto expected = array(
3208
+ {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}, {4, 3});
3209
+ CHECK(array_equal(z, expected).item<bool>());
3210
+
3211
+ x = ones({5});
3212
+ y = linspace(-2., 2., 5);
3213
+ z = outer(x, y);
3214
+ expected = array(
3215
+ {-2., -1., 0., 1., 2., -2., -1., 0., 1., 2., -2., -1., 0.,
3216
+ 1., 2., -2., -1., 0., 1., 2., -2., -1., 0., 1., 2.},
3217
+ {5, 5});
3218
+ CHECK(array_equal(z, expected).item<bool>());
3219
+ }
3220
+
3221
+ TEST_CASE("inner") {
3222
+ CHECK_THROWS_AS(
3223
+ inner(reshape(arange(5.), {1, 5}), reshape(arange(6.), {2, 3})),
3224
+ std::invalid_argument);
3225
+ auto x = array({1., 2., 3.});
3226
+ auto y = array({0., 1., 0.});
3227
+ auto z = inner(x, y);
3228
+ CHECK_EQ(z.item<float>(), 2.f);
3229
+
3230
+ x = reshape(arange(24.), {2, 3, 4});
3231
+ y = arange(4.);
3232
+ z = inner(x, y);
3233
+ auto expected = array({14., 38., 62., 86., 110., 134.}, {2, 3});
3234
+ CHECK(array_equal(z, expected).item<bool>());
3235
+
3236
+ x = reshape(arange(2.), {1, 1, 2});
3237
+ y = reshape(arange(6.), {3, 2});
3238
+ z = inner(x, y);
3239
+ expected = array({1., 3., 5.}, {1, 1, 3});
3240
+ CHECK(array_equal(z, expected).item<bool>());
3241
+
3242
+ z = inner(eye(2), array(7.));
3243
+ expected = array({7., 0., 0., 7.}, {2, 2});
3244
+ CHECK(array_equal(z, expected).item<bool>());
3245
+ }
3246
+
3247
+ TEST_CASE("test divmod") {
3248
+ auto x = array({1, 2, 3});
3249
+ auto y = array({1, 1, 1});
3250
+ auto out = divmod(x, y);
3251
+ CHECK(array_equal(out[0], array({1, 2, 3})).item<bool>());
3252
+ CHECK(array_equal(out[1], array({0, 0, 0})).item<bool>());
3253
+
3254
+ x = array({5, 6, 7});
3255
+ y = array({2, 2, 2});
3256
+ out = divmod(x, y);
3257
+ CHECK(array_equal(out[0], array({2, 3, 3})).item<bool>());
3258
+ CHECK(array_equal(out[1], array({1, 0, 1})).item<bool>());
3259
+
3260
+ // Siblings should be gone after evaling the graph
3261
+ CHECK(out[0].siblings().empty());
3262
+ CHECK(out[1].siblings().empty());
3263
+
3264
+ x = array({5.0, 6.0, 7.0});
3265
+ y = array({2.0, 2.0, 2.0});
3266
+ out = divmod(x, y);
3267
+ CHECK(array_equal(out[0], array({2.0, 3.0, 3.0})).item<bool>());
3268
+ CHECK(array_equal(out[1], array({1.0, 0.0, 1.0})).item<bool>());
3269
+
3270
+ x = array({1.0}, complex64);
3271
+ y = array({2.0}, complex64);
3272
+ CHECK_THROWS(divmod(x, y));
3273
+
3274
+ // Check that we can eval on both outputs
3275
+ x = array({1.0});
3276
+ y = array({2.0});
3277
+ out = divmod(x, y);
3278
+ eval(out);
3279
+ CHECK_EQ(out[0].item<float>(), 0.0);
3280
+ CHECK_EQ(out[1].item<float>(), 1.0);
3281
+
3282
+ // Check nested in the graph
3283
+ x = array({1.0});
3284
+ y = array({2.0});
3285
+ out = divmod(x, y);
3286
+ auto z = out[0] + out[1];
3287
+ CHECK_EQ(z.item<float>(), 1.0);
3288
+
3289
+ // Check that we can still eval when one output goes out of scope
3290
+ std::vector<array> out_holder;
3291
+ {
3292
+ out_holder.push_back(divmod(x, y)[0]);
3293
+ }
3294
+ eval(out_holder);
3295
+ CHECK_EQ(out_holder[0].item<float>(), 0.0);
3296
+
3297
+ // Check that we can still eval when the other output goes out of scope
3298
+ out_holder.clear();
3299
+ {
3300
+ out_holder.push_back(divmod(x, y)[1]);
3301
+ }
3302
+ eval(out_holder);
3303
+ CHECK_EQ(out_holder[0].item<float>(), 1.0);
3304
+ }
3305
+
3306
+ TEST_CASE("test diagonal") {
3307
+ auto x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2});
3308
+ auto out = diagonal(x);
3309
+ CHECK(array_equal(out, array({0, 3}, {2})).item<bool>());
3310
+
3311
+ CHECK_THROWS_AS(diagonal(x, 1, 6, 0), std::out_of_range);
3312
+ CHECK_THROWS_AS(diagonal(x, 1, 0, -3), std::out_of_range);
3313
+
3314
+ x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 4});
3315
+ out = diagonal(x, 2, 1, 0);
3316
+ CHECK(array_equal(out, array({8}, {1})).item<bool>());
3317
+
3318
+ out = diagonal(x, -1, 0, 1);
3319
+ CHECK(array_equal(out, array({4, 9}, {2})).item<bool>());
3320
+
3321
+ out = diagonal(x, -5, 0, 1);
3322
+ eval(out);
3323
+ CHECK_EQ(out.shape(), Shape{0});
3324
+
3325
+ x = array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 2, 2});
3326
+ out = diagonal(x, 1, 0, 1);
3327
+ CHECK(array_equal(out, array({2, 3}, {2, 1})).item<bool>());
3328
+
3329
+ out = diagonal(x, 0, 2, 0);
3330
+ CHECK(array_equal(out, array({0, 5, 2, 7}, {2, 2})).item<bool>());
3331
+
3332
+ out = diagonal(x, 1, -1, 0);
3333
+ CHECK(array_equal(out, array({4, 9, 6, 11}, {2, 2})).item<bool>());
3334
+
3335
+ x = reshape(arange(16), {2, 2, 2, 2});
3336
+ out = diagonal(x, 0, 0, 1);
3337
+ CHECK(array_equal(out, array({0, 12, 1, 13, 2, 14, 3, 15}, {2, 2, 2}))
3338
+ .item<bool>());
3339
+
3340
+ CHECK_THROWS_AS(diagonal(x, 0, 1, 1), std::invalid_argument);
3341
+
3342
+ x = array({0, 1}, {2});
3343
+ CHECK_THROWS_AS(diagonal(x, 0, 0, 1), std::invalid_argument);
3344
+ }
3345
+
3346
+ TEST_CASE("test diag") {
3347
+ // To few or too many dimensions
3348
+ CHECK_THROWS(diag(array(0.0)));
3349
+ CHECK_THROWS(diag(array({0.0}, {1, 1, 1})));
3350
+
3351
+ // Test with 1D array
3352
+ auto x = array({0, 1, 2, 3}, {4});
3353
+ auto out = diag(x, 0);
3354
+ CHECK(
3355
+ array_equal(
3356
+ out, array({0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3}, {4, 4}))
3357
+ .item<bool>());
3358
+
3359
+ out = diag(x, 1);
3360
+ CHECK(array_equal(
3361
+ out,
3362
+ array(
3363
+ {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
3364
+ 2, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0},
3365
+ {5, 5}))
3366
+ .item<bool>());
3367
+
3368
+ out = diag(x, -1);
3369
+ CHECK(array_equal(
3370
+ out,
3371
+ array(
3372
+ {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
3373
+ 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0},
3374
+ {5, 5}))
3375
+ .item<bool>());
3376
+
3377
+ // Test with 2D array
3378
+ x = array({0, 1, 2, 3, 4, 5, 6, 7, 8}, {3, 3});
3379
+ out = diag(x, 0);
3380
+ CHECK(array_equal(out, array({0, 4, 8}, {3})).item<bool>());
3381
+
3382
+ out = diag(x, 1);
3383
+ CHECK(array_equal(out, array({1, 5}, {2})).item<bool>());
3384
+
3385
+ out = diag(x, -1);
3386
+ CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
3387
+ }
3388
+
3389
+ TEST_CASE("test issubdtype") {
3390
+ const auto cats = {
3391
+ complexfloating,
3392
+ floating,
3393
+ inexact,
3394
+ signedinteger,
3395
+ unsignedinteger,
3396
+ integer,
3397
+ number,
3398
+ generic};
3399
+ const auto types = {
3400
+ bool_,
3401
+ uint8,
3402
+ uint16,
3403
+ uint32,
3404
+ uint64,
3405
+ int8,
3406
+ int16,
3407
+ int32,
3408
+ int64,
3409
+ float16,
3410
+ float32,
3411
+ bfloat16,
3412
+ complex64};
3413
+ for (const auto& type : types) {
3414
+ CHECK(issubdtype(type, type));
3415
+ CHECK(issubdtype(type, generic));
3416
+ switch (kindof(type)) {
3417
+ case Dtype::Kind::b:
3418
+ CHECK_FALSE(issubdtype(type, complexfloating));
3419
+ CHECK_FALSE(issubdtype(type, floating));
3420
+ CHECK_FALSE(issubdtype(type, inexact));
3421
+ CHECK_FALSE(issubdtype(type, signedinteger));
3422
+ CHECK_FALSE(issubdtype(type, unsignedinteger));
3423
+ CHECK_FALSE(issubdtype(type, integer));
3424
+ CHECK_FALSE(issubdtype(type, number));
3425
+ CHECK(issubdtype(type, generic));
3426
+ break;
3427
+ case Dtype::Kind::u:
3428
+ CHECK_FALSE(issubdtype(type, complexfloating));
3429
+ CHECK_FALSE(issubdtype(type, floating));
3430
+ CHECK_FALSE(issubdtype(type, inexact));
3431
+ CHECK_FALSE(issubdtype(type, signedinteger));
3432
+ CHECK(issubdtype(type, unsignedinteger));
3433
+ CHECK(issubdtype(type, integer));
3434
+ CHECK(issubdtype(type, number));
3435
+ CHECK(issubdtype(type, generic));
3436
+ break;
3437
+ case Dtype::Kind::i:
3438
+ CHECK_FALSE(issubdtype(type, complexfloating));
3439
+ CHECK_FALSE(issubdtype(type, floating));
3440
+ CHECK_FALSE(issubdtype(type, inexact));
3441
+ CHECK(issubdtype(type, signedinteger));
3442
+ CHECK_FALSE(issubdtype(type, unsignedinteger));
3443
+ CHECK(issubdtype(type, integer));
3444
+ CHECK(issubdtype(type, number));
3445
+ CHECK(issubdtype(type, generic));
3446
+ break;
3447
+ case Dtype::Kind::f:
3448
+ CHECK_FALSE(issubdtype(type, complexfloating));
3449
+ CHECK(issubdtype(type, floating));
3450
+ CHECK(issubdtype(type, inexact));
3451
+ CHECK_FALSE(issubdtype(type, signedinteger));
3452
+ CHECK_FALSE(issubdtype(type, unsignedinteger));
3453
+ CHECK_FALSE(issubdtype(type, integer));
3454
+ CHECK(issubdtype(type, number));
3455
+ CHECK(issubdtype(type, generic));
3456
+ break;
3457
+ case Dtype::Kind::c:
3458
+ CHECK(issubdtype(type, complexfloating));
3459
+ CHECK_FALSE(issubdtype(type, floating));
3460
+ CHECK(issubdtype(type, inexact));
3461
+ CHECK_FALSE(issubdtype(type, signedinteger));
3462
+ CHECK_FALSE(issubdtype(type, unsignedinteger));
3463
+ CHECK_FALSE(issubdtype(type, integer));
3464
+ CHECK(issubdtype(type, number));
3465
+ CHECK(issubdtype(type, generic));
3466
+ break;
3467
+ case Dtype::Kind::V:
3468
+ CHECK_FALSE(issubdtype(type, complexfloating));
3469
+ CHECK(issubdtype(type, floating));
3470
+ CHECK(issubdtype(type, inexact));
3471
+ CHECK_FALSE(issubdtype(type, signedinteger));
3472
+ CHECK_FALSE(issubdtype(type, unsignedinteger));
3473
+ CHECK_FALSE(issubdtype(type, integer));
3474
+ CHECK(issubdtype(type, number));
3475
+ CHECK(issubdtype(type, generic));
3476
+ break;
3477
+ }
3478
+ }
3479
+
3480
+ for (const auto& type : types) {
3481
+ CHECK(issubdtype(type, type));
3482
+ CHECK(issubdtype(type, generic));
3483
+ for (auto type1 : types) {
3484
+ CHECK_EQ(issubdtype(type, type1), type == type1);
3485
+ }
3486
+ }
3487
+
3488
+ for (const auto& cat : cats) {
3489
+ CHECK(issubdtype(cat, cat));
3490
+ switch (cat) {
3491
+ case Dtype::Category::complexfloating:
3492
+ CHECK(issubdtype(cat, complexfloating));
3493
+ CHECK_FALSE(issubdtype(cat, floating));
3494
+ CHECK(issubdtype(cat, inexact));
3495
+ CHECK_FALSE(issubdtype(cat, signedinteger));
3496
+ CHECK_FALSE(issubdtype(cat, unsignedinteger));
3497
+ CHECK_FALSE(issubdtype(cat, integer));
3498
+ CHECK(issubdtype(cat, number));
3499
+ CHECK(issubdtype(cat, generic));
3500
+ break;
3501
+ case Dtype::Category::floating:
3502
+ CHECK_FALSE(issubdtype(cat, complexfloating));
3503
+ CHECK(issubdtype(cat, floating));
3504
+ CHECK(issubdtype(cat, inexact));
3505
+ CHECK_FALSE(issubdtype(cat, signedinteger));
3506
+ CHECK_FALSE(issubdtype(cat, unsignedinteger));
3507
+ CHECK_FALSE(issubdtype(cat, integer));
3508
+ CHECK(issubdtype(cat, number));
3509
+ CHECK(issubdtype(cat, generic));
3510
+ break;
3511
+ case Dtype::Category::inexact:
3512
+ CHECK_FALSE(issubdtype(cat, complexfloating));
3513
+ CHECK_FALSE(issubdtype(cat, floating));
3514
+ CHECK(issubdtype(cat, inexact));
3515
+ CHECK_FALSE(issubdtype(cat, signedinteger));
3516
+ CHECK_FALSE(issubdtype(cat, unsignedinteger));
3517
+ CHECK_FALSE(issubdtype(cat, integer));
3518
+ CHECK(issubdtype(cat, number));
3519
+ CHECK(issubdtype(cat, generic));
3520
+ break;
3521
+ case Dtype::Category::signedinteger:
3522
+ CHECK_FALSE(issubdtype(cat, complexfloating));
3523
+ CHECK_FALSE(issubdtype(cat, floating));
3524
+ CHECK_FALSE(issubdtype(cat, inexact));
3525
+ CHECK(issubdtype(cat, signedinteger));
3526
+ CHECK_FALSE(issubdtype(cat, unsignedinteger));
3527
+ CHECK(issubdtype(cat, integer));
3528
+ CHECK(issubdtype(cat, number));
3529
+ CHECK(issubdtype(cat, generic));
3530
+ break;
3531
+ case Dtype::Category::unsignedinteger:
3532
+ CHECK_FALSE(issubdtype(cat, complexfloating));
3533
+ CHECK_FALSE(issubdtype(cat, floating));
3534
+ CHECK_FALSE(issubdtype(cat, inexact));
3535
+ CHECK_FALSE(issubdtype(cat, signedinteger));
3536
+ CHECK(issubdtype(cat, unsignedinteger));
3537
+ CHECK(issubdtype(cat, integer));
3538
+ CHECK(issubdtype(cat, number));
3539
+ CHECK(issubdtype(cat, generic));
3540
+ break;
3541
+ case Dtype::Category::integer:
3542
+ CHECK_FALSE(issubdtype(cat, complexfloating));
3543
+ CHECK_FALSE(issubdtype(cat, floating));
3544
+ CHECK_FALSE(issubdtype(cat, inexact));
3545
+ CHECK_FALSE(issubdtype(cat, signedinteger));
3546
+ CHECK_FALSE(issubdtype(cat, unsignedinteger));
3547
+ CHECK(issubdtype(cat, integer));
3548
+ CHECK(issubdtype(cat, number));
3549
+ CHECK(issubdtype(cat, generic));
3550
+ break;
3551
+ case Dtype::Category::number:
3552
+ CHECK_FALSE(issubdtype(cat, complexfloating));
3553
+ CHECK_FALSE(issubdtype(cat, floating));
3554
+ CHECK_FALSE(issubdtype(cat, inexact));
3555
+ CHECK_FALSE(issubdtype(cat, signedinteger));
3556
+ CHECK_FALSE(issubdtype(cat, unsignedinteger));
3557
+ CHECK_FALSE(issubdtype(cat, integer));
3558
+ CHECK(issubdtype(cat, number));
3559
+ CHECK(issubdtype(cat, generic));
3560
+ break;
3561
+ case Dtype::Category::generic:
3562
+ CHECK_FALSE(issubdtype(cat, complexfloating));
3563
+ CHECK_FALSE(issubdtype(cat, floating));
3564
+ CHECK_FALSE(issubdtype(cat, inexact));
3565
+ CHECK_FALSE(issubdtype(cat, signedinteger));
3566
+ CHECK_FALSE(issubdtype(cat, unsignedinteger));
3567
+ CHECK_FALSE(issubdtype(cat, integer));
3568
+ CHECK_FALSE(issubdtype(cat, number));
3569
+ CHECK(issubdtype(cat, generic));
3570
+ break;
3571
+ }
3572
+ }
3573
+ }
3574
+
3575
+ TEST_CASE("test atleast_1d") {
3576
+ auto x = array(1);
3577
+ auto out = atleast_1d(x);
3578
+ CHECK_EQ(out.ndim(), 1);
3579
+ CHECK_EQ(out.shape(), Shape{1});
3580
+
3581
+ x = array({1, 2, 3}, {3});
3582
+ out = atleast_1d(x);
3583
+ CHECK_EQ(out.ndim(), 1);
3584
+ CHECK_EQ(out.shape(), Shape{3});
3585
+
3586
+ x = array({1, 2, 3}, {3, 1});
3587
+ out = atleast_1d(x);
3588
+ CHECK_EQ(out.ndim(), 2);
3589
+ CHECK_EQ(out.shape(), Shape{3, 1});
3590
+ }
3591
+
3592
+ TEST_CASE("test atleast_1d vector") {
3593
+ auto x = std::vector<array>{
3594
+ array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
3595
+ auto out = atleast_1d(x);
3596
+ CHECK_EQ(out.size(), 3);
3597
+ CHECK_EQ(out[0].ndim(), 1);
3598
+ CHECK_EQ(out[0].shape(), Shape{1});
3599
+ CHECK_EQ(out[1].ndim(), 1);
3600
+ CHECK_EQ(out[1].shape(), Shape{3});
3601
+ CHECK_EQ(out[2].ndim(), 2);
3602
+ CHECK_EQ(out[2].shape(), Shape{3, 1});
3603
+ }
3604
+
3605
+ TEST_CASE("test atleast_2d") {
3606
+ auto x = array(1);
3607
+ auto out = atleast_2d(x);
3608
+ CHECK_EQ(out.ndim(), 2);
3609
+ CHECK_EQ(out.shape(), Shape{1, 1});
3610
+
3611
+ x = array({1, 2, 3}, {3});
3612
+ out = atleast_2d(x);
3613
+ CHECK_EQ(out.ndim(), 2);
3614
+ CHECK_EQ(out.shape(), Shape{1, 3});
3615
+
3616
+ x = array({1, 2, 3}, {3, 1});
3617
+ out = atleast_2d(x);
3618
+ CHECK_EQ(out.ndim(), 2);
3619
+ CHECK_EQ(out.shape(), Shape{3, 1});
3620
+ }
3621
+
3622
+ TEST_CASE("test atleast_2d vector") {
3623
+ auto x = std::vector<array>{
3624
+ array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
3625
+ auto out = atleast_2d(x);
3626
+ CHECK_EQ(out.size(), 3);
3627
+ CHECK_EQ(out[0].ndim(), 2);
3628
+ CHECK_EQ(out[0].shape(), Shape{1, 1});
3629
+ CHECK_EQ(out[1].ndim(), 2);
3630
+ CHECK_EQ(out[1].shape(), Shape{1, 3});
3631
+ CHECK_EQ(out[2].ndim(), 2);
3632
+ CHECK_EQ(out[2].shape(), Shape{3, 1});
3633
+ }
3634
+
3635
+ TEST_CASE("test atleast_3d") {
3636
+ auto x = array(1);
3637
+ auto out = atleast_3d(x);
3638
+ CHECK_EQ(out.ndim(), 3);
3639
+ CHECK_EQ(out.shape(), Shape{1, 1, 1});
3640
+
3641
+ x = array({1, 2, 3}, {3});
3642
+ out = atleast_3d(x);
3643
+ CHECK_EQ(out.ndim(), 3);
3644
+ CHECK_EQ(out.shape(), Shape{1, 3, 1});
3645
+
3646
+ x = array({1, 2, 3}, {3, 1});
3647
+ out = atleast_3d(x);
3648
+ CHECK_EQ(out.ndim(), 3);
3649
+ CHECK_EQ(out.shape(), Shape{3, 1, 1});
3650
+ }
3651
+
3652
+ TEST_CASE("test atleast_3d vector") {
3653
+ auto x = std::vector<array>{
3654
+ array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
3655
+ auto out = atleast_3d(x);
3656
+ CHECK_EQ(out.size(), 3);
3657
+ CHECK_EQ(out[0].ndim(), 3);
3658
+ CHECK_EQ(out[0].shape(), Shape{1, 1, 1});
3659
+ CHECK_EQ(out[1].ndim(), 3);
3660
+ CHECK_EQ(out[1].shape(), Shape{1, 3, 1});
3661
+ CHECK_EQ(out[2].ndim(), 3);
3662
+ CHECK_EQ(out[2].shape(), Shape{3, 1, 1});
3663
+ }
3664
+
3665
+ TEST_CASE("test topk") {
3666
+ auto x = reshape(arange(10), {2, 5});
3667
+
3668
+ {
3669
+ auto y = topk(x, 1, 1);
3670
+ CHECK(array_equal(y, array({4, 9}, {2, 1})).item<bool>());
3671
+ }
3672
+
3673
+ {
3674
+ auto y = topk(x, 2, 0);
3675
+ CHECK(array_equal(y, x).item<bool>());
3676
+ }
3677
+
3678
+ {
3679
+ auto y = topk(x, 1, 0);
3680
+ CHECK(array_equal(y, array({5, 6, 7, 8, 9}, {1, 5})).item<bool>());
3681
+ }
3682
+ }
3683
+
3684
+ TEST_CASE("test meshgrid") {
3685
+ // Test default
3686
+ auto x = array({1, 2, 3}, {3});
3687
+ auto in = std::vector<array>{x};
3688
+ auto out = meshgrid(in);
3689
+ CHECK(array_equal(out[0], x).item<bool>());
3690
+
3691
+ // Test different lengths
3692
+ auto y = array({4, 5}, {2});
3693
+ in = std::vector<array>{x, y};
3694
+ out = meshgrid(in);
3695
+ auto expected_zero = array({1, 2, 3, 1, 2, 3}, {2, 3});
3696
+ auto expected_one = array({4, 4, 4, 5, 5, 5}, {2, 3});
3697
+ CHECK(array_equal(out[0], expected_zero).item<bool>());
3698
+ CHECK(array_equal(out[1], expected_one).item<bool>());
3699
+
3700
+ // Test sparse true
3701
+ in = std::vector<array>{x, x};
3702
+ out = meshgrid(in, true);
3703
+ expected_zero = array({1, 2, 3}, {1, 3});
3704
+ expected_one = array({1, 2, 3}, {3, 1});
3705
+ CHECK(array_equal(out[0], expected_zero).item<bool>());
3706
+ CHECK(array_equal(out[1], expected_one).item<bool>());
3707
+ }
3708
+
3709
+ TEST_CASE("test conv1d") {
3710
+ auto in = astype(
3711
+ array(
3712
+ {0.5488135,
3713
+ 0.71518937,
3714
+ 0.60276338,
3715
+ 0.54488318,
3716
+ 0.4236548,
3717
+ 0.64589411},
3718
+ {1, 3, 2}),
3719
+ float16);
3720
+
3721
+ int stride = 1;
3722
+ int padding = 1;
3723
+
3724
+ {
3725
+ int groups = 1;
3726
+ auto wt = astype(
3727
+ array(
3728
+ {
3729
+
3730
+ 0.43758721, 0.891773, 0.96366276, 0.38344152,
3731
+ 0.79172504, 0.52889492,
3732
+
3733
+ 0.56804456, 0.92559664, 0.07103606, 0.0871293,
3734
+ 0.0202184, 0.83261985,
3735
+
3736
+ 0.77815675, 0.87001215, 0.97861834, 0.79915856,
3737
+ 0.46147936, 0.78052918,
3738
+
3739
+ 0.11827443, 0.63992102, 0.14335329, 0.94466892,
3740
+ 0.52184832, 0.41466194
3741
+
3742
+ },
3743
+ {4, 3, 2}),
3744
+ float16);
3745
+
3746
+ auto expected = array(
3747
+ {1.56836,
3748
+ 0.567383,
3749
+ 1.8125,
3750
+ 1.29492,
3751
+ 2.34375,
3752
+ 1.61035,
3753
+ 2.77539,
3754
+ 1.61328,
3755
+ 1.40527,
3756
+ 0.933105,
3757
+ 1.87402,
3758
+ 1.09082},
3759
+ {1, 3, 4});
3760
+
3761
+ auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
3762
+ CHECK(allclose(out, expected).item<bool>());
3763
+ }
3764
+
3765
+ {
3766
+ int groups = 2;
3767
+ auto wt = array(
3768
+ {0.43758721,
3769
+ 0.891773,
3770
+ 0.96366276,
3771
+
3772
+ 0.38344152,
3773
+ 0.79172504,
3774
+ 0.52889492,
3775
+
3776
+ 0.56804456,
3777
+ 0.92559664,
3778
+ 0.07103606,
3779
+
3780
+ 0.0871293,
3781
+ 0.0202184,
3782
+ 0.83261985
3783
+
3784
+ },
3785
+ {4, 3, 1});
3786
+
3787
+ auto expected = array(
3788
+ {1.07007,
3789
+ 0.753201,
3790
+ 0.700818,
3791
+ 0.468176,
3792
+ 1.18568,
3793
+ 0.91152,
3794
+ 0.956607,
3795
+ 0.611213,
3796
+ 0.641404,
3797
+ 0.566401,
3798
+ 0.907472,
3799
+ 0.0605397},
3800
+ {1, 3, 4});
3801
+
3802
+ auto out = conv1d(in, wt, stride, padding, /* dilation= */ 1, groups);
3803
+ CHECK(allclose(out, expected).item<bool>());
3804
+ }
3805
+ }
3806
+
3807
+ TEST_CASE("test conv2d") {
3808
+ auto in = array(
3809
+ {0.57429284,
3810
+ -0.21628855,
3811
+ -0.18673691,
3812
+ -0.3793517,
3813
+
3814
+ 0.3059678,
3815
+ -0.8137168,
3816
+ 0.6168841,
3817
+ -0.26912728},
3818
+ {1, 2, 2, 2});
3819
+
3820
+ std::pair<int, int> stride{1, 1};
3821
+ std::pair<int, int> padding{0, 0};
3822
+
3823
+ {
3824
+ int groups = 1;
3825
+
3826
+ auto wt = array(
3827
+ {0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172,
3828
+ -0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584,
3829
+ 0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907,
3830
+ 0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944,
3831
+ -0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727,
3832
+ -0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157,
3833
+ 1.6598022, 0.74204415},
3834
+ {4, 2, 2, 2});
3835
+
3836
+ auto expected =
3837
+ array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});
3838
+ auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
3839
+ CHECK(allclose(out, expected).item<bool>());
3840
+ }
3841
+
3842
+ {
3843
+ int groups = 2;
3844
+ auto wt = array(
3845
+ {0.3190391,
3846
+ -0.24937038,
3847
+
3848
+ 1.46210794,
3849
+ -2.06014071,
3850
+
3851
+ -0.3224172,
3852
+ -0.38405435,
3853
+
3854
+ 1.13376944,
3855
+ -1.09989127,
3856
+
3857
+ -0.17242821,
3858
+ -0.87785842,
3859
+
3860
+ 0.04221375,
3861
+ 0.58281521,
3862
+
3863
+ -1.10061918,
3864
+ 1.14472371,
3865
+
3866
+ 0.90159072,
3867
+ 0.50249434},
3868
+ {4, 2, 2, 1});
3869
+
3870
+ auto expected = array(
3871
+ {-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
3872
+
3873
+ auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
3874
+ CHECK(allclose(out, expected).item<bool>());
3875
+ }
3876
+
3877
+ {
3878
+ in = array(
3879
+ {0.57429284,
3880
+ -0.21628855,
3881
+ -0.18673691,
3882
+ -0.3793517,
3883
+
3884
+ 0.3059678,
3885
+ -0.8137168,
3886
+ 0.6168841,
3887
+ -0.26912728,
3888
+
3889
+ 0.57429284,
3890
+ -0.21628855,
3891
+ -0.18673691,
3892
+ -0.3793517,
3893
+
3894
+ 0.3059678,
3895
+ -0.8137168,
3896
+ 0.6168841,
3897
+ -0.26912728},
3898
+ {2, 2, 2, 2});
3899
+
3900
+ int groups = 2;
3901
+ auto wt = array(
3902
+ {0.3190391,
3903
+ -0.24937038,
3904
+
3905
+ 1.46210794,
3906
+ -2.06014071,
3907
+
3908
+ -0.3224172,
3909
+ -0.38405435,
3910
+
3911
+ 1.13376944,
3912
+ -1.09989127,
3913
+
3914
+ -0.17242821,
3915
+ -0.87785842,
3916
+
3917
+ 0.04221375,
3918
+ 0.58281521,
3919
+
3920
+ -1.10061918,
3921
+ 1.14472371,
3922
+
3923
+ 0.90159072,
3924
+ 0.50249434},
3925
+ {4, 2, 2, 1});
3926
+
3927
+ auto expected = array(
3928
+ {-0.59372161, -0.44505326, 0.17910982, -1.06507601}, {1, 1, 1, 4});
3929
+
3930
+ auto out = conv2d(in, wt, stride, padding, /* dilation= */ {1, 1}, groups);
3931
+ CHECK(allclose(out, expected).item<bool>());
3932
+ }
3933
+ }
3934
+
3935
+ TEST_CASE("test trace") {
3936
+ auto in = eye(3);
3937
+ auto out = trace(in).item<float>();
3938
+ CHECK_EQ(out, 3.0);
3939
+
3940
+ in = array({1, 2, 3, 4, 5, 6, 7, 8, 9}, {3, 3}, int32);
3941
+ auto out2 = trace(in).item<int>();
3942
+ CHECK_EQ(out2, 15);
3943
+
3944
+ in = reshape(arange(8), {2, 2, 2});
3945
+ auto out3 = trace(in, 0, 0, 1);
3946
+ CHECK(array_equal(out3, array({6, 8}, {2})).item<bool>());
3947
+
3948
+ auto out4 = trace(in, 0, 1, 2, float32);
3949
+ CHECK(array_equal(out4, array({3, 11}, {2})).item<bool>());
3950
+ }
3951
+
3952
+ TEST_CASE("test view") {
3953
+ auto in = array(3);
3954
+ CHECK_THROWS(view(in, int64));
3955
+
3956
+ in = array({1, 2, 3});
3957
+ CHECK_THROWS(view(in, int64));
3958
+
3959
+ in = array({1, 2, 3, 4}, int64);
3960
+ auto out = view(in, int32);
3961
+ CHECK(array_equal(out, array({1, 0, 2, 0, 3, 0, 4, 0})).item<bool>());
3962
+ }
3963
+
3964
+ TEST_CASE("test roll") {
3965
+ auto x = reshape(arange(10), {2, 5});
3966
+
3967
+ auto y = roll(x, 2);
3968
+ CHECK(array_equal(y, array({8, 9, 0, 1, 2, 3, 4, 5, 6, 7}, {2, 5}))
3969
+ .item<bool>());
3970
+
3971
+ y = roll(x, -2);
3972
+ CHECK(array_equal(y, array({2, 3, 4, 5, 6, 7, 8, 9, 0, 1}, {2, 5}))
3973
+ .item<bool>());
3974
+
3975
+ y = roll(x, 2, 1);
3976
+ CHECK(array_equal(y, array({3, 4, 0, 1, 2, 8, 9, 5, 6, 7}, {2, 5}))
3977
+ .item<bool>());
3978
+
3979
+ y = roll(x, -2, 1);
3980
+ CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))
3981
+ .item<bool>());
3982
+
3983
+ y = roll(x, 2, {0, 0, 0});
3984
+ CHECK(array_equal(y, array({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 5}))
3985
+ .item<bool>());
3986
+
3987
+ y = roll(x, 1, {1, 1, 1});
3988
+ CHECK(array_equal(y, array({2, 3, 4, 0, 1, 7, 8, 9, 5, 6}, {2, 5}))
3989
+ .item<bool>());
3990
+
3991
+ y = roll(x, {1, 2}, {0, 1});
3992
+ CHECK(array_equal(y, array({8, 9, 5, 6, 7, 3, 4, 0, 1, 2}, {2, 5}))
3993
+ .item<bool>());
3994
+
3995
+ y = roll(array({}), 0, 0);
3996
+ CHECK(array_equal(y, array({})).item<bool>());
3997
+ }
3998
+
3999
+ TEST_CASE("test contiguous") {
4000
+ auto x = array({1, 2, 3});
4001
+ x = contiguous(broadcast_to(x, {2, 2, 3}));
4002
+ eval(x);
4003
+ CHECK(x.flags().row_contiguous);
4004
+ CHECK_EQ(x.strides(), decltype(x.strides()){6, 3, 1});
4005
+
4006
+ x = array({1, 2, 1, 2}, {2, 2});
4007
+ x = contiguous(transpose(x), true);
4008
+ eval(x);
4009
+ CHECK(x.flags().col_contiguous);
4010
+ CHECK_EQ(x.strides(), decltype(x.strides()){1, 2});
4011
+ }
4012
+
4013
+ TEST_CASE("test bitwise shift operations") {
4014
+ std::vector<Dtype> dtypes = {
4015
+ int8, int16, int32, int64, uint8, uint16, uint32, uint64};
4016
+
4017
+ for (const auto& dtype : dtypes) {
4018
+ array x = full({4}, 1, dtype);
4019
+ array y = full({4}, 2, dtype);
4020
+
4021
+ auto left_shift_result = left_shift(x, y);
4022
+ CHECK_EQ(left_shift_result.dtype(), dtype);
4023
+ CHECK(array_equal(left_shift_result, array({4, 4, 4, 4}, dtype))
4024
+ .item<bool>());
4025
+
4026
+ auto right_shift_result = right_shift(full({4}, 4, dtype), y);
4027
+ CHECK_EQ(right_shift_result.dtype(), dtype);
4028
+ CHECK(array_equal(right_shift_result, full({4}, 1, dtype)).item<bool>());
4029
+ }
4030
+
4031
+ array x = array({127, -128}, int8);
4032
+ array y = array({1, 1}, int8);
4033
+ auto left_shift_result = left_shift(x, y);
4034
+ auto right_shift_result = right_shift(x, y);
4035
+
4036
+ CHECK(array_equal(left_shift_result, array({-2, 0}, int8)).item<bool>());
4037
+ CHECK(array_equal(right_shift_result, array({63, -64}, int8)).item<bool>());
4038
+
4039
+ array x_bool = full({4}, true, bool_);
4040
+ array y_bool = full({4}, true, bool_);
4041
+ auto left_shift_bool_result = left_shift(x_bool, y_bool);
4042
+ auto right_shift_bool_result = right_shift(x_bool, y_bool);
4043
+
4044
+ CHECK_EQ(left_shift_bool_result.dtype(), uint8);
4045
+ CHECK(array_equal(left_shift_bool_result, full({4}, 2, uint8)).item<bool>());
4046
+
4047
+ CHECK_EQ(right_shift_bool_result.dtype(), uint8);
4048
+ CHECK(array_equal(right_shift_bool_result, full({4}, 0, uint8)).item<bool>());
4049
+ }
4050
+
4051
+ TEST_CASE("test conv_transpose1d with output_padding") {
4052
+ auto in = array({1.0, 2.0, 3.0}, {1, 1, 3});
4053
+ auto wt = array({1.0, 1.0, 1.0}, {1, 1, 3});
4054
+ int stride = 2;
4055
+ int padding = 0;
4056
+ int dilation = 1;
4057
+ int output_padding = 1;
4058
+ int groups = 1;
4059
+
4060
+ auto out = conv_transpose1d(
4061
+ in, wt, stride, padding, dilation, output_padding, groups);
4062
+ auto expected = array({6.0, 0.0}, {1, 2, 1});
4063
+ CHECK(array_equal(out, expected).item<bool>());
4064
+ }
4065
+
4066
+ TEST_CASE("test conv_transpose2d with output_padding") {
4067
+ auto in = array({1.0, 2.0, 3.0, 4.0}, {1, 1, 2, 2});
4068
+ auto wt = array({1.0, 1.0, 1.0, 1.0}, {2, 1, 1, 2});
4069
+ std::pair<int, int> stride{2, 2};
4070
+ std::pair<int, int> padding{0, 0};
4071
+ std::pair<int, int> output_padding{1, 1};
4072
+ std::pair<int, int> dilation{1, 1};
4073
+ int groups = 1;
4074
+
4075
+ auto out = conv_transpose2d(
4076
+ in, wt, stride, padding, dilation, output_padding, groups);
4077
+ auto expected = array(
4078
+ {3.0,
4079
+ 3.0,
4080
+ 0.0,
4081
+ 0.0,
4082
+ 7.0,
4083
+ 7.0,
4084
+ 0.0,
4085
+ 0.0,
4086
+ 0.0,
4087
+ 0.0,
4088
+ 0.0,
4089
+ 0.0,
4090
+ 0.0,
4091
+ 0.0,
4092
+ 0.0,
4093
+ 0.0},
4094
+ {1, 2, 4, 2});
4095
+ CHECK(array_equal(out, expected).item<bool>());
4096
+ }
4097
+
4098
+ TEST_CASE("test conv_transpose3d with output_padding") {
4099
+ auto in = array({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}, {1, 1, 2, 2, 2});
4100
+ auto wt = array({1.0, 1.0}, {1, 1, 1, 1, 2});
4101
+ std::tuple<int, int, int> stride{2, 2, 2};
4102
+ std::tuple<int, int, int> padding{0, 0, 0};
4103
+ std::tuple<int, int, int> output_padding{1, 1, 1};
4104
+ std::tuple<int, int, int> dilation{1, 1, 1};
4105
+ int groups = 1;
4106
+
4107
+ auto out = conv_transpose3d(
4108
+ in, wt, stride, padding, dilation, output_padding, groups);
4109
+ auto expected = array(
4110
+ {3.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0, 0.0, 15.0,
4111
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
4112
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
4113
+ {1, 2, 4, 4, 1});
4114
+ CHECK(array_equal(out, expected).item<bool>());
4115
+ }
4116
+
4117
+ TEST_CASE("test fp8 conversion") {
4118
+ for (auto t : {float32, float16, bfloat16}) {
4119
+ array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0}, t);
4120
+ auto in_fp8 = to_fp8(in);
4121
+ auto out = from_fp8(in_fp8, t);
4122
+ CHECK(array_equal(out, in).item<bool>());
4123
+ }
4124
+
4125
+ array in({-1.125, -1.0, 0.0, 1.0, 1.125, 4.5, 448.0});
4126
+ array noisy_in({-1.135, -1.01, 0.0001, 1.01, 1.135, 4.6, 447.0});
4127
+ auto in_fp8 = to_fp8(noisy_in);
4128
+ auto out = from_fp8(in_fp8, float32);
4129
+ CHECK(array_equal(out, in).item<bool>());
4130
+
4131
+ // Overflow
4132
+ in = array({-600.0, 600.0});
4133
+ in_fp8 = to_fp8(in);
4134
+ out = from_fp8(in_fp8, float32);
4135
+
4136
+ auto expected = array({-448.0f, 448.0f});
4137
+ CHECK(array_equal(out, expected, true).item<bool>());
4138
+ }
4139
+
4140
+ TEST_CASE("test max min with nan") {
4141
+ // Test maximum and minimum with NaN values
4142
+ auto x = array({0.0f, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
4143
+ auto y = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
4144
+ auto expected_max = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
4145
+ auto expected_min = array({NAN, 1.0f, NAN, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
4146
+ auto max_result = maximum(x, y);
4147
+ auto min_result = minimum(x, y);
4148
+ CHECK(array_equal(max_result, expected_max, true).item<bool>());
4149
+ CHECK(array_equal(min_result, expected_min, true).item<bool>());
4150
+
4151
+ // Test with all NaN values
4152
+ x = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
4153
+ y = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
4154
+ max_result = maximum(x, y);
4155
+ min_result = minimum(x, y);
4156
+ auto expected = array({NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
4157
+ CHECK(array_equal(max_result, expected, true).item<bool>());
4158
+ CHECK(array_equal(min_result, expected, true).item<bool>());
4159
+ }