warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1021 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import cached_property
17
+ from typing import Any, ClassVar, Dict, Optional, Set
18
+
19
+ import warp as wp
20
+ import warp._src.fem.operator as operator
21
+ from warp._src.fem import cache
22
+ from warp._src.fem.domain import GeometryDomain
23
+ from warp._src.fem.linalg import basis_coefficient, generalized_inner, generalized_outer
24
+ from warp._src.fem.quadrature import Quadrature
25
+ from warp._src.fem.space import FunctionSpace, SpacePartition, SpaceRestriction
26
+ from warp._src.fem.types import (
27
+ NULL_ELEMENT_INDEX,
28
+ NULL_NODE_INDEX,
29
+ DofIndex,
30
+ ElementIndex,
31
+ NodeElementIndex,
32
+ Sample,
33
+ get_node_coord,
34
+ get_node_index_in_element,
35
+ )
36
+ from warp._src.fem.utils import type_zero_element
37
+
38
+ from .field import SpaceField
39
+
40
+ _wp_module_name_ = "warp.fem.field.virtual"
41
+
42
+
43
+ class AdjointField(SpaceField):
44
+ """Adjoint of a discrete field with respect to its degrees of freedom"""
45
+
46
+ _dynamic_attribute_constructors: ClassVar = {
47
+ "EvalArg": lambda obj: obj._make_eval_arg(),
48
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
49
+ "eval_degree": lambda obj: obj._make_eval_degree(),
50
+ "eval_inner": lambda obj: obj._make_eval_inner(),
51
+ "eval_grad_inner": lambda obj: obj._make_eval_grad_inner(),
52
+ "eval_div_inner": lambda obj: obj._make_eval_div_inner(),
53
+ "eval_outer": lambda obj: obj._make_eval_outer(),
54
+ "eval_grad_outer": lambda obj: obj._make_eval_grad_outer(),
55
+ "eval_div_outer": lambda obj: obj._make_eval_div_outer(),
56
+ "node_count": lambda obj: obj._make_node_count(),
57
+ "at_node": lambda obj: obj._make_at_node(),
58
+ "node_index": lambda obj: obj._make_node_index(),
59
+ }
60
+
61
+ def __init__(self, space: FunctionSpace, space_partition: SpacePartition, domain: GeometryDomain):
62
+ super().__init__(space, space_partition=space_partition)
63
+
64
+ self.node_dof_count = self.space.NODE_DOF_COUNT
65
+ self.value_dof_count = self.space.VALUE_DOF_COUNT
66
+ self.domain = domain
67
+
68
+ cache.setup_dynamic_attributes(self)
69
+
70
+ @cached_property
71
+ def name(self) -> str:
72
+ return f"{self.__class__.__name__}{self.space.name}{self._space_partition.name}"
73
+
74
+ @cache.cached_arg_value
75
+ def eval_arg_value(self, device):
76
+ return super().eval_arg_value(device)
77
+
78
+ def fill_eval_arg(self, arg, device):
79
+ self.space.fill_space_arg(arg.space_arg, device)
80
+ self.space.topology.fill_topo_arg(arg.topo_arg, device)
81
+
82
+ def rebind(self, space: FunctionSpace, space_partition: SpacePartition, domain: GeometryDomain):
83
+ """Rebind the field to a new space partition, space and domain.
84
+ The new space topology and space must be of similar types as the current ones
85
+ """
86
+
87
+ if (
88
+ space_partition.space_topology.name != self.space_partition.space_topology.name
89
+ or space.name != self.space.name
90
+ ):
91
+ raise ValueError("Incompatible space and/or space partition")
92
+
93
+ self._space = space
94
+ self._space_partition = space_partition
95
+ self.domain = domain
96
+
97
+ self.eval_arg_value.invalidate(self)
98
+
99
+ def _make_eval_arg(self):
100
+ @cache.dynamic_struct(suffix=self.name)
101
+ class EvalArg:
102
+ space_arg: self.space.SpaceArg
103
+ topo_arg: self.space.topology.TopologyArg
104
+
105
+ return EvalArg
106
+
107
+ def _make_element_eval_arg(self):
108
+ @cache.dynamic_struct(suffix=self.name)
109
+ class ElementEvalArg:
110
+ elt_arg: self.space.topology.ElementArg
111
+ eval_arg: self.EvalArg
112
+
113
+ return ElementEvalArg
114
+
115
+ def _make_eval_inner(self):
116
+ @cache.dynamic_func(suffix=self.name)
117
+ def eval_test_inner(args: self.ElementEvalArg, s: Sample):
118
+ dof = self._get_dof(s)
119
+ node_weight = self.space.element_inner_weight(
120
+ args.elt_arg,
121
+ args.eval_arg.space_arg,
122
+ s.element_index,
123
+ s.element_coords,
124
+ get_node_index_in_element(dof),
125
+ s.qp_index,
126
+ )
127
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
128
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
129
+ return self.space.space_value(dof_value, node_weight, local_value_map)
130
+
131
+ return eval_test_inner
132
+
133
+ def _make_eval_grad_inner(self):
134
+ if not self.space.gradient_valid():
135
+ return None
136
+
137
+ @cache.dynamic_func(suffix=self.name)
138
+ def eval_grad_inner(args: self.ElementEvalArg, s: Sample):
139
+ dof = self._get_dof(s)
140
+ nabla_weight = self.space.element_inner_weight_gradient(
141
+ args.elt_arg,
142
+ args.eval_arg.space_arg,
143
+ s.element_index,
144
+ s.element_coords,
145
+ get_node_index_in_element(dof),
146
+ s.qp_index,
147
+ )
148
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
149
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
150
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
151
+ return self.space.space_gradient(dof_value, nabla_weight, local_value_map, grad_transform)
152
+
153
+ return eval_grad_inner
154
+
155
+ def _make_eval_div_inner(self):
156
+ if not self.space.divergence_valid():
157
+ return None
158
+
159
+ @cache.dynamic_func(suffix=self.name)
160
+ def eval_div_inner(args: self.ElementEvalArg, s: Sample):
161
+ dof = self._get_dof(s)
162
+ nabla_weight = self.space.element_inner_weight_gradient(
163
+ args.elt_arg,
164
+ args.eval_arg.space_arg,
165
+ s.element_index,
166
+ s.element_coords,
167
+ get_node_index_in_element(dof),
168
+ s.qp_index,
169
+ )
170
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
171
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
172
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
173
+ return self.space.space_divergence(dof_value, nabla_weight, local_value_map, grad_transform)
174
+
175
+ return eval_div_inner
176
+
177
+ def _make_eval_outer(self):
178
+ @cache.dynamic_func(suffix=self.name)
179
+ def eval_test_outer(args: self.ElementEvalArg, s: Sample):
180
+ dof = self._get_dof(s)
181
+ node_weight = self.space.element_outer_weight(
182
+ args.elt_arg,
183
+ args.eval_arg.space_arg,
184
+ s.element_index,
185
+ s.element_coords,
186
+ get_node_index_in_element(dof),
187
+ s.qp_index,
188
+ )
189
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
190
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
191
+ return self.space.space_value(dof_value, node_weight, local_value_map)
192
+
193
+ return eval_test_outer
194
+
195
+ def _make_eval_grad_outer(self):
196
+ if not self.space.gradient_valid():
197
+ return None
198
+
199
+ @cache.dynamic_func(suffix=self.name)
200
+ def eval_grad_outer(args: self.ElementEvalArg, s: Sample):
201
+ dof = self._get_dof(s)
202
+ nabla_weight = self.space.element_outer_weight_gradient(
203
+ args.elt_arg,
204
+ args.eval_arg.space_arg,
205
+ s.element_index,
206
+ s.element_coords,
207
+ get_node_index_in_element(dof),
208
+ s.qp_index,
209
+ )
210
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
211
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
212
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
213
+ return self.space.space_gradient(dof_value, nabla_weight, local_value_map, grad_transform)
214
+
215
+ return eval_grad_outer
216
+
217
+ def _make_eval_div_outer(self):
218
+ if not self.space.divergence_valid():
219
+ return None
220
+
221
+ @cache.dynamic_func(suffix=self.name)
222
+ def eval_div_outer(args: self.ElementEvalArg, s: Sample):
223
+ dof = self._get_dof(s)
224
+ nabla_weight = self.space.element_outer_weight_gradient(
225
+ args.elt_arg,
226
+ args.eval_arg.space_arg,
227
+ s.element_index,
228
+ s.element_coords,
229
+ get_node_index_in_element(dof),
230
+ s.qp_index,
231
+ )
232
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
233
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
234
+ dof_value = self.space.node_basis_element(get_node_coord(dof))
235
+ return self.space.space_divergence(dof_value, nabla_weight, local_value_map, grad_transform)
236
+
237
+ return eval_div_outer
238
+
239
+ def _make_at_node(self):
240
+ @cache.dynamic_func(suffix=self.name)
241
+ def at_node(args: self.ElementEvalArg, s: Sample):
242
+ dof = self._get_dof(s)
243
+ node_coords = self.space.node_coords_in_element(
244
+ args.elt_arg, args.eval_arg.space_arg, s.element_index, get_node_index_in_element(dof)
245
+ )
246
+ return Sample(s.element_index, node_coords, s.qp_index, s.qp_weight, s.test_dof, s.trial_dof)
247
+
248
+ return at_node
249
+
250
+ def _make_node_index(self):
251
+ @cache.dynamic_func(suffix=self.name)
252
+ def node_index(args: self.ElementEvalArg, s: Sample):
253
+ dof = self._get_dof(s)
254
+ node_idx = self.space.topology.element_node_index(
255
+ args.elt_arg, args.eval_arg.topo_arg, s.element_index, get_node_index_in_element(dof)
256
+ )
257
+ return node_idx
258
+
259
+ return node_index
260
+
261
+ def _make_node_count(self):
262
+ @cache.dynamic_func(suffix=self.name)
263
+ def node_count(args: self.ElementEvalArg, s: Sample):
264
+ return self.space.topology.element_node_count(args.elt_arg, args.eval_arg.topo_arg, s.element_index)
265
+
266
+ return node_count
267
+
268
+
269
+ class TestField(AdjointField):
270
+ """Field defined over a space restriction that can be used as a test function.
271
+
272
+ In order to reuse computations, it is possible to define the test field using a SpaceRestriction
273
+ defined for a different value type than the test function value type, as long as the node topology is similar.
274
+ """
275
+
276
+ def __init__(self, space: FunctionSpace, space_restriction: SpaceRestriction):
277
+ if space_restriction.domain.dimension == space.dimension - 1:
278
+ space = space.trace()
279
+
280
+ if space_restriction.domain.dimension != space.dimension:
281
+ raise ValueError("Incompatible space and domain dimensions")
282
+
283
+ if space.topology != space_restriction.space_topology:
284
+ raise ValueError("Incompatible space and space partition topologies")
285
+
286
+ super().__init__(space, space_restriction.space_partition, space_restriction.domain)
287
+
288
+ self.space_restriction = space_restriction
289
+
290
+ def rebind(self, space: FunctionSpace, space_restriction: SpaceRestriction):
291
+ """Rebind the test field to a new space restriction and space.
292
+ The new space restriction and space must be of a similar type as the current ones
293
+ """
294
+
295
+ super().rebind(space, space_restriction.space_partition, space_restriction.domain)
296
+ self.space_restriction = space_restriction
297
+
298
+ @wp.func
299
+ def _get_dof(s: Sample):
300
+ return s.test_dof
301
+
302
+
303
+ class TrialField(AdjointField):
304
+ """Field defined over a domain that can be used as a trial function"""
305
+
306
+ def __init__(
307
+ self,
308
+ space: FunctionSpace,
309
+ space_partition: SpacePartition,
310
+ domain: GeometryDomain,
311
+ ):
312
+ if domain.dimension == space.dimension - 1:
313
+ space = space.trace()
314
+
315
+ if domain.dimension != space.dimension:
316
+ raise ValueError("Incompatible space and domain dimensions")
317
+
318
+ if not space.topology.is_derived_from(space_partition.space_topology):
319
+ raise ValueError("Incompatible space and space partition topologies")
320
+
321
+ super().__init__(space, space_partition, domain)
322
+
323
+ def partition_node_count(self) -> int:
324
+ """Returns the number of nodes in the associated space topology partition"""
325
+ return self.space_partition.node_count()
326
+
327
+ @wp.func
328
+ def _get_dof(s: Sample):
329
+ return s.trial_dof
330
+
331
+
332
+ class LocalAdjointField(SpaceField):
333
+ """
334
+ A custom field specially for dispatched assembly.
335
+ Stores adjoint and gradient adjoint at quadrature point locations.
336
+ """
337
+
338
+ INNER_DOF = wp.constant(0)
339
+ OUTER_DOF = wp.constant(1)
340
+ INNER_GRAD_DOF = wp.constant(2)
341
+ OUTER_GRAD_DOF = wp.constant(3)
342
+ DOF_TYPE_COUNT = wp.constant(4)
343
+
344
+ _OP_DOF_MAP_CONTINUOUS: ClassVar[Dict[operator.Operator, int]] = {
345
+ operator.inner: INNER_DOF,
346
+ operator.outer: INNER_DOF,
347
+ operator.grad: INNER_GRAD_DOF,
348
+ operator.grad_outer: OUTER_GRAD_DOF,
349
+ operator.div: INNER_GRAD_DOF,
350
+ operator.div_outer: OUTER_GRAD_DOF,
351
+ }
352
+
353
+ _OP_DOF_MAP_DISCONTINUOUS: ClassVar[Dict[operator.Operator, int]] = {
354
+ operator.inner: INNER_DOF,
355
+ operator.outer: OUTER_DOF,
356
+ operator.grad: INNER_GRAD_DOF,
357
+ operator.grad_outer: OUTER_GRAD_DOF,
358
+ operator.div: INNER_GRAD_DOF,
359
+ operator.div_outer: OUTER_GRAD_DOF,
360
+ }
361
+
362
+ DofOffsets = wp.vec(length=DOF_TYPE_COUNT, dtype=int)
363
+
364
+ @wp.struct
365
+ class EvalArg:
366
+ pass
367
+
368
+ _dynamic_attribute_constructors: ClassVar = {
369
+ "ElementEvalArg": lambda obj: obj._make_element_eval_arg(),
370
+ "eval_degree": lambda obj: obj._make_eval_degree(),
371
+ "_split_dof": lambda obj: obj._make_split_dof(),
372
+ "eval_inner": lambda obj: obj._make_eval_inner(),
373
+ "eval_grad_inner": lambda obj: obj._make_eval_grad_inner(),
374
+ "eval_div_inner": lambda obj: obj._make_eval_div_inner(),
375
+ "eval_outer": lambda obj: obj._make_eval_outer(),
376
+ "eval_grad_outer": lambda obj: obj._make_eval_grad_outer(),
377
+ "eval_div_outer": lambda obj: obj._make_eval_div_outer(),
378
+ }
379
+
380
+ def __init__(self, field: AdjointField):
381
+ # if not isinstance(field.space, CollocatedFunctionSpace):
382
+ # raise NotImplementedError("Local assembly only implemented for collocated function spaces")
383
+
384
+ super().__init__(field.space, space_partition=field.space_partition)
385
+ self.global_field = field
386
+
387
+ self.domain = self.global_field.domain
388
+ self.node_dof_count = self.space.NODE_DOF_COUNT
389
+ self.value_dof_count = self.space.VALUE_DOF_COUNT
390
+
391
+ self._dof_suffix = ""
392
+ self.at_node = None
393
+
394
+ self._is_discontinuous = (self.space.element_inner_weight != self.space.element_outer_weight) or (
395
+ self.space.element_inner_weight_gradient != self.space.element_outer_weight_gradient
396
+ )
397
+
398
+ self._TAYLOR_DOF_OFFSETS = LocalAdjointField.DofOffsets(0)
399
+ self._TAYLOR_DOF_COUNTS = LocalAdjointField.DofOffsets(0)
400
+ self.TAYLOR_DOF_COUNT = 0
401
+
402
+ cache.setup_dynamic_attributes(self)
403
+
404
+ def notify_operator_usage(self, ops: Set[operator.Operator]):
405
+ # Rebuild degrees-of-freedom offsets based on used operators
406
+
407
+ operators_dof_map = (
408
+ LocalAdjointField._OP_DOF_MAP_DISCONTINUOUS
409
+ if self._is_discontinuous
410
+ else LocalAdjointField._OP_DOF_MAP_CONTINUOUS
411
+ )
412
+
413
+ dof_counts = LocalAdjointField.DofOffsets(0)
414
+ for op in ops:
415
+ if op in operators_dof_map:
416
+ dof_counts[operators_dof_map[op]] = 1
417
+
418
+ grad_dim = self.geometry.cell_dimension
419
+ dof_counts[LocalAdjointField.INNER_GRAD_DOF] *= grad_dim
420
+ dof_counts[LocalAdjointField.OUTER_GRAD_DOF] *= grad_dim
421
+
422
+ dof_offsets = LocalAdjointField.DofOffsets(0)
423
+ for k in range(1, LocalAdjointField.DOF_TYPE_COUNT):
424
+ dof_offsets[k] = dof_offsets[k - 1] + dof_counts[k - 1]
425
+
426
+ self.TAYLOR_DOF_COUNT = wp.constant(dof_offsets[k] + dof_counts[k])
427
+
428
+ self._TAYLOR_DOF_OFFSETS = dof_offsets
429
+ self._TAYLOR_DOF_COUNTS = dof_counts
430
+
431
+ self._dof_suffix = "".join(str(c) for c in dof_counts)
432
+ cache.setup_dynamic_attributes(self)
433
+
434
+ @property
435
+ def name(self) -> str:
436
+ return f"{self.global_field.name}_Taylor{self._dof_suffix}"
437
+
438
+ def fill_eval_arg(self, arg, device):
439
+ pass
440
+
441
+ def _make_element_eval_arg(self):
442
+ from warp._src.fem import cache
443
+
444
+ @cache.dynamic_struct(suffix=self.name)
445
+ class ElementEvalArg:
446
+ elt_arg: self.space.topology.ElementArg
447
+ eval_arg: self.EvalArg
448
+
449
+ return ElementEvalArg
450
+
451
+ def _make_split_dof(self):
452
+ TAYLOR_DOF_COUNT = self.TAYLOR_DOF_COUNT
453
+
454
+ @cache.dynamic_func(suffix=str(TAYLOR_DOF_COUNT))
455
+ def split_dof(dof_index: DofIndex, dof_begin: int):
456
+ taylor_dof = get_node_index_in_element(dof_index) - dof_begin
457
+ value_dof = get_node_coord(dof_index)
458
+ return value_dof, taylor_dof
459
+
460
+ return split_dof
461
+
462
+ def _make_eval_inner(self):
463
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF])
464
+ zero_element = type_zero_element(self.dtype)
465
+
466
+ @cache.dynamic_func(suffix=self.name)
467
+ def eval_test_inner(args: self.ElementEvalArg, s: Sample):
468
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
469
+
470
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
471
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
472
+ return wp.where(taylor_dof == 0, dof_value, zero_element())
473
+
474
+ return eval_test_inner
475
+
476
+ def _make_eval_grad_inner(self):
477
+ if not self.gradient_valid():
478
+ return None
479
+
480
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF])
481
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF])
482
+ zero_element = type_zero_element(self.gradient_dtype)
483
+
484
+ @cache.dynamic_func(suffix=self.name)
485
+ def eval_nabla_test_inner(args: self.ElementEvalArg, s: Sample):
486
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
487
+
488
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
489
+ return zero_element()
490
+
491
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
492
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
493
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
494
+ return generalized_outer(dof_value, grad_transform[taylor_dof])
495
+
496
+ return eval_nabla_test_inner
497
+
498
+ def _make_eval_div_inner(self):
499
+ if not self.divergence_valid():
500
+ return None
501
+
502
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF])
503
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF])
504
+ zero_element = type_zero_element(self.divergence_dtype)
505
+
506
+ @cache.dynamic_func(suffix=self.name)
507
+ def eval_div_test_inner(args: self.ElementEvalArg, s: Sample):
508
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
509
+
510
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
511
+ return zero_element()
512
+
513
+ grad_transform = self.space.element_inner_reference_gradient_transform(args.elt_arg, s)
514
+ local_value_map = self.space.local_value_map_inner(args.elt_arg, s.element_index, s.element_coords)
515
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
516
+ return generalized_inner(dof_value, grad_transform[taylor_dof])
517
+
518
+ return eval_div_test_inner
519
+
520
+ def _make_eval_outer(self):
521
+ if not self._is_discontinuous:
522
+ return self.eval_inner
523
+
524
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF])
525
+ zero_element = type_zero_element(self.dtype)
526
+
527
+ @cache.dynamic_func(suffix=self.name)
528
+ def eval_test_outer(args: self.ElementEvalArg, s: Sample):
529
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
530
+
531
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
532
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
533
+ return wp.where(taylor_dof == 0, dof_value, zero_element())
534
+
535
+ return eval_test_outer
536
+
537
+ def _make_eval_grad_outer(self):
538
+ if not self.gradient_valid():
539
+ return None
540
+
541
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF])
542
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF])
543
+ zero_element = type_zero_element(self.gradient_dtype)
544
+
545
+ @cache.dynamic_func(suffix=self.name)
546
+ def eval_nabla_test_outer(args: self.ElementEvalArg, s: Sample):
547
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
548
+
549
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
550
+ return zero_element()
551
+
552
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
553
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
554
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
555
+ return generalized_outer(dof_value, grad_transform[taylor_dof])
556
+
557
+ return eval_nabla_test_outer
558
+
559
+ def _make_eval_div_outer(self):
560
+ if not self.divergence_valid():
561
+ return None
562
+
563
+ DOF_BEGIN = wp.constant(self._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF])
564
+ DOF_COUNT = wp.constant(self._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF])
565
+ zero_element = type_zero_element(self.divergence_dtype)
566
+
567
+ @cache.dynamic_func(suffix=self.name)
568
+ def eval_div_test_outer(args: self.ElementEvalArg, s: Sample):
569
+ value_dof, taylor_dof = self._split_dof(self._get_dof(s), DOF_BEGIN)
570
+
571
+ if taylor_dof < 0 or taylor_dof >= DOF_COUNT:
572
+ return zero_element()
573
+
574
+ grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
575
+ local_value_map = self.space.local_value_map_outer(args.elt_arg, s.element_index, s.element_coords)
576
+ dof_value = self.space.value_basis_element(value_dof, local_value_map)
577
+ return generalized_inner(dof_value, grad_transform[taylor_dof])
578
+
579
+ return eval_div_test_outer
580
+
581
+
582
+ class LocalTestField(LocalAdjointField):
583
+ def __init__(self, test_field: TestField):
584
+ super().__init__(test_field)
585
+ self.space_restriction = test_field.space_restriction
586
+
587
+ @wp.func
588
+ def _get_dof(s: Sample):
589
+ return s.test_dof
590
+
591
+
592
+ class LocalTrialField(LocalAdjointField):
593
+ def __init__(self, trial_field: TrialField):
594
+ super().__init__(trial_field)
595
+
596
+ @wp.func
597
+ def _get_dof(s: Sample):
598
+ return s.trial_dof
599
+
600
+
601
+ def make_linear_dispatch_kernel(
602
+ test: LocalTestField,
603
+ quadrature: Quadrature,
604
+ accumulate_dtype: type,
605
+ tile_size: int = 1,
606
+ kernel_options: Optional[Dict[str, Any]] = None,
607
+ ):
608
+ global_test: TestField = test.global_field
609
+ space_restriction = global_test.space_restriction
610
+ domain = global_test.domain
611
+
612
+ TEST_INNER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
613
+ TEST_OUTER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
614
+ TEST_INNER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
615
+ TEST_OUTER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
616
+
617
+ TEST_INNER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
618
+ TEST_OUTER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
619
+ TEST_INNER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
620
+ TEST_OUTER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
621
+
622
+ TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
623
+ TEST_NODE_DOF_COUNT = test.node_dof_count
624
+
625
+ res_vec = cache.cached_vec_type(length=test.node_dof_count, dtype=accumulate_dtype)
626
+ qp_vec = cache.cached_vec_type(length=test.node_dof_count, dtype=float)
627
+
628
+ @cache.dynamic_func(f"{test.name}_{quadrature.name}")
629
+ def next_qp(
630
+ qp: int,
631
+ elem_offset: int,
632
+ qp_point_count: int,
633
+ element_index: ElementIndex,
634
+ test_element_index: NodeElementIndex,
635
+ element_end: int,
636
+ qp_arg: quadrature.Arg,
637
+ domain_arg: domain.ElementArg,
638
+ domain_index_arg: domain.ElementIndexArg,
639
+ test_arg: space_restriction.NodeArg,
640
+ ):
641
+ while qp >= qp_point_count and elem_offset < element_end:
642
+ # Next element
643
+ elem_offset += 1
644
+
645
+ if elem_offset < element_end:
646
+ qp -= qp_point_count
647
+ test_element_index = space_restriction.node_element_index(test_arg, elem_offset)
648
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
649
+ qp_point_count = quadrature.point_count(
650
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index
651
+ )
652
+
653
+ return qp, elem_offset, qp_point_count, element_index, test_element_index
654
+
655
+ @cache.dynamic_kernel(
656
+ (test.name, quadrature.name, cache.pod_type_key(accumulate_dtype), tile_size),
657
+ kernel_options=kernel_options,
658
+ )
659
+ def dispatch_linear_kernel_fn(
660
+ qp_arg: quadrature.Arg,
661
+ domain_arg: domain.ElementArg,
662
+ domain_index_arg: domain.ElementIndexArg,
663
+ test_arg: space_restriction.NodeArg,
664
+ test_space_arg: test.space.SpaceArg,
665
+ local_result: wp.array3d(dtype=Any),
666
+ result: wp.array2d(dtype=Any),
667
+ ):
668
+ local_node_index, lane = wp.tid()
669
+
670
+ node_index = space_restriction.node_partition_index(test_arg, local_node_index)
671
+ if node_index == NULL_NODE_INDEX:
672
+ return
673
+
674
+ element_beg, element_end = space_restriction.node_element_range(test_arg, node_index)
675
+
676
+ val_sum = res_vec()
677
+
678
+ elem_offset = element_beg - 1
679
+ qp_point_count = int(0)
680
+ qp = lane
681
+ test_element_index = NodeElementIndex()
682
+ element_index = ElementIndex(NULL_ELEMENT_INDEX)
683
+
684
+ while elem_offset < element_end:
685
+ qp, elem_offset, qp_point_count, element_index, test_element_index = next_qp(
686
+ qp,
687
+ elem_offset,
688
+ qp_point_count,
689
+ element_index,
690
+ test_element_index,
691
+ element_end,
692
+ qp_arg,
693
+ domain_arg,
694
+ domain_index_arg,
695
+ test_arg,
696
+ )
697
+
698
+ if qp < qp_point_count:
699
+ qp_index = quadrature.point_index(
700
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
701
+ )
702
+ qp_eval_index = quadrature.point_evaluation_index(
703
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
704
+ )
705
+ coords = quadrature.point_coords(
706
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, qp
707
+ )
708
+
709
+ qp_result = local_result[qp_eval_index]
710
+
711
+ qp_sum = qp_vec()
712
+
713
+ if wp.static(0 != TEST_INNER_COUNT):
714
+ w = test.space.element_inner_weight(
715
+ domain_arg,
716
+ test_space_arg,
717
+ element_index,
718
+ coords,
719
+ test_element_index.node_index_in_element,
720
+ qp_index,
721
+ )
722
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
723
+ for val_dof in range(TEST_NODE_DOF_DIM):
724
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
725
+ qp_sum[test_node_dof] += (
726
+ basis_coefficient(w, val_dof) * qp_result[TEST_INNER_BEGIN, test_dof]
727
+ )
728
+
729
+ if wp.static(0 != TEST_OUTER_COUNT):
730
+ w = test.space.element_outer_weight(
731
+ domain_arg,
732
+ test_space_arg,
733
+ element_index,
734
+ coords,
735
+ test_element_index.node_index_in_element,
736
+ qp_index,
737
+ )
738
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
739
+ for val_dof in range(TEST_NODE_DOF_DIM):
740
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
741
+ qp_sum[test_node_dof] += (
742
+ basis_coefficient(w, val_dof) * qp_result[TEST_OUTER_BEGIN, test_dof]
743
+ )
744
+
745
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
746
+ w_grad = test.space.element_inner_weight_gradient(
747
+ domain_arg,
748
+ test_space_arg,
749
+ element_index,
750
+ coords,
751
+ test_element_index.node_index_in_element,
752
+ qp_index,
753
+ )
754
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
755
+ for val_dof in range(TEST_NODE_DOF_DIM):
756
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
757
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
758
+ qp_sum[test_node_dof] += (
759
+ basis_coefficient(w_grad, val_dof, grad_dof)
760
+ * qp_result[grad_dof + TEST_INNER_GRAD_BEGIN, test_dof]
761
+ )
762
+
763
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
764
+ w_grad = test.space.element_outer_weight_gradient(
765
+ domain_arg,
766
+ test_space_arg,
767
+ element_index,
768
+ coords,
769
+ test_element_index.node_index_in_element,
770
+ qp_index,
771
+ )
772
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
773
+ for val_dof in range(TEST_NODE_DOF_DIM):
774
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + val_dof
775
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
776
+ qp_sum[test_node_dof] += (
777
+ basis_coefficient(w_grad, val_dof, grad_dof)
778
+ * qp_result[grad_dof + TEST_OUTER_GRAD_BEGIN, test_dof]
779
+ )
780
+
781
+ val_sum += res_vec(qp_sum)
782
+ qp += wp.static(tile_size)
783
+
784
+ if wp.static(tile_size == 1):
785
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
786
+ result[node_index, test_node_dof] += result.dtype(val_sum[test_node_dof])
787
+ else:
788
+ t_sum = wp.tile_sum(wp.tile(val_sum, preserve_type=True))[0]
789
+ for test_node_dof in range(lane, TEST_NODE_DOF_COUNT, wp.static(tile_size)):
790
+ result[node_index, test_node_dof] += result.dtype(t_sum[test_node_dof])
791
+
792
+ return dispatch_linear_kernel_fn
793
+
794
+
795
+ def make_bilinear_dispatch_kernel(
796
+ test: LocalTestField,
797
+ trial: LocalTrialField,
798
+ quadrature: Quadrature,
799
+ accumulate_dtype: type,
800
+ tile_size: int = 1,
801
+ kernel_options: Optional[Dict[str, Any]] = None,
802
+ ):
803
+ global_test: TestField = test.global_field
804
+ space_restriction = global_test.space_restriction
805
+ domain = global_test.domain
806
+
807
+ TEST_INNER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
808
+ TEST_OUTER_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
809
+ TEST_INNER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
810
+ TEST_OUTER_GRAD_COUNT = test._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
811
+
812
+ TEST_INNER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
813
+ TEST_OUTER_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
814
+ TEST_INNER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
815
+ TEST_OUTER_GRAD_BEGIN = test._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
816
+
817
+ TRIAL_INNER_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_DOF]
818
+ TRIAL_OUTER_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_DOF]
819
+ TRIAL_INNER_GRAD_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.INNER_GRAD_DOF]
820
+ TRIAL_OUTER_GRAD_COUNT = trial._TAYLOR_DOF_COUNTS[LocalAdjointField.OUTER_GRAD_DOF]
821
+
822
+ TRIAL_INNER_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_DOF]
823
+ TRIAL_OUTER_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_DOF]
824
+ TRIAL_INNER_GRAD_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.INNER_GRAD_DOF]
825
+ TRIAL_OUTER_GRAD_BEGIN = trial._TAYLOR_DOF_OFFSETS[LocalAdjointField.OUTER_GRAD_DOF]
826
+
827
+ TEST_NODE_DOF_DIM = test.value_dof_count // test.node_dof_count
828
+ TRIAL_NODE_DOF_DIM = trial.value_dof_count // trial.node_dof_count
829
+ TEST_TRIAL_NODE_DOF_DIM = TEST_NODE_DOF_DIM * TRIAL_NODE_DOF_DIM
830
+
831
+ TEST_NODE_DOF_COUNT = test.node_dof_count
832
+ TRIAL_NODE_DOF_COUNT = trial.node_dof_count
833
+ TEST_TAYLOR_DOF_COUNT = test.TAYLOR_DOF_COUNT
834
+ TRIAL_TAYLOR_DOF_COUNT = trial.TAYLOR_DOF_COUNT
835
+
836
+ MAX_NODES_PER_ELEMENT = trial.space.topology.MAX_NODES_PER_ELEMENT
837
+
838
+ trial_dof_vec = cache.cached_vec_type(length=trial.TAYLOR_DOF_COUNT, dtype=float)
839
+ test_dof_vec = cache.cached_vec_type(length=test.TAYLOR_DOF_COUNT, dtype=float)
840
+
841
+ val_t = cache.cached_mat_type(shape=(test.node_dof_count, trial.node_dof_count), dtype=accumulate_dtype)
842
+
843
+ @cache.dynamic_kernel(
844
+ (trial.name, test.name, quadrature.name, cache.pod_type_key(accumulate_dtype), tile_size),
845
+ kernel_options=kernel_options,
846
+ )
847
+ def dispatch_bilinear_kernel_fn(
848
+ qp_arg: quadrature.Arg,
849
+ domain_arg: domain.ElementArg,
850
+ domain_index_arg: domain.ElementIndexArg,
851
+ test_arg: test.space_restriction.NodeArg,
852
+ test_space_arg: test.space.SpaceArg,
853
+ trial_partition_arg: trial.space_partition.PartitionArg,
854
+ trial_topology_arg: trial.space_partition.space_topology.TopologyArg,
855
+ trial_space_arg: trial.space.SpaceArg,
856
+ local_result: wp.array4d(dtype=float),
857
+ triplet_rows: wp.array(dtype=int),
858
+ triplet_cols: wp.array(dtype=int),
859
+ triplet_values: wp.array3d(dtype=Any),
860
+ ):
861
+ test_node_offset, trial_node, lane = wp.tid()
862
+
863
+ test_node_index = space_restriction.node_partition_index_from_element_offset(test_arg, test_node_offset)
864
+
865
+ test_element_index = space_restriction.node_element_index(test_arg, test_node_offset)
866
+ element_index = domain.element_index(domain_index_arg, test_element_index.domain_element_index)
867
+ test_node = test_element_index.node_index_in_element
868
+
869
+ if element_index == NULL_ELEMENT_INDEX:
870
+ element_trial_node_count = 0
871
+ else:
872
+ element_trial_node_count = trial.space.topology.element_node_count(
873
+ domain_arg, trial_topology_arg, element_index
874
+ )
875
+
876
+ if trial_node >= element_trial_node_count:
877
+ block_offset = test_node_offset * MAX_NODES_PER_ELEMENT + trial_node
878
+ triplet_rows[block_offset] = NULL_NODE_INDEX
879
+ triplet_cols[block_offset] = NULL_NODE_INDEX
880
+ return
881
+
882
+ qp_point_count = quadrature.point_count(
883
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index
884
+ )
885
+ qp_dof_count = qp_point_count * TEST_TRIAL_NODE_DOF_DIM
886
+
887
+ val_sum = val_t()
888
+
889
+ for dof in range(lane, qp_dof_count, wp.static(tile_size)):
890
+ k = dof // TEST_TRIAL_NODE_DOF_DIM
891
+ test_trial_val_dof = dof - k * TEST_TRIAL_NODE_DOF_DIM
892
+ test_val_dof = test_trial_val_dof // TRIAL_NODE_DOF_DIM
893
+ trial_val_dof = test_trial_val_dof - test_val_dof * TRIAL_NODE_DOF_DIM
894
+
895
+ qp_index = quadrature.point_index(
896
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
897
+ )
898
+ qp_eval_index = quadrature.point_evaluation_index(
899
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
900
+ )
901
+ coords = quadrature.point_coords(
902
+ domain_arg, qp_arg, test_element_index.domain_element_index, element_index, k
903
+ )
904
+
905
+ # test shape functions
906
+ w_test = test_dof_vec()
907
+
908
+ if wp.static(0 != TEST_INNER_COUNT):
909
+ w_test_inner = test.space.element_inner_weight(
910
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
911
+ )
912
+ w_test[TEST_INNER_BEGIN] = basis_coefficient(w_test_inner, test_val_dof)
913
+
914
+ if wp.static(0 != TEST_OUTER_COUNT):
915
+ w_test_outer = test.space.element_outer_weight(
916
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
917
+ )
918
+ w_test[TEST_OUTER_BEGIN] = basis_coefficient(w_test_outer, test_val_dof)
919
+
920
+ if wp.static(0 != TEST_INNER_GRAD_COUNT):
921
+ w_test_grad_inner = test.space.element_inner_weight_gradient(
922
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
923
+ )
924
+ for grad_dof in range(TEST_INNER_GRAD_COUNT):
925
+ w_test[TEST_INNER_GRAD_BEGIN + grad_dof] = basis_coefficient(
926
+ w_test_grad_inner, test_val_dof, grad_dof
927
+ )
928
+
929
+ if wp.static(0 != TEST_OUTER_GRAD_COUNT):
930
+ w_test_grad_outer = test.space.element_outer_weight_gradient(
931
+ domain_arg, test_space_arg, element_index, coords, test_node, qp_index
932
+ )
933
+ for grad_dof in range(TEST_OUTER_GRAD_COUNT):
934
+ w_test[TEST_OUTER_GRAD_BEGIN + grad_dof] = basis_coefficient(
935
+ w_test_grad_outer, test_val_dof, grad_dof
936
+ )
937
+
938
+ # trial shape functions
939
+ w_trial = trial_dof_vec()
940
+
941
+ if wp.static(0 != TRIAL_INNER_COUNT):
942
+ w_trial_inner = trial.space.element_inner_weight(
943
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
944
+ )
945
+ w_trial[TRIAL_INNER_BEGIN] = basis_coefficient(w_trial_inner, trial_val_dof)
946
+
947
+ if wp.static(0 != TRIAL_OUTER_COUNT):
948
+ w_trial_outer = trial.space.element_outer_weight(
949
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
950
+ )
951
+ w_trial[TRIAL_OUTER_BEGIN] = basis_coefficient(w_trial_outer, trial_val_dof)
952
+
953
+ if wp.static(0 != TRIAL_INNER_GRAD_COUNT):
954
+ w_trial_grad_inner = trial.space.element_inner_weight_gradient(
955
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
956
+ )
957
+ for grad_dof in range(TRIAL_INNER_GRAD_COUNT):
958
+ w_trial[TRIAL_INNER_GRAD_BEGIN + grad_dof] = basis_coefficient(
959
+ w_trial_grad_inner, trial_val_dof, grad_dof
960
+ )
961
+
962
+ if wp.static(0 != TRIAL_OUTER_GRAD_COUNT):
963
+ w_trial_grad_outer = trial.space.element_outer_weight_gradient(
964
+ domain_arg, trial_space_arg, element_index, coords, trial_node, qp_index
965
+ )
966
+ for grad_dof in range(TRIAL_OUTER_GRAD_COUNT):
967
+ w_trial[TRIAL_OUTER_GRAD_BEGIN + grad_dof] = basis_coefficient(
968
+ w_trial_grad_outer, trial_val_dof, grad_dof
969
+ )
970
+
971
+ # triple product test @ qp @ trial
972
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
973
+ test_dof = test_node_dof * TEST_NODE_DOF_DIM + test_val_dof
974
+ for trial_node_dof in range(TRIAL_NODE_DOF_COUNT):
975
+ dof_res = float(0.0)
976
+ trial_dof = trial_node_dof * TRIAL_NODE_DOF_DIM + trial_val_dof
977
+
978
+ for test_taylor_dof in range(TEST_TAYLOR_DOF_COUNT):
979
+ test_res = float(0.0)
980
+ for trial_taylor_dof in range(TRIAL_TAYLOR_DOF_COUNT):
981
+ taylor_dof = test_taylor_dof * TRIAL_TAYLOR_DOF_COUNT + trial_taylor_dof
982
+ test_res += (
983
+ local_result[test_dof, trial_dof, qp_eval_index, taylor_dof] * w_trial[trial_taylor_dof]
984
+ )
985
+ dof_res += w_test[test_taylor_dof] * test_res
986
+
987
+ val_sum[test_node_dof, trial_node_dof] += accumulate_dtype(dof_res)
988
+
989
+ # write block value
990
+ block_offset = test_node_offset * MAX_NODES_PER_ELEMENT + trial_node
991
+ if wp.static(tile_size) > 1:
992
+ val_sum = wp.tile_sum(wp.tile(val_sum, preserve_type=True))[0]
993
+
994
+ for dof in range(lane, wp.static(TEST_NODE_DOF_COUNT * TRIAL_NODE_DOF_COUNT), wp.static(tile_size)):
995
+ test_node_dof = dof // TRIAL_NODE_DOF_COUNT
996
+ trial_node_dof = dof - TRIAL_NODE_DOF_COUNT * test_node_dof
997
+
998
+ triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(
999
+ val_sum[test_node_dof, trial_node_dof]
1000
+ )
1001
+ else:
1002
+ for test_node_dof in range(TEST_NODE_DOF_COUNT):
1003
+ for trial_node_dof in range(TRIAL_NODE_DOF_COUNT):
1004
+ triplet_values[block_offset, test_node_dof, trial_node_dof] = triplet_values.dtype(
1005
+ val_sum[test_node_dof, trial_node_dof]
1006
+ )
1007
+
1008
+ # Set row and column indices
1009
+ if lane == 0:
1010
+ if trial_node < element_trial_node_count:
1011
+ trial_node_index = trial.space_partition.partition_node_index(
1012
+ trial_partition_arg,
1013
+ trial.space.topology.element_node_index(domain_arg, trial_topology_arg, element_index, trial_node),
1014
+ )
1015
+ else:
1016
+ trial_node_index = NULL_NODE_INDEX # will get ignored when converting to bsr
1017
+
1018
+ triplet_rows[block_offset] = test_node_index
1019
+ triplet_cols[block_offset] = trial_node_index
1020
+
1021
+ return dispatch_bilinear_kernel_fn