warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import numpy as np
20
20
  import warp as wp
21
21
  from warp.tests.unittest_utils import *
22
22
 
23
- wp.init() # For wp.context.runtime.core.wp_is_mathdx_enabled()
23
+ wp.init() # For wp._src.context.runtime.core.wp_is_mathdx_enabled()
24
24
 
25
25
  TILE_M = wp.constant(8)
26
26
  TILE_N = wp.constant(4)
@@ -490,7 +490,7 @@ def test_tile_upper_solve(L: wp.array2d(dtype=float), y: wp.array(dtype=float),
490
490
 
491
491
 
492
492
  def test_tile_cholesky_singular_matrices(test, device):
493
- if not wp.context.runtime.core.wp_is_mathdx_enabled():
493
+ if not wp._src.context.runtime.core.wp_is_mathdx_enabled():
494
494
  test.skipTest("MathDx is not enabled")
495
495
 
496
496
  rng = np.random.default_rng(42)
@@ -527,8 +527,11 @@ cuda_devices = get_cuda_test_devices()
527
527
 
528
528
 
529
529
  @unittest.skipUnless(
530
- not wp.context.runtime.core.wp_is_mathdx_enabled()
531
- or (wp.context.runtime.core.wp_is_mathdx_enabled() and wp.context.runtime.core.wp_cuda_toolkit_version() >= 12060),
530
+ not wp._src.context.runtime.core.wp_is_mathdx_enabled()
531
+ or (
532
+ wp._src.context.runtime.core.wp_is_mathdx_enabled()
533
+ and wp._src.context.runtime.core.wp_cuda_toolkit_version() >= 12060
534
+ ),
532
535
  "MathDx is not enabled or is enabled but CUDA toolkit version is less than 12.6",
533
536
  )
534
537
  class TestTileCholesky(unittest.TestCase):
@@ -40,6 +40,7 @@ def tile_load_1d_kernel(
40
40
  input: wp.array1d(dtype=float),
41
41
  out_full: wp.array1d(dtype=float),
42
42
  out_padded: wp.array1d(dtype=float),
43
+ out_sliced: wp.array1d(dtype=float),
43
44
  out_offset: wp.array1d(dtype=float),
44
45
  ):
45
46
  full0 = wp.tile_load(input, TILE_M)
@@ -50,8 +51,13 @@ def tile_load_1d_kernel(
50
51
  padded1 = wp.tile_load(input, shape=TILE_M, offset=TILE_OFFSET)
51
52
  padded2 = wp.tile_load(input, shape=(TILE_M,), offset=(TILE_OFFSET,))
52
53
 
54
+ sliced0 = wp.tile_load(input[::2], TILE_M)
55
+ sliced1 = wp.tile_load(input[::2], shape=TILE_M)
56
+ sliced2 = wp.tile_load(input[::2], shape=(TILE_M,))
57
+
53
58
  wp.tile_store(out_full, full0)
54
59
  wp.tile_store(out_padded, padded0)
60
+ wp.tile_store(out_sliced, sliced0)
55
61
  wp.tile_store(out_offset, full0, offset=(TILE_OFFSET,))
56
62
 
57
63
 
@@ -60,13 +66,16 @@ def tile_load_2d_kernel(
60
66
  input: wp.array2d(dtype=float),
61
67
  out_full: wp.array2d(dtype=float),
62
68
  out_padded: wp.array2d(dtype=float),
69
+ out_sliced: wp.array2d(dtype=float),
63
70
  out_offset: wp.array2d(dtype=float),
64
71
  ):
65
72
  full0 = wp.tile_load(input, shape=(TILE_M, TILE_N))
66
73
  padded0 = wp.tile_load(input, shape=(TILE_M, TILE_N), offset=(TILE_OFFSET, TILE_OFFSET))
74
+ sliced0 = wp.tile_load(input[::2, ::2], shape=(TILE_M, TILE_N))
67
75
 
68
76
  wp.tile_store(out_full, full0)
69
77
  wp.tile_store(out_padded, padded0)
78
+ wp.tile_store(out_sliced, sliced0)
70
79
  wp.tile_store(out_offset, full0, offset=(TILE_OFFSET, TILE_OFFSET))
71
80
 
72
81
 
@@ -75,13 +84,16 @@ def tile_load_3d_kernel(
75
84
  input: wp.array3d(dtype=float),
76
85
  out_full: wp.array3d(dtype=float),
77
86
  out_padded: wp.array3d(dtype=float),
87
+ out_sliced: wp.array3d(dtype=float),
78
88
  out_offset: wp.array3d(dtype=float),
79
89
  ):
80
90
  full0 = wp.tile_load(input, shape=(TILE_M, TILE_N, TILE_O))
81
91
  padded0 = wp.tile_load(input, shape=(TILE_M, TILE_N, TILE_O), offset=(TILE_OFFSET, TILE_OFFSET, TILE_OFFSET))
92
+ sliced0 = wp.tile_load(input[::2, ::2, ::2], shape=(TILE_M, TILE_N, TILE_O))
82
93
 
83
94
  wp.tile_store(out_full, full0)
84
95
  wp.tile_store(out_padded, padded0)
96
+ wp.tile_store(out_sliced, sliced0)
85
97
  wp.tile_store(out_offset, full0, offset=(TILE_OFFSET, TILE_OFFSET, TILE_OFFSET))
86
98
 
87
99
 
@@ -90,15 +102,18 @@ def tile_load_4d_kernel(
90
102
  input: wp.array4d(dtype=float),
91
103
  out_full: wp.array4d(dtype=float),
92
104
  out_padded: wp.array4d(dtype=float),
105
+ out_sliced: wp.array4d(dtype=float),
93
106
  out_offset: wp.array4d(dtype=float),
94
107
  ):
95
108
  full0 = wp.tile_load(input, shape=(TILE_M, TILE_N, TILE_O, TILE_P))
96
109
  padded0 = wp.tile_load(
97
110
  input, shape=(TILE_M, TILE_N, TILE_O, TILE_P), offset=(TILE_OFFSET, TILE_OFFSET, TILE_OFFSET, TILE_OFFSET)
98
111
  )
112
+ sliced0 = wp.tile_load(input[::2, ::2, ::2, ::2], shape=(TILE_M, TILE_N, TILE_O, TILE_P))
99
113
 
100
114
  wp.tile_store(out_full, full0)
101
115
  wp.tile_store(out_padded, padded0)
116
+ wp.tile_store(out_sliced, sliced0)
102
117
  wp.tile_store(out_offset, full0, offset=(TILE_OFFSET, TILE_OFFSET, TILE_OFFSET, TILE_OFFSET))
103
118
 
104
119
 
@@ -112,13 +127,14 @@ def test_tile_load(kernel, ndim):
112
127
  input = wp.array(rng.random(shape), dtype=float, requires_grad=True, device=device)
113
128
  output_full = wp.zeros(shape, dtype=float, device=device)
114
129
  output_padded = wp.zeros(shape, dtype=float, device=device)
130
+ output_sliced = wp.zeros(shape, dtype=float, device=device)
115
131
  output_offset = wp.zeros(shape, dtype=float, device=device)
116
132
 
117
133
  with wp.Tape() as tape:
118
134
  wp.launch_tiled(
119
135
  kernel,
120
136
  dim=[1],
121
- inputs=[input, output_full, output_padded, output_offset],
137
+ inputs=[input, output_full, output_padded, output_sliced, output_offset],
122
138
  block_dim=TILE_DIM,
123
139
  device=device,
124
140
  )
@@ -134,8 +150,16 @@ def test_tile_load(kernel, ndim):
134
150
  ref_offset = np.zeros_like(ref_full)
135
151
  ref_offset[src_slice] = ref_full[dest_slice]
136
152
 
153
+ # construct a slice for the source/dest sliced arrays
154
+ src_slice = tuple(slice(0, dim, 2) for dim in shape)
155
+ dest_slice = tuple(slice(0, (dim + 1) // 2) for dim in shape)
156
+
157
+ ref_sliced = np.zeros_like(ref_full)
158
+ ref_sliced[dest_slice] = ref_full[src_slice]
159
+
137
160
  assert_np_equal(output_full.numpy(), ref_full)
138
161
  assert_np_equal(output_padded.numpy(), ref_padded)
162
+ assert_np_equal(output_sliced.numpy(), ref_sliced)
139
163
  assert_np_equal(output_offset.numpy(), ref_offset)
140
164
 
141
165
  output_full.grad = wp.ones_like(output_full)
@@ -570,7 +594,7 @@ def test_tile_assign(kernel, ndim):
570
594
  input = wp.array(rng.random(shape), dtype=float, requires_grad=True, device=device)
571
595
  output = wp.zeros_like(input)
572
596
 
573
- with wp.Tape() as tape:
597
+ with wp.Tape():
574
598
  wp.launch(
575
599
  kernel,
576
600
  dim=shape,
@@ -21,7 +21,7 @@ import numpy as np
21
21
  import warp as wp
22
22
  from warp.tests.unittest_utils import *
23
23
 
24
- wp.init() # For wp.context.runtime.core.wp_is_mathdx_enabled()
24
+ wp.init() # For wp._src.context.runtime.core.wp_is_mathdx_enabled()
25
25
 
26
26
  TILE_M = wp.constant(8)
27
27
  TILE_N = wp.constant(4)
@@ -92,7 +92,7 @@ def tile_math_fft_kernel_vec2d(gx: wp.array2d(dtype=wp.vec2d), gy: wp.array2d(dt
92
92
  wp.tile_store(gy, xy)
93
93
 
94
94
 
95
- @unittest.skipUnless(wp.context.runtime.core.wp_is_mathdx_enabled(), "Warp was not built with MathDx support")
95
+ @unittest.skipUnless(wp._src.context.runtime.core.wp_is_mathdx_enabled(), "Warp was not built with MathDx support")
96
96
  def test_tile_math_fft(test, device, wp_dtype):
97
97
  np_real_dtype = {wp.vec2f: np.float32, wp.vec2d: np.float64}[wp_dtype]
98
98
  np_cplx_dtype = {wp.vec2f: np.complex64, wp.vec2d: np.complex128}[wp_dtype]
@@ -113,7 +113,7 @@ def test_tile_math_fft(test, device, wp_dtype):
113
113
  X_c64 = X.view(np_cplx_dtype).reshape(fft_size, fft_size)
114
114
  Y_c64 = np.fft.fft(X_c64, axis=-1)
115
115
 
116
- with wp.Tape() as tape:
116
+ with wp.Tape():
117
117
  wp.launch_tiled(kernel, dim=[1, 1], inputs=[X_wp, Y_wp], block_dim=TILE_DIM, device=device)
118
118
 
119
119
  Y_wp_c64 = Y_wp.numpy().view(np_cplx_dtype).reshape(fft_size, fft_size)
@@ -60,7 +60,7 @@ def test_tile_grouped_gemm(test, device):
60
60
  B_wp = wp.array(B, requires_grad=True, device=device)
61
61
  C_wp = wp.zeros((batch_count, TILE_M, TILE_N), requires_grad=True, device=device)
62
62
 
63
- with wp.Tape() as tape:
63
+ with wp.Tape():
64
64
  wp.launch_tiled(
65
65
  tile_grouped_gemm, dim=[batch_count], inputs=[A_wp, B_wp, C_wp], block_dim=TILE_DIM, device=device
66
66
  )
@@ -43,7 +43,7 @@ def create_array(rng, dim_in, dim_hid, dtype=float):
43
43
  def test_multi_layer_nn(test, device):
44
44
  import torch as tc
45
45
 
46
- if device.is_cuda and not wp.context.runtime.core.wp_is_mathdx_enabled():
46
+ if device.is_cuda and not wp._src.context.runtime.core.wp_is_mathdx_enabled():
47
47
  test.skipTest("Skipping test on CUDA device without MathDx (tolerance)")
48
48
 
49
49
  NUM_FREQ = wp.constant(8)
@@ -63,7 +63,7 @@ def test_multi_layer_nn(test, device):
63
63
  NUM_THREADS = 32
64
64
 
65
65
  dtype = wp.float16
66
- npdtype = wp.types.warp_type_to_np_dtype[dtype]
66
+ npdtype = wp._src.types.warp_type_to_np_dtype[dtype]
67
67
 
68
68
  @wp.func
69
69
  def relu(x: dtype):
@@ -188,7 +188,6 @@ def test_multi_layer_nn(test, device):
188
188
  optimizer_inputs = [p.flatten() for p in params]
189
189
  optimizer = warp.optim.Adam(optimizer_inputs, lr=0.01)
190
190
 
191
- num_batches = int((IMG_WIDTH * IMG_HEIGHT) / BATCH_SIZE)
192
191
  max_epochs = 30
193
192
 
194
193
  # create randomized batch indices
@@ -288,7 +287,6 @@ def test_single_layer_nn(test, device):
288
287
  import torch as tc
289
288
 
290
289
  DIM_IN = 8
291
- DIM_HID = 32
292
290
  DIM_OUT = 16
293
291
 
294
292
  NUM_BLOCKS = 56
@@ -79,7 +79,7 @@ def tile_sum_to_shared_kernel(input: wp.array2d(dtype=float), output: wp.array(d
79
79
 
80
80
  a = wp.tile_load(input[i], shape=TILE_DIM)
81
81
  s = wp.tile_sum(a)
82
- v = s[0] # force shared storage for s
82
+ v = s[0]
83
83
  wp.tile_store(output, s * 0.5, offset=i)
84
84
 
85
85
 
@@ -142,7 +142,7 @@ def test_tile_reduce_min(test, device):
142
142
  input_wp = wp.array(input, requires_grad=True, device=device)
143
143
  output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
144
144
 
145
- with wp.Tape() as tape:
145
+ with wp.Tape():
146
146
  wp.launch_tiled(
147
147
  tile_min_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
148
148
  )
@@ -190,7 +190,7 @@ def test_tile_reduce_argmin(test, device):
190
190
  input_wp = wp.array(input, requires_grad=True, device=device)
191
191
  output_wp = wp.zeros(batch_count, dtype=wp.int32, requires_grad=True, device=device)
192
192
 
193
- with wp.Tape() as tape:
193
+ with wp.Tape():
194
194
  wp.launch_tiled(
195
195
  tile_argmin_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
196
196
  )
@@ -231,7 +231,7 @@ def test_tile_reduce_max(test, device):
231
231
  input_wp = wp.array(input, requires_grad=True, device=device)
232
232
  output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
233
233
 
234
- with wp.Tape() as tape:
234
+ with wp.Tape():
235
235
  wp.launch_tiled(
236
236
  tile_max_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
237
237
  )
@@ -264,7 +264,7 @@ def test_tile_reduce_argmax(test, device):
264
264
  input_wp = wp.array(input, requires_grad=True, device=device)
265
265
  output_wp = wp.zeros(batch_count, dtype=wp.int32, requires_grad=True, device=device)
266
266
 
267
- with wp.Tape() as tape:
267
+ with wp.Tape():
268
268
  wp.launch_tiled(
269
269
  tile_argmax_kernel, dim=[batch_count], inputs=[input_wp, output_wp], block_dim=TILE_DIM, device=device
270
270
  )
@@ -297,7 +297,7 @@ def test_tile_reduce_custom(test, device):
297
297
  input_wp = wp.array(input, requires_grad=True, device=device)
298
298
  output_wp = wp.zeros(batch_count, requires_grad=True, device=device)
299
299
 
300
- with wp.Tape() as tape:
300
+ with wp.Tape():
301
301
  wp.launch_tiled(
302
302
  tile_reduce_custom_kernel,
303
303
  dim=[batch_count],
@@ -333,7 +333,7 @@ def test_tile_scan_inclusive(test, device):
333
333
  input_wp = wp.array2d(input, requires_grad=True, device=device)
334
334
  output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
335
335
 
336
- with wp.Tape() as tape:
336
+ with wp.Tape():
337
337
  wp.launch_tiled(
338
338
  create_tile_scan_inclusive_kernel(N),
339
339
  dim=[batch_count],
@@ -369,7 +369,7 @@ def test_tile_scan_exclusive(test, device):
369
369
  input_wp = wp.array2d(input, requires_grad=True, device=device)
370
370
  output_wp = wp.zeros_like(input_wp, requires_grad=True, device=device)
371
371
 
372
- with wp.Tape() as tape:
372
+ with wp.Tape():
373
373
  wp.launch_tiled(
374
374
  create_tile_scan_exclusive_kernel(N),
375
375
  dim=[batch_count],
@@ -501,12 +501,210 @@ def test_tile_reduce_simt(test, device):
501
501
 
502
502
  output = wp.zeros(shape=1, dtype=int, requires_grad=True, device=device)
503
503
 
504
- with wp.Tape() as tape:
504
+ with wp.Tape():
505
505
  wp.launch(tile_reduce_simt_kernel, dim=N, inputs=[output], block_dim=TILE_DIM, device=device)
506
506
 
507
507
  test.assertEqual(output.numpy()[0], np.sum(np.arange(N)))
508
508
 
509
509
 
510
+ # Tier 1: axis size <= 32
511
+ @wp.kernel
512
+ def tile_reduce_axis_tier1_sum_axis0_kernel(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
513
+ a = wp.tile_load(x, shape=(32, 64), storage="shared")
514
+ b = wp.tile_sum(a, axis=0)
515
+ wp.tile_store(y, b)
516
+
517
+
518
+ @wp.kernel
519
+ def tile_reduce_axis_tier1_prod_axis1_kernel(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
520
+ a = wp.tile_load(x, shape=(32, 8), storage="shared")
521
+ b = wp.tile_reduce(wp.mul, a, axis=1)
522
+ wp.tile_store(y, b)
523
+
524
+
525
+ @wp.kernel
526
+ def tile_reduce_axis_tier1_sum_axis2_kernel(x: wp.array3d(dtype=float), y: wp.array2d(dtype=float)):
527
+ a = wp.tile_load(x, shape=(8, 8, 16), storage="shared")
528
+ b = wp.tile_sum(a, axis=2)
529
+ wp.tile_store(y, b)
530
+
531
+
532
+ # Tier 2: 32 < axis size <= 256
533
+ @wp.kernel
534
+ def tile_reduce_axis_tier2_sum_axis0_kernel(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
535
+ a = wp.tile_load(x, shape=(200, 32), storage="shared")
536
+ b = wp.tile_sum(a, axis=0)
537
+ wp.tile_store(y, b)
538
+
539
+
540
+ @wp.kernel
541
+ def tile_reduce_axis_tier2_prod_axis1_kernel(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
542
+ a = wp.tile_load(x, shape=(16, 64), storage="shared")
543
+ b = wp.tile_reduce(wp.mul, a, axis=1)
544
+ wp.tile_store(y, b)
545
+
546
+
547
+ @wp.kernel
548
+ def tile_reduce_axis_tier2_sum_axis2_kernel(x: wp.array3d(dtype=float), y: wp.array2d(dtype=float)):
549
+ a = wp.tile_load(x, shape=(8, 8, 128), storage="shared")
550
+ b = wp.tile_sum(a, axis=2)
551
+ wp.tile_store(y, b)
552
+
553
+
554
+ # Tier 3: axis size > 256
555
+ @wp.kernel
556
+ def tile_reduce_axis_tier3_sum_axis0_kernel(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
557
+ a = wp.tile_load(x, shape=(400, 16), storage="shared")
558
+ b = wp.tile_sum(a, axis=0)
559
+ wp.tile_store(y, b)
560
+
561
+
562
+ @wp.kernel
563
+ def tile_reduce_axis_tier3_prod_axis1_kernel(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
564
+ a = wp.tile_load(x, shape=(8, 300), storage="shared")
565
+ b = wp.tile_reduce(wp.mul, a, axis=1)
566
+ wp.tile_store(y, b)
567
+
568
+
569
+ @wp.kernel
570
+ def tile_reduce_axis_tier3_sum_axis2_kernel(x: wp.array3d(dtype=float), y: wp.array2d(dtype=float)):
571
+ a = wp.tile_load(x, shape=(4, 4, 384), storage="shared")
572
+ b = wp.tile_sum(a, axis=2)
573
+ wp.tile_store(y, b)
574
+
575
+
576
+ def test_tile_reduce_axis_tier1(test, device):
577
+ # 2D sum: axis=0, size 32 (forward and backward)
578
+ x = wp.ones((32, 64), dtype=float, requires_grad=True, device=device)
579
+ y = wp.zeros(64, dtype=float, requires_grad=True, device=device)
580
+
581
+ with wp.Tape() as tape:
582
+ wp.launch_tiled(
583
+ tile_reduce_axis_tier1_sum_axis0_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
584
+ )
585
+
586
+ y.grad = wp.ones_like(y)
587
+ tape.backward()
588
+
589
+ assert_np_equal(y.numpy(), np.ones(64, dtype=float) * 32.0)
590
+ assert_np_equal(x.grad.numpy(), np.ones((32, 64), dtype=float))
591
+
592
+ # 2D product: axis=1, size 8
593
+ rng = np.random.default_rng(42)
594
+ x_np = rng.random((32, 8), dtype=np.float32) * 0.1 + 1.0
595
+ x = wp.array(x_np, dtype=float, device=device)
596
+ y = wp.zeros(32, dtype=float, device=device)
597
+
598
+ wp.launch_tiled(
599
+ tile_reduce_axis_tier1_prod_axis1_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
600
+ )
601
+
602
+ assert_np_equal(y.numpy(), np.prod(x_np, axis=1), tol=1e-3)
603
+
604
+ # 3D sum: axis=2, size 16 (forward and backward)
605
+ x = wp.ones((8, 8, 16), dtype=float, requires_grad=True, device=device)
606
+ y = wp.zeros((8, 8), dtype=float, requires_grad=True, device=device)
607
+
608
+ with wp.Tape() as tape:
609
+ wp.launch_tiled(
610
+ tile_reduce_axis_tier1_sum_axis2_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
611
+ )
612
+
613
+ y.grad = wp.ones_like(y)
614
+ tape.backward()
615
+
616
+ assert_np_equal(y.numpy(), np.ones((8, 8), dtype=float) * 16.0)
617
+ assert_np_equal(x.grad.numpy(), np.ones((8, 8, 16), dtype=float))
618
+
619
+
620
+ def test_tile_reduce_axis_tier2(test, device):
621
+ # 2D sum: axis=0, size 200 (forward and backward)
622
+ x = wp.ones((200, 32), dtype=float, requires_grad=True, device=device)
623
+ y = wp.zeros(32, dtype=float, requires_grad=True, device=device)
624
+
625
+ with wp.Tape() as tape:
626
+ wp.launch_tiled(
627
+ tile_reduce_axis_tier2_sum_axis0_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
628
+ )
629
+
630
+ y.grad = wp.ones_like(y)
631
+ tape.backward()
632
+
633
+ assert_np_equal(y.numpy(), np.ones(32, dtype=float) * 200.0)
634
+ assert_np_equal(x.grad.numpy(), np.ones((200, 32), dtype=float))
635
+
636
+ # 2D product: axis=1, size 64
637
+ rng = np.random.default_rng(42)
638
+ x_np = rng.random((16, 64), dtype=np.float32) * 0.05 + 1.0
639
+ x = wp.array(x_np, dtype=float, device=device)
640
+ y = wp.zeros(16, dtype=float, device=device)
641
+
642
+ wp.launch_tiled(
643
+ tile_reduce_axis_tier2_prod_axis1_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
644
+ )
645
+
646
+ assert_np_equal(y.numpy(), np.prod(x_np, axis=1), tol=1e-2)
647
+
648
+ # 3D sum: axis=2, size 128 (forward and backward)
649
+ x = wp.ones((8, 8, 128), dtype=float, requires_grad=True, device=device)
650
+ y = wp.zeros((8, 8), dtype=float, requires_grad=True, device=device)
651
+
652
+ with wp.Tape() as tape:
653
+ wp.launch_tiled(
654
+ tile_reduce_axis_tier2_sum_axis2_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
655
+ )
656
+
657
+ y.grad = wp.ones_like(y)
658
+ tape.backward()
659
+
660
+ assert_np_equal(y.numpy(), np.ones((8, 8), dtype=float) * 128.0)
661
+ assert_np_equal(x.grad.numpy(), np.ones((8, 8, 128), dtype=float))
662
+
663
+
664
+ def test_tile_reduce_axis_tier3(test, device):
665
+ # 2D sum: axis=0, size 400 (forward and backward)
666
+ x = wp.ones((400, 16), dtype=float, requires_grad=True, device=device)
667
+ y = wp.zeros(16, dtype=float, requires_grad=True, device=device)
668
+
669
+ with wp.Tape() as tape:
670
+ wp.launch_tiled(
671
+ tile_reduce_axis_tier3_sum_axis0_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
672
+ )
673
+
674
+ y.grad = wp.ones_like(y)
675
+ tape.backward()
676
+
677
+ assert_np_equal(y.numpy(), np.ones(16, dtype=float) * 400.0, tol=2e-4)
678
+ assert_np_equal(x.grad.numpy(), np.ones((400, 16), dtype=float))
679
+
680
+ # 2D product: axis=1, size 300
681
+ rng = np.random.default_rng(42)
682
+ x_np = rng.random((8, 300), dtype=np.float32) * 0.01 + 1.0
683
+ x = wp.array(x_np, dtype=float, device=device)
684
+ y = wp.zeros(8, dtype=float, device=device)
685
+
686
+ wp.launch_tiled(
687
+ tile_reduce_axis_tier3_prod_axis1_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
688
+ )
689
+
690
+ assert_np_equal(y.numpy(), np.prod(x_np, axis=1), tol=1e-1)
691
+
692
+ # 3D sum: axis=2, size 384 (forward and backward)
693
+ x = wp.ones((4, 4, 384), dtype=float, requires_grad=True, device=device)
694
+ y = wp.zeros((4, 4), dtype=float, requires_grad=True, device=device)
695
+
696
+ with wp.Tape() as tape:
697
+ wp.launch_tiled(
698
+ tile_reduce_axis_tier3_sum_axis2_kernel, dim=[1], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device
699
+ )
700
+
701
+ y.grad = wp.ones_like(y)
702
+ tape.backward()
703
+
704
+ assert_np_equal(y.numpy(), np.ones((4, 4), dtype=float) * 384.0, tol=2e-4)
705
+ assert_np_equal(x.grad.numpy(), np.ones((4, 4, 384), dtype=float))
706
+
707
+
510
708
  @wp.kernel
511
709
  def tile_untile_kernel(output: wp.array(dtype=int)):
512
710
  # thread index
@@ -525,7 +723,7 @@ def test_tile_untile(test, device):
525
723
 
526
724
  output = wp.zeros(shape=N, dtype=int, requires_grad=True, device=device)
527
725
 
528
- with wp.Tape() as tape:
726
+ with wp.Tape():
529
727
  wp.launch(tile_untile_kernel, dim=N, inputs=[output], block_dim=TILE_DIM, device=device)
530
728
 
531
729
  assert_np_equal(output.numpy(), np.arange(N) * 2)
@@ -549,7 +747,7 @@ def test_tile_untile_scalar(test, device):
549
747
 
550
748
  output = wp.zeros(shape=N, dtype=int, requires_grad=True, device=device)
551
749
 
552
- with wp.Tape() as tape:
750
+ with wp.Tape():
553
751
  wp.launch(tile_untile_kernel, dim=N, inputs=[output], block_dim=TILE_DIM, device=device)
554
752
 
555
753
  assert_np_equal(output.numpy(), np.arange(N) * 2)
@@ -594,7 +792,7 @@ def tile_ones_kernel(out: wp.array(dtype=float)):
594
792
  def test_tile_ones(test, device):
595
793
  output = wp.zeros(1, dtype=float, device=device)
596
794
 
597
- with wp.Tape() as tape:
795
+ with wp.Tape():
598
796
  wp.launch_tiled(tile_ones_kernel, dim=[1], inputs=[output], block_dim=TILE_DIM, device=device)
599
797
 
600
798
  test.assertAlmostEqual(output.numpy()[0], 256.0)
@@ -622,7 +820,7 @@ def test_tile_arange(test, device):
622
820
 
623
821
  output = wp.zeros(shape=(5, N), dtype=int, device=device)
624
822
 
625
- with wp.Tape() as tape:
823
+ with wp.Tape():
626
824
  wp.launch_tiled(tile_arange_kernel, dim=[1], inputs=[output], block_dim=TILE_DIM, device=device)
627
825
 
628
826
  assert_np_equal(output.numpy()[0], np.arange(17))
@@ -734,6 +932,9 @@ add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_cu
734
932
  add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
735
933
  add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_grouped_sum, devices=devices)
736
934
  add_function_test(TestTileReduce, "test_tile_reduce_simt", test_tile_reduce_simt, devices=devices)
935
+ add_function_test(TestTileReduce, "test_tile_reduce_axis_tier1", test_tile_reduce_axis_tier1, devices=devices)
936
+ add_function_test(TestTileReduce, "test_tile_reduce_axis_tier2", test_tile_reduce_axis_tier2, devices=devices)
937
+ add_function_test(TestTileReduce, "test_tile_reduce_axis_tier3", test_tile_reduce_axis_tier3, devices=devices)
737
938
  add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devices)
738
939
  add_function_test(TestTileReduce, "test_tile_arange", test_tile_arange, devices=devices)
739
940
  add_function_test(TestTileReduce, "test_tile_untile_scalar", test_tile_untile_scalar, devices=devices)
@@ -113,17 +113,12 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
113
113
  from warp.tests.interop.test_dlpack import TestDLPack
114
114
  from warp.tests.interop.test_jax import TestJax
115
115
  from warp.tests.interop.test_torch import TestTorch
116
- from warp.tests.sim.test_cloth import TestCloth
117
- from warp.tests.sim.test_collision import TestCollision
118
- from warp.tests.sim.test_coloring import TestColoring
119
- from warp.tests.sim.test_model import TestModel
120
- from warp.tests.sim.test_sim_grad import TestSimGradients
121
- from warp.tests.sim.test_sim_kinematics import TestSimKinematics
122
116
  from warp.tests.test_adam import TestAdam
123
117
  from warp.tests.test_arithmetic import TestArithmetic
124
118
  from warp.tests.test_array import TestArray
125
119
  from warp.tests.test_array_reduce import TestArrayReduce
126
120
  from warp.tests.test_atomic import TestAtomic
121
+ from warp.tests.test_atomic_bitwise import TestAtomicBitwise
127
122
  from warp.tests.test_atomic_cas import TestAtomicCAS
128
123
  from warp.tests.test_bool import TestBool
129
124
  from warp.tests.test_builtins_resolution import TestBuiltinsResolution
@@ -142,7 +137,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
142
137
  TestFemDiffusionExamples,
143
138
  TestFemExamples,
144
139
  TestOptimExamples,
145
- TestSimExamples,
146
140
  )
147
141
  from warp.tests.test_fabricarray import TestFabricArray
148
142
  from warp.tests.test_fast_math import TestFastMath
@@ -198,7 +192,9 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
198
192
  from warp.tests.test_vec_lite import TestVecLite
199
193
  from warp.tests.test_vec_scalar_ops import TestVecScalarOps
200
194
  from warp.tests.test_verify_fp import TestVerifyFP
195
+ from warp.tests.test_version import TestVersion
201
196
  from warp.tests.tile.test_tile import TestTile
197
+ from warp.tests.tile.test_tile_atomic_bitwise import TestTileAtomicBitwise
202
198
  from warp.tests.tile.test_tile_cholesky import TestTileCholesky
203
199
  from warp.tests.tile.test_tile_load import TestTileLoad
204
200
  from warp.tests.tile.test_tile_mathdx import TestTileMathDx
@@ -215,16 +211,14 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
215
211
  TestArrayReduce,
216
212
  TestAsync,
217
213
  TestAtomic,
214
+ TestAtomicBitwise,
218
215
  TestAtomicCAS,
219
216
  TestBool,
220
217
  TestBuiltinsResolution,
221
218
  TestBvh,
222
219
  TestClosestPointEdgeEdgeMethods,
223
- TestCloth,
224
220
  TestCodeGen,
225
221
  TestCodeGenInstancing,
226
- TestCollision,
227
- TestColoring,
228
222
  TestConditional,
229
223
  TestConstants,
230
224
  TestContext,
@@ -237,7 +231,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
237
231
  TestFemDiffusionExamples,
238
232
  TestFemExamples,
239
233
  TestOptimExamples,
240
- TestSimExamples,
241
234
  TestFabricArray,
242
235
  TestFastMath,
243
236
  TestFem,
@@ -272,7 +265,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
272
265
  TestMeshQueryAABBMethods,
273
266
  TestMeshQueryPoint,
274
267
  TestMeshQueryRay,
275
- TestModel,
276
268
  TestModuleHashing,
277
269
  TestModuleLite,
278
270
  TestMultiGPU,
@@ -289,8 +281,6 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
289
281
  TestRounding,
290
282
  TestRunlengthEncode,
291
283
  TestScalarOps,
292
- TestSimGradients,
293
- TestSimKinematics,
294
284
  TestSmoothstep,
295
285
  TestSnippets,
296
286
  TestSparse,
@@ -301,6 +291,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
301
291
  TestStruct,
302
292
  TestTape,
303
293
  TestTile,
294
+ TestTileAtomicBitwise,
304
295
  TestTileCholesky,
305
296
  TestTileLoad,
306
297
  TestTileMathDx,
@@ -319,6 +310,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
319
310
  TestVecLite,
320
311
  TestVecScalarOps,
321
312
  TestVerifyFP,
313
+ TestVersion,
322
314
  TestVolume,
323
315
  TestVolumeWrite,
324
316
  ]
@@ -376,24 +376,25 @@ def write_junit_results(
376
376
  )
377
377
 
378
378
  for test_data in test_records:
379
- test = test_data[0]
380
- test_duration = test_data[1]
381
- test_status = test_data[2]
379
+ test_classname = test_data[0]
380
+ test_methodname = test_data[1]
381
+ test_duration = test_data[2]
382
+ test_status = test_data[3]
382
383
 
383
384
  test_case = ET.SubElement(
384
- root, "testcase", classname=test.__class__.__name__, name=test._testMethodName, time=f"{test_duration:.3f}"
385
+ root, "testcase", classname=test_classname, name=test_methodname, time=f"{test_duration:.3f}"
385
386
  )
386
387
 
387
388
  if test_status == "FAIL":
388
- failure = ET.SubElement(test_case, "failure", message=str(test_data[3]))
389
- failure.text = str(test_data[4]) # Stacktrace
389
+ failure = ET.SubElement(test_case, "failure", message=str(test_data[4]))
390
+ failure.text = str(test_data[5]) # Stacktrace
390
391
  elif test_status == "ERROR":
391
392
  error = ET.SubElement(test_case, "error")
392
- error.text = str(test_data[4]) # Stacktrace
393
+ error.text = str(test_data[5]) # Stacktrace
393
394
  elif test_status == "SKIP":
394
395
  skip = ET.SubElement(test_case, "skipped")
395
396
  # Set the skip reason
396
- skip.set("message", str(test_data[3]))
397
+ skip.set("message", str(test_data[4]))
397
398
 
398
399
  tree = ET.ElementTree(root)
399
400
 
@@ -425,7 +426,7 @@ class ParallelJunitTestResult(unittest.TextTestResult):
425
426
 
426
427
  def _record_test(self, test, code, message=None, details=None):
427
428
  duration = round((time.perf_counter_ns() - self.start_time) * 1e-9, 3) # [s]
428
- self.test_record.append((test, duration, code, message, details))
429
+ self.test_record.append((test.__class__.__name__, test._testMethodName, duration, code, message, details))
429
430
 
430
431
  def addSuccess(self, test):
431
432
  super(unittest.TextTestResult, self).addSuccess(test)