warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/fem/linalg.py CHANGED
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any
2
17
 
3
18
  import warp as wp
@@ -157,11 +172,11 @@ def householder_qr_decomposition(A: Any):
157
172
 
158
173
  for i in range(type(x).length):
159
174
  for k in range(type(x).length):
160
- x[k] = wp.select(k < i, A[k, i], zero)
175
+ x[k] = wp.where(k < i, zero, A[k, i])
161
176
 
162
177
  alpha = wp.length(x) * wp.sign(x[i])
163
178
  x[i] += alpha
164
- two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
179
+ two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
165
180
 
166
181
  A -= wp.outer(two_over_x_sq * x, x * A)
167
182
  Q -= wp.outer(Q * x, two_over_x_sq * x)
@@ -186,11 +201,11 @@ def householder_make_hessenberg(A: Any):
186
201
 
187
202
  for i in range(1, type(x).length):
188
203
  for k in range(type(x).length):
189
- x[k] = wp.select(k < i, A[k, i - 1], zero)
204
+ x[k] = wp.where(k < i, zero, A[k, i - 1])
190
205
 
191
206
  alpha = wp.length(x) * wp.sign(x[i])
192
207
  x[i] += alpha
193
- two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
208
+ two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
194
209
 
195
210
  # apply on both sides
196
211
  A -= wp.outer(two_over_x_sq * x, x * A)
@@ -211,7 +226,7 @@ def solve_triangular(R: Any, b: Any):
211
226
  for i in range(b.length, 0, -1):
212
227
  j = i - 1
213
228
  r = b[j] - wp.dot(R[j], x)
214
- x[j] = wp.select(R[j, j] == zero, r / R[j, j], zero)
229
+ x[j] = wp.where(R[j, j] == zero, zero, r / R[j, j])
215
230
 
216
231
  return x
217
232
 
warp/fem/operator.py CHANGED
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any, Callable, Dict, Optional, Set
2
17
 
3
18
  import warp as wp
warp/fem/polynomial.py CHANGED
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  from enum import Enum
3
18
 
@@ -1,2 +1,17 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from .pic_quadrature import PicQuadrature
2
17
  from .quadrature import ExplicitQuadrature, NodalQuadrature, Quadrature, RegularQuadrature
@@ -1,9 +1,24 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from typing import Any, Optional, Tuple, Union
2
17
 
3
18
  import warp as wp
4
19
  from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_kernel
5
20
  from warp.fem.domain import GeometryDomain
6
- from warp.fem.types import Coords, ElementIndex, make_free_sample
21
+ from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, make_free_sample
7
22
  from warp.fem.utils import compress_node_indices
8
23
 
9
24
  from .quadrature import Quadrature
@@ -53,10 +68,10 @@ class PicQuadrature(Quadrature):
53
68
  def domain(self, domain: GeometryDomain):
54
69
  # Allow changing the quadrature domain as long as underlying geometry and element kind are the same
55
70
  if self.domain is not None and (
56
- domain.geometry != self.domain.geometry or domain.element_kind != self.domain.element_kind
71
+ domain.element_kind != self.domain.element_kind or domain.geometry.base != self.domain.geometry.base
57
72
  ):
58
73
  raise RuntimeError(
59
- "Cannot change the domain to that of a different Geometry and/or using different element kinds."
74
+ "The new domain must use the same base geometry and kind of elements as the current one."
60
75
  )
61
76
 
62
77
  self._domain = domain
@@ -74,11 +89,11 @@ class PicQuadrature(Quadrature):
74
89
  arg.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
75
90
  arg.cell_particle_indices = self._cell_particle_indices.array.to(device)
76
91
  arg.particle_fraction = self._particle_fraction.to(device)
77
- arg.particle_coords = self._particle_coords.to(device)
92
+ arg.particle_coords = self.particle_coords.to(device)
78
93
  return arg
79
94
 
80
95
  def total_point_count(self):
81
- return self._particle_coords.shape[0]
96
+ return self.particle_coords.shape[0]
82
97
 
83
98
  def active_cell_count(self):
84
99
  """Number of cells containing at least one particle"""
@@ -121,6 +136,12 @@ class PicQuadrature(Quadrature):
121
136
  particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
122
137
  return particle_index
123
138
 
139
+ @wp.func
140
+ def point_evaluation_index(
141
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
142
+ ):
143
+ return qp_arg.cell_particle_offsets[element_index] + index
144
+
124
145
  def fill_element_mask(self, mask: "wp.array(dtype=int)"):
