warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,136 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+
24
+ @wp.kernel
25
+ def load_store_half(f32: wp.array(dtype=wp.float32), f16: wp.array(dtype=wp.float16)):
26
+ tid = wp.tid()
27
+
28
+ # check conversion from f32->f16
29
+ a = wp.float16(f32[tid])
30
+ b = f16[tid]
31
+
32
+ wp.expect_eq(a, b)
33
+
34
+ # check stores
35
+ f16[tid] = a
36
+
37
+
38
+ def test_fp16_conversion(test, device):
39
+ s = [1.0, 2.0, 3.0, -3.14159]
40
+
41
+ np_f32 = np.array(s, dtype=np.float32)
42
+ np_f16 = np.array(s, dtype=np.float16)
43
+
44
+ wp_f32 = wp.array(s, dtype=wp.float32, device=device)
45
+ wp_f16 = wp.array(s, dtype=wp.float16, device=device)
46
+
47
+ assert_np_equal(np_f32, wp_f32.numpy())
48
+ assert_np_equal(np_f16, wp_f16.numpy())
49
+
50
+ wp.launch(load_store_half, dim=len(s), inputs=[wp_f32, wp_f16], device=device)
51
+
52
+ # check that stores worked
53
+ assert_np_equal(np_f16, wp_f16.numpy())
54
+
55
+
56
+ @wp.kernel
57
+ def value_load_store_half(f16_value: wp.float16, f16_array: wp.array(dtype=wp.float16)):
58
+ wp.expect_eq(f16_value, f16_array[0])
59
+
60
+ # check stores
61
+ f16_array[0] = f16_value
62
+
63
+
64
+ def test_fp16_kernel_parameter(test, device):
65
+ """Test the ability to pass in fp16 into kernels as parameters"""
66
+
67
+ s = [1.0, 2.0, 3.0, -3.14159]
68
+
69
+ for test_val in s:
70
+ np_f16 = np.array([test_val], dtype=np.float16)
71
+ wp_f16 = wp.array([test_val], dtype=wp.float16, device=device)
72
+
73
+ wp.launch(value_load_store_half, (1,), inputs=[wp.float16(test_val), wp_f16], device=device)
74
+
75
+ # check that stores worked
76
+ assert_np_equal(np_f16, wp_f16.numpy())
77
+
78
+ # Do the same thing but pass in test_val as a Python float to test automatic conversion
79
+ wp_f16 = wp.array([test_val], dtype=wp.float16, device=device)
80
+ wp.launch(value_load_store_half, (1,), inputs=[test_val, wp_f16], device=device)
81
+ assert_np_equal(np_f16, wp_f16.numpy())
82
+
83
+
84
+ @wp.kernel
85
+ def mul_half(input: wp.array(dtype=wp.float16), output: wp.array(dtype=wp.float16)):
86
+ tid = wp.tid()
87
+
88
+ # convert to compute type fp32
89
+ x = wp.float(input[tid]) * 2.0
90
+
91
+ # store back as fp16
92
+ output[tid] = wp.float16(x)
93
+
94
+
95
+ def test_fp16_grad(test, device):
96
+ rng = np.random.default_rng(123)
97
+
98
+ # checks that gradients are correctly propagated for
99
+ # fp16 arrays, even when intermediate calculations
100
+ # are performed in e.g.: fp32
101
+
102
+ s = rng.random(size=15).astype(np.float16)
103
+
104
+ input = wp.array(s, dtype=wp.float16, device=device, requires_grad=True)
105
+ output = wp.zeros_like(input)
106
+
107
+ tape = wp.Tape()
108
+ with tape:
109
+ wp.launch(mul_half, dim=len(s), inputs=[input, output], device=device)
110
+
111
+ ones = wp.array(np.ones(len(output)), dtype=wp.float16, device=device)
112
+
113
+ tape.backward(grads={output: ones})
114
+
115
+ assert_np_equal(input.grad.numpy(), np.ones(len(s)) * 2.0)
116
+
117
+
118
+ class TestFp16(unittest.TestCase):
119
+ pass
120
+
121
+
122
+ devices = []
123
+ if wp.is_cpu_available():
124
+ devices.append("cpu")
125
+ for cuda_device in get_selected_cuda_test_devices():
126
+ if cuda_device.arch >= 70:
127
+ devices.append(cuda_device)
128
+
129
+ add_function_test(TestFp16, "test_fp16_conversion", test_fp16_conversion, devices=devices)
130
+ add_function_test(TestFp16, "test_fp16_grad", test_fp16_grad, devices=devices)
131
+ add_function_test(TestFp16, "test_fp16_kernel_parameter", test_fp16_kernel_parameter, devices=devices)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ wp.clear_kernel_cache()
136
+ unittest.main(verbosity=2)
@@ -0,0 +1,454 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ import unittest
18
+ from typing import Any, Tuple
19
+
20
+ import numpy as np
21
+
22
+ import warp as wp
23
+ from warp.tests.unittest_utils import *
24
+
25
+
26
+ @wp.func
27
+ def sqr(x: float):
28
+ return x * x
29
+
30
+
31
+ # test nested user function calls
32
+ # and explicit return type hints
33
+ @wp.func
34
+ def cube(x: float) -> float:
35
+ return sqr(x) * x
36
+
37
+
38
+ @wp.func
39
+ def custom(x: int):
40
+ return x + 1
41
+
42
+
43
+ @wp.func
44
+ def custom(x: float):
45
+ return x + 1.0
46
+
47
+
48
+ @wp.func
49
+ def custom(x: wp.vec3):
50
+ return x + wp.vec3(1.0, 0.0, 0.0)
51
+
52
+
53
+ @wp.func
54
+ def noreturn(x: wp.vec3):
55
+ x = x + wp.vec3(0.0, 1.0, 0.0)
56
+
57
+ wp.expect_eq(x, wp.vec3(1.0, 1.0, 0.0))
58
+
59
+
60
+ @wp.kernel
61
+ def test_overload_func():
62
+ # tests overloading a custom @wp.func
63
+
64
+ i = custom(1)
65
+ f = custom(1.0)
66
+ v = custom(wp.vec3(1.0, 0.0, 0.0))
67
+
68
+ wp.expect_eq(i, 2)
69
+ wp.expect_eq(f, 2.0)
70
+ wp.expect_eq(v, wp.vec3(2.0, 0.0, 0.0))
71
+
72
+ noreturn(wp.vec3(1.0, 0.0, 0.0))
73
+
74
+
75
+ @wp.func
76
+ def foo(x: int):
77
+ # This shouldn't be picked up.
78
+ return x * 2
79
+
80
+
81
+ @wp.func
82
+ def foo(x: int):
83
+ return x * 3
84
+
85
+
86
+ @wp.kernel
87
+ def test_override_func():
88
+ i = foo(1)
89
+ wp.expect_eq(i, 3)
90
+
91
+
92
+ def test_func_closure_capture(test, device):
93
+ def make_closure_kernel(func):
94
+ def closure_kernel_fn(data: wp.array(dtype=float), expected: float):
95
+ f = func(data[wp.tid()])
96
+ wp.expect_eq(f, expected)
97
+
98
+ return wp.Kernel(func=closure_kernel_fn)
99
+
100
+ sqr_closure = make_closure_kernel(sqr)
101
+ cube_closure = make_closure_kernel(cube)
102
+
103
+ data = wp.array([2.0], dtype=float, device=device)
104
+ expected_sqr = 4.0
105
+ expected_cube = 8.0
106
+
107
+ wp.launch(sqr_closure, dim=data.shape, inputs=[data, expected_sqr], device=device)
108
+ wp.launch(cube_closure, dim=data.shape, inputs=[data, expected_cube], device=device)
109
+
110
+
111
+ @wp.func
112
+ def test_func(param1: wp.int32, param2: wp.int32, param3: wp.int32) -> wp.float32:
113
+ return 1.0
114
+
115
+
116
+ @wp.kernel
117
+ def test_return_kernel(test_data: wp.array(dtype=wp.float32)):
118
+ tid = wp.tid()
119
+ test_data[tid] = wp.lerp(test_func(0, 1, 2), test_func(0, 1, 2), 0.5)
120
+
121
+
122
+ def test_return_func(test, device):
123
+ test_data = wp.zeros(100, dtype=wp.float32, device=device)
124
+ wp.launch(kernel=test_return_kernel, dim=test_data.size, inputs=[test_data], device=device)
125
+
126
+
127
+ @wp.func
128
+ def multi_valued_func(a: wp.float32, b: wp.float32):
129
+ return a + b, a - b, a * b, a / b
130
+
131
+
132
+ def test_multi_valued_func(test, device):
133
+ @wp.kernel
134
+ def test_multi_valued_kernel(test_data1: wp.array(dtype=wp.float32), test_data2: wp.array(dtype=wp.float32)):
135
+ tid = wp.tid()
136
+ d1, d2 = test_data1[tid], test_data2[tid]
137
+ a, b, c, d = multi_valued_func(d1, d2)
138
+ wp.expect_eq(a, d1 + d2)
139
+ wp.expect_eq(b, d1 - d2)
140
+ wp.expect_eq(c, d1 * d2)
141
+ wp.expect_eq(d, d1 / d2)
142
+
143
+ test_data1 = wp.array(np.arange(100), dtype=wp.float32, device=device)
144
+ test_data2 = wp.array(np.arange(100, 0, -1), dtype=wp.float32, device=device)
145
+ wp.launch(kernel=test_multi_valued_kernel, dim=test_data1.size, inputs=[test_data1, test_data2], device=device)
146
+
147
+
148
+ @wp.kernel
149
+ def test_func_defaults():
150
+ # test default as expected
151
+ wp.expect_near(1.0, 1.0 + 1.0e-6)
152
+
153
+ # test that changing tolerance still works
154
+ wp.expect_near(1.0, 1.1, 0.5)
155
+
156
+
157
+ @wp.func
158
+ def sign(x: float):
159
+ return 123.0
160
+
161
+
162
+ @wp.kernel
163
+ def test_builtin_shadowing():
164
+ wp.expect_eq(sign(1.23), 123.0)
165
+
166
+
167
+ @wp.func
168
+ def user_func_with_defaults(a: int = 123, b: int = 234) -> int:
169
+ return a + b
170
+
171
+
172
+ @wp.kernel
173
+ def user_func_with_defaults_kernel():
174
+ a = user_func_with_defaults()
175
+ wp.expect_eq(a, 357)
176
+
177
+ b = user_func_with_defaults(111)
178
+ wp.expect_eq(b, 345)
179
+
180
+ c = user_func_with_defaults(111, 222)
181
+ wp.expect_eq(c, 333)
182
+
183
+ d = user_func_with_defaults(a=111)
184
+ wp.expect_eq(d, 345)
185
+
186
+ e = user_func_with_defaults(b=111)
187
+ wp.expect_eq(e, 234)
188
+
189
+
190
+ def test_user_func_with_defaults(test, device):
191
+ wp.launch(user_func_with_defaults_kernel, dim=1, device=device)
192
+
193
+ a = user_func_with_defaults()
194
+ assert a == 357
195
+
196
+ b = user_func_with_defaults(111)
197
+ assert b == 345
198
+
199
+ c = user_func_with_defaults(111, 222)
200
+ assert c == 333
201
+
202
+ d = user_func_with_defaults(a=111)
203
+ assert d == 345
204
+
205
+ e = user_func_with_defaults(b=111)
206
+ assert e == 234
207
+
208
+
209
+ @wp.func
210
+ def user_func_return_multiple_values(a: int, b: float) -> Tuple[int, float]:
211
+ return a + a, b * b
212
+
213
+
214
+ @wp.kernel
215
+ def test_user_func_return_multiple_values():
216
+ a, b = user_func_return_multiple_values(123, 234.0)
217
+ wp.expect_eq(a, 246)
218
+ wp.expect_eq(b, 54756.0)
219
+
220
+
221
+ @wp.func
222
+ def user_func_overload(
223
+ b: wp.array(dtype=Any),
224
+ i: int,
225
+ ):
226
+ return b[i] * 2.0
227
+
228
+
229
+ @wp.kernel
230
+ def user_func_overload_resolution_kernel(
231
+ a: wp.array(dtype=Any),
232
+ b: wp.array(dtype=Any),
233
+ ):
234
+ i = wp.tid()
235
+ a[i] = user_func_overload(b, i)
236
+
237
+
238
+ def test_user_func_overload_resolution(test, device):
239
+ a0 = wp.array((1, 2, 3), dtype=wp.vec3)
240
+ b0 = wp.array((2, 3, 4), dtype=wp.vec3)
241
+
242
+ a1 = wp.array((5,), dtype=float)
243
+ b1 = wp.array((6,), dtype=float)
244
+
245
+ wp.launch(user_func_overload_resolution_kernel, a0.shape, (a0, b0))
246
+ wp.launch(user_func_overload_resolution_kernel, a1.shape, (a1, b1))
247
+
248
+ assert_np_equal(a0.numpy()[0], (4, 6, 8))
249
+ assert a1.numpy()[0] == 12
250
+
251
+
252
+ @wp.func
253
+ def user_func_return_none() -> None:
254
+ pass
255
+
256
+
257
+ @wp.kernel
258
+ def test_return_annotation_none() -> None:
259
+ user_func_return_none()
260
+
261
+
262
+ devices = get_test_devices()
263
+
264
+
265
+ class TestFunc(unittest.TestCase):
266
+ def test_user_func_export(self):
267
+ # tests calling overloaded user-defined functions from Python
268
+ i = custom(1)
269
+ f = custom(1.0)
270
+ v = custom(wp.vec3(1.0, 0.0, 0.0))
271
+
272
+ self.assertEqual(i, 2)
273
+ self.assertEqual(f, 2.0)
274
+ assert_np_equal(np.array([*v]), np.array([2.0, 0.0, 0.0]))
275
+
276
+ def test_native_func_export(self):
277
+ # tests calling native functions from Python
278
+
279
+ q = wp.quat(0.0, 0.0, 0.0, 1.0)
280
+ assert_np_equal(np.array([*q]), np.array([0.0, 0.0, 0.0, 1.0]))
281
+
282
+ r = wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), 2.0)
283
+ assert_np_equal(np.array([*r]), np.array([0.8414709568023682, 0.0, 0.0, 0.5403022170066833]), tol=1.0e-3)
284
+
285
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
286
+ q = wp.normalize(q) * 2.0
287
+ assert_np_equal(
288
+ np.array([*q]),
289
+ np.array([0.18257418274879456, 0.3651483654975891, 0.547722578048706, 0.7302967309951782]) * 2.0,
290
+ tol=1.0e-3,
291
+ )
292
+
293
+ v2 = wp.vec2(1.0, 2.0)
294
+ v2 = wp.normalize(v2) * 2.0
295
+ assert_np_equal(np.array([*v2]), np.array([0.4472135901451111, 0.8944271802902222]) * 2.0, tol=1.0e-3)
296
+
297
+ v3 = wp.vec3(1.0, 2.0, 3.0)
298
+ v3 = wp.normalize(v3) * 2.0
299
+ assert_np_equal(
300
+ np.array([*v3]), np.array([0.26726123690605164, 0.5345224738121033, 0.8017836809158325]) * 2.0, tol=1.0e-3
301
+ )
302
+
303
+ v4 = wp.vec4(1.0, 2.0, 3.0, 4.0)
304
+ v4 = wp.normalize(v4) * 2.0
305
+ assert_np_equal(
306
+ np.array([*v4]),
307
+ np.array([0.18257418274879456, 0.3651483654975891, 0.547722578048706, 0.7302967309951782]) * 2.0,
308
+ tol=1.0e-3,
309
+ )
310
+
311
+ v = wp.vec2(0.0)
312
+ v += wp.vec2(1.0, 1.0)
313
+ assert v == wp.vec2(1.0, 1.0)
314
+ v -= wp.vec2(1.0, 1.0)
315
+ assert v == wp.vec2(0.0, 0.0)
316
+ v = wp.vec2(2.0, 2.0) - wp.vec2(1.0, 1.0)
317
+ assert v == wp.vec2(1.0, 1.0)
318
+ v *= 2.0
319
+ assert v == wp.vec2(2.0, 2.0)
320
+ v = v * 2.0
321
+ assert v == wp.vec2(4.0, 4.0)
322
+ v = v / 2.0
323
+ assert v == wp.vec2(2.0, 2.0)
324
+ v /= 2.0
325
+ assert v == wp.vec2(1.0, 1.0)
326
+ v = -v
327
+ assert v == wp.vec2(-1.0, -1.0)
328
+ v = +v
329
+ assert v == wp.vec2(-1.0, -1.0)
330
+
331
+ m22 = wp.mat22(1.0, 2.0, 3.0, 4.0)
332
+ m22 = m22 + m22
333
+
334
+ self.assertEqual(m22[1, 1], 8.0)
335
+ self.assertEqual(str(m22), "[[2.0, 4.0],\n [6.0, 8.0]]")
336
+
337
+ t = wp.transform(
338
+ wp.vec3(1.0, 2.0, 3.0),
339
+ wp.quat(4.0, 5.0, 6.0, 7.0),
340
+ )
341
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
342
+ self.assertSequenceEqual(
343
+ t * wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0), (396.0, 432.0, 720.0, 56.0, 70.0, 84.0, -28.0)
344
+ )
345
+ self.assertSequenceEqual(
346
+ t * wp.transform((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0)), (396.0, 432.0, 720.0, 56.0, 70.0, 84.0, -28.0)
347
+ )
348
+
349
+ t = wp.transform()
350
+ self.assertSequenceEqual(t, (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0))
351
+
352
+ t = wp.transform(p=(1.0, 2.0, 3.0), q=(4.0, 5.0, 6.0, 7.0))
353
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
354
+
355
+ t = wp.transform(q=(4.0, 5.0, 6.0, 7.0), p=(1.0, 2.0, 3.0))
356
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
357
+
358
+ t = wp.transform((1.0, 2.0, 3.0), q=(4.0, 5.0, 6.0, 7.0))
359
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
360
+
361
+ t = wp.transform(p=(1.0, 2.0, 3.0))
362
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 1.0))
363
+
364
+ t = wp.transform(q=(4.0, 5.0, 6.0, 7.0))
365
+ self.assertSequenceEqual(t, (0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0))
366
+
367
+ t = wp.transform((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0))
368
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
369
+
370
+ t = wp.transform(p=wp.vec3(1.0, 2.0, 3.0), q=wp.quat(4.0, 5.0, 6.0, 7.0))
371
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
372
+
373
+ t = wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0)
374
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
375
+
376
+ t = wp.transform(wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
377
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
378
+
379
+ t = wp.transform(*wp.transform(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
380
+ self.assertSequenceEqual(t, (1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0))
381
+
382
+ transformf = wp.types.transformation(dtype=float)
383
+
384
+ t = wp.transformf((1.0, 2.0, 3.0), (4.0, 5.0, 6.0, 7.0))
385
+ self.assertSequenceEqual(
386
+ t + transformf((2.0, 3.0, 4.0), (5.0, 6.0, 7.0, 8.0)),
387
+ (3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0),
388
+ )
389
+ self.assertSequenceEqual(
390
+ t - transformf((2.0, 3.0, 4.0), (5.0, 6.0, 7.0, 8.0)),
391
+ (-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0),
392
+ )
393
+
394
+ f = wp.sin(math.pi * 0.5)
395
+ self.assertAlmostEqual(f, 1.0, places=3)
396
+
397
+ m = wp.mat22(0.0, 0.0, 0.0, 0.0)
398
+ m += wp.mat22(1.0, 1.0, 1.0, 1.0)
399
+ assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
400
+ m -= wp.mat22(1.0, 1.0, 1.0, 1.0)
401
+ assert m == wp.mat22(0.0, 0.0, 0.0, 0.0)
402
+ m = wp.mat22(2.0, 2.0, 2.0, 2.0) - wp.mat22(1.0, 1.0, 1.0, 1.0)
403
+ assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
404
+ m *= 2.0
405
+ assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
406
+ m = m * 2.0
407
+ assert m == wp.mat22(4.0, 4.0, 4.0, 4.0)
408
+ m = m / 2.0
409
+ assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
410
+ m /= 2.0
411
+ assert m == wp.mat22(1.0, 1.0, 1.0, 1.0)
412
+ m = -m
413
+ assert m == wp.mat22(-1.0, -1.0, -1.0, -1.0)
414
+ m = +m
415
+ assert m == wp.mat22(-1.0, -1.0, -1.0, -1.0)
416
+ m = m * m
417
+ assert m == wp.mat22(2.0, 2.0, 2.0, 2.0)
418
+
419
+ def test_native_function_error_resolution(self):
420
+ a = wp.mat22f(1.0, 2.0, 3.0, 4.0)
421
+ b = wp.mat22d(1.0, 2.0, 3.0, 4.0)
422
+ with self.assertRaisesRegex(
423
+ RuntimeError,
424
+ r"^Couldn't find a function 'mul' compatible with " r"the arguments 'mat22f, mat22d'$",
425
+ ):
426
+ a * b
427
+
428
+
429
+ add_kernel_test(TestFunc, kernel=test_overload_func, name="test_overload_func", dim=1, devices=devices)
430
+ add_function_test(TestFunc, func=test_return_func, name="test_return_func", devices=devices)
431
+ add_kernel_test(TestFunc, kernel=test_override_func, name="test_override_func", dim=1, devices=devices)
432
+ add_function_test(TestFunc, func=test_func_closure_capture, name="test_func_closure_capture", devices=devices)
433
+ add_function_test(TestFunc, func=test_multi_valued_func, name="test_multi_valued_func", devices=devices)
434
+ add_kernel_test(TestFunc, kernel=test_func_defaults, name="test_func_defaults", dim=1, devices=devices)
435
+ add_kernel_test(TestFunc, kernel=test_builtin_shadowing, name="test_builtin_shadowing", dim=1, devices=devices)
436
+ add_function_test(TestFunc, func=test_user_func_with_defaults, name="test_user_func_with_defaults", devices=devices)
437
+ add_kernel_test(
438
+ TestFunc,
439
+ kernel=test_user_func_return_multiple_values,
440
+ name="test_user_func_return_multiple_values",
441
+ dim=1,
442
+ devices=devices,
443
+ )
444
+ add_function_test(
445
+ TestFunc, func=test_user_func_overload_resolution, name="test_user_func_overload_resolution", devices=devices
446
+ )
447
+ add_kernel_test(
448
+ TestFunc, kernel=test_return_annotation_none, name="test_return_annotation_none", dim=1, devices=devices
449
+ )
450
+
451
+
452
+ if __name__ == "__main__":
453
+ wp.clear_kernel_cache()
454
+ unittest.main(verbosity=2)
@@ -0,0 +1,98 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This is what we are actually testing.
17
+ from __future__ import annotations
18
+
19
+ import unittest
20
+
21
+ import warp as wp
22
+ from warp.tests.unittest_utils import *
23
+
24
+
25
+ @wp.struct
26
+ class FooData:
27
+ x: float
28
+ y: float
29
+
30
+
31
+ class Foo:
32
+ Data = FooData
33
+
34
+ @wp.func
35
+ def compute():
36
+ pass
37
+
38
+
39
+ @wp.kernel
40
+ def kernel_1(
41
+ out: wp.array(dtype=float),
42
+ ):
43
+ tid = wp.tid()
44
+
45
+
46
+ @wp.kernel
47
+ def kernel_2(
48
+ out: wp.array(dtype=float),
49
+ ):
50
+ tid = wp.tid()
51
+ out[tid] = 1.23
52
+
53
+
54
+ def create_kernel_3(foo: Foo):
55
+ def fn(
56
+ data: foo.Data,
57
+ out: wp.array(dtype=float),
58
+ ):
59
+ tid = wp.tid()
60
+
61
+ # Referencing a variable in a type hint like `foo.Data` isn't officially
62
+ # accepted by Python but it's still being used in some places (e.g.: `warp.fem`)
63
+ # where it works only because the variable being referenced within the function,
64
+ # which causes it to be promoted to a closure variable. Without that,
65
+ # it wouldn't be possible to resolve `foo` and to evaluate the `foo.Data`
66
+ # string to its corresponding type.
67
+ foo.compute()
68
+
69
+ out[tid] = data.x + data.y
70
+
71
+ return wp.Kernel(func=fn)
72
+
73
+
74
+ def test_future_annotations(test, device):
75
+ foo = Foo()
76
+ foo_data = FooData()
77
+ foo_data.x = 1.23
78
+ foo_data.y = 2.34
79
+
80
+ out = wp.empty(1, dtype=float)
81
+
82
+ kernel_3 = create_kernel_3(foo)
83
+
84
+ wp.launch(kernel_1, dim=out.shape, outputs=(out,))
85
+ wp.launch(kernel_2, dim=out.shape, outputs=(out,))
86
+ wp.launch(kernel_3, dim=out.shape, inputs=(foo_data,), outputs=(out,))
87
+
88
+
89
+ class TestFutureAnnotations(unittest.TestCase):
90
+ pass
91
+
92
+
93
+ add_function_test(TestFutureAnnotations, "test_future_annotations", test_future_annotations)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ wp.clear_kernel_cache()
98
+ unittest.main(verbosity=2)