warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -738,6 +738,28 @@ def test_tile_extract_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=flo
738
738
  wp.atomic_add(b, i, j, wp.tile_extract(tile, x, y))
739
739
 
740
740
 
741
+ @wp.kernel
742
+ def test_tile_extract_vec_kernel(x: wp.array(dtype=wp.vec3), y: wp.array(dtype=float)):
743
+ i = wp.tid()
744
+
745
+ tile = wp.tile_load(x, shape=(TILE_M))
746
+
747
+ a = tile[i][1]
748
+
749
+ y[i] = a
750
+
751
+
752
+ @wp.kernel
753
+ def test_tile_extract_mat_kernel(x: wp.array(dtype=wp.mat33), y: wp.array(dtype=float)):
754
+ i = wp.tid()
755
+
756
+ tile = wp.tile_load(x, shape=(TILE_M))
757
+
758
+ a = tile[i][1, 1]
759
+
760
+ y[i] = a
761
+
762
+
741
763
  def test_tile_extract(test, device):
742
764
  block_dim = 16
743
765
 
@@ -763,6 +785,40 @@ def test_tile_extract(test, device):
763
785
  expected_grad = np.ones_like(input)
764
786
  assert_np_equal(a.grad.numpy(), expected_grad)
765
787
 
788
+ # vector element test
789
+ x = wp.ones(TILE_M, dtype=wp.vec3, requires_grad=True, device=device)
790
+ y = wp.zeros(TILE_M, dtype=float, requires_grad=True, device=device)
791
+
792
+ with wp.Tape() as tape:
793
+ wp.launch(test_tile_extract_vec_kernel, dim=[TILE_M], inputs=[x, y], block_dim=TILE_DIM, device=device)
794
+
795
+ y.grad = wp.ones_like(y)
796
+
797
+ tape.backward()
798
+
799
+ x_grad_np = np.zeros((TILE_M, 3), dtype=float)
800
+ x_grad_np[:, 1] = 1.0
801
+
802
+ assert_np_equal(x.grad.numpy(), x_grad_np)
803
+ assert_np_equal(y.numpy(), np.ones(TILE_M, dtype=float))
804
+
805
+ # matrix element test
806
+ x = wp.ones(TILE_M, dtype=wp.mat33, requires_grad=True, device=device)
807
+ y = wp.zeros(TILE_M, dtype=float, requires_grad=True, device=device)
808
+
809
+ with wp.Tape() as tape:
810
+ wp.launch(test_tile_extract_mat_kernel, dim=[TILE_M], inputs=[x, y], block_dim=TILE_DIM, device=device)
811
+
812
+ y.grad = wp.ones_like(y)
813
+
814
+ tape.backward()
815
+
816
+ x_grad_np = np.zeros((TILE_M, 3, 3), dtype=float)
817
+ x_grad_np[:, 1, 1] = 1.0
818
+
819
+ assert_np_equal(y.numpy(), np.ones(TILE_M, dtype=float))
820
+ assert_np_equal(x.grad.numpy(), x_grad_np)
821
+
766
822
 
767
823
  @wp.kernel(module="unique")
768
824
  def test_tile_extract_repeated_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=float)):
@@ -822,6 +878,28 @@ def test_tile_assign_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
822
878
  wp.tile_atomic_add(y, a, offset=(0,))
823
879
 
824
880
 
881
+ @wp.kernel
882
+ def test_tile_assign_vec_kernel(x: wp.array(dtype=float), y: wp.array(dtype=wp.vec3)):
883
+ i = wp.tid()
884
+
885
+ a = wp.tile_zeros(shape=(TILE_M,), dtype=wp.vec3)
886
+
887
+ a[i][1] = x[i]
888
+
889
+ wp.tile_atomic_add(y, a, offset=(0,))
890
+
891
+
892
+ @wp.kernel
893
+ def test_tile_assign_mat_kernel(x: wp.array(dtype=float), y: wp.array(dtype=wp.mat33)):
894
+ i = wp.tid()
895
+
896
+ a = wp.tile_zeros(shape=(TILE_M,), dtype=wp.mat33)
897
+
898
+ a[i][1, 1] = x[i]
899
+
900
+ wp.tile_atomic_add(y, a, offset=(0,))
901
+
902
+
825
903
  def test_tile_assign(test, device):
826
904
  x = wp.full(TILE_M, 2.0, dtype=float, device=device, requires_grad=True)