125
146
  """Fills a mask array such that all non-empty elements are set to 1, all empty elements to zero.
126
147
 
@@ -141,7 +162,7 @@ class PicQuadrature(Quadrature):
141
162
  element_mask: wp.array(dtype=int),
142
163
  ):
143
164
  i = wp.tid()
144
- element_mask[i] = wp.select(element_particle_offsets[i] == element_particle_offsets[i + 1], 1, 0)
165
+ element_mask[i] = wp.where(element_particle_offsets[i] == element_particle_offsets[i + 1], 0, 1)
145
166
 
146
167
  @wp.kernel
147
168
  def _compute_uniform_fraction(
@@ -152,9 +173,11 @@ class PicQuadrature(Quadrature):
152
173
  p = wp.tid()
153
174
 
154
175
  cell = cell_index[p]
155
- cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
156
-
157
- cell_fraction[p] = 1.0 / float(cell_particle_count)
176
+ if cell == NULL_ELEMENT_INDEX:
177
+ cell_fraction[p] = 0.0
178
+ else:
179
+ cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
180
+ cell_fraction[p] = 1.0 / float(cell_particle_count)
158
181
 
159
182
  def _bin_particles(self, positions, measures, temporary_store: TemporaryStore):
160
183
  if wp.types.is_array(positions):
@@ -174,13 +197,13 @@ class PicQuadrature(Quadrature):
174
197
 
175
198
  device = positions.device
176
199
 
177
- cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
178
- cell_index = cell_index_temp.array
200
+ self._cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
201
+ self.cell_indices = self._cell_index_temp.array
179
202
 
180
203
  self._particle_coords_temp = borrow_temporary(
181
204
  temporary_store, shape=positions.shape, dtype=Coords, device=device, requires_grad=self._requires_grad
182
205
  )
183
- self._particle_coords = self._particle_coords_temp.array
206
+ self.particle_coords = self._particle_coords_temp.array
184
207
 
185
208
  wp.launch(
186
209
  dim=positions.shape[0],
@@ -188,25 +211,28 @@ class PicQuadrature(Quadrature):
188
211
  inputs=[
189
212
  self.domain.element_arg_value(device),
190
213
  positions,
191
- cell_index,
192
- self._particle_coords,
214
+ self.cell_indices,
215
+ self.particle_coords,
193
216
  ],
194
217
  device=device,
195
218
  )
196
219
 
197
220
  else:
198
- cell_index, self._particle_coords = positions
199
- if cell_index.shape != self._particle_coords.shape:
221
+ self.cell_indices, self.particle_coords = positions
222
+ if self.cell_indices.shape != self.particle_coords.shape:
200
223
  raise ValueError("Cell index and coordinates arrays must have the same shape")
201
224
 
202
- cell_index_temp = None
225
+ self._cell_index_temp = None
203
226
  self._particle_coords_temp = None
204
227
 
205
228
  self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
206
- self.domain.geometry_element_count(), cell_index, return_unique_nodes=True, temporary_store=temporary_store
229
+ self.domain.geometry_element_count(),
230
+ self.cell_indices,
231
+ return_unique_nodes=True,
232
+ temporary_store=temporary_store,
207
233
  )
208
234
 
209
- self._compute_fraction(cell_index, measures, temporary_store)
235
+ self._compute_fraction(self.cell_indices, measures, temporary_store)
210
236
 
211
237
  def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStore):
212
238
  device = cell_index.device
@@ -245,9 +271,13 @@ class PicQuadrature(Quadrature):
245
271
  cell_fraction: wp.array(dtype=float),
246
272
  ):
247
273
  p = wp.tid()
248
- sample = make_free_sample(cell_index[p], cell_coords[p])
249
274
 
250
- cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
275
+ cell = cell_index[p]
276
+ if cell == NULL_ELEMENT_INDEX:
277
+ cell_fraction[p] = 0.0
278
+ else:
279
+ sample = make_free_sample(cell_index[p], cell_coords[p])
280
+ cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
251
281
 
252
282
  wp.launch(
253
283
  dim=measures.shape[0],
@@ -256,7 +286,7 @@ class PicQuadrature(Quadrature):
256
286
  self.domain.element_arg_value(device),
257
287
  measures,
258
288
  cell_index,
259
- self._particle_coords,
289
+ self.particle_coords,
260
290
  self._particle_fraction,
261
291
  ],
262
292
  device=device,
@@ -1,14 +1,36 @@
1
- from typing import Any
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional
2
17
 
3
18
  import warp as wp
4
- from warp.fem import cache, domain
19
+ from warp.fem import cache
20
+ from warp.fem.domain import GeometryDomain
5
21
  from warp.fem.geometry import Element
6
22
  from warp.fem.space import FunctionSpace
7
- from warp.fem.types import Coords, ElementIndex
23
+ from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
8
24
 
9
25
  from ..polynomial import Polynomial
10
26
 
11
27
 
28
+ @wp.struct
29
+ class QuadraturePointElementIndex:
30
+ domain_element_index: ElementIndex
31
+ qp_index_in_element: int
32
+
33
+
12
34
  class Quadrature:
13
35
  """Interface class for quadrature rules"""
14
36
 
@@ -18,7 +40,7 @@ class Quadrature:
18
40
 
19
41
  pass
20
42
 
21
- def __init__(self, domain: domain.GeometryDomain):
43
+ def __init__(self, domain: GeometryDomain):
22
44
  self._domain = domain
23
45
 
24
46
  @property
@@ -30,52 +52,197 @@ class Quadrature:
30
52
  """
