warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2972 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+
25
+ @wp.kernel
26
+ def kernel_1d(a: wp.array(dtype=int, ndim=1)):
27
+ i = wp.tid()
28
+
29
+ wp.expect_eq(a[i], wp.tid())
30
+
31
+ a[i] = a[i] * 2
32
+ wp.atomic_add(a, i, 1)
33
+
34
+ wp.expect_eq(a[i], wp.tid() * 2 + 1)
35
+
36
+
37
+ def test_1d(test, device):
38
+ dim_x = 4
39
+
40
+ a = np.arange(0, dim_x, dtype=np.int32)
41
+
42
+ arr = wp.array(a, device=device)
43
+
44
+ test.assertEqual(arr.shape, a.shape)
45
+ test.assertEqual(arr.size, a.size)
46
+ test.assertEqual(arr.ndim, a.ndim)
47
+
48
+ with CheckOutput(test):
49
+ wp.launch(kernel_1d, dim=arr.size, inputs=[arr], device=device)
50
+
51
+
52
+ @wp.kernel
53
+ def kernel_2d(a: wp.array(dtype=int, ndim=2), m: int, n: int):
54
+ i = wp.tid() // n
55
+ j = wp.tid() % n
56
+
57
+ wp.expect_eq(a[i, j], wp.tid())
58
+ wp.expect_eq(a[i][j], wp.tid())
59
+
60
+ a[i, j] = a[i, j] * 2
61
+ wp.atomic_add(a, i, j, 1)
62
+
63
+ wp.expect_eq(a[i, j], wp.tid() * 2 + 1)
64
+
65
+
66
+ def test_2d(test, device):
67
+ dim_x = 4
68
+ dim_y = 2
69
+
70
+ a = np.arange(0, dim_x * dim_y, dtype=np.int32)
71
+ a = a.reshape(dim_x, dim_y)
72
+
73
+ arr = wp.array(a, device=device)
74
+
75
+ test.assertEqual(arr.shape, a.shape)
76
+ test.assertEqual(arr.size, a.size)
77
+ test.assertEqual(arr.ndim, a.ndim)
78
+
79
+ with CheckOutput(test):
80
+ wp.launch(kernel_2d, dim=arr.size, inputs=[arr, dim_x, dim_y], device=device)
81
+
82
+
83
+ @wp.kernel
84
+ def kernel_3d(a: wp.array(dtype=int, ndim=3), m: int, n: int, o: int):
85
+ i = wp.tid() // (n * o)
86
+ j = wp.tid() % (n * o) // o
87
+ k = wp.tid() % o
88
+
89
+ wp.expect_eq(a[i, j, k], wp.tid())
90
+ wp.expect_eq(a[i][j][k], wp.tid())
91
+
92
+ a[i, j, k] = a[i, j, k] * 2
93
+ a[i][j][k] = a[i][j][k] * 2
94
+ wp.atomic_add(a, i, j, k, 1)
95
+
96
+ wp.expect_eq(a[i, j, k], wp.tid() * 4 + 1)
97
+
98
+
99
+ def test_3d(test, device):
100
+ dim_x = 8
101
+ dim_y = 4
102
+ dim_z = 2
103
+
104
+ a = np.arange(0, dim_x * dim_y * dim_z, dtype=np.int32)
105
+ a = a.reshape(dim_x, dim_y, dim_z)
106
+
107
+ arr = wp.array(a, device=device)
108
+
109
+ test.assertEqual(arr.shape, a.shape)
110
+ test.assertEqual(arr.size, a.size)
111
+ test.assertEqual(arr.ndim, a.ndim)
112
+
113
+ with CheckOutput(test):
114
+ wp.launch(kernel_3d, dim=arr.size, inputs=[arr, dim_x, dim_y, dim_z], device=device)
115
+
116
+
117
+ @wp.kernel
118
+ def kernel_4d(a: wp.array(dtype=int, ndim=4), m: int, n: int, o: int, p: int):
119
+ i = wp.tid() // (n * o * p)
120
+ j = wp.tid() % (n * o * p) // (o * p)
121
+ k = wp.tid() % (o * p) / p
122
+ l = wp.tid() % p
123
+
124
+ wp.expect_eq(a[i, j, k, l], wp.tid())
125
+ wp.expect_eq(a[i][j][k][l], wp.tid())
126
+
127
+
128
+ def test_4d(test, device):
129
+ dim_x = 16
130
+ dim_y = 8
131
+ dim_z = 4
132
+ dim_w = 2
133
+
134
+ a = np.arange(0, dim_x * dim_y * dim_z * dim_w, dtype=np.int32)
135
+ a = a.reshape(dim_x, dim_y, dim_z, dim_w)
136
+
137
+ arr = wp.array(a, device=device)
138
+
139
+ test.assertEqual(arr.shape, a.shape)
140
+ test.assertEqual(arr.size, a.size)
141
+ test.assertEqual(arr.ndim, a.ndim)
142
+
143
+ with CheckOutput(test):
144
+ wp.launch(kernel_4d, dim=arr.size, inputs=[arr, dim_x, dim_y, dim_z, dim_w], device=device)
145
+
146
+
147
+ @wp.kernel
148
+ def kernel_4d_transposed(a: wp.array(dtype=int, ndim=4), m: int, n: int, o: int, p: int):
149
+ i = wp.tid() // (n * o * p)
150
+ j = wp.tid() % (n * o * p) // (o * p)
151
+ k = wp.tid() % (o * p) / p
152
+ l = wp.tid() % p
153
+
154
+ wp.expect_eq(a[l, k, j, i], wp.tid())
155
+ wp.expect_eq(a[l][k][j][i], wp.tid())
156
+
157
+
158
+ def test_4d_transposed(test, device):
159
+ dim_x = 16
160
+ dim_y = 8
161
+ dim_z = 4
162
+ dim_w = 2
163
+
164
+ a = np.arange(0, dim_x * dim_y * dim_z * dim_w, dtype=np.int32)
165
+ a = a.reshape(dim_x, dim_y, dim_z, dim_w)
166
+
167
+ arr = wp.array(a, device=device)
168
+
169
+ # Transpose the array manually, as using the wp.array() constructor with arr.T would make it contiguous first
170
+ a_T = a.T
171
+ arr_T = wp.array(
172
+ dtype=arr.dtype,
173
+ shape=a_T.shape,
174
+ strides=a_T.__array_interface__["strides"],
175
+ capacity=arr.capacity,
176
+ ptr=arr.ptr,
177
+ requires_grad=arr.requires_grad,
178
+ device=device,
179
+ )
180
+
181
+ test.assertFalse(arr_T.is_contiguous)
182
+ test.assertEqual(arr_T.shape, a_T.shape)
183
+ test.assertEqual(arr_T.strides, a_T.__array_interface__["strides"])
184
+ test.assertEqual(arr_T.size, a_T.size)
185
+ test.assertEqual(arr_T.ndim, a_T.ndim)
186
+
187
+ with CheckOutput(test):
188
+ wp.launch(kernel_4d_transposed, dim=arr_T.size, inputs=[arr_T, dim_x, dim_y, dim_z, dim_w], device=device)
189
+
190
+
191
+ @wp.kernel
192
+ def lower_bound_kernel(values: wp.array(dtype=float), arr: wp.array(dtype=float), indices: wp.array(dtype=int)):
193
+ tid = wp.tid()
194
+
195
+ indices[tid] = wp.lower_bound(arr, values[tid])
196
+
197
+
198
+ def test_lower_bound(test, device):
199
+ arr = wp.array(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0], dtype=float), dtype=float, device=device)
200
+ values = wp.array(np.array([-0.1, 0.0, 2.5, 4.0, 5.0, 5.5], dtype=float), dtype=float, device=device)
201
+ indices = wp.zeros(6, dtype=int, device=device)
202
+
203
+ wp.launch(kernel=lower_bound_kernel, dim=6, inputs=[values, arr, indices], device=device)
204
+
205
+ test.assertTrue((np.array([0, 0, 3, 4, 5, 5]) == indices.numpy()).all())
206
+
207
+
208
+ @wp.kernel
209
+ def f1(arr: wp.array(dtype=float)):
210
+ wp.expect_eq(arr.shape[0], 10)
211
+
212
+
213
+ @wp.kernel
214
+ def f2(arr: wp.array2d(dtype=float)):
215
+ wp.expect_eq(arr.shape[0], 10)
216
+ wp.expect_eq(arr.shape[1], 20)
217
+
218
+ slice = arr[0]
219
+ wp.expect_eq(slice.shape[0], 20)
220
+
221
+
222
+ @wp.kernel
223
+ def f3(arr: wp.array3d(dtype=float)):
224
+ wp.expect_eq(arr.shape[0], 10)
225
+ wp.expect_eq(arr.shape[1], 20)
226
+ wp.expect_eq(arr.shape[2], 30)
227
+
228
+ slice = arr[0, 0]
229
+ wp.expect_eq(slice.shape[0], 30)
230
+
231
+
232
+ @wp.kernel
233
+ def f4(arr: wp.array4d(dtype=float)):
234
+ wp.expect_eq(arr.shape[0], 10)
235
+ wp.expect_eq(arr.shape[1], 20)
236
+ wp.expect_eq(arr.shape[2], 30)
237
+ wp.expect_eq(arr.shape[3], 40)
238
+
239
+ slice = arr[0, 0, 0]
240
+ wp.expect_eq(slice.shape[0], 40)
241
+
242
+
243
+ def test_shape(test, device):
244
+ with CheckOutput(test):
245
+ a1 = wp.zeros(dtype=float, shape=10, device=device)
246
+ wp.launch(f1, dim=1, inputs=[a1], device=device)
247
+
248
+ a2 = wp.zeros(dtype=float, shape=(10, 20), device=device)
249
+ wp.launch(f2, dim=1, inputs=[a2], device=device)
250
+
251
+ a3 = wp.zeros(dtype=float, shape=(10, 20, 30), device=device)
252
+ wp.launch(f3, dim=1, inputs=[a3], device=device)
253
+
254
+ a4 = wp.zeros(dtype=float, shape=(10, 20, 30, 40), device=device)
255
+ wp.launch(f4, dim=1, inputs=[a4], device=device)
256
+
257
+
258
+ def test_negative_shape(test, device):
259
+ with test.assertRaisesRegex(ValueError, "Array shapes must be non-negative"):
260
+ _ = wp.zeros(shape=-1, dtype=int, device=device)
261
+
262
+ with test.assertRaisesRegex(ValueError, "Array shapes must be non-negative"):
263
+ _ = wp.zeros(shape=-(2**32), dtype=int, device=device)
264
+
265
+ with test.assertRaisesRegex(ValueError, "Array shapes must be non-negative"):
266
+ _ = wp.zeros(shape=(10, -1), dtype=int, device=device)
267
+
268
+
269
+ @wp.kernel
270
+ def sum_array(arr: wp.array(dtype=float), loss: wp.array(dtype=float)):
271
+ tid = wp.tid()
272
+ wp.atomic_add(loss, 0, arr[tid])
273
+
274
+
275
+ def test_flatten(test, device):
276
+ np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=float)
277
+ arr = wp.array(np_arr, dtype=float, shape=np_arr.shape, device=device, requires_grad=True)
278
+ arr_flat = arr.flatten()
279
+ arr_comp = wp.array(np_arr.flatten(), dtype=float, device=device)
280
+ assert_array_equal(arr_flat, arr_comp)
281
+
282
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
283
+ tape = wp.Tape()
284
+ with tape:
285
+ wp.launch(kernel=sum_array, dim=len(arr_flat), inputs=[arr_flat, loss], device=device)
286
+
287
+ tape.backward(loss=loss)
288
+ grad = tape.gradients[arr_flat]
289
+
290
+ ones = wp.array(
291
+ np.ones(
292
+ (8,),
293
+ dtype=float,
294
+ ),
295
+ dtype=float,
296
+ device=device,
297
+ )
298
+ assert_array_equal(grad, ones)
299
+ test.assertEqual(loss.numpy()[0], 36)
300
+
301
+
302
+ def test_reshape(test, device):
303
+ np_arr = np.arange(6, dtype=float)
304
+ arr = wp.array(np_arr, dtype=float, device=device, requires_grad=True)
305
+ arr_reshaped = arr.reshape((3, 2))
306
+ arr_comp = wp.array(np_arr.reshape((3, 2)), dtype=float, device=device)
307
+ assert_array_equal(arr_reshaped, arr_comp)
308
+
309
+ arr_reshaped = arr_reshaped.reshape(6)
310
+ assert_array_equal(arr_reshaped, arr)
311
+
312
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
313
+ tape = wp.Tape()
314
+ with tape:
315
+ wp.launch(kernel=sum_array, dim=len(arr_reshaped), inputs=[arr_reshaped, loss], device=device)
316
+
317
+ tape.backward(loss=loss)
318
+ grad = tape.gradients[arr_reshaped]
319
+
320
+ ones = wp.array(
321
+ np.ones(
322
+ (6,),
323
+ dtype=float,
324
+ ),
325
+ dtype=float,
326
+ device=device,
327
+ )
328
+ assert_array_equal(grad, ones)
329
+ test.assertEqual(loss.numpy()[0], 15)
330
+
331
+ np_arr = np.arange(6, dtype=float)
332
+ arr = wp.array(np_arr, dtype=float, device=device)
333
+ arr_infer = arr.reshape((-1, 3))
334
+ arr_comp = wp.array(np_arr.reshape((-1, 3)), dtype=float, device=device)
335
+ assert_array_equal(arr_infer, arr_comp)
336
+
337
+
338
+ @wp.kernel
339
+ def compare_stepped_window_a(x: wp.array2d(dtype=float)):
340
+ wp.expect_eq(x[0, 0], 1.0)
341
+ wp.expect_eq(x[0, 1], 2.0)
342
+ wp.expect_eq(x[1, 0], 9.0)
343
+ wp.expect_eq(x[1, 1], 10.0)
344
+
345
+
346
+ @wp.kernel
347
+ def compare_stepped_window_b(x: wp.array2d(dtype=float)):
348
+ wp.expect_eq(x[0, 0], 3.0)
349
+ wp.expect_eq(x[0, 1], 4.0)
350
+ wp.expect_eq(x[1, 0], 7.0)
351
+ wp.expect_eq(x[1, 1], 8.0)
352
+ wp.expect_eq(x[2, 0], 11.0)
353
+ wp.expect_eq(x[2, 1], 12.0)
354
+
355
+
356
+ def test_slicing(test, device):
357
+ np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], dtype=float)
358
+ arr = wp.array(np_arr, dtype=float, shape=np_arr.shape, device=device, requires_grad=True)
359
+
360
+ slice_a = arr[1, :, :] # test indexing
361
+ slice_b = arr[1:2, :, :] # test slicing
362
+ slice_c = arr[-1, :, :] # test negative indexing
363
+ slice_d = arr[-2:-1, :, :] # test negative slicing
364
+ slice_e = arr[-1:3, :, :] # test mixed slicing
365
+ slice_e2 = slice_e[0, 0, :] # test 2x slicing
366
+ slice_f = arr[0:3:2, 0, :] # test step
367
+
368
+ assert_array_equal(slice_a, wp.array(np_arr[1, :, :], dtype=float, device=device))
369
+ assert_array_equal(slice_b, wp.array(np_arr[1:2, :, :], dtype=float, device=device))
370
+ assert_array_equal(slice_c, wp.array(np_arr[-1, :, :], dtype=float, device=device))
371
+ assert_array_equal(slice_d, wp.array(np_arr[-2:-1, :, :], dtype=float, device=device))
372
+ assert_array_equal(slice_e, wp.array(np_arr[-1:3, :, :], dtype=float, device=device))
373
+ assert_array_equal(slice_e2, wp.array(np_arr[2, 0, :], dtype=float, device=device))
374
+
375
+ # wp does not support copying from/to non-contiguous arrays
376
+ # stepped windows must read on the device the original array was created on
377
+ wp.launch(kernel=compare_stepped_window_a, dim=1, inputs=[slice_f], device=device)
378
+
379
+ slice_flat = slice_b.flatten()
380
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
381
+ tape = wp.Tape()
382
+ with tape:
383
+ wp.launch(kernel=sum_array, dim=len(slice_flat), inputs=[slice_flat, loss], device=device)
384
+
385
+ tape.backward(loss=loss)
386
+ grad = tape.gradients[slice_flat]
387
+
388
+ ones = wp.array(
389
+ np.ones(
390
+ (4,),
391
+ dtype=float,
392
+ ),
393
+ dtype=float,
394
+ device=device,
395
+ )
396
+ assert_array_equal(grad, ones)
397
+ test.assertEqual(loss.numpy()[0], 26)
398
+
399
+ index_a = arr[1]
400
+ index_b = arr[2, 1]
401
+ index_c = arr[1, :]
402
+ index_d = arr[:, 1]
403
+
404
+ assert_array_equal(index_a, wp.array(np_arr[1], dtype=float, device=device))
405
+ assert_array_equal(index_b, wp.array(np_arr[2, 1], dtype=float, device=device))
406
+ assert_array_equal(index_c, wp.array(np_arr[1, :], dtype=float, device=device))
407
+ wp.launch(kernel=compare_stepped_window_b, dim=1, inputs=[index_d], device=device)
408
+
409
+ np_arr = np.zeros(10, dtype=int)
410
+ wp_arr = wp.array(np_arr, dtype=int, device=device)
411
+
412
+ assert_array_equal(wp_arr[:5], wp.array(np_arr[:5], dtype=int, device=device))
413
+ assert_array_equal(wp_arr[1:5], wp.array(np_arr[1:5], dtype=int, device=device))
414
+ assert_array_equal(wp_arr[-9:-5:1], wp.array(np_arr[-9:-5:1], dtype=int, device=device))
415
+ assert_array_equal(wp_arr[:5,], wp.array(np_arr[:5], dtype=int, device=device)) # noqa: E231
416
+
417
+
418
+ def test_view(test, device):
419
+ np_arr_a = np.arange(1, 10, 1, dtype=np.uint32)
420
+ np_arr_b = np.arange(1, 10, 1, dtype=np.float32)
421
+ np_arr_c = np.arange(1, 10, 1, dtype=np.uint16)
422
+ np_arr_d = np.arange(1, 10, 1, dtype=np.float16)
423
+ np_arr_e = np.ones((4, 4), dtype=np.float32)
424
+
425
+ wp_arr_a = wp.array(np_arr_a, dtype=wp.uint32, device=device)
426
+ wp_arr_b = wp.array(np_arr_b, dtype=wp.float32, device=device)
427
+ wp_arr_c = wp.array(np_arr_a, dtype=wp.uint16, device=device)
428
+ wp_arr_d = wp.array(np_arr_b, dtype=wp.float16, device=device)
429
+ wp_arr_e = wp.array(np_arr_e, dtype=wp.vec4, device=device)
430
+ wp_arr_f = wp.array(np_arr_e, dtype=wp.quat, device=device)
431
+
432
+ assert_np_equal(wp_arr_a.view(dtype=wp.float32).numpy(), np_arr_a.view(dtype=np.float32))
433
+ assert_np_equal(wp_arr_b.view(dtype=wp.uint32).numpy(), np_arr_b.view(dtype=np.uint32))
434
+ assert_np_equal(wp_arr_c.view(dtype=wp.float16).numpy(), np_arr_c.view(dtype=np.float16))
435
+ assert_np_equal(wp_arr_d.view(dtype=wp.uint16).numpy(), np_arr_d.view(dtype=np.uint16))
436
+ assert_array_equal(wp_arr_e.view(dtype=wp.quat), wp_arr_f)
437
+
438
+
439
+ def test_clone_adjoint(test, device):
440
+ state_in = wp.from_numpy(
441
+ np.array([1.0, 2.0, 3.0]).astype(np.float32), dtype=wp.float32, requires_grad=True, device=device
442
+ )
443
+
444
+ tape = wp.Tape()
445
+ with tape:
446
+ state_out = wp.clone(state_in)
447
+
448
+ grads = {state_out: wp.from_numpy(np.array([1.0, 1.0, 1.0]).astype(np.float32), dtype=wp.float32, device=device)}
449
+ tape.backward(grads=grads)
450
+
451
+ assert_np_equal(state_in.grad.numpy(), np.array([1.0, 1.0, 1.0]).astype(np.float32))
452
+
453
+
454
+ def test_assign_adjoint(test, device):
455
+ state_in = wp.from_numpy(
456
+ np.array([1.0, 2.0, 3.0]).astype(np.float32), dtype=wp.float32, requires_grad=True, device=device
457
+ )
458
+ state_out = wp.zeros(state_in.shape, dtype=wp.float32, requires_grad=True, device=device)
459
+
460
+ tape = wp.Tape()
461
+ with tape:
462
+ state_out.assign(state_in)
463
+
464
+ grads = {state_out: wp.from_numpy(np.array([1.0, 1.0, 1.0]).astype(np.float32), dtype=wp.float32, device=device)}
465
+ tape.backward(grads=grads)
466
+
467
+ assert_np_equal(state_in.grad.numpy(), np.array([1.0, 1.0, 1.0]).astype(np.float32))
468
+
469
+
470
+ @wp.kernel
471
+ def compare_2darrays(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float), z: wp.array2d(dtype=int)):
472
+ i, j = wp.tid()
473
+
474
+ if x[i, j] == y[i, j]:
475
+ z[i, j] = 1
476
+
477
+
478
+ @wp.kernel
479
+ def compare_3darrays(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float), z: wp.array3d(dtype=int)):
480
+ i, j, k = wp.tid()
481
+
482
+ if x[i, j, k] == y[i, j, k]:
483
+ z[i, j, k] = 1
484
+
485
+
486
+ def test_transpose(test, device):
487
+ # test default transpose in non-square 2d case
488
+ # wp does not support copying from/to non-contiguous arrays so check in kernel
489
+ np_arr = np.array([[1, 2], [3, 4], [5, 6]], dtype=float)
490
+ arr = wp.array(np_arr, dtype=float, device=device)
491
+ arr_transpose = arr.transpose()
492
+ arr_compare = wp.array(np_arr.transpose(), dtype=float, device=device)
493
+ check = wp.zeros(shape=(2, 3), dtype=int, device=device)
494
+
495
+ wp.launch(compare_2darrays, dim=(2, 3), inputs=[arr_transpose, arr_compare, check], device=device)
496
+ assert_np_equal(check.numpy(), np.ones((2, 3), dtype=int))
497
+
498
+ # test transpose in square 3d case
499
+ # wp does not support copying from/to non-contiguous arrays so check in kernel
500
+ np_arr = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], dtype=float)
501
+ arr = wp.array3d(np_arr, dtype=float, shape=np_arr.shape, device=device, requires_grad=True)
502
+ arr_transpose = arr.transpose((0, 2, 1))
503
+ arr_compare = wp.array3d(np_arr.transpose((0, 2, 1)), dtype=float, device=device)
504
+ check = wp.zeros(shape=(3, 2, 2), dtype=int, device=device)
505
+
506
+ wp.launch(compare_3darrays, dim=(3, 2, 2), inputs=[arr_transpose, arr_compare, check], device=device)
507
+ assert_np_equal(check.numpy(), np.ones((3, 2, 2), dtype=int))
508
+
509
+ # test transpose in square 3d case without axes supplied
510
+ arr_transpose = arr.transpose()
511
+ arr_compare = wp.array3d(np_arr.transpose(), dtype=float, device=device)
512
+ check = wp.zeros(shape=(2, 2, 3), dtype=int, device=device)
513
+
514
+ wp.launch(compare_3darrays, dim=(2, 2, 3), inputs=[arr_transpose, arr_compare, check], device=device)
515
+ assert_np_equal(check.numpy(), np.ones((2, 2, 3), dtype=int))
516
+
517
+ # test transpose in 1d case (should be noop)
518
+ np_arr = np.array([1, 2, 3], dtype=float)
519
+ arr = wp.array(np_arr, dtype=float, device=device)
520
+
521
+ assert_np_equal(arr.transpose().numpy(), np_arr.transpose())
522
+
523
+
524
+ def test_fill_scalar(test, device):
525
+ dim_x = 4
526
+
527
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
528
+ a1 = wp.zeros(dim_x, dtype=wptype, device=device)
529
+ a2 = wp.zeros((dim_x, dim_x), dtype=wptype, device=device)
530
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=wptype, device=device)
531
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=wptype, device=device)
532
+
533
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
534
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
535
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
536
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
537
+
538
+ # fill with int value
539
+ fill_value = 42
540
+
541
+ a1.fill_(fill_value)
542
+ a2.fill_(fill_value)
543
+ a3.fill_(fill_value)
544
+ a4.fill_(fill_value)
545
+
546
+ assert_np_equal(a1.numpy(), np.full(a1.shape, fill_value, dtype=nptype))
547
+ assert_np_equal(a2.numpy(), np.full(a2.shape, fill_value, dtype=nptype))
548
+ assert_np_equal(a3.numpy(), np.full(a3.shape, fill_value, dtype=nptype))
549
+ assert_np_equal(a4.numpy(), np.full(a4.shape, fill_value, dtype=nptype))
550
+
551
+ a1.zero_()
552
+ a2.zero_()
553
+ a3.zero_()
554
+ a4.zero_()
555
+
556
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
557
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
558
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
559
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
560
+
561
+ if wptype in wp.types.float_types:
562
+ # fill with float value
563
+ fill_value = 13.37
564
+
565
+ a1.fill_(fill_value)
566
+ a2.fill_(fill_value)
567
+ a3.fill_(fill_value)
568
+ a4.fill_(fill_value)
569
+
570
+ assert_np_equal(a1.numpy(), np.full(a1.shape, fill_value, dtype=nptype))
571
+ assert_np_equal(a2.numpy(), np.full(a2.shape, fill_value, dtype=nptype))
572
+ assert_np_equal(a3.numpy(), np.full(a3.shape, fill_value, dtype=nptype))
573
+ assert_np_equal(a4.numpy(), np.full(a4.shape, fill_value, dtype=nptype))
574
+
575
+ # fill with Warp scalar value
576
+ fill_value = wptype(17)
577
+
578
+ a1.fill_(fill_value)
579
+ a2.fill_(fill_value)
580
+ a3.fill_(fill_value)
581
+ a4.fill_(fill_value)
582
+
583
+ assert_np_equal(a1.numpy(), np.full(a1.shape, fill_value.value, dtype=nptype))
584
+ assert_np_equal(a2.numpy(), np.full(a2.shape, fill_value.value, dtype=nptype))
585
+ assert_np_equal(a3.numpy(), np.full(a3.shape, fill_value.value, dtype=nptype))
586
+ assert_np_equal(a4.numpy(), np.full(a4.shape, fill_value.value, dtype=nptype))
587
+
588
+
589
+ def test_fill_vector(test, device):
590
+ # test filling a vector array with scalar or vector values (vec_type, list, or numpy array)
591
+
592
+ dim_x = 4
593
+
594
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
595
+ # vector types
596
+ vector_types = [
597
+ wp.types.vector(2, wptype),
598
+ wp.types.vector(3, wptype),
599
+ wp.types.vector(4, wptype),
600
+ wp.types.vector(5, wptype),
601
+ ]
602
+
603
+ for vec_type in vector_types:
604
+ vec_len = vec_type._length_
605
+
606
+ a1 = wp.zeros(dim_x, dtype=vec_type, device=device)
607
+ a2 = wp.zeros((dim_x, dim_x), dtype=vec_type, device=device)
608
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=vec_type, device=device)
609
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=vec_type, device=device)
610
+
611
+ assert_np_equal(a1.numpy(), np.zeros((*a1.shape, vec_len), dtype=nptype))
612
+ assert_np_equal(a2.numpy(), np.zeros((*a2.shape, vec_len), dtype=nptype))
613
+ assert_np_equal(a3.numpy(), np.zeros((*a3.shape, vec_len), dtype=nptype))
614
+ assert_np_equal(a4.numpy(), np.zeros((*a4.shape, vec_len), dtype=nptype))
615
+
616
+ # fill with int scalar
617
+ fill_value = 42
618
+
619
+ a1.fill_(fill_value)
620
+ a2.fill_(fill_value)
621
+ a3.fill_(fill_value)
622
+ a4.fill_(fill_value)
623
+
624
+ assert_np_equal(a1.numpy(), np.full((*a1.shape, vec_len), fill_value, dtype=nptype))
625
+ assert_np_equal(a2.numpy(), np.full((*a2.shape, vec_len), fill_value, dtype=nptype))
626
+ assert_np_equal(a3.numpy(), np.full((*a3.shape, vec_len), fill_value, dtype=nptype))
627
+ assert_np_equal(a4.numpy(), np.full((*a4.shape, vec_len), fill_value, dtype=nptype))
628
+
629
+ # test zeroing
630
+ a1.zero_()
631
+ a2.zero_()
632
+ a3.zero_()
633
+ a4.zero_()
634
+
635
+ assert_np_equal(a1.numpy(), np.zeros((*a1.shape, vec_len), dtype=nptype))
636
+ assert_np_equal(a2.numpy(), np.zeros((*a2.shape, vec_len), dtype=nptype))
637
+ assert_np_equal(a3.numpy(), np.zeros((*a3.shape, vec_len), dtype=nptype))
638
+ assert_np_equal(a4.numpy(), np.zeros((*a4.shape, vec_len), dtype=nptype))
639
+
640
+ # vector values can be passed as a list, numpy array, or Warp vector instance
641
+ fill_list = [17, 42, 99, 101, 127][:vec_len]
642
+ fill_arr = np.array(fill_list, dtype=nptype)
643
+ fill_vec = vec_type(fill_list)
644
+
645
+ expected1 = np.tile(fill_arr, a1.size).reshape((*a1.shape, vec_len))
646
+ expected2 = np.tile(fill_arr, a2.size).reshape((*a2.shape, vec_len))
647
+ expected3 = np.tile(fill_arr, a3.size).reshape((*a3.shape, vec_len))
648
+ expected4 = np.tile(fill_arr, a4.size).reshape((*a4.shape, vec_len))
649
+
650
+ # fill with list of vector length
651
+ a1.fill_(fill_list)
652
+ a2.fill_(fill_list)
653
+ a3.fill_(fill_list)
654
+ a4.fill_(fill_list)
655
+
656
+ assert_np_equal(a1.numpy(), expected1)
657
+ assert_np_equal(a2.numpy(), expected2)
658
+ assert_np_equal(a3.numpy(), expected3)
659
+ assert_np_equal(a4.numpy(), expected4)
660
+
661
+ # clear
662
+ a1.zero_()
663
+ a2.zero_()
664
+ a3.zero_()
665
+ a4.zero_()
666
+
667
+ # fill with numpy array of vector length
668
+ a1.fill_(fill_arr)
669
+ a2.fill_(fill_arr)
670
+ a3.fill_(fill_arr)
671
+ a4.fill_(fill_arr)
672
+
673
+ assert_np_equal(a1.numpy(), expected1)
674
+ assert_np_equal(a2.numpy(), expected2)
675
+ assert_np_equal(a3.numpy(), expected3)
676
+ assert_np_equal(a4.numpy(), expected4)
677
+
678
+ # clear
679
+ a1.zero_()
680
+ a2.zero_()
681
+ a3.zero_()
682
+ a4.zero_()
683
+
684
+ # fill with vec instance
685
+ a1.fill_(fill_vec)
686
+ a2.fill_(fill_vec)
687
+ a3.fill_(fill_vec)
688
+ a4.fill_(fill_vec)
689
+
690
+ assert_np_equal(a1.numpy(), expected1)
691
+ assert_np_equal(a2.numpy(), expected2)
692
+ assert_np_equal(a3.numpy(), expected3)
693
+ assert_np_equal(a4.numpy(), expected4)
694
+
695
+ if wptype in wp.types.float_types:
696
+ # fill with float scalar
697
+ fill_value = 13.37
698
+
699
+ a1.fill_(fill_value)
700
+ a2.fill_(fill_value)
701
+ a3.fill_(fill_value)
702
+ a4.fill_(fill_value)
703
+
704
+ assert_np_equal(a1.numpy(), np.full((*a1.shape, vec_len), fill_value, dtype=nptype))
705
+ assert_np_equal(a2.numpy(), np.full((*a2.shape, vec_len), fill_value, dtype=nptype))
706
+ assert_np_equal(a3.numpy(), np.full((*a3.shape, vec_len), fill_value, dtype=nptype))
707
+ assert_np_equal(a4.numpy(), np.full((*a4.shape, vec_len), fill_value, dtype=nptype))
708
+
709
+ # fill with float list of vector length
710
+ fill_list = [-2.5, -1.25, 1.25, 2.5, 5.0][:vec_len]
711
+
712
+ a1.fill_(fill_list)
713
+ a2.fill_(fill_list)
714
+ a3.fill_(fill_list)
715
+ a4.fill_(fill_list)
716
+
717
+ expected1 = np.tile(np.array(fill_list, dtype=nptype), a1.size).reshape((*a1.shape, vec_len))
718
+ expected2 = np.tile(np.array(fill_list, dtype=nptype), a2.size).reshape((*a2.shape, vec_len))
719
+ expected3 = np.tile(np.array(fill_list, dtype=nptype), a3.size).reshape((*a3.shape, vec_len))
720
+ expected4 = np.tile(np.array(fill_list, dtype=nptype), a4.size).reshape((*a4.shape, vec_len))
721
+
722
+ assert_np_equal(a1.numpy(), expected1)
723
+ assert_np_equal(a2.numpy(), expected2)
724
+ assert_np_equal(a3.numpy(), expected3)
725
+ assert_np_equal(a4.numpy(), expected4)
726
+
727
+
728
+ def test_fill_matrix(test, device):
729
+ # test filling a matrix array with scalar or matrix values (mat_type, nested list, or 2d numpy array)
730
+
731
+ dim_x = 4
732
+
733
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
734
+ # matrix types
735
+ matrix_types = [
736
+ # square matrices
737
+ wp.types.matrix((2, 2), wptype),
738
+ wp.types.matrix((3, 3), wptype),
739
+ wp.types.matrix((4, 4), wptype),
740
+ wp.types.matrix((5, 5), wptype),
741
+ # non-square matrices
742
+ wp.types.matrix((2, 3), wptype),
743
+ wp.types.matrix((3, 2), wptype),
744
+ wp.types.matrix((3, 4), wptype),
745
+ wp.types.matrix((4, 3), wptype),
746
+ ]
747
+
748
+ for mat_type in matrix_types:
749
+ mat_len = mat_type._length_
750
+ mat_shape = mat_type._shape_
751
+
752
+ a1 = wp.zeros(dim_x, dtype=mat_type, device=device)
753
+ a2 = wp.zeros((dim_x, dim_x), dtype=mat_type, device=device)
754
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=mat_type, device=device)
755
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=mat_type, device=device)
756
+
757
+ assert_np_equal(a1.numpy(), np.zeros((*a1.shape, *mat_shape), dtype=nptype))
758
+ assert_np_equal(a2.numpy(), np.zeros((*a2.shape, *mat_shape), dtype=nptype))
759
+ assert_np_equal(a3.numpy(), np.zeros((*a3.shape, *mat_shape), dtype=nptype))
760
+ assert_np_equal(a4.numpy(), np.zeros((*a4.shape, *mat_shape), dtype=nptype))
761
+
762
+ # fill with scalar
763
+ fill_value = 42
764
+
765
+ a1.fill_(fill_value)
766
+ a2.fill_(fill_value)
767
+ a3.fill_(fill_value)
768
+ a4.fill_(fill_value)
769
+
770
+ assert_np_equal(a1.numpy(), np.full((*a1.shape, *mat_shape), fill_value, dtype=nptype))
771
+ assert_np_equal(a2.numpy(), np.full((*a2.shape, *mat_shape), fill_value, dtype=nptype))
772
+ assert_np_equal(a3.numpy(), np.full((*a3.shape, *mat_shape), fill_value, dtype=nptype))
773
+ assert_np_equal(a4.numpy(), np.full((*a4.shape, *mat_shape), fill_value, dtype=nptype))
774
+
775
+ # test zeroing
776
+ a1.zero_()
777
+ a2.zero_()
778
+ a3.zero_()
779
+ a4.zero_()
780
+
781
+ assert_np_equal(a1.numpy(), np.zeros((*a1.shape, *mat_shape), dtype=nptype))
782
+ assert_np_equal(a2.numpy(), np.zeros((*a2.shape, *mat_shape), dtype=nptype))
783
+ assert_np_equal(a3.numpy(), np.zeros((*a3.shape, *mat_shape), dtype=nptype))
784
+ assert_np_equal(a4.numpy(), np.zeros((*a4.shape, *mat_shape), dtype=nptype))
785
+
786
+ # matrix values can be passed as a 1d numpy array, 2d numpy array, flat list, nested list, or Warp matrix instance
787
+ if wptype != wp.bool:
788
+ fill_arr1 = np.arange(mat_len, dtype=nptype)
789
+ else:
790
+ fill_arr1 = np.ones(mat_len, dtype=nptype)
791
+ fill_arr2 = fill_arr1.reshape(mat_shape)
792
+ fill_list1 = list(fill_arr1)
793
+ fill_list2 = [list(row) for row in fill_arr2]
794
+ fill_mat = mat_type(fill_arr1)
795
+
796
+ expected1 = np.tile(fill_arr1, a1.size).reshape((*a1.shape, *mat_shape))
797
+ expected2 = np.tile(fill_arr1, a2.size).reshape((*a2.shape, *mat_shape))
798
+ expected3 = np.tile(fill_arr1, a3.size).reshape((*a3.shape, *mat_shape))
799
+ expected4 = np.tile(fill_arr1, a4.size).reshape((*a4.shape, *mat_shape))
800
+
801
+ # fill with 1d numpy array
802
+ a1.fill_(fill_arr1)
803
+ a2.fill_(fill_arr1)
804
+ a3.fill_(fill_arr1)
805
+ a4.fill_(fill_arr1)
806
+
807
+ assert_np_equal(a1.numpy(), expected1)
808
+ assert_np_equal(a2.numpy(), expected2)
809
+ assert_np_equal(a3.numpy(), expected3)
810
+ assert_np_equal(a4.numpy(), expected4)
811
+
812
+ # clear
813
+ a1.zero_()
814
+ a2.zero_()
815
+ a3.zero_()
816
+ a4.zero_()
817
+
818
+ # fill with 2d numpy array
819
+ a1.fill_(fill_arr2)
820
+ a2.fill_(fill_arr2)
821
+ a3.fill_(fill_arr2)
822
+ a4.fill_(fill_arr2)
823
+
824
+ assert_np_equal(a1.numpy(), expected1)
825
+ assert_np_equal(a2.numpy(), expected2)
826
+ assert_np_equal(a3.numpy(), expected3)
827
+ assert_np_equal(a4.numpy(), expected4)
828
+
829
+ # clear
830
+ a1.zero_()
831
+ a2.zero_()
832
+ a3.zero_()
833
+ a4.zero_()
834
+
835
+ # fill with flat list
836
+ a1.fill_(fill_list1)
837
+ a2.fill_(fill_list1)
838
+ a3.fill_(fill_list1)
839
+ a4.fill_(fill_list1)
840
+
841
+ assert_np_equal(a1.numpy(), expected1)
842
+ assert_np_equal(a2.numpy(), expected2)
843
+ assert_np_equal(a3.numpy(), expected3)
844
+ assert_np_equal(a4.numpy(), expected4)
845
+
846
+ # clear
847
+ a1.zero_()
848
+ a2.zero_()
849
+ a3.zero_()
850
+ a4.zero_()
851
+
852
+ # fill with nested list
853
+ a1.fill_(fill_list2)
854
+ a2.fill_(fill_list2)
855
+ a3.fill_(fill_list2)
856
+ a4.fill_(fill_list2)
857
+
858
+ assert_np_equal(a1.numpy(), expected1)
859
+ assert_np_equal(a2.numpy(), expected2)
860
+ assert_np_equal(a3.numpy(), expected3)
861
+ assert_np_equal(a4.numpy(), expected4)
862
+
863
+ # clear
864
+ a1.zero_()
865
+ a2.zero_()
866
+ a3.zero_()
867
+ a4.zero_()
868
+
869
+ # fill with mat instance
870
+ a1.fill_(fill_mat)
871
+ a2.fill_(fill_mat)
872
+ a3.fill_(fill_mat)
873
+ a4.fill_(fill_mat)
874
+
875
+ assert_np_equal(a1.numpy(), expected1)
876
+ assert_np_equal(a2.numpy(), expected2)
877
+ assert_np_equal(a3.numpy(), expected3)
878
+ assert_np_equal(a4.numpy(), expected4)
879
+
880
+
881
+ @wp.struct
882
+ class FillStruct:
883
+ # scalar members (make sure to test float16)
884
+ i1: wp.int8
885
+ i2: wp.int16
886
+ i4: wp.int32
887
+ i8: wp.int64
888
+ f2: wp.float16
889
+ f4: wp.float32
890
+ f8: wp.float16
891
+ # vector members (make sure to test vectors of float16)
892
+ v2: wp.types.vector(2, wp.int64)
893
+ v3: wp.types.vector(3, wp.float32)
894
+ v4: wp.types.vector(4, wp.float16)
895
+ v5: wp.types.vector(5, wp.uint8)
896
+ # matrix members (make sure to test matrices of float16)
897
+ m2: wp.types.matrix((2, 2), wp.float64)
898
+ m3: wp.types.matrix((3, 3), wp.int32)
899
+ m4: wp.types.matrix((4, 4), wp.float16)
900
+ m5: wp.types.matrix((5, 5), wp.int8)
901
+ # arrays
902
+ a1: wp.array(dtype=float)
903
+ a2: wp.array2d(dtype=float)
904
+ a3: wp.array3d(dtype=float)
905
+ a4: wp.array4d(dtype=float)
906
+
907
+
908
+ def test_fill_struct(test, device):
909
+ dim_x = 4
910
+
911
+ nptype = FillStruct.numpy_dtype()
912
+
913
+ a1 = wp.zeros(dim_x, dtype=FillStruct, device=device)
914
+ a2 = wp.zeros((dim_x, dim_x), dtype=FillStruct, device=device)
915
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=FillStruct, device=device)
916
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=FillStruct, device=device)
917
+
918
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
919
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
920
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
921
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
922
+
923
+ s = FillStruct()
924
+
925
+ # fill with default struct value (should be all zeros)
926
+ a1.fill_(s)
927
+ a2.fill_(s)
928
+ a3.fill_(s)
929
+ a4.fill_(s)
930
+
931
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
932
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
933
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
934
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
935
+
936
+ # scalars
937
+ s.i1 = -17
938
+ s.i2 = 42
939
+ s.i4 = 99
940
+ s.i8 = 101
941
+ s.f2 = -1.25
942
+ s.f4 = 13.37
943
+ s.f8 = 0.125
944
+ # vectors
945
+ s.v2 = [21, 22]
946
+ s.v3 = [31, 32, 33]
947
+ s.v4 = [41, 42, 43, 44]
948
+ s.v5 = [51, 52, 53, 54, 55]
949
+ # matrices
950
+ s.m2 = [[61, 62]] * 2
951
+ s.m3 = [[71, 72, 73]] * 3
952
+ s.m4 = [[81, 82, 83, 84]] * 4
953
+ s.m5 = [[91, 92, 93, 94, 95]] * 5
954
+ # arrays
955
+ s.a1 = wp.zeros((2,) * 1, dtype=float, device=device)
956
+ s.a2 = wp.zeros((2,) * 2, dtype=float, device=device)
957
+ s.a3 = wp.zeros((2,) * 3, dtype=float, device=device)
958
+ s.a4 = wp.zeros((2,) * 4, dtype=float, device=device)
959
+
960
+ # fill with custom struct value
961
+ a1.fill_(s)
962
+ a2.fill_(s)
963
+ a3.fill_(s)
964
+ a4.fill_(s)
965
+
966
+ ns = s.numpy_value()
967
+
968
+ expected1 = np.empty(a1.shape, dtype=nptype)
969
+ expected2 = np.empty(a2.shape, dtype=nptype)
970
+ expected3 = np.empty(a3.shape, dtype=nptype)
971
+ expected4 = np.empty(a4.shape, dtype=nptype)
972
+
973
+ expected1.fill(ns)
974
+ expected2.fill(ns)
975
+ expected3.fill(ns)
976
+ expected4.fill(ns)
977
+
978
+ assert_np_equal(a1.numpy(), expected1)
979
+ assert_np_equal(a2.numpy(), expected2)
980
+ assert_np_equal(a3.numpy(), expected3)
981
+ assert_np_equal(a4.numpy(), expected4)
982
+
983
+ # test clearing
984
+ a1.zero_()
985
+ a2.zero_()
986
+ a3.zero_()
987
+ a4.zero_()
988
+
989
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
990
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
991
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
992
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
993
+
994
+
995
+ def test_fill_slices(test, device):
996
+ # test fill_ and zero_ for non-contiguous arrays
997
+ # Note: we don't need to test the whole range of dtypes (vectors, matrices, structs) here
998
+
999
+ dim_x = 8
1000
+
1001
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1002
+ a1 = wp.zeros(dim_x, dtype=wptype, device=device)
1003
+ a2 = wp.zeros((dim_x, dim_x), dtype=wptype, device=device)
1004
+ a3 = wp.zeros((dim_x, dim_x, dim_x), dtype=wptype, device=device)
1005
+ a4 = wp.zeros((dim_x, dim_x, dim_x, dim_x), dtype=wptype, device=device)
1006
+
1007
+ assert_np_equal(a1.numpy(), np.zeros(a1.shape, dtype=nptype))
1008
+ assert_np_equal(a2.numpy(), np.zeros(a2.shape, dtype=nptype))
1009
+ assert_np_equal(a3.numpy(), np.zeros(a3.shape, dtype=nptype))
1010
+ assert_np_equal(a4.numpy(), np.zeros(a4.shape, dtype=nptype))
1011
+
1012
+ # partititon each array into even and odd slices
1013
+ a1a = a1[::2]
1014
+ a1b = a1[1::2]
1015
+ a2a = a2[::2]
1016
+ a2b = a2[1::2]
1017
+ a3a = a3[::2]
1018
+ a3b = a3[1::2]
1019
+ a4a = a4[::2]
1020
+ a4b = a4[1::2]
1021
+
1022
+ # fill even slices
1023
+ fill_a = 17
1024
+ a1a.fill_(fill_a)
1025
+ a2a.fill_(fill_a)
1026
+ a3a.fill_(fill_a)
1027
+ a4a.fill_(fill_a)
1028
+
1029
+ # ensure filled slices are correct
1030
+ assert_np_equal(a1a.numpy(), np.full(a1a.shape, fill_a, dtype=nptype))
1031
+ assert_np_equal(a2a.numpy(), np.full(a2a.shape, fill_a, dtype=nptype))
1032
+ assert_np_equal(a3a.numpy(), np.full(a3a.shape, fill_a, dtype=nptype))
1033
+ assert_np_equal(a4a.numpy(), np.full(a4a.shape, fill_a, dtype=nptype))
1034
+
1035
+ # ensure unfilled slices are unaffected
1036
+ assert_np_equal(a1b.numpy(), np.zeros(a1b.shape, dtype=nptype))
1037
+ assert_np_equal(a2b.numpy(), np.zeros(a2b.shape, dtype=nptype))
1038
+ assert_np_equal(a3b.numpy(), np.zeros(a3b.shape, dtype=nptype))
1039
+ assert_np_equal(a4b.numpy(), np.zeros(a4b.shape, dtype=nptype))
1040
+
1041
+ # fill odd slices
1042
+ fill_b = 42
1043
+ a1b.fill_(fill_b)
1044
+ a2b.fill_(fill_b)
1045
+ a3b.fill_(fill_b)
1046
+ a4b.fill_(fill_b)
1047
+
1048
+ # ensure filled slices are correct
1049
+ assert_np_equal(a1b.numpy(), np.full(a1b.shape, fill_b, dtype=nptype))
1050
+ assert_np_equal(a2b.numpy(), np.full(a2b.shape, fill_b, dtype=nptype))
1051
+ assert_np_equal(a3b.numpy(), np.full(a3b.shape, fill_b, dtype=nptype))
1052
+ assert_np_equal(a4b.numpy(), np.full(a4b.shape, fill_b, dtype=nptype))
1053
+
1054
+ # ensure unfilled slices are unaffected
1055
+ assert_np_equal(a1a.numpy(), np.full(a1a.shape, fill_a, dtype=nptype))
1056
+ assert_np_equal(a2a.numpy(), np.full(a2a.shape, fill_a, dtype=nptype))
1057
+ assert_np_equal(a3a.numpy(), np.full(a3a.shape, fill_a, dtype=nptype))
1058
+ assert_np_equal(a4a.numpy(), np.full(a4a.shape, fill_a, dtype=nptype))
1059
+
1060
+ # clear even slices
1061
+ a1a.zero_()
1062
+ a2a.zero_()
1063
+ a3a.zero_()
1064
+ a4a.zero_()
1065
+
1066
+ # ensure cleared slices are correct
1067
+ assert_np_equal(a1a.numpy(), np.zeros(a1a.shape, dtype=nptype))
1068
+ assert_np_equal(a2a.numpy(), np.zeros(a2a.shape, dtype=nptype))
1069
+ assert_np_equal(a3a.numpy(), np.zeros(a3a.shape, dtype=nptype))
1070
+ assert_np_equal(a4a.numpy(), np.zeros(a4a.shape, dtype=nptype))
1071
+
1072
+ # ensure uncleared slices are unaffected
1073
+ assert_np_equal(a1b.numpy(), np.full(a1b.shape, fill_b, dtype=nptype))
1074
+ assert_np_equal(a2b.numpy(), np.full(a2b.shape, fill_b, dtype=nptype))
1075
+ assert_np_equal(a3b.numpy(), np.full(a3b.shape, fill_b, dtype=nptype))
1076
+ assert_np_equal(a4b.numpy(), np.full(a4b.shape, fill_b, dtype=nptype))
1077
+
1078
+ # re-fill even slices
1079
+ a1a.fill_(fill_a)
1080
+ a2a.fill_(fill_a)
1081
+ a3a.fill_(fill_a)
1082
+ a4a.fill_(fill_a)
1083
+
1084
+ # clear odd slices
1085
+ a1b.zero_()
1086
+ a2b.zero_()
1087
+ a3b.zero_()
1088
+ a4b.zero_()
1089
+
1090
+ # ensure cleared slices are correct
1091
+ assert_np_equal(a1b.numpy(), np.zeros(a1b.shape, dtype=nptype))
1092
+ assert_np_equal(a2b.numpy(), np.zeros(a2b.shape, dtype=nptype))
1093
+ assert_np_equal(a3b.numpy(), np.zeros(a3b.shape, dtype=nptype))
1094
+ assert_np_equal(a4b.numpy(), np.zeros(a4b.shape, dtype=nptype))
1095
+
1096
+ # ensure uncleared slices are unaffected
1097
+ assert_np_equal(a1a.numpy(), np.full(a1a.shape, fill_a, dtype=nptype))
1098
+ assert_np_equal(a2a.numpy(), np.full(a2a.shape, fill_a, dtype=nptype))
1099
+ assert_np_equal(a3a.numpy(), np.full(a3a.shape, fill_a, dtype=nptype))
1100
+ assert_np_equal(a4a.numpy(), np.full(a4a.shape, fill_a, dtype=nptype))
1101
+
1102
+
1103
+ def test_full_scalar(test, device):
1104
+ dim = 4
1105
+
1106
+ for ndim in range(1, 5):
1107
+ shape = (dim,) * ndim
1108
+
1109
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1110
+ # fill with int value and specific dtype
1111
+ fill_value = 42
1112
+ a = wp.full(shape, fill_value, dtype=wptype, device=device)
1113
+ na = a.numpy()
1114
+
1115
+ test.assertEqual(a.shape, shape)
1116
+ test.assertEqual(a.dtype, wptype)
1117
+ test.assertEqual(na.shape, shape)
1118
+ test.assertEqual(na.dtype, nptype)
1119
+ assert_np_equal(na, np.full(shape, fill_value, dtype=nptype))
1120
+
1121
+ if wptype in wp.types.float_types:
1122
+ # fill with float value and specific dtype
1123
+ fill_value = 13.37
1124
+ a = wp.full(shape, fill_value, dtype=wptype, device=device)
1125
+ na = a.numpy()
1126
+
1127
+ test.assertEqual(a.shape, shape)
1128
+ test.assertEqual(a.dtype, wptype)
1129
+ test.assertEqual(na.shape, shape)
1130
+ test.assertEqual(na.dtype, nptype)
1131
+ assert_np_equal(na, np.full(shape, fill_value, dtype=nptype))
1132
+
1133
+ # fill with int value and automatically inferred dtype
1134
+ fill_value = 42
1135
+ a = wp.full(shape, fill_value, device=device)
1136
+ na = a.numpy()
1137
+
1138
+ test.assertEqual(a.shape, shape)
1139
+ test.assertEqual(a.dtype, wp.int32)
1140
+ test.assertEqual(na.shape, shape)
1141
+ test.assertEqual(na.dtype, np.int32)
1142
+ assert_np_equal(na, np.full(shape, fill_value, dtype=np.int32))
1143
+
1144
+ # fill with float value and automatically inferred dtype
1145
+ fill_value = 13.37
1146
+ a = wp.full(shape, fill_value, device=device)
1147
+ na = a.numpy()
1148
+
1149
+ test.assertEqual(a.shape, shape)
1150
+ test.assertEqual(a.dtype, wp.float32)
1151
+ test.assertEqual(na.shape, shape)
1152
+ test.assertEqual(na.dtype, np.float32)
1153
+ assert_np_equal(na, np.full(shape, fill_value, dtype=np.float32))
1154
+
1155
+
1156
+ def test_full_vector(test, device):
1157
+ dim = 4
1158
+
1159
+ for ndim in range(1, 5):
1160
+ shape = (dim,) * ndim
1161
+
1162
+ # full from scalar
1163
+ for veclen in [2, 3, 4, 5]:
1164
+ npshape = (*shape, veclen)
1165
+
1166
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1167
+ vectype = wp.types.vector(veclen, wptype)
1168
+
1169
+ # fill with scalar int value and specific dtype
1170
+ fill_value = 42
1171
+ a = wp.full(shape, fill_value, dtype=vectype, device=device)
1172
+ na = a.numpy()
1173
+
1174
+ test.assertEqual(a.shape, shape)
1175
+ test.assertEqual(a.dtype, vectype)
1176
+ test.assertEqual(na.shape, npshape)
1177
+ test.assertEqual(na.dtype, nptype)
1178
+ assert_np_equal(na, np.full(a.size * veclen, fill_value, dtype=nptype).reshape(npshape))
1179
+
1180
+ if wptype in wp.types.float_types:
1181
+ # fill with scalar float value and specific dtype
1182
+ fill_value = 13.37
1183
+ a = wp.full(shape, fill_value, dtype=vectype, device=device)
1184
+ na = a.numpy()
1185
+
1186
+ test.assertEqual(a.shape, shape)
1187
+ test.assertEqual(a.dtype, vectype)
1188
+ test.assertEqual(na.shape, npshape)
1189
+ test.assertEqual(na.dtype, nptype)
1190
+ assert_np_equal(na, np.full(a.size * veclen, fill_value, dtype=nptype).reshape(npshape))
1191
+
1192
+ # fill with vector value and specific dtype
1193
+ fill_vec = vectype(42)
1194
+ a = wp.full(shape, fill_vec, dtype=vectype, device=device)
1195
+ na = a.numpy()
1196
+
1197
+ test.assertEqual(a.shape, shape)
1198
+ test.assertEqual(a.dtype, vectype)
1199
+ test.assertEqual(na.shape, npshape)
1200
+ test.assertEqual(na.dtype, nptype)
1201
+ assert_np_equal(na, np.full(a.size * veclen, 42, dtype=nptype).reshape(npshape))
1202
+
1203
+ # fill with vector value and automatically inferred dtype
1204
+ a = wp.full(shape, fill_vec, device=device)
1205
+ na = a.numpy()
1206
+
1207
+ test.assertEqual(a.shape, shape)
1208
+ test.assertEqual(a.dtype, vectype)
1209
+ test.assertEqual(na.shape, npshape)
1210
+ test.assertEqual(na.dtype, nptype)
1211
+ assert_np_equal(na, np.full(a.size * veclen, 42, dtype=nptype).reshape(npshape))
1212
+
1213
+ fill_lists = [
1214
+ [17, 42],
1215
+ [17, 42, 99],
1216
+ [17, 42, 99, 101],
1217
+ [17, 42, 99, 101, 127],
1218
+ ]
1219
+
1220
+ # full from list and numpy array
1221
+ for fill_list in fill_lists:
1222
+ veclen = len(fill_list)
1223
+ npshape = (*shape, veclen)
1224
+
1225
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1226
+ vectype = wp.types.vector(veclen, wptype)
1227
+
1228
+ # fill with list and specific dtype
1229
+ a = wp.full(shape, fill_list, dtype=vectype, device=device)
1230
+ na = a.numpy()
1231
+
1232
+ test.assertEqual(a.shape, shape)
1233
+ test.assertEqual(a.dtype, vectype)
1234
+ test.assertEqual(na.shape, npshape)
1235
+ test.assertEqual(na.dtype, nptype)
1236
+
1237
+ expected = np.tile(np.array(fill_list, dtype=nptype), a.size).reshape(npshape)
1238
+ assert_np_equal(na, expected)
1239
+
1240
+ fill_arr = np.array(fill_list, dtype=nptype)
1241
+
1242
+ # fill with numpy array and specific dtype
1243
+ a = wp.full(shape, fill_arr, dtype=vectype, device=device)
1244
+ na = a.numpy()
1245
+
1246
+ test.assertEqual(a.shape, shape)
1247
+ test.assertEqual(a.dtype, vectype)
1248
+ test.assertEqual(na.shape, npshape)
1249
+ test.assertEqual(na.dtype, nptype)
1250
+ assert_np_equal(na, expected)
1251
+
1252
+ # fill with numpy array and automatically infer dtype
1253
+ a = wp.full(shape, fill_arr, device=device)
1254
+ na = a.numpy()
1255
+
1256
+ test.assertEqual(a.shape, shape)
1257
+ test.assertTrue(wp.types.types_equal(a.dtype, vectype))
1258
+ test.assertEqual(na.shape, npshape)
1259
+ test.assertEqual(na.dtype, nptype)
1260
+ assert_np_equal(na, expected)
1261
+
1262
+ # fill with list and automatically infer dtype
1263
+ a = wp.full(shape, fill_list, device=device)
1264
+ na = a.numpy()
1265
+
1266
+ test.assertEqual(a.shape, shape)
1267
+
1268
+ # check that the inferred dtype is a vector
1269
+ # Note that we cannot guarantee the scalar type, because it depends on numpy and may vary by platform
1270
+ # (e.g. int64 on Linux and int32 on Windows).
1271
+ test.assertEqual(a.dtype._wp_generic_type_str_, "vec_t")
1272
+ test.assertEqual(a.dtype._length_, veclen)
1273
+
1274
+ expected = np.tile(np.array(fill_list), a.size).reshape(npshape)
1275
+ assert_np_equal(na, expected)
1276
+
1277
+
1278
+ def test_full_matrix(test, device):
1279
+ dim = 4
1280
+
1281
+ for ndim in range(1, 5):
1282
+ shape = (dim,) * ndim
1283
+
1284
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1285
+ matrix_types = [
1286
+ # square matrices
1287
+ wp.types.matrix((2, 2), wptype),
1288
+ wp.types.matrix((3, 3), wptype),
1289
+ wp.types.matrix((4, 4), wptype),
1290
+ wp.types.matrix((5, 5), wptype),
1291
+ # non-square matrices
1292
+ wp.types.matrix((2, 3), wptype),
1293
+ wp.types.matrix((3, 2), wptype),
1294
+ wp.types.matrix((3, 4), wptype),
1295
+ wp.types.matrix((4, 3), wptype),
1296
+ ]
1297
+
1298
+ for mattype in matrix_types:
1299
+ npshape = (*shape, *mattype._shape_)
1300
+
1301
+ # fill with scalar int value and specific dtype
1302
+ fill_value = 42
1303
+ a = wp.full(shape, fill_value, dtype=mattype, device=device)
1304
+ na = a.numpy()
1305
+
1306
+ test.assertEqual(a.shape, shape)
1307
+ test.assertEqual(a.dtype, mattype)
1308
+ test.assertEqual(na.shape, npshape)
1309
+ test.assertEqual(na.dtype, nptype)
1310
+ assert_np_equal(na, np.full(a.size * mattype._length_, fill_value, dtype=nptype).reshape(npshape))
1311
+
1312
+ if wptype in wp.types.float_types:
1313
+ # fill with scalar float value and specific dtype
1314
+ fill_value = 13.37
1315
+ a = wp.full(shape, fill_value, dtype=mattype, device=device)
1316
+ na = a.numpy()
1317
+
1318
+ test.assertEqual(a.shape, shape)
1319
+ test.assertEqual(a.dtype, mattype)
1320
+ test.assertEqual(na.shape, npshape)
1321
+ test.assertEqual(na.dtype, nptype)
1322
+ assert_np_equal(na, np.full(a.size * mattype._length_, fill_value, dtype=nptype).reshape(npshape))
1323
+
1324
+ # fill with matrix value and specific dtype
1325
+ fill_mat = mattype(42)
1326
+ a = wp.full(shape, fill_mat, dtype=mattype, device=device)
1327
+ na = a.numpy()
1328
+
1329
+ test.assertEqual(a.shape, shape)
1330
+ test.assertEqual(a.dtype, mattype)
1331
+ test.assertEqual(na.shape, npshape)
1332
+ test.assertEqual(na.dtype, nptype)
1333
+ assert_np_equal(na, np.full(a.size * mattype._length_, 42, dtype=nptype).reshape(npshape))
1334
+
1335
+ # fill with matrix value and automatically inferred dtype
1336
+ fill_mat = mattype(42)
1337
+ a = wp.full(shape, fill_mat, device=device)
1338
+ na = a.numpy()
1339
+
1340
+ test.assertEqual(a.shape, shape)
1341
+ test.assertEqual(a.dtype, mattype)
1342
+ test.assertEqual(na.shape, npshape)
1343
+ test.assertEqual(na.dtype, nptype)
1344
+ assert_np_equal(na, np.full(a.size * mattype._length_, 42, dtype=nptype).reshape(npshape))
1345
+
1346
+ # fill with 1d numpy array and specific dtype
1347
+ if wptype != wp.bool:
1348
+ fill_arr1d = np.arange(mattype._length_, dtype=nptype)
1349
+ else:
1350
+ fill_arr1d = np.ones(mattype._length_, dtype=nptype)
1351
+ a = wp.full(shape, fill_arr1d, dtype=mattype, device=device)
1352
+ na = a.numpy()
1353
+
1354
+ test.assertEqual(a.shape, shape)
1355
+ test.assertEqual(a.dtype, mattype)
1356
+ test.assertEqual(na.shape, npshape)
1357
+ test.assertEqual(na.dtype, nptype)
1358
+
1359
+ expected = np.tile(fill_arr1d, a.size).reshape(npshape)
1360
+ assert_np_equal(na, expected)
1361
+
1362
+ # fill with 2d numpy array and specific dtype
1363
+ fill_arr2d = fill_arr1d.reshape(mattype._shape_)
1364
+ a = wp.full(shape, fill_arr2d, dtype=mattype, device=device)
1365
+ na = a.numpy()
1366
+
1367
+ test.assertEqual(a.shape, shape)
1368
+ test.assertEqual(a.dtype, mattype)
1369
+ test.assertEqual(na.shape, npshape)
1370
+ test.assertEqual(na.dtype, nptype)
1371
+ assert_np_equal(na, expected)
1372
+
1373
+ # fill with 2d numpy array and automatically infer dtype
1374
+ a = wp.full(shape, fill_arr2d, device=device)
1375
+ na = a.numpy()
1376
+
1377
+ test.assertEqual(a.shape, shape)
1378
+ test.assertTrue(wp.types.types_equal(a.dtype, mattype))
1379
+ test.assertEqual(na.shape, npshape)
1380
+ test.assertEqual(na.dtype, nptype)
1381
+ assert_np_equal(na, expected)
1382
+
1383
+ # fill with flat list and specific dtype
1384
+ fill_list1d = list(fill_arr1d)
1385
+ a = wp.full(shape, fill_list1d, dtype=mattype, device=device)
1386
+ na = a.numpy()
1387
+
1388
+ test.assertEqual(a.shape, shape)
1389
+ test.assertEqual(a.dtype, mattype)
1390
+ test.assertEqual(na.shape, npshape)
1391
+ test.assertEqual(na.dtype, nptype)
1392
+ assert_np_equal(na, expected)
1393
+
1394
+ # fill with nested list and specific dtype
1395
+ fill_list2d = [list(row) for row in fill_arr2d]
1396
+ a = wp.full(shape, fill_list2d, dtype=mattype, device=device)
1397
+ na = a.numpy()
1398
+
1399
+ test.assertEqual(a.shape, shape)
1400
+ test.assertEqual(a.dtype, mattype)
1401
+ test.assertEqual(na.shape, npshape)
1402
+ test.assertEqual(na.dtype, nptype)
1403
+ assert_np_equal(na, expected)
1404
+
1405
+ mat_lists = [
1406
+ # square matrices
1407
+ [[1, 2], [3, 4]],
1408
+ [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1409
+ [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
1410
+ # non-square matrices
1411
+ [[1, 2, 3, 4], [5, 6, 7, 8]],
1412
+ [[1, 2], [3, 4], [5, 6], [7, 8]],
1413
+ ]
1414
+
1415
+ # fill with nested lists and automatically infer dtype
1416
+ for fill_list in mat_lists:
1417
+ num_rows = len(fill_list)
1418
+ num_cols = len(fill_list[0])
1419
+ npshape = (*shape, num_rows, num_cols)
1420
+
1421
+ a = wp.full(shape, fill_list, device=device)
1422
+ na = a.numpy()
1423
+
1424
+ test.assertEqual(a.shape, shape)
1425
+
1426
+ # check that the inferred dtype is a correctly shaped matrix
1427
+ # Note that we cannot guarantee the scalar type, because it depends on numpy and may vary by platform
1428
+ # (e.g. int64 on Linux and int32 on Windows).
1429
+ test.assertEqual(a.dtype._wp_generic_type_str_, "mat_t")
1430
+ test.assertEqual(a.dtype._shape_, (num_rows, num_cols))
1431
+
1432
+ expected = np.tile(np.array(fill_list).flatten(), a.size).reshape(npshape)
1433
+ assert_np_equal(na, expected)
1434
+
1435
+
1436
+ def test_full_struct(test, device):
1437
+ dim = 4
1438
+
1439
+ for ndim in range(1, 5):
1440
+ shape = (dim,) * ndim
1441
+
1442
+ s = FillStruct()
1443
+
1444
+ # fill with default struct (should be zeros)
1445
+ a = wp.full(shape, s, dtype=FillStruct, device=device)
1446
+ na = a.numpy()
1447
+
1448
+ test.assertEqual(a.shape, shape)
1449
+ test.assertEqual(a.dtype, FillStruct)
1450
+ test.assertEqual(na.shape, shape)
1451
+ test.assertEqual(na.dtype, FillStruct.numpy_dtype())
1452
+ assert_np_equal(na, np.zeros(a.shape, dtype=FillStruct.numpy_dtype()))
1453
+
1454
+ # scalars
1455
+ s.i1 = -17
1456
+ s.i2 = 42
1457
+ s.i4 = 99
1458
+ s.i8 = 101
1459
+ s.f2 = -1.25
1460
+ s.f4 = 13.37
1461
+ s.f8 = 0.125
1462
+ # vectors
1463
+ s.v2 = [21, 22]
1464
+ s.v3 = [31, 32, 33]
1465
+ s.v4 = [41, 42, 43, 44]
1466
+ s.v5 = [51, 52, 53, 54, 55]
1467
+ # matrices
1468
+ s.m2 = [[61, 62]] * 2
1469
+ s.m3 = [[71, 72, 73]] * 3
1470
+ s.m4 = [[81, 82, 83, 84]] * 4
1471
+ s.m5 = [[91, 92, 93, 94, 95]] * 5
1472
+ # arrays
1473
+ s.a1 = wp.zeros((2,) * 1, dtype=float, device=device)
1474
+ s.a2 = wp.zeros((2,) * 2, dtype=float, device=device)
1475
+ s.a3 = wp.zeros((2,) * 3, dtype=float, device=device)
1476
+ s.a4 = wp.zeros((2,) * 4, dtype=float, device=device)
1477
+
1478
+ # fill with initialized struct and explicit dtype
1479
+ a = wp.full(shape, s, dtype=FillStruct, device=device)
1480
+ na = a.numpy()
1481
+
1482
+ test.assertEqual(a.shape, shape)
1483
+ test.assertEqual(a.dtype, FillStruct)
1484
+ test.assertEqual(na.shape, shape)
1485
+ test.assertEqual(na.dtype, FillStruct.numpy_dtype())
1486
+
1487
+ expected = np.empty(shape, dtype=FillStruct.numpy_dtype())
1488
+ expected.fill(s.numpy_value())
1489
+ assert_np_equal(na, expected)
1490
+
1491
+ # fill with initialized struct and automatically inferred dtype
1492
+ a = wp.full(shape, s, device=device)
1493
+ na = a.numpy()
1494
+
1495
+ test.assertEqual(a.shape, shape)
1496
+ test.assertEqual(a.dtype, FillStruct)
1497
+ test.assertEqual(na.shape, shape)
1498
+ test.assertEqual(na.dtype, FillStruct.numpy_dtype())
1499
+ assert_np_equal(na, expected)
1500
+
1501
+
1502
+ def test_ones_scalar(test, device):
1503
+ dim = 4
1504
+
1505
+ for ndim in range(1, 5):
1506
+ shape = (dim,) * ndim
1507
+
1508
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1509
+ a = wp.ones(shape, dtype=wptype, device=device)
1510
+ na = a.numpy()
1511
+
1512
+ test.assertEqual(a.shape, shape)
1513
+ test.assertEqual(a.dtype, wptype)
1514
+ test.assertEqual(na.shape, shape)
1515
+ test.assertEqual(na.dtype, nptype)
1516
+ assert_np_equal(na, np.ones(shape, dtype=nptype))
1517
+
1518
+
1519
+ def test_ones_vector(test, device):
1520
+ dim = 4
1521
+
1522
+ for ndim in range(1, 5):
1523
+ shape = (dim,) * ndim
1524
+
1525
+ for veclen in [2, 3, 4, 5]:
1526
+ npshape = (*shape, veclen)
1527
+
1528
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1529
+ vectype = wp.types.vector(veclen, wptype)
1530
+
1531
+ a = wp.ones(shape, dtype=vectype, device=device)
1532
+ na = a.numpy()
1533
+
1534
+ test.assertEqual(a.shape, shape)
1535
+ test.assertEqual(a.dtype, vectype)
1536
+ test.assertEqual(na.shape, npshape)
1537
+ test.assertEqual(na.dtype, nptype)
1538
+ assert_np_equal(na, np.ones(npshape, dtype=nptype))
1539
+
1540
+
1541
+ def test_ones_matrix(test, device):
1542
+ dim = 4
1543
+
1544
+ for ndim in range(1, 5):
1545
+ shape = (dim,) * ndim
1546
+
1547
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1548
+ matrix_types = [
1549
+ # square matrices
1550
+ wp.types.matrix((2, 2), wptype),
1551
+ wp.types.matrix((3, 3), wptype),
1552
+ wp.types.matrix((4, 4), wptype),
1553
+ wp.types.matrix((5, 5), wptype),
1554
+ # non-square matrices
1555
+ wp.types.matrix((2, 3), wptype),
1556
+ wp.types.matrix((3, 2), wptype),
1557
+ wp.types.matrix((3, 4), wptype),
1558
+ wp.types.matrix((4, 3), wptype),
1559
+ ]
1560
+
1561
+ for mattype in matrix_types:
1562
+ npshape = (*shape, *mattype._shape_)
1563
+
1564
+ a = wp.ones(shape, dtype=mattype, device=device)
1565
+ na = a.numpy()
1566
+
1567
+ test.assertEqual(a.shape, shape)
1568
+ test.assertEqual(a.dtype, mattype)
1569
+ test.assertEqual(na.shape, npshape)
1570
+ test.assertEqual(na.dtype, nptype)
1571
+ assert_np_equal(na, np.ones(npshape, dtype=nptype))
1572
+
1573
+
1574
+ def test_ones_like_scalar(test, device):
1575
+ dim = 4
1576
+
1577
+ for ndim in range(1, 5):
1578
+ shape = (dim,) * ndim
1579
+
1580
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1581
+ # source array
1582
+ a = wp.zeros(shape, dtype=wptype, device=device)
1583
+ na = a.numpy()
1584
+ test.assertEqual(a.shape, shape)
1585
+ test.assertEqual(a.dtype, wptype)
1586
+ test.assertEqual(na.shape, shape)
1587
+ test.assertEqual(na.dtype, nptype)
1588
+ assert_np_equal(na, np.zeros(shape, dtype=nptype))
1589
+
1590
+ # ones array
1591
+ b = wp.ones_like(a)
1592
+ nb = b.numpy()
1593
+ test.assertEqual(b.shape, shape)
1594
+ test.assertEqual(b.dtype, wptype)
1595
+ test.assertEqual(nb.shape, shape)
1596
+ test.assertEqual(nb.dtype, nptype)
1597
+ assert_np_equal(nb, np.ones(shape, dtype=nptype))
1598
+
1599
+
1600
+ def test_ones_like_vector(test, device):
1601
+ dim = 4
1602
+
1603
+ for ndim in range(1, 5):
1604
+ shape = (dim,) * ndim
1605
+
1606
+ for veclen in [2, 3, 4, 5]:
1607
+ npshape = (*shape, veclen)
1608
+
1609
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1610
+ vectype = wp.types.vector(veclen, wptype)
1611
+
1612
+ # source array
1613
+ a = wp.zeros(shape, dtype=vectype, device=device)
1614
+ na = a.numpy()
1615
+ test.assertEqual(a.shape, shape)
1616
+ test.assertEqual(a.dtype, vectype)
1617
+ test.assertEqual(na.shape, npshape)
1618
+ test.assertEqual(na.dtype, nptype)
1619
+ assert_np_equal(na, np.zeros(npshape, dtype=nptype))
1620
+
1621
+ # ones array
1622
+ b = wp.ones_like(a)
1623
+ nb = b.numpy()
1624
+ test.assertEqual(b.shape, shape)
1625
+ test.assertEqual(b.dtype, vectype)
1626
+ test.assertEqual(nb.shape, npshape)
1627
+ test.assertEqual(nb.dtype, nptype)
1628
+ assert_np_equal(nb, np.ones(npshape, dtype=nptype))
1629
+
1630
+
1631
+ def test_ones_like_matrix(test, device):
1632
+ dim = 4
1633
+
1634
+ for ndim in range(1, 5):
1635
+ shape = (dim,) * ndim
1636
+
1637
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1638
+ matrix_types = [
1639
+ # square matrices
1640
+ wp.types.matrix((2, 2), wptype),
1641
+ wp.types.matrix((3, 3), wptype),
1642
+ wp.types.matrix((4, 4), wptype),
1643
+ wp.types.matrix((5, 5), wptype),
1644
+ # non-square matrices
1645
+ wp.types.matrix((2, 3), wptype),
1646
+ wp.types.matrix((3, 2), wptype),
1647
+ wp.types.matrix((3, 4), wptype),
1648
+ wp.types.matrix((4, 3), wptype),
1649
+ ]
1650
+
1651
+ for mattype in matrix_types:
1652
+ npshape = (*shape, *mattype._shape_)
1653
+
1654
+ # source array
1655
+ a = wp.zeros(shape, dtype=mattype, device=device)
1656
+ na = a.numpy()
1657
+ test.assertEqual(a.shape, shape)
1658
+ test.assertEqual(a.dtype, mattype)
1659
+ test.assertEqual(na.shape, npshape)
1660
+ test.assertEqual(na.dtype, nptype)
1661
+ assert_np_equal(na, np.zeros(npshape, dtype=nptype))
1662
+
1663
+ # ones array
1664
+ b = wp.ones_like(a)
1665
+ nb = b.numpy()
1666
+ test.assertEqual(b.shape, shape)
1667
+ test.assertEqual(b.dtype, mattype)
1668
+ test.assertEqual(nb.shape, npshape)
1669
+ test.assertEqual(nb.dtype, nptype)
1670
+ assert_np_equal(nb, np.ones(npshape, dtype=nptype))
1671
+
1672
+
1673
+ def test_round_trip(test, device):
1674
+ rng = np.random.default_rng(123)
1675
+ dim_x = 4
1676
+
1677
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1678
+ a_np = rng.standard_normal(size=dim_x).astype(nptype)
1679
+ a = wp.array(a_np, device=device)
1680
+ test.assertEqual(a.dtype, wptype)
1681
+
1682
+ assert_np_equal(a.numpy(), a_np)
1683
+
1684
+ v_np = rng.standard_normal(size=(dim_x, 3)).astype(nptype)
1685
+ v = wp.array(v_np, dtype=wp.types.vector(3, wptype), device=device)
1686
+
1687
+ assert_np_equal(v.numpy(), v_np)
1688
+
1689
+
1690
+ def test_empty_array(test, device):
1691
+ # Test whether common operations work with empty (zero-sized) arrays
1692
+ # without throwing exceptions.
1693
+
1694
+ def test_empty_ops(ndim, nrows, ncols, wptype, nptype):
1695
+ shape = (0,) * ndim
1696
+ dtype_shape = ()
1697
+
1698
+ if wptype in wp.types.scalar_types:
1699
+ # scalar, vector, or matrix
1700
+ if ncols > 0:
1701
+ if nrows > 0:
1702
+ wptype = wp.types.matrix((nrows, ncols), wptype)
1703
+ else:
1704
+ wptype = wp.types.vector(ncols, wptype)
1705
+ dtype_shape = wptype._shape_
1706
+ fill_value = wptype(42)
1707
+ else:
1708
+ # struct
1709
+ fill_value = wptype()
1710
+
1711
+ # create a zero-sized array
1712
+ a = wp.empty(shape, dtype=wptype, device=device, requires_grad=True)
1713
+
1714
+ test.assertEqual(a.ptr, None)
1715
+ test.assertEqual(a.size, 0)
1716
+ test.assertEqual(a.shape, shape)
1717
+ test.assertEqual(a.grad.ptr, None)
1718
+ test.assertEqual(a.grad.size, 0)
1719
+ test.assertEqual(a.grad.shape, shape)
1720
+
1721
+ # all of these methods should succeed with zero-sized arrays
1722
+ a.zero_()
1723
+ a.fill_(fill_value)
1724
+ b = a.flatten()
1725
+ b = a.reshape((0,))
1726
+ b = a.transpose()
1727
+ b = a.contiguous()
1728
+
1729
+ b = wp.empty_like(a)
1730
+ b = wp.zeros_like(a)
1731
+ b = wp.full_like(a, fill_value)
1732
+ b = wp.clone(a)
1733
+
1734
+ wp.copy(a, b)
1735
+ a.assign(b)
1736
+
1737
+ na = a.numpy()
1738
+ test.assertEqual(na.size, 0)
1739
+ test.assertEqual(na.shape, (*shape, *dtype_shape))
1740
+ test.assertEqual(na.dtype, nptype)
1741
+
1742
+ test.assertEqual(a.list(), [])
1743
+
1744
+ for ndim in range(1, 5):
1745
+ # test with scalars, vectors, and matrices
1746
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1747
+ # scalars
1748
+ test_empty_ops(ndim, 0, 0, wptype, nptype)
1749
+
1750
+ for ncols in [2, 3, 4, 5]:
1751
+ # vectors
1752
+ test_empty_ops(ndim, 0, ncols, wptype, nptype)
1753
+ # square matrices
1754
+ test_empty_ops(ndim, ncols, ncols, wptype, nptype)
1755
+
1756
+ # non-square matrices
1757
+ test_empty_ops(ndim, 2, 3, wptype, nptype)
1758
+ test_empty_ops(ndim, 3, 2, wptype, nptype)
1759
+ test_empty_ops(ndim, 3, 4, wptype, nptype)
1760
+ test_empty_ops(ndim, 4, 3, wptype, nptype)
1761
+
1762
+ # test with structs
1763
+ test_empty_ops(ndim, 0, 0, FillStruct, FillStruct.numpy_dtype())
1764
+
1765
+
1766
+ def test_empty_from_numpy(test, device):
1767
+ # Test whether wrapping an empty (zero-sized) numpy array works correctly
1768
+
1769
+ def test_empty_from_data(ndim, nrows, ncols, wptype, nptype):
1770
+ shape = (0,) * ndim
1771
+ dtype_shape = ()
1772
+
1773
+ if ncols > 0:
1774
+ if nrows > 0:
1775
+ wptype = wp.types.matrix((nrows, ncols), wptype)
1776
+ else:
1777
+ wptype = wp.types.vector(ncols, wptype)
1778
+ dtype_shape = wptype._shape_
1779
+
1780
+ npshape = (*shape, *dtype_shape)
1781
+
1782
+ na = np.empty(npshape, dtype=nptype)
1783
+ a = wp.array(na, dtype=wptype, device=device)
1784
+ test.assertEqual(a.size, 0)
1785
+ test.assertEqual(a.shape, shape)
1786
+
1787
+ for ndim in range(1, 5):
1788
+ # test with scalars, vectors, and matrices
1789
+ for nptype, wptype in wp.types.np_dtype_to_warp_type.items():
1790
+ # scalars
1791
+ test_empty_from_data(ndim, 0, 0, wptype, nptype)
1792
+
1793
+ for ncols in [2, 3, 4, 5]:
1794
+ # vectors
1795
+ test_empty_from_data(ndim, 0, ncols, wptype, nptype)
1796
+ # square matrices
1797
+ test_empty_from_data(ndim, ncols, ncols, wptype, nptype)
1798
+
1799
+ # non-square matrices
1800
+ test_empty_from_data(ndim, 2, 3, wptype, nptype)
1801
+ test_empty_from_data(ndim, 3, 2, wptype, nptype)
1802
+ test_empty_from_data(ndim, 3, 4, wptype, nptype)
1803
+ test_empty_from_data(ndim, 4, 3, wptype, nptype)
1804
+
1805
+
1806
+ def test_empty_from_list(test, device):
1807
+ # Test whether creating an array from an empty Python list works correctly
1808
+
1809
+ def test_empty_from_data(nrows, ncols, wptype):
1810
+ if ncols > 0:
1811
+ if nrows > 0:
1812
+ wptype = wp.types.matrix((nrows, ncols), wptype)
1813
+ else:
1814
+ wptype = wp.types.vector(ncols, wptype)
1815
+
1816
+ a = wp.array([], dtype=wptype, device=device)
1817
+ test.assertEqual(a.size, 0)
1818
+ test.assertEqual(a.shape, (0,))
1819
+
1820
+ # test with scalars, vectors, and matrices
1821
+ for wptype in wp.types.scalar_types:
1822
+ # scalars
1823
+ test_empty_from_data(0, 0, wptype)
1824
+
1825
+ for ncols in [2, 3, 4, 5]:
1826
+ # vectors
1827
+ test_empty_from_data(0, ncols, wptype)
1828
+ # square matrices
1829
+ test_empty_from_data(ncols, ncols, wptype)
1830
+
1831
+ # non-square matrices
1832
+ test_empty_from_data(2, 3, wptype)
1833
+ test_empty_from_data(3, 2, wptype)
1834
+ test_empty_from_data(3, 4, wptype)
1835
+ test_empty_from_data(4, 3, wptype)
1836
+
1837
+
1838
+ def test_to_list_scalar(test, device):
1839
+ dim = 3
1840
+ fill_value = 42
1841
+
1842
+ for ndim in range(1, 5):
1843
+ shape = (dim,) * ndim
1844
+
1845
+ for wptype in wp.types.scalar_types:
1846
+ a = wp.full(shape, fill_value, dtype=wptype, device=device)
1847
+ l = a.list()
1848
+
1849
+ test.assertEqual(len(l), a.size)
1850
+ test.assertTrue(all(x == fill_value for x in l))
1851
+
1852
+
1853
+ def test_to_list_vector(test, device):
1854
+ dim = 3
1855
+
1856
+ for ndim in range(1, 5):
1857
+ shape = (dim,) * ndim
1858
+
1859
+ for veclen in [2, 3, 4, 5]:
1860
+ for wptype in wp.types.scalar_types:
1861
+ vectype = wp.types.vector(veclen, wptype)
1862
+ fill_value = vectype(42)
1863
+
1864
+ a = wp.full(shape, fill_value, dtype=vectype, device=device)
1865
+ l = a.list()
1866
+
1867
+ test.assertEqual(len(l), a.size)
1868
+ test.assertTrue(all(x == fill_value for x in l))
1869
+
1870
+
1871
+ def test_to_list_matrix(test, device):
1872
+ dim = 3
1873
+
1874
+ for ndim in range(1, 5):
1875
+ shape = (dim,) * ndim
1876
+
1877
+ for wptype in wp.types.scalar_types:
1878
+ matrix_types = [
1879
+ # square matrices
1880
+ wp.types.matrix((2, 2), wptype),
1881
+ wp.types.matrix((3, 3), wptype),
1882
+ wp.types.matrix((4, 4), wptype),
1883
+ wp.types.matrix((5, 5), wptype),
1884
+ # non-square matrices
1885
+ wp.types.matrix((2, 3), wptype),
1886
+ wp.types.matrix((3, 2), wptype),
1887
+ wp.types.matrix((3, 4), wptype),
1888
+ wp.types.matrix((4, 3), wptype),
1889
+ ]
1890
+
1891
+ for mattype in matrix_types:
1892
+ fill_value = mattype(42)
1893
+
1894
+ a = wp.full(shape, fill_value, dtype=mattype, device=device)
1895
+ l = a.list()
1896
+
1897
+ test.assertEqual(len(l), a.size)
1898
+ test.assertTrue(all(x == fill_value for x in l))
1899
+
1900
+
1901
+ def test_to_list_struct(test, device):
1902
+ @wp.struct
1903
+ class Inner:
1904
+ h: wp.float16
1905
+ v: wp.vec3
1906
+
1907
+ @wp.struct
1908
+ class ListStruct:
1909
+ i: int
1910
+ f: float
1911
+ h: wp.float16
1912
+ vi: wp.vec2i
1913
+ vf: wp.vec3f
1914
+ vh: wp.vec4h
1915
+ mi: wp.types.matrix((2, 2), int)
1916
+ mf: wp.types.matrix((3, 3), float)
1917
+ mh: wp.types.matrix((4, 4), wp.float16)
1918
+ inner: Inner
1919
+ a1: wp.array(dtype=int)
1920
+ a2: wp.array2d(dtype=float)
1921
+ a3: wp.array3d(dtype=wp.float16)
1922
+ bool: wp.bool
1923
+
1924
+ dim = 3
1925
+
1926
+ s = ListStruct()
1927
+ s.i = 42
1928
+ s.f = 2.5
1929
+ s.h = -1.25
1930
+ s.vi = wp.vec2i(1, 2)
1931
+ s.vf = wp.vec3f(0.1, 0.2, 0.3)
1932
+ s.vh = wp.vec4h(1.0, 2.0, 3.0, 4.0)
1933
+ s.mi = [[1, 2], [3, 4]]
1934
+ s.mf = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
1935
+ s.mh = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
1936
+ s.inner = Inner()
1937
+ s.inner.h = 1.5
1938
+ s.inner.v = [1, 2, 3]
1939
+ s.a1 = wp.empty(1, dtype=int, device=device)
1940
+ s.a2 = wp.empty((1, 1), dtype=float, device=device)
1941
+ s.a3 = wp.empty((1, 1, 1), dtype=wp.float16, device=device)
1942
+ s.bool = True
1943
+
1944
+ for ndim in range(1, 5):
1945
+ shape = (dim,) * ndim
1946
+
1947
+ a = wp.full(shape, s, dtype=ListStruct, device=device)
1948
+ l = a.list()
1949
+
1950
+ for i in range(a.size):
1951
+ test.assertEqual(l[i].i, s.i)
1952
+ test.assertEqual(l[i].f, s.f)
1953
+ test.assertEqual(l[i].h, s.h)
1954
+ test.assertEqual(l[i].vi, s.vi)
1955
+ test.assertEqual(l[i].vf, s.vf)
1956
+ test.assertEqual(l[i].vh, s.vh)
1957
+ test.assertEqual(l[i].mi, s.mi)
1958
+ test.assertEqual(l[i].mf, s.mf)
1959
+ test.assertEqual(l[i].mh, s.mh)
1960
+ test.assertEqual(l[i].bool, s.bool)
1961
+ test.assertEqual(l[i].inner.h, s.inner.h)
1962
+ test.assertEqual(l[i].inner.v, s.inner.v)
1963
+ test.assertEqual(l[i].a1.dtype, s.a1.dtype)
1964
+ test.assertEqual(l[i].a1.ndim, s.a1.ndim)
1965
+ test.assertEqual(l[i].a2.dtype, s.a2.dtype)
1966
+ test.assertEqual(l[i].a2.ndim, s.a2.ndim)
1967
+ test.assertEqual(l[i].a3.dtype, s.a3.dtype)
1968
+ test.assertEqual(l[i].a3.ndim, s.a3.ndim)
1969
+
1970
+
1971
+ @wp.kernel
1972
+ def kernel_array_to_bool(array_null: wp.array(dtype=float), array_valid: wp.array(dtype=float)):
1973
+ if not array_null:
1974
+ # always succeed
1975
+ wp.expect_eq(0, 0)
1976
+ else:
1977
+ # force failure
1978
+ wp.expect_eq(1, 2)
1979
+
1980
+ if array_valid:
1981
+ # always succeed
1982
+ wp.expect_eq(0, 0)
1983
+ else:
1984
+ # force failure
1985
+ wp.expect_eq(1, 2)
1986
+
1987
+
1988
+ def test_array_to_bool(test, device):
1989
+ arr = wp.zeros(8, dtype=float, device=device)
1990
+
1991
+ wp.launch(kernel_array_to_bool, dim=1, inputs=[None, arr], device=device)
1992
+
1993
+
1994
+ @wp.struct
1995
+ class InputStruct:
1996
+ param1: int
1997
+ param2: float
1998
+ param3: wp.vec3
1999
+ param4: wp.array(dtype=float)
2000
+
2001
+
2002
+ @wp.struct
2003
+ class OutputStruct:
2004
+ param1: int
2005
+ param2: float
2006
+ param3: wp.vec3
2007
+
2008
+
2009
+ @wp.kernel
2010
+ def struct_array_kernel(inputs: wp.array(dtype=InputStruct), outputs: wp.array(dtype=OutputStruct)):
2011
+ tid = wp.tid()
2012
+
2013
+ wp.expect_eq(inputs[tid].param1, tid)
2014
+ wp.expect_eq(inputs[tid].param2, float(tid * tid))
2015
+
2016
+ wp.expect_eq(inputs[tid].param3[0], 1.0)
2017
+ wp.expect_eq(inputs[tid].param3[1], 2.0)
2018
+ wp.expect_eq(inputs[tid].param3[2], 3.0)
2019
+
2020
+ wp.expect_eq(inputs[tid].param4[0], 1.0)
2021
+ wp.expect_eq(inputs[tid].param4[1], 2.0)
2022
+ wp.expect_eq(inputs[tid].param4[2], 3.0)
2023
+
2024
+ o = OutputStruct()
2025
+ o.param1 = inputs[tid].param1
2026
+ o.param2 = inputs[tid].param2
2027
+ o.param3 = inputs[tid].param3
2028
+
2029
+ outputs[tid] = o
2030
+
2031
+
2032
+ def test_array_of_structs(test, device):
2033
+ num_items = 10
2034
+
2035
+ l = []
2036
+ for i in range(num_items):
2037
+ s = InputStruct()
2038
+ s.param1 = i
2039
+ s.param2 = float(i * i)
2040
+ s.param3 = wp.vec3(1.0, 2.0, 3.0)
2041
+ s.param4 = wp.array([1.0, 2.0, 3.0], dtype=float, device=device)
2042
+
2043
+ l.append(s)
2044
+
2045
+ # initialize array from list of structs
2046
+ inputs = wp.array(l, dtype=InputStruct, device=device)
2047
+ outputs = wp.zeros(num_items, dtype=OutputStruct, device=device)
2048
+
2049
+ # pass to our compute kernel
2050
+ wp.launch(struct_array_kernel, dim=num_items, inputs=[inputs, outputs], device=device)
2051
+
2052
+ out_numpy = outputs.numpy()
2053
+ out_list = outputs.list()
2054
+ out_cptr = outputs.to("cpu").cptr()
2055
+
2056
+ for i in range(num_items):
2057
+ test.assertEqual(out_numpy[i][0], l[i].param1)
2058
+ test.assertEqual(out_numpy[i][1], l[i].param2)
2059
+ assert_np_equal(out_numpy[i][2], np.array(l[i].param3))
2060
+
2061
+ # test named slices of numpy structured array
2062
+ test.assertEqual(out_numpy["param1"][i], l[i].param1)
2063
+ test.assertEqual(out_numpy["param2"][i], l[i].param2)
2064
+ assert_np_equal(out_numpy["param3"][i], np.array(l[i].param3))
2065
+
2066
+ test.assertEqual(out_list[i].param1, l[i].param1)
2067
+ test.assertEqual(out_list[i].param2, l[i].param2)
2068
+ test.assertEqual(out_list[i].param3, l[i].param3)
2069
+
2070
+ test.assertEqual(out_cptr[i].param1, l[i].param1)
2071
+ test.assertEqual(out_cptr[i].param2, l[i].param2)
2072
+ test.assertEqual(out_cptr[i].param3, l[i].param3)
2073
+
2074
+
2075
+ @wp.struct
2076
+ class GradStruct:
2077
+ param1: int
2078
+ param2: float
2079
+ param3: wp.vec3
2080
+
2081
+
2082
+ @wp.kernel
2083
+ def test_array_of_structs_grad_kernel(inputs: wp.array(dtype=GradStruct), loss: wp.array(dtype=float)):
2084
+ tid = wp.tid()
2085
+
2086
+ wp.atomic_add(loss, 0, inputs[tid].param2 * 2.0)
2087
+
2088
+
2089
+ def test_array_of_structs_grad(test, device):
2090
+ num_items = 10
2091
+
2092
+ l = []
2093
+ for i in range(num_items):
2094
+ g = GradStruct()
2095
+ g.param2 = float(i)
2096
+
2097
+ l.append(g)
2098
+
2099
+ a = wp.array(l, dtype=GradStruct, device=device, requires_grad=True)
2100
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
2101
+
2102
+ with wp.Tape() as tape:
2103
+ wp.launch(test_array_of_structs_grad_kernel, dim=num_items, inputs=[a, loss], device=device)
2104
+
2105
+ tape.backward(loss)
2106
+
2107
+ grads = a.grad.numpy()
2108
+ assert_np_equal(grads["param2"], np.full(num_items, 2.0, dtype=np.float32))
2109
+
2110
+
2111
+ @wp.struct
2112
+ class NumpyStruct:
2113
+ x: int
2114
+ v: wp.vec3
2115
+
2116
+
2117
+ def test_array_of_structs_from_numpy(test, device):
2118
+ num_items = 10
2119
+
2120
+ na = np.zeros(num_items, dtype=NumpyStruct.numpy_dtype())
2121
+ na["x"] = 17
2122
+ na["v"] = (1, 2, 3)
2123
+
2124
+ a = wp.array(data=na, dtype=NumpyStruct, device=device)
2125
+
2126
+ assert_np_equal(a.numpy(), na)
2127
+
2128
+
2129
+ def test_array_of_structs_roundtrip(test, device):
2130
+ num_items = 10
2131
+
2132
+ value = NumpyStruct()
2133
+ value.x = 17
2134
+ value.v = wp.vec3(1.0, 2.0, 3.0)
2135
+
2136
+ # create Warp structured array
2137
+ a = wp.full(num_items, value, device=device)
2138
+
2139
+ # convert to NumPy structured array
2140
+ na = a.numpy()
2141
+
2142
+ expected = np.zeros(num_items, dtype=NumpyStruct.numpy_dtype())
2143
+ expected["x"] = value.x
2144
+ expected["v"] = value.v
2145
+
2146
+ assert_np_equal(na, expected)
2147
+
2148
+ # modify a field
2149
+ na["x"] = 42
2150
+
2151
+ # convert back to Warp array
2152
+ a = wp.from_numpy(na, NumpyStruct, device=device)
2153
+
2154
+ expected["x"] = 42
2155
+
2156
+ assert_np_equal(a.numpy(), expected)
2157
+
2158
+
2159
+ def test_array_from_numpy(test, device):
2160
+ arr = np.array((1.0, 2.0, 3.0), dtype=float)
2161
+
2162
+ result = wp.from_numpy(arr, device=device)
2163
+ expected = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, shape=(3,))
2164
+ assert_np_equal(result.numpy(), expected.numpy())
2165
+
2166
+ result = wp.from_numpy(arr, dtype=wp.vec3, device=device)
2167
+ expected = wp.array(((1.0, 2.0, 3.0),), dtype=wp.vec3, shape=(1,))
2168
+ assert_np_equal(result.numpy(), expected.numpy())
2169
+
2170
+ # --------------------------------------------------------------------------
2171
+
2172
+ arr = np.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=float)
2173
+
2174
+ result = wp.from_numpy(arr, device=device)
2175
+ expected = wp.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=wp.vec3, shape=(2,))
2176
+ assert_np_equal(result.numpy(), expected.numpy())
2177
+
2178
+ result = wp.from_numpy(arr, dtype=wp.float32, device=device)
2179
+ expected = wp.array(((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)), dtype=wp.float32, shape=(2, 3))
2180
+ assert_np_equal(result.numpy(), expected.numpy())
2181
+
2182
+ result = wp.from_numpy(arr, dtype=wp.float32, shape=(6,), device=device)
2183
+ expected = wp.array((1.0, 2.0, 3.0, 4.0, 5.0, 6.0), dtype=wp.float32, shape=(6,))
2184
+ assert_np_equal(result.numpy(), expected.numpy())
2185
+
2186
+ # --------------------------------------------------------------------------
2187
+
2188
+ arr = np.array(
2189
+ (
2190
+ (
2191
+ (1.0, 2.0, 3.0, 4.0),
2192
+ (2.0, 3.0, 4.0, 5.0),
2193
+ (3.0, 4.0, 5.0, 6.0),
2194
+ (4.0, 5.0, 6.0, 7.0),
2195
+ ),
2196
+ (
2197
+ (2.0, 3.0, 4.0, 5.0),
2198
+ (3.0, 4.0, 5.0, 6.0),
2199
+ (4.0, 5.0, 6.0, 7.0),
2200
+ (5.0, 6.0, 7.0, 8.0),
2201
+ ),
2202
+ ),
2203
+ dtype=float,
2204
+ )
2205
+
2206
+ result = wp.from_numpy(arr, device=device)
2207
+ expected = wp.array(
2208
+ (
2209
+ (
2210
+ (1.0, 2.0, 3.0, 4.0),
2211
+ (2.0, 3.0, 4.0, 5.0),
2212
+ (3.0, 4.0, 5.0, 6.0),
2213
+ (4.0, 5.0, 6.0, 7.0),
2214
+ ),
2215
+ (
2216
+ (2.0, 3.0, 4.0, 5.0),
2217
+ (3.0, 4.0, 5.0, 6.0),
2218
+ (4.0, 5.0, 6.0, 7.0),
2219
+ (5.0, 6.0, 7.0, 8.0),
2220
+ ),
2221
+ ),
2222
+ dtype=wp.mat44,
2223
+ shape=(2,),
2224
+ )
2225
+ assert_np_equal(result.numpy(), expected.numpy())
2226
+
2227
+ result = wp.from_numpy(arr, dtype=wp.float32, device=device)
2228
+ expected = wp.array(
2229
+ (
2230
+ (
2231
+ (1.0, 2.0, 3.0, 4.0),
2232
+ (2.0, 3.0, 4.0, 5.0),
2233
+ (3.0, 4.0, 5.0, 6.0),
2234
+ (4.0, 5.0, 6.0, 7.0),
2235
+ ),
2236
+ (
2237
+ (2.0, 3.0, 4.0, 5.0),
2238
+ (3.0, 4.0, 5.0, 6.0),
2239
+ (4.0, 5.0, 6.0, 7.0),
2240
+ (5.0, 6.0, 7.0, 8.0),
2241
+ ),
2242
+ ),
2243
+ dtype=wp.float32,
2244
+ shape=(2, 4, 4),
2245
+ )
2246
+ assert_np_equal(result.numpy(), expected.numpy())
2247
+
2248
+ result = wp.from_numpy(arr, dtype=wp.vec4, device=device).reshape((8,)) # Reshape from (2, 4)
2249
+ expected = wp.array(
2250
+ (
2251
+ (1.0, 2.0, 3.0, 4.0),
2252
+ (2.0, 3.0, 4.0, 5.0),
2253
+ (3.0, 4.0, 5.0, 6.0),
2254
+ (4.0, 5.0, 6.0, 7.0),
2255
+ (2.0, 3.0, 4.0, 5.0),
2256
+ (3.0, 4.0, 5.0, 6.0),
2257
+ (4.0, 5.0, 6.0, 7.0),
2258
+ (5.0, 6.0, 7.0, 8.0),
2259
+ ),
2260
+ dtype=wp.vec4,
2261
+ shape=(8,),
2262
+ )
2263
+ assert_np_equal(result.numpy(), expected.numpy())
2264
+
2265
+ result = wp.from_numpy(arr, dtype=wp.float32, shape=(32,), device=device)
2266
+ expected = wp.array(
2267
+ (
2268
+ 1.0,
2269
+ 2.0,
2270
+ 3.0,
2271
+ 4.0,
2272
+ 2.0,
2273
+ 3.0,
2274
+ 4.0,
2275
+ 5.0,
2276
+ 3.0,
2277
+ 4.0,
2278
+ 5.0,
2279
+ 6.0,
2280
+ 4.0,
2281
+ 5.0,
2282
+ 6.0,
2283
+ 7.0,
2284
+ 2.0,
2285
+ 3.0,
2286
+ 4.0,
2287
+ 5.0,
2288
+ 3.0,
2289
+ 4.0,
2290
+ 5.0,
2291
+ 6.0,
2292
+ 4.0,
2293
+ 5.0,
2294
+ 6.0,
2295
+ 7.0,
2296
+ 5.0,
2297
+ 6.0,
2298
+ 7.0,
2299
+ 8.0,
2300
+ ),
2301
+ dtype=wp.float32,
2302
+ shape=(32,),
2303
+ )
2304
+ assert_np_equal(result.numpy(), expected.numpy())
2305
+
2306
+
2307
+ def test_array_aliasing_from_numpy(test, device):
2308
+ device = wp.get_device(device)
2309
+ assert device.is_cpu
2310
+
2311
+ a_np = np.ones(8, dtype=np.int32)
2312
+ a_wp = wp.array(a_np, dtype=int, copy=False, device=device)
2313
+ test.assertIs(a_wp._ref, a_np) # check that some ref is kept to original array
2314
+ test.assertEqual(a_wp.ptr, a_np.ctypes.data)
2315
+
2316
+ a_np_2 = a_wp.numpy()
2317
+ test.assertTrue((a_np_2 == 1).all())
2318
+
2319
+ # updating source array should update aliased array
2320
+ a_np.fill(2)
2321
+ test.assertTrue((a_np_2 == 2).all())
2322
+
2323
+ # trying to alias from a different type should do a copy
2324
+ # do it twice to check that the copy buffer is not being reused for different arrays
2325
+
2326
+ b_np = np.ones(8, dtype=np.int64)
2327
+ c_np = np.zeros(8, dtype=np.int64)
2328
+ b_wp = wp.array(b_np, dtype=int, copy=False, device=device)
2329
+ c_wp = wp.array(c_np, dtype=int, copy=False, device=device)
2330
+
2331
+ test.assertNotEqual(b_wp.ptr, b_np.ctypes.data)
2332
+ test.assertNotEqual(b_wp.ptr, c_wp.ptr)
2333
+
2334
+ b_np_2 = b_wp.numpy()
2335
+ c_np_2 = c_wp.numpy()
2336
+ test.assertTrue((b_np_2 == 1).all())
2337
+ test.assertTrue((c_np_2 == 0).all())
2338
+
2339
+
2340
+ def test_array_from_cai(test, device):
2341
+ import torch
2342
+
2343
+ @wp.kernel
2344
+ def first_row_plus_one(x: wp.array2d(dtype=float)):
2345
+ i, j = wp.tid()
2346
+ if i == 0:
2347
+ x[i, j] += 1.0
2348
+
2349
+ # start with torch tensor
2350
+ arr = torch.zeros((3, 3))
2351
+ torch_device = wp.device_to_torch(device)
2352
+ arr_torch = arr.to(torch_device)
2353
+
2354
+ # wrap as warp array via __cuda_array_interface__
2355
+ arr_warp = wp.array(arr_torch, device=device)
2356
+
2357
+ wp.launch(kernel=first_row_plus_one, dim=(3, 3), inputs=[arr_warp], device=device)
2358
+
2359
+ # re-wrap as torch array
2360
+ arr_torch = wp.to_torch(arr_warp)
2361
+
2362
+ # transpose
2363
+ arr_torch = torch.as_strided(arr_torch, size=(3, 3), stride=(arr_torch.stride(1), arr_torch.stride(0)))
2364
+
2365
+ # re-wrap as warp array with new strides
2366
+ arr_warp = wp.array(arr_torch, device=device)
2367
+
2368
+ wp.launch(kernel=first_row_plus_one, dim=(3, 3), inputs=[arr_warp], device=device)
2369
+
2370
+ assert_np_equal(arr_warp.numpy(), np.array([[2, 1, 1], [1, 0, 0], [1, 0, 0]]))
2371
+
2372
+
2373
+ @wp.kernel
2374
+ def inplace_add_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
2375
+ i = wp.tid()
2376
+ x[i] += y[i]
2377
+
2378
+
2379
+ @wp.kernel
2380
+ def inplace_add_2d(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2381
+ i, j = wp.tid()
2382
+ x[i, j] += y[i, j]
2383
+
2384
+
2385
+ @wp.kernel
2386
+ def inplace_add_3d(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float)):
2387
+ i, j, k = wp.tid()
2388
+ x[i, j, k] += y[i, j, k]
2389
+
2390
+
2391
+ @wp.kernel
2392
+ def inplace_add_4d(x: wp.array4d(dtype=float), y: wp.array4d(dtype=float)):
2393
+ i, j, k, l = wp.tid()
2394
+ x[i, j, k, l] += y[i, j, k, l]
2395
+
2396
+
2397
+ @wp.kernel
2398
+ def inplace_sub_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
2399
+ i = wp.tid()
2400
+ x[i] -= y[i]
2401
+
2402
+
2403
+ @wp.kernel
2404
+ def inplace_sub_2d(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
2405
+ i, j = wp.tid()
2406
+ x[i, j] -= y[i, j]
2407
+
2408
+
2409
+ @wp.kernel
2410
+ def inplace_sub_3d(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float)):
2411
+ i, j, k = wp.tid()
2412
+ x[i, j, k] -= y[i, j, k]
2413
+
2414
+
2415
+ @wp.kernel
2416
+ def inplace_sub_4d(x: wp.array4d(dtype=float), y: wp.array4d(dtype=float)):
2417
+ i, j, k, l = wp.tid()
2418
+ x[i, j, k, l] -= y[i, j, k, l]
2419
+
2420
+
2421
+ @wp.kernel
2422
+ def inplace_add_vecs(x: wp.array(dtype=wp.vec3), y: wp.array(dtype=wp.vec3)):
2423
+ i = wp.tid()
2424
+ x[i] += y[i]
2425
+
2426
+
2427
+ @wp.kernel
2428
+ def inplace_add_mats(x: wp.array(dtype=wp.mat33), y: wp.array(dtype=wp.mat33)):
2429
+ i = wp.tid()
2430
+ x[i] += y[i]
2431
+
2432
+
2433
+ @wp.kernel
2434
+ def inplace_add_rhs(x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
2435
+ i = wp.tid()
2436
+ a = y[i]
2437
+ a += x[i]
2438
+ wp.atomic_add(z, 0, a)
2439
+
2440
+
2441
+ vec9 = wp.vec(length=9, dtype=float)
2442
+
2443
+
2444
+ @wp.kernel
2445
+ def inplace_add_custom_vec(x: wp.array(dtype=vec9), y: wp.array(dtype=vec9)):
2446
+ i = wp.tid()
2447
+ x[i] += y[i]
2448
+ x[i] += y[i]
2449
+
2450
+
2451
+ def test_array_inplace_diff_ops(test, device):
2452
+ N = 3
2453
+ x1 = wp.ones(N, dtype=float, requires_grad=True, device=device)
2454
+ x2 = wp.ones((N, N), dtype=float, requires_grad=True, device=device)
2455
+ x3 = wp.ones((N, N, N), dtype=float, requires_grad=True, device=device)
2456
+ x4 = wp.ones((N, N, N, N), dtype=float, requires_grad=True, device=device)
2457
+
2458
+ y1 = wp.clone(x1, requires_grad=True, device=device)
2459
+ y2 = wp.clone(x2, requires_grad=True, device=device)
2460
+ y3 = wp.clone(x3, requires_grad=True, device=device)
2461
+ y4 = wp.clone(x4, requires_grad=True, device=device)
2462
+
2463
+ v1 = wp.ones(1, dtype=wp.vec3, requires_grad=True, device=device)
2464
+ v2 = wp.clone(v1, requires_grad=True, device=device)
2465
+
2466
+ m1 = wp.ones(1, dtype=wp.mat33, requires_grad=True, device=device)
2467
+ m2 = wp.clone(m1, requires_grad=True, device=device)
2468
+
2469
+ x = wp.ones(1, dtype=float, requires_grad=True, device=device)
2470
+ y = wp.clone(x, requires_grad=True, device=device)
2471
+ z = wp.zeros(1, dtype=float, requires_grad=True, device=device)
2472
+
2473
+ np_ones_1d = np.ones(N, dtype=float)
2474
+ np_ones_2d = np.ones((N, N), dtype=float)
2475
+ np_ones_3d = np.ones((N, N, N), dtype=float)
2476
+ np_ones_4d = np.ones((N, N, N, N), dtype=float)
2477
+
2478
+ np_twos_1d = np.full(N, 2.0, dtype=float)
2479
+ np_twos_2d = np.full((N, N), 2.0, dtype=float)
2480
+ np_twos_3d = np.full((N, N, N), 2.0, dtype=float)
2481
+ np_twos_4d = np.full((N, N, N, N), 2.0, dtype=float)
2482
+
2483
+ tape = wp.Tape()
2484
+ with tape:
2485
+ wp.launch(inplace_add_1d, N, inputs=[x1, y1], device=device)
2486
+ wp.launch(inplace_add_2d, (N, N), inputs=[x2, y2], device=device)
2487
+ wp.launch(inplace_add_3d, (N, N, N), inputs=[x3, y3], device=device)
2488
+ wp.launch(inplace_add_4d, (N, N, N, N), inputs=[x4, y4], device=device)
2489
+
2490
+ tape.backward(grads={x1: wp.ones_like(x1), x2: wp.ones_like(x2), x3: wp.ones_like(x3), x4: wp.ones_like(x4)})
2491
+
2492
+ assert_np_equal(x1.grad.numpy(), np_ones_1d)
2493
+ assert_np_equal(x2.grad.numpy(), np_ones_2d)
2494
+ assert_np_equal(x3.grad.numpy(), np_ones_3d)
2495
+ assert_np_equal(x4.grad.numpy(), np_ones_4d)
2496
+
2497
+ assert_np_equal(y1.grad.numpy(), np_ones_1d)
2498
+ assert_np_equal(y2.grad.numpy(), np_ones_2d)
2499
+ assert_np_equal(y3.grad.numpy(), np_ones_3d)
2500
+ assert_np_equal(y4.grad.numpy(), np_ones_4d)
2501
+
2502
+ assert_np_equal(x1.numpy(), np_twos_1d)
2503
+ assert_np_equal(x2.numpy(), np_twos_2d)
2504
+ assert_np_equal(x3.numpy(), np_twos_3d)
2505
+ assert_np_equal(x4.numpy(), np_twos_4d)
2506
+
2507
+ x1.grad.zero_()
2508
+ x2.grad.zero_()
2509
+ x3.grad.zero_()
2510
+ x4.grad.zero_()
2511
+ tape.reset()
2512
+
2513
+ with tape:
2514
+ wp.launch(inplace_sub_1d, N, inputs=[x1, y1], device=device)
2515
+ wp.launch(inplace_sub_2d, (N, N), inputs=[x2, y2], device=device)
2516
+ wp.launch(inplace_sub_3d, (N, N, N), inputs=[x3, y3], device=device)
2517
+ wp.launch(inplace_sub_4d, (N, N, N, N), inputs=[x4, y4], device=device)
2518
+
2519
+ tape.backward(grads={x1: wp.ones_like(x1), x2: wp.ones_like(x2), x3: wp.ones_like(x3), x4: wp.ones_like(x4)})
2520
+
2521
+ assert_np_equal(x1.grad.numpy(), np_ones_1d)
2522
+ assert_np_equal(x2.grad.numpy(), np_ones_2d)
2523
+ assert_np_equal(x3.grad.numpy(), np_ones_3d)
2524
+ assert_np_equal(x4.grad.numpy(), np_ones_4d)
2525
+
2526
+ assert_np_equal(y1.grad.numpy(), -np_ones_1d)
2527
+ assert_np_equal(y2.grad.numpy(), -np_ones_2d)
2528
+ assert_np_equal(y3.grad.numpy(), -np_ones_3d)
2529
+ assert_np_equal(y4.grad.numpy(), -np_ones_4d)
2530
+
2531
+ assert_np_equal(x1.numpy(), np_ones_1d)
2532
+ assert_np_equal(x2.numpy(), np_ones_2d)
2533
+ assert_np_equal(x3.numpy(), np_ones_3d)
2534
+ assert_np_equal(x4.numpy(), np_ones_4d)
2535
+
2536
+ x1.grad.zero_()
2537
+ x2.grad.zero_()
2538
+ x3.grad.zero_()
2539
+ x4.grad.zero_()
2540
+ tape.reset()
2541
+
2542
+ with tape:
2543
+ wp.launch(inplace_add_vecs, 1, inputs=[v1, v2], device=device)
2544
+ wp.launch(inplace_add_mats, 1, inputs=[m1, m2], device=device)
2545
+ wp.launch(inplace_add_rhs, 1, inputs=[x, y, z], device=device)
2546
+
2547
+ tape.backward(loss=z, grads={v1: wp.ones_like(v1, requires_grad=False), m1: wp.ones_like(m1, requires_grad=False)})
2548
+
2549
+ assert_np_equal(v1.numpy(), np.full(shape=(1, 3), fill_value=2.0, dtype=float))
2550
+ assert_np_equal(v1.grad.numpy(), np.ones(shape=(1, 3), dtype=float))
2551
+ assert_np_equal(v2.grad.numpy(), np.ones(shape=(1, 3), dtype=float))
2552
+
2553
+ assert_np_equal(m1.numpy(), np.full(shape=(1, 3, 3), fill_value=2.0, dtype=float))
2554
+ assert_np_equal(m1.grad.numpy(), np.ones(shape=(1, 3, 3), dtype=float))
2555
+ assert_np_equal(m2.grad.numpy(), np.ones(shape=(1, 3, 3), dtype=float))
2556
+
2557
+ assert_np_equal(x.grad.numpy(), np.ones(1, dtype=float))
2558
+ assert_np_equal(y.grad.numpy(), np.ones(1, dtype=float))
2559
+ tape.reset()
2560
+
2561
+ x = wp.zeros(1, dtype=vec9, requires_grad=True, device=device)
2562
+ y = wp.ones(1, dtype=vec9, requires_grad=True, device=device)
2563
+
2564
+ with tape:
2565
+ wp.launch(inplace_add_custom_vec, 1, inputs=[x, y], device=device)
2566
+
2567
+ tape.backward(grads={x: wp.ones_like(x)})
2568
+
2569
+ assert_np_equal(x.numpy(), np.full((1, 9), 2.0, dtype=float))
2570
+ assert_np_equal(y.grad.numpy(), np.full((1, 9), 2.0, dtype=float))
2571
+
2572
+
2573
+ @wp.kernel
2574
+ def inplace_mul_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
2575
+ i = wp.tid()
2576
+ x[i] *= y[i]
2577
+
2578
+
2579
+ @wp.kernel
2580
+ def inplace_div_1d(x: wp.array(dtype=float), y: wp.array(dtype=float)):
2581
+ i = wp.tid()
2582
+ x[i] /= y[i]
2583
+
2584
+
2585
+ @wp.kernel
2586
+ def inplace_add_non_atomic_types(x: wp.array(dtype=Any), y: wp.array(dtype=Any)):
2587
+ i = wp.tid()
2588
+ x[i] += y[i]
2589
+
2590
+
2591
+ uint16vec3 = wp.vec(length=3, dtype=wp.uint16)
2592
+
2593
+
2594
+ def test_array_inplace_non_diff_ops(test, device):
2595
+ N = 3
2596
+ x1 = wp.full(N, value=10.0, dtype=float, device=device)
2597
+ y1 = wp.full(N, value=5.0, dtype=float, device=device)
2598
+
2599
+ wp.launch(inplace_mul_1d, N, inputs=[x1, y1], device=device)
2600
+ assert_np_equal(x1.numpy(), np.full(N, fill_value=50.0, dtype=float))
2601
+
2602
+ x1.fill_(10.0)
2603
+ y1.fill_(5.0)
2604
+ wp.launch(inplace_div_1d, N, inputs=[x1, y1], device=device)
2605
+ assert_np_equal(x1.numpy(), np.full(N, fill_value=2.0, dtype=float))
2606
+
2607
+ for dtype in wp.types.non_atomic_types + (wp.vec2b, wp.vec2ub, wp.vec2s, wp.vec2us, uint16vec3):
2608
+ x = wp.full(N, value=0, dtype=dtype, device=device)
2609
+ y = wp.full(N, value=1, dtype=dtype, device=device)
2610
+
2611
+ wp.launch(inplace_add_non_atomic_types, N, inputs=[x, y], device=device)
2612
+ assert_np_equal(x.numpy(), y.numpy())
2613
+
2614
+
2615
+ @wp.kernel
2616
+ def inc_scalar(a: wp.array(dtype=float)):
2617
+ tid = wp.tid()
2618
+ a[tid] = a[tid] + 1.0
2619
+
2620
+
2621
+ @wp.kernel
2622
+ def inc_vector(a: wp.array(dtype=wp.vec3f)):
2623
+ tid = wp.tid()
2624
+ a[tid] = a[tid] + wp.vec3f(1.0)
2625
+
2626
+
2627
+ @wp.kernel
2628
+ def inc_matrix(a: wp.array(dtype=wp.mat22f)):
2629
+ tid = wp.tid()
2630
+ a[tid] = a[tid] + wp.mat22f(1.0)
2631
+
2632
+
2633
+ def test_direct_from_numpy(test, device):
2634
+ """Pass NumPy arrays to Warp kernels directly"""
2635
+
2636
+ n = 12
2637
+
2638
+ s = np.arange(n, dtype=np.float32)
2639
+ v = np.arange(n, dtype=np.float32).reshape((n // 3, 3))
2640
+ m = np.arange(n, dtype=np.float32).reshape((n // 4, 2, 2))
2641
+
2642
+ wp.launch(inc_scalar, dim=n, inputs=[s], device=device)
2643
+ wp.launch(inc_vector, dim=n // 3, inputs=[v], device=device)
2644
+ wp.launch(inc_matrix, dim=n // 4, inputs=[m], device=device)
2645
+
2646
+ expected = np.arange(1, n + 1, dtype=np.float32)
2647
+
2648
+ assert_np_equal(s, expected)
2649
+ assert_np_equal(v.reshape(n), expected)
2650
+ assert_np_equal(m.reshape(n), expected)
2651
+
2652
+
2653
+ @wp.kernel
2654
+ def kernel_array_from_ptr(
2655
+ ptr: wp.uint64,
2656
+ ):
2657
+ arr = wp.array(ptr=ptr, shape=(2, 3), dtype=wp.float32)
2658
+ arr[0, 0] = 1.0
2659
+ arr[0, 1] = 2.0
2660
+ arr[0, 2] = 3.0
2661
+
2662
+
2663
+ def test_kernel_array_from_ptr(test, device):
2664
+ arr = wp.zeros(shape=(2, 3), dtype=wp.float32, device=device)
2665
+ wp.launch(kernel_array_from_ptr, dim=(1,), inputs=(arr.ptr,), device=device)
2666
+ assert_np_equal(arr.numpy(), np.array(((1.0, 2.0, 3.0), (0.0, 0.0, 0.0))))
2667
+
2668
+
2669
+ def test_array_from_int32_domain(test, device):
2670
+ wp.zeros(np.array([1504, 1080, 520], dtype=np.int32), dtype=wp.float32, device=device)
2671
+
2672
+
2673
+ def test_array_from_int64_domain(test, device):
2674
+ wp.zeros(np.array([1504, 1080, 520], dtype=np.int64), dtype=wp.float32, device=device)
2675
+
2676
+
2677
+ def test_numpy_array_interface(test, device):
2678
+ # We should be able to convert between NumPy and Warp arrays using __array_interface__ on CPU.
2679
+ # This tests all scalar types supported by both.
2680
+
2681
+ n = 10
2682
+
2683
+ scalar_types = wp.types.scalar_types
2684
+
2685
+ for dtype in scalar_types:
2686
+ # test round trip
2687
+ a1 = wp.zeros(n, dtype=dtype, device="cpu")
2688
+ na = np.array(a1)
2689
+ a2 = wp.array(na, device="cpu")
2690
+
2691
+ assert a1.dtype == a2.dtype
2692
+ assert a1.shape == a2.shape
2693
+ assert a1.strides == a2.strides
2694
+
2695
+
2696
+ @wp.kernel
2697
+ def kernel_indexing_types(
2698
+ arr_1d: wp.array(dtype=wp.int32, ndim=1),
2699
+ arr_2d: wp.array(dtype=wp.int32, ndim=2),
2700
+ arr_3d: wp.array(dtype=wp.int32, ndim=3),
2701
+ arr_4d: wp.array(dtype=wp.int32, ndim=4),
2702
+ ):
2703
+ x = arr_1d[wp.uint8(0)]
2704
+ y = arr_1d[wp.int16(1)]
2705
+ z = arr_1d[wp.uint32(2)]
2706
+ w = arr_1d[wp.int64(3)]
2707
+
2708
+ x = arr_2d[wp.uint8(0), wp.uint8(0)]
2709
+ y = arr_2d[wp.int16(1), wp.int16(1)]
2710
+ z = arr_2d[wp.uint32(2), wp.uint32(2)]
2711
+ w = arr_2d[wp.int64(3), wp.int64(3)]
2712
+
2713
+ x = arr_3d[wp.uint8(0), wp.uint8(0), wp.uint8(0)]
2714
+ y = arr_3d[wp.int16(1), wp.int16(1), wp.int16(1)]
2715
+ z = arr_3d[wp.uint32(2), wp.uint32(2), wp.uint32(2)]
2716
+ w = arr_3d[wp.int64(3), wp.int64(3), wp.int64(3)]
2717
+
2718
+ x = arr_4d[wp.uint8(0), wp.uint8(0), wp.uint8(0), wp.uint8(0)]
2719
+ y = arr_4d[wp.int16(1), wp.int16(1), wp.int16(1), wp.int16(1)]
2720
+ z = arr_4d[wp.uint32(2), wp.uint32(2), wp.uint32(2), wp.uint32(2)]
2721
+ w = arr_4d[wp.int64(3), wp.int64(3), wp.int64(3), wp.int64(3)]
2722
+
2723
+ arr_1d[wp.uint8(0)] = 123
2724
+ arr_1d[wp.int16(1)] = 123
2725
+ arr_1d[wp.uint32(2)] = 123
2726
+ arr_1d[wp.int64(3)] = 123
2727
+
2728
+ arr_2d[wp.uint8(0), wp.uint8(0)] = 123
2729
+ arr_2d[wp.int16(1), wp.int16(1)] = 123
2730
+ arr_2d[wp.uint32(2), wp.uint32(2)] = 123
2731
+ arr_2d[wp.int64(3), wp.int64(3)] = 123
2732
+
2733
+ arr_3d[wp.uint8(0), wp.uint8(0), wp.uint8(0)] = 123
2734
+ arr_3d[wp.int16(1), wp.int16(1), wp.int16(1)] = 123
2735
+ arr_3d[wp.uint32(2), wp.uint32(2), wp.uint32(2)] = 123
2736
+ arr_3d[wp.int64(3), wp.int64(3), wp.int64(3)] = 123
2737
+
2738
+ arr_4d[wp.uint8(0), wp.uint8(0), wp.uint8(0), wp.uint8(0)] = 123
2739
+ arr_4d[wp.int16(1), wp.int16(1), wp.int16(1), wp.int16(1)] = 123
2740
+ arr_4d[wp.uint32(2), wp.uint32(2), wp.uint32(2), wp.uint32(2)] = 123
2741
+ arr_4d[wp.int64(3), wp.int64(3), wp.int64(3), wp.int64(3)] = 123
2742
+
2743
+ wp.atomic_add(arr_1d, wp.uint8(0), 123)
2744
+ wp.atomic_sub(arr_1d, wp.int16(1), 123)
2745
+ wp.atomic_min(arr_1d, wp.uint32(2), 123)
2746
+ wp.atomic_max(arr_1d, wp.int64(3), 123)
2747
+
2748
+ wp.atomic_add(arr_2d, wp.uint8(0), wp.uint8(0), 123)
2749
+ wp.atomic_sub(arr_2d, wp.int16(1), wp.int16(1), 123)
2750
+ wp.atomic_min(arr_2d, wp.uint32(2), wp.uint32(2), 123)
2751
+ wp.atomic_max(arr_2d, wp.int64(3), wp.int64(3), 123)
2752
+
2753
+ wp.atomic_add(arr_3d, wp.uint8(0), wp.uint8(0), wp.uint8(0), 123)
2754
+ wp.atomic_sub(arr_3d, wp.int16(1), wp.int16(1), wp.int16(1), 123)
2755
+ wp.atomic_min(arr_3d, wp.uint32(2), wp.uint32(2), wp.uint32(2), 123)
2756
+ wp.atomic_max(arr_3d, wp.int64(3), wp.int64(3), wp.int64(3), 123)
2757
+
2758
+ wp.atomic_add(arr_4d, wp.uint8(0), wp.uint8(0), wp.uint8(0), wp.uint8(0), 123)
2759
+ wp.atomic_sub(arr_4d, wp.int16(1), wp.int16(1), wp.int16(1), wp.int16(1), 123)
2760
+ wp.atomic_min(arr_4d, wp.uint32(2), wp.uint32(2), wp.uint32(2), wp.uint32(2), 123)
2761
+ wp.atomic_max(arr_4d, wp.int64(3), wp.int64(3), wp.int64(3), wp.int64(3), 123)
2762
+
2763
+
2764
+ def test_indexing_types(test, device):
2765
+ arr_1d = wp.zeros(shape=(4,), dtype=wp.int32, device=device)
2766
+ arr_2d = wp.zeros(shape=(4, 4), dtype=wp.int32, device=device)
2767
+ arr_3d = wp.zeros(shape=(4, 4, 4), dtype=wp.int32, device=device)
2768
+ arr_4d = wp.zeros(shape=(4, 4, 4, 4), dtype=wp.int32, device=device)
2769
+ wp.launch(
2770
+ kernel=kernel_indexing_types,
2771
+ dim=1,
2772
+ inputs=(arr_1d, arr_2d, arr_3d, arr_4d),
2773
+ device=device,
2774
+ )
2775
+
2776
+
2777
+ def test_alloc_strides(test, device):
2778
+ def test_transposed(shape, dtype):
2779
+ # allocate without specifying strides
2780
+ a1 = wp.zeros(shape, dtype=dtype)
2781
+
2782
+ # allocate with contiguous strides
2783
+ strides = wp.types.strides_from_shape(shape, dtype)
2784
+ a2 = wp.zeros(shape, dtype=dtype, strides=strides)
2785
+
2786
+ # allocate with transposed (reversed) shape/strides
2787
+ rshape = shape[::-1]
2788
+ rstrides = strides[::-1]
2789
+ a3 = wp.zeros(rshape, dtype=dtype, strides=rstrides)
2790
+
2791
+ # ensure that correct capacity was allocated
2792
+ assert a2.capacity == a1.capacity
2793
+ assert a3.capacity == a1.capacity
2794
+
2795
+ with wp.ScopedDevice(device):
2796
+ shapes = [(5, 5), (5, 3), (3, 5), (2, 3, 4), (4, 2, 3), (3, 2, 4)]
2797
+ for shape in shapes:
2798
+ with test.subTest(msg=f"shape={shape}"):
2799
+ test_transposed(shape, wp.int8)
2800
+ test_transposed(shape, wp.float32)
2801
+ test_transposed(shape, wp.vec3)
2802
+
2803
+
2804
+ def test_casting(test, device):
2805
+ idxs = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
2806
+ idxs = wp.array(idxs, device=device).reshape((-1, 3))
2807
+ idxs = wp.array(idxs, shape=idxs.shape[0], dtype=wp.vec3i, device=device)
2808
+ assert idxs.dtype is wp.vec3i
2809
+ assert idxs.shape == (4,)
2810
+ assert idxs.strides == (12,)
2811
+
2812
+
2813
+ @wp.kernel
2814
+ def array_len_kernel(
2815
+ a1: wp.array(dtype=int),
2816
+ a2: wp.array(dtype=float, ndim=3),
2817
+ out: wp.array(dtype=int),
2818
+ ):
2819
+ length = len(a1)
2820
+ wp.expect_eq(len(a1), 123)
2821
+ out[0] = len(a1)
2822
+
2823
+ length = len(a2)
2824
+ wp.expect_eq(len(a2), 2)
2825
+ out[1] = len(a2)
2826
+
2827
+
2828
+ def test_array_len(test, device):
2829
+ a1 = wp.zeros(123, dtype=int, device=device)
2830
+ a2 = wp.zeros((2, 3, 4), dtype=float, device=device)
2831
+ out = wp.empty(2, dtype=int, device=device)
2832
+ wp.launch(
2833
+ array_len_kernel,
2834
+ dim=(1,),
2835
+ inputs=(
2836
+ a1,
2837
+ a2,
2838
+ ),
2839
+ outputs=(out,),
2840
+ device=device,
2841
+ )
2842
+
2843
+ test.assertEqual(out.numpy()[0], 123)
2844
+ test.assertEqual(out.numpy()[1], 2)
2845
+
2846
+
2847
+ def test_cuda_interface_conversion(test, device):
2848
+ class MyArrayInterface:
2849
+ def __init__(self, data):
2850
+ self.data = np.array(data)
2851
+ self.__array_interface__ = self.data.__array_interface__
2852
+ self.__cuda_array_interface__ = self.data.__array_interface__
2853
+ self.__len__ = self.data.__len__
2854
+
2855
+ array = MyArrayInterface((1, 2, 3))
2856
+ wp_array = wp.array(array, dtype=wp.int8, device=device)
2857
+ assert wp_array.ptr != 0
2858
+
2859
+ array = MyArrayInterface((1, 2, 3))
2860
+ wp_array = wp.array(array, dtype=wp.float32, device=device)
2861
+ assert wp_array.ptr != 0
2862
+
2863
+ array = MyArrayInterface((1, 2, 3))
2864
+ wp_array = wp.array(array, dtype=wp.vec3, device=device)
2865
+ assert wp_array.ptr != 0
2866
+
2867
+ array = MyArrayInterface((1, 2, 3, 4))
2868
+ wp_array = wp.array(array, dtype=wp.mat22, device=device)
2869
+ assert wp_array.ptr != 0
2870
+
2871
+
2872
+ devices = get_test_devices()
2873
+
2874
+
2875
+ class TestArray(unittest.TestCase):
2876
+ def test_array_new_del(self):
2877
+ # test the scenario in which an array instance is created but not initialized before gc
2878
+ instance = wp.array.__new__(wp.array)
2879
+ instance.__del__()
2880
+
2881
+
2882
+ add_function_test(TestArray, "test_shape", test_shape, devices=devices)
2883
+ add_function_test(TestArray, "test_negative_shape", test_negative_shape, devices=devices)
2884
+ add_function_test(TestArray, "test_flatten", test_flatten, devices=devices)
2885
+ add_function_test(TestArray, "test_reshape", test_reshape, devices=devices)
2886
+ add_function_test(TestArray, "test_slicing", test_slicing, devices=devices)
2887
+ add_function_test(TestArray, "test_transpose", test_transpose, devices=devices)
2888
+ add_function_test(TestArray, "test_view", test_view, devices=devices)
2889
+ add_function_test(TestArray, "test_clone_adjoint", test_clone_adjoint, devices=devices)
2890
+ add_function_test(TestArray, "test_assign_adjoint", test_assign_adjoint, devices=devices)
2891
+
2892
+ add_function_test(TestArray, "test_1d_array", test_1d, devices=devices)
2893
+ add_function_test(TestArray, "test_2d_array", test_2d, devices=devices)
2894
+ add_function_test(TestArray, "test_3d_array", test_3d, devices=devices)
2895
+ add_function_test(TestArray, "test_4d_array", test_4d, devices=devices)
2896
+ add_function_test(TestArray, "test_4d_array_transposed", test_4d_transposed, devices=devices)
2897
+
2898
+ add_function_test(TestArray, "test_fill_scalar", test_fill_scalar, devices=devices)
2899
+ add_function_test(TestArray, "test_fill_vector", test_fill_vector, devices=devices)
2900
+ add_function_test(TestArray, "test_fill_matrix", test_fill_matrix, devices=devices)
2901
+ add_function_test(TestArray, "test_fill_struct", test_fill_struct, devices=devices)
2902
+ add_function_test(TestArray, "test_fill_slices", test_fill_slices, devices=devices)
2903
+ add_function_test(TestArray, "test_full_scalar", test_full_scalar, devices=devices)
2904
+ add_function_test(TestArray, "test_full_vector", test_full_vector, devices=devices)
2905
+ add_function_test(TestArray, "test_full_matrix", test_full_matrix, devices=devices)
2906
+ add_function_test(TestArray, "test_full_struct", test_full_struct, devices=devices)
2907
+ add_function_test(TestArray, "test_ones_scalar", test_ones_scalar, devices=devices)
2908
+ add_function_test(TestArray, "test_ones_vector", test_ones_vector, devices=devices)
2909
+ add_function_test(TestArray, "test_ones_matrix", test_ones_matrix, devices=devices)
2910
+ add_function_test(TestArray, "test_ones_like_scalar", test_ones_like_scalar, devices=devices)
2911
+ add_function_test(TestArray, "test_ones_like_vector", test_ones_like_vector, devices=devices)
2912
+ add_function_test(TestArray, "test_ones_like_matrix", test_ones_like_matrix, devices=devices)
2913
+ add_function_test(TestArray, "test_empty_array", test_empty_array, devices=devices)
2914
+ add_function_test(TestArray, "test_empty_from_numpy", test_empty_from_numpy, devices=devices)
2915
+ add_function_test(TestArray, "test_empty_from_list", test_empty_from_list, devices=devices)
2916
+ add_function_test(TestArray, "test_to_list_scalar", test_to_list_scalar, devices=devices)
2917
+ add_function_test(TestArray, "test_to_list_vector", test_to_list_vector, devices=devices)
2918
+ add_function_test(TestArray, "test_to_list_matrix", test_to_list_matrix, devices=devices)
2919
+ add_function_test(TestArray, "test_to_list_struct", test_to_list_struct, devices=devices)
2920
+
2921
+ add_function_test(TestArray, "test_lower_bound", test_lower_bound, devices=devices)
2922
+ add_function_test(TestArray, "test_round_trip", test_round_trip, devices=devices)
2923
+ add_function_test(TestArray, "test_array_to_bool", test_array_to_bool, devices=devices)
2924
+ add_function_test(TestArray, "test_array_of_structs", test_array_of_structs, devices=devices)
2925
+ add_function_test(TestArray, "test_array_of_structs_grad", test_array_of_structs_grad, devices=devices)
2926
+ add_function_test(TestArray, "test_array_of_structs_from_numpy", test_array_of_structs_from_numpy, devices=devices)
2927
+ add_function_test(TestArray, "test_array_of_structs_roundtrip", test_array_of_structs_roundtrip, devices=devices)
2928
+ add_function_test(TestArray, "test_array_from_numpy", test_array_from_numpy, devices=devices)
2929
+ add_function_test(TestArray, "test_array_aliasing_from_numpy", test_array_aliasing_from_numpy, devices=["cpu"])
2930
+ add_function_test(TestArray, "test_numpy_array_interface", test_numpy_array_interface, devices=["cpu"])
2931
+
2932
+ add_function_test(TestArray, "test_array_inplace_diff_ops", test_array_inplace_diff_ops, devices=devices)
2933
+ add_function_test(TestArray, "test_array_inplace_non_diff_ops", test_array_inplace_non_diff_ops, devices=devices)
2934
+ add_function_test(TestArray, "test_direct_from_numpy", test_direct_from_numpy, devices=["cpu"])
2935
+ add_function_test(TestArray, "test_kernel_array_from_ptr", test_kernel_array_from_ptr, devices=devices)
2936
+
2937
+ add_function_test(TestArray, "test_array_from_int32_domain", test_array_from_int32_domain, devices=devices)
2938
+ add_function_test(TestArray, "test_array_from_int64_domain", test_array_from_int64_domain, devices=devices)
2939
+ add_function_test(TestArray, "test_indexing_types", test_indexing_types, devices=devices)
2940
+
2941
+ add_function_test(TestArray, "test_alloc_strides", test_alloc_strides, devices=devices)
2942
+ add_function_test(TestArray, "test_casting", test_casting, devices=devices)
2943
+ add_function_test(TestArray, "test_array_len", test_array_len, devices=devices)
2944
+ add_function_test(TestArray, "test_cuda_interface_conversion", test_cuda_interface_conversion, devices=devices)
2945
+
2946
+ try:
2947
+ import torch
2948
+
2949
+ # check which Warp devices work with Torch
2950
+ # CUDA devices may fail if Torch was not compiled with CUDA support
2951
+ torch_compatible_devices = []
2952
+ torch_compatible_cuda_devices = []
2953
+
2954
+ for d in devices:
2955
+ try:
2956
+ t = torch.arange(10, device=wp.device_to_torch(d))
2957
+ t += 1
2958
+ torch_compatible_devices.append(d)
2959
+ if d.is_cuda:
2960
+ torch_compatible_cuda_devices.append(d)
2961
+ except Exception as e:
2962
+ print(f"Skipping Array tests that use Torch on device '{d}' due to exception: {e}")
2963
+
2964
+ add_function_test(TestArray, "test_array_from_cai", test_array_from_cai, devices=torch_compatible_cuda_devices)
2965
+
2966
+ except Exception as e:
2967
+ print(f"Skipping Array tests that use Torch due to exception: {e}")
2968
+
2969
+
2970
+ if __name__ == "__main__":
2971
+ wp.clear_kernel_cache()
2972
+ unittest.main(verbosity=2)