warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -87,11 +87,11 @@ def get_select_kernel2(dtype):
87
87
  def test_arrays(test, device, dtype):
88
88
  rng = np.random.default_rng(123)
89
89
 
90
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
91
- vec2 = wp.types.vector(length=2, dtype=wptype)
92
- vec3 = wp.types.vector(length=3, dtype=wptype)
93
- vec4 = wp.types.vector(length=4, dtype=wptype)
94
- vec5 = wp.types.vector(length=5, dtype=wptype)
90
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
91
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
92
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
93
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
94
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
95
95
 
96
96
  v2_np = randvals(rng, (10, 2), dtype)
97
97
  v3_np = randvals(rng, (10, 3), dtype)
@@ -108,9 +108,9 @@ def test_arrays(test, device, dtype):
108
108
  assert_np_equal(v4.numpy(), v4_np, tol=1.0e-6)
109
109
  assert_np_equal(v5.numpy(), v5_np, tol=1.0e-6)
110
110
 
111
- vec2 = wp.types.vector(length=2, dtype=wptype)
112
- vec3 = wp.types.vector(length=3, dtype=wptype)
113
- vec4 = wp.types.vector(length=4, dtype=wptype)
111
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
112
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
113
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
114
114
 
115
115
  v2 = wp.array(v2_np, dtype=vec2, requires_grad=True, device=device)
116
116
  v3 = wp.array(v3_np, dtype=vec3, requires_grad=True, device=device)
@@ -125,8 +125,8 @@ def test_components(test, device, dtype):
125
125
  # test accessing vector components from Python - this is especially important
126
126
  # for float16, which requires special handling internally
127
127
 
128
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
129
- vec3 = wp.types.vector(length=3, dtype=wptype)
128
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
129
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
130
130
 
131
131
  v = vec3(1, 2, 3)
132
132
 
@@ -184,10 +184,10 @@ def test_components(test, device, dtype):
184
184
 
185
185
 
186
186
  def test_py_arithmetic_ops(test, device, dtype):
187
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
187
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
188
188
 
189
189
  def make_vec(*args):
190
- if wptype in wp.types.int_types:
190
+ if wptype in wp._src.types.int_types:
191
191
  # Cast to the correct integer type to simulate wrapping.
192
192
  return tuple(wptype._type_(x).value for x in args)
193
193
 
@@ -219,11 +219,11 @@ def test_constructors(test, device, dtype, register_kernels=False):
219
219
  np.float64: 1.0e-8,
220
220
  }.get(dtype, 0)
221
221
 
222
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
223
- vec2 = wp.types.vector(length=2, dtype=wptype)
224
- vec3 = wp.types.vector(length=3, dtype=wptype)
225
- vec4 = wp.types.vector(length=4, dtype=wptype)
226
- vec5 = wp.types.vector(length=5, dtype=wptype)
222
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
223
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
224
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
225
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
226
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
227
227
 
228
228
  def check_scalar_constructor(
229
229
  input: wp.array(dtype=wptype),
@@ -463,7 +463,7 @@ def test_anon_type_instance(test, device, dtype, register_kernels=False):
463
463
  np.float64: 1.0e-8,
464
464
  }.get(dtype, 0)
465
465
 
466
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
466
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
467
467
 
468
468
  def check_scalar_init(
469
469
  input: wp.array(dtype=wptype),
@@ -586,11 +586,11 @@ def test_indexing(test, device, dtype, register_kernels=False):
586
586
  np.float64: 1.0e-8,
587
587
  }.get(dtype, 0)
588
588
 
589
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
590
- vec2 = wp.types.vector(length=2, dtype=wptype)
591
- vec3 = wp.types.vector(length=3, dtype=wptype)
592
- vec4 = wp.types.vector(length=4, dtype=wptype)
593
- vec5 = wp.types.vector(length=5, dtype=wptype)
589
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
590
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
591
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
592
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
593
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
594
594
 
