warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,848 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Set
17
+
18
+ import warp as wp
19
+ import warp.fem.operator as operator
20
+ from warp.fem import cache
21
+ from warp.fem.domain import GeometryDomain
22
+ from warp.fem.linalg import basis_coefficient, generalized_inner, generalized_outer
23
+ from warp.fem.quadrature import Quadrature
24
+ from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction
25
+ from warp.fem.types import NULL_NODE_INDEX, DofIndex, Sample, get_node_coord, get_node_index_in_element
26
+
27
+ from .field import SpaceField
28
+
29
+
30
+ class AdjointField(SpaceField):
31
+ """Adjoint of a discrete field with respect to its degrees of freedom"""
32
+
33
+ def __init__(self, space: FunctionSpace, space_partition: SpaceRestriction):
34
+ super().__init__(space, space_partition=space_partition)
35
+
36
+ self.node_dof_count = self.space.NODE_DOF_COUNT
37
+ self.value_dof_count = self.space.VALUE_DOF_COUNT
38
+
39
+ self.EvalArg = self.space.SpaceArg
40
+ self.ElementEvalArg = self._make_element_eval_arg()
41
+
42
+ self.eval_arg_value = self.space.space_arg_value
43
+
44
+ self.eval_degree = self._make_eval_degree()
45
+ self.eval_inner = self._make_eval_inner()
46
+ self.eval_grad_inner = self._make_eval_grad_inner()
47
+ self.eval_div_inner = self._make_eval_div_inner()
48
+ self.eval_outer = self._make_eval_outer()
49
+ self.eval_grad_outer = self._make_eval_grad_outer()
50
+ self.eval_div_outer = self._make_eval_div_outer()
51
+ self.at_node = self._make_at_node()
52
+
53
+ @property
54
+ def name(self) -> str:
55
+ return f"{self.__class__.__name__}{self.space.name}{self._space_partition.name}"
56
+
57
+ def _make_element_eval_arg(self):
58
+ from warp.fem import cache
59
+
60
+ @cache.dynamic_struct(suffix=self.name)
61
+ class ElementEvalArg:
62
+ elt_arg: self.space.topology.ElementArg
63
+ eval_arg: self.EvalArg
64
+
65
+ return ElementEvalArg
66
+
67
+ def _make_eval_inner(self):
68
+ @cache.dynamic_func(suffix=self.name)
69
+ def eval_test_inner(args: self.ElementEvalArg, s: Sample):
70
+ dof = self._get_dof(s)
71
+ node_weight = self.space.element_inner_weight(
72
+ args.elt_arg,
73
+ args.eval_arg,
74
+ s.element_index,
75
+ s.element_coords,
76
+ get_node_index_in_element(dof),
77
+ s.qp_index,
78
+ )
79
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
80
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
81
+ return self.space.space_value(dof_value, node_weight, local_value_map)
82
+
83
+ return eval_test_inner
84
+
85
+ def _make_eval_grad_inner(self):
86
+ if not self.space.gradient_valid():
87
+ return None
88
+
89
+ @cache.dynamic_func(suffix=self.name)
90
+ def eval_grad_inner(args: self.ElementEvalArg, s: Sample):
91
+ dof = self._get_dof(s)
92
+ nabla_weight = self.space.element_inner_weight_gradient(
93
+ args.elt_arg,
94
+ args.eval_arg,
95
+ s.element_index,
96
+ s.element_coords,
97
+ get_node_index_in_element(dof),
98
+ s.qp_index,
99
+ )
100
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
101
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
102
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
103
+ return self.space.space_gradient(dof_value, nabla_weight, local_value_map, grad_transform)
104
+
105
+ return eval_grad_inner
106
+
107
+ def _make_eval_div_inner(self):
108
+ if not self.space.divergence_valid():
109
+ return None
110
+
111
+ @cache.dynamic_func(suffix=self.name)
112
+ def eval_div_inner(args: self.ElementEvalArg, s: Sample):
113
+ dof = self._get_dof(s)
114
+ nabla_weight = self.space.element_inner_weight_gradient(
115
+ args.elt_arg,
116
+ args.eval_arg,
117
+ s.element_index,
118
+ s.element_coords,
119
+ get_node_index_in_element(dof),
120
+ s.qp_index,
121
+ )
122
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
123
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
124
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
125
+ return self.space.space_divergence(dof_value, nabla_weight, local_value_map, grad_transform)
126
+
127
+ return eval_div_inner
128
+
129
+ def _make_eval_outer(self):
130
+ @cache.dynamic_func(suffix=self.name)
131
+ def eval_test_outer(args: self.ElementEvalArg, s: Sample):
132
+ dof = self._get_dof(s)
133
+ node_weight = self.space.element_outer_weight(
134
+ args.elt_arg,
135
+ args.eval_arg,
136
+ s.element_index,
137
+ s.element_coords,
138
+ get_node_index_in_element(dof),
139
+ s.qp_index,
140
+ )
141
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
142
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
143
+ return self.space.space_value(dof_value, node_weight, local_value_map)
144
+
145
+ return eval_test_outer
146
+
147
+ def _make_eval_grad_outer(self):
148
+ if not self.space.gradient_valid():
149
+ return None
150
+
151
+ @cache.dynamic_func(suffix=self.name)
152
+ def eval_grad_outer(args: self.ElementEvalArg, s: Sample):
153
+ dof = self._get_dof(s)
154
+ nabla_weight = self.space.element_outer_weight_gradient(
155
+ args.elt_arg,
156
+ args.eval_arg,
157
+ s.element_index,
158
+ s.element_coords,
159
+ get_node_index_in_element(dof),
160
+ s.qp_index,
161
+ )
162
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
163
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
164
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
165
+ return self.space.space_gradient(dof_value, nabla_weight, local_value_map, grad_transform)
166
+
167
+ return eval_grad_outer
168
+
169
+ def _make_eval_div_outer(self):
170
+ if not self.space.divergence_valid():
171
+ return None
172
+
173
+ @cache.dynamic_func(suffix=self.name)
174
+ def eval_div_outer(args: self.ElementEvalArg, s: Sample):
175
+ dof = self._get_dof(s)
176
+ nabla_weight = self.space.element_outer_weight_gradient(
177
+ args.elt_arg,
178
+ args.eval_arg,
179
+ s.element_index,
180
+ s.element_coords,
181
+ get_node_index_in_element(dof),
182
+ s.qp_index,
183
+ )
184
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
185
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
186
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
187
+ return self.space.space_divergence(dof_value, nabla_weight, local_value_map, grad_transform)
188
+
189
+ return eval_div_outer
190
+
191
+ def _make_at_node(self):
192
+ @cache.dynamic_func(suffix=self.name)
193
+ def at_node(args: self.ElementEvalArg, s: Sample):
194
+ dof = self._get_dof(s)
195
+ node_coords = self.space.node_coords_in_element(
196
+ args.elt_arg, args.eval_arg, s.element_index, get_node_index_in_element(dof)
197
+ )
198
+ return Sample(s.element_index, node_coords, s.qp_index, s.qp_weight, s.test_dof, s.trial_dof)
199
+
200
+ return at_node
201
+
202
+
203
+ class TestField(AdjointField):
204
+ """Field defined over a space restriction that can be used as a test function.
205
+
206
+ In order to reuse computations, it is possible to define the test field using a SpaceRestriction
207
+ defined for a different value type than the test function value type, as long as the node topology is similar.
208
+ """
209
+
210
+ def __init__(self, space_restriction: SpaceRestriction, space: FunctionSpace):
211
+ if space_restriction.domain.dimension == space.dimension - 1:
212
+ space = space.trace()
213
+
214
+ if space_restriction.domain.dimension != space.dimension:
215
+ raise ValueError("Incompatible space and domain dimensions")
216
+
217
+ if space.topology != space_restriction.space_topology:
218
+ raise ValueError("Incompatible space and space partition topologies")
219
+
220
+ super().__init__(space, space_restriction.space_partition)
221
+
222
+ self.space_restriction = space_restriction
223
+ self.domain = space_restriction.domain
224
+
225
+ @wp.func
226
+ def _get_dof(s: Sample):
227
+ return s.test_dof
228
+
229
+
230
+ class TrialField(AdjointField):
231
+ """Field defined over a domain that can be used as a trial function"""
232
+
233
+ def __init__(
234
+ self,
235
+ space: FunctionSpace,
236
+ space_partition: SpacePartition,
237
+ domain: GeometryDomain,
238
+ ):
239
+ if domain.dimension == space.dimension - 1:
240
+ space = space.trace()
241
+
242
+ if domain.dimension != space.dimension:
243
+ raise ValueError("Incompatible space and domain dimensions")
244
+
245
+ if not space.topology.is_derived_from(space_partition.space_topology):
246
+ raise ValueError("Incompatible space and space partition topologies")
247
+
248
+ super().__init__(space, space_partition)
249
+ self.domain = domain
250
+
251
+ def partition_node_count(self) -> int:
252
+ """Returns the number of nodes in the associated space topology partition"""
253
+ return self.space_partition.node_count()
254
+
255
+ @wp.func
256
+ def _get_dof(s: Sample):
257
+ return s.trial_dof
258
+
259
+
260
+ class LocalAdjointField(SpaceField):
261
+ """
262
+ A custom field specially for dispatched assembly.
263
+ Stores adjoint and gradient adjoint at quadrature point locations.
264
+ """
265
+
266
+ INNER_DOF = wp.constant(0)
267
+ OUTER_DOF = wp.constant(1)
268
+ INNER_GRAD_DOF = wp.constant(2)
269
+ OUTER_GRAD_DOF = wp.constant(3)
270
+ DOF_TYPE_COUNT = wp.constant(4)
271
+
272
+ _OP_DOF_MAP_CONTINUOUS = {
273
+ operator.inner: INNER_DOF,
274
+ operator.outer: INNER_DOF,
275
+ operator.grad: INNER_GRAD_DOF,
276
+ operator.grad_outer: INNER_GRAD_DOF,
277
+ operator.div: INNER_GRAD_DOF,
278
+ operator.div_outer: INNER_GRAD_DOF,
279
+ }
280
+
281
+ _OP_DOF_MAP_DISCONTINUOUS = {
282
+ operator.inner: INNER_DOF,
283
+ operator.outer: OUTER_DOF,
284
+ operator.grad: INNER_GRAD_DOF,
285
+ operator.grad_outer: OUTER_GRAD_DOF,
286
+ operator.div: INNER_GRAD_DOF,
287
+ operator.div_outer: OUTER_GRAD_DOF,
288
+ }
289
+
290
+ DofOffsets = wp.vec(length=DOF_TYPE_COUNT, dtype=int)
291
+
292
+ @wp.struct
293
+ class EvalArg:
294
+ pass
295
+
296
+ def __init__(self, field: AdjointField):
297
+ # if not isinstance(field.space, CollocatedFunctionSpace):
298
+ # raise NotImplementedError("Local assembly only implemented for collocated function spaces")
299
+
300
+ super().__init__(field.space, space_partition=field.space_partition)
301
+ self.global_field = field
302
+
303
+ self.domain = self.global_field.domain
304
+ self.node_dof_count = self.space.NODE_DOF_COUNT
305
+ self.value_dof_count = self.space.VALUE_DOF_COUNT
306
+
307
+ self._dof_suffix = ""
308
+
309
+ self.ElementEvalArg = self._make_element_eval_arg()
310
+ self.eval_degree = self._make_eval_degree()
311
+ self.at_node = None
312
+
313
+ self._is_discontinuous = (self.space.element_inner_weight != self.space.element_outer_weight) or (
314
+ self.space.element_inner_weight_gradient != self.space.element_outer_weight_gradient
315
+ )
316
+
317
+ self._TAYLOR_DOF_OFFSETS = LocalAdjointField.DofOffsets(0)
318
+ self._TAYLOR_DOF_COUNTS = LocalAdjointField.DofOffsets(0)
319
+ self.TAYLOR_DOF_COUNT = 0
320
+
321
+ def notify_operator_usage(self, ops: Set[operator.Operator]):
322
+ # Rebuild degrees-of-freedom offsets based on used operators
323
+
324
+ operators_dof_map = (
325
+ LocalAdjointField._OP_DOF_MAP_DISCONTINUOUS
326
+ if self._is_discontinuous
327
+ else LocalAdjointField._OP_DOF_MAP_CONTINUOUS
328
+ )
329
+
330
+ dof_counts = LocalAdjointField.DofOffsets(0)
331
+ for op in ops:
332
+ if op in operators_dof_map:
333
+ dof_counts[operators_dof_map[op]] = 1
334
+
335
+ grad_dim = self.geometry.cell_dimension
336
+ dof_counts[LocalAdjointField.INNER_GRAD_DOF] *= grad_dim
337
+ dof_counts[LocalAdjointField.OUTER_GRAD_DOF] *= grad_dim
338
+
339
+ dof_offsets = LocalAdjointField.DofOffsets(0)
340
+ for k in range(1, LocalAdjointField.DOF_TYPE_COUNT):
341
+ dof_offsets[k] = dof_offsets[k - 1] + dof_counts[k - 1]
342
+
343
+ self.TAYLOR_DOF_COUNT = wp.constant(dof_offsets[k] + dof_counts[k])
344
+
345
+ self._TAYLOR_DOF_OFFSETS = dof_offsets
346
+ self._TAYLOR_DOF_COUNTS = dof_counts
347
+
348
+ self._dof_suffix = "".join(str(c) for c in dof_counts)
349
+
350
+ self._split_dof = self._make_split_dof()
351
+
352
+ self.eval_inner = self._make_eval_inner()
353
+ self.eval_grad_inner = self._make_eval_grad_inner()
354
+ self.eval_div_inner = self._make_eval_div_inner()
355
+
356
+ if self._is_discontinuous:
357
+ self.eval_outer = self._make_eval_outer()
358
+ self.eval_grad_outer = self._make_eval_grad_outer()
359
+ self.eval_div_outer = self._make_eval_div_outer()
360
+ else:
361
+ self.eval_outer = self.eval_inner
362
+ self.eval_grad_outer = self.eval_grad_inner
363
+ self.eval_div_outer = self.eval_div_inner
364
+
365
+ @property
366
+ def name(self) -> str:
367
+ return f"{self.global_field.name}_Taylor{self._dof_suffix}"
368
+
369
+ def eval_arg_value(self, device):
370
+ return LocalAdjointField.EvalArg()
371
+
372
+ def _make_element_eval_arg(self):
373
+ from warp.fem import cache
374
+
375
+ @cache.dynamic_struct(suffix=self.name)
376
+ class ElementEvalArg:
377
+ elt_arg: self.space.topology.ElementArg
378
+ eval_arg: self.EvalArg
379
+
380
+ return ElementEvalArg
381
+
382
+ def _make_split_dof(self):
383
+ TAYLOR_DOF_COUNT = self.TAYLOR_DOF_COUNT
384
+
385
+ @cache.dynamic_func(suffix=str(TAYLOR_DOF_COUNT))
386
+ def split_dof(dof_index: DofIndex, dof_begin: int):
387
+ taylor_dof = get_node_index_in_element(dof_index) - dof_begin
388
+ value_dof = get_node_coord(dof_index)
389
+ return value_dof, taylor_dof
390
+
391
+ return split_dof
392
+
393
+ def _make_eval_inner(self):
394
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF])
395
+
396
+ @cache.dynamic_func(suffix=self.name)
397
+ def eval_test_inner(args: self.ElementEvalArg, s: Sample):
398
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
399
+
400
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
401
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
402
+ return wp.where(taylor_dof == 0, dof_value, self.dtype(0.0))
403
+
404
+ return eval_test_inner
405
+
406
+ def _make_eval_grad_inner(self):
407
+ if not self.gradient_valid():
408
+ return None
409
+
410
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF])
411
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF])
412
+
413
+ @cache.dynamic_func(suffix=self.name)
414
+ def eval_nabla_test_inner(args: self.ElementEvalArg, s: Sample):
415
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
416
+
417
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
418
+ return self.gradient_dtype(0.0)
419
+
420
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
421
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
422
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
423
+ return generalized_outer(dof_value, grad_transform[taylor_dof])
424
+
425
+ return eval_nabla_test_inner
426
+
427
+ def _make_eval_div_inner(self):
428
+ if not self.divergence_valid():
429
+ return None
430
+
431
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF])
432
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF])
433
+
434
+ @cache.dynamic_func(suffix=self.name)
435
+ def eval_div_test_inner(args: self.ElementEvalArg, s: Sample):
436
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
437
+
438
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
439
+ return self.divergence_dtype(0.0)
440
+
441
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
442
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
443
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
444
+ return generalized_inner(dof_value, grad_transform[taylor_dof])
445
+
446
+ return eval_div_test_inner
447
+
448
+ def _make_eval_outer(self):
449
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF])
450
+
451
+ @cache.dynamic_func(suffix=self.name)
452
+ def eval_test_outer(args: self.ElementEvalArg, s: Sample):
453
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
454
+
455
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
456
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
457
+ return wp.where(taylor_dof == 0, dof_value, self.dtype(0.0))
458
+
459
+ return eval_test_outer
460
+
461
+ def _make_eval_grad_outer(self):
462
+ if not self.gradient_valid():
463
+ return None
464
+
465
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF])
466
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF])
467
+
468
+ @cache.dynamic_func(suffix=self.name)
469
+ def eval_nabla_test_outer(args: self.ElementEvalArg, s: Sample):
470
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
471
+
472
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
473
+ return self.gradient_dtype(0.0)
474
+
475
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
476
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
477
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
478
+ return generalized_outer(dof_value, grad_transform[taylor_dof])
479
+
480
+ return eval_nabla_test_outer
481
+
482
+ def _make_eval_div_outer(self):
483
+ if not self.divergence_valid():
484
+ return None
485
+
486
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF])
487
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF])
488
+
489
+ @cache.dynamic_func(suffix=self.name)
490
+ def eval_div_test_outer(args: self.ElementEvalArg, s: Sample):
491
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
492
+
493
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
494
+ return self.divergence_dtype(0.0)
495
+
496
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
497
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
498
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
499
+ return generalized_inner(dof_value, grad_transform[taylor_dof])
500
+
501
+ return eval_div_test_outer
502
+
503
+
504
+ class LocalTestField(LocalAdjointField):
505
+ def __init__(self, test_field: TestField):
506
+ super().__init__(test_field)
507
+ self.space_restriction = test_field.space_restriction
508
+
509
+ @wp.func
510
+ def _get_dof(s: Sample):
511
+ return s.test_dof
512
+
513
+
514
+ class LocalTrialField(LocalAdjointField):
515
+ def __init__(self, trial_field: TrialField):
516
+ super().__init__(trial_field)
517
+
518
+ @wp.func
519
+ def _get_dof(s: Sample):
520
+ return s.trial_dof
521
+
522
+
523
+ def make_linear_dispatch_kernel(test: LocalTestField, quadrature: Quadrature, accumulate_dtype: type):
524
+ global_test: TestField = test.global_field
525
+ space_restriction = global_test.space_restriction
526
+ domain = global_test.domain
527
+
528
+ TEST_INNER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
529
+ TEST_OUTER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
530
+ TEST_INNER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
531
+ TEST_OUTER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
532
+
533
+ TEST_INNER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
534
+ TEST_OUTER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
535
+ TEST_INNER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
536
+ TEST_OUTER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
537
+
538
+ TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
539
+
540
+ @cache.dynamic_kernel(f"{test.name}_{quadrature.name}_{wp.types.get_type_code(accumulate_dtype)}")
541
+ def dispatch_linear_kernel_fn(
542
+ qp_arg: quadrature.Arg,
543
+ domain_arg: domain.ElementArg,
544
+ domain_index_arg: domain.ElementIndexArg,
545
+ test_arg: space_restriction.NodeArg,
546
+ test_space_arg: test.space.SpaceArg,
547
+ local_result: wp.array3d(dtype=Any),
548
+ result: wp.array2d(dtype=Any),
549
+ ):
550
+ local_node_index, test_node_dof = wp.tid()
551
+ node_index = space_restriction.node_partition_index(test_arg, local_node_index)
552
+ element_beg, element_end = space_restriction.node_element_range(test_arg, node_index)
553
+
554
+ val_sum = accumulate_dtype(0.0)
555
+
556
+ for n in range(element_beg, element_end):
557
+ test_element_index = space_restriction.node_element_index(test_arg, n)
558
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
559
+
560
+ qp_point_count = quadrature.point_count(
561
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index
562
+ )
563
+ for k in range(qp_point_count):
564
+ qp_index = quadrature.point_index(
565
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
566
+ )
567
+ qp_eval_index = quadrature.point_evaluation_index(
568
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
569
+ )
570
+ coords = quadrature.point_coords(
571
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
572
+ )
573
+
574
+ qp_result = local_result[qp_eval_index]
575
+
576
+ qp_sum = float(0.0)
577
+
578
+ if wp.static(0 != TEST_INNER_COUNT):
579
+ w = test.space.element_inner_weight(
580
+ domain_arg,
581
+ test_space_arg,
582
+ element_index,
583
+ coords,
584
+ test_element_index.node_index_in_element,
585
+ qp_index,
586
+ )
587
+ for val_dof in range(TEST_NODE_DOF_DIM):
588
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
589
+ qp_sum += basis_coefficient(w, val_dof) * qp_result[TEST_INNER_BEGIN, test_dof]
590
+
591
+ if wp.static(0 != TEST_OUTER_COUNT):
592
+ w = test.space.element_outer_weight(
593
+ domain_arg,
594
+ test_space_arg,
595
+ element_index,
596
+ coords,
597
+ test_element_index.node_index_in_element,
598
+ qp_index,
599
+ )
600
+ for val_dof in range(TEST_NODE_DOF_DIM):
601
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
602
+ qp_sum += basis_coefficient(w, val_dof) * qp_result[TEST_OUTER_BEGIN, test_dof]
603
+
604
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
605
+ w_grad = test.space.element_inner_weight_gradient(
606
+ domain_arg,
607
+ test_space_arg,
608
+ element_index,
609
+ coords,
610
+ test_element_index.node_index_in_element,
611
+ qp_index,
612
+ )
613
+ for val_dof in range(TEST_NODE_DOF_DIM):
614
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
615
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
616
+ qp_sum += (
617
+ basis_coefficient(w_grad, val_dof, grad_dof)
618
+ * qp_result[grad_dof + TEST_INNER_GRAD_BEGIN, test_dof]
619
+ )
620
+
621
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
622
+ w_grad = test.space.element_outer_weight_gradient(
623
+ domain_arg,
624
+ test_space_arg,
625
+ element_index,
626
+ coords,
627
+ test_element_index.node_index_in_element,
628
+ qp_index,
629
+ )
630
+ for val_dof in range(TEST_NODE_DOF_DIM):
631
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
632
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
633
+ qp_sum += (
634
+ basis_coefficient(w_grad, val_dof, grad_dof)
635
+ * qp_result[grad_dof + TEST_OUTER_GRAD_BEGIN, test_dof]
636
+ )
637
+
638
+ val_sum += accumulate_dtype(qp_sum)
639
+
640
+ result[node_index, test_node_dof] += result.dtype(val_sum)
641
+
642
+ return dispatch_linear_kernel_fn
643
+
644
+
645
+ def make_bilinear_dispatch_kernel(
646
+ test: LocalTestField, trial: LocalTrialField, quadrature: Quadrature, accumulate_dtype: type
647
+ ):
648
+ global_test: TestField = test.global_field
649
+ space_restriction = global_test.space_restriction
650
+ domain = global_test.domain
651
+
652
+ TEST_INNER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
653
+ TEST_OUTER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
654
+ TEST_INNER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
655
+ TEST_OUTER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
656
+
657
+ TEST_INNER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
658
+ TEST_OUTER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
659
+ TEST_INNER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
660
+ TEST_OUTER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
661
+
662
+ TRIAL_INNER_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
663
+ TRIAL_OUTER_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
664
+ TRIAL_INNER_GRAD_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
665
+ TRIAL_OUTER_GRAD_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
666
+
667
+ TRIAL_INNER_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
668
+ TRIAL_OUTER_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
669
+ TRIAL_INNER_GRAD_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
670
+ TRIAL_OUTER_GRAD_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
671
+
672
+ TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
673
+ TRIAL_NODE_DOF_DIM = trial.value_dof_count // trial.node_dof_count
674
+
675
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
676
+
677
+ trial_dof_vec = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
678
+
679
+ @cache.dynamic_kernel(f"{trial.name}_{test.name}_{quadrature.name}{wp.types.get_type_code(accumulate_dtype)}")
680
+ def dispatch_bilinear_kernel_fn(
681
+ qp_arg: quadrature.Arg,
682
+ domain_arg: domain.ElementArg,
683
+ domain_index_arg: domain.ElementIndexArg,
684
+ test_arg: test.space_restriction.NodeArg,
685
+ test_space_arg: test.space.SpaceArg,
686
+ trial_partition_arg: trial.space_partition.PartitionArg,
687
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
688
+ trial_space_arg: trial.space.SpaceArg,
689
+ local_result: wp.array4d(dtype=trial_dof_vec),
690
+ triplet_rows: wp.array(dtype=int),
691
+ triplet_cols: wp.array(dtype=int),
692
+ triplet_values: wp.array3d(dtype=Any),
693
+ ):
694
+ test_local_node_index, test_node_dof, trial_node_dof, trial_node = wp.tid()
695
+
696
+ test_node_index = space_restriction.node_partition_index(test_arg, test_local_node_index)
697
+ element_beg, element_end = space_restriction.node_element_range(test_arg, test_node_index)
698
+
699
+ for element in range(element_beg, element_end):
700
+ test_element_index = space_restriction.node_element_index(test_arg, element)
701
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
702
+ test_node = test_element_index.node_index_in_element
703
+
704
+ element_trial_node_count = trial.space.topology.element_node_count(
705
+ domain_arg, trial_topology_arg, element_index
706
+ )
707
+
708
+ qp_point_count = wp.where(
709
+ trial_node < element_trial_node_count,
710
+ quadrature.point_count(domain_arg, qp_arg, test_element_index.domain_element_index, element_index),
711
+ 0,
712
+ )
713
+
714
+ val_sum = accumulate_dtype(0.0)
715
+
716
+ for k in range(qp_point_count):
717
+ qp_index = quadrature.point_index(
718
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
719
+ )
720
+ qp_eval_index = quadrature.point_evaluation_index(
721
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
722
+ )
723
+ coords = quadrature.point_coords(
724
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
725
+ )
726
+
727
+ qp_result = local_result[qp_eval_index]
728
+ trial_result = float(0.0)
729
+
730
+ if wp.static(0 != TEST_INNER_COUNT):
731
+ w_test_inner = test.space.element_inner_weight(
732
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
733
+ )
734
+
735
+ if wp.static(0 != TEST_OUTER_COUNT):
736
+ w_test_outer = test.space.element_outer_weight(
737
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
738
+ )
739
+
740
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
741
+ w_test_grad_inner = test.space.element_inner_weight_gradient(
742
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
743
+ )
744
+
745
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
746
+ w_test_grad_outer = test.space.element_outer_weight_gradient(
747
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
748
+ )
749
+
750
+ if wp.static(0 != TRIAL_INNER_COUNT):
751
+ w_trial_inner = trial.space.element_inner_weight(
752
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
753
+ )
754
+
755
+ if wp.static(0 != TRIAL_OUTER_COUNT):
756
+ w_trial_outer = trial.space.element_outer_weight(
757
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
758
+ )
759
+
760
+ if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
761
+ w_trial_grad_inner = trial.space.element_inner_weight_gradient(
762
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
763
+ )
764
+
765
+ if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
766
+ w_trial_grad_outer = trial.space.element_outer_weight_gradient(
767
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
768
+ )
769
+
770
+ for trial_val_dof in range(TRIAL_NODE_DOF_DIM):
771
+ trial_dof = trial_node_dof * TRIAL_NODE_DOF_DIM + trial_val_dof
772
+ test_result = trial_dof_vec(0.0)
773
+
774
+ if wp.static(0 != TEST_INNER_COUNT):
775
+ for test_val_dof in range(TEST_NODE_DOF_DIM):
776
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
777
+ test_result += (
778
+ basis_coefficient(w_test_inner, test_val_dof)
779
+ * qp_result[test_dof, trial_dof, TEST_INNER_BEGIN]
780
+ )
781
+
782
+ if wp.static(0 != TEST_OUTER_COUNT):
783
+ for test_val_dof in range(TEST_NODE_DOF_DIM):
784
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
785
+ test_result += (
786
+ basis_coefficient(w_test_outer, test_val_dof)
787
+ * qp_result[test_dof, trial_dof, TEST_OUTER_BEGIN]
788
+ )
789
+
790
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
791
+ for test_val_dof in range(TEST_NODE_DOF_DIM):
792
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
793
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
794
+ test_result += (
795
+ basis_coefficient(w_test_grad_inner, test_val_dof, grad_dof)
796
+ * qp_result[test_dof, trial_dof, grad_dof + TEST_INNER_GRAD_BEGIN]
797
+ )
798
+
799
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
800
+ for test_val_dof in range(TEST_NODE_DOF_DIM):
801
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
802
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
803
+ test_result += (
804
+ basis_coefficient(w_test_grad_outer, test_val_dof, grad_dof)
805
+ * qp_result[test_dof, trial_dof, grad_dof + TEST_OUTER_GRAD_BEGIN]
806
+ )
807
+
808
+ if wp.static(0 != TRIAL_INNER_COUNT):
809
+ trial_result += basis_coefficient(w_trial_inner, trial_val_dof) * test_result[TRIAL_INNER_BEGIN]
810
+
811
+ if wp.static(0 != TRIAL_OUTER_COUNT):
812
+ trial_result += basis_coefficient(w_trial_outer, trial_val_dof) * test_result[TRIAL_OUTER_BEGIN]
813
+
814
+ if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
815
+ for grad_dof in range(TRIAL_INNER_GRAD_COUNT):
816
+ trial_result += (
817
+ basis_coefficient(w_trial_grad_inner, trial_val_dof, grad_dof)
818
+ * test_result[grad_dof + TRIAL_INNER_GRAD_BEGIN]
819
+ )
820
+
821
+ if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
822
+ for grad_dof in range(TRIAL_OUTER_GRAD_COUNT):
823
+ trial_result += (
824
+ basis_coefficient(w_trial_grad_outer, trial_val_dof, grad_dof)
825
+ * test_result[grad_dof + TRIAL_OUTER_GRAD_BEGIN]
826
+ )
827
+
828
+ val_sum += accumulate_dtype(trial_result)
829
+
830
+ block_offset = element * MAX_NODES_PER_ELEMENT + trial_node
831
+ triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(val_sum)
832
+
833
+ # Set row and column indices
834
+ if test_node_dof == 0 and trial_node_dof == 0:
835
+ if trial_node < element_trial_node_count:
836
+ trial_node_index = trial.space_partition.partition_node_index(
837
+ trial_partition_arg,
838
+ trial.space.topology.element_node_index(
839
+ domain_arg, trial_topology_arg, element_index, trial_node
840
+ ),
841
+ )
842
+ else:
843
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
844
+
845
+ triplet_rows[block_offset] = test_node_index
846
+ triplet_cols[block_offset] = trial_node_index
847
+
848
+ return dispatch_bilinear_kernel_fn