warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1495 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any
18
+
19
+ import warp as wp
20
+ import warp.tests.aux_test_name_clash1 as name_clash_module_1
21
+ import warp.tests.aux_test_name_clash2 as name_clash_module_2
22
+ from warp.tests.unittest_utils import *
23
+
24
+ # =======================================================================
25
+
26
+
27
+ @wp.kernel
28
+ def global_kernel(a: wp.array(dtype=int)):
29
+ a[0] = 17
30
+
31
+
32
+ global_kernel_1 = global_kernel
33
+
34
+
35
+ @wp.kernel
36
+ def global_kernel(a: wp.array(dtype=int)):
37
+ a[0] = 42
38
+
39
+
40
+ global_kernel_2 = global_kernel
41
+
42
+
43
+ def test_global_kernel_redefine(test, device):
44
+ """Ensure that referenced kernels remain valid and unique, even when redefined."""
45
+
46
+ with wp.ScopedDevice(device):
47
+ a = wp.zeros(1, dtype=int)
48
+
49
+ wp.launch(global_kernel, dim=1, inputs=[a])
50
+ test.assertEqual(a.numpy()[0], 42)
51
+
52
+ wp.launch(global_kernel_1, dim=1, inputs=[a])
53
+ test.assertEqual(a.numpy()[0], 17)
54
+
55
+ wp.launch(global_kernel_2, dim=1, inputs=[a])
56
+ test.assertEqual(a.numpy()[0], 42)
57
+
58
+
59
+ # =======================================================================
60
+
61
+
62
+ @wp.func
63
+ def global_func():
64
+ return 17
65
+
66
+
67
+ global_func_1 = global_func
68
+
69
+
70
+ @wp.func
71
+ def global_func():
72
+ return 42
73
+
74
+
75
+ global_func_2 = global_func
76
+
77
+
78
+ @wp.kernel
79
+ def global_func_kernel(a: wp.array(dtype=int)):
80
+ a[0] = global_func()
81
+ a[1] = global_func_1()
82
+ a[2] = global_func_2()
83
+
84
+
85
+ def test_global_func_redefine(test, device):
86
+ """Ensure that referenced functions remain valid and unique, even when redefined."""
87
+
88
+ with wp.ScopedDevice(device):
89
+ a = wp.zeros(3, dtype=int)
90
+ wp.launch(global_func_kernel, dim=1, inputs=[a])
91
+ assert_np_equal(a.numpy(), np.array([42, 17, 42]))
92
+
93
+
94
+ # =======================================================================
95
+
96
+
97
+ @wp.struct
98
+ class GlobalStruct:
99
+ v: float
100
+
101
+
102
+ GlobalStruct1 = GlobalStruct
103
+
104
+
105
+ @wp.struct
106
+ class GlobalStruct:
107
+ v: wp.vec2
108
+
109
+
110
+ GlobalStruct2 = GlobalStruct
111
+
112
+
113
+ @wp.kernel
114
+ def global_struct_args_kernel(s0: GlobalStruct, s1: GlobalStruct1, s2: GlobalStruct2, a: wp.array(dtype=float)):
115
+ a[0] = s0.v[0]
116
+ a[1] = s0.v[1]
117
+ a[2] = s1.v
118
+ a[3] = s2.v[0]
119
+ a[4] = s2.v[1]
120
+
121
+
122
+ def test_global_struct_args_redefine(test, device):
123
+ """Ensure that referenced structs remain valid and unique, even when redefined."""
124
+ with wp.ScopedDevice(device):
125
+ s0 = GlobalStruct()
126
+ s1 = GlobalStruct1()
127
+ s2 = GlobalStruct2()
128
+ s0.v = wp.vec2(1.0, 2.0)
129
+ s1.v = 3.0
130
+ s2.v = wp.vec2(4.0, 5.0)
131
+
132
+ a = wp.zeros(5, dtype=float)
133
+
134
+ wp.launch(global_struct_args_kernel, dim=1, inputs=[s0, s1, s2, a])
135
+
136
+ assert_np_equal(a.numpy(), np.array([1, 2, 3, 4, 5], dtype=np.float32))
137
+
138
+
139
+ @wp.kernel
140
+ def global_struct_ctor_kernel(a: wp.array(dtype=float)):
141
+ s0 = GlobalStruct()
142
+ s1 = GlobalStruct1()
143
+ s2 = GlobalStruct2()
144
+ s0.v = wp.vec2(1.0, 2.0)
145
+ s1.v = 3.0
146
+ s2.v = wp.vec2(4.0, 5.0)
147
+ a[0] = s0.v[0]
148
+ a[1] = s0.v[1]
149
+ a[2] = s1.v
150
+ a[3] = s2.v[0]
151
+ a[4] = s2.v[1]
152
+
153
+
154
+ def test_global_struct_ctor_redefine(test, device):
155
+ """Ensure that referenced structs remain valid and unique, even when redefined."""
156
+ with wp.ScopedDevice(device):
157
+ a = wp.zeros(5, dtype=float)
158
+ wp.launch(global_struct_ctor_kernel, dim=1, inputs=[a])
159
+ assert_np_equal(a.numpy(), np.array([1, 2, 3, 4, 5], dtype=np.float32))
160
+
161
+
162
+ # =======================================================================
163
+
164
+
165
+ # "primary" (first) overload
166
+ @wp.func
167
+ def global_func_po(x: int):
168
+ return x * x
169
+
170
+
171
+ # "secondary" overload
172
+ @wp.func
173
+ def global_func_po(x: float):
174
+ return x * x
175
+
176
+
177
+ # redefine primary overload
178
+ @wp.func
179
+ def global_func_po(x: int):
180
+ return x * x * x
181
+
182
+
183
+ @wp.kernel
184
+ def global_overload_primary_kernel(a: wp.array(dtype=float)):
185
+ # use primary (int) overload
186
+ a[0] = float(global_func_po(2))
187
+ # use secondary (float) overload
188
+ a[1] = global_func_po(2.0)
189
+
190
+
191
+ def test_global_overload_primary_redefine(test, device):
192
+ """Ensure that redefining a primary overload works and doesn't affect secondary overloads."""
193
+ with wp.ScopedDevice(device):
194
+ a = wp.zeros(2, dtype=float)
195
+ wp.launch(global_overload_primary_kernel, dim=1, inputs=[a])
196
+ assert_np_equal(a.numpy(), np.array([8, 4], dtype=np.float32))
197
+
198
+
199
+ # =======================================================================
200
+
201
+
202
+ # "primary" (first) overload
203
+ @wp.func
204
+ def global_func_so(x: int):
205
+ return x * x
206
+
207
+
208
+ # "secondary" overload
209
+ @wp.func
210
+ def global_func_so(x: float):
211
+ return x * x
212
+
213
+
214
+ # redefine secondary overload
215
+ @wp.func
216
+ def global_func_so(x: float):
217
+ return x * x * x
218
+
219
+
220
+ @wp.kernel
221
+ def global_overload_secondary_kernel(a: wp.array(dtype=float)):
222
+ # use primary (int) overload
223
+ a[0] = float(global_func_so(2))
224
+ # use secondary (float) overload
225
+ a[1] = global_func_so(2.0)
226
+
227
+
228
+ def test_global_overload_secondary_redefine(test, device):
229
+ """Ensure that redefining a secondary overload works."""
230
+ with wp.ScopedDevice(device):
231
+ a = wp.zeros(2, dtype=float)
232
+ wp.launch(global_overload_secondary_kernel, dim=1, inputs=[a])
233
+ assert_np_equal(a.numpy(), np.array([4, 8], dtype=np.float32))
234
+
235
+
236
+ # =======================================================================
237
+
238
+
239
+ @wp.kernel
240
+ def global_generic_kernel(x: Any, a: wp.array(dtype=Any)):
241
+ a[0] = x * x
242
+
243
+
244
+ global_generic_kernel_1 = global_generic_kernel
245
+
246
+
247
+ @wp.kernel
248
+ def global_generic_kernel(x: Any, a: wp.array(dtype=Any)):
249
+ a[0] = x * x * x
250
+
251
+
252
+ global_generic_kernel_2 = global_generic_kernel
253
+
254
+
255
+ def test_global_generic_kernel_redefine(test, device):
256
+ """Ensure that referenced generic kernels remain valid and unique, even when redefined."""
257
+
258
+ with wp.ScopedDevice(device):
259
+ ai = wp.zeros(1, dtype=int)
260
+ af = wp.zeros(1, dtype=float)
261
+
262
+ wp.launch(global_generic_kernel, dim=1, inputs=[2, ai])
263
+ wp.launch(global_generic_kernel, dim=1, inputs=[2.0, af])
264
+ test.assertEqual(ai.numpy()[0], 8)
265
+ test.assertEqual(af.numpy()[0], 8.0)
266
+
267
+ wp.launch(global_generic_kernel_1, dim=1, inputs=[2, ai])
268
+ wp.launch(global_generic_kernel_1, dim=1, inputs=[2.0, af])
269
+ test.assertEqual(ai.numpy()[0], 4)
270
+ test.assertEqual(af.numpy()[0], 4.0)
271
+
272
+ wp.launch(global_generic_kernel_2, dim=1, inputs=[2, ai])
273
+ wp.launch(global_generic_kernel_2, dim=1, inputs=[2.0, af])
274
+ test.assertEqual(ai.numpy()[0], 8)
275
+ test.assertEqual(af.numpy()[0], 8.0)
276
+
277
+
278
+ # =======================================================================
279
+
280
+
281
+ @wp.func
282
+ def global_generic_func(x: Any):
283
+ return x * x
284
+
285
+
286
+ global_generic_func_1 = global_generic_func
287
+
288
+
289
+ @wp.func
290
+ def global_generic_func(x: Any):
291
+ return x * x * x
292
+
293
+
294
+ global_generic_func_2 = global_generic_func
295
+
296
+
297
+ @wp.kernel
298
+ def global_generic_func_kernel(ai: wp.array(dtype=int), af: wp.array(dtype=float)):
299
+ ai[0] = global_generic_func(2)
300
+ af[0] = global_generic_func(2.0)
301
+
302
+ ai[1] = global_generic_func_1(2)
303
+ af[1] = global_generic_func_1(2.0)
304
+
305
+ ai[2] = global_generic_func_2(2)
306
+ af[2] = global_generic_func_2(2.0)
307
+
308
+
309
+ def test_global_generic_func_redefine(test, device):
310
+ """Ensure that referenced generic functions remain valid and unique, even when redefined."""
311
+
312
+ with wp.ScopedDevice(device):
313
+ ai = wp.zeros(3, dtype=int)
314
+ af = wp.zeros(3, dtype=float)
315
+ wp.launch(global_generic_func_kernel, dim=1, inputs=[ai, af])
316
+ assert_np_equal(ai.numpy(), np.array([8, 4, 8], dtype=np.int32))
317
+ assert_np_equal(af.numpy(), np.array([8, 4, 8], dtype=np.float32))
318
+
319
+
320
+ # =======================================================================
321
+
322
+
323
+ def create_kernel_simple():
324
+ # not a closure
325
+ @wp.kernel
326
+ def k(a: wp.array(dtype=int)):
327
+ a[0] = 17
328
+
329
+ return k
330
+
331
+
332
+ simple_kernel_1 = create_kernel_simple()
333
+ simple_kernel_2 = create_kernel_simple()
334
+
335
+
336
+ def test_create_kernel_simple(test, device):
337
+ """Test creating multiple identical simple (non-closure) kernels."""
338
+ with wp.ScopedDevice(device):
339
+ a = wp.zeros(1, dtype=int)
340
+
341
+ wp.launch(simple_kernel_1, dim=1, inputs=[a])
342
+ test.assertEqual(a.numpy()[0], 17)
343
+
344
+ wp.launch(simple_kernel_2, dim=1, inputs=[a])
345
+ test.assertEqual(a.numpy()[0], 17)
346
+
347
+
348
+ # =======================================================================
349
+
350
+
351
+ def create_func_simple():
352
+ # not a closure
353
+ @wp.func
354
+ def f():
355
+ return 17
356
+
357
+ return f
358
+
359
+
360
+ simple_func_1 = create_func_simple()
361
+ simple_func_2 = create_func_simple()
362
+
363
+
364
+ @wp.kernel
365
+ def simple_func_kernel(a: wp.array(dtype=int)):
366
+ a[0] = simple_func_1()
367
+ a[1] = simple_func_2()
368
+
369
+
370
+ def test_create_func_simple(test, device):
371
+ """Test creating multiple identical simple (non-closure) functions."""
372
+ with wp.ScopedDevice(device):
373
+ a = wp.zeros(2, dtype=int)
374
+ wp.launch(simple_func_kernel, dim=1, inputs=[a])
375
+ assert_np_equal(a.numpy(), np.array([17, 17]))
376
+
377
+
378
+ # =======================================================================
379
+
380
+
381
+ def create_struct_simple():
382
+ @wp.struct
383
+ class S:
384
+ x: int
385
+
386
+ return S
387
+
388
+
389
+ SimpleStruct1 = create_struct_simple()
390
+ SimpleStruct2 = create_struct_simple()
391
+
392
+
393
+ @wp.kernel
394
+ def simple_struct_args_kernel(s1: SimpleStruct1, s2: SimpleStruct2, a: wp.array(dtype=int)):
395
+ a[0] = s1.x
396
+ a[1] = s2.x
397
+
398
+
399
+ def test_create_struct_simple_args(test, device):
400
+ """Test creating multiple identical structs and passing them as arguments."""
401
+ with wp.ScopedDevice(device):
402
+ s1 = SimpleStruct1()
403
+ s2 = SimpleStruct2()
404
+ s1.x = 17
405
+ s2.x = 42
406
+ a = wp.zeros(2, dtype=int)
407
+ wp.launch(simple_struct_args_kernel, dim=1, inputs=[s1, s2, a])
408
+ assert_np_equal(a.numpy(), np.array([17, 42]))
409
+
410
+
411
+ @wp.kernel
412
+ def simple_struct_ctor_kernel(a: wp.array(dtype=int)):
413
+ s1 = SimpleStruct1()
414
+ s2 = SimpleStruct2()
415
+ s1.x = 17
416
+ s2.x = 42
417
+ a[0] = s1.x
418
+ a[1] = s2.x
419
+
420
+
421
+ def test_create_struct_simple_ctor(test, device):
422
+ """Test creating multiple identical structs and constructing them in kernels."""
423
+ with wp.ScopedDevice(device):
424
+ a = wp.zeros(2, dtype=int)
425
+ wp.launch(simple_struct_ctor_kernel, dim=1, inputs=[a])
426
+ assert_np_equal(a.numpy(), np.array([17, 42]))
427
+
428
+
429
+ # =======================================================================
430
+
431
+
432
+ def create_generic_kernel_simple():
433
+ # not a closure
434
+ @wp.kernel
435
+ def k(x: Any, a: wp.array(dtype=Any)):
436
+ a[0] = x * x
437
+
438
+ return k
439
+
440
+
441
+ simple_generic_kernel_1 = create_generic_kernel_simple()
442
+ simple_generic_kernel_2 = create_generic_kernel_simple()
443
+
444
+
445
+ def test_create_generic_kernel_simple(test, device):
446
+ """Test creating multiple identical simple (non-closure) generic kernels."""
447
+ with wp.ScopedDevice(device):
448
+ ai = wp.zeros(1, dtype=int)
449
+ af = wp.zeros(1, dtype=float)
450
+
451
+ wp.launch(simple_generic_kernel_1, dim=1, inputs=[2, ai])
452
+ wp.launch(simple_generic_kernel_1, dim=1, inputs=[2.0, af])
453
+ test.assertEqual(ai.numpy()[0], 4)
454
+ test.assertEqual(af.numpy()[0], 4.0)
455
+
456
+ wp.launch(simple_generic_kernel_2, dim=1, inputs=[2, ai])
457
+ wp.launch(simple_generic_kernel_2, dim=1, inputs=[2.0, af])
458
+ test.assertEqual(ai.numpy()[0], 4)
459
+ test.assertEqual(af.numpy()[0], 4.0)
460
+
461
+
462
+ # =======================================================================
463
+
464
+
465
+ def create_generic_func_simple():
466
+ # not a closure
467
+ @wp.func
468
+ def f(x: Any):
469
+ return x * x
470
+
471
+ return f
472
+
473
+
474
+ simple_generic_func_1 = create_generic_func_simple()
475
+ simple_generic_func_2 = create_generic_func_simple()
476
+
477
+
478
+ @wp.kernel
479
+ def simple_generic_func_kernel(
480
+ ai: wp.array(dtype=int),
481
+ af: wp.array(dtype=float),
482
+ ):
483
+ ai[0] = simple_generic_func_1(2)
484
+ af[0] = simple_generic_func_1(2.0)
485
+
486
+ ai[1] = simple_generic_func_2(2)
487
+ af[1] = simple_generic_func_2(2.0)
488
+
489
+
490
+ def test_create_generic_func_simple(test, device):
491
+ """Test creating multiple identical simple (non-closure) generic functions."""
492
+ with wp.ScopedDevice(device):
493
+ ai = wp.zeros(2, dtype=int)
494
+ af = wp.zeros(2, dtype=float)
495
+ wp.launch(simple_generic_func_kernel, dim=1, inputs=[ai, af])
496
+ assert_np_equal(ai.numpy(), np.array([4, 4], dtype=np.int32))
497
+ assert_np_equal(af.numpy(), np.array([4, 4], dtype=np.float32))
498
+
499
+
500
+ # =======================================================================
501
+
502
+
503
+ def create_kernel_cond(cond):
504
+ if cond:
505
+
506
+ @wp.kernel
507
+ def k(a: wp.array(dtype=int)):
508
+ a[0] = 17
509
+ else:
510
+
511
+ @wp.kernel
512
+ def k(a: wp.array(dtype=int)):
513
+ a[0] = 42
514
+
515
+ return k
516
+
517
+
518
+ cond_kernel_1 = create_kernel_cond(True)
519
+ cond_kernel_2 = create_kernel_cond(False)
520
+
521
+
522
+ def test_create_kernel_cond(test, device):
523
+ """Test conditionally creating different simple (non-closure) kernels."""
524
+ with wp.ScopedDevice(device):
525
+ a = wp.zeros(1, dtype=int)
526
+
527
+ wp.launch(cond_kernel_1, dim=1, inputs=[a])
528
+ test.assertEqual(a.numpy()[0], 17)
529
+
530
+ wp.launch(cond_kernel_2, dim=1, inputs=[a])
531
+ test.assertEqual(a.numpy()[0], 42)
532
+
533
+
534
+ # =======================================================================
535
+
536
+
537
+ def create_func_cond(cond):
538
+ if cond:
539
+
540
+ @wp.func
541
+ def f():
542
+ return 17
543
+ else:
544
+
545
+ @wp.func
546
+ def f():
547
+ return 42
548
+
549
+ return f
550
+
551
+
552
+ cond_func_1 = create_func_cond(True)
553
+ cond_func_2 = create_func_cond(False)
554
+
555
+
556
+ @wp.kernel
557
+ def cond_func_kernel(a: wp.array(dtype=int)):
558
+ a[0] = cond_func_1()
559
+ a[1] = cond_func_2()
560
+
561
+
562
+ def test_create_func_cond(test, device):
563
+ """Test conditionally creating different simple (non-closure) functions."""
564
+ with wp.ScopedDevice(device):
565
+ a = wp.zeros(2, dtype=int)
566
+ wp.launch(cond_func_kernel, dim=1, inputs=[a])
567
+ assert_np_equal(a.numpy(), np.array([17, 42]))
568
+
569
+
570
+ # =======================================================================
571
+
572
+
573
+ def create_struct_cond(cond):
574
+ if cond:
575
+
576
+ @wp.struct
577
+ class S:
578
+ v: float
579
+ else:
580
+
581
+ @wp.struct
582
+ class S:
583
+ v: wp.vec2
584
+
585
+ return S
586
+
587
+
588
+ CondStruct1 = create_struct_cond(True)
589
+ CondStruct2 = create_struct_cond(False)
590
+
591
+
592
+ @wp.kernel
593
+ def cond_struct_args_kernel(s1: CondStruct1, s2: CondStruct2, a: wp.array(dtype=float)):
594
+ a[0] = s1.v
595
+ a[1] = s2.v[0]
596
+ a[2] = s2.v[1]
597
+
598
+
599
+ def test_create_struct_cond_args(test, device):
600
+ """Test conditionally creating different structs and passing them as arguments."""
601
+ with wp.ScopedDevice(device):
602
+ s1 = CondStruct1()
603
+ s2 = CondStruct2()
604
+ s1.v = 1.0
605
+ s2.v = wp.vec2(2.0, 3.0)
606
+ a = wp.zeros(3, dtype=float)
607
+ wp.launch(cond_struct_args_kernel, dim=1, inputs=[s1, s2, a])
608
+ assert_np_equal(a.numpy(), np.array([1, 2, 3], dtype=np.float32))
609
+
610
+
611
+ @wp.kernel
612
+ def cond_struct_ctor_kernel(a: wp.array(dtype=float)):
613
+ s1 = CondStruct1()
614
+ s2 = CondStruct2()
615
+ s1.v = 1.0
616
+ s2.v = wp.vec2(2.0, 3.0)
617
+ a[0] = s1.v
618
+ a[1] = s2.v[0]
619
+ a[2] = s2.v[1]
620
+
621
+
622
+ def test_create_struct_cond_ctor(test, device):
623
+ """Test conditionally creating different structs and passing them as arguments."""
624
+ with wp.ScopedDevice(device):
625
+ a = wp.zeros(3, dtype=float)
626
+ wp.launch(cond_struct_ctor_kernel, dim=1, inputs=[a])
627
+ assert_np_equal(a.numpy(), np.array([1, 2, 3], dtype=np.float32))
628
+
629
+
630
+ # =======================================================================
631
+
632
+
633
+ def create_generic_kernel_cond(cond):
634
+ if cond:
635
+
636
+ @wp.kernel
637
+ def k(x: Any, a: wp.array(dtype=Any)):
638
+ a[0] = x * x
639
+ else:
640
+
641
+ @wp.kernel
642
+ def k(x: Any, a: wp.array(dtype=Any)):
643
+ a[0] = x * x * x
644
+
645
+ return k
646
+
647
+
648
+ cond_generic_kernel_1 = create_generic_kernel_cond(True)
649
+ cond_generic_kernel_2 = create_generic_kernel_cond(False)
650
+
651
+
652
+ def test_create_generic_kernel_cond(test, device):
653
+ """Test creating different simple (non-closure) generic kernels."""
654
+ with wp.ScopedDevice(device):
655
+ ai = wp.zeros(1, dtype=int)
656
+ af = wp.zeros(1, dtype=float)
657
+
658
+ wp.launch(cond_generic_kernel_1, dim=1, inputs=[2, ai])
659
+ wp.launch(cond_generic_kernel_1, dim=1, inputs=[2.0, af])
660
+ test.assertEqual(ai.numpy()[0], 4)
661
+ test.assertEqual(af.numpy()[0], 4.0)
662
+
663
+ wp.launch(cond_generic_kernel_2, dim=1, inputs=[2, ai])
664
+ wp.launch(cond_generic_kernel_2, dim=1, inputs=[2.0, af])
665
+ test.assertEqual(ai.numpy()[0], 8)
666
+ test.assertEqual(af.numpy()[0], 8.0)
667
+
668
+
669
+ # =======================================================================
670
+
671
+
672
+ def create_generic_func_cond(cond):
673
+ if cond:
674
+
675
+ @wp.func
676
+ def f(x: Any):
677
+ return x * x
678
+ else:
679
+
680
+ @wp.func
681
+ def f(x: Any):
682
+ return x * x * x
683
+
684
+ return f
685
+
686
+
687
+ cond_generic_func_1 = create_generic_func_cond(True)
688
+ cond_generic_func_2 = create_generic_func_cond(False)
689
+
690
+
691
+ @wp.kernel
692
+ def cond_generic_func_kernel(
693
+ ai: wp.array(dtype=int),
694
+ af: wp.array(dtype=float),
695
+ ):
696
+ ai[0] = cond_generic_func_1(2)
697
+ af[0] = cond_generic_func_1(2.0)
698
+
699
+ ai[1] = cond_generic_func_2(2)
700
+ af[1] = cond_generic_func_2(2.0)
701
+
702
+
703
+ def test_create_generic_func_cond(test, device):
704
+ """Test creating different simple (non-closure) generic functions."""
705
+ with wp.ScopedDevice(device):
706
+ ai = wp.zeros(2, dtype=int)
707
+ af = wp.zeros(2, dtype=float)
708
+ wp.launch(cond_generic_func_kernel, dim=1, inputs=[ai, af])
709
+ assert_np_equal(ai.numpy(), np.array([4, 8], dtype=np.int32))
710
+ assert_np_equal(af.numpy(), np.array([4, 8], dtype=np.float32))
711
+
712
+
713
+ # =======================================================================
714
+
715
+
716
+ def create_kernel_closure(value: int):
717
+ # closure
718
+ @wp.kernel
719
+ def k(a: wp.array(dtype=int)):
720
+ a[0] = value
721
+
722
+ return k
723
+
724
+
725
+ closure_kernel_1 = create_kernel_closure(17)
726
+ closure_kernel_2 = create_kernel_closure(42)
727
+
728
+
729
+ def test_create_kernel_closure(test, device):
730
+ """Test creating kernel closures."""
731
+ with wp.ScopedDevice(device):
732
+ a = wp.zeros(1, dtype=int)
733
+
734
+ wp.launch(closure_kernel_1, dim=1, inputs=[a])
735
+ test.assertEqual(a.numpy()[0], 17)
736
+
737
+ wp.launch(closure_kernel_2, dim=1, inputs=[a])
738
+ test.assertEqual(a.numpy()[0], 42)
739
+
740
+
741
+ # =======================================================================
742
+
743
+
744
+ def create_func_closure(value: int):
745
+ # closure
746
+ @wp.func
747
+ def f():
748
+ return value
749
+
750
+ return f
751
+
752
+
753
+ closure_func_1 = create_func_closure(17)
754
+ closure_func_2 = create_func_closure(42)
755
+
756
+
757
+ @wp.kernel
758
+ def closure_func_kernel(a: wp.array(dtype=int)):
759
+ a[0] = closure_func_1()
760
+ a[1] = closure_func_2()
761
+
762
+
763
+ def test_create_func_closure(test, device):
764
+ """Test creating function closures."""
765
+ with wp.ScopedDevice(device):
766
+ a = wp.zeros(2, dtype=int)
767
+ wp.launch(closure_func_kernel, dim=1, inputs=[a])
768
+ assert_np_equal(a.numpy(), np.array([17, 42]))
769
+
770
+
771
+ # =======================================================================
772
+
773
+
774
+ def create_func_closure_overload(value: int):
775
+ @wp.func
776
+ def f():
777
+ return value
778
+
779
+ @wp.func
780
+ def f(x: int):
781
+ return value * x
782
+
783
+ # return overloaded closure function
784
+ return f
785
+
786
+
787
+ closure_func_overload_1 = create_func_closure_overload(2)
788
+ closure_func_overload_2 = create_func_closure_overload(3)
789
+
790
+
791
+ @wp.kernel
792
+ def closure_func_overload_kernel(a: wp.array(dtype=int)):
793
+ a[0] = closure_func_overload_1()
794
+ a[1] = closure_func_overload_1(2)
795
+ a[2] = closure_func_overload_2()
796
+ a[3] = closure_func_overload_2(2)
797
+
798
+
799
+ def test_create_func_closure_overload(test, device):
800
+ """Test creating overloaded function closures."""
801
+ with wp.ScopedDevice(device):
802
+ a = wp.zeros(4, dtype=int)
803
+ wp.launch(closure_func_overload_kernel, dim=1, inputs=[a])
804
+ assert_np_equal(a.numpy(), np.array([2, 4, 3, 6]))
805
+
806
+
807
+ # =======================================================================
808
+
809
+
810
+ def create_func_closure_overload_selfref(value: int):
811
+ @wp.func
812
+ def f():
813
+ return value
814
+
815
+ @wp.func
816
+ def f(x: int):
817
+ # reference another overload
818
+ return f() * x
819
+
820
+ # return overloaded closure function
821
+ return f
822
+
823
+
824
+ closure_func_overload_selfref_1 = create_func_closure_overload_selfref(2)
825
+ closure_func_overload_selfref_2 = create_func_closure_overload_selfref(3)
826
+
827
+
828
+ @wp.kernel
829
+ def closure_func_overload_selfref_kernel(a: wp.array(dtype=int)):
830
+ a[0] = closure_func_overload_selfref_1()
831
+ a[1] = closure_func_overload_selfref_1(2)
832
+ a[2] = closure_func_overload_selfref_2()
833
+ a[3] = closure_func_overload_selfref_2(2)
834
+
835
+
836
+ def test_create_func_closure_overload_selfref(test, device):
837
+ """Test creating overloaded function closures with self-referencing overloads."""
838
+ with wp.ScopedDevice(device):
839
+ a = wp.zeros(4, dtype=int)
840
+ wp.launch(closure_func_overload_selfref_kernel, dim=1, inputs=[a])
841
+ assert_np_equal(a.numpy(), np.array([2, 4, 3, 6]))
842
+
843
+
844
+ # =======================================================================
845
+
846
+
847
+ def create_func_closure_nonoverload(dtype, value):
848
+ @wp.func
849
+ def f(x: dtype):
850
+ return x * value
851
+
852
+ return f
853
+
854
+
855
+ # functions created in different scopes should NOT be overloads of each other
856
+ # (i.e., creating new functions with the same signature should not replace previous ones)
857
+ closure_func_nonoverload_1 = create_func_closure_nonoverload(int, 2)
858
+ closure_func_nonoverload_2 = create_func_closure_nonoverload(float, 2.0)
859
+ closure_func_nonoverload_3 = create_func_closure_nonoverload(int, 3)
860
+ closure_func_nonoverload_4 = create_func_closure_nonoverload(float, 3.0)
861
+
862
+
863
+ @wp.kernel
864
+ def closure_func_nonoverload_kernel(
865
+ ai: wp.array(dtype=int),
866
+ af: wp.array(dtype=float),
867
+ ):
868
+ ai[0] = closure_func_nonoverload_1(2)
869
+ af[0] = closure_func_nonoverload_2(2.0)
870
+ ai[1] = closure_func_nonoverload_3(2)
871
+ af[1] = closure_func_nonoverload_4(2.0)
872
+
873
+
874
+ def test_create_func_closure_nonoverload(test, device):
875
+ """Test creating function closures that are not overloads of each other (overloads are grouped by scope, not globally)."""
876
+ with wp.ScopedDevice(device):
877
+ ai = wp.zeros(2, dtype=int)
878
+ af = wp.zeros(2, dtype=float)
879
+ wp.launch(closure_func_nonoverload_kernel, dim=1, inputs=[ai, af])
880
+ assert_np_equal(ai.numpy(), np.array([4, 6], dtype=np.int32))
881
+ assert_np_equal(af.numpy(), np.array([4, 6], dtype=np.float32))
882
+
883
+
884
+ # =======================================================================
885
+
886
+
887
+ def create_fk_closure(a, b):
888
+ # closure
889
+ @wp.func
890
+ def f():
891
+ return a
892
+
893
+ # closure
894
+ @wp.kernel
895
+ def k(a: wp.array(dtype=int)):
896
+ a[0] = f() + b
897
+
898
+ return f, k
899
+
900
+
901
+ fk_closure_func_1, fk_closure_kernel_1 = create_fk_closure(10, 7)
902
+ fk_closure_func_2, fk_closure_kernel_2 = create_fk_closure(40, 2)
903
+
904
+
905
+ # use generated functions in a new kernel
906
+ @wp.kernel
907
+ def fk_closure_combine_kernel(a: wp.array(dtype=int)):
908
+ a[0] = fk_closure_func_1() + fk_closure_func_2()
909
+
910
+
911
+ def test_create_fk_closure(test, device):
912
+ """Test creating function and kernel closures together, then reusing the functions in another kernel."""
913
+ with wp.ScopedDevice(device):
914
+ a = wp.zeros(1, dtype=int)
915
+
916
+ wp.launch(fk_closure_kernel_1, dim=1, inputs=[a])
917
+ test.assertEqual(a.numpy()[0], 17)
918
+
919
+ wp.launch(fk_closure_kernel_2, dim=1, inputs=[a])
920
+ test.assertEqual(a.numpy()[0], 42)
921
+
922
+ wp.launch(fk_closure_combine_kernel, dim=1, inputs=[a])
923
+ test.assertEqual(a.numpy()[0], 50)
924
+
925
+
926
+ # =======================================================================
927
+
928
+
929
+ def create_generic_kernel_closure(value):
930
+ @wp.kernel
931
+ def k(x: Any, a: wp.array(dtype=Any)):
932
+ a[0] = x * type(x)(value)
933
+
934
+ return k
935
+
936
+
937
+ generic_closure_kernel_1 = create_generic_kernel_closure(2)
938
+ generic_closure_kernel_2 = create_generic_kernel_closure(3)
939
+
940
+
941
+ def test_create_generic_kernel_closure(test, device):
942
+ """Test creating generic closure kernels."""
943
+ with wp.ScopedDevice(device):
944
+ ai = wp.zeros(1, dtype=int)
945
+ af = wp.zeros(1, dtype=float)
946
+
947
+ wp.launch(generic_closure_kernel_1, dim=1, inputs=[2, ai])
948
+ wp.launch(generic_closure_kernel_1, dim=1, inputs=[2.0, af])
949
+ test.assertEqual(ai.numpy()[0], 4)
950
+ test.assertEqual(af.numpy()[0], 4.0)
951
+
952
+ wp.launch(generic_closure_kernel_2, dim=1, inputs=[2, ai])
953
+ wp.launch(generic_closure_kernel_2, dim=1, inputs=[2.0, af])
954
+ test.assertEqual(ai.numpy()[0], 6)
955
+ test.assertEqual(af.numpy()[0], 6.0)
956
+
957
+
958
+ # =======================================================================
959
+
960
+
961
+ def create_generic_kernel_overload_closure(value, dtype):
962
+ @wp.kernel
963
+ def k(x: Any, a: wp.array(dtype=Any)):
964
+ a[0] = x * type(x)(value)
965
+
966
+ # return only the overload, not the generic kernel
967
+ return wp.overload(k, [dtype, wp.array(dtype=dtype)])
968
+
969
+
970
+ generic_closure_kernel_overload_i1 = create_generic_kernel_overload_closure(2, int)
971
+ generic_closure_kernel_overload_i2 = create_generic_kernel_overload_closure(3, int)
972
+ generic_closure_kernel_overload_f1 = create_generic_kernel_overload_closure(2, float)
973
+ generic_closure_kernel_overload_f2 = create_generic_kernel_overload_closure(3, float)
974
+
975
+
976
+ def test_create_generic_kernel_overload_closure(test, device):
977
+ """Test creating generic closure kernels, but return only overloads, not the generic kernels themselves."""
978
+ with wp.ScopedDevice(device):
979
+ ai = wp.zeros(1, dtype=int)
980
+ af = wp.zeros(1, dtype=float)
981
+
982
+ wp.launch(generic_closure_kernel_overload_i1, dim=1, inputs=[2, ai])
983
+ wp.launch(generic_closure_kernel_overload_f1, dim=1, inputs=[2.0, af])
984
+ test.assertEqual(ai.numpy()[0], 4)
985
+ test.assertEqual(af.numpy()[0], 4.0)
986
+
987
+ wp.launch(generic_closure_kernel_overload_i2, dim=1, inputs=[2, ai])
988
+ wp.launch(generic_closure_kernel_overload_f2, dim=1, inputs=[2.0, af])
989
+ test.assertEqual(ai.numpy()[0], 6)
990
+ test.assertEqual(af.numpy()[0], 6.0)
991
+
992
+
993
+ # =======================================================================
994
+
995
+
996
+ def create_generic_func_closure(value):
997
+ @wp.func
998
+ def f(x: Any):
999
+ return x * type(x)(value)
1000
+
1001
+ return f
1002
+
1003
+
1004
+ generic_closure_func_1 = create_generic_func_closure(2)
1005
+ generic_closure_func_2 = create_generic_func_closure(3)
1006
+
1007
+
1008
+ @wp.kernel
1009
+ def closure_generic_func_kernel(
1010
+ ai: wp.array(dtype=int),
1011
+ af: wp.array(dtype=float),
1012
+ ):
1013
+ ai[0] = generic_closure_func_1(2)
1014
+ af[0] = generic_closure_func_1(2.0)
1015
+
1016
+ ai[1] = generic_closure_func_2(2)
1017
+ af[1] = generic_closure_func_2(2.0)
1018
+
1019
+
1020
+ def test_create_generic_func_closure(test, device):
1021
+ """Test creating generic closure functions."""
1022
+ with wp.ScopedDevice(device):
1023
+ ai = wp.zeros(2, dtype=int)
1024
+ af = wp.zeros(2, dtype=float)
1025
+ wp.launch(closure_generic_func_kernel, dim=1, inputs=[ai, af])
1026
+ assert_np_equal(ai.numpy(), np.array([4, 6], dtype=np.int32))
1027
+ assert_np_equal(af.numpy(), np.array([4, 6], dtype=np.float32))
1028
+
1029
+
1030
+ # =======================================================================
1031
+
1032
+
1033
+ def create_generic_func_closure_overload(value):
1034
+ @wp.func
1035
+ def f(x: Any):
1036
+ return x * type(x)(value)
1037
+
1038
+ @wp.func
1039
+ def f(x: Any, y: Any):
1040
+ return f(x + y)
1041
+
1042
+ # return overloaded generic closure function
1043
+ return f
1044
+
1045
+
1046
+ generic_closure_func_overload_1 = create_generic_func_closure_overload(2)
1047
+ generic_closure_func_overload_2 = create_generic_func_closure_overload(3)
1048
+
1049
+
1050
+ @wp.kernel
1051
+ def generic_closure_func_overload_kernel(
1052
+ ai: wp.array(dtype=int),
1053
+ af: wp.array(dtype=float),
1054
+ ):
1055
+ ai[0] = generic_closure_func_overload_1(1) # 1 * 2 = 2
1056
+ ai[1] = generic_closure_func_overload_2(1) # 1 * 3 = 3
1057
+ ai[2] = generic_closure_func_overload_1(1, 2) # (1 + 2) * 2 = 6
1058
+ ai[3] = generic_closure_func_overload_2(1, 2) # (1 + 2) * 3 = 9
1059
+
1060
+ af[0] = generic_closure_func_overload_1(1.0) # 1 * 2 = 2
1061
+ af[1] = generic_closure_func_overload_2(1.0) # 1 * 3 = 3
1062
+ af[2] = generic_closure_func_overload_1(1.0, 2.0) # (1 + 2) * 2 = 6
1063
+ af[3] = generic_closure_func_overload_2(1.0, 2.0) # (1 + 2) * 3 = 9
1064
+
1065
+
1066
+ def test_create_generic_func_closure_overload(test, device):
1067
+ """Test creating overloaded generic function closures."""
1068
+ with wp.ScopedDevice(device):
1069
+ ai = wp.zeros(4, dtype=int)
1070
+ af = wp.zeros(4, dtype=float)
1071
+ wp.launch(generic_closure_func_overload_kernel, dim=1, inputs=[ai, af])
1072
+ assert_np_equal(ai.numpy(), np.array([2, 3, 6, 9], dtype=np.int32))
1073
+ assert_np_equal(af.numpy(), np.array([2, 3, 6, 9], dtype=np.float32))
1074
+
1075
+
1076
+ # =======================================================================
1077
+
1078
+
1079
+ def create_type_closure_scalar(scalar_type):
1080
+ @wp.kernel
1081
+ def k(input: float, expected: float):
1082
+ x = scalar_type(input)
1083
+ wp.expect_eq(float(x), expected)
1084
+
1085
+ return k
1086
+
1087
+
1088
+ type_closure_kernel_int = create_type_closure_scalar(int)
1089
+ type_closure_kernel_float = create_type_closure_scalar(float)
1090
+ type_closure_kernel_uint8 = create_type_closure_scalar(wp.uint8)
1091
+
1092
+
1093
+ def test_type_closure_scalar(test, device):
1094
+ with wp.ScopedDevice(device):
1095
+ wp.launch(type_closure_kernel_int, dim=1, inputs=[-1.5, -1.0])
1096
+ wp.launch(type_closure_kernel_float, dim=1, inputs=[-1.5, -1.5])
1097
+
1098
+ # FIXME: a problem with type conversions breaks this case
1099
+ # wp.launch(type_closure_kernel_uint8, dim=1, inputs=[-1.5, 255.0])
1100
+
1101
+
1102
+ # =======================================================================
1103
+
1104
+
1105
+ def create_type_closure_vector(vec_type):
1106
+ @wp.kernel
1107
+ def k(expected: float):
1108
+ v = vec_type(1.0)
1109
+ wp.expect_eq(wp.length_sq(v), expected)
1110
+
1111
+ return k
1112
+
1113
+
1114
+ type_closure_kernel_vec2 = create_type_closure_vector(wp.vec2)
1115
+ type_closure_kernel_vec3 = create_type_closure_vector(wp.vec3)
1116
+
1117
+
1118
+ def test_type_closure_vector(test, device):
1119
+ with wp.ScopedDevice(device):
1120
+ wp.launch(type_closure_kernel_vec2, dim=1, inputs=[2.0])
1121
+ wp.launch(type_closure_kernel_vec3, dim=1, inputs=[3.0])
1122
+
1123
+
1124
+ # =======================================================================
1125
+
1126
+
1127
+ @wp.struct
1128
+ class ClosureStruct1:
1129
+ v: float
1130
+
1131
+
1132
+ @wp.struct
1133
+ class ClosureStruct2:
1134
+ v: wp.vec2
1135
+
1136
+
1137
+ @wp.func
1138
+ def closure_struct_func(s: ClosureStruct1):
1139
+ return 17.0
1140
+
1141
+
1142
+ @wp.func
1143
+ def closure_struct_func(s: ClosureStruct2):
1144
+ return 42.0
1145
+
1146
+
1147
+ def create_type_closure_struct(struct_type):
1148
+ @wp.kernel
1149
+ def k(expected: float):
1150
+ s = struct_type()
1151
+ result = closure_struct_func(s)
1152
+ wp.expect_eq(result, expected)
1153
+
1154
+ return k
1155
+
1156
+
1157
+ type_closure_kernel_struct1 = create_type_closure_struct(ClosureStruct1)
1158
+ type_closure_kernel_struct2 = create_type_closure_struct(ClosureStruct2)
1159
+
1160
+
1161
+ def test_type_closure_struct(test, device):
1162
+ with wp.ScopedDevice(device):
1163
+ wp.launch(type_closure_kernel_struct1, dim=1, inputs=[17.0])
1164
+ wp.launch(type_closure_kernel_struct2, dim=1, inputs=[42.0])
1165
+
1166
+
1167
+ # =======================================================================
1168
+
1169
+
1170
+ @wp.kernel
1171
+ def name_clash_func_kernel(a: wp.array(dtype=int)):
1172
+ a[0] = name_clash_module_1.same_func()
1173
+ a[1] = name_clash_module_2.same_func()
1174
+ a[2] = name_clash_module_1.different_func()
1175
+ a[3] = name_clash_module_2.different_func()
1176
+
1177
+
1178
+ def test_name_clash_func(test, device):
1179
+ """Test using identically named functions from different modules"""
1180
+ with wp.ScopedDevice(device):
1181
+ a = wp.zeros(4, dtype=int)
1182
+ wp.launch(name_clash_func_kernel, dim=1, inputs=[a])
1183
+ assert_np_equal(a.numpy(), np.array([99, 99, 17, 42]))
1184
+
1185
+
1186
+ # =======================================================================
1187
+
1188
+
1189
+ @wp.kernel
1190
+ def name_clash_structs_args_kernel(
1191
+ s1: name_clash_module_1.SameStruct,
1192
+ s2: name_clash_module_2.SameStruct,
1193
+ d1: name_clash_module_1.DifferentStruct,
1194
+ d2: name_clash_module_2.DifferentStruct,
1195
+ a: wp.array(dtype=float),
1196
+ ):
1197
+ a[0] = s1.x
1198
+ a[1] = s2.x
1199
+ a[2] = d1.v
1200
+ a[3] = d2.v[0]
1201
+ a[4] = d2.v[1]
1202
+
1203
+
1204
+ def test_name_clash_struct_args(test, device):
1205
+ with wp.ScopedDevice(device):
1206
+ s1 = name_clash_module_1.SameStruct()
1207
+ s2 = name_clash_module_2.SameStruct()
1208
+ d1 = name_clash_module_1.DifferentStruct()
1209
+ d2 = name_clash_module_2.DifferentStruct()
1210
+ s1.x = 1.0
1211
+ s2.x = 2.0
1212
+ d1.v = 3.0
1213
+ d2.v = wp.vec2(4.0, 5.0)
1214
+ a = wp.zeros(5, dtype=float)
1215
+ wp.launch(name_clash_structs_args_kernel, dim=1, inputs=[s1, s2, d1, d2, a])
1216
+ assert_np_equal(a.numpy(), np.array([1, 2, 3, 4, 5], dtype=np.float32))
1217
+
1218
+
1219
+ # =======================================================================
1220
+
1221
+
1222
+ @wp.kernel
1223
+ def name_clash_structs_ctor_kernel(
1224
+ a: wp.array(dtype=float),
1225
+ ):
1226
+ s1 = name_clash_module_1.SameStruct()
1227
+ s2 = name_clash_module_2.SameStruct()
1228
+ d1 = name_clash_module_1.DifferentStruct()
1229
+ d2 = name_clash_module_2.DifferentStruct()
1230
+
1231
+ s1.x = 1.0
1232
+ s2.x = 2.0
1233
+ d1.v = 3.0
1234
+ d2.v = wp.vec2(4.0, 5.0)
1235
+
1236
+ a[0] = s1.x
1237
+ a[1] = s2.x
1238
+ a[2] = d1.v
1239
+ a[3] = d2.v[0]
1240
+ a[4] = d2.v[1]
1241
+
1242
+
1243
+ def test_name_clash_struct_ctor(test, device):
1244
+ with wp.ScopedDevice(device):
1245
+ a = wp.zeros(5, dtype=float)
1246
+ wp.launch(name_clash_structs_ctor_kernel, dim=1, inputs=[a])
1247
+ assert_np_equal(a.numpy(), np.array([1, 2, 3, 4, 5], dtype=np.float32))
1248
+
1249
+
1250
+ # =======================================================================
1251
+
1252
+
1253
+ def test_create_kernel_loop(test, device):
1254
+ """
1255
+ Test creating a kernel in a loop. The kernel is always the same,
1256
+ so the module hash doesn't change and the module shouldn't be reloaded.
1257
+ This test ensures that the kernel hooks are found for new duplicate kernels.
1258
+ """
1259
+
1260
+ with wp.ScopedDevice(device):
1261
+ for _ in range(5):
1262
+
1263
+ @wp.kernel
1264
+ def k():
1265
+ pass
1266
+
1267
+ wp.launch(k, dim=1)
1268
+ wp.synchronize_device()
1269
+
1270
+
1271
+ # =======================================================================
1272
+
1273
+
1274
+ def test_module_mark_modified(test, device):
1275
+ """Test that Module.mark_modified() forces module rehashing and reloading."""
1276
+
1277
+ with wp.ScopedDevice(device):
1278
+
1279
+ @wp.kernel
1280
+ def k(expected: int):
1281
+ wp.expect_eq(C, expected)
1282
+
1283
+ C = 17
1284
+ wp.launch(k, dim=1, inputs=[17])
1285
+ wp.synchronize_device()
1286
+
1287
+ # redefine constant and force rehashing on next launch
1288
+ C = 42
1289
+ k.module.mark_modified()
1290
+
1291
+ wp.launch(k, dim=1, inputs=[42])
1292
+ wp.synchronize_device()
1293
+
1294
+
1295
+ # =======================================================================
1296
+
1297
+
1298
+ def test_garbage_collection(test, device):
1299
+ """Test that dynamically generated kernels without user references are not retained in the module."""
1300
+
1301
+ # use a helper module with a known kernel count
1302
+ import warp.tests.aux_test_instancing_gc as gc_test_module
1303
+
1304
+ with wp.ScopedDevice(device):
1305
+ a = wp.zeros(1, dtype=int)
1306
+
1307
+ for i in range(10):
1308
+ # create a unique kernel on each iteration
1309
+ k = gc_test_module.create_kernel_closure(i)
1310
+
1311
+ # import gc
1312
+ # gc.collect()
1313
+
1314
+ # since we don't keep references to the previous kernels,
1315
+ # they should be garbage-collected and not appear in the module
1316
+ k.module.load(device=device)
1317
+ test.assertEqual(len(k.module.live_kernels), 1)
1318
+
1319
+ # test the kernel
1320
+ wp.launch(k, dim=1, inputs=[a])
1321
+ test.assertEqual(a.numpy()[0], i)
1322
+
1323
+
1324
+ # =======================================================================
1325
+
1326
+
1327
+ class TestCodeGenInstancing(unittest.TestCase):
1328
+ pass
1329
+
1330
+
1331
+ devices = get_test_devices()
1332
+
1333
+ # global redefinitions with retained references
1334
+ add_function_test(
1335
+ TestCodeGenInstancing, func=test_global_kernel_redefine, name="test_global_kernel_redefine", devices=devices
1336
+ )
1337
+ add_function_test(
1338
+ TestCodeGenInstancing, func=test_global_func_redefine, name="test_global_func_redefine", devices=devices
1339
+ )
1340
+ add_function_test(
1341
+ TestCodeGenInstancing,
1342
+ func=test_global_struct_args_redefine,
1343
+ name="test_global_struct_args_redefine",
1344
+ devices=devices,
1345
+ )
1346
+ add_function_test(
1347
+ TestCodeGenInstancing,
1348
+ func=test_global_struct_ctor_redefine,
1349
+ name="test_global_struct_ctor_redefine",
1350
+ devices=devices,
1351
+ )
1352
+ add_function_test(
1353
+ TestCodeGenInstancing,
1354
+ func=test_global_overload_primary_redefine,
1355
+ name="test_global_overload_primary_redefine",
1356
+ devices=devices,
1357
+ )
1358
+ add_function_test(
1359
+ TestCodeGenInstancing,
1360
+ func=test_global_overload_secondary_redefine,
1361
+ name="test_global_overload_secondary_redefine",
1362
+ devices=devices,
1363
+ )
1364
+ add_function_test(
1365
+ TestCodeGenInstancing,
1366
+ func=test_global_generic_kernel_redefine,
1367
+ name="test_global_generic_kernel_redefine",
1368
+ devices=devices,
1369
+ )
1370
+ add_function_test(
1371
+ TestCodeGenInstancing,
1372
+ func=test_global_generic_func_redefine,
1373
+ name="test_global_generic_func_redefine",
1374
+ devices=devices,
1375
+ )
1376
+
1377
+ # create identical simple kernels, functions, and structs
1378
+ add_function_test(
1379
+ TestCodeGenInstancing, func=test_create_kernel_simple, name="test_create_kernel_simple", devices=devices
1380
+ )
1381
+ add_function_test(TestCodeGenInstancing, func=test_create_func_simple, name="test_create_func_simple", devices=devices)
1382
+ add_function_test(
1383
+ TestCodeGenInstancing, func=test_create_struct_simple_args, name="test_create_struct_simple_args", devices=devices
1384
+ )
1385
+ add_function_test(
1386
+ TestCodeGenInstancing, func=test_create_struct_simple_ctor, name="test_create_struct_simple_ctor", devices=devices
1387
+ )
1388
+ add_function_test(
1389
+ TestCodeGenInstancing,
1390
+ func=test_create_generic_kernel_simple,
1391
+ name="test_create_generic_kernel_simple",
1392
+ devices=devices,
1393
+ )
1394
+ add_function_test(
1395
+ TestCodeGenInstancing, func=test_create_generic_func_simple, name="test_create_generic_func_simple", devices=devices
1396
+ )
1397
+
1398
+ # create different simple kernels, functions, and structs
1399
+ add_function_test(TestCodeGenInstancing, func=test_create_kernel_cond, name="test_create_kernel_cond", devices=devices)
1400
+ add_function_test(TestCodeGenInstancing, func=test_create_func_cond, name="test_create_func_cond", devices=devices)
1401
+ add_function_test(
1402
+ TestCodeGenInstancing, func=test_create_struct_cond_args, name="test_create_struct_cond_args", devices=devices
1403
+ )
1404
+ add_function_test(
1405
+ TestCodeGenInstancing, func=test_create_struct_cond_ctor, name="test_create_struct_cond_ctor", devices=devices
1406
+ )
1407
+ add_function_test(
1408
+ TestCodeGenInstancing, func=test_create_generic_kernel_cond, name="test_create_generic_kernel_cond", devices=devices
1409
+ )
1410
+ add_function_test(
1411
+ TestCodeGenInstancing, func=test_create_generic_func_cond, name="test_create_generic_func_cond", devices=devices
1412
+ )
1413
+
1414
+ # closure kernels and functions
1415
+ add_function_test(
1416
+ TestCodeGenInstancing, func=test_create_kernel_closure, name="test_create_kernel_closure", devices=devices
1417
+ )
1418
+ add_function_test(
1419
+ TestCodeGenInstancing, func=test_create_func_closure, name="test_create_func_closure", devices=devices
1420
+ )
1421
+ add_function_test(
1422
+ TestCodeGenInstancing,
1423
+ func=test_create_func_closure_overload,
1424
+ name="test_create_func_closure_overload",
1425
+ devices=devices,
1426
+ )
1427
+ add_function_test(
1428
+ TestCodeGenInstancing,
1429
+ func=test_create_func_closure_overload_selfref,
1430
+ name="test_create_func_closure_overload_selfref",
1431
+ devices=devices,
1432
+ )
1433
+ add_function_test(
1434
+ TestCodeGenInstancing,
1435
+ func=test_create_func_closure_nonoverload,
1436
+ name="test_create_func_closure_nonoverload",
1437
+ devices=devices,
1438
+ )
1439
+ add_function_test(TestCodeGenInstancing, func=test_create_fk_closure, name="test_create_fk_closure", devices=devices)
1440
+ add_function_test(
1441
+ TestCodeGenInstancing,
1442
+ func=test_create_generic_kernel_closure,
1443
+ name="test_create_generic_kernel_closure",
1444
+ devices=devices,
1445
+ )
1446
+ add_function_test(
1447
+ TestCodeGenInstancing,
1448
+ func=test_create_generic_kernel_overload_closure,
1449
+ name="test_create_generic_kernel_overload_closure",
1450
+ devices=devices,
1451
+ )
1452
+ add_function_test(
1453
+ TestCodeGenInstancing,
1454
+ func=test_create_generic_func_closure,
1455
+ name="test_create_generic_func_closure",
1456
+ devices=devices,
1457
+ )
1458
+ add_function_test(
1459
+ TestCodeGenInstancing,
1460
+ func=test_create_generic_func_closure_overload,
1461
+ name="test_create_generic_func_closure_overload",
1462
+ devices=devices,
1463
+ )
1464
+
1465
+ # type closures
1466
+ add_function_test(
1467
+ TestCodeGenInstancing, func=test_type_closure_scalar, name="test_type_closure_scalar", devices=devices
1468
+ )
1469
+ add_function_test(
1470
+ TestCodeGenInstancing, func=test_type_closure_vector, name="test_type_closure_vector", devices=devices
1471
+ )
1472
+ add_function_test(
1473
+ TestCodeGenInstancing, func=test_type_closure_struct, name="test_type_closure_struct", devices=devices
1474
+ )
1475
+
1476
+ # test name clashes between modules
1477
+ add_function_test(TestCodeGenInstancing, func=test_name_clash_func, name="test_name_clash_func", devices=devices)
1478
+ add_function_test(
1479
+ TestCodeGenInstancing, func=test_name_clash_struct_args, name="test_name_clash_struct_args", devices=devices
1480
+ )
1481
+ add_function_test(
1482
+ TestCodeGenInstancing, func=test_name_clash_struct_ctor, name="test_name_clash_struct_ctor", devices=devices
1483
+ )
1484
+
1485
+ # miscellaneous tests
1486
+ add_function_test(TestCodeGenInstancing, func=test_create_kernel_loop, name="test_create_kernel_loop", devices=devices)
1487
+ add_function_test(
1488
+ TestCodeGenInstancing, func=test_module_mark_modified, name="test_module_mark_modified", devices=devices
1489
+ )
1490
+ add_function_test(TestCodeGenInstancing, func=test_garbage_collection, name="test_garbage_collection", devices=devices)
1491
+
1492
+
1493
+ if __name__ == "__main__":
1494
+ wp.clear_kernel_cache()
1495
+ unittest.main(verbosity=2)