warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/fem/integrate.py CHANGED
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import ast
2
17
  import inspect
3
18
  import textwrap
@@ -19,7 +34,7 @@ from warp.fem.field import (
19
34
  make_restriction,
20
35
  )
21
36
  from warp.fem.field.virtual import make_bilinear_dispatch_kernel, make_linear_dispatch_kernel
22
- from warp.fem.linalg import array_axpy
37
+ from warp.fem.linalg import array_axpy, basis_coefficient
23
38
  from warp.fem.operator import Integrand, Operator, at_node, integrand
24
39
  from warp.fem.quadrature import Quadrature, RegularQuadrature
25
40
  from warp.fem.types import (
@@ -478,7 +493,7 @@ class PassFieldArgsToIntegrand(ast.NodeTransformer):
478
493
  callee = getattr(call.func, "id", None)
479
494
 
480
495
  if callee == self._func_name:
481
- # Replace function arguments with ours generated structs
496
+ # Replace function arguments with our generated structs
482
497
  call.args.clear()
483
498
  for arg in self._arg_names:
484
499
  if arg == self._domain_name:
@@ -561,33 +576,33 @@ def get_integrate_constant_kernel(
561
576
  ):
562
577
  def integrate_kernel_fn(
563
578
  qp_arg: quadrature.Arg,
579
+ qp_element_index_arg: quadrature.ElementIndexArg,
564
580
  domain_arg: domain.ElementArg,
565
581
  domain_index_arg: domain.ElementIndexArg,
566
582
  fields: FieldStruct,
567
583
  values: ValueStruct,
568
584
  result: wp.array(dtype=accumulate_dtype),
569
585
  ):
570
- domain_element_index = wp.tid()
586
+ qp_eval_index = wp.tid()
587
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
588
+ if domain_element_index == NULL_ELEMENT_INDEX:
589
+ return
590
+
571
591
  element_index = domain.element_index(domain_index_arg, domain_element_index)
572
- elem_sum = accumulate_dtype(0.0)
592
+
593
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
594
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
595
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
573
596
 
574
597
  test_dof_index = NULL_DOF_INDEX
575
598
  trial_dof_index = NULL_DOF_INDEX
576
599
 
577
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
578
- for k in range(qp_point_count):
579
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
580
- coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
581
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
582
-
583
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
584
- vol = domain.element_measure(domain_arg, sample)
600
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
601
+ vol = domain.element_measure(domain_arg, sample)
585
602
 
586
- val = integrand_func(sample, fields, values)
587
-
588
- elem_sum += accumulate_dtype(qp_weight * vol * val)
603
+ val = integrand_func(sample, fields, values)
589
604
 
590
- wp.atomic_add(result, 0, elem_sum)
605
+ wp.atomic_add(result, 0, accumulate_dtype(qp_weight * vol * val))
591
606
 
592
607
  return integrate_kernel_fn
593
608
 
@@ -730,35 +745,35 @@ def get_integrate_linear_local_kernel(
730
745
  ValueStruct: wp.codegen.Struct,
731
746
  test: LocalTestField,
732
747
  ):
733
- TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
734
-
735
748
  def integrate_kernel_fn(
736
749
  qp_arg: quadrature.Arg,
750
+ qp_element_index_arg: quadrature.ElementIndexArg,
737
751
  domain_arg: domain.ElementArg,
738
752
  domain_index_arg: domain.ElementIndexArg,
739
753
  fields: FieldStruct,
740
754
  values: ValueStruct,
741
755
  result: wp.array3d(dtype=float),
742
756
  ):
743
- domain_element_index, taylor_dof, test_dof = wp.tid()
744
- element_index = domain.element_index(domain_index_arg, domain_element_index)
757
+ qp_eval_index, taylor_dof, test_dof = wp.tid()
758
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
745
759
 
746
- trial_dof_index = NULL_DOF_INDEX
747
- test_dof_offset = test_dof * TAYLOR_DOF_COUNT
760
+ if domain_element_index == NULL_ELEMENT_INDEX:
761
+ return
748
762
 
749
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
750
- for qp in range(qp_point_count):
751
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
752
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
753
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
763
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
754
764
 
755
- vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
765
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
766
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
767
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
756
768
 
757
- test_dof_index = DofIndex(qp_index, test_dof_offset + taylor_dof)
769
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
758
770
 
759
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
760
- val = integrand_func(sample, fields, values)
761
- result[qp_index, taylor_dof, test_dof] = qp_weight * vol * val
771
+ trial_dof_index = NULL_DOF_INDEX
772
+ test_dof_index = DofIndex(taylor_dof, test_dof)
773
+
774
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
775
+ val = integrand_func(sample, fields, values)
776
+ result[qp_eval_index, taylor_dof, test_dof] = qp_weight * vol * val
762
777
 
763
778
  return integrate_kernel_fn
764
779
 
@@ -803,10 +818,10 @@ def get_integrate_bilinear_kernel(
803
818
  element_trial_node_count = trial.space.topology.element_node_count(
804
819
  domain_arg, trial_topology_arg, element_index
805
820
  )
806
- qp_point_count = wp.select(
821
+ qp_point_count = wp.where(
807
822
  trial_node < element_trial_node_count,
808
- 0,
809
823
  quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
824
+ 0,
810
825
  )
811
826
 
812
827
  test_dof_index = DofIndex(
@@ -948,36 +963,38 @@ def get_integrate_bilinear_local_kernel(
948
963
 
949
964
  def integrate_kernel_fn(
950
965
  qp_arg: quadrature.Arg,
966
+ qp_element_index_arg: quadrature.ElementIndexArg,
951
967
  domain_arg: domain.ElementArg,
952
968
  domain_index_arg: domain.ElementIndexArg,
953
969
  fields: FieldStruct,
954
970
  values: ValueStruct,
955
971
  result: wp.array4d(dtype=float),
956
972
  ):
957
- domain_element_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
973
+ qp_eval_index, test_dof, trial_dof, trial_taylor_dof = wp.tid()
974
+
975
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
976
+ if domain_element_index == NULL_ELEMENT_INDEX:
977
+ return
978
+
958
979
  element_index = domain.element_index(domain_index_arg, domain_element_index)
959
980
 
960
- test_dof_offset = TEST_TAYLOR_DOF_COUNT * test_dof
961
- trial_dof_offset = TRIAL_TAYLOR_DOF_COUNT * trial_dof
981
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
982
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
983
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
962
984
 
963
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
964
- for k in range(qp_point_count):
965
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
966
- qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
967
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
985
+ vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
986
+ qp_vol = vol * qp_weight
968
987
 
969
- vol = domain.element_measure(domain_arg, make_free_sample(element_index, qp_coords))
970
- qp_vol = vol * qp_weight
988
+ trial_dof_index = DofIndex(trial_taylor_dof, trial_dof)
971
989
 
972
- for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
973
- taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
990
+ for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
991
+ taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
974
992
 
975
- test_dof_index = DofIndex(qp_index, test_dof_offset + test_taylor_dof)
976
- trial_dof_index = DofIndex(qp_index, trial_dof_offset + trial_taylor_dof)
993
+ test_dof_index = DofIndex(test_taylor_dof, test_dof)
977
994
 
978
- sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
979
- val = integrand_func(sample, fields, values)
980
- result[qp_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
995
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
996
+ val = integrand_func(sample, fields, values)
997
+ result[qp_eval_index, test_dof, trial_dof, taylor_dof] = qp_vol * val
981
998
 
982
999
  return integrate_kernel_fn
983
1000
 
@@ -1123,6 +1140,7 @@ def _launch_integrate_kernel(
1123
1140
  output_dtype: type,
1124
1141
  output: Optional[Union[wp.array, BsrMatrix]],
1125
1142
  add_to_output: bool,
1143
+ bsr_options: Optional[Dict[str, Any]],
1126
1144
  device,
1127
1145
  ):
1128
1146
  # Set-up launch arguments
@@ -1160,9 +1178,10 @@ def _launch_integrate_kernel(
1160
1178
 
1161
1179
  wp.launch(
1162
1180
  kernel=kernel,
1163
- dim=domain.element_count(),
1181
+ dim=quadrature.evaluation_point_count(),
1164
1182
  inputs=[
1165
1183
  qp_arg,
1184
+ quadrature.element_index_arg_value(device),
1166
1185
  domain_elt_arg,
1167
1186
  domain_elt_index_arg,
1168
1187
  field_arg_values,
@@ -1264,15 +1283,16 @@ def _launch_integrate_kernel(
1264
1283
  temporary_store=temporary_store,
1265
1284
  device=device,
1266
1285
  requires_grad=output.requires_grad,
1267
- shape=(quadrature.total_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1286
+ shape=(quadrature.evaluation_point_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1268
1287
  dtype=float,
1269
1288
  )
1270
1289
 
1271
1290
  wp.launch(
1272
1291
  kernel=kernel,
1273
- dim=(domain.element_count(), test.TAYLOR_DOF_COUNT, test.value_dof_count),
1292
+ dim=local_result.array.shape,
1274
1293
  inputs=[
1275
1294
  qp_arg,
1295
+ quadrature.element_index_arg_value(device),
1276
1296
  domain_elt_arg,
1277
1297
  domain_elt_index_arg,
1278
1298
  field_arg_values,
@@ -1374,7 +1394,7 @@ def _launch_integrate_kernel(
1374
1394
  device=device,
1375
1395
  requires_grad=False,
1376
1396
  shape=(
1377
- quadrature.total_point_count(),
1397
+ quadrature.evaluation_point_count(),
1378
1398
  test.value_dof_count,
1379
1399
  trial.value_dof_count,
1380
1400
  test.TAYLOR_DOF_COUNT * trial.TAYLOR_DOF_COUNT,
@@ -1384,9 +1404,15 @@ def _launch_integrate_kernel(
1384
1404
 
1385
1405
  wp.launch(
1386
1406
  kernel=kernel,
1387
- dim=(domain.element_count(), test.value_dof_count, trial.value_dof_count, trial.TAYLOR_DOF_COUNT),
1407
+ dim=(
1408
+ quadrature.evaluation_point_count(),
1409
+ test.value_dof_count,
1410
+ trial.value_dof_count,
1411
+ trial.TAYLOR_DOF_COUNT,
1412
+ ),
1388
1413
  inputs=[
1389
1414
  qp_arg,
1415
+ quadrature.element_index_arg_value(device),
1390
1416
  domain_elt_arg,
1391
1417
  domain_elt_index_arg,
1392
1418
  field_arg_values,
@@ -1481,7 +1507,7 @@ def _launch_integrate_kernel(
1481
1507
  else:
1482
1508
  bsr_result = output
1483
1509
 
1484
- bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values)
1510
+ bsr_set_from_triplets(bsr_result, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
1485
1511
 
1486
1512
  # Do not wait for garbage collection
1487
1513
  triplet_values_temp.release()
@@ -1526,8 +1552,9 @@ def integrate(
1526
1552
  device=None,
1527
1553
  temporary_store: Optional[cache.TemporaryStore] = None,
1528
1554
  kernel_options: Optional[Dict[str, Any]] = None,
1529
- assembly: str = None,
1555
+ assembly: Optional[str] = None,
1530
1556
  add: bool = False,
1557
+ bsr_options: Optional[Dict[str, Any]] = None,
1531
1558
  ):
1532
1559
  """
1533
1560
  Integrates a constant, linear or bilinear form, and returns a scalar, array, or sparse matrix, respectively.
@@ -1551,6 +1578,7 @@ def integrate(
1551
1578
  - "dispatch": For linear or bilinear forms, first evaluate the form at quadrature points then dispatch to nodes in a second pass. More efficient for integrands that are expensive to evaluate. Incompatible with `at_node` operator on test or trial functions.
1552
1579
  - `None` (default): Automatically picks a suitable assembly strategy (either "generic" or "dispatch")
1553
1580
  add: If True and `output` is provided, add the integration result to `output` instead of replacing its content
1581
+ bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
1554
1582
  """
1555
1583
  if fields is None:
1556
1584
  fields = {}
@@ -1663,6 +1691,7 @@ def integrate(
1663
1691
  output_dtype=output_dtype,
1664
1692
  output=output,
1665
1693
  add_to_output=add,
1694
+ bsr_options=bsr_options,
1666
1695
  device=device,
1667
1696
  )
1668
1697
 
@@ -1808,53 +1837,128 @@ def get_interpolate_at_quadrature_kernel(
1808
1837
  ):
1809
1838
  def interpolate_at_quadrature_nonvalued_kernel_fn(
1810
1839
  qp_arg: quadrature.Arg,
1840
+ qp_element_index_arg: quadrature.ElementIndexArg,
1811
1841
  domain_arg: quadrature.domain.ElementArg,
1812
1842
  domain_index_arg: quadrature.domain.ElementIndexArg,
1813
1843
  fields: FieldStruct,
1814
1844
  values: ValueStruct,
1815
1845
  result: wp.array(dtype=float),
1816
1846
  ):
1817
- domain_element_index = wp.tid()
1847
+ qp_eval_index = wp.tid()
1848
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1849
+ if domain_element_index == NULL_ELEMENT_INDEX:
1850
+ return
1851
+
1818
1852
  element_index = domain.element_index(domain_index_arg, domain_element_index)
1819
1853
 
1820
1854
  test_dof_index = NULL_DOF_INDEX
1821
1855
  trial_dof_index = NULL_DOF_INDEX
1822
1856
 
1823
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
1824
- for k in range(qp_point_count):
1825
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
1826
- coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
1827
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
1857
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1858
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1859
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1828
1860
 
1829
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1830
- integrand_func(sample, fields, values)
1861
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1862
+ integrand_func(sample, fields, values)
1831
1863
 
1832
1864
  def interpolate_at_quadrature_kernel_fn(
1833
1865
  qp_arg: quadrature.Arg,
1866
+ qp_element_index_arg: quadrature.ElementIndexArg,
1834
1867
  domain_arg: quadrature.domain.ElementArg,
1835
1868
  domain_index_arg: quadrature.domain.ElementIndexArg,
1836
1869
  fields: FieldStruct,
1837
1870
  values: ValueStruct,
1838
1871
  result: wp.array(dtype=value_type),
1839
1872
  ):
1840
- domain_element_index = wp.tid()
1873
+ qp_eval_index = wp.tid()
1874
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1875
+ if domain_element_index == NULL_ELEMENT_INDEX:
1876
+ return
1877
+
1841
1878
  element_index = domain.element_index(domain_index_arg, domain_element_index)
1842
1879
 
1843
1880
  test_dof_index = NULL_DOF_INDEX
1844
1881
  trial_dof_index = NULL_DOF_INDEX
1845
1882
 
1846
- qp_point_count = quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index)
1847
- for k in range(qp_point_count):
1848
- qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, k)
1849
- coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, k)
1850
- qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, k)
1883
+ coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1884
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1885
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1851
1886
 
1852
- sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1853
- result[qp_index] = integrand_func(sample, fields, values)
1887
+ sample = Sample(element_index, coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1888
+ result[qp_index] = integrand_func(sample, fields, values)
1854
1889
 
1855
1890
  return interpolate_at_quadrature_nonvalued_kernel_fn if value_type is None else interpolate_at_quadrature_kernel_fn
1856
1891
 
1857
1892
 
1893
+ def get_interpolate_jacobian_at_quadrature_kernel(
1894
+ integrand_func: wp.Function,
1895
+ domain: GeometryDomain,
1896
+ quadrature: Quadrature,
1897
+ FieldStruct: wp.codegen.Struct,
1898
+ ValueStruct: wp.codegen.Struct,
1899
+ trial: TrialField,
1900
+ value_size: int,
1901
+ value_type: type,
1902
+ ):
1903
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
1904
+ VALUE_SIZE = wp.constant(value_size)
1905
+
1906
+ def interpolate_jacobian_kernel_fn(
1907
+ qp_arg: quadrature.Arg,
1908
+ qp_element_index_arg: quadrature.ElementIndexArg,
1909
+ domain_arg: domain.ElementArg,
1910
+ domain_index_arg: domain.ElementIndexArg,
1911
+ trial_partition_arg: trial.space_partition.PartitionArg,
1912
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
1913
+ fields: FieldStruct,
1914
+ values: ValueStruct,
1915
+ triplet_rows: wp.array(dtype=int),
1916
+ triplet_cols: wp.array(dtype=int),
1917
+ triplet_values: wp.array3d(dtype=value_type),
1918
+ ):
1919
+ qp_eval_index, trial_node, trial_dof = wp.tid()
1920
+ domain_element_index, qp = quadrature.evaluation_point_element_index(qp_element_index_arg, qp_eval_index)
1921
+
1922
+ if domain_element_index == NULL_ELEMENT_INDEX:
1923
+ return
1924
+
1925
+ element_index = domain.element_index(domain_index_arg, domain_element_index)
1926
+ if qp >= quadrature.point_count(domain_arg, qp_arg, domain_element_index, element_index):
1927
+ return
1928
+
1929
+ element_trial_node_count = trial.space.topology.element_node_count(
1930
+ domain_arg, trial_topology_arg, element_index
1931
+ )
1932
+
1933
+ qp_coords = quadrature.point_coords(domain_arg, qp_arg, domain_element_index, element_index, qp)
1934
+ qp_weight = quadrature.point_weight(domain_arg, qp_arg, domain_element_index, element_index, qp)
1935
+ qp_index = quadrature.point_index(domain_arg, qp_arg, domain_element_index, element_index, qp)
1936
+
1937
+ block_offset = qp_index * MAX_NODES_PER_ELEMENT + trial_node
1938
+
1939
+ test_dof_index = NULL_DOF_INDEX
1940
+ trial_dof_index = DofIndex(trial_node, trial_dof)
1941
+
1942
+ sample = Sample(element_index, qp_coords, qp_index, qp_weight, test_dof_index, trial_dof_index)
1943
+ val = integrand_func(sample, fields, values)
1944
+
1945
+ for k in range(VALUE_SIZE):
1946
+ triplet_values[block_offset, k, trial_dof] = basis_coefficient(val, k)
1947
+
1948
+ if trial_dof == 0:
1949
+ if trial_node < element_trial_node_count:
1950
+ trial_node_index = trial.space_partition.partition_node_index(
1951
+ trial_partition_arg,
1952
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
1953
+ )
1954
+ else:
1955
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
1956
+ triplet_rows[block_offset] = qp_index
1957
+ triplet_cols[block_offset] = trial_node_index
1958
+
1959
+ return interpolate_jacobian_kernel_fn
1960
+
1961
+
1858
1962
  def get_interpolate_free_kernel(
1859
1963
  integrand_func: wp.Function,
1860
1964
  domain: GeometryDomain,
@@ -1924,9 +2028,9 @@ def _generate_interpolate_kernel(
1924
2028
  dest_dtype = dest.dtype if dest else None
1925
2029
  type_str = wp.types.get_type_code(dest_dtype) if dest_dtype else ""
1926
2030
  if quadrature is None:
1927
- kernel_suffix = f"_itp_{field_names}_{type_str}"
2031
+ kernel_suffix = f"_itp_{field_names}_{domain.name}_{type_str}"
1928
2032
  else:
1929
- kernel_suffix = f"_itp_{field_names}_{quadrature.name}_{type_str}"
2033
+ kernel_suffix = f"_itp_{field_names}_{domain.name}_{quadrature.name}_{type_str}"
1930
2034
 
1931
2035
  kernel = cache.get_integrand_kernel(
1932
2036
  integrand=integrand,
@@ -1971,14 +2075,27 @@ def _generate_interpolate_kernel(
1971
2075
  ValueStruct=ValueStruct,
1972
2076
  )
1973
2077
  elif quadrature is not None:
1974
- interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
1975
- integrand_func,
1976
- domain=domain,
1977
- quadrature=quadrature,
1978
- value_type=dest_dtype,
1979
- FieldStruct=FieldStruct,
1980
- ValueStruct=ValueStruct,
1981
- )
2078
+ if arguments.trial_name:
2079
+ trial = arguments.field_args[arguments.trial_name]
2080
+ interpolate_kernel_fn = get_interpolate_jacobian_at_quadrature_kernel(
2081
+ integrand_func,
2082
+ domain=domain,
2083
+ quadrature=quadrature,
2084
+ FieldStruct=FieldStruct,
2085
+ ValueStruct=ValueStruct,
2086
+ trial=trial,
2087
+ value_size=dest.block_shape[0],
2088
+ value_type=dest.scalar_type,
2089
+ )
2090
+ else:
2091
+ interpolate_kernel_fn = get_interpolate_at_quadrature_kernel(
2092
+ integrand_func,
2093
+ domain=domain,
2094
+ quadrature=quadrature,
2095
+ value_type=dest_dtype,
2096
+ FieldStruct=FieldStruct,
2097
+ ValueStruct=ValueStruct,
2098
+ )
1982
2099
  else:
1983
2100
  interpolate_kernel_fn = get_interpolate_free_kernel(
1984
2101
  integrand_func,
@@ -2012,8 +2129,11 @@ def _launch_interpolate_kernel(
2012
2129
  dest: Optional[Union[FieldRestriction, wp.array]],
2013
2130
  quadrature: Optional[Quadrature],
2014
2131
  dim: int,
2132
+ trial: Optional[TrialField],
2015
2133
  fields: Dict[str, FieldLike],
2016
2134
  values: Dict[str, Any],
2135
+ temporary_store: Optional[cache.TemporaryStore],
2136
+ bsr_options: Optional[Dict[str, Any]],
2017
2137
  device,
2018
2138
  ) -> wp.Kernel:
2019
2139
  # Set-up launch arguments
@@ -2044,21 +2164,74 @@ def _launch_interpolate_kernel(
2044
2164
  ],
2045
2165
  device=device,
2046
2166
  )
2047
- elif quadrature is not None:
2048
- qp_arg = quadrature.arg_value(device)
2167
+ return
2168
+
2169
+ if quadrature is None:
2049
2170
  wp.launch(
2050
2171
  kernel=kernel,
2051
- dim=domain.element_count(),
2052
- inputs=[qp_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2172
+ dim=dim,
2173
+ inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
2053
2174
  device=device,
2054
2175
  )
2055
- else:
2176
+ return
2177
+
2178
+ qp_arg = quadrature.arg_value(device)
2179
+ qp_element_index_arg = quadrature.element_index_arg_value(device)
2180
+ if trial is None:
2056
2181
  wp.launch(
2057
2182
  kernel=kernel,
2058
- dim=dim,
2059
- inputs=[dim, elt_arg, field_arg_values, value_struct_values, dest],
2183
+ dim=quadrature.evaluation_point_count(),
2184
+ inputs=[qp_arg, qp_element_index_arg, elt_arg, elt_index_arg, field_arg_values, value_struct_values, dest],
2060
2185
  device=device,
2061
2186
  )
2187
+ return
2188
+
2189
+ nnz = quadrature.total_point_count() * trial.space.topology.MAX_NODES_PER_ELEMENT
2190
+
2191
+ if dest.nrow != quadrature.total_point_count() or dest.ncol != trial.space_partition.node_count():
2192
+ raise RuntimeError(
2193
+ f"'dest' matrix must have {quadrature.total_point_count()} rows and {trial.space_partition.node_count()} columns of blocks"
2194
+ )
2195
+ if dest.block_shape[1] != trial.node_dof_count:
2196
+ raise f"'dest' matrix blocks must have {trial.node_dof_count} columns"
2197
+
2198
+ triplet_rows_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2199
+ triplet_cols_temp = cache.borrow_temporary(temporary_store, shape=(nnz,), dtype=int, device=device)
2200
+ triplet_values_temp = cache.borrow_temporary(
2201
+ temporary_store,
2202
+ dtype=dest.scalar_type,
2203
+ shape=(nnz, *dest.block_shape),
2204
+ device=device,
2205
+ )
2206
+ triplet_cols = triplet_cols_temp.array
2207
+ triplet_rows = triplet_rows_temp.array
2208
+ triplet_values = triplet_values_temp.array
2209
+ triplet_rows.fill_(-1)
2210
+ triplet_values.zero_()
2211
+
2212
+ trial_partition_arg = trial.space_partition.partition_arg_value(device)
2213
+ trial_topology_arg = trial.space_partition.space_topology.topo_arg_value(device)
2214
+
2215
+ wp.launch(
2216
+ kernel=kernel,
2217
+ dim=(quadrature.evaluation_point_count(), trial.space.topology.MAX_NODES_PER_ELEMENT, trial.node_dof_count),
2218
+ inputs=[
2219
+ qp_arg,
2220
+ qp_element_index_arg,
2221
+ elt_arg,
2222
+ elt_index_arg,
2223
+ trial_partition_arg,
2224
+ trial_topology_arg,
2225
+ field_arg_values,
2226
+ value_struct_values,
2227
+ triplet_rows,
2228
+ triplet_cols,
2229
+ triplet_values,
2230
+ ],
2231
+ device=device,
2232
+ )
2233
+
2234
+ bsr_set_from_triplets(dest, triplet_rows, triplet_cols, triplet_values, **(bsr_options or {}))
2062
2235
 
2063
2236
 
2064
2237
  @integrand
@@ -2076,6 +2249,8 @@ def interpolate(
2076
2249
  values: Optional[Dict[str, Any]] = None,
2077
2250
  device=None,
2078
2251
  kernel_options: Optional[Dict[str, Any]] = None,
2252
+ temporary_store: Optional[cache.TemporaryStore] = None,
2253
+ bsr_options: Optional[Dict[str, Any]] = None,
2079
2254
  ):
2080
2255
  """
2081
2256
  Interpolates a function at a finite set of sample points and optionally assigns the result to a discrete field or a raw warp array.
@@ -2094,6 +2269,8 @@ def interpolate(
2094
2269
  values: Additional variable values to be passed to the integrand, can be of any type accepted by warp kernel launches. Keys in the dictionary must match integrand parameter names.
2095
2270
  device: Device on which to perform the interpolation
2096
2271
  kernel_options: Overloaded options to be passed to the kernel builder (e.g, ``{"enable_backward": True}``)
2272
+ temporary_store: shared pool from which to allocate temporary arrays
2273
+ bsr_options: Additional options to be passed to the sparse matrix construction algorithm. See :func:`warp.sparse.bsr_set_from_triplets()`
2097
2274
  """
2098
2275
 
2099
2276
  if isinstance(integrand, FieldLike):
@@ -2111,8 +2288,12 @@ def interpolate(
2111
2288
  raise ValueError("integrand must be tagged with @integrand decorator")
2112
2289
 
2113
2290
  arguments = _parse_integrand_arguments(integrand, fields)
2114
- if arguments.test_name or arguments.trial_name:
2115
- raise ValueError("Test or Trial fields should not be used for interpolation")
2291
+ if arguments.test_name:
2292
+ raise ValueError(f"Test field '{arguments.test_name}' maybe not be used for interpolation")
2293
+ if arguments.trial_name and (quadrature is None or not isinstance(dest, BsrMatrix)):
2294
+ raise ValueError(
2295
+ f"Interpolation using trial field '{arguments.trial_name}' requires 'quadrature' to be provided and 'dest' to be a `warp.sparse.BsrMatrix`"
2296
+ )
2116
2297
 
2117
2298
  if isinstance(dest, DiscreteField):
2118
2299
  dest = make_restriction(dest, domain=domain)
@@ -2145,7 +2326,10 @@ def interpolate(
2145
2326
  dest=dest,
2146
2327
  quadrature=quadrature,
2147
2328
  dim=dim,
2329
+ trial=fields.get(arguments.trial_name),
2148
2330
  fields=arguments.field_args,
2149
2331
  values=values,
2332
+ temporary_store=temporary_store,
2333
+ bsr_options=bsr_options,
2150
2334
  device=device,
2151
2335
  )