warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,810 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import sys
17
+ import unittest
18
+ from typing import Tuple
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ @wp.kernel
25
+ def test_expect():
26
+ a = 1.0
27
+ a += 2.0
28
+
29
+ wp.expect_eq(123, 123)
30
+ wp.expect_neq(123, 234)
31
+
32
+ wp.expect_eq(wp.vec2(1.0, 2.0), wp.vec2(1.0, 2.0))
33
+ wp.expect_neq(wp.vec2(1.0, 2.0), wp.vec2(2.0, 3.0))
34
+
35
+ wp.expect_eq(wp.mat22(1.0, 2.0, 3.0, 4.0), wp.mat22(1.0, 2.0, 3.0, 4.0))
36
+ wp.expect_neq(wp.mat22(1.0, 2.0, 3.0, 4.0), wp.mat22(2.0, 3.0, 4.0, 5.0))
37
+
38
+
39
+ @wp.kernel
40
+ def test_rename():
41
+ a = 0
42
+ b = 1
43
+
44
+ a = b
45
+ a = 2
46
+
47
+ wp.expect_eq(a, 2)
48
+ wp.expect_eq(b, 1)
49
+
50
+
51
+ @wp.kernel
52
+ def test_inplace():
53
+ a = 1.0
54
+ a += 2.0
55
+
56
+ wp.expect_eq(a, 3.0)
57
+
58
+
59
+ @wp.kernel
60
+ def test_constant(c: float):
61
+ a = 0.0
62
+ a = c + 1.0
63
+
64
+ wp.expect_eq(a, 2.0)
65
+
66
+
67
+ @wp.kernel
68
+ def test_dynamic_for_rename(n: int):
69
+ f0 = int(0.0)
70
+ f1 = int(1.0)
71
+
72
+ for _i in range(0, n):
73
+ f = f0 + f1
74
+
75
+ f0 = f1
76
+ f1 = f
77
+
78
+ wp.expect_eq(f1, 89)
79
+
80
+
81
+ @wp.kernel
82
+ def test_dynamic_for_inplace(n: int):
83
+ a = float(0.0)
84
+
85
+ for _i in range(0, n):
86
+ a += 1.0
87
+
88
+ wp.expect_eq(a, float(n))
89
+
90
+
91
+ @wp.kernel
92
+ def test_reassign():
93
+ f0 = 1.0
94
+ f1 = f0
95
+
96
+ f1 = f1 + 2.0
97
+
98
+ wp.expect_eq(f1, 3.0)
99
+ wp.expect_eq(f0, 1.0)
100
+
101
+
102
+ @wp.kernel
103
+ def test_dynamic_reassign(n: int):
104
+ f0 = wp.vec3()
105
+ f1 = f0
106
+
107
+ for _i in range(0, n):
108
+ f1 = f1 - wp.vec3(2.0, 0.0, 0.0)
109
+
110
+ wp.expect_eq(f1, wp.vec3(-4.0, 0.0, 0.0))
111
+ wp.expect_eq(f0, wp.vec3())
112
+
113
+
114
+ @wp.kernel
115
+ def test_range_static_sum(result: wp.array(dtype=int)):
116
+ a = int(0)
117
+ for _i in range(10):
118
+ a = a + 1
119
+
120
+ b = int(0)
121
+ for _i in range(0, 10):
122
+ b = b + 1
123
+
124
+ c = int(0)
125
+ for _i in range(0, 20, 2):
126
+ c = c + 1
127
+
128
+ result[0] = a
129
+ result[1] = b
130
+ result[2] = c
131
+
132
+
133
+ @wp.kernel
134
+ def test_range_dynamic_sum(start: int, end: int, step: int, result: wp.array(dtype=int)):
135
+ a = int(0)
136
+ for _i in range(end):
137
+ a = a + 1
138
+
139
+ b = int(0)
140
+ for _i in range(start, end):
141
+ b = b + 1
142
+
143
+ c = int(0)
144
+ for _i in range(start, end * step, step):
145
+ c = c + 1
146
+
147
+ d = int(0)
148
+ for _i in range(end * step, start, -step):
149
+ d = d + 1
150
+
151
+ result[0] = a
152
+ result[1] = b
153
+ result[2] = c
154
+ result[3] = d
155
+
156
+
157
+ @wp.kernel
158
+ def test_range_dynamic(start: int, end: int, step: int, result: wp.array(dtype=int)):
159
+ output = int(0)
160
+ for i in range(start, end, step):
161
+ result[output] = i
162
+ output += 1
163
+
164
+
165
+ @wp.kernel
166
+ def test_range_dynamic_nested(n: int):
167
+ sum1 = float(0.0)
168
+ sum2 = float(0.0)
169
+ sum3 = float(0.0)
170
+
171
+ for _i in range(n):
172
+ sum1 = sum1 + 1.0
173
+ sum3 = sum3 + 1.0
174
+
175
+ for _j in range(n):
176
+ sum2 = sum2 + 1.0
177
+ sum3 = sum3 + 1.0
178
+
179
+ sum3 = sum3 + 1.0
180
+
181
+ wp.expect_eq(sum1, float(n))
182
+ wp.expect_eq(sum2, float(n * n))
183
+ wp.expect_eq(sum3, float(n * n + 2 * n))
184
+
185
+
186
+ @wp.kernel
187
+ def test_while(n: int):
188
+ i = int(0)
189
+
190
+ while i < n:
191
+ i = i + 1
192
+
193
+ wp.expect_eq(i, n)
194
+
195
+
196
+ @wp.kernel
197
+ def test_pass(n: int):
198
+ i = int(0)
199
+
200
+ while i < n:
201
+ if False:
202
+ pass
203
+ else:
204
+ i = i + 1
205
+
206
+ wp.expect_eq(i, n)
207
+
208
+
209
+ @wp.kernel
210
+ def test_break(n: int):
211
+ a = int(0)
212
+
213
+ for _i in range(0, n):
214
+ if a == 5:
215
+ break
216
+
217
+ a += 1
218
+
219
+ wp.expect_eq(a, 5)
220
+
221
+
222
+ @wp.kernel
223
+ def test_break_early(n: int):
224
+ a = int(0)
225
+
226
+ for i in range(0, n):
227
+ if i > 5:
228
+ a = 1
229
+ break
230
+
231
+ wp.expect_eq(a, 1)
232
+
233
+
234
+ @wp.kernel
235
+ def test_break_unroll():
236
+ a = int(0)
237
+
238
+ for i in range(0, 10):
239
+ if i > 5:
240
+ a = i
241
+ break
242
+
243
+ wp.expect_eq(a, 6)
244
+
245
+
246
+ @wp.kernel
247
+ def test_break_while():
248
+ a = int(0)
249
+
250
+ while a < 10:
251
+ if a > 5:
252
+ break
253
+ a += 1
254
+
255
+ wp.expect_eq(a, 6)
256
+
257
+
258
+ @wp.kernel
259
+ def test_break_multiple(n: int):
260
+ a = int(0)
261
+
262
+ for i in range(0, n):
263
+ if i == 6:
264
+ a = 1
265
+ break
266
+
267
+ if i == 5:
268
+ a = 2
269
+ break
270
+
271
+ if i == 7:
272
+ a = 3
273
+ break
274
+
275
+ wp.expect_eq(a, 2)
276
+
277
+
278
+ @wp.kernel
279
+ def test_continue(n: int):
280
+ a = int(0)
281
+
282
+ for i in range(0, n):
283
+ if i == 5:
284
+ continue
285
+
286
+ a += 1
287
+
288
+ wp.expect_eq(a, n - 1)
289
+
290
+
291
+ @wp.kernel
292
+ def test_continue_unroll():
293
+ a = int(0)
294
+
295
+ for i in range(0, 10):
296
+ if i == 5:
297
+ continue
298
+
299
+ a += 1
300
+
301
+ wp.expect_eq(a, 9)
302
+
303
+
304
+ lower = wp.constant(-3)
305
+ upper = wp.constant(3)
306
+ step = wp.constant(2)
307
+
308
+
309
+ # test unrolling of loops with constant size params
310
+ # we can't easily test if unrolling has occurred
311
+ # so just verify correctness at this stage
312
+ @wp.kernel
313
+ def test_range_constant():
314
+ s = 0
315
+ for i in range(upper):
316
+ s += i
317
+
318
+ # sum [0, 3)
319
+ wp.expect_eq(s, 3)
320
+
321
+ s = 0
322
+ for i in range(lower, upper):
323
+ s += i
324
+
325
+ # sum [-3, 3)
326
+ wp.expect_eq(s, -3)
327
+
328
+ s = 0
329
+ for i in range(lower, upper, step):
330
+ s += i
331
+
332
+ # sum [-3, 3)
333
+ wp.expect_eq(s, -3)
334
+
335
+
336
+ N = wp.constant(3)
337
+
338
+
339
+ # test a dynamic loop nested between loops expected to be unrolled.
340
+ @wp.kernel
341
+ def test_range_constant_dynamic_nested(m: int):
342
+ s = int(0)
343
+ for _i in range(N):
344
+ for _k in range(m):
345
+ for _j in range(N):
346
+ s += 1
347
+
348
+ wp.expect_eq(s, N * m * N)
349
+
350
+
351
+ @wp.kernel
352
+ def test_range_expression():
353
+ idx = 1
354
+ batch_size = 100
355
+
356
+ a = wp.float(0.0)
357
+ c = wp.float(1.0)
358
+
359
+ # constant expression with a function
360
+ for _i in range(4 * idx, wp.min(4 * idx + 4, batch_size)):
361
+ a += c
362
+
363
+ for _i in range(4 * idx, min(4 * idx + 4, batch_size)):
364
+ a += c
365
+
366
+ tid = wp.tid()
367
+
368
+ # dynamic expression with a function
369
+ for _i in range(4 * idx, wp.min(4 * idx, tid + 1000)):
370
+ a += c
371
+
372
+ for _i in range(4 * idx, min(4 * idx, tid + 1000)):
373
+ a += c
374
+
375
+ wp.expect_eq(a, 8.0)
376
+
377
+
378
+ def test_unresolved_func(test, device):
379
+ # kernel with unresolved function must be in a separate module, otherwise the current module would fail to load
380
+ from warp.tests.aux_test_unresolved_func import unresolved_func_kernel
381
+
382
+ # ensure that an appropriate exception is raised when the bad module gets loaded
383
+ with test.assertRaisesRegex(RuntimeError, "Could not find function wp.missing_func"):
384
+ wp.launch(unresolved_func_kernel, dim=1, inputs=[], device=device)
385
+
386
+ # remove all references to the bad module so that subsequent calls to wp.force_load()
387
+ # won't try to load it unless we explicitly re-import it again
388
+ del wp.context.user_modules["warp.tests.aux_test_unresolved_func"]
389
+ del sys.modules["warp.tests.aux_test_unresolved_func"]
390
+
391
+
392
+ def test_unresolved_symbol(test, device):
393
+ # kernel with unresolved symbol must be in a separate module, otherwise the current module would fail to load
394
+ from warp.tests.aux_test_unresolved_symbol import unresolved_symbol_kernel
395
+
396
+ # ensure that an appropriate exception is raised when the bad module gets loaded
397
+ with test.assertRaisesRegex(KeyError, "Referencing undefined symbol: missing_symbol"):
398
+ wp.launch(unresolved_symbol_kernel, dim=1, inputs=[], device=device)
399
+
400
+ # remove all references to the bad module so that subsequent calls to wp.force_load()
401
+ # won't try to load it unless we explicitly re-import it again
402
+ del wp.context.user_modules["warp.tests.aux_test_unresolved_symbol"]
403
+ del sys.modules["warp.tests.aux_test_unresolved_symbol"]
404
+
405
+
406
+ def test_error_global_var(test, device):
407
+ arr = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
408
+
409
+ def kernel_1_fn(out: wp.array(dtype=float)):
410
+ out[0] = arr[0]
411
+
412
+ def kernel_2_fn(out: wp.array(dtype=float)):
413
+ out[0] = arr
414
+
415
+ def kernel_3_fn(out: wp.array(dtype=float)):
416
+ out[0] = wp.lower_bound(arr, 2.0)
417
+
418
+ out = wp.empty_like(arr)
419
+
420
+ kernel = wp.Kernel(func=kernel_1_fn)
421
+ with test.assertRaisesRegex(TypeError, r"Invalid external reference type: <class 'warp.types.array'>"):
422
+ wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device)
423
+
424
+ kernel = wp.Kernel(func=kernel_2_fn)
425
+ with test.assertRaisesRegex(TypeError, r"Invalid external reference type: <class 'warp.types.array'>"):
426
+ wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device)
427
+
428
+ kernel = wp.Kernel(func=kernel_3_fn)
429
+ with test.assertRaisesRegex(TypeError, r"Invalid external reference type: <class 'warp.types.array'>"):
430
+ wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device)
431
+
432
+
433
+ def test_error_collection_construct(test, device):
434
+ def kernel_1_fn():
435
+ x = [1.0, 2.0, 3.0]
436
+
437
+ def kernel_2_fn():
438
+ x = (1.0, 2.0, 3.0)
439
+
440
+ def kernel_3_fn():
441
+ x = {"a": 1.0, "b": 2.0, "c": 3.0}
442
+
443
+ def kernel_4_fn():
444
+ wp.length((1.0, 2.0, 3.0))
445
+
446
+ kernel = wp.Kernel(func=kernel_1_fn)
447
+ with test.assertRaisesRegex(
448
+ RuntimeError,
449
+ r"List constructs are not supported in kernels. Use vectors like `wp.vec3\(\)` for small collections instead.",
450
+ ):
451
+ wp.launch(kernel, dim=1, device=device)
452
+
453
+ kernel = wp.Kernel(func=kernel_2_fn)
454
+ with test.assertRaisesRegex(
455
+ RuntimeError,
456
+ r"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3\(\)` for small collections instead.",
457
+ ):
458
+ wp.launch(kernel, dim=1, device=device)
459
+
460
+ kernel = wp.Kernel(func=kernel_3_fn)
461
+ with test.assertRaisesRegex(RuntimeError, r"Construct `ast.Dict` not supported in kernels."):
462
+ wp.launch(kernel, dim=1, device=device)
463
+
464
+ kernel = wp.Kernel(func=kernel_4_fn)
465
+ with test.assertRaisesRegex(
466
+ RuntimeError, r"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3\(\)` instead."
467
+ ):
468
+ wp.launch(kernel, dim=1, device=device)
469
+
470
+
471
+ def test_error_unmatched_arguments(test, device):
472
+ def kernel_1_fn():
473
+ a = 1 * 1.0
474
+
475
+ def kernel_2_fn():
476
+ x = wp.dot(wp.vec2(1.0, 2.0), wp.vec2h(wp.float16(1.0), wp.float16(2.0)))
477
+
478
+ kernel = wp.Kernel(func=kernel_1_fn)
479
+ with test.assertRaisesRegex(RuntimeError, r"Input types must be the same, got \['int32', 'float32'\]"):
480
+ wp.launch(kernel, dim=1, device=device)
481
+
482
+ kernel = wp.Kernel(func=kernel_2_fn)
483
+ with test.assertRaisesRegex(
484
+ RuntimeError,
485
+ r"Input types must be exactly the same, got \[\"vector\(length=2, dtype=<class 'warp.types.float32'>\)\", \"vector\(length=2, dtype=<class 'warp.types.float16'>\)\"\]",
486
+ ):
487
+ wp.launch(kernel, dim=1, device=device)
488
+
489
+
490
+ def test_error_mutating_constant_in_dynamic_loop(test, device):
491
+ @wp.kernel
492
+ def dynamic_loop_kernel(n: int, input: wp.array(dtype=float)):
493
+ my_constant = 0.0
494
+ for i in range(n):
495
+ my_constant += input[i]
496
+
497
+ inputs = wp.array([1.0, 2.0, 3.0], dtype=float, device=device)
498
+ with test.assertRaisesRegex(
499
+ wp.codegen.WarpCodegenError,
500
+ r"Error mutating a constant my_constant inside a dynamic loop, use the following syntax\: pi = float\(3\.141\) to declare a dynamic variable",
501
+ ):
502
+ wp.launch(dynamic_loop_kernel, dim=1, inputs=[3, inputs], device=device)
503
+
504
+ # the following nested loop must not raise an error
505
+ const_a = 7
506
+ const_b = 5
507
+
508
+ @wp.kernel
509
+ def mixed_dyn_static_loop_kernel(dyn_a: int, dyn_b: int, dyn_c: int, output: wp.array(dtype=float, ndim=2)):
510
+ tid = wp.tid()
511
+ for i in range(const_a + 1):
512
+ for j in range(dyn_a + 1):
513
+ for k in range(dyn_b + 1):
514
+ for l in range(const_b + 1):
515
+ for m in range(dyn_c + 1):
516
+ coeff = i + j + k + l + m
517
+ output[tid, coeff] = 1.0
518
+
519
+ dyn_a, dyn_b, dyn_c = 3, 4, 5
520
+ num_threads = 10
521
+ output = wp.empty([num_threads, const_a + const_b + dyn_a + dyn_b + dyn_c + 1], dtype=float, device=device)
522
+ wp.launch(
523
+ mixed_dyn_static_loop_kernel,
524
+ num_threads,
525
+ inputs=[
526
+ dyn_a,
527
+ dyn_b,
528
+ dyn_c,
529
+ ],
530
+ outputs=[output],
531
+ device=device,
532
+ )
533
+ assert_np_equal(output.numpy(), np.ones([num_threads, const_a + const_b + dyn_a + dyn_b + dyn_c + 1]))
534
+
535
+ @wp.kernel
536
+ def static_then_dynamic_loop_kernel(mats: wp.array(dtype=wp.mat33d)):
537
+ tid = wp.tid()
538
+ mat = wp.mat33d()
539
+ for i in range(3):
540
+ for j in range(3):
541
+ mat[i, j] = wp.float64(0.0)
542
+
543
+ dim = 2
544
+ for i in range(dim + 1):
545
+ for j in range(dim + 1):
546
+ mat[i, j] = wp.float64(1.0)
547
+
548
+ mats[tid] = mat
549
+
550
+ mats = wp.empty(1, dtype=wp.mat33d, device=device)
551
+ wp.launch(static_then_dynamic_loop_kernel, dim=1, inputs=[mats], device=device)
552
+ assert_np_equal(mats.numpy(), np.ones((1, 3, 3)))
553
+
554
+ @wp.kernel
555
+ def dynamic_then_static_loop_kernel(mats: wp.array(dtype=wp.mat33d)):
556
+ tid = wp.tid()
557
+ mat = wp.mat33d()
558
+
559
+ dim = 2
560
+ for i in range(dim + 1):
561
+ for j in range(dim + 1):
562
+ mat[i, j] = wp.float64(1.0)
563
+
564
+ for i in range(3):
565
+ for j in range(3):
566
+ mat[i, j] = wp.float64(0.0)
567
+
568
+ mats[tid] = mat
569
+
570
+ mats = wp.empty(1, dtype=wp.mat33d, device=device)
571
+ wp.launch(dynamic_then_static_loop_kernel, dim=1, inputs=[mats], device=device)
572
+ assert_np_equal(mats.numpy(), np.zeros((1, 3, 3)))
573
+
574
+
575
+ def test_error_return_annotation_mismatch(test, device):
576
+ @wp.func
577
+ def foo_1(x: wp.int32) -> wp.int16:
578
+ return wp.int8(x)
579
+
580
+ def kernel_1_fn():
581
+ x = foo_1(123)
582
+
583
+ @wp.func
584
+ def foo_2(x: int) -> int:
585
+ return (x + x, x * x)
586
+
587
+ def kernel_2_fn():
588
+ x = foo_2(123)
589
+
590
+ @wp.func
591
+ def foo_3(x: int) -> Tuple[int, int]:
592
+ return (x, 1.23)
593
+
594
+ def kernel_3_fn():
595
+ x, y = foo_3(123)
596
+
597
+ @wp.func
598
+ def foo_4(x: int) -> Tuple[int, int, int]:
599
+ return (x + x, x * x)
600
+
601
+ def kernel_4_fn():
602
+ x, y, z = foo_4(123)
603
+
604
+ kernel = wp.Kernel(func=kernel_1_fn)
605
+ with test.assertRaisesRegex(
606
+ wp.codegen.WarpCodegenError,
607
+ r"The function `foo_1` has its return type annotated as `int16` but the code returns a value of type `int8`.",
608
+ ):
609
+ wp.launch(kernel, dim=1, device=device)
610
+
611
+ kernel = wp.Kernel(func=kernel_2_fn)
612
+ with test.assertRaisesRegex(
613
+ wp.codegen.WarpCodegenError,
614
+ r"The function `foo_2` has its return type annotated as `int` but the code returns 2 values.",
615
+ ):
616
+ wp.launch(kernel, dim=1, device=device)
617
+
618
+ kernel = wp.Kernel(func=kernel_3_fn)
619
+ with test.assertRaisesRegex(
620
+ wp.codegen.WarpCodegenError,
621
+ r"The function `foo_3` has its return type annotated as `Tuple\[int, int\]` but the code returns a tuple with types `\(int32, float32\)`.",
622
+ ):
623
+ wp.launch(kernel, dim=1, device=device)
624
+
625
+ kernel = wp.Kernel(func=kernel_4_fn)
626
+ with test.assertRaisesRegex(
627
+ wp.codegen.WarpCodegenError,
628
+ r"The function `foo_4` has its return type annotated as a tuple of 3 elements but the code returns 2 values.",
629
+ ):
630
+ wp.launch(kernel, dim=1, device=device)
631
+
632
+
633
+ @wp.kernel
634
+ def test_call_syntax():
635
+ expected_pow = 16.0
636
+ wp.expect_eq(wp.pow(2.0, 4.0), expected_pow)
637
+ wp.expect_eq(wp.pow(x=2.0, y=4.0), expected_pow)
638
+ wp.expect_eq(wp.pow(2.0, y=4.0), expected_pow)
639
+ wp.expect_eq(wp.pow(y=4.0, x=2.0), expected_pow)
640
+
641
+ expected_matrix = wp.mat44(2.0, 0.0, 0.0, 1.0, 0.0, 3.0, 0.0, 2.0, 0.0, 0.0, 4.0, 3.0, 0.0, 0.0, 0.0, 1.0)
642
+ pos = wp.vec3(1.0, 2.0, 3.0)
643
+ rot = wp.quat(0.0, 0.0, 0.0, 1.0)
644
+ scale = wp.vec3(2.0, 3.0, 4.0)
645
+ wp.expect_eq(wp.matrix(pos, rot, scale, wp.float32), expected_matrix)
646
+ wp.expect_eq(wp.matrix(pos=pos, rot=rot, scale=scale, dtype=wp.float32), expected_matrix)
647
+ wp.expect_eq(wp.matrix(pos, rot, scale, dtype=wp.float32), expected_matrix)
648
+ wp.expect_eq(wp.matrix(rot=rot, pos=pos, dtype=wp.float32, scale=scale), expected_matrix)
649
+
650
+
651
+ # test shadowing builtin functions
652
+ @wp.func
653
+ def sum(a: wp.vec3) -> float:
654
+ return a[0] + a[1] + a[2]
655
+
656
+
657
+ @wp.kernel
658
+ def test_shadow_builtin():
659
+ wp.expect_eq(sum(wp.vec3(1.0)), 3.0)
660
+
661
+
662
+ @wp.struct
663
+ class Iterator:
664
+ valid: wp.bool
665
+
666
+
667
+ @wp.kernel(enable_backward=False)
668
+ def test_while_condition_eval():
669
+ it = Iterator()
670
+ it.valid = True
671
+ while it.valid:
672
+ it.valid = False
673
+
674
+
675
+ class TestCodeGen(unittest.TestCase):
676
+ pass
677
+
678
+
679
+ devices = get_test_devices()
680
+
681
+ add_kernel_test(TestCodeGen, name="test_expect", kernel=test_expect, dim=1, devices=devices)
682
+ add_kernel_test(TestCodeGen, name="test_inplace", kernel=test_inplace, dim=1, devices=devices)
683
+ add_kernel_test(TestCodeGen, name="test_rename", kernel=test_rename, dim=1, devices=devices)
684
+ add_kernel_test(TestCodeGen, name="test_constant", kernel=test_constant, inputs=[1.0], dim=1, devices=devices)
685
+ add_kernel_test(
686
+ TestCodeGen, name="test_dynamic_for_rename", kernel=test_dynamic_for_rename, inputs=[10], dim=1, devices=devices
687
+ )
688
+ add_kernel_test(
689
+ TestCodeGen, name="test_dynamic_for_inplace", kernel=test_dynamic_for_inplace, inputs=[10], dim=1, devices=devices
690
+ )
691
+ add_kernel_test(TestCodeGen, name="test_reassign", kernel=test_reassign, dim=1, devices=devices)
692
+ add_kernel_test(
693
+ TestCodeGen, name="test_dynamic_reassign", kernel=test_dynamic_reassign, inputs=[2], dim=1, devices=devices
694
+ )
695
+
696
+ add_kernel_test(
697
+ TestCodeGen,
698
+ name="test_range_dynamic_forward",
699
+ kernel=test_range_dynamic,
700
+ dim=1,
701
+ inputs=[0, 4, 1],
702
+ expect=[0, 1, 2, 3],
703
+ devices=devices,
704
+ )
705
+ add_kernel_test(
706
+ TestCodeGen,
707
+ name="test_range_dynamic_reverse",
708
+ kernel=test_range_dynamic,
709
+ dim=1,
710
+ inputs=[4, 0, -1],
711
+ expect=[4, 3, 2, 1],
712
+ devices=devices,
713
+ )
714
+ add_kernel_test(
715
+ TestCodeGen,
716
+ name="test_range_dynamic_forward_step",
717
+ kernel=test_range_dynamic,
718
+ dim=1,
719
+ inputs=[0, 8, 2],
720
+ expect=[0, 2, 4, 6],
721
+ devices=devices,
722
+ )
723
+ add_kernel_test(
724
+ TestCodeGen,
725
+ name="test_range_dynamic_reverse_step",
726
+ kernel=test_range_dynamic,
727
+ dim=1,
728
+ inputs=[8, 0, -2],
729
+ expect=[8, 6, 4, 2],
730
+ devices=devices,
731
+ )
732
+
733
+ add_kernel_test(
734
+ TestCodeGen, name="test_range_static_sum", kernel=test_range_static_sum, dim=1, expect=[10, 10, 10], devices=devices
735
+ )
736
+ add_kernel_test(
737
+ TestCodeGen,
738
+ name="test_range_dynamic_sum",
739
+ kernel=test_range_dynamic_sum,
740
+ dim=1,
741
+ inputs=[0, 10, 2],
742
+ expect=[10, 10, 10, 10],
743
+ devices=devices,
744
+ )
745
+ add_kernel_test(
746
+ TestCodeGen,
747
+ name="test_range_dynamic_sum_zero",
748
+ kernel=test_range_dynamic_sum,
749
+ dim=1,
750
+ inputs=[0, 0, 1],
751
+ expect=[0, 0, 0, 0],
752
+ devices=devices,
753
+ )
754
+ add_kernel_test(TestCodeGen, name="test_range_constant", kernel=test_range_constant, dim=1, devices=devices)
755
+ add_kernel_test(
756
+ TestCodeGen,
757
+ name="test_range_constant_dynamic_nested",
758
+ kernel=test_range_constant_dynamic_nested,
759
+ dim=1,
760
+ inputs=[10],
761
+ devices=devices,
762
+ )
763
+ add_kernel_test(
764
+ TestCodeGen, name="test_range_dynamic_nested", kernel=test_range_dynamic_nested, dim=1, inputs=[4], devices=devices
765
+ )
766
+ add_kernel_test(TestCodeGen, name="test_range_expression", kernel=test_range_expression, dim=1, devices=devices)
767
+
768
+ add_kernel_test(TestCodeGen, name="test_while_zero", kernel=test_while, dim=1, inputs=[0], devices=devices)
769
+ add_kernel_test(TestCodeGen, name="test_while_positive", kernel=test_while, dim=1, inputs=[16], devices=devices)
770
+ add_kernel_test(TestCodeGen, name="test_pass", kernel=test_pass, dim=1, inputs=[16], devices=devices)
771
+
772
+ add_kernel_test(TestCodeGen, name="test_break", kernel=test_break, dim=1, inputs=[10], devices=devices)
773
+ add_kernel_test(TestCodeGen, name="test_break_early", kernel=test_break_early, dim=1, inputs=[10], devices=devices)
774
+ add_kernel_test(TestCodeGen, name="test_break_unroll", kernel=test_break_unroll, dim=1, devices=devices)
775
+ add_kernel_test(TestCodeGen, name="test_break_while", kernel=test_break_while, dim=1, devices=devices)
776
+ add_kernel_test(
777
+ TestCodeGen, name="test_break_multiple", kernel=test_break_multiple, dim=1, inputs=[10], devices=devices
778
+ )
779
+ add_kernel_test(TestCodeGen, name="test_continue", kernel=test_continue, dim=1, inputs=[10], devices=devices)
780
+ add_kernel_test(TestCodeGen, name="test_continue_unroll", kernel=test_continue_unroll, dim=1, devices=devices)
781
+
782
+ add_function_test(TestCodeGen, func=test_unresolved_func, name="test_unresolved_func", devices=devices)
783
+ add_function_test(TestCodeGen, func=test_unresolved_symbol, name="test_unresolved_symbol", devices=devices)
784
+ add_function_test(TestCodeGen, func=test_error_global_var, name="test_error_global_var", devices=devices)
785
+ add_function_test(
786
+ TestCodeGen, func=test_error_collection_construct, name="test_error_collection_construct", devices=devices
787
+ )
788
+ add_function_test(
789
+ TestCodeGen, func=test_error_unmatched_arguments, name="test_error_unmatched_arguments", devices=devices
790
+ )
791
+ add_function_test(
792
+ TestCodeGen,
793
+ func=test_error_mutating_constant_in_dynamic_loop,
794
+ name="test_error_mutating_constant_in_dynamic_loop",
795
+ devices=devices,
796
+ )
797
+ add_function_test(
798
+ TestCodeGen,
799
+ func=test_error_return_annotation_mismatch,
800
+ name="test_error_return_annotation_mismatch",
801
+ devices=devices,
802
+ )
803
+ add_kernel_test(TestCodeGen, name="test_call_syntax", kernel=test_call_syntax, dim=1, devices=devices)
804
+ add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices)
805
+ add_kernel_test(TestCodeGen, name="test_while_condition_eval", kernel=test_while_condition_eval, dim=1, devices=devices)
806
+
807
+
808
+ if __name__ == "__main__":
809
+ wp.clear_kernel_cache()
810
+ unittest.main(verbosity=2, failfast=True)