warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,591 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional
17
+
18
+ import warp as wp
19
+ from warp.fem import cache
20
+ from warp.fem.domain import GeometryDomain
21
+ from warp.fem.geometry import Element
22
+ from warp.fem.space import FunctionSpace
23
+ from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
24
+
25
+ from ..polynomial import Polynomial
26
+
27
+
28
+ @wp.struct
29
+ class QuadraturePointElementIndex:
30
+ domain_element_index: ElementIndex
31
+ qp_index_in_element: int
32
+
33
+
34
+ class Quadrature:
35
+ """Interface class for quadrature rules"""
36
+
37
+ @wp.struct
38
+ class Arg:
39
+ """Structure containing arguments to be passed to device functions"""
40
+
41
+ pass
42
+
43
+ def __init__(self, domain: GeometryDomain):
44
+ self._domain = domain
45
+
46
+ @property
47
+ def domain(self):
48
+ """Domain over which this quadrature is defined"""
49
+ return self._domain
50
+
51
+ def arg_value(self, device) -> "Arg":
52
+ """
53
+ Value of the argument to be passed to device
54
+ """
55
+ arg = Quadrature.Arg()
56
+ return arg
57
+
58
+ def total_point_count(self):
59
+ """Number of unique quadrature points that can be indexed by this rule.
60
+ Returns a number such that `point_index()` is always smaller than this number.
61
+ """
62
+ raise NotImplementedError()
63
+
64
+ def evaluation_point_count(self):
65
+ """Number of quadrature points that needs to be evaluated, mostly for internal purposes.
66
+ If the indexing scheme is sparse, or if a quadrature point is shared among multiple elements
67
+ (e.g, nodal quadrature), `evaluation_point_count` may be different than `total_point_count()`.
68
+ Returns a number such that `evaluation_point_index()` is always smaller than this number.
69
+ """
70
+ return self.total_point_count()
71
+
72
+ def max_points_per_element(self):
73
+ """Maximum number of points per element if known, or ``None`` otherwise"""
74
+ return None
75
+
76
+ @staticmethod
77
+ def point_count(
78
+ elt_arg: "GeometryDomain.ElementArg",
79
+ qp_arg: Arg,
80
+ domain_element_index: ElementIndex,
81
+ geo_element_index: ElementIndex,
82
+ ):
83
+ """Number of quadrature points for a given element"""
84
+ raise NotImplementedError()
85
+
86
+ @staticmethod
87
+ def point_coords(
88
+ elt_arg: "GeometryDomain.ElementArg",
89
+ qp_arg: Arg,
90
+ domain_element_index: ElementIndex,
91
+ geo_element_index: ElementIndex,
92
+ element_qp_index: int,
93
+ ):
94
+ """Coordinates in element of the element's qp_index'th quadrature point"""
95
+ raise NotImplementedError()
96
+
97
+ @staticmethod
98
+ def point_weight(
99
+ elt_arg: "GeometryDomain.ElementArg",
100
+ qp_arg: Arg,
101
+ domain_element_index: ElementIndex,
102
+ geo_element_index: ElementIndex,
103
+ element_qp_index: int,
104
+ ):
105
+ """Weight of the element's qp_index'th quadrature point"""
106
+ raise NotImplementedError()
107
+
108
+ @staticmethod
109
+ def point_index(
110
+ elt_arg: "GeometryDomain.ElementArg",
111
+ qp_arg: Arg,
112
+ domain_element_index: ElementIndex,
113
+ geo_element_index: ElementIndex,
114
+ element_qp_index: int,
115
+ ):
116
+ """
117
+ Global index of the element's qp_index'th quadrature point.
118
+ May be shared among elements.
119
+ This is what determines `qp_index` in integrands' `Sample` arguments.
120
+ """
121
+ raise NotImplementedError()
122
+
123
+ @staticmethod
124
+ def point_evaluation_index(
125
+ elt_arg: "GeometryDomain.ElementArg",
126
+ qp_arg: Arg,
127
+ domain_element_index: ElementIndex,
128
+ geo_element_index: ElementIndex,
129
+ element_qp_index: int,
130
+ ):
131
+ """Quadrature point index according to evaluation order.
132
+ Quadrature points for distinct elements must have different evaluation indices.
133
+ Mostly for internal/parallelization purposes.
134
+ """
135
+ raise NotImplementedError()
136
+
137
+ def __str__(self) -> str:
138
+ return self.name
139
+
140
+ # By default cache the mapping from evaluation point indices to domain elements
141
+
142
+ ElementIndexArg = wp.array(dtype=QuadraturePointElementIndex)
143
+
144
+ @cache.cached_arg_value
145
+ def element_index_arg_value(self, device):
146
+ """Builds a map from quadrature point evaluation indices to their index in the element to which they belong"""
147
+
148
+ @cache.dynamic_kernel(f"{self.name}{self.domain.name}")
149
+ def quadrature_point_element_indices(
150
+ qp_arg: self.Arg,
151
+ domain_arg: self.domain.ElementArg,
152
+ domain_index_arg: self.domain.ElementIndexArg,
153
+ result: wp.array(dtype=QuadraturePointElementIndex),
154
+ ):
155
+ domain_element_index = wp.tid()
156
+ element_index = self.domain.element_index(domain_index_arg, domain_element_index)
157
+
158
+ qp_point_count = self.point_count(domain_arg, qp_arg, domain_element_index, element_index)
159
+ for k in range(qp_point_count):
160
+ qp_eval_index = self.point_evaluation_index(domain_arg, qp_arg, domain_element_index, element_index, k)
161
+ result[qp_eval_index] = QuadraturePointElementIndex(domain_element_index, k)
162
+
163
+ null_qp_index = QuadraturePointElementIndex()
164
+ null_qp_index.domain_element_index = NULL_ELEMENT_INDEX
165
+ result = wp.full(
166
+ value=null_qp_index,
167
+ shape=(self.evaluation_point_count()),
168
+ dtype=QuadraturePointElementIndex,
169
+ device=device,
170
+ )
171
+ wp.launch(
172
+ quadrature_point_element_indices,
173
+ device=result.device,
174
+ dim=self.domain.element_count(),
175
+ inputs=[
176
+ self.arg_value(result.device),
177
+ self.domain.element_arg_value(result.device),
178
+ self.domain.element_index_arg_value(result.device),
179
+ result,
180
+ ],
181
+ )
182
+
183
+ return result
184
+
185
+ @wp.func
186
+ def evaluation_point_element_index(
187
+ element_index_arg: wp.array(dtype=QuadraturePointElementIndex),
188
+ qp_eval_index: QuadraturePointIndex,
189
+ ):
190
+ """Maps from quadrature point evaluation indices to their index in the element to which they belong
191
+ If the quadrature point does not exist, should return NULL_ELEMENT_INDEX as the domain element index
192
+ """
193
+
194
+ element_index = element_index_arg[qp_eval_index]
195
+ return element_index.domain_element_index, element_index.qp_index_in_element
196
+
197
+
198
+ class _QuadratureWithRegularEvaluationPoints(Quadrature):
199
+ """Helper subclass for quadrature formulas which use a uniform number of
200
+ evaluations points per element. Avoids building explicit mapping"""
201
+
202
+ def __init__(self, domain: GeometryDomain, N: int):
203
+ super().__init__(domain)
204
+ self._EVALUATION_POINTS_PER_ELEMENT = N
205
+
206
+ self.point_evaluation_index = self._make_regular_point_evaluation_index()
207
+ self.evaluation_point_element_index = self._make_regular_evaluation_point_element_index()
208
+
209
+ ElementIndexArg = Quadrature.Arg
210
+ element_index_arg_value = Quadrature.arg_value
211
+
212
+ def evaluation_point_count(self):
213
+ return self.domain.element_count() * self._EVALUATION_POINTS_PER_ELEMENT
214
+
215
+ def _make_regular_point_evaluation_index(self):
216
+ N = self._EVALUATION_POINTS_PER_ELEMENT
217
+
218
+ @cache.dynamic_func(suffix=f"{self.name}")
219
+ def evaluation_point_index(
220
+ elt_arg: self.domain.ElementArg,
221
+ qp_arg: self.Arg,
222
+ domain_element_index: ElementIndex,
223
+ element_index: ElementIndex,
224
+ qp_index: int,
225
+ ):
226
+ return N * domain_element_index + qp_index
227
+
228
+ return evaluation_point_index
229
+
230
+ def _make_regular_evaluation_point_element_index(self):
231
+ N = self._EVALUATION_POINTS_PER_ELEMENT
232
+
233
+ @cache.dynamic_func(suffix=f"{N}")
234
+ def quadrature_evaluation_point_element_index(
235
+ qp_arg: Quadrature.Arg,
236
+ qp_index: QuadraturePointIndex,
237
+ ):
238
+ domain_element_index = qp_index // N
239
+ index_in_element = qp_index - domain_element_index * N
240
+ return domain_element_index, index_in_element
241
+
242
+ return quadrature_evaluation_point_element_index
243
+
244
+
245
+ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
246
+ """Regular quadrature formula, using a constant set of quadrature points per element"""
247
+
248
+ @wp.struct
249
+ class Arg:
250
+ # Quadrature points and weights used to be passed as Warp constants,
251
+ # but this tended to incur register spilling for high point counts
252
+ points: wp.array(dtype=Coords)
253
+ weights: wp.array(dtype=float)
254
+
255
+ # Cache common formulas so we do dot have to do h2d transfer for each call
256
+ class CachedFormula:
257
+ _cache = {}
258
+
259
+ def __init__(self, element: Element, order: int, family: Polynomial):
260
+ self.points, self.weights = element.instantiate_quadrature(order, family)
261
+ self.count = wp.constant(len(self.points))
262
+
263
+ @cache.cached_arg_value
264
+ def arg_value(self, device):
265
+ arg = RegularQuadrature.Arg()
266
+ arg.points = wp.array(self.points, device=device, dtype=Coords)
267
+ arg.weights = wp.array(self.weights, device=device, dtype=float)
268
+ return arg
269
+
270
+ @staticmethod
271
+ def get(element: Element, order: int, family: Polynomial):
272
+ key = (element.__class__.__name__, order, family)
273
+ try:
274
+ return RegularQuadrature.CachedFormula._cache[key]
275
+ except KeyError:
276
+ quadrature = RegularQuadrature.CachedFormula(element, order, family)
277
+ RegularQuadrature.CachedFormula._cache[key] = quadrature
278
+ return quadrature
279
+
280
+ def __init__(
281
+ self,
282
+ domain: GeometryDomain,
283
+ order: int,
284
+ family: Polynomial = None,
285
+ ):
286
+ self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
287
+ self.family = family
288
+ self.order = order
289
+
290
+ super().__init__(domain, self._formula.count)
291
+
292
+ self.point_count = self._make_point_count()
293
+ self.point_index = self._make_point_index()
294
+ self.point_coords = self._make_point_coords()
295
+ self.point_weight = self._make_point_weight()
296
+
297
+ @property
298
+ def name(self):
299
+ return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
300
+
301
+ def total_point_count(self):
302
+ return self._formula.count * self.domain.element_count()
303
+
304
+ def max_points_per_element(self):
305
+ return self._formula.count
306
+
307
+ @property
308
+ def points(self):
309
+ return self._formula.points
310
+
311
+ @property
312
+ def weights(self):
313
+ return self._formula.weights
314
+
315
+ def arg_value(self, device):
316
+ return self._formula.arg_value(device)
317
+
318
+ def _make_point_count(self):
319
+ N = self._formula.count
320
+
321
+ @cache.dynamic_func(suffix=self.name)
322
+ def point_count(
323
+ elt_arg: self.domain.ElementArg,
324
+ qp_arg: self.Arg,
325
+ domain_element_index: ElementIndex,
326
+ element_index: ElementIndex,
327
+ ):
328
+ return N
329
+
330
+ return point_count
331
+
332
+ def _make_point_coords(self):
333
+ @cache.dynamic_func(suffix=self.name)
334
+ def point_coords(
335
+ elt_arg: self.domain.ElementArg,
336
+ qp_arg: self.Arg,
337
+ domain_element_index: ElementIndex,
338
+ element_index: ElementIndex,
339
+ qp_index: int,
340
+ ):
341
+ return qp_arg.points[qp_index]
342
+
343
+ return point_coords
344
+
345
+ def _make_point_weight(self):
346
+ @cache.dynamic_func(suffix=self.name)
347
+ def point_weight(
348
+ elt_arg: self.domain.ElementArg,
349
+ qp_arg: self.Arg,
350
+ domain_element_index: ElementIndex,
351
+ element_index: ElementIndex,
352
+ qp_index: int,
353
+ ):
354
+ return qp_arg.weights[qp_index]
355
+
356
+ return point_weight
357
+
358
+ def _make_point_index(self):
359
+ N = self._formula.count
360
+
361
+ @cache.dynamic_func(suffix=self.name)
362
+ def point_index(
363
+ elt_arg: self.domain.ElementArg,
364
+ qp_arg: self.Arg,
365
+ domain_element_index: ElementIndex,
366
+ element_index: ElementIndex,
367
+ qp_index: int,
368
+ ):
369
+ return N * domain_element_index + qp_index
370
+
371
+ return point_index
372
+
373
+
374
+ class NodalQuadrature(Quadrature):
375
+ """Quadrature using space node points as quadrature points
376
+
377
+ Note that in contrast to the `nodal=True` flag for :func:`integrate`, using this quadrature does not imply
378
+ any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
379
+ """
380
+
381
+ def __init__(self, domain: Optional[GeometryDomain], space: FunctionSpace):
382
+ self._space = space
383
+
384
+ super().__init__(domain)
385
+
386
+ self.Arg = self._make_arg()
387
+
388
+ self.point_count = self._make_point_count()
389
+ self.point_index = self._make_point_index()
390
+ self.point_coords = self._make_point_coords()
391
+ self.point_weight = self._make_point_weight()
392
+ self.point_evaluation_index = self._make_point_evaluation_index()
393
+
394
+ @property
395
+ def name(self):
396
+ return f"{self.__class__.__name__}_{self._space.name}"
397
+
398
+ def total_point_count(self):
399
+ return self._space.node_count()
400
+
401
+ def max_points_per_element(self):
402
+ return self._space.topology.MAX_NODES_PER_ELEMENT
403
+
404
+ def _make_arg(self):
405
+ @cache.dynamic_struct(suffix=self.name)
406
+ class Arg:
407
+ space_arg: self._space.SpaceArg
408
+ topo_arg: self._space.topology.TopologyArg
409
+
410
+ return Arg
411
+
412
+ @cache.cached_arg_value
413
+ def arg_value(self, device):
414
+ arg = self.Arg()
415
+ arg.space_arg = self._space.space_arg_value(device)
416
+ arg.topo_arg = self._space.topology.topo_arg_value(device)
417
+ return arg
418
+
419
+ def _make_point_count(self):
420
+ @cache.dynamic_func(suffix=self.name)
421
+ def point_count(
422
+ elt_arg: self.domain.ElementArg,
423
+ qp_arg: self.Arg,
424
+ domain_element_index: ElementIndex,
425
+ element_index: ElementIndex,
426
+ ):
427
+ return self._space.topology.element_node_count(elt_arg, qp_arg.topo_arg, element_index)
428
+
429
+ return point_count
430
+
431
+ def _make_point_coords(self):
432
+ @cache.dynamic_func(suffix=self.name)
433
+ def point_coords(
434
+ elt_arg: self.domain.ElementArg,
435
+ qp_arg: self.Arg,
436
+ domain_element_index: ElementIndex,
437
+ element_index: ElementIndex,
438
+ qp_index: int,
439
+ ):
440
+ return self._space.node_coords_in_element(elt_arg, qp_arg.space_arg, element_index, qp_index)
441
+
442
+ return point_coords
443
+
444
+ def _make_point_weight(self):
445
+ @cache.dynamic_func(suffix=self.name)
446
+ def point_weight(
447
+ elt_arg: self.domain.ElementArg,
448
+ qp_arg: self.Arg,
449
+ domain_element_index: ElementIndex,
450
+ element_index: ElementIndex,
451
+ qp_index: int,
452
+ ):
453
+ return self._space.node_quadrature_weight(elt_arg, qp_arg.space_arg, element_index, qp_index)
454
+
455
+ return point_weight
456
+
457
+ def _make_point_index(self):
458
+ @cache.dynamic_func(suffix=self.name)
459
+ def point_index(
460
+ elt_arg: self.domain.ElementArg,
461
+ qp_arg: self.Arg,
462
+ domain_element_index: ElementIndex,
463
+ element_index: ElementIndex,
464
+ qp_index: int,
465
+ ):
466
+ return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
467
+
468
+ return point_index
469
+
470
+ def evaluation_point_count(self):
471
+ return self.domain.element_count() * self._space.topology.MAX_NODES_PER_ELEMENT
472
+
473
+ def _make_point_evaluation_index(self):
474
+ N = self._space.topology.MAX_NODES_PER_ELEMENT
475
+
476
+ @cache.dynamic_func(suffix=self.name)
477
+ def evaluation_point_index(
478
+ elt_arg: self.domain.ElementArg,
479
+ qp_arg: self.Arg,
480
+ domain_element_index: ElementIndex,
481
+ element_index: ElementIndex,
482
+ qp_index: int,
483
+ ):
484
+ return N * domain_element_index + qp_index
485
+
486
+ return evaluation_point_index
487
+
488
+
489
+ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
490
+ """Quadrature using explicit per-cell points and weights.
491
+
492
+ The number of quadrature points per cell is assumed to be constant and deduced from the shape of the points and weights arrays.
493
+ Quadrature points may be provided for either the whole geometry or just the domain's elements.
494
+
495
+ Args:
496
+ domain: Domain of definition of the quadrature formula
497
+ points: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the coordinates of each quadrature point.
498
+ weights: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the weight for each quadrature point.
499
+
500
+ See also: :class:`PicQuadrature`
501
+ """
502
+
503
+ @wp.struct
504
+ class Arg:
505
+ points_per_cell: int
506
+ points: wp.array2d(dtype=Coords)
507
+ weights: wp.array2d(dtype=float)
508
+
509
+ def __init__(self, domain: GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
510
+ if points.shape != weights.shape:
511
+ raise ValueError("Points and weights arrays must have the same shape")
512
+
513
+ if points.shape[0] == domain.geometry_element_count():
514
+ self.point_index = ExplicitQuadrature._point_index_geo
515
+ self.point_coords = ExplicitQuadrature._point_coords_geo
516
+ self.point_weight = ExplicitQuadrature._point_weight_geo
517
+ elif points.shape[0] == domain.element_count():
518
+ self.point_index = ExplicitQuadrature._point_index_domain
519
+ self.point_coords = ExplicitQuadrature._point_coords_domain
520
+ self.point_weight = ExplicitQuadrature._point_weight_domain
521
+ else:
522
+ raise NotImplementedError(
523
+ "The number of rows of points and weights must match the element count of either the domain or the geometry"
524
+ )
525
+
526
+ self._points_per_cell = points.shape[1]
527
+
528
+ self._whole_geo = points.shape[0] == domain.geometry_element_count()
529
+
530
+ super().__init__(domain, self._points_per_cell)
531
+ self._points = points
532
+ self._weights = weights
533
+
534
+ @property
535
+ def name(self):
536
+ return f"{self.__class__.__name__}_{self._whole_geo}"
537
+
538
+ def total_point_count(self):
539
+ return self._weights.size
540
+
541
+ def max_points_per_element(self):
542
+ return self._points_per_cell
543
+
544
+ @cache.cached_arg_value
545
+ def arg_value(self, device):
546
+ arg = self.Arg()
547
+ arg.points_per_cell = self._points_per_cell
548
+ arg.points = self._points.to(device)
549
+ arg.weights = self._weights.to(device)
550
+
551
+ return arg
552
+
553
+ @wp.func
554
+ def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
555
+ return qp_arg.points.shape[1]
556
+
557
+ @wp.func
558
+ def _point_coords_domain(
559
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
560
+ ):
561
+ return qp_arg.points[domain_element_index, qp_index]
562
+
563
+ @wp.func
564
+ def _point_weight_domain(
565
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
566
+ ):
567
+ return qp_arg.weights[domain_element_index, qp_index]
568
+
569
+ @wp.func
570
+ def _point_index_domain(
571
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
572
+ ):
573
+ return qp_arg.points_per_cell * domain_element_index + qp_index
574
+
575
+ @wp.func
576
+ def _point_coords_geo(
577
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
578
+ ):
579
+ return qp_arg.points[element_index, qp_index]
580
+
581
+ @wp.func
582
+ def _point_weight_geo(
583
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
584
+ ):
585
+ return qp_arg.weights[element_index, qp_index]
586
+
587
+ @wp.func
588
+ def _point_index_geo(
589
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
590
+ ):
591
+ return qp_arg.points_per_cell * element_index + qp_index