827
905
  y = wp.zeros(TILE_M, dtype=float, device=device, requires_grad=True)
@@ -836,6 +914,100 @@ def test_tile_assign(test, device):
836
914
  assert_np_equal(y.numpy(), np.full(TILE_M, 2.0, dtype=np.float32))
837
915
  assert_np_equal(x.grad.numpy(), np.full(TILE_M, 1.0, dtype=np.float32))
838
916
 
917
+ # vector element test
918
+ x = wp.full(TILE_M, 2.0, dtype=float, device=device, requires_grad=True)
919
+ y = wp.zeros(TILE_M, dtype=wp.vec3, device=device, requires_grad=True)
920
+
921
+ tape = wp.Tape()
922
+ with tape:
923
+ wp.launch(test_tile_assign_vec_kernel, dim=[TILE_M], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
924
+
925
+ y.grad = wp.ones_like(y)
926
+ tape.backward()
927
+
928
+ y_np = np.zeros((TILE_M, 3), dtype=float)
929
+ y_np[:, 1] = 2.0
930
+
931
+ assert_np_equal(y.numpy(), y_np)
932
+ assert_np_equal(x.grad.numpy(), np.full(TILE_M, 1.0, dtype=np.float32))
933
+
934
+ # matrix element test
935
+ x = wp.full(TILE_M, 2.0, dtype=float, device=device, requires_grad=True)
936
+ y = wp.zeros(TILE_M, dtype=wp.mat33, device=device, requires_grad=True)
937
+
938
+ tape = wp.Tape()
939
+ with tape:
940
+ wp.launch(test_tile_assign_mat_kernel, dim=[TILE_M], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
941
+
942
+ y.grad = wp.ones_like(y)
943
+ tape.backward()
944
+
945
+ y_np = np.zeros((TILE_M, 3, 3), dtype=float)
946
+ y_np[:, 1, 1] = 2.0
947
+
948
+ assert_np_equal(y.numpy(), y_np)
949
+ assert_np_equal(x.grad.numpy(), np.full(TILE_M, 1.0, dtype=np.float32))
950
+
951
+
952
+ @wp.kernel
953
+ def test_tile_where_kernel(select: int, x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
954
+ x_reg = wp.tile_load(x, shape=(TILE_M,), storage="register")
955
+ y_reg = wp.tile_load(y, shape=(TILE_M,), storage="register")
956
+
957
+ x_shared = wp.tile_load(x, shape=(TILE_M,), storage="shared")
958
+ y_shared = wp.tile_load(y, shape=(TILE_M,), storage="shared")
959
+
960
+ if select == 0:
961
+ s = x_reg
962
+ elif select == 1:
963
+ s = y_reg
964
+ elif select == 2:
965
+ s = x_shared
966
+ else:
967
+ s = y_shared
968
+
969
+ wp.tile_store(z, s)
970
+
971
+
972
+ def test_tile_where(test, device):
973
+ x = wp.full((TILE_M,), 1.0, dtype=float, device=device, requires_grad=True)
974
+ y = wp.full((TILE_M,), 2.0, dtype=float, device=device, requires_grad=True)
975
+ z = wp.zeros((TILE_M), dtype=float, device=device, requires_grad=True)
976
+
977
+ z_expected = [
978
+ np.full(TILE_M, 1.0, dtype=np.float32),
979
+ np.full(TILE_M, 2.0, dtype=np.float32),
980
+ np.full(TILE_M, 1.0, dtype=np.float32),
981
+ np.full(TILE_M, 2.0, dtype=np.float32),
982
+ ]
983
+ x_grad_expected = [
984
+ np.full(TILE_M, 1.0, dtype=np.float32),
985
+ np.full(TILE_M, 0.0, dtype=np.float32),
986
+ np.full(TILE_M, 1.0, dtype=np.float32),
987
+ np.full(TILE_M, 0.0, dtype=np.float32),
988
+ ]
989
+ y_grad_expected = [
990
+ np.full(TILE_M, 0.0, dtype=np.float32),
991
+ np.full(TILE_M, 1.0, dtype=np.float32),
992
+ np.full(TILE_M, 0.0, dtype=np.float32),
993
+ np.full(TILE_M, 1.0, dtype=np.float32),
994
+ ]
995
+
996
+ for i in range(4):
997
+ tape = wp.Tape()
998
+ with tape:
999
+ wp.launch_tiled(test_tile_where_kernel, dim=[1], inputs=[i, x, y], outputs=[z], block_dim=32, device=device)
1000
+
1001
+ z.grad = wp.ones_like(z)
1002
+
1003
+ tape.backward()
1004
+
1005
+ assert_np_equal(z.numpy(), z_expected[i])
1006
+ assert_np_equal(x.grad.numpy(), x_grad_expected[i])
1007
+ assert_np_equal(y.grad.numpy(), y_grad_expected[i])
1008
+
1009
+ tape.zero()
1010
+
839
1011
 
840
1012
  @wp.kernel
841
1013
  def test_tile_transpose_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
@@ -1118,6 +1290,71 @@ def test_tile_len(test, device):
1118
1290
  test.assertEqual(out.numpy()[0], TILE_M)
1119
1291
 
1120
1292
 
1293
+ @wp.struct
1294
+ class TestStruct:
1295
+ x: wp.float32
1296
+ y: wp.vec3
1297
+
1298
+
1299
+ @wp.kernel
1300
+ def test_tile_construction_kernel(
1301
+ out_zeros: wp.array(dtype=float),
1302
+ out_ones: wp.array(dtype=float),
1303
+ out_arange: wp.array(dtype=float),
1304
+ out_full_twos: wp.array(dtype=float),
1305
+ out_full_vecs: wp.array(dtype=wp.vec3),
1306
+ out_full_mats: wp.array(dtype=wp.mat33),
1307
+ out_full_structs: wp.array(dtype=TestStruct),
1308
+ ):
1309
+ zeros = wp.tile_zeros(TILE_M, dtype=float)
1310
+ ones = wp.tile_ones(TILE_M, dtype=float)
1311
+ arange = wp.tile_arange(TILE_M, dtype=float)
1312
+ full_twos = wp.tile_full(TILE_M, value=2.0, dtype=float)
1313
+ full_vecs = wp.tile_full(TILE_M, value=wp.vec3(1.0), dtype=wp.vec3)
1314
+ full_mats = wp.tile_full(TILE_M, value=wp.mat33(1.0), dtype=wp.mat33)
1315
+
1316
+ ts = TestStruct()
1317
+ ts.x = wp.float32(2.0)
1318
+ ts.y = wp.vec3(1.0)
1319
+ full_structs = wp.tile_full(TILE_M, value=ts, dtype=TestStruct)
1320
+
1321
+ wp.tile_store(out_zeros, zeros)
1322
+ wp.tile_store(out_ones, ones)
1323
+ wp.tile_store(out_arange, arange)
1324
+ wp.tile_store(out_full_twos, full_twos)
1325
+ wp.tile_store(out_full_vecs, full_vecs)
1326
+ wp.tile_store(out_full_mats, full_mats)
1327
+ wp.tile_store(out_full_structs, full_structs)
1328
+
1329
+
1330
+ def test_tile_construction(test, device):
1331
+ zeros = wp.empty(TILE_M, dtype=float, device=device)
1332
+ ones = wp.empty(TILE_M, dtype=float, device=device)
1333
+ arange = wp.empty(TILE_M, dtype=float, device=device)
1334
+ full_twos = wp.empty(TILE_M, dtype=float, device=device)
1335
+ full_vecs = wp.empty(TILE_M, dtype=wp.vec3, device=device)
1336
+ full_mats = wp.empty(TILE_M, dtype=wp.mat33, device=device)
1337
+ full_structs = wp.empty(TILE_M, dtype=TestStruct, device=device)
1338
+
1339
+ wp.launch_tiled(
1340
+ test_tile_construction_kernel,
1341
+ dim=1,
1342
+ inputs=[],
1343
+ outputs=[zeros, ones, arange, full_twos, full_vecs, full_mats, full_structs],
1344
+ block_dim=TILE_DIM,
1345
+ device=device,
1346
+ )
1347
+
1348
+ assert_np_equal(zeros.numpy(), np.zeros(TILE_M, dtype=float))
1349
+ assert_np_equal(ones.numpy(), np.ones(TILE_M, dtype=float))
1350
+ assert_np_equal(full_twos.numpy(), np.full(TILE_M, 2.0, dtype=float))
1351
+ assert_np_equal(full_vecs.numpy(), np.ones((TILE_M, 3), dtype=float))
1352
+ assert_np_equal(full_mats.numpy(), np.ones((TILE_M, 3, 3), dtype=float))
1353
+ assert_np_equal(full_structs.numpy()["x"], np.full(TILE_M, 2.0, dtype=float))
1354
+ assert_np_equal(full_structs.numpy()["y"], np.ones((TILE_M, 3), dtype=float))
1355
+ assert_np_equal(arange.numpy(), np.arange(TILE_M, dtype=float))
1356
+
1357
+
1121
1358
  @wp.kernel
1122
1359
  def test_tile_print_kernel():
1123
1360
  # shared tile
@@ -1261,6 +1498,7 @@ add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, device
1261
1498
  add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
1262
1499
  add_function_test(TestTile, "test_tile_extract_repeated", test_tile_extract_repeated, devices=devices)
1263
1500
  add_function_test(TestTile, "test_tile_assign", test_tile_assign, devices=devices)
1501
+ add_function_test(TestTile, "test_tile_where", test_tile_where, devices=devices)
1264
1502
  add_function_test(TestTile, "test_tile_broadcast_add_1d", test_tile_broadcast_add_1d, devices=devices)
1265
1503
  add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_add_2d, devices=devices)
1266
1504
  add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
@@ -1269,6 +1507,7 @@ add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad
1269
1507
  add_function_test(TestTile, "test_tile_squeeze", test_tile_squeeze, devices=devices)
1270
1508
  add_function_test(TestTile, "test_tile_reshape", test_tile_reshape, devices=devices)
1271
1509
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
1510
+ add_function_test(TestTile, "test_tile_construction", test_tile_construction, devices=devices)
1272
1511
  # add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1273
1512
  # add_function_test(TestTile, "test_tile_inplace", test_tile_inplace, devices=devices)
1274
1513
  # add_function_test(TestTile, "test_tile_astype", test_tile_astype, devices=devices)
@@ -0,0 +1,403 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ @wp.kernel
25
+ def test_tile_atomic_bitwise_scalar_kernel(
26
+ a: wp.array(dtype=wp.uint32), b: wp.array(dtype=wp.uint32), c: wp.array(dtype=wp.uint32), op_type: int
27
+ ):
28
+ word_idx, bit_idx = wp.tid()
29
+ block_dim = wp.block_dim()
30
+ assert block_dim == 32
31
+ s = wp.tile_zeros(shape=1, dtype=wp.uint32)
32
+ # write to tile first, then write only once to the array
33
+ s[0] = a[word_idx]
34
+ if op_type < 3:
35
+ bit_mask = wp.uint32(1) << wp.uint32(bit_idx)
36
+ if op_type == 0:
37
+ s[0] &= (b[word_idx] & bit_mask) | ~bit_mask
38
+ elif op_type == 1:
39
+ s[0] |= b[word_idx] & bit_mask
40
+ elif op_type == 2:
41
+ s[0] ^= b[word_idx] & bit_mask
42
+ else:
43
+ # inter-tile operations
44
+ s_bit_mask = wp.tile_zeros(shape=32, dtype=wp.uint32)
45
+ s_bit_mask[(bit_idx + 1) % 32] = wp.uint32(1) << wp.uint32((bit_idx + 1) % 32)
46
+ if op_type == 3:
47
+ s[0] &= (b[word_idx] & s_bit_mask[bit_idx]) | ~s_bit_mask[bit_idx]
48
+ elif op_type == 4:
49
+ s[0] |= b[word_idx] & s_bit_mask[bit_idx]
50
+ elif op_type == 5:
51
+ s[0] ^= b[word_idx] & s_bit_mask[bit_idx]
52
+ c[word_idx] = s[0]
53
+
54
+
55
+ @wp.kernel
56
+ def test_tile_atomic_bitwise_scalar_tilewise_kernel(
57
+ a: wp.array(dtype=wp.uint32), b: wp.array(dtype=wp.uint32), c: wp.array(dtype=wp.uint32), op_type: int
58
+ ):
59
+ batch_idx, _ = wp.tid()
60
+ block_dim = wp.block_dim()
61
+ assert block_dim == 32
62
+ # Each tile is responsible for a batch of 32 elements
63
+ s1 = wp.tile_load(a, shape=32, offset=batch_idx * 32)
64
+ s2 = wp.tile_load(b, shape=32, offset=batch_idx * 32)
65
+ # inter-tile operations (batch-wise)
66
+ if op_type < 9:
67
+ if op_type == 6:
68
+ s1 &= s2
69
+ elif op_type == 7:
70
+ s1 |= s2
71
+ elif op_type == 8:
72
+ s1 ^= s2
73
+ wp.tile_store(c, s1, offset=batch_idx * 32)
74
+ else:
75
+ if op_type == 9:
76
+ s3 = s1 & s2
77
+ elif op_type == 10:
78
+ s3 = s1 | s2
79
+ elif op_type == 11:
80
+ s3 = s1 ^ s2
81
+ wp.tile_store(c, s3, offset=batch_idx * 32)
82
+
83
+
84
+ def test_tile_atomic_bitwise_scalar(test, device):
85
+ n = 1024
86
+ rng = np.random.default_rng(42)
87
+
88
+ a = rng.integers(0, np.iinfo(np.uint32).max, size=n, dtype=np.uint32)
89
+ b = rng.integers(0, np.iinfo(np.uint32).max, size=n, dtype=np.uint32)
90
+
91
+ expected_and = a & b
92
+ expected_or = a | b
93
+ expected_xor = a ^ b
94
+
95
+ with wp.ScopedDevice(device):
96
+ a_wp = wp.array(a, dtype=wp.uint32, device=device)
97
+ b_wp = wp.array(b, dtype=wp.uint32, device=device)
98
+ c_wp = wp.zeros(shape=n, dtype=wp.uint32, device=device)
99
+
100
+ wp.launch_tiled(test_tile_atomic_bitwise_scalar_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 0], block_dim=32)
101
+ assert_np_equal(c_wp.numpy(), expected_and)
102
+ wp.launch_tiled(test_tile_atomic_bitwise_scalar_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 1], block_dim=32)
103
+ assert_np_equal(c_wp.numpy(), expected_or)
104
+ wp.launch_tiled(test_tile_atomic_bitwise_scalar_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 2], block_dim=32)
105
+ assert_np_equal(c_wp.numpy(), expected_xor)
106
+ wp.launch_tiled(test_tile_atomic_bitwise_scalar_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 3], block_dim=32)
107
+ assert_np_equal(c_wp.numpy(), expected_and)
108
+ wp.launch_tiled(test_tile_atomic_bitwise_scalar_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 4], block_dim=32)
109
+ assert_np_equal(c_wp.numpy(), expected_or)
110
+ wp.launch_tiled(test_tile_atomic_bitwise_scalar_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 5], block_dim=32)
111
+ assert_np_equal(c_wp.numpy(), expected_xor)
112
+
113
+ wp.launch_tiled(
114
+ test_tile_atomic_bitwise_scalar_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 6], block_dim=32
115
+ )
116
+ assert_np_equal(c_wp.numpy(), expected_and)
117
+ wp.launch_tiled(
118
+ test_tile_atomic_bitwise_scalar_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 7], block_dim=32
119
+ )
120
+ assert_np_equal(c_wp.numpy(), expected_or)
121
+ wp.launch_tiled(
122
+ test_tile_atomic_bitwise_scalar_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 8], block_dim=32
123
+ )
124
+ assert_np_equal(c_wp.numpy(), expected_xor)
125
+ wp.launch_tiled(
126
+ test_tile_atomic_bitwise_scalar_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 9], block_dim=32
127
+ )
128
+ assert_np_equal(c_wp.numpy(), expected_and)
129
+ wp.launch_tiled(
130
+ test_tile_atomic_bitwise_scalar_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 10], block_dim=32
131
+ )
132
+ assert_np_equal(c_wp.numpy(), expected_or)
133
+ wp.launch_tiled(
134
+ test_tile_atomic_bitwise_scalar_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 11], block_dim=32
135
+ )
136
+ assert_np_equal(c_wp.numpy(), expected_xor)
137
+
138
+
139
+ @wp.kernel
140
+ def test_tile_atomic_bitwise_vector_kernel(
141
+ a: wp.array(dtype=wp.vec3ui), b: wp.array(dtype=wp.vec3ui), c: wp.array(dtype=wp.vec3ui), op_type: int
142
+ ):
143
+ word_idx, bit_idx = wp.tid()
144
+ block_dim = wp.block_dim()
145
+ assert block_dim == 32
146
+ s = wp.tile_zeros(shape=1, dtype=wp.vec3ui)
147
+ # write to tile first, then write only once to the array
148
+ s[0] = a[word_idx]
149
+ if op_type < 3:
150
+ bit_mask = wp.vec3ui(wp.uint32(1)) << wp.vec3ui(wp.uint32(bit_idx))
151
+ if op_type == 0:
152
+ s[0] &= (b[word_idx] & bit_mask) | ~bit_mask
153
+ elif op_type == 1:
154
+ s[0] |= b[word_idx] & bit_mask
155
+ elif op_type == 2:
156
+ s[0] ^= b[word_idx] & bit_mask
157
+ else:
158
+ # inter-tile operations
159
+ s_bit_mask = wp.tile_zeros(shape=32, dtype=wp.vec3ui)
160
+ s_bit_mask[(bit_idx + 1) % 32] = wp.vec3ui(wp.uint32(1) << wp.uint32((bit_idx + 1) % 32))
161
+ if op_type == 3:
162
+ s[0] &= (b[word_idx] & s_bit_mask[bit_idx]) | ~s_bit_mask[bit_idx]
163
+ elif op_type == 4:
164
+ s[0] |= b[word_idx] & s_bit_mask[bit_idx]
165
+ elif op_type == 5:
166
+ s[0] ^= b[word_idx] & s_bit_mask[bit_idx]
167
+ c[word_idx] = s[0]
168
+
169
+
170
+ @wp.kernel
171
+ def test_tile_atomic_bitwise_vector_tilewise_kernel(
172
+ a: wp.array(dtype=wp.vec3ui), b: wp.array(dtype=wp.vec3ui), c: wp.array(dtype=wp.vec3ui), op_type: int
173
+ ):
174
+ batch_idx, _ = wp.tid()
175
+ block_dim = wp.block_dim()
176
+ assert block_dim == 32
177
+ # Each tile is responsible for a batch of 32 elements
178
+ s1 = wp.tile_load(a, shape=32, offset=batch_idx * 32)
179
+ s2 = wp.tile_load(b, shape=32, offset=batch_idx * 32)
180
+ # inter-tile operations (batch-wise)
181
+ if op_type < 9:
182
+ if op_type == 6:
183
+ s1 &= s2
184
+ elif op_type == 7:
185
+ s1 |= s2
186
+ elif op_type == 8:
187
+ s1 ^= s2
188
+ wp.tile_store(c, s1, offset=batch_idx * 32)
189
+ else:
190
+ if op_type == 9:
191
+ s3 = s1 & s2
192
+ elif op_type == 10:
193
+ s3 = s1 | s2
194
+ elif op_type == 11:
195
+ s3 = s1 ^ s2
196
+ wp.tile_store(c, s3, offset=batch_idx * 32)
197
+
198
+
199
+ def test_tile_atomic_bitwise_vector(test, device):
200
+ n = 1024
201
+ rng = np.random.default_rng(42)
202
+
203
+ a = rng.integers(0, np.iinfo(np.uint32).max, size=(n, 3), dtype=np.uint32)
204
+ b = rng.integers(0, np.iinfo(np.uint32).max, size=(n, 3), dtype=np.uint32)
205
+
206
+ expected_and = a & b
207
+ expected_or = a | b
208
+ expected_xor = a ^ b
209
+
210
+ with wp.ScopedDevice(device):
211
+ a_wp = wp.array(a, dtype=wp.vec3ui, device=device)
212
+ b_wp = wp.array(b, dtype=wp.vec3ui, device=device)
213
+ c_wp = wp.zeros(shape=n, dtype=wp.vec3ui, device=device)
214
+
215
+ wp.launch_tiled(test_tile_atomic_bitwise_vector_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 0], block_dim=32)
216
+ assert_np_equal(c_wp.numpy(), expected_and)
217
+ wp.launch_tiled(test_tile_atomic_bitwise_vector_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 1], block_dim=32)
218
+ assert_np_equal(c_wp.numpy(), expected_or)
219
+ wp.launch_tiled(test_tile_atomic_bitwise_vector_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 2], block_dim=32)
220
+ assert_np_equal(c_wp.numpy(), expected_xor)
221
+ wp.launch_tiled(test_tile_atomic_bitwise_vector_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 3], block_dim=32)
222
+ assert_np_equal(c_wp.numpy(), expected_and)
223
+ wp.launch_tiled(test_tile_atomic_bitwise_vector_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 4], block_dim=32)
224
+ assert_np_equal(c_wp.numpy(), expected_or)
225
+ wp.launch_tiled(test_tile_atomic_bitwise_vector_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 5], block_dim=32)
226
+ assert_np_equal(c_wp.numpy(), expected_xor)
227
+
228
+ wp.launch_tiled(
229
+ test_tile_atomic_bitwise_vector_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 6], block_dim=32
230
+ )
231
+ assert_np_equal(c_wp.numpy(), expected_and)
232
+ wp.launch_tiled(
233
+ test_tile_atomic_bitwise_vector_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 7], block_dim=32
234
+ )
235
+ assert_np_equal(c_wp.numpy(), expected_or)
236
+ wp.launch_tiled(
237
+ test_tile_atomic_bitwise_vector_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 8], block_dim=32
238
+ )
239
+ assert_np_equal(c_wp.numpy(), expected_xor)
240
+ wp.launch_tiled(
241
+ test_tile_atomic_bitwise_vector_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 9], block_dim=32
242
+ )
243
+ assert_np_equal(c_wp.numpy(), expected_and)
244
+ wp.launch_tiled(
245
+ test_tile_atomic_bitwise_vector_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 10], block_dim=32
246
+ )
247
+ assert_np_equal(c_wp.numpy(), expected_or)
248
+ wp.launch_tiled(
249
+ test_tile_atomic_bitwise_vector_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 11], block_dim=32
250
+ )
251
+ assert_np_equal(c_wp.numpy(), expected_xor)
252
+
253
+
254
+ mat33ui = wp._src.types.matrix(shape=(3, 3), dtype=wp.uint32)
255
+
256
+
257
+ @wp.kernel
258
+ def test_tile_atomic_bitwise_matrix_kernel(
259
+ a: wp.array(dtype=mat33ui), b: wp.array(dtype=mat33ui), c: wp.array(dtype=mat33ui), op_type: int
260
+ ):
261
+ word_idx, bit_idx = wp.tid()
262
+ block_dim = wp.block_dim()
263
+ assert block_dim == 32
264
+ s = wp.tile_zeros(shape=1, dtype=mat33ui)
265
+ # write to tile first, then write only once to the array
266
+ s[0] = a[word_idx]
267
+ if op_type < 3:
268
+ bit_mask = mat33ui(wp.uint32(1)) << mat33ui(wp.uint32(bit_idx))
269
+ if op_type == 0:
270
+ s[0] &= (b[word_idx] & bit_mask) | ~bit_mask
271
+ elif op_type == 1:
272
+ s[0] |= b[word_idx] & bit_mask
273
+ elif op_type == 2:
274
+ s[0] ^= b[word_idx] & bit_mask
275
+ else:
276
+ # inter-tile operations
277
+ s_bit_mask = wp.tile_zeros(shape=32, dtype=mat33ui)
278
+ s_bit_mask[(bit_idx + 1) % 32] = mat33ui(wp.uint32(1) << wp.uint32((bit_idx + 1) % 32))
279
+ if op_type == 3:
280
+ s[0] &= (b[word_idx] & s_bit_mask[bit_idx]) | ~s_bit_mask[bit_idx]
281
+ elif op_type == 4:
282
+ s[0] |= b[word_idx] & s_bit_mask[bit_idx]
283
+ elif op_type == 5:
284
+ s[0] ^= b[word_idx] & s_bit_mask[bit_idx]
285
+ c[word_idx] = s[0]
286
+
287
+
288
+ @wp.kernel
289
+ def test_tile_atomic_bitwise_matrix_tilewise_kernel(
290
+ a: wp.array(dtype=mat33ui), b: wp.array(dtype=mat33ui), c: wp.array(dtype=mat33ui), op_type: int
291
+ ):
292
+ batch_idx, _ = wp.tid()
293
+ block_dim = wp.block_dim()
294
+ assert block_dim == 32
295
+ # Each tile is responsible for a batch of 32 elements
296
+ s1 = wp.tile_load(a, shape=32, offset=batch_idx * 32)
297
+ s2 = wp.tile_load(b, shape=32, offset=batch_idx * 32)
298
+ # inter-tile operations (batch-wise)
299
+ if op_type < 9:
300
+ if op_type == 6:
301
+ s1 &= s2
302
+ elif op_type == 7:
303
+ s1 |= s2
304
+ elif op_type == 8:
305
+ s1 ^= s2
306
+ wp.tile_store(c, s1, offset=batch_idx * 32)
307
+ else:
308
+ if op_type == 9:
309
+ s3 = s1 & s2
310
+ elif op_type == 10:
311
+ s3 = s1 | s2
312
+ elif op_type == 11:
313
+ s3 = s1 ^ s2
314
+ wp.tile_store(c, s3, offset=batch_idx * 32)
315
+
316
+
317
+ def test_tile_atomic_bitwise_matrix(test, device):
318
+ n = 1024
319
+ rng = np.random.default_rng(42)
320
+
321
+ a = rng.integers(0, np.iinfo(np.uint32).max, size=(n, 3, 3), dtype=np.uint32)
322
+ b = rng.integers(0, np.iinfo(np.uint32).max, size=(n, 3, 3), dtype=np.uint32)
323
+
324
+ expected_and = a & b
325
+ expected_or = a | b
326
+ expected_xor = a ^ b
327
+
328
+ with wp.ScopedDevice(device):
329
+ a_wp = wp.array(a, dtype=mat33ui, device=device)
330
+ b_wp = wp.array(b, dtype=mat33ui, device=device)
331
+ c_wp = wp.zeros(shape=n, dtype=mat33ui, device=device)
332
+
333
+ wp.launch_tiled(test_tile_atomic_bitwise_matrix_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 0], block_dim=32)
334
+ assert_np_equal(c_wp.numpy(), expected_and)
335
+ wp.launch_tiled(test_tile_atomic_bitwise_matrix_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 1], block_dim=32)
336
+ assert_np_equal(c_wp.numpy(), expected_or)
337
+ wp.launch_tiled(test_tile_atomic_bitwise_matrix_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 2], block_dim=32)
338
+ assert_np_equal(c_wp.numpy(), expected_xor)
339
+ wp.launch_tiled(test_tile_atomic_bitwise_matrix_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 3], block_dim=32)
340
+ assert_np_equal(c_wp.numpy(), expected_and)
341
+ wp.launch_tiled(test_tile_atomic_bitwise_matrix_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 4], block_dim=32)
342
+ assert_np_equal(c_wp.numpy(), expected_or)
343
+ wp.launch_tiled(test_tile_atomic_bitwise_matrix_kernel, dim=n, inputs=[a_wp, b_wp, c_wp, 5], block_dim=32)
344
+ assert_np_equal(c_wp.numpy(), expected_xor)
345
+
346
+ wp.launch_tiled(
347
+ test_tile_atomic_bitwise_matrix_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 6], block_dim=32
348
+ )
349
+ assert_np_equal(c_wp.numpy(), expected_and)
350
+ wp.launch_tiled(
351
+ test_tile_atomic_bitwise_matrix_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 7], block_dim=32
352
+ )
353
+ assert_np_equal(c_wp.numpy(), expected_or)
354
+ wp.launch_tiled(
355
+ test_tile_atomic_bitwise_matrix_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 8], block_dim=32
356
+ )
357
+ assert_np_equal(c_wp.numpy(), expected_xor)
358
+ wp.launch_tiled(
359
+ test_tile_atomic_bitwise_matrix_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 9], block_dim=32
360
+ )
361
+ assert_np_equal(c_wp.numpy(), expected_and)
362
+ wp.launch_tiled(
363
+ test_tile_atomic_bitwise_matrix_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 10], block_dim=32
364
+ )
365
+ assert_np_equal(c_wp.numpy(), expected_or)
366
+ wp.launch_tiled(
367
+ test_tile_atomic_bitwise_matrix_tilewise_kernel, dim=n // 32, inputs=[a_wp, b_wp, c_wp, 11], block_dim=32
368
+ )
369
+ assert_np_equal(c_wp.numpy(), expected_xor)
370
+
371
+
372
+ devices = get_cuda_test_devices()
373
+
374
+
375
+ class TestTileAtomicBitwise(unittest.TestCase):
376
+ pass
377
+
378
+
379
+ add_function_test(
380
+ TestTileAtomicBitwise,
381
+ "test_tile_atomic_bitwise_scalar",
382
+ test_tile_atomic_bitwise_scalar,
383
+ devices=devices,
384
+ )
385
+
386
+ add_function_test(
387
+ TestTileAtomicBitwise,
388
+ "test_tile_atomic_bitwise_vector",
389
+ test_tile_atomic_bitwise_vector,
390
+ devices=devices,
391
+ )
392
+
393
+ add_function_test(
394
+ TestTileAtomicBitwise,
395
+ "test_tile_atomic_bitwise_matrix",
396
+ test_tile_atomic_bitwise_matrix,
397
+ devices=devices,
398
+ )
399
+
400
+
401
+ if __name__ == "__main__":
402
+ wp.clear_kernel_cache()
403
+ unittest.main(verbosity=2)