595
595
  def check_indexing(
596
596
  v2: wp.array(dtype=vec2),
@@ -691,11 +691,11 @@ def test_indexing(test, device, dtype, register_kernels=False):
691
691
 
692
692
 
693
693
  def test_equality(test, device, dtype, register_kernels=False):
694
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
695
- vec2 = wp.types.vector(length=2, dtype=wptype)
696
- vec3 = wp.types.vector(length=3, dtype=wptype)
697
- vec4 = wp.types.vector(length=4, dtype=wptype)
698
- vec5 = wp.types.vector(length=5, dtype=wptype)
694
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
695
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
696
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
697
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
698
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
699
699
 
700
700
  def check_unsigned_equality(
701
701
  v20: wp.array(dtype=vec2),
@@ -821,11 +821,11 @@ def test_scalar_multiplication(test, device, dtype, register_kernels=False):
821
821
  np.float64: 1.0e-8,
822
822
  }.get(dtype, 0)
823
823
 
824
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
825
- vec2 = wp.types.vector(length=2, dtype=wptype)
826
- vec3 = wp.types.vector(length=3, dtype=wptype)
827
- vec4 = wp.types.vector(length=4, dtype=wptype)
828
- vec5 = wp.types.vector(length=5, dtype=wptype)
824
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
825
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
826
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
827
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
828
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
829
829
 
830
830
  def check_mul(
831
831
  s: wp.array(dtype=wptype),
@@ -953,11 +953,11 @@ def test_scalar_multiplication_rightmul(test, device, dtype, register_kernels=Fa
953
953
  np.float64: 1.0e-8,
954
954
  }.get(dtype, 0)
955
955
 
956
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
957
- vec2 = wp.types.vector(length=2, dtype=wptype)
958
- vec3 = wp.types.vector(length=3, dtype=wptype)
959
- vec4 = wp.types.vector(length=4, dtype=wptype)
960
- vec5 = wp.types.vector(length=5, dtype=wptype)
956
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
957
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
958
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
959
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
960
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
961
961
 
962
962
  def check_rightmul(
963
963
  s: wp.array(dtype=wptype),
@@ -1085,11 +1085,11 @@ def test_cw_multiplication(test, device, dtype, register_kernels=False):
1085
1085
  np.float64: 1.0e-8,
1086
1086
  }.get(dtype, 0)
1087
1087
 
1088
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1089
- vec2 = wp.types.vector(length=2, dtype=wptype)
1090
- vec3 = wp.types.vector(length=3, dtype=wptype)
1091
- vec4 = wp.types.vector(length=4, dtype=wptype)
1092
- vec5 = wp.types.vector(length=5, dtype=wptype)
1088
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1089
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
1090
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1091
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
1092
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
1093
1093
 
1094
1094
  def check_cw_mul(
1095
1095
  s2: wp.array(dtype=vec2),
@@ -1230,11 +1230,11 @@ def test_scalar_division(test, device, dtype, register_kernels=False):
1230
1230
  np.float64: 1.0e-8,
1231
1231
  }.get(dtype, 0)
1232
1232
 
1233
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1234
- vec2 = wp.types.vector(length=2, dtype=wptype)
1235
- vec3 = wp.types.vector(length=3, dtype=wptype)
1236
- vec4 = wp.types.vector(length=4, dtype=wptype)
1237
- vec5 = wp.types.vector(length=5, dtype=wptype)
1233
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1234
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
1235
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1236
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
1237
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
1238
1238
 
1239
1239
  def check_div(
1240
1240
  s: wp.array(dtype=wptype),
@@ -1386,11 +1386,11 @@ def test_cw_division(test, device, dtype, register_kernels=False):
1386
1386
  np.float64: 1.0e-8,
1387
1387
  }.get(dtype, 0)
1388
1388
 
1389
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1390
- vec2 = wp.types.vector(length=2, dtype=wptype)
1391
- vec3 = wp.types.vector(length=3, dtype=wptype)
1392
- vec4 = wp.types.vector(length=4, dtype=wptype)
1393
- vec5 = wp.types.vector(length=5, dtype=wptype)
1389
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1390
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
1391
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1392
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
1393
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
1394
1394
 
1395
1395
  def check_cw_div(
1396
1396
  s2: wp.array(dtype=vec2),
@@ -1554,11 +1554,11 @@ def test_addition(test, device, dtype, register_kernels=False):
1554
1554
  np.float64: 1.0e-8,
1555
1555
  }.get(dtype, 0)
1556
1556
 
1557
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1558
- vec2 = wp.types.vector(length=2, dtype=wptype)
1559
- vec3 = wp.types.vector(length=3, dtype=wptype)
1560
- vec4 = wp.types.vector(length=4, dtype=wptype)
1561
- vec5 = wp.types.vector(length=5, dtype=wptype)
1557
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1558
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
1559
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1560
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
1561
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
1562
1562
 
1563
1563
  def check_add(
1564
1564
  s2: wp.array(dtype=vec2),
@@ -1695,11 +1695,11 @@ def test_dotproduct(test, device, dtype, register_kernels=False):
1695
1695
  np.float64: 1.0e-8,
1696
1696
  }.get(dtype, 0)
1697
1697
 
1698
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1699
- vec2 = wp.types.vector(length=2, dtype=wptype)
1700
- vec3 = wp.types.vector(length=3, dtype=wptype)
1701
- vec4 = wp.types.vector(length=4, dtype=wptype)
1702
- vec5 = wp.types.vector(length=5, dtype=wptype)
1698
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1699
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
1700
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1701
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
1702
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
1703
1703
 
1704
1704
  def check_dot(
1705
1705
  s2: wp.array(dtype=vec2),
@@ -1816,11 +1816,11 @@ def test_modulo(test, device, dtype, register_kernels=False):
1816
1816
  np.float64: 1.0e-8,
1817
1817
  }.get(dtype, 0)
1818
1818
 
1819
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1820
- vec2 = wp.types.vector(length=2, dtype=wptype)
1821
- vec3 = wp.types.vector(length=3, dtype=wptype)
1822
- vec4 = wp.types.vector(length=4, dtype=wptype)
1823
- vec5 = wp.types.vector(length=5, dtype=wptype)
1819
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1820
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
1821
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1822
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
1823
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
1824
1824
 
1825
1825
  def check_mod(
1826
1826
  s2: wp.array(dtype=vec2),
@@ -1942,19 +1942,19 @@ def test_modulo(test, device, dtype, register_kernels=False):
1942
1942
 
1943
1943
 
1944
1944
  def test_equivalent_types(test, device, dtype, register_kernels=False):
1945
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1945
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
1946
1946
 
1947
1947
  # vector types
1948
- vec2 = wp.types.vector(length=2, dtype=wptype)
1949
- vec3 = wp.types.vector(length=3, dtype=wptype)
1950
- vec4 = wp.types.vector(length=4, dtype=wptype)
1951
- vec5 = wp.types.vector(length=5, dtype=wptype)
1948
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
1949
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
1950
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
1951
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
1952
1952
 
1953
1953
  # vector types equivalent to the above
1954
- vec2_equiv = wp.types.vector(length=2, dtype=wptype)
1955
- vec3_equiv = wp.types.vector(length=3, dtype=wptype)
1956
- vec4_equiv = wp.types.vector(length=4, dtype=wptype)
1957
- vec5_equiv = wp.types.vector(length=5, dtype=wptype)
1954
+ vec2_equiv = wp._src.types.vector(length=2, dtype=wptype)
1955
+ vec3_equiv = wp._src.types.vector(length=3, dtype=wptype)
1956
+ vec4_equiv = wp._src.types.vector(length=4, dtype=wptype)
1957
+ vec5_equiv = wp._src.types.vector(length=5, dtype=wptype)
1958
1958
 
1959
1959
  # declare kernel with original types
1960
1960
  def check_equivalence(
@@ -2021,11 +2021,11 @@ def test_conversions(test, device, dtype, register_kernels=False):
2021
2021
 
2022
2022
 
2023
2023
  def test_constants(test, device, dtype, register_kernels=False):
2024
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2025
- vec2 = wp.types.vector(length=2, dtype=wptype)
2026
- vec3 = wp.types.vector(length=3, dtype=wptype)
2027
- vec4 = wp.types.vector(length=4, dtype=wptype)
2028
- vec5 = wp.types.vector(length=5, dtype=wptype)
2024
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
2025
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
2026
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
2027
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
2028
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
2029
2029
 
2030
2030
  cv2 = wp.constant(vec2(1, 2))
2031
2031
  cv3 = wp.constant(vec3(1, 2, 3))
@@ -2047,11 +2047,11 @@ def test_constants(test, device, dtype, register_kernels=False):
2047
2047
 
2048
2048
 
2049
2049
  def test_abs(test, device, dtype, register_kernels=False):
2050
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2051
- vec2 = wp.types.vector(length=2, dtype=wptype)
2052
- vec3 = wp.types.vector(length=3, dtype=wptype)
2053
- vec4 = wp.types.vector(length=4, dtype=wptype)
2054
- vec5 = wp.types.vector(length=5, dtype=wptype)
2050
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
2051
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
2052
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
2053
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
2054
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
2055
2055
 
2056
2056
  def check_vector_abs():
2057
2057
  res2 = wp.abs(vec2(wptype(-1), wptype(2)))
@@ -2075,11 +2075,11 @@ def test_abs(test, device, dtype, register_kernels=False):
2075
2075
 
2076
2076
 
2077
2077
  def test_sign(test, device, dtype, register_kernels=False):
2078
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2079
- vec2 = wp.types.vector(length=2, dtype=wptype)
2080
- vec3 = wp.types.vector(length=3, dtype=wptype)
2081
- vec4 = wp.types.vector(length=4, dtype=wptype)
2082
- vec5 = wp.types.vector(length=5, dtype=wptype)
2078
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
2079
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
2080
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
2081
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
2082
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
2083
2083
 
2084
2084
  def check_vector_sign():
2085
2085
  res2 = wp.sign(vec2(wptype(-1), wptype(2)))
@@ -2113,11 +2113,11 @@ def test_minmax(test, device, dtype, register_kernels=False):
2113
2113
  np.float16: 1.0e-2,
2114
2114
  }.get(dtype, 0)
2115
2115
 
2116
- wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2117
- vec2 = wp.types.vector(length=2, dtype=wptype)
2118
- vec3 = wp.types.vector(length=3, dtype=wptype)
2119
- vec4 = wp.types.vector(length=4, dtype=wptype)
2120
- vec5 = wp.types.vector(length=5, dtype=wptype)
2116
+ wptype = wp._src.types.np_dtype_to_warp_type[np.dtype(dtype)]
2117
+ vec2 = wp._src.types.vector(length=2, dtype=wptype)
2118
+ vec3 = wp._src.types.vector(length=3, dtype=wptype)
2119
+ vec4 = wp._src.types.vector(length=4, dtype=wptype)
2120
+ vec5 = wp._src.types.vector(length=5, dtype=wptype)
2121
2121
 
2122
2122
  # \TODO: Also not quite sure why: this kernel compiles incredibly
2123
2123
  # slowly though...
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import unittest
18
+
19
+ import warp as wp
20
+ from warp._src.context import get_warp_clang_version, get_warp_version
21
+
22
+
23
+ class TestVersion(unittest.TestCase):
24
+ """Tests for native library version verification using string comparison."""
25
+
26
+ def test_get_warp_version_returns_string(self):
27
+ """Test that get_warp_version() returns a string."""
28
+ version = get_warp_version()
29
+ self.assertIsInstance(version, str)
30
+ self.assertRegex(version, r"^\d+\.\d+\.\d+")
31
+
32
+ def test_get_warp_clang_version_returns_string(self):
33
+ """Test that get_warp_clang_version() returns a string."""
34
+ version = get_warp_clang_version()
35
+ self.assertIsInstance(version, str)
36
+ self.assertRegex(version, r"^\d+\.\d+\.\d+")
37
+
38
+ def test_dll_versions_match_python_version(self):
39
+ """Test that native library versions match Python package version exactly."""
40
+ python_version = wp._src.config.version
41
+ warp_version = get_warp_version()
42
+ warp_clang_version = get_warp_clang_version()
43
+
44
+ self.assertEqual(
45
+ warp_version, python_version, f"Warp library version {warp_version} != Python version {python_version}"
46
+ )
47
+ self.assertEqual(
48
+ warp_clang_version,
49
+ python_version,
50
+ f"warp-clang library version {warp_clang_version} != Python version {python_version}",
51
+ )
52
+
53
+ def test_version_format_validation(self):
54
+ """Test that versions follow PEP 440 versioning format."""
55
+ # PEP 440 pattern: [EPOCH!]MAJOR.MINOR.PATCH[{a|b|rc}N][.postN][.devN][+LOCAL]
56
+ # Examples: 1.10.0, 1.10.0.dev0, 1.10.0.dev20251017, 1.10.0a1, 1.10.0rc1, 1!1.10.0
57
+ pep440_pattern = (
58
+ r"^(\d+!)?" # Optional epoch
59
+ r"\d+\.\d+\.\d+" # Major.minor.patch
60
+ r"(\.?(a|alpha|b|beta|rc)\.?\d+)?" # Optional alpha/beta/rc (with or without dots)
61
+ r"(\.post\d+)?" # Optional post release
62
+ r"(\.dev\d+)?" # Optional dev release
63
+ r"(\+[a-zA-Z0-9.]+)?$" # Optional local version identifier
64
+ )
65
+
66
+ warp_version = get_warp_version()
67
+ warp_clang_version = get_warp_clang_version()
68
+
69
+ self.assertIsNotNone(re.match(pep440_pattern, warp_version))
70
+ self.assertIsNotNone(re.match(pep440_pattern, warp_clang_version))
71
+
72
+
73
+ if __name__ == "__main__":
74
+ wp.clear_kernel_cache()
75
+ unittest.main(verbosity=2)
@@ -738,6 +738,28 @@ def test_tile_extract_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=flo
738
738
  wp.atomic_add(b, i, j, wp.tile_extract(tile, x, y))
739
739
 
740
740
 
741
+ @wp.kernel
742
+ def test_tile_extract_vec_kernel(x: wp.array(dtype=wp.vec3), y: wp.array(dtype=float)):
743
+ i = wp.tid()
744
+
745
+ tile = wp.tile_load(x, shape=(TILE_M))
746
+
747
+ a = tile[i][1]
748
+
749
+ y[i] = a
750
+
751
+
752
+ @wp.kernel
753
+ def test_tile_extract_mat_kernel(x: wp.array(dtype=wp.mat33), y: wp.array(dtype=float)):
754
+ i = wp.tid()
755
+
756
+ tile = wp.tile_load(x, shape=(TILE_M))
757
+
758
+ a = tile[i][1, 1]
759
+
760
+ y[i] = a
761
+
762
+
741
763
  def test_tile_extract(test, device):
742
764
  block_dim = 16
743
765
 
@@ -763,6 +785,40 @@ def test_tile_extract(test, device):
763
785
  expected_grad = np.ones_like(input)
764
786
  assert_np_equal(a.grad.numpy(), expected_grad)
765
787
 
788
+ # vector element test
789
+ x = wp.ones(TILE_M, dtype=wp.vec3, requires_grad=True, device=device)
790
+ y = wp.zeros(TILE_M, dtype=float, requires_grad=True, device=device)
791
+
792
+ with wp.Tape() as tape:
793
+ wp.launch(test_tile_extract_vec_kernel, dim=[TILE_M], inputs=[x, y], block_dim=TILE_DIM, device=device)
794
+
795
+ y.grad = wp.ones_like(y)
796
+
797
+ tape.backward()
798
+
799
+ x_grad_np = np.zeros((TILE_M, 3), dtype=float)
800
+ x_grad_np[:, 1] = 1.0
801
+
802
+ assert_np_equal(x.grad.numpy(), x_grad_np)
803
+ assert_np_equal(y.numpy(), np.ones(TILE_M, dtype=float))
804
+
805
+ # matrix element test
806
+ x = wp.ones(TILE_M, dtype=wp.mat33, requires_grad=True, device=device)
807
+ y = wp.zeros(TILE_M, dtype=float, requires_grad=True, device=device)
808
+
809
+ with wp.Tape() as tape:
810
+ wp.launch(test_tile_extract_mat_kernel, dim=[TILE_M], inputs=[x, y], block_dim=TILE_DIM, device=device)
811
+
812
+ y.grad = wp.ones_like(y)
813
+
814
+ tape.backward()
815
+
816
+ x_grad_np = np.zeros((TILE_M, 3, 3), dtype=float)
817
+ x_grad_np[:, 1, 1] = 1.0
818
+
819
+ assert_np_equal(y.numpy(), np.ones(TILE_M, dtype=float))
820
+ assert_np_equal(x.grad.numpy(), x_grad_np)
821
+
766
822
 
767
823
  @wp.kernel(module="unique")
768
824
  def test_tile_extract_repeated_kernel(a: wp.array2d(dtype=float), b: wp.array2d(dtype=float)):
@@ -822,6 +878,28 @@ def test_tile_assign_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
822
878
  wp.tile_atomic_add(y, a, offset=(0,))
823
879
 
824
880
 
881
+ @wp.kernel
882
+ def test_tile_assign_vec_kernel(x: wp.array(dtype=float), y: wp.array(dtype=wp.vec3)):
883
+ i = wp.tid()
884
+
885
+ a = wp.tile_zeros(shape=(TILE_M,), dtype=wp.vec3)
886
+
887
+ a[i][1] = x[i]
888
+
889
+ wp.tile_atomic_add(y, a, offset=(0,))
890
+
891
+
892
+ @wp.kernel
893
+ def test_tile_assign_mat_kernel(x: wp.array(dtype=float), y: wp.array(dtype=wp.mat33)):
894
+ i = wp.tid()
895
+
896
+ a = wp.tile_zeros(shape=(TILE_M,), dtype=wp.mat33)
897
+
898
+ a[i][1, 1] = x[i]
899
+
900
+ wp.tile_atomic_add(y, a, offset=(0,))
901
+
902
+
825
903
  def test_tile_assign(test, device):
826
904
  x = wp.full(TILE_M, 2.0, dtype=float, device=device, requires_grad=True)
827
905
  y = wp.zeros(TILE_M, dtype=float, device=device, requires_grad=True)
@@ -836,6 +914,40 @@ def test_tile_assign(test, device):
836
914
  assert_np_equal(y.numpy(), np.full(TILE_M, 2.0, dtype=np.float32))
837
915
  assert_np_equal(x.grad.numpy(), np.full(TILE_M, 1.0, dtype=np.float32))
838
916
 
917
+ # vector element test
918
+ x = wp.full(TILE_M, 2.0, dtype=float, device=device, requires_grad=True)
919
+ y = wp.zeros(TILE_M, dtype=wp.vec3, device=device, requires_grad=True)
920
+
921
+ tape = wp.Tape()
922
+ with tape:
923
+ wp.launch(test_tile_assign_vec_kernel, dim=[TILE_M], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
924
+
925
+ y.grad = wp.ones_like(y)
926
+ tape.backward()
927
+
928
+ y_np = np.zeros((TILE_M, 3), dtype=float)
929
+ y_np[:, 1] = 2.0
930
+
931
+ assert_np_equal(y.numpy(), y_np)
932
+ assert_np_equal(x.grad.numpy(), np.full(TILE_M, 1.0, dtype=np.float32))
933
+
934
+ # matrix element test
935
+ x = wp.full(TILE_M, 2.0, dtype=float, device=device, requires_grad=True)
936
+ y = wp.zeros(TILE_M, dtype=wp.mat33, device=device, requires_grad=True)
937
+
938
+ tape = wp.Tape()
939
+ with tape:
940
+ wp.launch(test_tile_assign_mat_kernel, dim=[TILE_M], inputs=[x], outputs=[y], block_dim=TILE_DIM, device=device)
941
+
942
+ y.grad = wp.ones_like(y)
943
+ tape.backward()
944
+
945
+ y_np = np.zeros((TILE_M, 3, 3), dtype=float)
946
+ y_np[:, 1, 1] = 2.0
947
+
948
+ assert_np_equal(y.numpy(), y_np)
949
+ assert_np_equal(x.grad.numpy(), np.full(TILE_M, 1.0, dtype=np.float32))
950
+
839
951
 
840
952
  @wp.kernel
841
953
  def test_tile_where_kernel(select: int, x: wp.array(dtype=float), y: wp.array(dtype=float), z: wp.array(dtype=float)):
@@ -1178,6 +1290,71 @@ def test_tile_len(test, device):
1178
1290
  test.assertEqual(out.numpy()[0], TILE_M)
1179
1291
 
1180
1292
 
1293
+ @wp.struct
1294
+ class TestStruct:
1295
+ x: wp.float32
1296
+ y: wp.vec3
1297
+
1298
+
1299
+ @wp.kernel
1300
+ def test_tile_construction_kernel(
1301
+ out_zeros: wp.array(dtype=float),
1302
+ out_ones: wp.array(dtype=float),
1303
+ out_arange: wp.array(dtype=float),
1304
+ out_full_twos: wp.array(dtype=float),
1305
+ out_full_vecs: wp.array(dtype=wp.vec3),
1306
+ out_full_mats: wp.array(dtype=wp.mat33),
1307
+ out_full_structs: wp.array(dtype=TestStruct),
1308
+ ):
1309
+ zeros = wp.tile_zeros(TILE_M, dtype=float)
1310
+ ones = wp.tile_ones(TILE_M, dtype=float)
1311
+ arange = wp.tile_arange(TILE_M, dtype=float)
1312
+ full_twos = wp.tile_full(TILE_M, value=2.0, dtype=float)
1313
+ full_vecs = wp.tile_full(TILE_M, value=wp.vec3(1.0), dtype=wp.vec3)
1314
+ full_mats = wp.tile_full(TILE_M, value=wp.mat33(1.0), dtype=wp.mat33)
1315
+
1316
+ ts = TestStruct()
1317
+ ts.x = wp.float32(2.0)
1318
+ ts.y = wp.vec3(1.0)
1319
+ full_structs = wp.tile_full(TILE_M, value=ts, dtype=TestStruct)
1320
+
1321
+ wp.tile_store(out_zeros, zeros)
1322
+ wp.tile_store(out_ones, ones)
1323
+ wp.tile_store(out_arange, arange)
1324
+ wp.tile_store(out_full_twos, full_twos)
1325
+ wp.tile_store(out_full_vecs, full_vecs)
1326
+ wp.tile_store(out_full_mats, full_mats)
1327
+ wp.tile_store(out_full_structs, full_structs)
1328
+
1329
+
1330
+ def test_tile_construction(test, device):
1331
+ zeros = wp.empty(TILE_M, dtype=float, device=device)
1332
+ ones = wp.empty(TILE_M, dtype=float, device=device)
1333
+ arange = wp.empty(TILE_M, dtype=float, device=device)
1334
+ full_twos = wp.empty(TILE_M, dtype=float, device=device)
1335
+ full_vecs = wp.empty(TILE_M, dtype=wp.vec3, device=device)
1336
+ full_mats = wp.empty(TILE_M, dtype=wp.mat33, device=device)
1337
+ full_structs = wp.empty(TILE_M, dtype=TestStruct, device=device)
1338
+
1339
+ wp.launch_tiled(
1340
+ test_tile_construction_kernel,
1341
+ dim=1,
1342
+ inputs=[],
1343
+ outputs=[zeros, ones, arange, full_twos, full_vecs, full_mats, full_structs],
1344
+ block_dim=TILE_DIM,
1345
+ device=device,
1346
+ )
1347
+
1348
+ assert_np_equal(zeros.numpy(), np.zeros(TILE_M, dtype=float))
1349
+ assert_np_equal(ones.numpy(), np.ones(TILE_M, dtype=float))
1350
+ assert_np_equal(full_twos.numpy(), np.full(TILE_M, 2.0, dtype=float))
1351
+ assert_np_equal(full_vecs.numpy(), np.ones((TILE_M, 3), dtype=float))
1352
+ assert_np_equal(full_mats.numpy(), np.ones((TILE_M, 3, 3), dtype=float))
1353
+ assert_np_equal(full_structs.numpy()["x"], np.full(TILE_M, 2.0, dtype=float))
1354
+ assert_np_equal(full_structs.numpy()["y"], np.ones((TILE_M, 3), dtype=float))
1355
+ assert_np_equal(arange.numpy(), np.arange(TILE_M, dtype=float))
1356
+
1357
+
1181
1358
  @wp.kernel
1182
1359
  def test_tile_print_kernel():
1183
1360
  # shared tile
@@ -1330,6 +1507,7 @@ add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad
1330
1507
  add_function_test(TestTile, "test_tile_squeeze", test_tile_squeeze, devices=devices)
1331
1508
  add_function_test(TestTile, "test_tile_reshape", test_tile_reshape, devices=devices)
1332
1509
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
1510
+ add_function_test(TestTile, "test_tile_construction", test_tile_construction, devices=devices)
1333
1511
  # add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
1334
1512
  # add_function_test(TestTile, "test_tile_inplace", test_tile_inplace, devices=devices)
1335
1513
  # add_function_test(TestTile, "test_tile_astype", test_tile_astype, devices=devices)