31
53
  Value of the argument to be passed to device
32
54
  """
33
- arg = RegularQuadrature.Arg()
55
+ arg = Quadrature.Arg()
34
56
  return arg
35
57
 
36
58
  def total_point_count(self):
37
- """Total number of quadrature points over the domain"""
59
+ """Number of unique quadrature points that can be indexed by this rule.
60
+ Returns a number such that `point_index()` is always smaller than this number.
61
+ """
38
62
  raise NotImplementedError()
39
63
 
64
+ def evaluation_point_count(self):
65
+ """Number of quadrature points that needs to be evaluated, mostly for internal purposes.
66
+ If the indexing scheme is sparse, or if a quadrature point is shared among multiple elements
67
+ (e.g, nodal quadrature), `evaluation_point_count` may be different than `total_point_count()`.
68
+ Returns a number such that `evaluation_point_index()` is always smaller than this number.
69
+ """
70
+ return self.total_point_count()
71
+
40
72
  def max_points_per_element(self):
41
73
  """Maximum number of points per element if known, or ``None`` otherwise"""
42
74
  return None
43
75
 
44
76
  @staticmethod
45
- def point_count(elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex):
77
+ def point_count(
78
+ elt_arg: "GeometryDomain.ElementArg",
79
+ qp_arg: Arg,
80
+ domain_element_index: ElementIndex,
81
+ geo_element_index: ElementIndex,
82
+ ):
46
83
  """Number of quadrature points for a given element"""
47
84
  raise NotImplementedError()
48
85
 
49
86
  @staticmethod
50
87
  def point_coords(
51
- elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
88
+ elt_arg: "GeometryDomain.ElementArg",
89
+ qp_arg: Arg,
90
+ domain_element_index: ElementIndex,
91
+ geo_element_index: ElementIndex,
92
+ element_qp_index: int,
52
93
  ):
53
94
  """Coordinates in element of the element's qp_index'th quadrature point"""
54
95
  raise NotImplementedError()
55
96
 
56
97
  @staticmethod
57
98
  def point_weight(
58
- elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
99
+ elt_arg: "GeometryDomain.ElementArg",
100
+ qp_arg: Arg,
101
+ domain_element_index: ElementIndex,
102
+ geo_element_index: ElementIndex,
103
+ element_qp_index: int,
59
104
  ):
60
105
  """Weight of the element's qp_index'th quadrature point"""
61
106
  raise NotImplementedError()
62
107
 
63
108
  @staticmethod
64
109
  def point_index(
65
- elt_arg: "domain.GeometryDomain.ElementArg",
110
+ elt_arg: "GeometryDomain.ElementArg",
111
+ qp_arg: Arg,
112
+ domain_element_index: ElementIndex,
113
+ geo_element_index: ElementIndex,
114
+ element_qp_index: int,
115
+ ):
116
+ """
117
+ Global index of the element's qp_index'th quadrature point.
118
+ May be shared among elements.
119
+ This is what determines `qp_index` in integrands' `Sample` arguments.
120
+ """
121
+ raise NotImplementedError()
122
+
123
+ @staticmethod
124
+ def point_evaluation_index(
125
+ elt_arg: "GeometryDomain.ElementArg",
66
126
  qp_arg: Arg,
67
127
  domain_element_index: ElementIndex,
68
128
  geo_element_index: ElementIndex,
69
129
  element_qp_index: int,
70
130
  ):
71
- """Global index of the element's qp_index'th quadrature point"""
131
+ """Quadrature point index according to evaluation order.
132
+ Quadrature points for distinct elements must have different evaluation indices.
133
+ Mostly for internal/parallelization purposes.
134
+ """
72
135
  raise NotImplementedError()
73
136
 
74
137
  def __str__(self) -> str:
75
138
  return self.name
76
139
 
140
+ # By default cache the mapping from evaluation point indices to domain elements
141
+
142
+ ElementIndexArg = wp.array(dtype=QuadraturePointElementIndex)
143
+
144
+ @cache.cached_arg_value
145
+ def element_index_arg_value(self, device):
146
+ """Builds a map from quadrature point evaluation indices to their index in the element to which they belong"""
147
+
148
+ @cache.dynamic_kernel(f"{self.name}{self.domain.name}")
149
+ def quadrature_point_element_indices(
150
+ qp_arg: self.Arg,
151
+ domain_arg: self.domain.ElementArg,
152
+ domain_index_arg: self.domain.ElementIndexArg,
153
+ result: wp.array(dtype=QuadraturePointElementIndex),
154
+ ):
155
+ domain_element_index = wp.tid()
156
+ element_index = self.domain.element_index(domain_index_arg, domain_element_index)
157
+
158
+ qp_point_count = self.point_count(domain_arg, qp_arg, domain_element_index, element_index)
159
+ for k in range(qp_point_count):
160
+ qp_eval_index = self.point_evaluation_index(domain_arg, qp_arg, domain_element_index, element_index, k)
161
+ result[qp_eval_index] = QuadraturePointElementIndex(domain_element_index, k)
162
+
163
+ null_qp_index = QuadraturePointElementIndex()
164
+ null_qp_index.domain_element_index = NULL_ELEMENT_INDEX
165
+ result = wp.full(
166
+ value=null_qp_index,
167
+ shape=(self.evaluation_point_count()),
168
+ dtype=QuadraturePointElementIndex,
169
+ device=device,
170
+ )
171
+ wp.launch(
172
+ quadrature_point_element_indices,
173
+ device=result.device,
174
+ dim=self.domain.element_count(),
175
+ inputs=[
176
+ self.arg_value(result.device),
177
+ self.domain.element_arg_value(result.device),
178
+ self.domain.element_index_arg_value(result.device),
179
+ result,
180
+ ],
181
+ )
182
+
183
+ return result
184
+
185
+ @wp.func
186
+ def evaluation_point_element_index(
187
+ element_index_arg: wp.array(dtype=QuadraturePointElementIndex),
188
+ qp_eval_index: QuadraturePointIndex,
189
+ ):
190
+ """Maps from quadrature point evaluation indices to their index in the element to which they belong
191
+ If the quadrature point does not exist, should return NULL_ELEMENT_INDEX as the domain element index
192
+ """
193
+
194
+ element_index = element_index_arg[qp_eval_index]
195
+ return element_index.domain_element_index, element_index.qp_index_in_element
196
+
197
+
198
+ class _QuadratureWithRegularEvaluationPoints(Quadrature):
199
+ """Helper subclass for quadrature formulas which use a uniform number of
200
+ evaluations points per element. Avoids building explicit mapping"""
201
+
202
+ def __init__(self, domain: GeometryDomain, N: int):
203
+ super().__init__(domain)
204
+ self._EVALUATION_POINTS_PER_ELEMENT = N
205
+
206
+ self.point_evaluation_index = self._make_regular_point_evaluation_index()
207
+ self.evaluation_point_element_index = self._make_regular_evaluation_point_element_index()
77
208
 
78
- class RegularQuadrature(Quadrature):
209
+ ElementIndexArg = Quadrature.Arg
210
+ element_index_arg_value = Quadrature.arg_value
211
+
212
+ def evaluation_point_count(self):
213
+ return self.domain.element_count() * self._EVALUATION_POINTS_PER_ELEMENT
214
+
215
+ def _make_regular_point_evaluation_index(self):
216
+ N = self._EVALUATION_POINTS_PER_ELEMENT
217
+
218
+ @cache.dynamic_func(suffix=f"{self.name}")
219
+ def evaluation_point_index(
220
+ elt_arg: self.domain.ElementArg,
221
+ qp_arg: self.Arg,
222
+ domain_element_index: ElementIndex,
223
+ element_index: ElementIndex,
224
+ qp_index: int,
225
+ ):
226
+ return N * domain_element_index + qp_index
227
+
228
+ return evaluation_point_index
229
+
230
+ def _make_regular_evaluation_point_element_index(self):
231
+ N = self._EVALUATION_POINTS_PER_ELEMENT
232
+
233
+ @cache.dynamic_func(suffix=f"{N}")
234
+ def quadrature_evaluation_point_element_index(
235
+ qp_arg: Quadrature.Arg,
236
+ qp_index: QuadraturePointIndex,
237
+ ):
238
+ domain_element_index = qp_index // N
239
+ index_in_element = qp_index - domain_element_index * N
240
+ return domain_element_index, index_in_element
241
+
242
+ return quadrature_evaluation_point_element_index
243
+
244
+
245
+ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
79
246
  """Regular quadrature formula, using a constant set of quadrature points per element"""
80
247
 
81
248
  @wp.struct
@@ -112,16 +279,15 @@ class RegularQuadrature(Quadrature):
112
279
 
113
280
  def __init__(
114
281
  self,
115
- domain: domain.GeometryDomain,
282
+ domain: GeometryDomain,
116
283
  order: int,
117
284
  family: Polynomial = None,
118
285
  ):
119
- super().__init__(domain)
120
-
286
+ self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
121
287
  self.family = family
122
288
  self.order = order
123
289
 
124
- self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
290
+ super().__init__(domain, self._formula.count)
125
291
 
126
292
  self.point_count = self._make_point_count()
127
293
  self.point_index = self._make_point_index()
@@ -212,17 +378,18 @@ class NodalQuadrature(Quadrature):
212
378
  any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
213
379
  """
