warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3756 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+
25
+ @wp.kernel
26
+ def kernel_1d(a: wp.array(dtype=int, ndim=1)):
27
+ i = wp.tid()
28
+
29
+ wp.expect_eq(a[i], wp.tid())
30
+
31
+ a[i] = a[i] * 2
32
+ wp.atomic_add(a, i, 1)
33
+
34
+ wp.expect_eq(a[i], wp.tid() * 2 + 1)
35
+
36
+
37
+ def test_1d(test, device):
38
+ dim_x = 4
39
+
40
+ a = np.arange(0, dim_x, dtype=np.int32)
41
+
42
+ arr = wp.array(a, device=device)
43
+
44
+ test.assertEqual(arr.shape, a.shape)
45
+ test.assertEqual(arr.size, a.size)
46
+ test.assertEqual(arr.ndim, a.ndim)
47
+
48
+ with CheckOutput(test):
49
+ wp.launch(kernel_1d, dim=arr.size, inputs=[arr], device=device)
50
+
51
+
52
+ @wp.kernel
53
+ def kernel_2d(a: wp.array(dtype=int, ndim=2), m: int, n: int):
54
+ i = wp.tid() // n
55
+ j = wp.tid() % n
56
+
57
+ wp.expect_eq(a[i, j], wp.tid())
58
+ wp.expect_eq(a[i][j], wp.tid())
59
+
60
+ a[i, j] = a[i, j] * 2
61
+ wp.atomic_add(a, i, j, 1)
62
+
63
+ wp.expect_eq(a[i, j], wp.tid() * 2 + 1)
64
+
65
+
66
+ def test_2d(test, device):
67
+ dim_x = 4
68
+ dim_y = 2
69
+
70
+ a = np.arange(0, dim_x * dim_y, dtype=np.int32)
71
+ a = a.reshape(dim_x, dim_y)
72
+
73
+ arr = wp.array(a, device=device)
74
+
75
+ test.assertEqual(arr.shape, a.shape)
76
+ test.assertEqual(arr.size, a.size)
77
+ test.assertEqual(arr.ndim, a.ndim)
78
+
79
+ with CheckOutput(test):
80
+ wp.launch(kernel_2d, dim=arr.size, inputs=[arr, dim_x, dim_y], device=device)
81
+
82
+
83
+ @wp.kernel
84
+ def kernel_3d(a: wp.array(dtype=int, ndim=3), m: int, n: int, o: int):
85
+ i = wp.tid() // (n * o)
86
+ j = wp.tid() % (n * o) // o
87
+ k = wp.tid() % o
88
+
89
+ wp.expect_eq(a[i, j, k], wp.tid())
90
+ wp.expect_eq(a[i][j][k], wp.tid())
91
+
92
+ a[i, j, k] = a[i, j, k] * 2
93
+ a[i][j][k] = a[i][j][k] * 2
94
+ wp.atomic_add(a, i, j, k, 1)
95
+
96
+ wp.expect_eq(a[i, j, k], wp.tid() * 4 + 1)
97
+
98
+
99
+ def test_3d(test, device):
100
+ dim_x = 8
101
+ dim_y = 4
102
+ dim_z = 2
103
+
104
+ a = np.arange(0, dim_x * dim_y * dim_z, dtype=np.int32)
105
+ a = a.reshape(dim_x, dim_y, dim_z)
106
+
107
+ arr = wp.array(a, device=device)
108
+
109
+ test.assertEqual(arr.shape, a.shape)
110
+ test.assertEqual(arr.size, a.size)
111
+ test.assertEqual(arr.ndim, a.ndim)
112
+
113
+ with CheckOutput(test):
114
+ wp.launch(kernel_3d, dim=arr.size, inputs=[arr, dim_x, dim_y, dim_z], device=device)
115
+
116
+
117
+ @wp.kernel
118
+ def kernel_4d(a: wp.array(dtype=int, ndim=4), m: int, n: int, o: int, p: int):
119
+ i = wp.tid() // (n * o * p)
120
+ j = wp.tid() % (n * o * p) // (o * p)
121
+ k = wp.tid() % (o * p) / p
122
+ l = wp.tid() % p
123
+
124
+ wp.expect_eq(a[i, j, k, l], wp.tid())
125
+ wp.expect_eq(a[i][j][k][l], wp.tid())
126
+
127
+
128
+ def test_4d(test, device):
129
+ dim_x = 16
130
+ dim_y = 8
131
+ dim_z = 4
132
+ dim_w = 2
133
+
134
+ a = np.arange(0, dim_x * dim_y * dim_z * dim_w, dtype=np.int32)
135
+ a = a.reshape(dim_x, dim_y, dim_z, dim_w)
136
+
137
+ arr = wp.array(a, device=device)
138
+
139
+ test.assertEqual(arr.shape, a.shape)
140
+ test.assertEqual(arr.size, a.size)
141
+ test.assertEqual(arr.ndim, a.ndim)
142
+
143
+ with CheckOutput(test):
144
+ wp.launch(kernel_4d, dim=arr.size, inputs=[arr, dim_x, dim_y, dim_z, dim_w], device=device)
145
+
146
+
147
+ @wp.kernel
148
+ def kernel_4d_transposed(a: wp.array(dtype=int, ndim=4), m: int, n: int, o: int, p: int):
149
+ i = wp.tid() // (n * o * p)
150
+ j = wp.tid() % (n * o * p) // (o * p)
151
+ k = wp.tid() % (o * p) / p
152
+ l = wp.tid() % p
153
+
154
+ wp.expect_eq(a[l, k, j, i], wp.tid())
155
+ wp.expect_eq(a[l][k][j][i], wp.tid())
156
+
157
+
158
+ def test_4d_transposed(test, device):
159
+ dim_x = 16
160
+ dim_y = 8
161
+ dim_z = 4
162
+ dim_w = 2
163
+
164
+ a = np.arange(0, dim_x * dim_y * dim_z * dim_w, dtype=np.int32)
165
+ a = a.reshape(dim_x, dim_y, dim_z, dim_w)
166
+
167
+ arr = wp.array(a, device=device)
168
+
169
+ # Transpose the array manually, as using the wp.array() constructor with arr.T would make it contiguous first
170
+ a_T = a.T
171
+ arr_T = wp.array(
172
+ dtype=arr.dtype,
173
+ shape=a_T.shape,
174
+ strides=a_T.__array_interface__["strides"],
175
+ capacity=arr.capacity,
176
+ ptr=arr.ptr,
177
+ requires_grad=arr.requires_grad,
178
+ device=device,
179
+ )
180
+
181
+ test.assertFalse(arr_T.is_contiguous)
182
+ test.assertEqual(arr_T.shape, a_T.shape)
183
+ test.assertEqual(arr_T.strides, a_T.__array_interface__["strides"])
184
+ test.assertEqual(arr_T.size, a_T.size)
185
+ test.assertEqual(arr_T.ndim, a_T.ndim)
186
+
187
+ with CheckOutput(test):
188
+ wp.launch(kernel_4d_transposed, dim=arr_T.size, inputs=[arr_T, dim_x, dim_y, dim_z, dim_w], device=device)
189
+
190
+
191
+ @wp.kernel
192
+ def lower_bound_kernel(values: wp.array(dtype=float), arr: wp.array(dtype=float), indices: wp.array(dtype=int)):
193
+ tid = wp.tid()
194
+
195
+ indices[tid] = wp.lower_bound(arr, values[tid])
196
+
197
+
198
+ def test_lower_bound(test, device):
199
+ arr = wp.array(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0], dtype=float), dtype=float, device=device)
200
+ values = wp.array(np.array([-0.1, 0.0, 2.5, 4.0, 5.0, 5.5], dtype=float), dtype=float, device=device)
201
+ indices = wp.zeros(6, dtype=int, device=device)
202
+
203
+ wp.launch(kernel=lower_bound_kernel, dim=6, inputs=[values, arr, indices], device=device)
204
+
205
+ test.assertTrue((np.array([0, 0, 3, 4, 5, 5]) == indices.numpy()).all())
206
+
207
+
208
+ @wp.kernel
209
+ def f1(arr: wp.array(dtype=float)):
210
+ wp.expect_eq(arr.shape[0], 10)
211
+
212
+
213
+ @wp.kernel
214
+ def f2(arr: wp.array2d(dtype=float)):
215
+ wp.expect_eq(arr.shape[0], 10)
216
+ wp.expect_eq(arr.shape[1], 20)
217
+
218
+ slice = arr[0]
219
+ wp.expect_eq(slice.shape[0], 20)
220
+
221
+
222
+ @wp.kernel
223
+ def f3(arr: wp.array3d(dtype=float)):
224
+ wp.expect_eq(arr.shape[0], 10)
225
+ wp.expect_eq(arr.shape[1], 20)
226
+ wp.expect_eq(arr.shape[2], 30)
227
+
228
+ slice = arr[0, 0]
229
+ wp.expect_eq(slice.shape[0], 30)
230
+
231
+
232
+ @wp.kernel
233
+ def f4(arr: wp.array4d(dtype=float)):
234
+ wp.expect_eq(arr.shape[0], 10)
235
+ wp.expect_eq(arr.shape[1], 20)
236
+ wp.expect_eq(arr.shape[2], 30)
237
+ wp.expect_eq(arr.shape[3], 40)
238
+
239
+ slice = arr[0, 0, 0]
240
+ wp.expect_eq(slice.shape[0], 40)
241
+
242
+
243
+ def test_shape(test, device):
244
+ with CheckOutput(test):
245
+ a1 = wp.zeros(dtype=float, shape=10, device=device)
246
+ wp.launch(f1, dim=1, inputs=[a1], device=device)
247
+
248
+ a2 = wp.zeros(dtype=float, shape=(10, 20), device=device)
249
+ wp.launch(f2, dim=1, inputs=[a2], device=device)
250
+
251
+ a3 = wp.zeros(dtype=float, shape=(10, 20, 30), device=device)
252
+ wp.launch(f3, dim=1, inputs=[a3], device=device)
253
+
254
+ a4 = wp.zeros(dtype=float, shape=(10, 20, 30, 40), device=device)
255
+ wp.launch(f4, dim=1, inputs=[a4], device=device)
256
+
257
+
258
+ def test_negative_shape(test, device):
259
+ with test.assertRaisesRegex(ValueError, "Array shapes must be non-negative"):
260
+ _ = wp.zeros(shape=-1, dtype=int, device=device)
261
+
262
+ with test.assertRaisesRegex(ValueError, "Array shapes must be non-negative"):
263
+ _ = wp.zeros(shape=-(2**32), dtype=int, device=device)
264
+
265
+ with test.assertRaisesRegex(ValueError, "Array shapes must be non-negative"):
266
+ _ = wp.zeros(shape=(10, -1), dtype=int, device=device)
267
+
268
+
269
+ @wp.kernel
270
+ def sum_array(arr: wp.array(dtype=float), loss: wp.array(dtype=float)):
271
+ tid = wp.tid()
272
+ wp.atomic_add(loss, 0, arr[tid])
273
+
274
+
275
+ def test_flatten(test, device):
276
+ np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=float)
277
+ arr = wp.array(np_arr, dtype=float, shape=np_arr.shape, device=device, requires_grad=True)
278
+ arr_flat = arr.flatten()
279
+ arr_comp = wp.array(np_arr.flatten(), dtype=float, device=device)
280
+ assert_array_equal(arr_flat, arr_comp)
281
+
282
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
283
+ tape = wp.Tape()
284
+ with tape:
285
+ wp.launch(kernel=sum_array, dim=len(arr_flat), inputs=[arr_flat, loss], device=device)
286
+
287
+ tape.backward(loss=loss)
288
+ grad = tape.gradients[arr_flat]
289
+
290
+ ones = wp.array(
291
+ np.ones(
292
+ (8,),
293
+ dtype=float,
294
+ ),
295
+ dtype=float,
296
+ device=device,
297
+ )
298
+ assert_array_equal(grad, ones)
299
+ test.assertEqual(loss.numpy()[0], 36)
300
+
301
+
302
+ def test_reshape(test, device):
303
+ np_arr = np.arange(6, dtype=float)
304
+ arr = wp.array(np_arr, dtype=float, device=device, requires_grad=True)
305
+ arr_reshaped = arr.reshape((3, 2))
306
+ arr_comp = wp.array(np_arr.reshape((3, 2)), dtype=float, device=device)
307
+ assert_array_equal(arr_reshaped, arr_comp)
308
+
309
+ arr_reshaped = arr_reshaped.reshape(6)
310
+ assert_array_equal(arr_reshaped, arr)
311
+
312
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
313
+ tape = wp.Tape()
314
+ with tape:
315
+ wp.launch(kernel=sum_array, dim=len(arr_reshaped), inputs=[arr_reshaped, loss], device=device)
316
+
317
+ tape.backward(loss=loss)
318
+ grad = tape.gradients[arr_reshaped]
319
+
320
+ ones = wp.array(
321
+ np.ones(
322
+ (6,),
323
+ dtype=float,
324
+ ),
325
+ dtype=float,
326
+ device=device,
327
+ )
328
+ assert_array_equal(grad, ones)
329
+ test.assertEqual(loss.numpy()[0], 15)
330
+
331
+ np_arr = np.arange(6, dtype=float)
332
+ arr = wp.array(np_arr, dtype=float, device=device)
333
+ arr_infer = arr.reshape((-1, 3))
334
+ arr_comp = wp.array(np_arr.reshape((-1, 3)), dtype=float, device=device)
335
+ assert_array_equal(arr_infer, arr_comp)
336
+
337
+
338
+ @wp.kernel
339
+ def compare_stepped_window_a(x: wp.array2d(dtype=float)):
340
+ wp.expect_eq(x[0, 0], 1.0)
341
+ wp.expect_eq(x[0, 1], 2.0)
342
+ wp.expect_eq(x[1, 0], 9.0)
343
+ wp.expect_eq(x[1, 1], 10.0)
344
+
345
+
346
+ @wp.kernel
347
+ def compare_stepped_window_b(x: wp.array2d(dtype=float)):
348
+ wp.expect_eq(x[0, 0], 3.0)
349
+ wp.expect_eq(x[0, 1], 4.0)
350
+ wp.expect_eq(x[1, 0], 7.0)
351
+ wp.expect_eq(x[1, 1], 8.0)
352
+ wp.expect_eq(x[2, 0], 11.0)
353
+ wp.expect_eq(x[2, 1], 12.0)
354
+
355
+
356
+ def test_slicing(test, device):
357
+ np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], dtype=float)
358
+ arr = wp.array(np_arr, dtype=float, shape=np_arr.shape, device=device, requires_grad=True)
359
+
360
+ slice_a = arr[1, :, :] # test indexing
361
+ slice_b = arr[1:2, :, :] # test slicing
362
+ slice_c = arr[-1, :, :] # test negative indexing
363
+ slice_d = arr[-2:-1, :, :] # test negative slicing
364
+ slice_e = arr[-1:3, :, :] # test mixed slicing
365
+ slice_e2 = slice_e[0, 0, :] # test 2x slicing
366
+ slice_f = arr[0:3:2, 0, :] # test step
367
+ slice_g = arr[1:1, :0, -1:-1] # test empty slice
368
+
369
+ assert_array_equal(slice_a, wp.array(np_arr[1, :, :], dtype=float, device=device))
370
+ assert_array_equal(slice_b, wp.array(np_arr[1:2, :, :], dtype=float, device=device))
371
+ assert_array_equal(slice_c, wp.array(np_arr[-1, :, :], dtype=float, device=device))
372
+ assert_array_equal(slice_d, wp.array(np_arr[-2:-1, :, :], dtype=float, device=device))
373
+ assert_array_equal(slice_e, wp.array(np_arr[-1:3, :, :], dtype=float, device=device))
374
+ assert_array_equal(slice_e2, wp.array(np_arr[2, 0, :], dtype=float, device=device))
375
+ assert slice_g.shape == np_arr[1:1, :0, -1:-1].shape == (0, 0, 0)
376
+
377
+ # wp does not support copying from/to non-contiguous arrays
378
+ # stepped windows must read on the device the original array was created on
379
+ wp.launch(kernel=compare_stepped_window_a, dim=1, inputs=[slice_f], device=device)
380
+
381
+ slice_flat = slice_b.flatten()
382
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
383
+ tape = wp.Tape()
384
+ with tape:
385
+ wp.launch(kernel=sum_array, dim=len(slice_flat), inputs=[slice_flat, loss], device=device)
386
+
387
+ tape.backward(loss=loss)
388
+ grad = tape.gradients[slice_flat]
389
+
390
+ ones = wp.array(
391
+ np.ones(
392
+ (4,),
393
+ dtype=float,
394
+ ),
395
+ dtype=float,
396
+ device=device,
397
+ )
398
+ assert_array_equal(grad, ones)
399
+ test.assertEqual(loss.numpy()[0], 26)
400
+
401
+ index_a = arr[1]
402
+ index_b = arr[2, 1]
403
+ index_c = arr[1, :]
404
+ index_d = arr[:, 1]
405
+
406
+ assert_array_equal(index_a, wp.array(np_arr[1], dtype=float, device=device))
407
+ assert_array_equal(index_b, wp.array(np_arr[2, 1], dtype=float, device=device))
408
+ assert_array_equal(index_c, wp.array(np_arr[1, :], dtype=float, device=device))
409
+ wp.launch(kernel=compare_stepped_window_b, dim=1, inputs=[index_d], device=device)
410
+
411
+ np_arr = np.zeros(10, dtype=int)
412
+ wp_arr = wp.array(np_arr, dtype=int, device=device)
413
+
414
+ assert_array_equal(wp_arr[:5], wp.array(np_arr[:5], dtype=int, device=device))
415
+ assert_array_equal(wp_arr[1:5], wp.array(np_arr[1:5], dtype=int, device=device))
416
+ assert_array_equal(wp_arr[-9:-5:1], wp.array(np_arr[-9:-5:1], dtype=int, device=device))
417
+ assert_array_equal(wp_arr[:5,], wp.array(np_arr[:5], dtype=int, device=device))
418
+
419
+
420
+ def test_view(test, device):
421
+ np_arr_a = np.arange(1, 10, 1, dtype=np.uint32)
422
+ np_arr_b = np.arange(1, 10, 1, dtype=np.float32)
423
+ np_arr_c = np.arange(1, 10, 1, dtype=np.uint16)
424
+ np_arr_d = np.arange(1, 10, 1, dtype=np.float16)
425
+ np_arr_e = np.ones((4, 4), dtype=np.float32)
426
+
427
+ wp_arr_a = wp.array(np_arr_a, dtype=wp.uint32, device=device)
428
+ wp_arr_b = wp.array(np_arr_b, dtype=wp.float32, device=device)
429
+ wp_arr_c = wp.array(np_arr_a, dtype=wp.uint16, device=device)
430
+ wp_arr_d = wp.array(np_arr_b, dtype=wp.float16, device=device)
431
+ wp_arr_e = wp.array(np_arr_e, dtype=wp.vec4, device=device)
432
+ wp_arr_f = wp.array(np_arr_e, dtype=wp.quat, device=device)
433
+
434
+ assert_np_equal(wp_arr_a.view(dtype=wp.float32).numpy(), np_arr_a.view(dtype=np.float32))
435
+ assert_np_equal(wp_arr_b.view(dtype=wp.uint32).numpy(), np_arr_b.view(dtype=np.uint32))
436
+ assert_np_equal(wp_arr_c.view(dtype=wp.float16).numpy(), np_arr_c.view(dtype=np.float16))
437
+ assert_np_equal(wp_arr_d.view(dtype=wp.uint16).numpy(), np_arr_d.view(dtype=np.uint16))
438
+ assert_array_equal(wp_arr_e.view(dtype=wp.quat), wp_arr_f)
439
+
440
+
441
+ def test_clone_adjoint(test, device):
442
+ state_in = wp.from_numpy(
443
+ np.array([1.0, 2.0, 3.0]).astype(np.float32), dtype=wp.float32, requires_grad=True, device=device
444
+ )
445
+
446
+ tape = wp.Tape()
447
+ with tape:
448
+ state_out = wp.clone(state_in)
449
+
450
+ grads = {state_out: wp.from_numpy(np.array([1.0, 1.0, 1.0]).astype(np.float32), dtype=wp.float32, device=device)}
451
+ tape.backward(grads=grads)
452
+
453
+ assert_np_equal(state_in.grad.numpy(), np.array([1.0, 1.0, 1.0]).astype(np.float32))
454
+
455
+
456
+ def test_assign_adjoint(test, device):
457
+ state_in = wp.from_numpy(
458
+ np.array([1.0, 2.0, 3.0]).astype(np.float32), dtype=wp.float32, requires_grad=True, device=device
459
+ )
460
+ state_out = wp.zeros(state_in.shape, dtype=wp.float32, requires_grad=True, device=device)
461
+
462
+ tape = wp.Tape()
463
+ with tape:
464
+ state_out.assign(state_in)
465
+
466
+ grads = {state_out: wp.from_numpy(np.array([1.0, 1.0, 1.0]).astype(np.float32), dtype=wp.float32, device=device)}
467
+ tape.backward(grads=grads)
468
+
469
+ assert_np_equal(state_in.grad.numpy(), np.array([1.0, 1.0, 1.0]).astype(np.float32))
470
+
471
+
472
+ @wp.kernel
473
+ def compare_2darrays(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float), z: wp.array2d(dtype=int)):
474
+ i, j = wp.tid()
475
+
476
+ if x[i, j] == y[i, j]:
477
+ z[i, j] = 1
478
+
479
+
480
+ @wp.kernel
481
+ def compare_3darrays(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float), z: wp.array3d(dtype=int)):
482
+ i, j, k = wp.tid()
483
+
484
+ if x[i, j, k] == y[i, j, k]:
485
+ z[i, j, k] = 1
486
+
487
+
488
+ def test_transpose(test, device):
489
+ # test default transpose in non-square 2d case
490
+ # wp does not support copying from/to non-contiguous arrays so check in kernel
491
+ np_arr = np.array([[1, 2], [3, 4], [5, 6]], dtype=float)
492
+ arr = wp.array(np_arr, dtype=float, device=device)
493
+ arr_transpose = arr.transpose()
494
+ arr_compare = wp.array(np_arr.transpose(), dtype=float, device=device)
495
+ check = wp.zeros(shape=(2, 3), dtype=int, device=device)
496
+
497
+ wp.launch(compare_2darrays, dim=(2, 3), inputs=[arr_transpose, arr_compare, check], device=device)
498
+ assert_np_equal(check.numpy(), np.ones((2, 3), dtype=int))
499
+
500
+ # test transpose in square 3d case
501
+ # wp does not support copying from/to non-contiguous arrays so check in kernel
502
+ np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], dtype=float)
503
+ arr = wp.array3d(np_arr, dtype=float, shape=np_arr.shape, device=device, requires_grad=True)
504
+ arr_transpose = arr.transpose((0, 2, 1))
505
+ arr_compare = wp.array3d(np_arr.transpose((0, 2, 1)), dtype=float, device=device)
506
+ check = wp.zeros(shape=(3, 2, 2), dtype=int, device=device)
507
+
508
+ wp.launch(compare_3darrays, dim=(3, 2, 2), inputs=[arr_transpose, arr_compare, check], device=device)
509
+ assert_np_equal(check.numpy(), np.ones((3, 2, 2), dtype=int))
510
+
511
+ # test transpose in square 3d case without axes supplied
512
+ arr_transpose = arr.transpose()
513
+ arr_compare = wp.array3d(np_arr.transpose(), dtype=float, device=device)
514
+ check = wp.zeros(shape=(2, 2, 3), dtype=int, device=device)
515
+
516
+ wp.launch(compare_3darrays, dim=(2, 2, 3), inputs=[arr_transpose, arr_compare, check], device=device)
517
+ assert_np_equal(check.numpy(), np.ones((2, 2, 3), dtype=int))
518
+
519
+ # test transpose in 1d case (should be noop)
520
+ np_arr = np.array([1, 2, 3], dtype=float)
521
+ arr = wp.array(np_arr, dtype=float, device=device)
522
+
523
+ assert_np_equal(arr.transpose().numpy(), np_arr.transpose())
524
+
525
+
526
+ def test_fill_scalar(test, device):
527
+ dim_x = 4
528
+
529
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
530
+ a1 = wp.zeros(dim_x, dtype=wptype, device=device)
531
+ a2 = wp.zeros((dim_x, dim_x), dtype=wptype, device=device)
532
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=wptype, device=device)
533
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=wptype, device=device)
534
+
535
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
536
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
537
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
538
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
539
+
540
+ # fill with int value
541
+ fill_value = 42
542
+
543
+ a1.fill_(fill_value)
544
+ a2.fill_(fill_value)
545
+ a3.fill_(fill_value)
546
+ a4.fill_(fill_value)
547
+
548
+ assert_np_equal(a1.numpy(), np.full(a1.shape, fill_value, dtype=nptype))
549
+ assert_np_equal(a2.numpy(), np.full(a2.shape, fill_value, dtype=nptype))
550
+ assert_np_equal(a3.numpy(), np.full(a3.shape, fill_value, dtype=nptype))
551
+ assert_np_equal(a4.numpy(), np.full(a4.shape, fill_value, dtype=nptype))
552
+
553
+ a1.zero_()
554
+ a2.zero_()
555
+ a3.zero_()
556
+ a4.zero_()
557
+
558
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
559
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
560
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
561
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
562
+
563
+ if wptype in wp._src.types.float_types:
564
+ # fill with float value
565
+ fill_value = 13.37
566
+
567
+ a1.fill_(fill_value)
568
+ a2.fill_(fill_value)
569
+ a3.fill_(fill_value)
570
+ a4.fill_(fill_value)
571
+
572
+ assert_np_equal(a1.numpy(), np.full(a1.shape, fill_value, dtype=nptype))
573
+ assert_np_equal(a2.numpy(), np.full(a2.shape, fill_value, dtype=nptype))
574
+ assert_np_equal(a3.numpy(), np.full(a3.shape, fill_value, dtype=nptype))
575
+ assert_np_equal(a4.numpy(), np.full(a4.shape, fill_value, dtype=nptype))
576
+
577
+ # fill with Warp scalar value
578
+ fill_value = wptype(17)
579
+
580
+ a1.fill_(fill_value)
581
+ a2.fill_(fill_value)
582
+ a3.fill_(fill_value)
583
+ a4.fill_(fill_value)
584
+
585
+ assert_np_equal(a1.numpy(), np.full(a1.shape, fill_value.value, dtype=nptype))
586
+ assert_np_equal(a2.numpy(), np.full(a2.shape, fill_value.value, dtype=nptype))
587
+ assert_np_equal(a3.numpy(), np.full(a3.shape, fill_value.value, dtype=nptype))
588
+ assert_np_equal(a4.numpy(), np.full(a4.shape, fill_value.value, dtype=nptype))
589
+
590
+
591
+ def test_fill_vector(test, device):
592
+ # test filling a vector array with scalar or vector values (vec_type, list, or numpy array)
593
+
594
+ dim_x = 4
595
+
596
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
597
+ # vector types
598
+ vector_types = [
599
+ wp._src.types.vector(2, wptype),
600
+ wp._src.types.vector(3, wptype),
601
+ wp._src.types.vector(4, wptype),
602
+ wp._src.types.vector(5, wptype),
603
+ ]
604
+
605
+ for vec_type in vector_types:
606
+ vec_len = vec_type._length_
607
+
608
+ a1 = wp.zeros(dim_x, dtype=vec_type, device=device)
609
+ a2 = wp.zeros((dim_x, dim_x), dtype=vec_type, device=device)
610
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=vec_type, device=device)
611
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=vec_type, device=device)
612
+
613
+ assert_np_equal(a1.numpy(), np.zeros((*a1.shape, vec_len), dtype=nptype))
614
+ assert_np_equal(a2.numpy(), np.zeros((*a2.shape, vec_len), dtype=nptype))
615
+ assert_np_equal(a3.numpy(), np.zeros((*a3.shape, vec_len), dtype=nptype))
616
+ assert_np_equal(a4.numpy(), np.zeros((*a4.shape, vec_len), dtype=nptype))
617
+
618
+ # fill with int scalar
619
+ fill_value = 42
620
+
621
+ a1.fill_(fill_value)
622
+ a2.fill_(fill_value)
623
+ a3.fill_(fill_value)
624
+ a4.fill_(fill_value)
625
+
626
+ assert_np_equal(a1.numpy(), np.full((*a1.shape, vec_len), fill_value, dtype=nptype))
627
+ assert_np_equal(a2.numpy(), np.full((*a2.shape, vec_len), fill_value, dtype=nptype))
628
+ assert_np_equal(a3.numpy(), np.full((*a3.shape, vec_len), fill_value, dtype=nptype))
629
+ assert_np_equal(a4.numpy(), np.full((*a4.shape, vec_len), fill_value, dtype=nptype))
630
+
631
+ # test zeroing
632
+ a1.zero_()
633
+ a2.zero_()
634
+ a3.zero_()
635
+ a4.zero_()
636
+
637
+ assert_np_equal(a1.numpy(), np.zeros((*a1.shape, vec_len), dtype=nptype))
638
+ assert_np_equal(a2.numpy(), np.zeros((*a2.shape, vec_len), dtype=nptype))
639
+ assert_np_equal(a3.numpy(), np.zeros((*a3.shape, vec_len), dtype=nptype))
640
+ assert_np_equal(a4.numpy(), np.zeros((*a4.shape, vec_len), dtype=nptype))
641
+
642
+ # vector values can be passed as a list, numpy array, or Warp vector instance
643
+ fill_list = [17, 42, 99, 101, 127][:vec_len]
644
+ fill_arr = np.array(fill_list, dtype=nptype)
645
+ fill_vec = vec_type(fill_list)
646
+
647
+ expected1 = np.tile(fill_arr, a1.size).reshape((*a1.shape, vec_len))
648
+ expected2 = np.tile(fill_arr, a2.size).reshape((*a2.shape, vec_len))
649
+ expected3 = np.tile(fill_arr, a3.size).reshape((*a3.shape, vec_len))
650
+ expected4 = np.tile(fill_arr, a4.size).reshape((*a4.shape, vec_len))
651
+
652
+ # fill with list of vector length
653
+ a1.fill_(fill_list)
654
+ a2.fill_(fill_list)
655
+ a3.fill_(fill_list)
656
+ a4.fill_(fill_list)
657
+
658
+ assert_np_equal(a1.numpy(), expected1)
659
+ assert_np_equal(a2.numpy(), expected2)
660
+ assert_np_equal(a3.numpy(), expected3)
661
+ assert_np_equal(a4.numpy(), expected4)
662
+
663
+ # clear
664
+ a1.zero_()
665
+ a2.zero_()
666
+ a3.zero_()
667
+ a4.zero_()
668
+
669
+ # fill with numpy array of vector length
670
+ a1.fill_(fill_arr)
671
+ a2.fill_(fill_arr)
672
+ a3.fill_(fill_arr)
673
+ a4.fill_(fill_arr)
674
+
675
+ assert_np_equal(a1.numpy(), expected1)
676
+ assert_np_equal(a2.numpy(), expected2)
677
+ assert_np_equal(a3.numpy(), expected3)
678
+ assert_np_equal(a4.numpy(), expected4)
679
+
680
+ # clear
681
+ a1.zero_()
682
+ a2.zero_()
683
+ a3.zero_()
684
+ a4.zero_()
685
+
686
+ # fill with vec instance
687
+ a1.fill_(fill_vec)
688
+ a2.fill_(fill_vec)
689
+ a3.fill_(fill_vec)
690
+ a4.fill_(fill_vec)
691
+
692
+ assert_np_equal(a1.numpy(), expected1)
693
+ assert_np_equal(a2.numpy(), expected2)
694
+ assert_np_equal(a3.numpy(), expected3)
695
+ assert_np_equal(a4.numpy(), expected4)
696
+
697
+ if wptype in wp._src.types.float_types:
698
+ # fill with float scalar
699
+ fill_value = 13.37
700
+
701
+ a1.fill_(fill_value)
702
+ a2.fill_(fill_value)
703
+ a3.fill_(fill_value)
704
+ a4.fill_(fill_value)
705
+
706
+ assert_np_equal(a1.numpy(), np.full((*a1.shape, vec_len), fill_value, dtype=nptype))
707
+ assert_np_equal(a2.numpy(), np.full((*a2.shape, vec_len), fill_value, dtype=nptype))
708
+ assert_np_equal(a3.numpy(), np.full((*a3.shape, vec_len), fill_value, dtype=nptype))
709
+ assert_np_equal(a4.numpy(), np.full((*a4.shape, vec_len), fill_value, dtype=nptype))
710
+
711
+ # fill with float list of vector length
712
+ fill_list = [-2.5, -1.25, 1.25, 2.5, 5.0][:vec_len]
713
+
714
+ a1.fill_(fill_list)
715
+ a2.fill_(fill_list)
716
+ a3.fill_(fill_list)
717
+ a4.fill_(fill_list)
718
+
719
+ expected1 = np.tile(np.array(fill_list, dtype=nptype), a1.size).reshape((*a1.shape, vec_len))
720
+ expected2 = np.tile(np.array(fill_list, dtype=nptype), a2.size).reshape((*a2.shape, vec_len))
721
+ expected3 = np.tile(np.array(fill_list, dtype=nptype), a3.size).reshape((*a3.shape, vec_len))
722
+ expected4 = np.tile(np.array(fill_list, dtype=nptype), a4.size).reshape((*a4.shape, vec_len))
723
+
724
+ assert_np_equal(a1.numpy(), expected1)
725
+ assert_np_equal(a2.numpy(), expected2)
726
+ assert_np_equal(a3.numpy(), expected3)
727
+ assert_np_equal(a4.numpy(), expected4)
728
+
729
+
730
+ def test_fill_matrix(test, device):
731
+ # test filling a matrix array with scalar or matrix values (mat_type, nested list, or 2d numpy array)
732
+
733
+ dim_x = 4
734
+
735
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
736
+ # matrix types
737
+ matrix_types = [
738
+ # square matrices
739
+ wp._src.types.matrix((2, 2), wptype),
740
+ wp._src.types.matrix((3, 3), wptype),
741
+ wp._src.types.matrix((4, 4), wptype),
742
+ wp._src.types.matrix((5, 5), wptype),
743
+ # non-square matrices
744
+ wp._src.types.matrix((2, 3), wptype),
745
+ wp._src.types.matrix((3, 2), wptype),
746
+ wp._src.types.matrix((3, 4), wptype),
747
+ wp._src.types.matrix((4, 3), wptype),
748
+ ]
749
+
750
+ for mat_type in matrix_types:
751
+ mat_len = mat_type._length_
752
+ mat_shape = mat_type._shape_
753
+
754
+ a1 = wp.zeros(dim_x, dtype=mat_type, device=device)
755
+ a2 = wp.zeros((dim_x, dim_x), dtype=mat_type, device=device)
756
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=mat_type, device=device)
757
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=mat_type, device=device)
758
+
759
+ assert_np_equal(a1.numpy(), np.zeros((*a1.shape, *mat_shape), dtype=nptype))
760
+ assert_np_equal(a2.numpy(), np.zeros((*a2.shape, *mat_shape), dtype=nptype))
761
+ assert_np_equal(a3.numpy(), np.zeros((*a3.shape, *mat_shape), dtype=nptype))
762
+ assert_np_equal(a4.numpy(), np.zeros((*a4.shape, *mat_shape), dtype=nptype))
763
+
764
+ # fill with scalar
765
+ fill_value = 42
766
+
767
+ a1.fill_(fill_value)
768
+ a2.fill_(fill_value)
769
+ a3.fill_(fill_value)
770
+ a4.fill_(fill_value)
771
+
772
+ assert_np_equal(a1.numpy(), np.full((*a1.shape, *mat_shape), fill_value, dtype=nptype))
773
+ assert_np_equal(a2.numpy(), np.full((*a2.shape, *mat_shape), fill_value, dtype=nptype))
774
+ assert_np_equal(a3.numpy(), np.full((*a3.shape, *mat_shape), fill_value, dtype=nptype))
775
+ assert_np_equal(a4.numpy(), np.full((*a4.shape, *mat_shape), fill_value, dtype=nptype))
776
+
777
+ # test zeroing
778
+ a1.zero_()
779
+ a2.zero_()
780
+ a3.zero_()
781
+ a4.zero_()
782
+
783
+ assert_np_equal(a1.numpy(), np.zeros((*a1.shape, *mat_shape), dtype=nptype))
784
+ assert_np_equal(a2.numpy(), np.zeros((*a2.shape, *mat_shape), dtype=nptype))
785
+ assert_np_equal(a3.numpy(), np.zeros((*a3.shape, *mat_shape), dtype=nptype))
786
+ assert_np_equal(a4.numpy(), np.zeros((*a4.shape, *mat_shape), dtype=nptype))
787
+
788
+ # matrix values can be passed as a 1d numpy array, 2d numpy array, flat list, nested list, or Warp matrix instance
789
+ if wptype != wp.bool:
790
+ fill_arr1 = np.arange(mat_len, dtype=nptype)
791
+ else:
792
+ fill_arr1 = np.ones(mat_len, dtype=nptype)
793
+ fill_arr2 = fill_arr1.reshape(mat_shape)
794
+ fill_list1 = list(fill_arr1)
795
+ fill_list2 = [list(row) for row in fill_arr2]
796
+ fill_mat = mat_type(fill_arr1)
797
+
798
+ expected1 = np.tile(fill_arr1, a1.size).reshape((*a1.shape, *mat_shape))
799
+ expected2 = np.tile(fill_arr1, a2.size).reshape((*a2.shape, *mat_shape))
800
+ expected3 = np.tile(fill_arr1, a3.size).reshape((*a3.shape, *mat_shape))
801
+ expected4 = np.tile(fill_arr1, a4.size).reshape((*a4.shape, *mat_shape))
802
+
803
+ # fill with 1d numpy array
804
+ a1.fill_(fill_arr1)
805
+ a2.fill_(fill_arr1)
806
+ a3.fill_(fill_arr1)
807
+ a4.fill_(fill_arr1)
808
+
809
+ assert_np_equal(a1.numpy(), expected1)
810
+ assert_np_equal(a2.numpy(), expected2)
811
+ assert_np_equal(a3.numpy(), expected3)
812
+ assert_np_equal(a4.numpy(), expected4)
813
+
814
+ # clear
815
+ a1.zero_()
816
+ a2.zero_()
817
+ a3.zero_()
818
+ a4.zero_()
819
+
820
+ # fill with 2d numpy array
821
+ a1.fill_(fill_arr2)
822
+ a2.fill_(fill_arr2)
823
+ a3.fill_(fill_arr2)
824
+ a4.fill_(fill_arr2)
825
+
826
+ assert_np_equal(a1.numpy(), expected1)
827
+ assert_np_equal(a2.numpy(), expected2)
828
+ assert_np_equal(a3.numpy(), expected3)
829
+ assert_np_equal(a4.numpy(), expected4)
830
+
831
+ # clear
832
+ a1.zero_()
833
+ a2.zero_()
834
+ a3.zero_()
835
+ a4.zero_()
836
+
837
+ # fill with flat list
838
+ a1.fill_(fill_list1)
839
+ a2.fill_(fill_list1)
840
+ a3.fill_(fill_list1)
841
+ a4.fill_(fill_list1)
842
+
843
+ assert_np_equal(a1.numpy(), expected1)
844
+ assert_np_equal(a2.numpy(), expected2)
845
+ assert_np_equal(a3.numpy(), expected3)
846
+ assert_np_equal(a4.numpy(), expected4)
847
+
848
+ # clear
849
+ a1.zero_()
850
+ a2.zero_()
851
+ a3.zero_()
852
+ a4.zero_()
853
+
854
+ # fill with nested list
855
+ a1.fill_(fill_list2)
856
+ a2.fill_(fill_list2)
857
+ a3.fill_(fill_list2)
858
+ a4.fill_(fill_list2)
859
+
860
+ assert_np_equal(a1.numpy(), expected1)
861
+ assert_np_equal(a2.numpy(), expected2)
862
+ assert_np_equal(a3.numpy(), expected3)
863
+ assert_np_equal(a4.numpy(), expected4)
864
+
865
+ # clear
866
+ a1.zero_()
867
+ a2.zero_()
868
+ a3.zero_()
869
+ a4.zero_()
870
+
871
+ # fill with mat instance
872
+ a1.fill_(fill_mat)
873
+ a2.fill_(fill_mat)
874
+ a3.fill_(fill_mat)
875
+ a4.fill_(fill_mat)
876
+
877
+ assert_np_equal(a1.numpy(), expected1)
878
+ assert_np_equal(a2.numpy(), expected2)
879
+ assert_np_equal(a3.numpy(), expected3)
880
+ assert_np_equal(a4.numpy(), expected4)
881
+
882
+
883
+ @wp.struct
884
+ class FillStruct:
885
+ # scalar members (make sure to test float16)
886
+ i1: wp.int8
887
+ i2: wp.int16
888
+ i4: wp.int32
889
+ i8: wp.int64
890
+ f2: wp.float16
891
+ f4: wp.float32
892
+ f8: wp.float16
893
+ # vector members (make sure to test vectors of float16)
894
+ v2: wp._src.types.vector(2, wp.int64)
895
+ v3: wp._src.types.vector(3, wp.float32)
896
+ v4: wp._src.types.vector(4, wp.float16)
897
+ v5: wp._src.types.vector(5, wp.uint8)
898
+ # matrix members (make sure to test matrices of float16)
899
+ m2: wp._src.types.matrix((2, 2), wp.float64)
900
+ m3: wp._src.types.matrix((3, 3), wp.int32)
901
+ m4: wp._src.types.matrix((4, 4), wp.float16)
902
+ m5: wp._src.types.matrix((5, 5), wp.int8)
903
+ # arrays
904
+ a1: wp.array(dtype=float)
905
+ a2: wp.array2d(dtype=float)
906
+ a3: wp.array3d(dtype=float)
907
+ a4: wp.array4d(dtype=float)
908
+
909
+
910
+ def test_fill_struct(test, device):
911
+ dim_x = 4
912
+
913
+ nptype = FillStruct.numpy_dtype()
914
+
915
+ a1 = wp.zeros(dim_x, dtype=FillStruct, device=device)
916
+ a2 = wp.zeros((dim_x, dim_x), dtype=FillStruct, device=device)
917
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=FillStruct, device=device)
918
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=FillStruct, device=device)
919
+
920
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
921
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
922
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
923
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
924
+
925
+ s = FillStruct()
926
+
927
+ # fill with default struct value (should be all zeros)
928
+ a1.fill_(s)
929
+ a2.fill_(s)
930
+ a3.fill_(s)
931
+ a4.fill_(s)
932
+
933
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
934
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
935
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
936
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
937
+
938
+ # scalars
939
+ s.i1 = -17
940
+ s.i2 = 42
941
+ s.i4 = 99
942
+ s.i8 = 101
943
+ s.f2 = -1.25
944
+ s.f4 = 13.37
945
+ s.f8 = 0.125
946
+ # vectors
947
+ s.v2 = [21, 22]
948
+ s.v3 = [31, 32, 33]
949
+ s.v4 = [41, 42, 43, 44]
950
+ s.v5 = [51, 52, 53, 54, 55]
951
+ # matrices
952
+ s.m2 = [[61, 62]] * 2
953
+ s.m3 = [[71, 72, 73]] * 3
954
+ s.m4 = [[81, 82, 83, 84]] * 4
955
+ s.m5 = [[91, 92, 93, 94, 95]] * 5
956
+ # arrays
957
+ s.a1 = wp.zeros((2,) * 1, dtype=float, device=device)
958
+ s.a2 = wp.zeros((2,) * 2, dtype=float, device=device)
959
+ s.a3 = wp.zeros((2,) * 3, dtype=float, device=device)
960
+ s.a4 = wp.zeros((2,) * 4, dtype=float, device=device)
961
+
962
+ # fill with custom struct value
963
+ a1.fill_(s)
964
+ a2.fill_(s)
965
+ a3.fill_(s)
966
+ a4.fill_(s)
967
+
968
+ ns = s.numpy_value()
969
+
970
+ expected1 = np.empty(a1.shape, dtype=nptype)
971
+ expected2 = np.empty(a2.shape, dtype=nptype)
972
+ expected3 = np.empty(a3.shape, dtype=nptype)
973
+ expected4 = np.empty(a4.shape, dtype=nptype)
974
+
975
+ expected1.fill(ns)
976
+ expected2.fill(ns)
977
+ expected3.fill(ns)
978
+ expected4.fill(ns)
979
+
980
+ assert_np_equal(a1.numpy(), expected1)
981
+ assert_np_equal(a2.numpy(), expected2)
982
+ assert_np_equal(a3.numpy(), expected3)
983
+ assert_np_equal(a4.numpy(), expected4)
984
+
985
+ # test clearing
986
+ a1.zero_()
987
+ a2.zero_()
988
+ a3.zero_()
989
+ a4.zero_()
990
+
991
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
992
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
993
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
994
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
995
+
996
+
997
+ def test_fill_slices(test, device):
998
+ # test fill_ and zero_ for non-contiguous arrays
999
+ # Note: we don't need to test the whole range of dtypes (vectors, matrices, structs) here
1000
+
1001
+ dim_x = 8
1002
+
1003
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1004
+ a1 = wp.zeros(dim_x, dtype=wptype, device=device)
1005
+ a2 = wp.zeros((dim_x, dim_x), dtype=wptype, device=device)
1006
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=wptype, device=device)
1007
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=wptype, device=device)
1008
+
1009
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
1010
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
1011
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
1012
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
1013
+
1014
+ # partition each array into even and odd slices
1015
+ a1a = a1[::2]
1016
+ a1b = a1[1::2]
1017
+ a2a = a2[::2]
1018
+ a2b = a2[1::2]
1019
+ a3a = a3[::2]
1020
+ a3b = a3[1::2]
1021
+ a4a = a4[::2]
1022
+ a4b = a4[1::2]
1023
+
1024
+ # fill even slices
1025
+ fill_a = 17
1026
+ a1a.fill_(fill_a)
1027
+ a2a.fill_(fill_a)
1028
+ a3a.fill_(fill_a)
1029
+ a4a.fill_(fill_a)
1030
+
1031
+ # ensure filled slices are correct
1032
+ assert_np_equal(a1a.numpy(), np.full(a1a.shape, fill_a, dtype=nptype))
1033
+ assert_np_equal(a2a.numpy(), np.full(a2a.shape, fill_a, dtype=nptype))
1034
+ assert_np_equal(a3a.numpy(), np.full(a3a.shape, fill_a, dtype=nptype))
1035
+ assert_np_equal(a4a.numpy(), np.full(a4a.shape, fill_a, dtype=nptype))
1036
+
1037
+ # ensure unfilled slices are unaffected
1038
+ assert_np_equal(a1b.numpy(), np.zeros(a1b.shape, dtype=nptype))
1039
+ assert_np_equal(a2b.numpy(), np.zeros(a2b.shape, dtype=nptype))
1040
+ assert_np_equal(a3b.numpy(), np.zeros(a3b.shape, dtype=nptype))
1041
+ assert_np_equal(a4b.numpy(), np.zeros(a4b.shape, dtype=nptype))
1042
+
1043
+ # fill odd slices
1044
+ fill_b = 42
1045
+ a1b.fill_(fill_b)
1046
+ a2b.fill_(fill_b)
1047
+ a3b.fill_(fill_b)
1048
+ a4b.fill_(fill_b)
1049
+
1050
+ # ensure filled slices are correct
1051
+ assert_np_equal(a1b.numpy(), np.full(a1b.shape, fill_b, dtype=nptype))
1052
+ assert_np_equal(a2b.numpy(), np.full(a2b.shape, fill_b, dtype=nptype))
1053
+ assert_np_equal(a3b.numpy(), np.full(a3b.shape, fill_b, dtype=nptype))
1054
+ assert_np_equal(a4b.numpy(), np.full(a4b.shape, fill_b, dtype=nptype))
1055
+
1056
+ # ensure unfilled slices are unaffected
1057
+ assert_np_equal(a1a.numpy(), np.full(a1a.shape, fill_a, dtype=nptype))
1058
+ assert_np_equal(a2a.numpy(), np.full(a2a.shape, fill_a, dtype=nptype))
1059
+ assert_np_equal(a3a.numpy(), np.full(a3a.shape, fill_a, dtype=nptype))
1060
+ assert_np_equal(a4a.numpy(), np.full(a4a.shape, fill_a, dtype=nptype))
1061
+
1062
+ # clear even slices
1063
+ a1a.zero_()
1064
+ a2a.zero_()
1065
+ a3a.zero_()
1066
+ a4a.zero_()
1067
+
1068
+ # ensure cleared slices are correct
1069
+ assert_np_equal(a1a.numpy(), np.zeros(a1a.shape, dtype=nptype))
1070
+ assert_np_equal(a2a.numpy(), np.zeros(a2a.shape, dtype=nptype))
1071
+ assert_np_equal(a3a.numpy(), np.zeros(a3a.shape, dtype=nptype))
1072
+ assert_np_equal(a4a.numpy(), np.zeros(a4a.shape, dtype=nptype))
1073
+
1074
+ # ensure uncleared slices are unaffected
1075
+ assert_np_equal(a1b.numpy(), np.full(a1b.shape, fill_b, dtype=nptype))
1076
+ assert_np_equal(a2b.numpy(), np.full(a2b.shape, fill_b, dtype=nptype))
1077
+ assert_np_equal(a3b.numpy(), np.full(a3b.shape, fill_b, dtype=nptype))
1078
+ assert_np_equal(a4b.numpy(), np.full(a4b.shape, fill_b, dtype=nptype))
1079
+
1080
+ # re-fill even slices
1081
+ a1a.fill_(fill_a)
1082
+ a2a.fill_(fill_a)
1083
+ a3a.fill_(fill_a)
1084
+ a4a.fill_(fill_a)
1085
+
1086
+ # clear odd slices
1087
+ a1b.zero_()
1088
+ a2b.zero_()
1089
+ a3b.zero_()
1090
+ a4b.zero_()
1091
+
1092
+ # ensure cleared slices are correct
1093
+ assert_np_equal(a1b.numpy(), np.zeros(a1b.shape, dtype=nptype))
1094
+ assert_np_equal(a2b.numpy(), np.zeros(a2b.shape, dtype=nptype))
1095
+ assert_np_equal(a3b.numpy(), np.zeros(a3b.shape, dtype=nptype))
1096
+ assert_np_equal(a4b.numpy(), np.zeros(a4b.shape, dtype=nptype))
1097
+
1098
+ # ensure uncleared slices are unaffected
1099
+ assert_np_equal(a1a.numpy(), np.full(a1a.shape, fill_a, dtype=nptype))
1100
+ assert_np_equal(a2a.numpy(), np.full(a2a.shape, fill_a, dtype=nptype))
1101
+ assert_np_equal(a3a.numpy(), np.full(a3a.shape, fill_a, dtype=nptype))
1102
+ assert_np_equal(a4a.numpy(), np.full(a4a.shape, fill_a, dtype=nptype))
1103
+
1104
+
1105
+ def test_full_scalar(test, device):
1106
+ dim = 4
1107
+
1108
+ for ndim in range(1, 5):
1109
+ shape = (dim,) * ndim
1110
+
1111
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1112
+ # fill with int value and specific dtype
1113
+ fill_value = 42
1114
+ a = wp.full(shape, fill_value, dtype=wptype, device=device)
1115
+ na = a.numpy()
1116
+
1117
+ test.assertEqual(a.shape, shape)
1118
+ test.assertEqual(a.dtype, wptype)
1119
+ test.assertEqual(na.shape, shape)
1120
+ test.assertEqual(na.dtype, nptype)
1121
+ assert_np_equal(na, np.full(shape, fill_value, dtype=nptype))
1122
+
1123
+ if wptype in wp._src.types.float_types:
1124
+ # fill with float value and specific dtype
1125
+ fill_value = 13.37
1126
+ a = wp.full(shape, fill_value, dtype=wptype, device=device)
1127
+ na = a.numpy()
1128
+
1129
+ test.assertEqual(a.shape, shape)
1130
+ test.assertEqual(a.dtype, wptype)
1131
+ test.assertEqual(na.shape, shape)
1132
+ test.assertEqual(na.dtype, nptype)
1133
+ assert_np_equal(na, np.full(shape, fill_value, dtype=nptype))
1134
+
1135
+ # fill with int value and automatically inferred dtype
1136
+ fill_value = 42
1137
+ a = wp.full(shape, fill_value, device=device)
1138
+ na = a.numpy()
1139
+
1140
+ test.assertEqual(a.shape, shape)
1141
+ test.assertEqual(a.dtype, wp.int32)
1142
+ test.assertEqual(na.shape, shape)
1143
+ test.assertEqual(na.dtype, np.int32)
1144
+ assert_np_equal(na, np.full(shape, fill_value, dtype=np.int32))
1145
+
1146
+ # fill with float value and automatically inferred dtype
1147
+ fill_value = 13.37
1148
+ a = wp.full(shape, fill_value, device=device)
1149
+ na = a.numpy()
1150
+
1151
+ test.assertEqual(a.shape, shape)
1152
+ test.assertEqual(a.dtype, wp.float32)
1153
+ test.assertEqual(na.shape, shape)
1154
+ test.assertEqual(na.dtype, np.float32)
1155
+ assert_np_equal(na, np.full(shape, fill_value, dtype=np.float32))
1156
+
1157
+
1158
+ def test_full_vector(test, device):
1159
+ dim = 4
1160
+
1161
+ for ndim in range(1, 5):
1162
+ shape = (dim,) * ndim
1163
+
1164
+ # full from scalar
1165
+ for veclen in [2, 3, 4, 5]:
1166
+ npshape = (*shape, veclen)
1167
+
1168
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1169
+ vectype = wp._src.types.vector(veclen, wptype)
1170
+
1171
+ # fill with scalar int value and specific dtype
1172
+ fill_value = 42
1173
+ a = wp.full(shape, fill_value, dtype=vectype, device=device)
1174
+ na = a.numpy()
1175
+
1176
+ test.assertEqual(a.shape, shape)
1177
+ test.assertEqual(a.dtype, vectype)
1178
+ test.assertEqual(na.shape, npshape)
1179
+ test.assertEqual(na.dtype, nptype)
1180
+ assert_np_equal(na, np.full(a.size * veclen, fill_value, dtype=nptype).reshape(npshape))
1181
+
1182
+ if wptype in wp._src.types.float_types:
1183
+ # fill with scalar float value and specific dtype
1184
+ fill_value = 13.37
1185
+ a = wp.full(shape, fill_value, dtype=vectype, device=device)
1186
+ na = a.numpy()
1187
+
1188
+ test.assertEqual(a.shape, shape)
1189
+ test.assertEqual(a.dtype, vectype)
1190
+ test.assertEqual(na.shape, npshape)
1191
+ test.assertEqual(na.dtype, nptype)
1192
+ assert_np_equal(na, np.full(a.size * veclen, fill_value, dtype=nptype).reshape(npshape))
1193
+
1194
+ # fill with vector value and specific dtype
1195
+ fill_vec = vectype(42)
1196
+ a = wp.full(shape, fill_vec, dtype=vectype, device=device)
1197
+ na = a.numpy()
1198
+
1199
+ test.assertEqual(a.shape, shape)
1200
+ test.assertEqual(a.dtype, vectype)
1201
+ test.assertEqual(na.shape, npshape)
1202
+ test.assertEqual(na.dtype, nptype)
1203
+ assert_np_equal(na, np.full(a.size * veclen, 42, dtype=nptype).reshape(npshape))
1204
+
1205
+ # fill with vector value and automatically inferred dtype
1206
+ a = wp.full(shape, fill_vec, device=device)
1207
+ na = a.numpy()
1208
+
1209
+ test.assertEqual(a.shape, shape)
1210
+ test.assertEqual(a.dtype, vectype)
1211
+ test.assertEqual(na.shape, npshape)
1212
+ test.assertEqual(na.dtype, nptype)
1213
+ assert_np_equal(na, np.full(a.size * veclen, 42, dtype=nptype).reshape(npshape))
1214
+
1215
+ fill_lists = [
1216
+ [17, 42],
1217
+ [17, 42, 99],
1218
+ [17, 42, 99, 101],
1219
+ [17, 42, 99, 101, 127],
1220
+ ]
1221
+
1222
+ # full from list and numpy array
1223
+ for fill_list in fill_lists:
1224
+ veclen = len(fill_list)
1225
+ npshape = (*shape, veclen)
1226
+
1227
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1228
+ vectype = wp._src.types.vector(veclen, wptype)
1229
+
1230
+ # fill with list and specific dtype
1231
+ a = wp.full(shape, fill_list, dtype=vectype, device=device)
1232
+ na = a.numpy()
1233
+
1234
+ test.assertEqual(a.shape, shape)
1235
+ test.assertEqual(a.dtype, vectype)
1236
+ test.assertEqual(na.shape, npshape)
1237
+ test.assertEqual(na.dtype, nptype)
1238
+
1239
+ expected = np.tile(np.array(fill_list, dtype=nptype), a.size).reshape(npshape)
1240
+ assert_np_equal(na, expected)
1241
+
1242
+ fill_arr = np.array(fill_list, dtype=nptype)
1243
+
1244
+ # fill with numpy array and specific dtype
1245
+ a = wp.full(shape, fill_arr, dtype=vectype, device=device)
1246
+ na = a.numpy()
1247
+
1248
+ test.assertEqual(a.shape, shape)
1249
+ test.assertEqual(a.dtype, vectype)
1250
+ test.assertEqual(na.shape, npshape)
1251
+ test.assertEqual(na.dtype, nptype)
1252
+ assert_np_equal(na, expected)
1253
+
1254
+ # fill with numpy array and automatically infer dtype
1255
+ a = wp.full(shape, fill_arr, device=device)
1256
+ na = a.numpy()
1257
+
1258
+ test.assertEqual(a.shape, shape)
1259
+ test.assertTrue(wp._src.types.types_equal(a.dtype, vectype))
1260
+ test.assertEqual(na.shape, npshape)
1261
+ test.assertEqual(na.dtype, nptype)
1262
+ assert_np_equal(na, expected)
1263
+
1264
+ # fill with list and automatically infer dtype
1265
+ a = wp.full(shape, fill_list, device=device)
1266
+ na = a.numpy()
1267
+
1268
+ test.assertEqual(a.shape, shape)
1269
+
1270
+ # check that the inferred dtype is a vector
1271
+ # Note that we cannot guarantee the scalar type, because it depends on numpy and may vary by platform
1272
+ # (e.g. int64 on Linux and int32 on Windows).
1273
+ test.assertEqual(a.dtype._wp_generic_type_str_, "vec_t")
1274
+ test.assertEqual(a.dtype._length_, veclen)
1275
+
1276
+ expected = np.tile(np.array(fill_list), a.size).reshape(npshape)
1277
+ assert_np_equal(na, expected)
1278
+
1279
+
1280
+ def test_full_matrix(test, device):
1281
+ dim = 4
1282
+
1283
+ for ndim in range(1, 5):
1284
+ shape = (dim,) * ndim
1285
+
1286
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1287
+ matrix_types = [
1288
+ # square matrices
1289
+ wp._src.types.matrix((2, 2), wptype),
1290
+ wp._src.types.matrix((3, 3), wptype),
1291
+ wp._src.types.matrix((4, 4), wptype),
1292
+ wp._src.types.matrix((5, 5), wptype),
1293
+ # non-square matrices
1294
+ wp._src.types.matrix((2, 3), wptype),
1295
+ wp._src.types.matrix((3, 2), wptype),
1296
+ wp._src.types.matrix((3, 4), wptype),
1297
+ wp._src.types.matrix((4, 3), wptype),
1298
+ ]
1299
+
1300
+ for mattype in matrix_types:
1301
+ npshape = (*shape, *mattype._shape_)
1302
+
1303
+ # fill with scalar int value and specific dtype
1304
+ fill_value = 42
1305
+ a = wp.full(shape, fill_value, dtype=mattype, device=device)
1306
+ na = a.numpy()
1307
+
1308
+ test.assertEqual(a.shape, shape)
1309
+ test.assertEqual(a.dtype, mattype)
1310
+ test.assertEqual(na.shape, npshape)
1311
+ test.assertEqual(na.dtype, nptype)
1312
+ assert_np_equal(na, np.full(a.size * mattype._length_, fill_value, dtype=nptype).reshape(npshape))
1313
+
1314
+ if wptype in wp._src.types.float_types:
1315
+ # fill with scalar float value and specific dtype
1316
+ fill_value = 13.37
1317
+ a = wp.full(shape, fill_value, dtype=mattype, device=device)
1318
+ na = a.numpy()
1319
+
1320
+ test.assertEqual(a.shape, shape)
1321
+ test.assertEqual(a.dtype, mattype)
1322
+ test.assertEqual(na.shape, npshape)
1323
+ test.assertEqual(na.dtype, nptype)
1324
+ assert_np_equal(na, np.full(a.size * mattype._length_, fill_value, dtype=nptype).reshape(npshape))
1325
+
1326
+ # fill with matrix value and specific dtype
1327
+ fill_mat = mattype(42)
1328
+ a = wp.full(shape, fill_mat, dtype=mattype, device=device)
1329
+ na = a.numpy()
1330
+
1331
+ test.assertEqual(a.shape, shape)
1332
+ test.assertEqual(a.dtype, mattype)
1333
+ test.assertEqual(na.shape, npshape)
1334
+ test.assertEqual(na.dtype, nptype)
1335
+ assert_np_equal(na, np.full(a.size * mattype._length_, 42, dtype=nptype).reshape(npshape))
1336
+
1337
+ # fill with matrix value and automatically inferred dtype
1338
+ fill_mat = mattype(42)
1339
+ a = wp.full(shape, fill_mat, device=device)
1340
+ na = a.numpy()
1341
+
1342
+ test.assertEqual(a.shape, shape)
1343
+ test.assertEqual(a.dtype, mattype)
1344
+ test.assertEqual(na.shape, npshape)
1345
+ test.assertEqual(na.dtype, nptype)
1346
+ assert_np_equal(na, np.full(a.size * mattype._length_, 42, dtype=nptype).reshape(npshape))
1347
+
1348
+ # fill with 1d numpy array and specific dtype
1349
+ if wptype != wp.bool:
1350
+ fill_arr1d = np.arange(mattype._length_, dtype=nptype)
1351
+ else:
1352
+ fill_arr1d = np.ones(mattype._length_, dtype=nptype)
1353
+ a = wp.full(shape, fill_arr1d, dtype=mattype, device=device)
1354
+ na = a.numpy()
1355
+
1356
+ test.assertEqual(a.shape, shape)
1357
+ test.assertEqual(a.dtype, mattype)
1358
+ test.assertEqual(na.shape, npshape)
1359
+ test.assertEqual(na.dtype, nptype)
1360
+
1361
+ expected = np.tile(fill_arr1d, a.size).reshape(npshape)
1362
+ assert_np_equal(na, expected)
1363
+
1364
+ # fill with 2d numpy array and specific dtype
1365
+ fill_arr2d = fill_arr1d.reshape(mattype._shape_)
1366
+ a = wp.full(shape, fill_arr2d, dtype=mattype, device=device)
1367
+ na = a.numpy()
1368
+
1369
+ test.assertEqual(a.shape, shape)
1370
+ test.assertEqual(a.dtype, mattype)
1371
+ test.assertEqual(na.shape, npshape)
1372
+ test.assertEqual(na.dtype, nptype)
1373
+ assert_np_equal(na, expected)
1374
+
1375
+ # fill with 2d numpy array and automatically infer dtype
1376
+ a = wp.full(shape, fill_arr2d, device=device)
1377
+ na = a.numpy()
1378
+
1379
+ test.assertEqual(a.shape, shape)
1380
+ test.assertTrue(wp._src.types.types_equal(a.dtype, mattype))
1381
+ test.assertEqual(na.shape, npshape)
1382
+ test.assertEqual(na.dtype, nptype)
1383
+ assert_np_equal(na, expected)
1384
+
1385
+ # fill with flat list and specific dtype
1386
+ fill_list1d = list(fill_arr1d)
1387
+ a = wp.full(shape, fill_list1d, dtype=mattype, device=device)
1388
+ na = a.numpy()
1389
+
1390
+ test.assertEqual(a.shape, shape)
1391
+ test.assertEqual(a.dtype, mattype)
1392
+ test.assertEqual(na.shape, npshape)
1393
+ test.assertEqual(na.dtype, nptype)
1394
+ assert_np_equal(na, expected)
1395
+
1396
+ # fill with nested list and specific dtype
1397
+ fill_list2d = [list(row) for row in fill_arr2d]
1398
+ a = wp.full(shape, fill_list2d, dtype=mattype, device=device)
1399
+ na = a.numpy()
1400
+
1401
+ test.assertEqual(a.shape, shape)
1402
+ test.assertEqual(a.dtype, mattype)
1403
+ test.assertEqual(na.shape, npshape)
1404
+ test.assertEqual(na.dtype, nptype)
1405
+ assert_np_equal(na, expected)
1406
+
1407
+ mat_lists = [
1408
+ # square matrices
1409
+ [[1, 2], [3, 4]],
1410
+ [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1411
+ [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
1412
+ # non-square matrices
1413
+ [[1, 2, 3, 4], [5, 6, 7, 8]],
1414
+ [[1, 2], [3, 4], [5, 6], [7, 8]],
1415
+ ]
1416
+
1417
+ # fill with nested lists and automatically infer dtype
1418
+ for fill_list in mat_lists:
1419
+ num_rows = len(fill_list)
1420
+ num_cols = len(fill_list[0])
1421
+ npshape = (*shape, num_rows, num_cols)
1422
+
1423
+ a = wp.full(shape, fill_list, device=device)
1424
+ na = a.numpy()
1425
+
1426
+ test.assertEqual(a.shape, shape)
1427
+
1428
+ # check that the inferred dtype is a correctly shaped matrix
1429
+ # Note that we cannot guarantee the scalar type, because it depends on numpy and may vary by platform
1430
+ # (e.g. int64 on Linux and int32 on Windows).
1431
+ test.assertEqual(a.dtype._wp_generic_type_str_, "mat_t")
1432
+ test.assertEqual(a.dtype._shape_, (num_rows, num_cols))
1433
+
1434
+ expected = np.tile(np.array(fill_list).flatten(), a.size).reshape(npshape)
1435
+ assert_np_equal(na, expected)
1436
+
1437
+
1438
+ def test_full_struct(test, device):
1439
+ dim = 4
1440
+
1441
+ for ndim in range(1, 5):
1442
+ shape = (dim,) * ndim
1443
+
1444
+ s = FillStruct()
1445
+
1446
+ # fill with default struct (should be zeros)
1447
+ a = wp.full(shape, s, dtype=FillStruct, device=device)
1448
+ na = a.numpy()
1449
+
1450
+ test.assertEqual(a.shape, shape)
1451
+ test.assertEqual(a.dtype, FillStruct)
1452
+ test.assertEqual(na.shape, shape)
1453
+ test.assertEqual(na.dtype, FillStruct.numpy_dtype())
1454
+ assert_np_equal(na, np.zeros(a.shape, dtype=FillStruct.numpy_dtype()))
1455
+
1456
+ # scalars
1457
+ s.i1 = -17
1458
+ s.i2 = 42
1459
+ s.i4 = 99
1460
+ s.i8 = 101
1461
+ s.f2 = -1.25
1462
+ s.f4 = 13.37
1463
+ s.f8 = 0.125
1464
+ # vectors
1465
+ s.v2 = [21, 22]
1466
+ s.v3 = [31, 32, 33]
1467
+ s.v4 = [41, 42, 43, 44]
1468
+ s.v5 = [51, 52, 53, 54, 55]
1469
+ # matrices
1470
+ s.m2 = [[61, 62]] * 2
1471
+ s.m3 = [[71, 72, 73]] * 3
1472
+ s.m4 = [[81, 82, 83, 84]] * 4
1473
+ s.m5 = [[91, 92, 93, 94, 95]] * 5
1474
+ # arrays
1475
+ s.a1 = wp.zeros((2,) * 1, dtype=float, device=device)
1476
+ s.a2 = wp.zeros((2,) * 2, dtype=float, device=device)
1477
+ s.a3 = wp.zeros((2,) * 3, dtype=float, device=device)
1478
+ s.a4 = wp.zeros((2,) * 4, dtype=float, device=device)
1479
+
1480
+ # fill with initialized struct and explicit dtype
1481
+ a = wp.full(shape, s, dtype=FillStruct, device=device)
1482
+ na = a.numpy()
1483
+
1484
+ test.assertEqual(a.shape, shape)
1485
+ test.assertEqual(a.dtype, FillStruct)
1486
+ test.assertEqual(na.shape, shape)
1487
+ test.assertEqual(na.dtype, FillStruct.numpy_dtype())
1488
+
1489
+ expected = np.empty(shape, dtype=FillStruct.numpy_dtype())
1490
+ expected.fill(s.numpy_value())
1491
+ assert_np_equal(na, expected)
1492
+
1493
+ # fill with initialized struct and automatically inferred dtype
1494
+ a = wp.full(shape, s, device=device)
1495
+ na = a.numpy()
1496
+
1497
+ test.assertEqual(a.shape, shape)
1498
+ test.assertEqual(a.dtype, FillStruct)
1499
+ test.assertEqual(na.shape, shape)
1500
+ test.assertEqual(na.dtype, FillStruct.numpy_dtype())
1501
+ assert_np_equal(na, expected)
1502
+
1503
+
1504
+ def test_ones_scalar(test, device):
1505
+ dim = 4
1506
+
1507
+ for ndim in range(1, 5):
1508
+ shape = (dim,) * ndim
1509
+
1510
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1511
+ a = wp.ones(shape, dtype=wptype, device=device)
1512
+ na = a.numpy()
1513
+
1514
+ test.assertEqual(a.shape, shape)
1515
+ test.assertEqual(a.dtype, wptype)
1516
+ test.assertEqual(na.shape, shape)
1517
+ test.assertEqual(na.dtype, nptype)
1518
+ assert_np_equal(na, np.ones(shape, dtype=nptype))
1519
+
1520
+
1521
+ def test_ones_vector(test, device):
1522
+ dim = 4
1523
+
1524
+ for ndim in range(1, 5):
1525
+ shape = (dim,) * ndim
1526
+
1527
+ for veclen in [2, 3, 4, 5]:
1528
+ npshape = (*shape, veclen)
1529
+
1530
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1531
+ vectype = wp._src.types.vector(veclen, wptype)
1532
+
1533
+ a = wp.ones(shape, dtype=vectype, device=device)
1534
+ na = a.numpy()
1535
+
1536
+ test.assertEqual(a.shape, shape)
1537
+ test.assertEqual(a.dtype, vectype)
1538
+ test.assertEqual(na.shape, npshape)
1539
+ test.assertEqual(na.dtype, nptype)
1540
+ assert_np_equal(na, np.ones(npshape, dtype=nptype))
1541
+
1542
+
1543
+ def test_ones_matrix(test, device):
1544
+ dim = 4
1545
+
1546
+ for ndim in range(1, 5):
1547
+ shape = (dim,) * ndim
1548
+
1549
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1550
+ matrix_types = [
1551
+ # square matrices
1552
+ wp._src.types.matrix((2, 2), wptype),
1553
+ wp._src.types.matrix((3, 3), wptype),
1554
+ wp._src.types.matrix((4, 4), wptype),
1555
+ wp._src.types.matrix((5, 5), wptype),
1556
+ # non-square matrices
1557
+ wp._src.types.matrix((2, 3), wptype),
1558
+ wp._src.types.matrix((3, 2), wptype),
1559
+ wp._src.types.matrix((3, 4), wptype),
1560
+ wp._src.types.matrix((4, 3), wptype),
1561
+ ]
1562
+
1563
+ for mattype in matrix_types:
1564
+ npshape = (*shape, *mattype._shape_)
1565
+
1566
+ a = wp.ones(shape, dtype=mattype, device=device)
1567
+ na = a.numpy()
1568
+
1569
+ test.assertEqual(a.shape, shape)
1570
+ test.assertEqual(a.dtype, mattype)
1571
+ test.assertEqual(na.shape, npshape)
1572
+ test.assertEqual(na.dtype, nptype)
1573
+ assert_np_equal(na, np.ones(npshape, dtype=nptype))
1574
+
1575
+
1576
+ def test_ones_like_scalar(test, device):
1577
+ dim = 4
1578
+
1579
+ for ndim in range(1, 5):
1580
+ shape = (dim,) * ndim
1581
+
1582
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1583
+ # source array
1584
+ a = wp.zeros(shape, dtype=wptype, device=device)
1585
+ na = a.numpy()
1586
+ test.assertEqual(a.shape, shape)
1587
+ test.assertEqual(a.dtype, wptype)
1588
+ test.assertEqual(na.shape, shape)
1589
+ test.assertEqual(na.dtype, nptype)
1590
+ assert_np_equal(na, np.zeros(shape, dtype=nptype))
1591
+
1592
+ # ones array
1593
+ b = wp.ones_like(a)
1594
+ nb = b.numpy()
1595
+ test.assertEqual(b.shape, shape)
1596
+ test.assertEqual(b.dtype, wptype)
1597
+ test.assertEqual(nb.shape, shape)
1598
+ test.assertEqual(nb.dtype, nptype)
1599
+ assert_np_equal(nb, np.ones(shape, dtype=nptype))
1600
+
1601
+
1602
+ def test_ones_like_vector(test, device):
1603
+ dim = 4
1604
+
1605
+ for ndim in range(1, 5):
1606
+ shape = (dim,) * ndim
1607
+
1608
+ for veclen in [2, 3, 4, 5]:
1609
+ npshape = (*shape, veclen)
1610
+
1611
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1612
+ vectype = wp._src.types.vector(veclen, wptype)
1613
+
1614
+ # source array
1615
+ a = wp.zeros(shape, dtype=vectype, device=device)
1616
+ na = a.numpy()
1617
+ test.assertEqual(a.shape, shape)
1618
+ test.assertEqual(a.dtype, vectype)
1619
+ test.assertEqual(na.shape, npshape)
1620
+ test.assertEqual(na.dtype, nptype)
1621
+ assert_np_equal(na, np.zeros(npshape, dtype=nptype))
1622
+
1623
+ # ones array
1624
+ b = wp.ones_like(a)
1625
+ nb = b.numpy()
1626
+ test.assertEqual(b.shape, shape)
1627
+ test.assertEqual(b.dtype, vectype)
1628
+ test.assertEqual(nb.shape, npshape)
1629
+ test.assertEqual(nb.dtype, nptype)
1630
+ assert_np_equal(nb, np.ones(npshape, dtype=nptype))
1631
+
1632
+
1633
+ def test_ones_like_matrix(test, device):
1634
+ dim = 4
1635
+
1636
+ for ndim in range(1, 5):
1637
+ shape = (dim,) * ndim
1638
+
1639
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1640
+ matrix_types = [
1641
+ # square matrices
1642
+ wp._src.types.matrix((2, 2), wptype),
1643
+ wp._src.types.matrix((3, 3), wptype),
1644
+ wp._src.types.matrix((4, 4), wptype),
1645
+ wp._src.types.matrix((5, 5), wptype),
1646
+ # non-square matrices
1647
+ wp._src.types.matrix((2, 3), wptype),
1648
+ wp._src.types.matrix((3, 2), wptype),
1649
+ wp._src.types.matrix((3, 4), wptype),
1650
+ wp._src.types.matrix((4, 3), wptype),
1651
+ ]
1652
+
1653
+ for mattype in matrix_types:
1654
+ npshape = (*shape, *mattype._shape_)
1655
+
1656
+ # source array
1657
+ a = wp.zeros(shape, dtype=mattype, device=device)
1658
+ na = a.numpy()
1659
+ test.assertEqual(a.shape, shape)
1660
+ test.assertEqual(a.dtype, mattype)
1661
+ test.assertEqual(na.shape, npshape)
1662
+ test.assertEqual(na.dtype, nptype)
1663
+ assert_np_equal(na, np.zeros(npshape, dtype=nptype))
1664
+
1665
+ # ones array
1666
+ b = wp.ones_like(a)
1667
+ nb = b.numpy()
1668
+ test.assertEqual(b.shape, shape)
1669
+ test.assertEqual(b.dtype, mattype)
1670
+ test.assertEqual(nb.shape, npshape)
1671
+ test.assertEqual(nb.dtype, nptype)
1672
+ assert_np_equal(nb, np.ones(npshape, dtype=nptype))
1673
+
1674
+
1675
+ def test_round_trip(test, device):
1676
+ rng = np.random.default_rng(123)
1677
+ dim_x = 4
1678
+
1679
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1680
+ a_np = rng.standard_normal(size=dim_x).astype(nptype)
1681
+ a = wp.array(a_np, device=device)
1682
+ test.assertEqual(a.dtype, wptype)
1683
+
1684
+ assert_np_equal(a.numpy(), a_np)
1685
+
1686
+ v_np = rng.standard_normal(size=(dim_x, 3)).astype(nptype)
1687
+ v = wp.array(v_np, dtype=wp._src.types.vector(3, wptype), device=device)
1688
+
1689
+ assert_np_equal(v.numpy(), v_np)
1690
+
1691
+
1692
+ def test_empty_array(test, device):
1693
+ # Test whether common operations work with empty (zero-sized) arrays
1694
+ # without throwing exceptions.
1695
+
1696
+ def test_empty_ops(ndim, nrows, ncols, wptype, nptype):
1697
+ shape = (0,) * ndim
1698
+ dtype_shape = ()
1699
+
1700
+ if wptype in wp._src.types.scalar_types:
1701
+ # scalar, vector, or matrix
1702
+ if ncols > 0:
1703
+ if nrows > 0:
1704
+ wptype = wp._src.types.matrix((nrows, ncols), wptype)
1705
+ else:
1706
+ wptype = wp._src.types.vector(ncols, wptype)
1707
+ dtype_shape = wptype._shape_
1708
+ fill_value = wptype(42)
1709
+ else:
1710
+ # struct
1711
+ fill_value = wptype()
1712
+
1713
+ # create a zero-sized array
1714
+ a = wp.empty(shape, dtype=wptype, device=device, requires_grad=True)
1715
+
1716
+ test.assertEqual(a.ptr, None)
1717
+ test.assertEqual(a.size, 0)
1718
+ test.assertEqual(a.shape, shape)
1719
+ test.assertEqual(a.grad.ptr, None)
1720
+ test.assertEqual(a.grad.size, 0)
1721
+ test.assertEqual(a.grad.shape, shape)
1722
+
1723
+ # all of these methods should succeed with zero-sized arrays
1724
+ a.zero_()
1725
+ a.fill_(fill_value)
1726
+ b = a.flatten()
1727
+ b = a.reshape((0,))
1728
+ b = a.transpose()
1729
+ b = a.contiguous()
1730
+
1731
+ b = wp.empty_like(a)
1732
+ b = wp.zeros_like(a)
1733
+ b = wp.full_like(a, fill_value)
1734
+ b = wp.clone(a)
1735
+
1736
+ wp.copy(a, b)
1737
+ a.assign(b)
1738
+
1739
+ na = a.numpy()
1740
+ test.assertEqual(na.size, 0)
1741
+ test.assertEqual(na.shape, (*shape, *dtype_shape))
1742
+ test.assertEqual(na.dtype, nptype)
1743
+
1744
+ test.assertEqual(a.list(), [])
1745
+
1746
+ for ndim in range(1, 5):
1747
+ # test with scalars, vectors, and matrices
1748
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1749
+ # scalars
1750
+ test_empty_ops(ndim, 0, 0, wptype, nptype)
1751
+
1752
+ for ncols in [2, 3, 4, 5]:
1753
+ # vectors
1754
+ test_empty_ops(ndim, 0, ncols, wptype, nptype)
1755
+ # square matrices
1756
+ test_empty_ops(ndim, ncols, ncols, wptype, nptype)
1757
+
1758
+ # non-square matrices
1759
+ test_empty_ops(ndim, 2, 3, wptype, nptype)
1760
+ test_empty_ops(ndim, 3, 2, wptype, nptype)
1761
+ test_empty_ops(ndim, 3, 4, wptype, nptype)
1762
+ test_empty_ops(ndim, 4, 3, wptype, nptype)
1763
+
1764
+ # test with structs
1765
+ test_empty_ops(ndim, 0, 0, FillStruct, FillStruct.numpy_dtype())
1766
+
1767
+
1768
+ def test_empty_from_numpy(test, device):
1769
+ # Test whether wrapping an empty (zero-sized) numpy array works correctly
1770
+
1771
+ def test_empty_from_data(ndim, nrows, ncols, wptype, nptype):
1772
+ shape = (0,) * ndim
1773
+ dtype_shape = ()
1774
+
1775
+ if ncols > 0:
1776
+ if nrows > 0:
1777
+ wptype = wp._src.types.matrix((nrows, ncols), wptype)
1778
+ else:
1779
+ wptype = wp._src.types.vector(ncols, wptype)
1780
+ dtype_shape = wptype._shape_
1781
+
1782
+ npshape = (*shape, *dtype_shape)
1783
+
1784
+ na = np.empty(npshape, dtype=nptype)
1785
+ a = wp.array(na, dtype=wptype, device=device)
1786
+ test.assertEqual(a.size, 0)
1787
+ test.assertEqual(a.shape, shape)
1788
+
1789
+ for ndim in range(1, 5):
1790
+ # test with scalars, vectors, and matrices
1791
+ for nptype, wptype in wp._src.types.np_dtype_to_warp_type.items():
1792
+ # scalars
1793
+ test_empty_from_data(ndim, 0, 0, wptype, nptype)
1794
+
1795
+ for ncols in [2, 3, 4, 5]:
1796
+ # vectors
1797
+ test_empty_from_data(ndim, 0, ncols, wptype, nptype)
1798
+ # square matrices
1799
+ test_empty_from_data(ndim, ncols, ncols, wptype, nptype)
1800
+
1801
+ # non-square matrices
1802
+ test_empty_from_data(ndim, 2, 3, wptype, nptype)
1803
+ test_empty_from_data(ndim, 3, 2, wptype, nptype)
1804
+ test_empty_from_data(ndim, 3, 4, wptype, nptype)
1805
+ test_empty_from_data(ndim, 4, 3, wptype, nptype)
1806
+
1807
+
1808
+ def test_empty_from_list(test, device):
1809
+ # Test whether creating an array from an empty Python list works correctly
1810
+
1811
+ def test_empty_from_data(nrows, ncols, wptype):
1812
+ if ncols > 0:
1813
+ if nrows > 0:
1814
+ wptype = wp._src.types.matrix((nrows, ncols), wptype)
1815
+ else:
1816
+ wptype = wp._src.types.vector(ncols, wptype)
1817
+
1818
+ a = wp.array([], dtype=wptype, device=device)
1819
+ test.assertEqual(a.size, 0)
1820
+ test.assertEqual(a.shape, (0,))
1821
+
1822
+ # test with scalars, vectors, and matrices
1823
+ for wptype in wp._src.types.scalar_types:
1824
+ # scalars
1825
+ test_empty_from_data(0, 0, wptype)
1826
+
1827
+ for ncols in [2, 3, 4, 5]:
1828
+ # vectors
1829
+ test_empty_from_data(0, ncols, wptype)
1830
+ # square matrices
1831
+ test_empty_from_data(ncols, ncols, wptype)
1832
+
1833
+ # non-square matrices
1834
+ test_empty_from_data(2, 3, wptype)
1835
+ test_empty_from_data(3, 2, wptype)
1836
+ test_empty_from_data(3, 4, wptype)
1837
+ test_empty_from_data(4, 3, wptype)
1838
+
1839
+
1840
+ def test_to_list_scalar(test, device):
1841
+ dim = 3
1842
+ fill_value = 42
1843
+
1844
+ for ndim in range(1, 5):
1845
+ shape = (dim,) * ndim
1846
+
1847
+ for wptype in wp._src.types.scalar_types:
1848
+ a = wp.full(shape, fill_value, dtype=wptype, device=device)
1849
+ l = a.list()
1850
+
1851
+ test.assertEqual(len(l), a.size)
1852
+ test.assertTrue(all(x == fill_value for x in l))
1853
+
1854
+
1855
+ def test_to_list_vector(test, device):
1856
+ dim = 3
1857
+
1858
+ for ndim in range(1, 5):
1859
+ shape = (dim,) * ndim
1860
+
1861
+ for veclen in [2, 3, 4, 5]:
1862
+ for wptype in wp._src.types.scalar_types:
1863
+ vectype = wp._src.types.vector(veclen, wptype)
1864
+ fill_value = vectype(42)
1865
+
1866
+ a = wp.full(shape, fill_value, dtype=vectype, device=device)
1867
+ l = a.list()
1868
+
1869
+ test.assertEqual(len(l), a.size)
1870
+ test.assertTrue(all(x == fill_value for x in l))
1871
+
1872
+
1873
+ def test_to_list_matrix(test, device):
1874
+ dim = 3
1875
+
1876
+ for ndim in range(1, 5):
1877
+ shape = (dim,) * ndim
1878
+
1879
+ for wptype in wp._src.types.scalar_types:
1880
+ matrix_types = [
1881
+ # square matrices
1882
+ wp._src.types.matrix((2, 2), wptype),
1883
+ wp._src.types.matrix((3, 3), wptype),
1884
+ wp._src.types.matrix((4, 4), wptype),
1885
+ wp._src.types.matrix((5, 5), wptype),
1886
+ # non-square matrices
1887
+ wp._src.types.matrix((2, 3), wptype),
1888
+ wp._src.types.matrix((3, 2), wptype),
1889
+ wp._src.types.matrix((3, 4), wptype),
1890
+ wp._src.types.matrix((4, 3), wptype),
1891
+ ]
1892
+
1893
+ for mattype in matrix_types:
1894
+ fill_value = mattype(42)
1895
+
1896
+ a = wp.full(shape, fill_value, dtype=mattype, device=device)
1897
+ l = a.list()
1898
+
1899
+ test.assertEqual(len(l), a.size)
1900
+ test.assertTrue(all(x == fill_value for x in l))
1901
+
1902
+
1903
+ def test_to_list_struct(test, device):
1904
+ @wp.struct
1905
+ class Inner:
1906
+ h: wp.float16
1907
+ v: wp.vec3
1908
+
1909
+ @wp.struct
1910
+ class ListStruct:
1911
+ i: int
1912
+ f: float
1913
+ h: wp.float16
1914
+ vi: wp.vec2i
1915
+ vf: wp.vec3f
1916
+ vh: wp.vec4h
1917
+ mi: wp._src.types.matrix((2, 2), int)
1918
+ mf: wp._src.types.matrix((3, 3), float)
1919
+ mh: wp._src.types.matrix((4, 4), wp.float16)
1920
+ inner: Inner
1921
+ a1: wp.array(dtype=int)
1922
+ a2: wp.array2d(dtype=float)
1923
+ a3: wp.array3d(dtype=wp.float16)
1924
+ bool: wp.bool
1925
+
1926
+ dim = 3
1927
+
1928
+ s = ListStruct()
1929
+ s.i = 42
1930
+ s.f = 2.5
1931
+ s.h = -1.25
1932
+ s.vi = wp.vec2i(1, 2)
1933
+ s.vf = wp.vec3f(0.1, 0.2, 0.3)
1934
+ s.vh = wp.vec4h(1.0, 2.0, 3.0, 4.0)
1935
+ s.mi = [[1, 2], [3, 4]]
1936
+ s.mf = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1937
+ s.mh = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
1938
+ s.inner = Inner()
1939
+ s.inner.h = 1.5
1940
+ s.inner.v = [1, 2, 3]
1941
+ s.a1 = wp.empty(1, dtype=int, device=device)
1942
+ s.a2 = wp.empty((1, 1), dtype=float, device=device)
1943
+ s.a3 = wp.empty((1, 1, 1), dtype=wp.float16, device=device)
1944
+ s.bool = True
1945
+
1946
+ for ndim in range(1, 5):
1947
+ shape = (dim,) * ndim
1948
+
1949
+ a = wp.full(shape, s, dtype=ListStruct, device=device)
1950
+ l = a.list()
1951
+
1952
+ for i in range(a.size):
1953
+ test.assertEqual(l[i].i, s.i)
1954
+ test.assertEqual(l[i].f, s.f)
1955
+ test.assertEqual(l[i].h, s.h)
1956
+ test.assertEqual(l[i].vi, s.vi)
1957
+ test.assertEqual(l[i].vf, s.vf)
1958
+ test.assertEqual(l[i].vh, s.vh)
1959
+ test.assertEqual(l[i].mi, s.mi)
1960
+ test.assertEqual(l[i].mf, s.mf)
1961
+ test.assertEqual(l[i].mh, s.mh)
1962
+ test.assertEqual(l[i].bool, s.bool)
1963
+ test.assertEqual(l[i].inner.h, s.inner.h)
1964
+ test.assertEqual(l[i].inner.v, s.inner.v)
1965
+ test.assertEqual(l[i].a1.dtype, s.a1.dtype)
1966
+ test.assertEqual(l[i].a1.ndim, s.a1.ndim)
1967
+ test.assertEqual(l[i].a2.dtype, s.a2.dtype)
1968
+ test.assertEqual(l[i].a2.ndim, s.a2.ndim)
1969
+ test.assertEqual(l[i].a3.dtype, s.a3.dtype)
1970
+ test.assertEqual(l[i].a3.ndim, s.a3.ndim)
1971
+
1972
+
1973
+ @wp.kernel
1974
+ def kernel_array_to_bool(array_null: wp.array(dtype=float), array_valid: wp.array(dtype=float)):
1975
+ if not array_null:
1976
+ # always succeed
1977
+ wp.expect_eq(0, 0)
1978
+ else:
1979
+ # force failure
1980
+ wp.expect_eq(1, 2)
1981
+
1982
+ if array_valid:
1983
+ # always succeed
1984
+ wp.expect_eq(0, 0)
1985
+ else:
1986
+ # force failure
1987
+ wp.expect_eq(1, 2)
1988
+
1989
+
1990
+ def test_array_to_bool(test, device):
1991
+ arr = wp.zeros(8, dtype=float, device=device)
1992
+
1993
+ wp.launch(kernel_array_to_bool, dim=1, inputs=[None, arr], device=device)
1994
+
1995
+
1996
+ @wp.struct
1997
+ class InputStruct:
1998
+ param1: int
1999
+ param2: float
2000
+ param3: wp.vec3
2001
+ param4: wp.array(dtype=float)
2002
+
2003
+
2004
+ @wp.struct
2005
+ class OutputStruct:
2006
+ param1: int
2007
+ param2: float
2008
+ param3: wp.vec3
2009
+
2010
+
2011
+ @wp.kernel
2012
+ def struct_array_kernel(inputs: wp.array(dtype=InputStruct), outputs: wp.array(dtype=OutputStruct)):
2013
+ tid = wp.tid()
2014
+
2015
+ wp.expect_eq(inputs[tid].param1, tid)
2016
+ wp.expect_eq(inputs[tid].param2, float(tid * tid))
2017
+
2018
+ wp.expect_eq(inputs[tid].param3[0], 1.0)
2019
+ wp.expect_eq(inputs[tid].param3[1], 2.0)
2020
+ wp.expect_eq(inputs[tid].param3[2], 3.0)
2021
+
2022
+ wp.expect_eq(inputs[tid].param4[0], 1.0)
2023
+ wp.expect_eq(inputs[tid].param4[1], 2.0)
2024
+ wp.expect_eq(inputs[tid].param4[2], 3.0)
2025
+
2026
+ o = OutputStruct()
2027
+ o.param1 = inputs[tid].param1
2028
+ o.param2 = inputs[tid].param2
2029
+ o.param3 = inputs[tid].param3
2030
+
2031
+ outputs[tid] = o
2032
+
2033
+
2034
+ def test_array_of_structs(test, device):
2035
+ num_items = 10
2036
+
2037
+ l = []
2038
+ for i in range(num_items):
2039
+ s = InputStruct()
2040
+ s.param1 = i
2041
+ s.param2 = float(i * i)
2042
+ s.param3 = wp.vec3(1.0, 2.0, 3.0)
2043
+ s.param4 = wp.array([1.0, 2.0, 3.0], dtype=float, device=device)
2044
+
2045
+ l.append(s)
2046
+
2047
+ # initialize array from list of structs
2048
+ inputs = wp.array(l, dtype=InputStruct, device=device)
2049
+ outputs = wp.zeros(num_items, dtype=OutputStruct, device=device)
2050
+
2051
+ # pass to our compute kernel
2052
+ wp.launch(struct_array_kernel, dim=num_items, inputs=[inputs, outputs], device=device)
2053
+
2054
+ out_numpy = outputs.numpy()
2055
+ out_list = outputs.list()
2056
+ out_cptr = outputs.to("cpu").cptr()
2057
+
2058
+ for i in range(num_items):
2059
+ test.assertEqual(out_numpy[i][0], l[i].param1)
2060
+ test.assertEqual(out_numpy[i][1], l[i].param2)
2061
+ assert_np_equal(out_numpy[i][2], np.array(l[i].param3))
2062
+
2063
+ # test named slices of numpy structured array
2064
+ test.assertEqual(out_numpy["param1"][i], l[i].param1)
2065
+ test.assertEqual(out_numpy["param2"][i], l[i].param2)
2066
+ assert_np_equal(out_numpy["param3"][i], np.array(l[i].param3))
2067
+
2068
+ test.assertEqual(out_list[i].param1, l[i].param1)
2069
+ test.assertEqual(out_list[i].param2, l[i].param2)
2070
+ test.assertEqual(out_list[i].param3, l[i].param3)
2071
+
2072
+ test.assertEqual(out_cptr[i].param1, l[i].param1)
2073
+ test.assertEqual(out_cptr[i].param2, l[i].param2)
2074
+ test.assertEqual(out_cptr[i].param3, l[i].param3)
2075
+
2076
+
2077
+ @wp.struct
2078
+ class GradStruct:
2079
+ param1: int
2080
+ param2: float
2081
+ param3: wp.vec3
2082
+
2083
+
2084
+ @wp.kernel
2085
+ def test_array_of_structs_grad_kernel(inputs: wp.array(dtype=GradStruct), loss: wp.array(dtype=float)):
2086
+ tid = wp.tid()
2087
+
2088
+ wp.atomic_add(loss, 0, inputs[tid].param2 * 2.0)
2089
+
2090
+
2091
+ def test_array_of_structs_grad(test, device):
2092
+ num_items = 10
2093
+
2094
+ l = []
2095
+ for i in range(num_items):
2096
+ g = GradStruct()
2097
+ g.param2 = float(i)
2098
+
2099
+ l.append(g)
2100
+
2101
+ a = wp.array(l, dtype=GradStruct, device=device, requires_grad=True)
2102
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
2103
+
2104
+ with wp.Tape() as tape:
2105
+ wp.launch(test_array_of_structs_grad_kernel, dim=num_items, inputs=[a, loss], device=device)
2106
+
2107
+ tape.backward(loss)
2108
+
2109
+ grads = a.grad.numpy()
2110
+ assert_np_equal(grads["param2"], np.full(num_items, 2.0, dtype=np.float32))
2111
+
2112
+
2113
+ @wp.struct
2114
+ class NumpyStruct:
2115
+ x: int
2116
+ v: wp.vec3
2117
+
2118
+
2119
+ def test_array_of_structs_from_numpy(test, device):
2120
+ num_items = 10
2121
+
2122
+ na = np.zeros(num_items, dtype=NumpyStruct.numpy_dtype())
2123
+ na["x"] = 17
2124
+ na["v"] = (1, 2, 3)
2125
+
2126
+ a = wp.array(data=na, dtype=NumpyStruct, device=device)
2127
+
2128
+ assert_np_equal(a.numpy(), na)
2129
+
2130
+
2131
+ def test_array_of_structs_roundtrip(test, device):
2132
+ num_items = 10
2133
+
2134
+ value = NumpyStruct()
2135
+ value.x = 17
2136
+ value.v = wp.vec3(1.0, 2.0, 3.0)
2137
+
2138
+ # create Warp structured array
2139
+ a = wp.full(num_items, value, device=device)
2140
+
2141
+ # convert to NumPy structured array
2142
+ na = a.numpy()
2143
+
2144
+ expected = np.zeros(num_items, dtype=NumpyStruct.numpy_dtype())
2145
+ expected["x"] = value.x
2146
+ expected["v"] = value.v
2147
+
2148
+ assert_np_equal(na, expected)
2149
+
2150
+ # modify a field
2151
+ na["x"] = 42
2152
+
2153
+ # convert back to Warp array
2154
+ a = wp.from_numpy(na, NumpyStruct, device=device)
2155
+
2156
+ expected["x"] = 42
2157
+
2158
+ assert_np_equal(a.numpy(), expected)
2159
+
2160
+
2161
+ def test_array_from_numpy(test, device):
2162
+ arr = np.array((1.0, 2.0, 3.0), dtype=float)
2163
+
2164
+ result = wp.from_numpy(arr, device=device)
2165
+ expected = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, shape=(3,))
2166
+ assert_np_equal(result.numpy(), expected.numpy())
2167
+
2168
+ result = wp.from_numpy(arr, dtype=wp.vec3, device=device)
2169
+ expected = wp.array(((1.0, 2.0, 3.0),), dtype=wp.vec3, shape=(1,))
2170
+ assert_np_equal(result.numpy(), expected.numpy())
2171
+
2172
+ # --------------------------------------------------------------------------
2173
+
2174
+ arr = np.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=float)
2175
+
2176
+ result = wp.from_numpy(arr, device=device)
2177
+ expected = wp.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=wp.vec3, shape=(2,))
2178
+ assert_np_equal(result.numpy(), expected.numpy())
2179
+
2180
+ result = wp.from_numpy(arr, dtype=wp.float32, device=device)
2181
+ expected = wp.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=wp.float32, shape=(2, 3))
2182
+ assert_np_equal(result.numpy(), expected.numpy())
2183
+
2184
+ result = wp.from_numpy(arr, dtype=wp.float32, shape=(6,), device=device)
2185
+ expected = wp.array((1.0, 2.0, 3.0, 4.0, 5.0, 6.0), dtype=wp.float32, shape=(6,))
2186
+ assert_np_equal(result.numpy(), expected.numpy())
2187
+
2188
+ # --------------------------------------------------------------------------
2189
+
2190
+ arr = np.array(
2191
+ (
2192
+ (
2193
+ (1.0, 2.0, 3.0, 4.0),
2194
+ (2.0, 3.0, 4.0, 5.0),
2195
+ (3.0, 4.0, 5.0, 6.0),
2196
+ (4.0, 5.0, 6.0, 7.0),
2197
+ ),
2198
+ (
2199
+ (2.0, 3.0, 4.0, 5.0),
2200
+ (3.0, 4.0, 5.0, 6.0),
2201
+ (4.0, 5.0, 6.0, 7.0),
2202
+ (5.0, 6.0, 7.0, 8.0),
2203
+ ),
2204
+ ),
2205
+ dtype=float,
2206
+ )
2207
+
2208
+ result = wp.from_numpy(arr, device=device)
2209
+ expected = wp.array(
2210
+ (
2211
+ (
2212
+ (1.0, 2.0, 3.0, 4.0),
2213
+ (2.0, 3.0, 4.0, 5.0),
2214
+ (3.0, 4.0, 5.0, 6.0),
2215
+ (4.0, 5.0, 6.0, 7.0),
2216
+ ),
2217
+ (
2218
+ (2.0, 3.0, 4.0, 5.0),
2219
+ (3.0, 4.0, 5.0, 6.0),
2220
+ (4.0, 5.0, 6.0, 7.0),
2221
+ (5.0, 6.0, 7.0, 8.0),
2222
+ ),
2223
+ ),
2224
+ dtype=wp.mat44,
2225
+ shape=(2,),
2226
+ )
2227
+ assert_np_equal(result.numpy(), expected.numpy())
2228
+
2229
+ result = wp.from_numpy(arr, dtype=wp.float32, device=device)
2230
+ expected = wp.array(
2231
+ (
2232
+ (
2233
+ (1.0, 2.0, 3.0, 4.0),
2234
+ (2.0, 3.0, 4.0, 5.0),
2235
+ (3.0, 4.0, 5.0, 6.0),
2236
+ (4.0, 5.0, 6.0, 7.0),
2237
+ ),
2238
+ (
2239
+ (2.0, 3.0, 4.0, 5.0),
2240
+ (3.0, 4.0, 5.0, 6.0),
2241
+ (4.0, 5.0, 6.0, 7.0),
2242
+ (5.0, 6.0, 7.0, 8.0),
2243
+ ),
2244
+ ),
2245
+ dtype=wp.float32,
2246
+ shape=(2, 4, 4),
2247
+ )
2248
+ assert_np_equal(result.numpy(), expected.numpy())
2249
+
2250
+ result = wp.from_numpy(arr, dtype=wp.vec4, device=device).reshape((8,)) # Reshape from (2, 4)
2251
+ expected = wp.array(
2252
+ (
2253
+ (1.0, 2.0, 3.0, 4.0),
2254
+ (2.0, 3.0, 4.0, 5.0),
2255
+ (3.0, 4.0, 5.0, 6.0),
2256
+ (4.0, 5.0, 6.0, 7.0),
2257
+ (2.0, 3.0, 4.0, 5.0),
2258
+ (3.0, 4.0, 5.0, 6.0),
2259
+ (4.0, 5.0, 6.0, 7.0),
2260
+ (5.0, 6.0, 7.0, 8.0),
2261
+ ),
2262
+ dtype=wp.vec4,
2263
+ shape=(8,),
2264
+ )
2265
+ assert_np_equal(result.numpy(), expected.numpy())
2266
+
2267
+ result = wp.from_numpy(arr, dtype=wp.float32, shape=(32,), device=device)
2268
+ expected = wp.array(
2269
+ (
2270
+ 1.0,
2271
+ 2.0,
2272
+ 3.0,
2273
+ 4.0,
2274
+ 2.0,
2275
+ 3.0,
2276
+ 4.0,
2277
+ 5.0,
2278
+ 3.0,
2279
+ 4.0,
2280
+ 5.0,
2281
+ 6.0,
2282
+ 4.0,
2283
+ 5.0,
2284
+ 6.0,
2285
+ 7.0,
2286
+ 2.0,
2287
+ 3.0,
2288
+ 4.0,
2289
+ 5.0,
2290
+ 3.0,
2291
+ 4.0,
2292
+ 5.0,
2293
+ 6.0,
2294
+ 4.0,
2295
+ 5.0,
2296
+ 6.0,
2297
+ 7.0,
2298
+ 5.0,
2299
+ 6.0,
2300
+ 7.0,
2301
+ 8.0,
2302
+ ),
2303
+ dtype=wp.float32,
2304
+ shape=(32,),
2305
+ )
2306
+ assert_np_equal(result.numpy(), expected.numpy())
2307
+
2308
+
2309
+ def test_array_aliasing_from_numpy(test, device):
2310
+ device = wp.get_device(device)
2311
+ assert device.is_cpu
2312
+
2313
+ a_np = np.ones(8, dtype=np.int32)
2314
+ a_wp = wp.array(a_np, dtype=int, copy=False, device=device)
2315
+ test.assertIs(a_wp._ref, a_np) # check that some ref is kept to original array
2316
+ test.assertEqual(a_wp.ptr, a_np.ctypes.data)
2317
+
2318
+ a_np_2 = a_wp.numpy()
2319
+ test.assertTrue((a_np_2 == 1).all())
2320
+
2321
+ # updating source array should update aliased array
2322
+ a_np.fill(2)
2323
+ test.assertTrue((a_np_2 == 2).all())
2324
+
2325
+ # trying to alias from a different type should do a copy
2326
+ # do it twice to check that the copy buffer is not being reused for different arrays
2327
+
2328
+ b_np = np.ones(8, dtype=np.int64)
2329
+ c_np = np.zeros(8, dtype=np.int64)
2330
+ b_wp = wp.array(b_np, dtype=int, copy=False, device=device)
2331
+ c_wp = wp.array(c_np, dtype=int, copy=False, device=device)
2332
+
2333
+ test.assertNotEqual(b_wp.ptr, b_np.ctypes.data)
2334
+ test.assertNotEqual(b_wp.ptr, c_wp.ptr)
2335
+
2336
+ b_np_2 = b_wp.numpy()
2337
+ c_np_2 = c_wp.numpy()
2338
+ test.assertTrue((b_np_2 == 1).all())
2339
+ test.assertTrue((c_np_2 == 0).all())
2340
+
2341
+
2342
+ def test_array_from_cai(test, device):
2343
+ import torch
2344
+
2345
+ @wp.kernel
2346
+ def first_row_plus_one(x: wp.array2d(dtype=float)):
2347
+ i, j = wp.tid()
2348
+ if i == 0:
2349
+ x[i, j] += 1.0
2350
+
2351
+ # start with torch tensor
2352
+ arr = torch.zeros((3, 3))
2353
+ torch_device = wp.device_to_torch(device)
2354
+ arr_torch = arr.to(torch_device)
2355
+
2356
+ # wrap as warp array via __cuda_array_interface__
2357
+ arr_warp = wp.array(arr_torch, device=device)
2358
+
2359
+ wp.launch(kernel=first_row_plus_one, dim=(3, 3), inputs=[arr_warp], device=device)
2360
+
2361
+ # re-wrap as torch array
2362
+ arr_torch = wp.to_torch(arr_warp)
2363
+
2364
+ # transpose
2365
+ arr_torch = torch.as_strided(arr_torch, size=(3, 3), stride=(arr_torch.stride(1), arr_torch.stride(0)))
2366
+
2367
+ # re-wrap as warp array with new strides
2368
+ arr_warp = wp.array(arr_torch, device=device)
2369
+
2370
+ wp.launch(kernel=first_row_plus_one, dim=(3, 3), inputs=[arr_warp], device=device)
2371
+
2372
+ assert_np_equal(arr_warp.numpy(), np.array([[2, 1, 1], [1, 0, 0], [1, 0, 0]]))
2373
+
2374
+
2375
+ def test_array_from_data(test, device):
2376
+ with wp.ScopedDevice(device):
2377
+ # =========================================
2378
+ # scalars, reshaping
2379
+
2380
+ data = np.arange(12, dtype=np.float32).reshape((3, 4))
2381
+ src = wp.array(data)
2382
+
2383
+ assert src.device == device
2384
+
2385
+ dtypes = [Any, wp.float32]
2386
+ shapes = [None, (3, 4), (12,), (3, 2, 2)]
2387
+
2388
+ for dtype in dtypes:
2389
+ for shape in shapes:
2390
+ with test.subTest(msg=f"scalar, dtype={dtype}, shape={shape}"):
2391
+ dst = wp.array(src, dtype=dtype, shape=shape)
2392
+ assert dst.device == src.device
2393
+ if dtype is Any:
2394
+ assert dst.dtype == src.dtype
2395
+ else:
2396
+ assert dst.dtype == dtype
2397
+ if shape is None:
2398
+ assert dst.shape == src.shape
2399
+ assert_np_equal(dst.numpy(), data)
2400
+ else:
2401
+ assert dst.shape == shape
2402
+ assert_np_equal(dst.numpy(), data.reshape(shape))
2403
+
2404
+ # =========================================
2405
+ # vectors, reshaping
2406
+
2407
+ with test.subTest(msg="vector, single"):
2408
+ data = np.arange(3, dtype=np.float32)
2409
+ src = wp.array(data)
2410
+ dst = wp.array(src, dtype=wp.vec3)
2411
+ assert dst.dtype == wp.vec3
2412
+ assert dst.shape == (1,)
2413
+ assert_np_equal(dst.numpy(), data.reshape((1, 3)))
2414
+
2415
+ with test.subTest(msg="vector, multiple in 1d"):
2416
+ data = np.arange(12, dtype=np.float32)
2417
+ src = wp.array(data)
2418
+ dst = wp.array(src, dtype=wp.vec3)
2419
+ assert dst.dtype == wp.vec3
2420
+ assert dst.shape == (4,)
2421
+ assert_np_equal(dst.numpy(), data.reshape((4, 3)))
2422
+
2423
+ with test.subTest(msg="vector, singles in 2d"):
2424
+ data = np.arange(12, dtype=np.float32).reshape((4, 3))
2425
+ src = wp.array(data)
2426
+ dst = wp.array(src, dtype=wp.vec3)
2427
+ assert dst.dtype == wp.vec3
2428
+ assert dst.shape == (4,)
2429
+ assert_np_equal(dst.numpy(), data.reshape((4, 3)))
2430
+
2431
+ with test.subTest(msg="vector, multiples in 2d"):
2432
+ data = np.arange(24, dtype=np.float32).reshape((4, 6))
2433
+ src = wp.array(data)
2434
+ dst = wp.array(src, dtype=wp.vec3)
2435
+ assert dst.dtype == wp.vec3
2436
+ assert dst.shape == (4, 2)
2437
+ assert_np_equal(dst.numpy(), data.reshape((4, 2, 3)))
2438
+
2439
+ with test.subTest(msg="vector, singles in 2d, reshape"):
2440
+ data = np.arange(12, dtype=np.float32).reshape((4, 3))
2441
+ src = wp.array(data)
2442
+ dst = wp.array(src, dtype=wp.vec3, shape=(2, 2))
2443
+ assert dst.dtype == wp.vec3
2444
+ assert dst.shape == (2, 2)
2445
+ assert_np_equal(dst.numpy(), data.reshape((2, 2, 3)))
2446
+
2447
+ with test.subTest(msg="vector, multiples in 2d, reshape"):
2448
+ data = np.arange(24, dtype=np.float32).reshape((4, 6))
2449
+ src = wp.array(data)
2450
+ dst = wp.array(src, dtype=wp.vec3, shape=(2, 2, 2))
2451
+ assert dst.dtype == wp.vec3
2452
+ assert dst.shape == (2, 2, 2)
2453
+ assert_np_equal(dst.numpy(), data.reshape((2, 2, 2, 3)))
2454
+
2455
+ # =========================================
2456
+ # matrices, reshaping
2457
+
2458
+ with test.subTest(msg="matrix, single in 2d"):
2459
+ # one 2x2 matrix in a 2d array
2460
+ data = np.arange(4, dtype=np.float32).reshape((2, 2))
2461
+ src = wp.array(data)
2462
+ dst = wp.array(src, dtype=wp.mat22)
2463
+ assert dst.dtype == wp.mat22
2464
+ assert dst.shape == (1,)
2465
+ assert_np_equal(dst.numpy(), data.reshape((1, 2, 2)))
2466
+
2467
+ with test.subTest(msg="matrix, single in 1d"):
2468
+ # 2x2 matrix in a 1d array
2469
+ data = np.arange(4, dtype=np.float32)
2470
+ src = wp.array(data)
2471
+ dst = wp.array(src, dtype=wp.mat22)
2472
+ assert dst.dtype == wp.mat22
2473
+ assert dst.shape == (1,)
2474
+ assert_np_equal(dst.numpy(), data.reshape((1, 2, 2)))
2475
+
2476
+ with test.subTest(msg="matrix, multiples in 1d"):
2477
+ # 3 2x2 matrices in a 1d array
2478
+ data = np.arange(12, dtype=np.float32)
2479
+ src = wp.array(data)
2480
+ dst = wp.array(src, dtype=wp.mat22)
2481
+ assert dst.dtype == wp.mat22
2482
+ assert dst.shape == (3,)
2483
+ assert_np_equal(dst.numpy(), data.reshape((3, 2, 2)))
2484
+
2485
+ with test.subTest(msg="matrix, multiples in 1d, reshape"):
2486
+ # 4 2x2 matrices in a 1d array
2487
+ data = np.arange(16, dtype=np.float32)
2488
+ src = wp.array(data)
2489
+ dst = wp.array(src, dtype=wp.mat22, shape=(4,))
2490
+ assert dst.dtype == wp.mat22
2491
+ assert dst.shape == (4,)
2492
+ assert_np_equal(dst.numpy(), data.reshape((4, 2, 2)))
2493
+
2494
+ with test.subTest(msg="matrix, multiples in 2d"):
2495
+ # 3 2x2 matrices in a 2d array
2496
+ data = np.arange(12, dtype=np.float32).reshape((3, 4))
2497
+ src = wp.array(data)
2498
+ dst = wp.array(src, dtype=wp.mat22)
2499
+ assert dst.dtype == wp.mat22
2500
+ assert dst.shape == (3,)
2501
+ assert_np_equal(dst.numpy(), data.reshape((3, 2, 2)))
2502
+
2503
+ with test.subTest(msg="matrix, multiples in 2d, reshape"):
2504
+ # 4 2x2 matrices in a 2d array
2505
+ data = np.arange(16, dtype=np.float32).reshape((4, 4))
2506
+ src = wp.array(data)
2507
+ dst = wp.array(src, dtype=wp.mat22, shape=(2, 2))
2508
+ assert dst.dtype == wp.mat22
2509
+ assert dst.shape == (2, 2)
2510
+ assert_np_equal(dst.numpy(), data.reshape((2, 2, 2, 2)))
2511
+
2512
+ with test.subTest(msg="matrix, multiples in 3d"):
2513
+ # 3 2x2 matrices in a 3d array
2514
+ data = np.arange(12, dtype=np.float32).reshape((3, 2, 2))
2515
+ src = wp.array(data)
2516
+ dst = wp.array(src, dtype=wp.mat22)
2517
+ assert dst.dtype == wp.mat22
2518
+ assert dst.shape == (3,)
2519
+ assert_np_equal(dst.numpy(), data.reshape((3, 2, 2)))
2520
+
2521
+ with test.subTest(msg="matrix, multiples in 3d, reshape"):
2522
+ # 4 2x2 matrices in a 3d array
2523
+ data = np.arange(16, dtype=np.float32).reshape((4, 2, 2))
2524
+ src = wp.array(data)
2525
+ dst = wp.array(src, dtype=wp.mat22, shape=(2, 2))
2526
+ assert dst.dtype == wp.mat22
2527
+ assert dst.shape == (2, 2)
2528
+ assert_np_equal(dst.numpy(), data.reshape((2, 2, 2, 2)))
2529
+
2530
+ # =========================================
2531
+ # vectors and matrices in strided arrays
2532
+
2533
+ with test.subTest(msg="vector, singles in 2d, strided"):
2534
+ # 4 vec3 in strided 2d array
2535
+ data = np.arange(20, dtype=np.float32).reshape((4, 5))
2536
+ src = wp.array(data)[:, 2:] # source with strides
2537
+ dst = wp.array(src, dtype=wp.vec3)
2538
+ assert dst.dtype == wp.vec3
2539
+ assert dst.shape == (4,)
2540
+ expected = np.array(
2541
+ [
2542
+ [2, 3, 4],
2543
+ [7, 8, 9],
2544
+ [12, 13, 14],
2545
+ [17, 18, 19],
2546
+ ],
2547
+ dtype=np.float32,
2548
+ )
2549
+ assert_np_equal(dst.numpy(), expected)
2550
+
2551
+ with test.subTest(msg="vector, multiples in 2d, strided"):
2552
+ # 4 vec3 in strided 2d array
2553
+ data = np.arange(14, dtype=np.float32).reshape((2, 7))
2554
+ src = wp.array(data)[:, 1:] # source with strides
2555
+ dst = wp.array(src, dtype=wp.vec3)
2556
+ assert dst.dtype == wp.vec3
2557
+ assert dst.shape == (2, 2)
2558
+ expected = np.array(
2559
+ [
2560
+ [
2561
+ [1, 2, 3],
2562
+ [4, 5, 6],
2563
+ ],
2564
+ [
2565
+ [8, 9, 10],
2566
+ [11, 12, 13],
2567
+ ],
2568
+ ],
2569
+ dtype=np.float32,
2570
+ )
2571
+ assert_np_equal(dst.numpy(), expected)
2572
+
2573
+ with test.subTest(msg="matrix, multiples in 2d, strided"):
2574
+ # 3 2x2 matrices in a 2d array
2575
+ data = np.arange(15, dtype=np.float32).reshape((3, 5))
2576
+ src = wp.array(data)[:, 1:] # source with strides
2577
+ dst = wp.array(src, dtype=wp.mat22)
2578
+ assert dst.dtype == wp.mat22
2579
+ assert dst.shape == (3,)
2580
+ expected = np.array(
2581
+ [
2582
+ [
2583
+ [1, 2],
2584
+ [3, 4],
2585
+ ],
2586
+ [
2587
+ [6, 7],
2588
+ [8, 9],
2589
+ ],
2590
+ [
2591
+ [11, 12],
2592
+ [13, 14],
2593
+ ],
2594
+ ],
2595
+ dtype=np.float32,
2596
+ )
2597
+ assert_np_equal(dst.numpy(), expected)
2598
+
2599
+ with test.subTest(msg="matrix, multiples in 3d, strided"):
2600
+ # 3 2x2 matrices in a 3d array
2601
+ data = np.arange(18, dtype=np.float32).reshape((3, 3, 2))
2602
+ src = wp.array(data)[:, 1:] # source with strides
2603
+ dst = wp.array(src, dtype=wp.mat22)
2604
+ assert dst.dtype == wp.mat22
2605
+ assert dst.shape == (3,)
2606
+ expected = np.array(
2607
+ [
2608
+ [
2609
+ [2, 3],
2610
+ [4, 5],
2611
+ ],
2612
+ [
2613
+ [8, 9],
2614
+ [10, 11],
2615
+ ],
2616
+ [
2617
+ [14, 15],
2618
+ [16, 17],
2619
+ ],
2620
+ ],
2621
+ dtype=np.float32,
2622
+ )
2623
+ assert_np_equal(dst.numpy(), expected)
2624
+
2625
+
2626
+ @wp.kernel
2627
+ def inplace_add_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
2628
+ i = wp.tid()
2629
+ x[i] += y[i]
2630
+
2631
+
2632
+ @wp.kernel
2633
+ def inplace_add_2d(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2634
+ i, j = wp.tid()
2635
+ x[i, j] += y[i, j]
2636
+
2637
+
2638
+ @wp.kernel
2639
+ def inplace_add_3d(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float)):
2640
+ i, j, k = wp.tid()
2641
+ x[i, j, k] += y[i, j, k]
2642
+
2643
+
2644
+ @wp.kernel
2645
+ def inplace_add_4d(x: wp.array4d(dtype=float), y: wp.array4d(dtype=float)):
2646
+ i, j, k, l = wp.tid()
2647
+ x[i, j, k, l] += y[i, j, k, l]
2648
+
2649
+
2650
+ @wp.kernel
2651
+ def inplace_sub_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
2652
+ i = wp.tid()
2653
+ x[i] -= y[i]
2654
+
2655
+
2656
+ @wp.kernel
2657
+ def inplace_sub_2d(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2658
+ i, j = wp.tid()
2659
+ x[i, j] -= y[i, j]
2660
+
2661
+
2662
+ @wp.kernel
2663
+ def inplace_sub_3d(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float)):
2664
+ i, j, k = wp.tid()
2665
+ x[i, j, k] -= y[i, j, k]
2666
+
2667
+
2668
+ @wp.kernel
2669
+ def inplace_sub_4d(x: wp.array4d(dtype=float), y: wp.array4d(dtype=float)):
2670
+ i, j, k, l = wp.tid()
2671
+ x[i, j, k, l] -= y[i, j, k, l]
2672
+
2673
+
2674
+ @wp.kernel
2675
+ def inplace_add_vecs(x: wp.array(dtype=wp.vec3), y: wp.array(dtype=wp.vec3)):
2676
+ i = wp.tid()
2677
+ x[i] += y[i]
2678
+
2679
+
2680
+ @wp.kernel
2681
+ def inplace_add_mats(x: wp.array(dtype=wp.mat33), y: wp.array(dtype=wp.mat33)):
2682
+ i = wp.tid()
2683
+ x[i] += y[i]
2684
+
2685
+
2686
+ @wp.kernel
2687
+ def inplace_add_rhs(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
2688
+ i = wp.tid()
2689
+ a = y[i]
2690
+ a += x[i]
2691
+ wp.atomic_add(z, 0, a)
2692
+
2693
+
2694
+ vec9 = wp.vec(length=9, dtype=float)
2695
+
2696
+
2697
+ @wp.kernel
2698
+ def inplace_add_custom_vec(x: wp.array(dtype=vec9), y: wp.array(dtype=vec9)):
2699
+ i = wp.tid()
2700
+ x[i] += y[i]
2701
+ x[i] += y[i]
2702
+
2703
+
2704
+ def test_array_inplace_diff_ops(test, device):
2705
+ N = 3
2706
+ x1 = wp.ones(N, dtype=float, requires_grad=True, device=device)
2707
+ x2 = wp.ones((N, N), dtype=float, requires_grad=True, device=device)
2708
+ x3 = wp.ones((N, N, N), dtype=float, requires_grad=True, device=device)
2709
+ x4 = wp.ones((N, N, N, N), dtype=float, requires_grad=True, device=device)
2710
+
2711
+ y1 = wp.clone(x1, requires_grad=True, device=device)
2712
+ y2 = wp.clone(x2, requires_grad=True, device=device)
2713
+ y3 = wp.clone(x3, requires_grad=True, device=device)
2714
+ y4 = wp.clone(x4, requires_grad=True, device=device)
2715
+
2716
+ v1 = wp.ones(1, dtype=wp.vec3, requires_grad=True, device=device)
2717
+ v2 = wp.clone(v1, requires_grad=True, device=device)
2718
+
2719
+ m1 = wp.ones(1, dtype=wp.mat33, requires_grad=True, device=device)
2720
+ m2 = wp.clone(m1, requires_grad=True, device=device)
2721
+
2722
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
2723
+ y = wp.clone(x, requires_grad=True, device=device)
2724
+ z = wp.zeros(1, dtype=float, requires_grad=True, device=device)
2725
+
2726
+ np_ones_1d = np.ones(N, dtype=float)
2727
+ np_ones_2d = np.ones((N, N), dtype=float)
2728
+ np_ones_3d = np.ones((N, N, N), dtype=float)
2729
+ np_ones_4d = np.ones((N, N, N, N), dtype=float)
2730
+
2731
+ np_twos_1d = np.full(N, 2.0, dtype=float)
2732
+ np_twos_2d = np.full((N, N), 2.0, dtype=float)
2733
+ np_twos_3d = np.full((N, N, N), 2.0, dtype=float)
2734
+ np_twos_4d = np.full((N, N, N, N), 2.0, dtype=float)
2735
+
2736
+ tape = wp.Tape()
2737
+ with tape:
2738
+ wp.launch(inplace_add_1d, N, inputs=[x1, y1], device=device)
2739
+ wp.launch(inplace_add_2d, (N, N), inputs=[x2, y2], device=device)
2740
+ wp.launch(inplace_add_3d, (N, N, N), inputs=[x3, y3], device=device)
2741
+ wp.launch(inplace_add_4d, (N, N, N, N), inputs=[x4, y4], device=device)
2742
+
2743
+ tape.backward(grads={x1: wp.ones_like(x1), x2: wp.ones_like(x2), x3: wp.ones_like(x3), x4: wp.ones_like(x4)})
2744
+
2745
+ assert_np_equal(x1.grad.numpy(), np_ones_1d)
2746
+ assert_np_equal(x2.grad.numpy(), np_ones_2d)
2747
+ assert_np_equal(x3.grad.numpy(), np_ones_3d)
2748
+ assert_np_equal(x4.grad.numpy(), np_ones_4d)
2749
+
2750
+ assert_np_equal(y1.grad.numpy(), np_ones_1d)
2751
+ assert_np_equal(y2.grad.numpy(), np_ones_2d)
2752
+ assert_np_equal(y3.grad.numpy(), np_ones_3d)
2753
+ assert_np_equal(y4.grad.numpy(), np_ones_4d)
2754
+
2755
+ assert_np_equal(x1.numpy(), np_twos_1d)
2756
+ assert_np_equal(x2.numpy(), np_twos_2d)
2757
+ assert_np_equal(x3.numpy(), np_twos_3d)
2758
+ assert_np_equal(x4.numpy(), np_twos_4d)
2759
+
2760
+ x1.grad.zero_()
2761
+ x2.grad.zero_()
2762
+ x3.grad.zero_()
2763
+ x4.grad.zero_()
2764
+ tape.reset()
2765
+
2766
+ with tape:
2767
+ wp.launch(inplace_sub_1d, N, inputs=[x1, y1], device=device)
2768
+ wp.launch(inplace_sub_2d, (N, N), inputs=[x2, y2], device=device)
2769
+ wp.launch(inplace_sub_3d, (N, N, N), inputs=[x3, y3], device=device)
2770
+ wp.launch(inplace_sub_4d, (N, N, N, N), inputs=[x4, y4], device=device)
2771
+
2772
+ tape.backward(grads={x1: wp.ones_like(x1), x2: wp.ones_like(x2), x3: wp.ones_like(x3), x4: wp.ones_like(x4)})
2773
+
2774
+ assert_np_equal(x1.grad.numpy(), np_ones_1d)
2775
+ assert_np_equal(x2.grad.numpy(), np_ones_2d)
2776
+ assert_np_equal(x3.grad.numpy(), np_ones_3d)
2777
+ assert_np_equal(x4.grad.numpy(), np_ones_4d)
2778
+
2779
+ assert_np_equal(y1.grad.numpy(), -np_ones_1d)
2780
+ assert_np_equal(y2.grad.numpy(), -np_ones_2d)
2781
+ assert_np_equal(y3.grad.numpy(), -np_ones_3d)
2782
+ assert_np_equal(y4.grad.numpy(), -np_ones_4d)
2783
+
2784
+ assert_np_equal(x1.numpy(), np_ones_1d)
2785
+ assert_np_equal(x2.numpy(), np_ones_2d)
2786
+ assert_np_equal(x3.numpy(), np_ones_3d)
2787
+ assert_np_equal(x4.numpy(), np_ones_4d)
2788
+
2789
+ x1.grad.zero_()
2790
+ x2.grad.zero_()
2791
+ x3.grad.zero_()
2792
+ x4.grad.zero_()
2793
+ tape.reset()
2794
+
2795
+ with tape:
2796
+ wp.launch(inplace_add_vecs, 1, inputs=[v1, v2], device=device)
2797
+ wp.launch(inplace_add_mats, 1, inputs=[m1, m2], device=device)
2798
+ wp.launch(inplace_add_rhs, 1, inputs=[x, y, z], device=device)
2799
+
2800
+ tape.backward(loss=z, grads={v1: wp.ones_like(v1, requires_grad=False), m1: wp.ones_like(m1, requires_grad=False)})
2801
+
2802
+ assert_np_equal(v1.numpy(), np.full(shape=(1, 3), fill_value=2.0, dtype=float))
2803
+ assert_np_equal(v1.grad.numpy(), np.ones(shape=(1, 3), dtype=float))
2804
+ assert_np_equal(v2.grad.numpy(), np.ones(shape=(1, 3), dtype=float))
2805
+
2806
+ assert_np_equal(m1.numpy(), np.full(shape=(1, 3, 3), fill_value=2.0, dtype=float))
2807
+ assert_np_equal(m1.grad.numpy(), np.ones(shape=(1, 3, 3), dtype=float))
2808
+ assert_np_equal(m2.grad.numpy(), np.ones(shape=(1, 3, 3), dtype=float))
2809
+
2810
+ assert_np_equal(x.grad.numpy(), np.ones(1, dtype=float))
2811
+ assert_np_equal(y.grad.numpy(), np.ones(1, dtype=float))
2812
+ tape.reset()
2813
+
2814
+ x = wp.zeros(1, dtype=vec9, requires_grad=True, device=device)
2815
+ y = wp.ones(1, dtype=vec9, requires_grad=True, device=device)
2816
+
2817
+ with tape:
2818
+ wp.launch(inplace_add_custom_vec, 1, inputs=[x, y], device=device)
2819
+
2820
+ tape.backward(grads={x: wp.ones_like(x)})
2821
+
2822
+ assert_np_equal(x.numpy(), np.full((1, 9), 2.0, dtype=float))
2823
+ assert_np_equal(y.grad.numpy(), np.full((1, 9), 2.0, dtype=float))
2824
+
2825
+
2826
+ @wp.kernel
2827
+ def inplace_mul_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
2828
+ i = wp.tid()
2829
+ x[i] *= y[i]
2830
+
2831
+
2832
+ @wp.kernel
2833
+ def inplace_div_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
2834
+ i = wp.tid()
2835
+ x[i] /= y[i]
2836
+
2837
+
2838
+ @wp.kernel
2839
+ def inplace_add_non_atomic_types(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
2840
+ i = wp.tid()
2841
+ x[i] += y[i]
2842
+
2843
+
2844
+ uint16vec3 = wp.vec(length=3, dtype=wp.uint16)
2845
+
2846
+
2847
+ def test_array_inplace_non_diff_ops(test, device):
2848
+ N = 3
2849
+ x1 = wp.full(N, value=10.0, dtype=float, device=device)
2850
+ y1 = wp.full(N, value=5.0, dtype=float, device=device)
2851
+
2852
+ wp.launch(inplace_mul_1d, N, inputs=[x1, y1], device=device)
2853
+ assert_np_equal(x1.numpy(), np.full(N, fill_value=50.0, dtype=float))
2854
+
2855
+ x1.fill_(10.0)
2856
+ y1.fill_(5.0)
2857
+ wp.launch(inplace_div_1d, N, inputs=[x1, y1], device=device)
2858
+ assert_np_equal(x1.numpy(), np.full(N, fill_value=2.0, dtype=float))
2859
+
2860
+ for dtype in (*wp._src.types.non_atomic_types, wp.vec2b, wp.vec2ub, wp.vec2s, wp.vec2us, uint16vec3):
2861
+ x = wp.full(N, value=0, dtype=dtype, device=device)
2862
+ y = wp.full(N, value=1, dtype=dtype, device=device)
2863
+
2864
+ wp.launch(inplace_add_non_atomic_types, N, inputs=[x, y], device=device)
2865
+ assert_np_equal(x.numpy(), y.numpy())
2866
+
2867
+
2868
+ @wp.kernel
2869
+ def inc_scalar(a: wp.array(dtype=float)):
2870
+ tid = wp.tid()
2871
+ a[tid] = a[tid] + 1.0
2872
+
2873
+
2874
+ @wp.kernel
2875
+ def inc_vector(a: wp.array(dtype=wp.vec3f)):
2876
+ tid = wp.tid()
2877
+ a[tid] = a[tid] + wp.vec3f(1.0)
2878
+
2879
+
2880
+ @wp.kernel
2881
+ def inc_matrix(a: wp.array(dtype=wp.mat22f)):
2882
+ tid = wp.tid()
2883
+ a[tid] = a[tid] + wp.mat22f(1.0)
2884
+
2885
+
2886
+ def test_direct_from_numpy(test, device):
2887
+ """Pass NumPy arrays to Warp kernels directly"""
2888
+
2889
+ n = 12
2890
+
2891
+ s = np.arange(n, dtype=np.float32)
2892
+ v = np.arange(n, dtype=np.float32).reshape((n // 3, 3))
2893
+ m = np.arange(n, dtype=np.float32).reshape((n // 4, 2, 2))
2894
+
2895
+ wp.launch(inc_scalar, dim=n, inputs=[s], device=device)
2896
+ wp.launch(inc_vector, dim=n // 3, inputs=[v], device=device)
2897
+ wp.launch(inc_matrix, dim=n // 4, inputs=[m], device=device)
2898
+
2899
+ expected = np.arange(1, n + 1, dtype=np.float32)
2900
+
2901
+ assert_np_equal(s, expected)
2902
+ assert_np_equal(v.reshape(n), expected)
2903
+ assert_np_equal(m.reshape(n), expected)
2904
+
2905
+
2906
+ @wp.kernel
2907
+ def kernel_array_from_ptr(arr_orig: wp.array2d(dtype=wp.float32)):
2908
+ arr = wp.array(ptr=arr_orig.ptr, shape=(2, 3), dtype=wp.float32)
2909
+ arr[0, 0] = 1.0
2910
+ arr[0, 1] = 2.0
2911
+ arr[0, 2] = 3.0
2912
+
2913
+
2914
+ def test_kernel_array_from_ptr(test, device):
2915
+ arr = wp.zeros(shape=(2, 3), dtype=wp.float32, device=device)
2916
+ wp.launch(kernel_array_from_ptr, dim=(1,), inputs=(arr,), device=device)
2917
+ assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 3.0), (0.0, 0.0, 0.0))))
2918
+
2919
+
2920
+ @wp.struct
2921
+ class MyStruct:
2922
+ a: wp.float32
2923
+ b: wp.float32
2924
+ c: wp.float32
2925
+
2926
+
2927
+ @wp.kernel
2928
+ def kernel_array_from_ptr_struct(arr_orig: wp.array(dtype=MyStruct)):
2929
+ arr = wp.array(ptr=arr_orig.ptr, shape=(2,), dtype=MyStruct)
2930
+ arr[0].a = 1.0
2931
+ arr[0].b = 2.0
2932
+ arr[0].c = 3.0
2933
+ arr[1].a = 4.0
2934
+ arr[1].b = 5.0
2935
+ arr[1].c = 6.0
2936
+
2937
+
2938
+ def test_kernel_array_from_ptr_struct(test, device):
2939
+ arr = wp.zeros(shape=(2,), dtype=MyStruct, device=device)
2940
+ wp.launch(kernel_array_from_ptr_struct, dim=(1,), inputs=(arr,), device=device)
2941
+ arr_np = arr.numpy()
2942
+ expected = np.zeros_like(arr_np)
2943
+ expected[0] = (1.0, 2.0, 3.0)
2944
+ expected[1] = (4.0, 5.0, 6.0)
2945
+ assert_np_equal(arr_np, expected)
2946
+
2947
+
2948
+ @wp.kernel
2949
+ def kernel_array_from_ptr_variable_shape(
2950
+ ptr: wp.uint64,
2951
+ shape_x: int,
2952
+ shape_y: int,
2953
+ ):
2954
+ arr = wp.array(ptr=ptr, shape=(shape_x, shape_y), dtype=wp.float32)
2955
+ arr[0, 0] = 1.0
2956
+ arr[0, 1] = 2.0
2957
+ if shape_y > 2:
2958
+ arr[0, 2] = 3.0
2959
+
2960
+
2961
+ def test_kernel_array_from_ptr_variable_shape(test, device):
2962
+ arr = wp.zeros(shape=(2, 3), dtype=wp.float32, device=device)
2963
+ wp.launch(kernel_array_from_ptr_variable_shape, dim=(1,), inputs=(arr.ptr, 2, 2), device=device)
2964
+ assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 0.0), (0.0, 0.0, 0.0))))
2965
+ wp.launch(kernel_array_from_ptr_variable_shape, dim=(1,), inputs=(arr.ptr, 2, 3), device=device)
2966
+ assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 3.0), (0.0, 0.0, 0.0))))
2967
+
2968
+
2969
+ def test_array_from_int32_domain(test, device):
2970
+ wp.zeros(np.array([1504, 1080, 520], dtype=np.int32), dtype=wp.float32, device=device)
2971
+
2972
+
2973
+ def test_array_from_int64_domain(test, device):
2974
+ wp.zeros(np.array([1504, 1080, 520], dtype=np.int64), dtype=wp.float32, device=device)
2975
+
2976
+
2977
+ def test_numpy_array_interface(test, device):
2978
+ # We should be able to convert between NumPy and Warp arrays using __array_interface__ on CPU.
2979
+ # This tests all scalar types supported by both.
2980
+
2981
+ n = 10
2982
+
2983
+ scalar_types = wp._src.types.scalar_types
2984
+
2985
+ for dtype in scalar_types:
2986
+ # test round trip
2987
+ a1 = wp.zeros(n, dtype=dtype, device="cpu")
2988
+ na = np.array(a1)
2989
+ a2 = wp.array(na, device="cpu")
2990
+
2991
+ assert a1.dtype == a2.dtype
2992
+ assert a1.shape == a2.shape
2993
+ assert a1.strides == a2.strides
2994
+
2995
+
2996
+ @wp.kernel
2997
+ def kernel_indexing_types(
2998
+ arr_1d: wp.array(dtype=wp.int32, ndim=1),
2999
+ arr_2d: wp.array(dtype=wp.int32, ndim=2),
3000
+ arr_3d: wp.array(dtype=wp.int32, ndim=3),
3001
+ arr_4d: wp.array(dtype=wp.int32, ndim=4),
3002
+ ):
3003
+ x = arr_1d[wp.uint8(0)]
3004
+ y = arr_1d[wp.int16(1)]
3005
+ z = arr_1d[wp.uint32(2)]
3006
+ w = arr_1d[wp.int64(3)]
3007
+
3008
+ x = arr_2d[wp.uint8(0), wp.uint8(0)]
3009
+ y = arr_2d[wp.int16(1), wp.int16(1)]
3010
+ z = arr_2d[wp.uint32(2), wp.uint32(2)]
3011
+ w = arr_2d[wp.int64(3), wp.int64(3)]
3012
+
3013
+ x = arr_3d[wp.uint8(0), wp.uint8(0), wp.uint8(0)]
3014
+ y = arr_3d[wp.int16(1), wp.int16(1), wp.int16(1)]
3015
+ z = arr_3d[wp.uint32(2), wp.uint32(2), wp.uint32(2)]
3016
+ w = arr_3d[wp.int64(3), wp.int64(3), wp.int64(3)]
3017
+
3018
+ x = arr_4d[wp.uint8(0), wp.uint8(0), wp.uint8(0), wp.uint8(0)]
3019
+ y = arr_4d[wp.int16(1), wp.int16(1), wp.int16(1), wp.int16(1)]
3020
+ z = arr_4d[wp.uint32(2), wp.uint32(2), wp.uint32(2), wp.uint32(2)]
3021
+ w = arr_4d[wp.int64(3), wp.int64(3), wp.int64(3), wp.int64(3)]
3022
+
3023
+ arr_1d[wp.uint8(0)] = 123
3024
+ arr_1d[wp.int16(1)] = 123
3025
+ arr_1d[wp.uint32(2)] = 123
3026
+ arr_1d[wp.int64(3)] = 123
3027
+
3028
+ arr_2d[wp.uint8(0), wp.uint8(0)] = 123
3029
+ arr_2d[wp.int16(1), wp.int16(1)] = 123
3030
+ arr_2d[wp.uint32(2), wp.uint32(2)] = 123
3031
+ arr_2d[wp.int64(3), wp.int64(3)] = 123
3032
+
3033
+ arr_3d[wp.uint8(0), wp.uint8(0), wp.uint8(0)] = 123
3034
+ arr_3d[wp.int16(1), wp.int16(1), wp.int16(1)] = 123
3035
+ arr_3d[wp.uint32(2), wp.uint32(2), wp.uint32(2)] = 123
3036
+ arr_3d[wp.int64(3), wp.int64(3), wp.int64(3)] = 123
3037
+
3038
+ arr_4d[wp.uint8(0), wp.uint8(0), wp.uint8(0), wp.uint8(0)] = 123
3039
+ arr_4d[wp.int16(1), wp.int16(1), wp.int16(1), wp.int16(1)] = 123
3040
+ arr_4d[wp.uint32(2), wp.uint32(2), wp.uint32(2), wp.uint32(2)] = 123
3041
+ arr_4d[wp.int64(3), wp.int64(3), wp.int64(3), wp.int64(3)] = 123
3042
+
3043
+ wp.atomic_add(arr_1d, wp.uint8(0), 123)
3044
+ wp.atomic_sub(arr_1d, wp.int16(1), 123)
3045
+ wp.atomic_min(arr_1d, wp.uint32(2), 123)
3046
+ wp.atomic_max(arr_1d, wp.int64(3), 123)
3047
+
3048
+ wp.atomic_add(arr_2d, wp.uint8(0), wp.uint8(0), 123)
3049
+ wp.atomic_sub(arr_2d, wp.int16(1), wp.int16(1), 123)
3050
+ wp.atomic_min(arr_2d, wp.uint32(2), wp.uint32(2), 123)
3051
+ wp.atomic_max(arr_2d, wp.int64(3), wp.int64(3), 123)
3052
+
3053
+ wp.atomic_add(arr_3d, wp.uint8(0), wp.uint8(0), wp.uint8(0), 123)
3054
+ wp.atomic_sub(arr_3d, wp.int16(1), wp.int16(1), wp.int16(1), 123)
3055
+ wp.atomic_min(arr_3d, wp.uint32(2), wp.uint32(2), wp.uint32(2), 123)
3056
+ wp.atomic_max(arr_3d, wp.int64(3), wp.int64(3), wp.int64(3), 123)
3057
+
3058
+ wp.atomic_add(arr_4d, wp.uint8(0), wp.uint8(0), wp.uint8(0), wp.uint8(0), 123)
3059
+ wp.atomic_sub(arr_4d, wp.int16(1), wp.int16(1), wp.int16(1), wp.int16(1), 123)
3060
+ wp.atomic_min(arr_4d, wp.uint32(2), wp.uint32(2), wp.uint32(2), wp.uint32(2), 123)
3061
+ wp.atomic_max(arr_4d, wp.int64(3), wp.int64(3), wp.int64(3), wp.int64(3), 123)
3062
+
3063
+
3064
+ def test_indexing_types(test, device):
3065
+ arr_1d = wp.zeros(shape=(4,), dtype=wp.int32, device=device)
3066
+ arr_2d = wp.zeros(shape=(4, 4), dtype=wp.int32, device=device)
3067
+ arr_3d = wp.zeros(shape=(4, 4, 4), dtype=wp.int32, device=device)
3068
+ arr_4d = wp.zeros(shape=(4, 4, 4, 4), dtype=wp.int32, device=device)
3069
+ wp.launch(
3070
+ kernel=kernel_indexing_types,
3071
+ dim=1,
3072
+ inputs=(arr_1d, arr_2d, arr_3d, arr_4d),
3073
+ device=device,
3074
+ )
3075
+
3076
+
3077
+ def test_alloc_strides(test, device):
3078
+ def test_transposed(shape, dtype):
3079
+ # allocate without specifying strides
3080
+ a1 = wp.zeros(shape, dtype=dtype)
3081
+
3082
+ # allocate with contiguous strides
3083
+ strides = wp._src.types.strides_from_shape(shape, dtype)
3084
+ a2 = wp.zeros(shape, dtype=dtype, strides=strides)
3085
+
3086
+ # allocate with transposed (reversed) shape/strides
3087
+ rshape = shape[::-1]
3088
+ rstrides = strides[::-1]
3089
+ a3 = wp.zeros(rshape, dtype=dtype, strides=rstrides)
3090
+
3091
+ # ensure that correct capacity was allocated
3092
+ assert a2.capacity == a1.capacity
3093
+ assert a3.capacity == a1.capacity
3094
+
3095
+ with wp.ScopedDevice(device):
3096
+ shapes = [(5, 5), (5, 3), (3, 5), (2, 3, 4), (4, 2, 3), (3, 2, 4)]
3097
+ for shape in shapes:
3098
+ with test.subTest(msg=f"shape={shape}"):
3099
+ test_transposed(shape, wp.int8)
3100
+ test_transposed(shape, wp.float32)
3101
+ test_transposed(shape, wp.vec3)
3102
+
3103
+
3104
+ def test_casting(test, device):
3105
+ idxs = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
3106
+ idxs = wp.array(idxs, device=device, dtype=wp.int32).reshape((-1, 3))
3107
+ idxs = wp.array(idxs, shape=idxs.shape[0], dtype=wp.vec3i, device=device)
3108
+ assert idxs.dtype is wp.vec3i
3109
+ assert idxs.shape == (4,)
3110
+ assert idxs.strides == (12,)
3111
+
3112
+
3113
+ @wp.kernel
3114
+ def array_len_kernel(
3115
+ a1: wp.array(dtype=int),
3116
+ a2: wp.array(dtype=float, ndim=3),
3117
+ out: wp.array(dtype=int),
3118
+ ):
3119
+ length = len(a1)
3120
+ wp.expect_eq(len(a1), 123)
3121
+ out[0] = len(a1)
3122
+
3123
+ length = len(a2)
3124
+ wp.expect_eq(len(a2), 2)
3125
+ out[1] = len(a2)
3126
+
3127
+
3128
+ def test_array_len(test, device):
3129
+ a1 = wp.zeros(123, dtype=int, device=device)
3130
+ a2 = wp.zeros((2, 3, 4), dtype=float, device=device)
3131
+ out = wp.empty(2, dtype=int, device=device)
3132
+ wp.launch(
3133
+ array_len_kernel,
3134
+ dim=(1,),
3135
+ inputs=(
3136
+ a1,
3137
+ a2,
3138
+ ),
3139
+ outputs=(out,),
3140
+ device=device,
3141
+ )
3142
+
3143
+ test.assertEqual(out.numpy()[0], 123)
3144
+ test.assertEqual(out.numpy()[1], 2)
3145
+
3146
+
3147
+ def test_cuda_interface_conversion(test, device):
3148
+ class MyArrayInterface:
3149
+ def __init__(self, data, npdtype):
3150
+ self.data = np.array(data, dtype=npdtype)
3151
+ self.__array_interface__ = self.data.__array_interface__
3152
+ self.__cuda_array_interface__ = self.data.__array_interface__
3153
+ self.__len__ = self.data.__len__
3154
+
3155
+ array = MyArrayInterface((1, 2, 3), np.int8)
3156
+ wp_array = wp.array(array, dtype=wp.int8, device=device)
3157
+ assert wp_array.ptr != 0
3158
+
3159
+ array = MyArrayInterface((1, 2, 3), np.float32)
3160
+ wp_array = wp.array(array, dtype=wp.float32, device=device)
3161
+ assert wp_array.ptr != 0
3162
+
3163
+ array = MyArrayInterface((1, 2, 3), np.float32)
3164
+ wp_array = wp.array(array, dtype=wp.vec3, device=device)
3165
+ assert wp_array.ptr != 0
3166
+
3167
+ array = MyArrayInterface((1, 2, 3, 4), np.float32)
3168
+ wp_array = wp.array(array, dtype=wp.mat22, device=device)
3169
+ assert wp_array.ptr != 0
3170
+
3171
+
3172
+ @wp.kernel
3173
+ def test_array1d_slicing_kernel(arr: wp.array1d(dtype=int)):
3174
+ sub = arr[:3]
3175
+ wp.expect_eq(sub.ndim, 1)
3176
+ wp.expect_eq(sub.shape[0], 3)
3177
+ wp.expect_eq(sub.shape[1], 0)
3178
+ wp.expect_eq(sub[0], 0)
3179
+ wp.expect_eq(sub[2], 2)
3180
+
3181
+ sub = arr[3:5]
3182
+ wp.expect_eq(sub.ndim, 1)
3183
+ wp.expect_eq(sub.shape[0], 2)
3184
+ wp.expect_eq(sub.shape[1], 0)
3185
+ wp.expect_eq(sub[0], 3)
3186
+ wp.expect_eq(sub[1], 4)
3187
+
3188
+ sub = arr[3::-1]
3189
+ wp.expect_eq(sub.ndim, 1)
3190
+ wp.expect_eq(sub.shape[0], 4)
3191
+ wp.expect_eq(sub.shape[1], 0)
3192
+ wp.expect_eq(sub[0], 3)
3193
+ wp.expect_eq(sub[3], 0)
3194
+
3195
+ sub = arr[::-3]
3196
+ sub = sub[::2]
3197
+ wp.expect_eq(sub.ndim, 1)
3198
+ wp.expect_eq(sub.shape[0], 3)
3199
+ wp.expect_eq(sub.shape[1], 0)
3200
+ wp.expect_eq(sub[0], 15)
3201
+ wp.expect_eq(sub[2], 3)
3202
+
3203
+
3204
+ def test_array1d_slicing(test, device):
3205
+ arr = wp.array(tuple(range(16)), dtype=int, device=device)
3206
+ wp.launch(test_array1d_slicing_kernel, dim=1, inputs=(arr,), device=device)
3207
+
3208
+
3209
+ @wp.kernel
3210
+ def test_array2d_slicing_kernel(arr: wp.array2d(dtype=int)):
3211
+ sub = arr[:2]
3212
+ wp.expect_eq(sub.ndim, 2)
3213
+ wp.expect_eq(sub.shape[0], 2)
3214
+ wp.expect_eq(sub.shape[1], 4)
3215
+ wp.expect_eq(sub.shape[2], 0)
3216
+ wp.expect_eq(sub[0, 0], 0)
3217
+ wp.expect_eq(sub[1, 3], 7)
3218
+
3219
+ sub = arr[:2, 1]
3220
+ wp.expect_eq(sub.ndim, 1)
3221
+ wp.expect_eq(sub.shape[0], 2)
3222
+ wp.expect_eq(sub.shape[1], 0)
3223
+ wp.expect_eq(sub[0], 1)
3224
+ wp.expect_eq(sub[1], 5)
3225
+
3226
+ sub = arr[-4, :3]
3227
+ wp.expect_eq(sub.ndim, 1)
3228
+ wp.expect_eq(sub.shape[0], 3)
3229
+ wp.expect_eq(sub.shape[1], 0)
3230
+ wp.expect_eq(sub[0], 16)
3231
+ wp.expect_eq(sub[2], 18)
3232
+
3233
+ sub = arr[3:5, 3:1:-1]
3234
+ wp.expect_eq(sub.ndim, 2)
3235
+ wp.expect_eq(sub.shape[0], 2)
3236
+ wp.expect_eq(sub.shape[1], 2)
3237
+ wp.expect_eq(sub.shape[2], 0)
3238
+ wp.expect_eq(sub[0, 0], 15)
3239
+ wp.expect_eq(sub[1, 1], 18)
3240
+
3241
+ sub = arr[::4]
3242
+ sub = sub[:, ::-3]
3243
+ wp.expect_eq(sub.ndim, 2)
3244
+ wp.expect_eq(sub.shape[0], 2)
3245
+ wp.expect_eq(sub.shape[1], 2)
3246
+ wp.expect_eq(sub.shape[2], 0)
3247
+ wp.expect_eq(sub[0, 0], 3)
3248
+ wp.expect_eq(sub[1, 1], 16)
3249
+
3250
+
3251
+ def test_array2d_slicing(test, device):
3252
+ arr = wp.array(tuple(range(32)), dtype=int, shape=(8, 4), device=device)
3253
+ wp.launch(test_array2d_slicing_kernel, dim=1, inputs=(arr,), device=device)
3254
+
3255
+
3256
+ @wp.kernel
3257
+ def test_array3d_slicing_kernel(arr: wp.array3d(dtype=int)):
3258
+ sub = arr[-1:]
3259
+ wp.expect_eq(sub.ndim, 3)
3260
+ wp.expect_eq(sub.shape[0], 1)
3261
+ wp.expect_eq(sub.shape[1], 8)
3262
+ wp.expect_eq(sub.shape[2], 4)
3263
+ wp.expect_eq(sub.shape[3], 0)
3264
+ wp.expect_eq(sub[0, 0, 0], 32)
3265
+ wp.expect_eq(sub[0, 7, 3], 63)
3266
+
3267
+ sub = arr[:2, -3]
3268
+ wp.expect_eq(sub.ndim, 2)
3269
+ wp.expect_eq(sub.shape[0], 2)
3270
+ wp.expect_eq(sub.shape[1], 4)
3271
+ wp.expect_eq(sub.shape[2], 0)
3272
+ wp.expect_eq(sub[0, 0], 20)
3273
+ wp.expect_eq(sub[1, 3], 55)
3274
+
3275
+ sub = arr[1, 2:]
3276
+ wp.expect_eq(sub.ndim, 2)
3277
+ wp.expect_eq(sub.shape[0], 6)
3278
+ wp.expect_eq(sub.shape[1], 4)
3279
+ wp.expect_eq(sub.shape[2], 0)
3280
+ wp.expect_eq(sub[0, 0], 40)
3281
+ wp.expect_eq(sub[5, 3], 63)
3282
+
3283
+ sub = arr[:1, 3:1:-1]
3284
+ wp.expect_eq(sub.ndim, 3)
3285
+ wp.expect_eq(sub.shape[0], 1)
3286
+ wp.expect_eq(sub.shape[1], 2)
3287
+ wp.expect_eq(sub.shape[2], 4)
3288
+ wp.expect_eq(sub.shape[3], 0)
3289
+ wp.expect_eq(sub[0, 0, 0], 12)
3290
+ wp.expect_eq(sub[0, 1, 3], 11)
3291
+
3292
+ sub = arr[::-2, 1, 3]
3293
+ wp.expect_eq(sub.ndim, 1)
3294
+ wp.expect_eq(sub.shape[0], 1)
3295
+ wp.expect_eq(sub.shape[1], 0)
3296
+ wp.expect_eq(sub[0], 39)
3297
+
3298
+ sub = arr[0, 2:5, -3]
3299
+ wp.expect_eq(sub.ndim, 1)
3300
+ wp.expect_eq(sub.shape[0], 3)
3301
+ wp.expect_eq(sub.shape[1], 0)
3302
+ wp.expect_eq(sub[0], 9)
3303
+ wp.expect_eq(sub[2], 17)
3304
+
3305
+ sub = arr[0, -2, ::2]
3306
+ wp.expect_eq(sub.ndim, 1)
3307
+ wp.expect_eq(sub.shape[0], 2)
3308
+ wp.expect_eq(sub.shape[1], 0)
3309
+ wp.expect_eq(sub[0], 24)
3310
+ wp.expect_eq(sub[1], 26)
3311
+
3312
+ sub = arr[-1:, :5:2, 0]
3313
+ wp.expect_eq(sub.ndim, 2)
3314
+ wp.expect_eq(sub.shape[0], 1)
3315
+ wp.expect_eq(sub.shape[1], 3)
3316
+ wp.expect_eq(sub.shape[2], 0)
3317
+ wp.expect_eq(sub[0, 0], 32)
3318
+ wp.expect_eq(sub[0, 2], 48)
3319
+
3320
+ sub = arr[:, 0, ::2]
3321
+ wp.expect_eq(sub.ndim, 2)
3322
+ wp.expect_eq(sub.shape[0], 2)
3323
+ wp.expect_eq(sub.shape[1], 2)
3324
+ wp.expect_eq(sub.shape[2], 0)
3325
+ wp.expect_eq(sub[0, 0], 0)
3326
+ wp.expect_eq(sub[1, 1], 34)
3327
+
3328
+ sub = arr[1, ::-4, ::-3]
3329
+ wp.expect_eq(sub.ndim, 2)
3330
+ wp.expect_eq(sub.shape[0], 2)
3331
+ wp.expect_eq(sub.shape[1], 2)
3332
+ wp.expect_eq(sub.shape[2], 0)
3333
+ wp.expect_eq(sub[0, 0], 63)
3334
+ wp.expect_eq(sub[1, 1], 44)
3335
+
3336
+ sub = arr[::2, :3:, -2:]
3337
+ wp.expect_eq(sub.ndim, 3)
3338
+ wp.expect_eq(sub.shape[0], 1)
3339
+ wp.expect_eq(sub.shape[1], 3)
3340
+ wp.expect_eq(sub.shape[2], 2)
3341
+ wp.expect_eq(sub.shape[3], 0)
3342
+ wp.expect_eq(sub[0, 0, 0], 2)
3343
+ wp.expect_eq(sub[0, 2, 1], 11)
3344
+
3345
+ sub = arr[:, :1]
3346
+ sub = sub[:, :, :2]
3347
+ wp.expect_eq(sub.ndim, 3)
3348
+ wp.expect_eq(sub.shape[0], 2)
3349
+ wp.expect_eq(sub.shape[1], 1)
3350
+ wp.expect_eq(sub.shape[2], 2)
3351
+ wp.expect_eq(sub.shape[3], 0)
3352
+ wp.expect_eq(sub[0, 0, 0], 0)
3353
+ wp.expect_eq(sub[1, 0, 1], 33)
3354
+
3355
+
3356
+ def test_array3d_slicing(test, device):
3357
+ arr = wp.array(tuple(range(64)), dtype=int, shape=(2, 8, 4), device=device)
3358
+ wp.launch(test_array3d_slicing_kernel, dim=1, inputs=(arr,), device=device)
3359
+
3360
+
3361
+ @wp.kernel
3362
+ def test_array4d_slicing_kernel(arr: wp.array4d(dtype=int)):
3363
+ sub = arr[:1]
3364
+ wp.expect_eq(sub.ndim, 4)
3365
+ wp.expect_eq(sub.shape[0], 1)
3366
+ wp.expect_eq(sub.shape[1], 2)
3367
+ wp.expect_eq(sub.shape[2], 2)
3368
+ wp.expect_eq(sub.shape[3], 4)
3369
+ wp.expect_eq(sub[0, 0, 0, 0], 0)
3370
+ wp.expect_eq(sub[0, 1, 1, 3], 15)
3371
+
3372
+ sub = arr[2:, 0]
3373
+ wp.expect_eq(sub.ndim, 3)
3374
+ wp.expect_eq(sub.shape[0], 2)
3375
+ wp.expect_eq(sub.shape[1], 2)
3376
+ wp.expect_eq(sub.shape[2], 4)
3377
+ wp.expect_eq(sub.shape[3], 0)
3378
+ wp.expect_eq(sub[0, 0, 0], 32)
3379
+ wp.expect_eq(sub[1, 1, 3], 55)
3380
+
3381
+ sub = arr[-1, -1:]
3382
+ wp.expect_eq(sub.ndim, 3)
3383
+ wp.expect_eq(sub.shape[0], 1)
3384
+ wp.expect_eq(sub.shape[1], 2)
3385
+ wp.expect_eq(sub.shape[2], 4)
3386
+ wp.expect_eq(sub.shape[3], 0)
3387
+ wp.expect_eq(sub[0, 0, 0], 56)
3388
+ wp.expect_eq(sub[0, 1, 3], 63)
3389
+
3390
+ sub = arr[3:4, :1]
3391
+ wp.expect_eq(sub.ndim, 4)
3392
+ wp.expect_eq(sub.shape[0], 1)
3393
+ wp.expect_eq(sub.shape[1], 1)
3394
+ wp.expect_eq(sub.shape[2], 2)
3395
+ wp.expect_eq(sub.shape[3], 4)
3396
+ wp.expect_eq(sub[0, 0, 0, 0], 48)
3397
+ wp.expect_eq(sub[0, 0, 1, 3], 55)
3398
+
3399
+ sub = arr[2::, 0, -1]
3400
+ wp.expect_eq(sub.ndim, 2)
3401
+ wp.expect_eq(sub.shape[0], 2)
3402
+ wp.expect_eq(sub.shape[1], 4)
3403
+ wp.expect_eq(sub.shape[2], 0)
3404
+ wp.expect_eq(sub[0, 0], 36)
3405
+ wp.expect_eq(sub[1, 3], 55)
3406
+
3407
+ sub = arr[-2, ::2, -2]
3408
+ wp.expect_eq(sub.ndim, 2)
3409
+ wp.expect_eq(sub.shape[0], 1)
3410
+ wp.expect_eq(sub.shape[1], 4)
3411
+ wp.expect_eq(sub.shape[2], 0)
3412
+ wp.expect_eq(sub[0, 0], 32)
3413
+ wp.expect_eq(sub[0, 3], 35)
3414
+
3415
+ sub = arr[1, -1, ::-3]
3416
+ wp.expect_eq(sub.ndim, 2)
3417
+ wp.expect_eq(sub.shape[0], 1)
3418
+ wp.expect_eq(sub.shape[1], 4)
3419
+ wp.expect_eq(sub.shape[2], 0)
3420
+ wp.expect_eq(sub[0, 0], 28)
3421
+ wp.expect_eq(sub[0, 3], 31)
3422
+
3423
+ sub = arr[1::2, :-1, 0]
3424
+ wp.expect_eq(sub.ndim, 3)
3425
+ wp.expect_eq(sub.shape[0], 2)
3426
+ wp.expect_eq(sub.shape[1], 1)
3427
+ wp.expect_eq(sub.shape[2], 4)
3428
+ wp.expect_eq(sub.shape[3], 0)
3429
+ wp.expect_eq(sub[0, 0, 0], 16)
3430
+ wp.expect_eq(sub[1, 0, 3], 51)
3431
+
3432
+ sub = arr[:2, 1, 1:]
3433
+ wp.expect_eq(sub.ndim, 3)
3434
+ wp.expect_eq(sub.shape[0], 2)
3435
+ wp.expect_eq(sub.shape[1], 1)
3436
+ wp.expect_eq(sub.shape[2], 4)
3437
+ wp.expect_eq(sub.shape[3], 0)
3438
+ wp.expect_eq(sub[0, 0, 0], 12)
3439
+ wp.expect_eq(sub[1, 0, 3], 31)
3440
+
3441
+ sub = arr[-1, :1, ::-3]
3442
+ wp.expect_eq(sub.ndim, 3)
3443
+ wp.expect_eq(sub.shape[0], 1)
3444
+ wp.expect_eq(sub.shape[1], 1)
3445
+ wp.expect_eq(sub.shape[2], 4)
3446
+ wp.expect_eq(sub.shape[3], 0)
3447
+ wp.expect_eq(sub[0, 0, 0], 52)
3448
+ wp.expect_eq(sub[0, 0, 3], 55)
3449
+
3450
+ sub = arr[::-4, :1, 1:]
3451
+ wp.expect_eq(sub.ndim, 4)
3452
+ wp.expect_eq(sub.shape[0], 1)
3453
+ wp.expect_eq(sub.shape[1], 1)
3454
+ wp.expect_eq(sub.shape[2], 1)
3455
+ wp.expect_eq(sub.shape[3], 4)
3456
+ wp.expect_eq(sub[0, 0, 0, 0], 52)
3457
+ wp.expect_eq(sub[0, 0, 0, 3], 55)
3458
+
3459
+ sub = arr[:2, 0, 1, 2]
3460
+ wp.expect_eq(sub.ndim, 1)
3461
+ wp.expect_eq(sub.shape[0], 2)
3462
+ wp.expect_eq(sub.shape[1], 0)
3463
+ wp.expect_eq(sub[0], 6)
3464
+ wp.expect_eq(sub[1], 22)
3465
+
3466
+ sub = arr[-3, ::2, 0, 2]
3467
+ wp.expect_eq(sub.ndim, 1)
3468
+ wp.expect_eq(sub.shape[0], 1)
3469
+ wp.expect_eq(sub.shape[1], 0)
3470
+ wp.expect_eq(sub[0], 18)
3471
+
3472
+ sub = arr[2, 0, :-1, 1]
3473
+ wp.expect_eq(sub.ndim, 1)
3474
+ wp.expect_eq(sub.shape[0], 1)
3475
+ wp.expect_eq(sub.shape[1], 0)
3476
+ wp.expect_eq(sub[0], 33)
3477
+
3478
+ sub = arr[1, 0, 1, :]
3479
+ wp.expect_eq(sub.ndim, 1)
3480
+ wp.expect_eq(sub.shape[0], 4)
3481
+ wp.expect_eq(sub.shape[1], 0)
3482
+ wp.expect_eq(sub[0], 20)
3483
+ wp.expect_eq(sub[3], 23)
3484
+
3485
+ sub = arr[1:, :2, 1, -3]
3486
+ wp.expect_eq(sub.ndim, 2)
3487
+ wp.expect_eq(sub.shape[0], 3)
3488
+ wp.expect_eq(sub.shape[1], 2)
3489
+ wp.expect_eq(sub.shape[2], 0)
3490
+ wp.expect_eq(sub[0, 0], 21)
3491
+ wp.expect_eq(sub[2, 1], 61)
3492
+
3493
+ sub = arr[2:, 0, :2, 1]
3494
+ wp.expect_eq(sub.ndim, 2)
3495
+ wp.expect_eq(sub.shape[0], 2)
3496
+ wp.expect_eq(sub.shape[1], 2)
3497
+ wp.expect_eq(sub.shape[2], 0)
3498
+ wp.expect_eq(sub[0, 0], 33)
3499
+ wp.expect_eq(sub[1, 1], 53)
3500
+
3501
+ sub = arr[::-2, 0, 0, ::3]
3502
+ wp.expect_eq(sub.ndim, 2)
3503
+ wp.expect_eq(sub.shape[0], 2)
3504
+ wp.expect_eq(sub.shape[1], 2)
3505
+ wp.expect_eq(sub.shape[2], 0)
3506
+ wp.expect_eq(sub[0, 0], 48)
3507
+ wp.expect_eq(sub[1, 1], 19)
3508
+
3509
+ sub = arr[-2, 1:2, ::-1, 0]
3510
+ wp.expect_eq(sub.ndim, 2)
3511
+ wp.expect_eq(sub.shape[0], 1)
3512
+ wp.expect_eq(sub.shape[1], 2)
3513
+ wp.expect_eq(sub.shape[2], 0)
3514
+ wp.expect_eq(sub[0, 0], 44)
3515
+ wp.expect_eq(sub[0, 1], 40)
3516
+
3517
+ sub = arr[1, :2, 0, ::-2]
3518
+ wp.expect_eq(sub.ndim, 2)
3519
+ wp.expect_eq(sub.shape[0], 2)
3520
+ wp.expect_eq(sub.shape[1], 2)
3521
+ wp.expect_eq(sub.shape[2], 0)
3522
+ wp.expect_eq(sub[0, 0], 19)
3523
+ wp.expect_eq(sub[1, 1], 25)
3524
+
3525
+ sub = arr[-1, 0, ::3, -4:-1]
3526
+ wp.expect_eq(sub.ndim, 2)
3527
+ wp.expect_eq(sub.shape[0], 1)
3528
+ wp.expect_eq(sub.shape[1], 3)
3529
+ wp.expect_eq(sub.shape[2], 0)
3530
+ wp.expect_eq(sub[0, 0], 48)
3531
+ wp.expect_eq(sub[0, 2], 50)
3532
+
3533
+ sub = arr[-2:, 1:2, ::3, 1]
3534
+ wp.expect_eq(sub.ndim, 3)
3535
+ wp.expect_eq(sub.shape[0], 2)
3536
+ wp.expect_eq(sub.shape[1], 1)
3537
+ wp.expect_eq(sub.shape[2], 1)
3538
+ wp.expect_eq(sub.shape[3], 0)
3539
+ wp.expect_eq(sub[0, 0, 0], 41)
3540
+ wp.expect_eq(sub[1, 0, 0], 57)
3541
+
3542
+ sub = arr[:1, :, 1, -2:]
3543
+ wp.expect_eq(sub.ndim, 3)
3544
+ wp.expect_eq(sub.shape[0], 1)
3545
+ wp.expect_eq(sub.shape[1], 2)
3546
+ wp.expect_eq(sub.shape[2], 2)
3547
+ wp.expect_eq(sub.shape[3], 0)
3548
+ wp.expect_eq(sub[0, 0, 0], 6)
3549
+ wp.expect_eq(sub[0, 1, 1], 15)
3550
+
3551
+ sub = arr[:2:-1, 0, -1:, ::-1]
3552
+ wp.expect_eq(sub.ndim, 3)
3553
+ wp.expect_eq(sub.shape[0], 1)
3554
+ wp.expect_eq(sub.shape[1], 1)
3555
+ wp.expect_eq(sub.shape[2], 4)
3556
+ wp.expect_eq(sub.shape[3], 0)
3557
+ wp.expect_eq(sub[0, 0, 0], 55)
3558
+ wp.expect_eq(sub[0, 0, 3], 52)
3559
+
3560
+ sub = arr[-2, ::-1, -2:, 1:3]
3561
+ wp.expect_eq(sub.ndim, 3)
3562
+ wp.expect_eq(sub.shape[0], 2)
3563
+ wp.expect_eq(sub.shape[1], 2)
3564
+ wp.expect_eq(sub.shape[2], 2)
3565
+ wp.expect_eq(sub.shape[3], 0)
3566
+ wp.expect_eq(sub[0, 0, 0], 41)
3567
+ wp.expect_eq(sub[1, 1, 1], 38)
3568
+
3569
+ sub = arr[:2, 1:, 1:, :-2]
3570
+ wp.expect_eq(sub.ndim, 4)
3571
+ wp.expect_eq(sub.shape[0], 2)
3572
+ wp.expect_eq(sub.shape[1], 1)
3573
+ wp.expect_eq(sub.shape[2], 1)
3574
+ wp.expect_eq(sub.shape[3], 2)
3575
+ wp.expect_eq(sub[0, 0, 0, 0], 12)
3576
+ wp.expect_eq(sub[1, 0, 0, 1], 29)
3577
+
3578
+ sub = arr[-2:, 1, ::-1]
3579
+ sub = sub[::2]
3580
+ wp.expect_eq(sub.ndim, 3)
3581
+ wp.expect_eq(sub.shape[0], 1)
3582
+ wp.expect_eq(sub.shape[1], 2)
3583
+ wp.expect_eq(sub.shape[2], 4)
3584
+ wp.expect_eq(sub.shape[3], 0)
3585
+ wp.expect_eq(sub[0, 0, 0], 44)
3586
+ wp.expect_eq(sub[0, 1, 3], 43)
3587
+
3588
+
3589
+ def test_array4d_slicing(test, device):
3590
+ arr = wp.array(tuple(range(64)), dtype=int, shape=(4, 2, 2, 4), device=device)
3591
+ wp.launch(test_array4d_slicing_kernel, dim=1, inputs=(arr,), device=device)
3592
+
3593
+
3594
+ def test_graph_fill_vecmat(test, device):
3595
+ """Make sure the fill value persists with the graph."""
3596
+
3597
+ def _fill_vecmat(vecmat_arr, scalar_value):
3598
+ # fill array with a local/temporary value, which must be retained by the graph
3599
+ vecmat_arr.fill_(vecmat_arr.dtype(scalar_value))
3600
+
3601
+ def _run_tests(arrays):
3602
+ # create captures using temporary fill values, different for each array
3603
+ captures = []
3604
+ for i, arr in enumerate(arrays):
3605
+ scalar_value = i + 1
3606
+ with wp.ScopedCapture() as capture:
3607
+ _fill_vecmat(arr, scalar_value)
3608
+ captures.append(capture)
3609
+
3610
+ # make sure each graph fills its array with the correct value
3611
+ for i, arr in enumerate(arrays):
3612
+ with test.subTest(msg=f"array type={type(arr)}, dtype={arr.dtype}"):
3613
+ wp.capture_launch(captures[i].graph)
3614
+
3615
+ expected_scalar_value = i + 1
3616
+ np_dtype = wp.dtype_to_numpy(arr.dtype._wp_scalar_type_)
3617
+ np_shape = (*arr.shape, *arr.dtype._shape_)
3618
+ expected = np.full(np_shape, expected_scalar_value, dtype=np_dtype)
3619
+
3620
+ assert_np_equal(arr.numpy(), expected)
3621
+
3622
+ with wp.ScopedDevice(device):
3623
+ from warp._src.types import vector_types
3624
+
3625
+ # create arrays with different vector/matrix types
3626
+ n = 1000
3627
+ contiguous_arrays = []
3628
+ strided_arrays = []
3629
+ indexed_arrays = []
3630
+ indices = wp.array(np.arange(n, dtype=np.int32))
3631
+ for vectype in vector_types:
3632
+ contiguous_arrays.append(wp.zeros(n, dtype=vectype))
3633
+ strided_arrays.append(wp.zeros(n * 2, dtype=vectype)[::2])
3634
+ indexed_arrays.append(wp.zeros(n, dtype=vectype)[indices])
3635
+
3636
+ # test the different array types
3637
+ _run_tests(contiguous_arrays)
3638
+ _run_tests(strided_arrays)
3639
+ _run_tests(indexed_arrays)
3640
+
3641
+
3642
+ devices = get_test_devices()
3643
+ cuda_devices = get_cuda_test_devices()
3644
+
3645
+
3646
+ class TestArray(unittest.TestCase):
3647
+ def test_array_new_del(self):
3648
+ # test the scenario in which an array instance is created but not initialized before gc
3649
+ instance = wp.array.__new__(wp.array)
3650
+ instance.__del__()
3651
+
3652
+
3653
+ add_function_test(TestArray, "test_shape", test_shape, devices=devices)
3654
+ add_function_test(TestArray, "test_negative_shape", test_negative_shape, devices=devices)
3655
+ add_function_test(TestArray, "test_flatten", test_flatten, devices=devices)
3656
+ add_function_test(TestArray, "test_reshape", test_reshape, devices=devices)
3657
+
3658
+ add_function_test(TestArray, "test_slicing", test_slicing, devices=devices)
3659
+ add_function_test(TestArray, "test_transpose", test_transpose, devices=devices)
3660
+ add_function_test(TestArray, "test_view", test_view, devices=devices)
3661
+ add_function_test(TestArray, "test_clone_adjoint", test_clone_adjoint, devices=devices)
3662
+ add_function_test(TestArray, "test_assign_adjoint", test_assign_adjoint, devices=devices)
3663
+
3664
+ add_function_test(TestArray, "test_1d_array", test_1d, devices=devices)
3665
+ add_function_test(TestArray, "test_2d_array", test_2d, devices=devices)
3666
+ add_function_test(TestArray, "test_3d_array", test_3d, devices=devices)
3667
+ add_function_test(TestArray, "test_4d_array", test_4d, devices=devices)
3668
+ add_function_test(TestArray, "test_4d_array_transposed", test_4d_transposed, devices=devices)
3669
+
3670
+ add_function_test(TestArray, "test_fill_scalar", test_fill_scalar, devices=devices)
3671
+ add_function_test(TestArray, "test_fill_vector", test_fill_vector, devices=devices)
3672
+ add_function_test(TestArray, "test_fill_matrix", test_fill_matrix, devices=devices)
3673
+ add_function_test(TestArray, "test_fill_struct", test_fill_struct, devices=devices)
3674
+ add_function_test(TestArray, "test_fill_slices", test_fill_slices, devices=devices)
3675
+ add_function_test(TestArray, "test_full_scalar", test_full_scalar, devices=devices)
3676
+ add_function_test(TestArray, "test_full_vector", test_full_vector, devices=devices)
3677
+ add_function_test(TestArray, "test_full_matrix", test_full_matrix, devices=devices)
3678
+ add_function_test(TestArray, "test_full_struct", test_full_struct, devices=devices)
3679
+ add_function_test(TestArray, "test_ones_scalar", test_ones_scalar, devices=devices)
3680
+ add_function_test(TestArray, "test_ones_vector", test_ones_vector, devices=devices)
3681
+ add_function_test(TestArray, "test_ones_matrix", test_ones_matrix, devices=devices)
3682
+ add_function_test(TestArray, "test_ones_like_scalar", test_ones_like_scalar, devices=devices)
3683
+ add_function_test(TestArray, "test_ones_like_vector", test_ones_like_vector, devices=devices)
3684
+ add_function_test(TestArray, "test_ones_like_matrix", test_ones_like_matrix, devices=devices)
3685
+ add_function_test(TestArray, "test_empty_array", test_empty_array, devices=devices)
3686
+ add_function_test(TestArray, "test_empty_from_numpy", test_empty_from_numpy, devices=devices)
3687
+ add_function_test(TestArray, "test_empty_from_list", test_empty_from_list, devices=devices)
3688
+ add_function_test(TestArray, "test_to_list_scalar", test_to_list_scalar, devices=devices)
3689
+ add_function_test(TestArray, "test_to_list_vector", test_to_list_vector, devices=devices)
3690
+ add_function_test(TestArray, "test_to_list_matrix", test_to_list_matrix, devices=devices)
3691
+ add_function_test(TestArray, "test_to_list_struct", test_to_list_struct, devices=devices)
3692
+
3693
+ add_function_test(TestArray, "test_lower_bound", test_lower_bound, devices=devices)
3694
+ add_function_test(TestArray, "test_round_trip", test_round_trip, devices=devices)
3695
+ add_function_test(TestArray, "test_array_to_bool", test_array_to_bool, devices=devices)
3696
+ add_function_test(TestArray, "test_array_of_structs", test_array_of_structs, devices=devices)
3697
+ add_function_test(TestArray, "test_array_of_structs_grad", test_array_of_structs_grad, devices=devices)
3698
+ add_function_test(TestArray, "test_array_of_structs_from_numpy", test_array_of_structs_from_numpy, devices=devices)
3699
+ add_function_test(TestArray, "test_array_of_structs_roundtrip", test_array_of_structs_roundtrip, devices=devices)
3700
+ add_function_test(TestArray, "test_array_from_numpy", test_array_from_numpy, devices=devices)
3701
+ add_function_test(TestArray, "test_array_aliasing_from_numpy", test_array_aliasing_from_numpy, devices=["cpu"])
3702
+ add_function_test(TestArray, "test_numpy_array_interface", test_numpy_array_interface, devices=["cpu"])
3703
+
3704
+ add_function_test(TestArray, "test_array_inplace_diff_ops", test_array_inplace_diff_ops, devices=devices)
3705
+ add_function_test(TestArray, "test_array_inplace_non_diff_ops", test_array_inplace_non_diff_ops, devices=devices)
3706
+ add_function_test(TestArray, "test_direct_from_numpy", test_direct_from_numpy, devices=["cpu"])
3707
+ add_function_test(TestArray, "test_kernel_array_from_ptr", test_kernel_array_from_ptr, devices=devices)
3708
+ add_function_test(TestArray, "test_kernel_array_from_ptr_struct", test_kernel_array_from_ptr_struct, devices=devices)
3709
+ add_function_test(
3710
+ TestArray, "test_kernel_array_from_ptr_variable_shape", test_kernel_array_from_ptr_variable_shape, devices=devices
3711
+ )
3712
+
3713
+ add_function_test(TestArray, "test_array_from_int32_domain", test_array_from_int32_domain, devices=devices)
3714
+ add_function_test(TestArray, "test_array_from_int64_domain", test_array_from_int64_domain, devices=devices)
3715
+ add_function_test(TestArray, "test_indexing_types", test_indexing_types, devices=devices)
3716
+
3717
+ add_function_test(TestArray, "test_alloc_strides", test_alloc_strides, devices=devices)
3718
+ add_function_test(TestArray, "test_casting", test_casting, devices=devices)
3719
+ add_function_test(TestArray, "test_array_len", test_array_len, devices=devices)
3720
+ add_function_test(TestArray, "test_cuda_interface_conversion", test_cuda_interface_conversion, devices=devices)
3721
+ add_function_test(TestArray, "test_array_from_data", test_array_from_data, devices=devices)
3722
+
3723
+ add_function_test(TestArray, "test_array1d_slicing", test_array1d_slicing, devices=devices)
3724
+ add_function_test(TestArray, "test_array2d_slicing", test_array2d_slicing, devices=devices)
3725
+ add_function_test(TestArray, "test_array3d_slicing", test_array3d_slicing, devices=devices)
3726
+ add_function_test(TestArray, "test_array4d_slicing", test_array4d_slicing, devices=devices)
3727
+
3728
+ add_function_test(TestArray, "test_graph_fill_vecmat", test_graph_fill_vecmat, devices=cuda_devices)
3729
+
3730
+ try:
3731
+ import torch
3732
+
3733
+ # check which Warp devices work with Torch
3734
+ # CUDA devices may fail if Torch was not compiled with CUDA support
3735
+ torch_compatible_devices = []
3736
+ torch_compatible_cuda_devices = []
3737
+
3738
+ for d in devices:
3739
+ try:
3740
+ t = torch.arange(10, device=wp.device_to_torch(d))
3741
+ t += 1
3742
+ torch_compatible_devices.append(d)
3743
+ if d.is_cuda:
3744
+ torch_compatible_cuda_devices.append(d)
3745
+ except Exception as e:
3746
+ print(f"Skipping Array tests that use Torch on device '{d}' due to exception: {e}")
3747
+
3748
+ add_function_test(TestArray, "test_array_from_cai", test_array_from_cai, devices=torch_compatible_cuda_devices)
3749
+
3750
+ except Exception as e:
3751
+ print(f"Skipping Array tests that use Torch due to exception: {e}")
3752
+
3753
+
3754
+ if __name__ == "__main__":
3755
+ wp.clear_kernel_cache()
3756
+ unittest.main(verbosity=2)