warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,178 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any, NamedTuple
18
+
19
+ import numpy as np
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import add_function_test, assert_np_equal, get_test_devices
23
+
24
+
25
+ class ScalarFloatValues(NamedTuple):
26
+ degrees: wp.float32 = None
27
+ radians: wp.float32 = None
28
+
29
+
30
+ @wp.kernel
31
+ def scalar_float_kernel(
32
+ i: int,
33
+ x: wp.array(dtype=wp.float32),
34
+ out: wp.array(dtype=wp.float32),
35
+ ):
36
+ if i == 0:
37
+ out[0] = wp.degrees(x[0])
38
+ elif i == 1:
39
+ out[0] = wp.radians(x[0])
40
+
41
+
42
+ def test_scalar_math(test, device):
43
+ float_values = ScalarFloatValues(degrees=(0.123,), radians=(123.0,))
44
+ float_results_expected = ScalarFloatValues(degrees=7.047381, radians=2.146755)
45
+ adj_float_results_expected = ScalarFloatValues(degrees=57.29578, radians=0.017453)
46
+ for i, values in enumerate(float_values):
47
+ x = wp.array([values[0]], dtype=wp.float32, requires_grad=True, device=device)
48
+ out = wp.array([0.0], dtype=wp.float32, requires_grad=True, device=device)
49
+
50
+ tape = wp.Tape()
51
+ with tape:
52
+ wp.launch(scalar_float_kernel, dim=1, inputs=[i, x, out], device=device)
53
+
54
+ assert_np_equal(out.numpy(), np.array([float_results_expected[i]]), tol=1e-6)
55
+
56
+ tape.backward(out)
57
+
58
+ assert_np_equal(tape.gradients[x].numpy(), np.array([adj_float_results_expected[i]]), tol=1e-6)
59
+
60
+
61
+ @wp.kernel
62
+ def test_vec_norm_kernel(vs: wp.array(dtype=Any), out: wp.array(dtype=float, ndim=2)):
63
+ tid = wp.tid()
64
+ out[tid, 0] = wp.norm_l1(vs[tid])
65
+ out[tid, 1] = wp.norm_l2(vs[tid])
66
+ out[tid, 2] = wp.norm_huber(vs[tid])
67
+ out[tid, 3] = wp.norm_pseudo_huber(vs[tid])
68
+
69
+
70
+ def test_vec_norm(test, device):
71
+ # ground-truth implementations from SciPy
72
+ def huber(delta, x):
73
+ if x <= delta:
74
+ return 0.5 * x**2
75
+ else:
76
+ return delta * (x - 0.5 * delta)
77
+
78
+ def pseudo_huber(delta, x):
79
+ return delta**2 * (np.sqrt(1 + (x / delta) ** 2) - 1)
80
+
81
+ v0 = wp.vec3(-2.0, -1.0, -3.0)
82
+ v1 = wp.vec3(2.0, 1.0, 3.0)
83
+ v2 = wp.vec3(0.0, 0.0, 0.0)
84
+
85
+ xs = wp.array([v0, v1, v2], dtype=wp.vec3, requires_grad=True, device=device)
86
+ out = wp.empty((len(xs), 4), dtype=wp.float32, requires_grad=True, device=device)
87
+
88
+ wp.launch(test_vec_norm_kernel, dim=len(xs), inputs=[xs], outputs=[out], device=device)
89
+
90
+ for i, x in enumerate([v0, v1, v2]):
91
+ assert_np_equal(
92
+ out.numpy()[i],
93
+ np.array(
94
+ [
95
+ np.linalg.norm(x, ord=1),
96
+ np.linalg.norm(x, ord=2),
97
+ huber(1.0, wp.length(x)),
98
+ # note SciPy defines the Pseudo-Huber loss slightly differently
99
+ pseudo_huber(1.0, wp.length(x)) + 1.0,
100
+ ]
101
+ ),
102
+ tol=1e-6,
103
+ )
104
+
105
+
106
+ devices = get_test_devices()
107
+
108
+
109
+ class TestMath(unittest.TestCase):
110
+ def test_vec_type(self):
111
+ vec5 = wp.vec(length=5, dtype=float)
112
+ v = vec5()
113
+ w = vec5()
114
+ a = vec5(1.0)
115
+ b = vec5(0.0, 0.0, 0.0, 0.0, 0.0)
116
+ c = vec5(0.0)
117
+
118
+ v[0] = 1.0
119
+ v.x = 0.0
120
+ v[1:] = [1.0, 1.0, 1.0, 1.0]
121
+
122
+ w[0] = 1.0
123
+ w[1:] = [0.0, 0.0, 0.0, 0.0]
124
+
125
+ self.assertEqual(v[0], w[1], "vec setter error")
126
+ self.assertEqual(v.x, w.y, "vec setter error")
127
+
128
+ for x in v[1:]:
129
+ self.assertEqual(x, 1.0, "vec slicing error")
130
+
131
+ self.assertEqual(b, c, "vec equality error")
132
+
133
+ self.assertEqual(str(v), "[0.0, 1.0, 1.0, 1.0, 1.0]", "vec to string error")
134
+
135
+ def test_mat_type(self):
136
+ mat55 = wp.mat(shape=(5, 5), dtype=float)
137
+ m1 = mat55()
138
+ m2 = mat55()
139
+
140
+ for i in range(5):
141
+ for j in range(5):
142
+ if i == j:
143
+ m1[i, j] = 1.0
144
+ else:
145
+ m1[i, j] = 0.0
146
+
147
+ for i in range(5):
148
+ m2[i] = [1.0, 1.0, 1.0, 1.0, 1.0]
149
+
150
+ a = mat55(1.0)
151
+ # fmt: off
152
+ b = mat55(
153
+ 1.0, 0.0, 0.0, 0.0, 0.0,
154
+ 0.0, 1.0, 0.0, 0.0, 0.0,
155
+ 0.0, 0.0, 1.0, 0.0, 0.0,
156
+ 0.0, 0.0, 0.0, 1.0, 0.0,
157
+ 0.0, 0.0, 0.0, 0.0, 1.0,
158
+ )
159
+ # fmt: on
160
+
161
+ self.assertEqual(m1, b, "mat element setting error")
162
+ self.assertEqual(m2, a, "mat row setting error")
163
+ self.assertEqual(m1[0, 0], 1.0, "mat element getting error")
164
+ self.assertEqual(m2[0], [1.0, 1.0, 1.0, 1.0, 1.0], "mat row getting error")
165
+ self.assertEqual(
166
+ str(b),
167
+ "[[1.0, 0.0, 0.0, 0.0, 0.0],\n [0.0, 1.0, 0.0, 0.0, 0.0],\n [0.0, 0.0, 1.0, 0.0, 0.0],\n [0.0, 0.0, 0.0, 1.0, 0.0],\n [0.0, 0.0, 0.0, 0.0, 1.0]]",
168
+ "mat to string error",
169
+ )
170
+
171
+
172
+ add_function_test(TestMath, "test_scalar_math", test_scalar_math, devices=devices)
173
+ add_function_test(TestMath, "test_vec_norm", test_vec_norm, devices=devices)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ wp.clear_kernel_cache()
178
+ unittest.main(verbosity=2)
warp/tests/test_mlp.py ADDED
@@ -0,0 +1,282 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ @wp.func
25
+ def mlp_activation(z: float):
26
+ return wp.tanh(z)
27
+
28
+
29
+ @wp.kernel
30
+ def mlp_kernel(
31
+ weights: wp.array2d(dtype=float),
32
+ bias: wp.array(dtype=float),
33
+ x: wp.array2d(dtype=float),
34
+ y: wp.array2d(dtype=float),
35
+ ):
36
+ wp.mlp(weights, bias, mlp_activation, wp.tid(), x, y)
37
+
38
+
39
+ @wp.kernel
40
+ def loss_kernel(x: wp.array2d(dtype=float), loss: wp.array(dtype=float)):
41
+ i, j = wp.tid()
42
+
43
+ wp.atomic_add(loss, 0, x[i, j] * x[i, j])
44
+
45
+
46
+ def test_mlp(test, device):
47
+ rng = np.random.default_rng(123)
48
+
49
+ m = 10
50
+ n = 200
51
+
52
+ batches = 20000
53
+
54
+ weights = wp.array(rng.random(size=(m, n)) * 0.5 - 0.5, dtype=float, device=device)
55
+ bias = wp.array(rng.random(size=m) * 0.5 - 0.5, dtype=float, device=device)
56
+
57
+ x = wp.array(rng.random(size=(n, batches)), dtype=float, device=device)
58
+ y = wp.zeros(shape=(m, batches), device=device)
59
+
60
+ with wp.ScopedTimer("warp", active=False):
61
+ wp.launch(mlp_kernel, dim=batches, inputs=[weights, bias, x, y], device=device)
62
+ wp.synchronize()
63
+
64
+ # A*x + b
65
+ with wp.ScopedTimer("numpy", active=False):
66
+ expect = np.tanh(weights.numpy().reshape(m, n) @ x.numpy().reshape(-1, batches) + bias.numpy().reshape(m, 1))
67
+
68
+ result = y.numpy().reshape(-1, batches)
69
+
70
+ assert_np_equal(result, expect, tol=1.0e-6)
71
+
72
+
73
+ def create_mlp(m, n):
74
+ import torch
75
+
76
+ torch.manual_seed(0)
77
+
78
+ class FeedForward(torch.nn.Module):
79
+ def __init__(self, input_size, hidden_size):
80
+ super(FeedForward, self).__init__()
81
+
82
+ self.input_size = input_size
83
+ self.hidden_size = hidden_size
84
+ self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)
85
+ self.act = torch.nn.Tanh()
86
+
87
+ def forward(self, x):
88
+ out = self.fc1(x)
89
+ out = self.act(out)
90
+ return out
91
+
92
+ return FeedForward(m, n)
93
+
94
+
95
+ def create_golden():
96
+ import torch
97
+
98
+ rng = np.random.default_rng(123)
99
+
100
+ input_size = 32
101
+ hidden_size = 16
102
+ batch_size = 64
103
+
104
+ network = create_mlp(input_size, hidden_size)
105
+
106
+ x = torch.Tensor(rng.random(size=(batch_size, input_size)))
107
+ x.requires_grad = True
108
+
109
+ y = network.forward(x)
110
+ y.retain_grad()
111
+
112
+ loss = torch.inner(y.flatten(), y.flatten())
113
+ loss.backward(retain_graph=True)
114
+
115
+ results = {}
116
+ results["weights"] = network.fc1.weight.cpu().detach().numpy()
117
+ results["weights_grad"] = network.fc1.weight.grad.cpu().detach().numpy()
118
+ results["bias"] = network.fc1.bias.cpu().detach().numpy()
119
+ results["bias_grad"] = network.fc1.bias.grad.cpu().detach().numpy()
120
+ results["x"] = x.cpu().detach().numpy()
121
+ results["x_grad"] = x.grad.cpu().detach().numpy()
122
+ results["y"] = y.cpu().detach().numpy()
123
+ results["y_grad"] = y.grad.cpu().detach().numpy()
124
+ results["loss"] = loss.cpu().detach().numpy()
125
+
126
+ np.save(os.path.join(os.path.dirname(__file__), "assets/mlp_golden.npy"), results, allow_pickle=True)
127
+
128
+
129
+ def load_golden():
130
+ return np.load(os.path.join(os.path.dirname(__file__), "assets/mlp_golden.npy"), allow_pickle=True).item()
131
+
132
+
133
+ def test_mlp_grad(test, device):
134
+ # uncomment to re-build golden files
135
+ # create_golden()
136
+
137
+ results = load_golden()
138
+
139
+ torch_weights = results["weights"]
140
+ torch_weights_grad = results["weights_grad"]
141
+ torch_bias = results["bias"]
142
+ torch_bias_grad = results["bias_grad"]
143
+ torch_x = results["x"].T
144
+ torch_x_grad = results["x_grad"].T
145
+ torch_y = results["y"].T
146
+ torch_y_grad = results["y_grad"].T
147
+ torch_loss = results["loss"].T
148
+
149
+ weights = wp.array(torch_weights, dtype=float, device=device, requires_grad=True)
150
+ bias = wp.array(torch_bias, dtype=float, device=device, requires_grad=True)
151
+
152
+ x = wp.array(torch_x, dtype=float, device=device, requires_grad=True)
153
+ y = wp.array(torch_y, dtype=float, device=device, requires_grad=True)
154
+ y.zero_()
155
+
156
+ loss = wp.zeros(1, dtype=float, device=device, requires_grad=True)
157
+
158
+ m = torch_weights.shape[0]
159
+ n = torch_weights.shape[1]
160
+ b = torch_x.shape[1]
161
+
162
+ tape = wp.Tape()
163
+ with tape:
164
+ wp.launch(mlp_kernel, dim=b, inputs=[weights, bias, x, y], device=device)
165
+ wp.launch(loss_kernel, dim=y.shape, inputs=[y, loss], device=device)
166
+
167
+ tape.backward(loss=loss)
168
+
169
+ # check forward result
170
+ assert_np_equal(y.numpy().reshape(-1, b), torch_y, tol=1.0e-1)
171
+ assert_np_equal(loss.numpy(), torch_loss, tol=1.0e-1)
172
+
173
+ # check backward result
174
+ assert_np_equal(tape.gradients[weights].numpy().reshape(m, n), torch_weights_grad, tol=1.0e-1)
175
+ assert_np_equal(tape.gradients[bias].numpy(), torch_bias_grad, tol=1.0e-1)
176
+ assert_np_equal(tape.gradients[x].numpy().reshape(n, b), torch_x_grad, tol=1.0e-1)
177
+ assert_np_equal(tape.gradients[y].numpy().reshape(m, b), torch_y_grad, tol=1.0e-1)
178
+
179
+
180
+ def profile_mlp_torch():
181
+ import torch
182
+
183
+ rng = np.random.default_rng(123)
184
+
185
+ m = 128
186
+ n = 64
187
+
188
+ steps = 20
189
+
190
+ for i in range(steps):
191
+ b = 2**i
192
+
193
+ network = create_mlp(m, n)
194
+
195
+ x = torch.Tensor(rng.random(size=(b, m)))
196
+
197
+ with wp.ScopedTimer("torch_forward" + str(b)):
198
+ y = network.forward(x)
199
+ torch.cuda.synchronize()
200
+
201
+ for i in range(steps):
202
+ b = 2**i
203
+
204
+ network = create_mlp(m, n)
205
+
206
+ x = torch.Tensor(rng.random(size=(b, m)))
207
+ y = network.forward(x)
208
+
209
+ loss = torch.norm(y)
210
+
211
+ # run once to alloc all gradients
212
+ loss.backward(retain_graph=True)
213
+
214
+ with wp.ScopedTimer("torch-backward" + str(b)):
215
+ loss.backward()
216
+ torch.cuda.synchronize()
217
+
218
+
219
+ def profile_mlp_warp(device):
220
+ rng = np.random.default_rng(123)
221
+
222
+ m = 128
223
+ n = 64
224
+
225
+ steps = 20
226
+
227
+ for i in range(steps):
228
+ b = 2**i
229
+
230
+ weights = wp.array(rng.random(size=(m, n)) * 0.5 - 0.5, dtype=float, device=device)
231
+ bias = wp.array(rng.random(size=m) * 0.5 - 0.5, dtype=float, device=device)
232
+
233
+ x = wp.array(rng.random(size=(n, b)), dtype=float, device=device)
234
+ y = wp.zeros(shape=(m, b), device=device)
235
+
236
+ with wp.ScopedTimer("warp-forward" + str(b)):
237
+ wp.launch(mlp_kernel, dim=b, inputs=[weights, bias, x, y], device=device)
238
+ wp.synchronize()
239
+
240
+ for i in range(steps):
241
+ b = 2**i
242
+
243
+ weights = wp.array(rng.random(size=(m, n)) * 0.5 - 0.5, dtype=float, device=device, requires_grad=True)
244
+ bias = wp.array(rng.random(size=m) * 0.5 - 0.5, dtype=float, device=device, requires_grad=True)
245
+
246
+ x = wp.array(rng.random(size=(n, b)), dtype=float, device=device, requires_grad=True)
247
+ y = wp.zeros(shape=(m, b), device=device, requires_grad=True)
248
+
249
+ loss = wp.zeros(1, dtype=float, device=device)
250
+
251
+ tape = wp.Tape()
252
+ with tape:
253
+ wp.launch(mlp_kernel, dim=b, inputs=[weights, bias, x, y], device=device)
254
+ wp.launch(loss_kernel, dim=y.size, inputs=[y.flatten(), loss], device=device)
255
+
256
+ # run backward once to ensure all adjoints are allocated
257
+ tape.backward(loss)
258
+ wp.synchronize()
259
+
260
+ with wp.ScopedTimer("warp-backward" + str(b)):
261
+ tape.backward(loss)
262
+ wp.synchronize()
263
+
264
+
265
+ # profile_mlp_warp("cuda")
266
+ # profile_mlp_torch()
267
+
268
+
269
+ devices = get_test_devices()
270
+
271
+
272
+ class TestMLP(unittest.TestCase):
273
+ pass
274
+
275
+
276
+ add_function_test(TestMLP, "test_mlp", test_mlp, devices=devices, check_output=False)
277
+ add_function_test(TestMLP, "test_mlp_grad", test_mlp_grad, devices=devices, check_output=False)
278
+
279
+
280
+ if __name__ == "__main__":
281
+ wp.clear_kernel_cache()
282
+ unittest.main(verbosity=2, failfast=False)
@@ -0,0 +1,258 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # TODO: add more tests for kernels and generics
17
+
18
+ import os
19
+ import tempfile
20
+ import unittest
21
+ from importlib import util
22
+
23
+ import warp as wp
24
+ from warp.tests.unittest_utils import *
25
+
26
+ FUNC_OVERLOAD_1 = """# -*- coding: utf-8 -*-
27
+ import warp as wp
28
+
29
+ @wp.func
30
+ def fn():
31
+ wp.print(17)
32
+
33
+ @wp.func
34
+ def fn(value: int):
35
+ wp.print(value)
36
+
37
+ @wp.kernel
38
+ def k():
39
+ print(fn())
40
+ print(fn(99))
41
+ """
42
+
43
+ # should be same hash as FUNC_OVERLOAD_1
44
+ FUNC_OVERLOAD_2 = """# -*- coding: utf-8 -*-
45
+ import warp as wp
46
+
47
+ @wp.func
48
+ def fn():
49
+ wp.print(17)
50
+
51
+ @wp.func
52
+ def fn(value: int):
53
+ wp.print(value)
54
+
55
+ @wp.kernel
56
+ def k():
57
+ print(fn())
58
+ print(fn(99))
59
+ """
60
+
61
+ # should be different hash than FUNC_OVERLOAD_1 (first overload is different)
62
+ FUNC_OVERLOAD_3 = """# -*- coding: utf-8 -*-
63
+ import warp as wp
64
+
65
+ @wp.func
66
+ def fn():
67
+ wp.print(42)
68
+
69
+ @wp.func
70
+ def fn(value: int):
71
+ wp.print(value)
72
+
73
+ @wp.kernel
74
+ def k():
75
+ print(fn())
76
+ print(fn(99))
77
+ """
78
+
79
+ # should be different hash than FUNC_OVERLOAD_1 (second overload is different)
80
+ FUNC_OVERLOAD_4 = """# -*- coding: utf-8 -*-
81
+ import warp as wp
82
+
83
+ @wp.func
84
+ def fn():
85
+ wp.print(17)
86
+
87
+ @wp.func
88
+ def fn(value: int):
89
+ wp.print(value + 1)
90
+
91
+ @wp.kernel
92
+ def k():
93
+ print(fn())
94
+ print(fn(99))
95
+ """
96
+
97
+ FUNC_GENERIC_1 = """# -*- coding: utf-8 -*-
98
+ import warp as wp
99
+
100
+ from typing import Any
101
+
102
+ @wp.func
103
+ def generic_fn(x: Any):
104
+ return x * x
105
+
106
+ @wp.func
107
+ def generic_fn(x: Any, y: Any):
108
+ return x * y
109
+
110
+ @wp.kernel
111
+ def k():
112
+ print(generic_fn(17))
113
+ print(generic_fn(17, 42))
114
+ """
115
+
116
+ # should be same hash as FUNC_GENERIC_1
117
+ FUNC_GENERIC_2 = """# -*- coding: utf-8 -*-
118
+ import warp as wp
119
+
120
+ from typing import Any
121
+
122
+ @wp.func
123
+ def generic_fn(x: Any):
124
+ return x * x
125
+
126
+ @wp.func
127
+ def generic_fn(x: Any, y: Any):
128
+ return x * y
129
+
130
+ @wp.kernel
131
+ def k():
132
+ print(generic_fn(17))
133
+ print(generic_fn(17, 42))
134
+ """
135
+
136
+ # should be different hash than FUNC_GENERIC_1 (first overload is different)
137
+ FUNC_GENERIC_3 = """# -*- coding: utf-8 -*-
138
+ import warp as wp
139
+
140
+ from typing import Any
141
+
142
+ @wp.func
143
+ def generic_fn(x: Any):
144
+ return x + x
145
+
146
+ @wp.func
147
+ def generic_fn(x: Any, y: Any):
148
+ return x * y
149
+
150
+ @wp.kernel
151
+ def k():
152
+ print(generic_fn(17))
153
+ print(generic_fn(17, 42))
154
+ """
155
+
156
+ # should be different hash than FUNC_GENERIC_1 (second overload is different)
157
+ FUNC_GENERIC_4 = """# -*- coding: utf-8 -*-
158
+ import warp as wp
159
+
160
+ from typing import Any
161
+
162
+ @wp.func
163
+ def generic_fn(x: Any):
164
+ return x * x
165
+
166
+ @wp.func
167
+ def generic_fn(x: Any, y: Any):
168
+ return x + y
169
+
170
+ @wp.kernel
171
+ def k():
172
+ print(generic_fn(17))
173
+ print(generic_fn(17, 42))
174
+ """
175
+
176
+
177
+ def load_code_as_module(code, name):
178
+ file, file_path = tempfile.mkstemp(suffix=".py")
179
+
180
+ try:
181
+ with os.fdopen(file, "w") as f:
182
+ f.write(code)
183
+
184
+ spec = util.spec_from_file_location(name, file_path)
185
+ module = util.module_from_spec(spec)
186
+ spec.loader.exec_module(module)
187
+ finally:
188
+ os.remove(file_path)
189
+
190
+ return wp.get_module(module.__name__)
191
+
192
+
193
+ def test_function_overload_hashing(test, device):
194
+ m1 = load_code_as_module(FUNC_OVERLOAD_1, "func_overload_1")
195
+ m2 = load_code_as_module(FUNC_OVERLOAD_2, "func_overload_2")
196
+ m3 = load_code_as_module(FUNC_OVERLOAD_3, "func_overload_3")
197
+ m4 = load_code_as_module(FUNC_OVERLOAD_4, "func_overload_4")
198
+
199
+ hash1 = m1.hash_module()
200
+ hash2 = m2.hash_module()
201
+ hash3 = m3.hash_module()
202
+ hash4 = m4.hash_module()
203
+
204
+ test.assertEqual(hash2, hash1)
205
+ test.assertNotEqual(hash3, hash1)
206
+ test.assertNotEqual(hash4, hash1)
207
+
208
+
209
+ def test_function_generic_overload_hashing(test, device):
210
+ m1 = load_code_as_module(FUNC_GENERIC_1, "func_generic_1")
211
+ m2 = load_code_as_module(FUNC_GENERIC_2, "func_generic_2")
212
+ m3 = load_code_as_module(FUNC_GENERIC_3, "func_generic_3")
213
+ m4 = load_code_as_module(FUNC_GENERIC_4, "func_generic_4")
214
+
215
+ hash1 = m1.hash_module()
216
+ hash2 = m2.hash_module()
217
+ hash3 = m3.hash_module()
218
+ hash4 = m4.hash_module()
219
+
220
+ test.assertEqual(hash2, hash1)
221
+ test.assertNotEqual(hash3, hash1)
222
+ test.assertNotEqual(hash4, hash1)
223
+
224
+
225
+ SIMPLE_MODULE = """# -*- coding: utf-8 -*-
226
+ import warp as wp
227
+
228
+ @wp.kernel
229
+ def k():
230
+ pass
231
+ """
232
+
233
+
234
+ def test_module_load(test, device):
235
+ """Ensure that loading a module does not change its hash"""
236
+ m = load_code_as_module(SIMPLE_MODULE, "simple_module")
237
+
238
+ hash1 = m.hash_module()
239
+ m.load(device)
240
+ hash2 = m.hash_module()
241
+
242
+ test.assertEqual(hash1, hash2)
243
+
244
+
245
+ class TestModuleHashing(unittest.TestCase):
246
+ pass
247
+
248
+
249
+ devices = get_test_devices()
250
+
251
+ add_function_test(TestModuleHashing, "test_function_overload_hashing", test_function_overload_hashing)
252
+ add_function_test(TestModuleHashing, "test_function_generic_overload_hashing", test_function_generic_overload_hashing)
253
+ add_function_test(TestModuleHashing, "test_module_load", test_module_load, devices=devices)
254
+
255
+
256
+ if __name__ == "__main__":
257
+ wp.clear_kernel_cache()
258
+ unittest.main(verbosity=2)