214
380
 
215
- def __init__(self, domain: domain.GeometryDomain, space: FunctionSpace):
216
- super().__init__(domain)
217
-
381
+ def __init__(self, domain: Optional[GeometryDomain], space: FunctionSpace):
218
382
  self._space = space
219
383
 
384
+ super().__init__(domain)
385
+
220
386
  self.Arg = self._make_arg()
221
387
 
222
388
  self.point_count = self._make_point_count()
223
389
  self.point_index = self._make_point_index()
224
390
  self.point_coords = self._make_point_coords()
225
391
  self.point_weight = self._make_point_weight()
392
+ self.point_evaluation_index = self._make_point_evaluation_index()
226
393
 
227
394
  @property
228
395
  def name(self):
@@ -300,8 +467,26 @@ class NodalQuadrature(Quadrature):
300
467
 
301
468
  return point_index
302
469
 
470
+ def evaluation_point_count(self):
471
+ return self.domain.element_count() * self._space.topology.MAX_NODES_PER_ELEMENT
303
472
 
304
- class ExplicitQuadrature(Quadrature):
473
+ def _make_point_evaluation_index(self):
474
+ N = self._space.topology.MAX_NODES_PER_ELEMENT
475
+
476
+ @cache.dynamic_func(suffix=self.name)
477
+ def evaluation_point_index(
478
+ elt_arg: self.domain.ElementArg,
479
+ qp_arg: self.Arg,
480
+ domain_element_index: ElementIndex,
481
+ element_index: ElementIndex,
482
+ qp_index: int,
483
+ ):
484
+ return N * domain_element_index + qp_index
485
+
486
+ return evaluation_point_index
487
+
488
+
489
+ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
305
490
  """Quadrature using explicit per-cell points and weights.
306
491
 
307
492
  The number of quadrature points per cell is assumed to be constant and deduced from the shape of the points and weights arrays.
@@ -321,11 +506,7 @@ class ExplicitQuadrature(Quadrature):
321
506
  points: wp.array2d(dtype=Coords)
322
507
  weights: wp.array2d(dtype=float)
323
508
 
324
- def __init__(
325
- self, domain: domain.GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"
326
- ):
327
- super().__init__(domain)
328
-
509
+ def __init__(self, domain: GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
329
510
  if points.shape != weights.shape:
330
511
  raise ValueError("Points and weights arrays must have the same shape")
331
512
 
@@ -343,7 +524,10 @@ class ExplicitQuadrature(Quadrature):
343
524
  )
344
525
 
345
526
  self._points_per_cell = points.shape[1]
527
+
346
528
  self._whole_geo = points.shape[0] == domain.geometry_element_count()
529
+
530
+ super().__init__(domain, self._points_per_cell)
347
531
  self._points = points
348
532
  self._weights = weights
349
533
 
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  # isort: skip_file
2
17
 
3
18
  from enum import Enum
@@ -97,7 +112,7 @@ def make_polynomial_basis_space(
97
112
  the constructed basis space
98
113
  """
99
114
 
100
- base_geo = geo.base if isinstance(geo, _geometry.DeformedGeometry) else geo
115
+ base_geo = geo.base
101
116
 
102
117
  if element_basis is None:
103
118
  element_basis = ElementBasis.LAGRANGE