warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,665 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import cached_property
17
+ from typing import Any, ClassVar, Optional
18
+
19
+ import warp as wp
20
+ from warp._src.fem import cache
21
+ from warp._src.fem.domain import GeometryDomain
22
+ from warp._src.fem.geometry import Element
23
+ from warp._src.fem.space.function_space import FunctionSpace
24
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
25
+
26
+ from ..polynomial import Polynomial
27
+
28
+ _wp_module_name_ = "warp.fem.quadrature.quadrature"
29
+
30
+
31
+ @wp.struct
32
+ class QuadraturePointElementIndex:
33
+ domain_element_index: ElementIndex
34
+ qp_index_in_element: int
35
+
36
+
37
+ class Quadrature:
38
+ """Interface class for quadrature rules"""
39
+
40
+ @wp.struct
41
+ class Arg:
42
+ """Structure containing arguments to be passed to device functions"""
43
+
44
+ pass
45
+
46
+ def __init__(self, domain: GeometryDomain):
47
+ self._domain = domain
48
+
49
+ @property
50
+ def domain(self):
51
+ """Domain over which this quadrature is defined"""
52
+ return self._domain
53
+
54
+ @cache.cached_arg_value
55
+ def arg_value(self, device) -> "Arg":
56
+ """
57
+ Value of the argument to be passed to device
58
+ """
59
+ arg = self.Arg()
60
+ self.fill_arg(arg, device)
61
+ return arg
62
+
63
+ def fill_arg(self, arg: Arg, device):
64
+ """
65
+ Fill the argument with the value of the argument to be passed to device
66
+ """
67
+ if self.arg_value is __class__.arg_value:
68
+ raise NotImplementedError()
69
+ arg.assign(self.arg_value(device))
70
+
71
+ def total_point_count(self):
72
+ """Number of unique quadrature points that can be indexed by this rule.
73
+ Returns a number such that `point_index()` is always smaller than this number.
74
+ """
75
+ raise NotImplementedError()
76
+
77
+ def evaluation_point_count(self):
78
+ """Number of quadrature points that needs to be evaluated, mostly for internal purposes.
79
+ If the indexing scheme is sparse, or if a quadrature point is shared among multiple elements
80
+ (e.g, nodal quadrature), `evaluation_point_count` may be different than `total_point_count()`.
81
+ Returns a number such that `evaluation_point_index()` is always smaller than this number.
82
+ """
83
+ return self.total_point_count()
84
+
85
+ def max_points_per_element(self):
86
+ """Maximum number of points per element if known, or ``None`` otherwise"""
87
+ return None
88
+
89
+ @staticmethod
90
+ def point_count(
91
+ elt_arg: "GeometryDomain.ElementArg",
92
+ qp_arg: Arg,
93
+ domain_element_index: ElementIndex,
94
+ geo_element_index: ElementIndex,
95
+ ):
96
+ """Number of quadrature points for a given element"""
97
+ raise NotImplementedError()
98
+
99
+ @staticmethod
100
+ def point_coords(
101
+ elt_arg: "GeometryDomain.ElementArg",
102
+ qp_arg: Arg,
103
+ domain_element_index: ElementIndex,
104
+ geo_element_index: ElementIndex,
105
+ element_qp_index: int,
106
+ ):
107
+ """Coordinates in element of the element's qp_index'th quadrature point"""
108
+ raise NotImplementedError()
109
+
110
+ @staticmethod
111
+ def point_weight(
112
+ elt_arg: "GeometryDomain.ElementArg",
113
+ qp_arg: Arg,
114
+ domain_element_index: ElementIndex,
115
+ geo_element_index: ElementIndex,
116
+ element_qp_index: int,
117
+ ):
118
+ """Weight of the element's qp_index'th quadrature point"""
119
+ raise NotImplementedError()
120
+
121
+ @staticmethod
122
+ def point_index(
123
+ elt_arg: "GeometryDomain.ElementArg",
124
+ qp_arg: Arg,
125
+ domain_element_index: ElementIndex,
126
+ geo_element_index: ElementIndex,
127
+ element_qp_index: int,
128
+ ):
129
+ """
130
+ Global index of the element's qp_index'th quadrature point.
131
+ May be shared among elements.
132
+ This is what determines `qp_index` in integrands' `Sample` arguments.
133
+ """
134
+ raise NotImplementedError()
135
+
136
+ @staticmethod
137
+ def point_evaluation_index(
138
+ elt_arg: "GeometryDomain.ElementArg",
139
+ qp_arg: Arg,
140
+ domain_element_index: ElementIndex,
141
+ geo_element_index: ElementIndex,
142
+ element_qp_index: int,
143
+ ):
144
+ """Quadrature point index according to evaluation order.
145
+ Quadrature points for distinct elements must have different evaluation indices.
146
+ Mostly for internal/parallelization purposes.
147
+ """
148
+ raise NotImplementedError()
149
+
150
+ def __str__(self) -> str:
151
+ return self.name
152
+
153
+ # By default cache the mapping from evaluation point indices to domain elements
154
+
155
+ ElementIndexArg = wp.array(dtype=QuadraturePointElementIndex)
156
+
157
+ @cache.cached_arg_value
158
+ def element_index_arg_value(self, device):
159
+ """Builds a map from quadrature point evaluation indices to their index in the element to which they belong"""
160
+
161
+ @cache.dynamic_kernel(f"{self.name}{self.domain.name}")
162
+ def quadrature_point_element_indices(
163
+ qp_arg: self.Arg,
164
+ domain_arg: self.domain.ElementArg,
165
+ domain_index_arg: self.domain.ElementIndexArg,
166
+ result: wp.array(dtype=QuadraturePointElementIndex),
167
+ ):
168
+ domain_element_index = wp.tid()
169
+ element_index = self.domain.element_index(domain_index_arg, domain_element_index)
170
+ if element_index == NULL_ELEMENT_INDEX:
171
+ return
172
+
173
+ qp_point_count = self.point_count(domain_arg, qp_arg, domain_element_index, element_index)
174
+ for k in range(qp_point_count):
175
+ qp_eval_index = self.point_evaluation_index(domain_arg, qp_arg, domain_element_index, element_index, k)
176
+ result[qp_eval_index] = QuadraturePointElementIndex(domain_element_index, k)
177
+
178
+ null_qp_index = QuadraturePointElementIndex()
179
+ null_qp_index.domain_element_index = NULL_ELEMENT_INDEX
180
+ result = wp.full(
181
+ value=null_qp_index,
182
+ shape=(self.evaluation_point_count()),
183
+ dtype=QuadraturePointElementIndex,
184
+ device=device,
185
+ )
186
+ wp.launch(
187
+ quadrature_point_element_indices,
188
+ device=result.device,
189
+ dim=self.domain.element_count(),
190
+ inputs=[
191
+ self.arg_value(result.device),
192
+ self.domain.element_arg_value(result.device),
193
+ self.domain.element_index_arg_value(result.device),
194
+ result,
195
+ ],
196
+ )
197
+
198
+ return result
199
+
200
+ @wp.func
201
+ def evaluation_point_element_index(
202
+ element_index_arg: wp.array(dtype=QuadraturePointElementIndex),
203
+ qp_eval_index: QuadraturePointIndex,
204
+ ):
205
+ """Maps from quadrature point evaluation indices to their index in the element to which they belong
206
+ If the quadrature point does not exist, should return NULL_ELEMENT_INDEX as the domain element index
207
+ """
208
+
209
+ element_index = element_index_arg[qp_eval_index]
210
+ return element_index.domain_element_index, element_index.qp_index_in_element
211
+
212
+
213
+ class _QuadratureWithRegularEvaluationPoints(Quadrature):
214
+ """Helper subclass for quadrature formulas which use a uniform number of
215
+ evaluations points per element. Avoids building explicit mapping"""
216
+
217
+ _dynamic_attribute_constructors: ClassVar = {
218
+ "point_evaluation_index": lambda obj: obj._make_regular_point_evaluation_index(),
219
+ "evaluation_point_element_index": lambda obj: obj._make_regular_evaluation_point_element_index(),
220
+ }
221
+
222
+ def __init__(self, domain: GeometryDomain, N: int):
223
+ super().__init__(domain)
224
+ self._EVALUATION_POINTS_PER_ELEMENT = N
225
+
226
+ cache.setup_dynamic_attributes(self, cls=__class__)
227
+
228
+ ElementIndexArg = Quadrature.Arg
229
+
230
+ def element_index_arg_value(self, device):
231
+ return Quadrature.Arg()
232
+
233
+ def evaluation_point_count(self):
234
+ return self.domain.element_count() * self._EVALUATION_POINTS_PER_ELEMENT
235
+
236
+ def _make_regular_point_evaluation_index(self):
237
+ N = self._EVALUATION_POINTS_PER_ELEMENT
238
+
239
+ @cache.dynamic_func(suffix=f"{self.name}")
240
+ def evaluation_point_index(
241
+ elt_arg: self.domain.ElementArg,
242
+ qp_arg: self.Arg,
243
+ domain_element_index: ElementIndex,
244
+ element_index: ElementIndex,
245
+ qp_index: int,
246
+ ):
247
+ return N * domain_element_index + qp_index
248
+
249
+ return evaluation_point_index
250
+
251
+ def _make_regular_evaluation_point_element_index(self):
252
+ N = self._EVALUATION_POINTS_PER_ELEMENT
253
+
254
+ @cache.dynamic_func(suffix=f"{N}")
255
+ def quadrature_evaluation_point_element_index(
256
+ qp_arg: Quadrature.Arg,
257
+ qp_index: QuadraturePointIndex,
258
+ ):
259
+ domain_element_index = qp_index // N
260
+ index_in_element = qp_index - domain_element_index * N
261
+ return domain_element_index, index_in_element
262
+
263
+ return quadrature_evaluation_point_element_index
264
+
265
+
266
+ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
267
+ """Regular quadrature formula, using a constant set of quadrature points per element"""
268
+
269
+ @wp.struct
270
+ class Arg:
271
+ # Quadrature points and weights used to be passed as Warp constants,
272
+ # but this tended to incur register spilling for high point counts
273
+ points: wp.array(dtype=Coords)
274
+ weights: wp.array(dtype=float)
275
+
276
+ # Cache common formulas so we do dot have to do h2d transfer for each call
277
+ class CachedFormula:
278
+ _cache: ClassVar = {}
279
+
280
+ def __init__(self, element: Element, order: int, family: Polynomial):
281
+ self.points, self.weights = element.prototype.instantiate_quadrature(order, family)
282
+ self.count = wp.constant(len(self.points))
283
+
284
+ @cache.cached_arg_value
285
+ def arg_value(self, device):
286
+ arg = RegularQuadrature.Arg()
287
+
288
+ # pause graph capture while we copy from host
289
+ # we want the cached result to be available outside of the graph
290
+ if device.is_capturing:
291
+ graph = wp.context.capture_pause()
292
+ else:
293
+ graph = None
294
+
295
+ arg.points = wp.array(self.points, device=device, dtype=Coords)
296
+ arg.weights = wp.array(self.weights, device=device, dtype=float)
297
+
298
+ if graph is not None:
299
+ wp.context.capture_resume(graph)
300
+ return arg
301
+
302
+ def fill_arg(self, arg: "RegularQuadrature.Arg", device):
303
+ arg.assign(self.arg_value(device))
304
+
305
+ @staticmethod
306
+ def get(element: Element, order: int, family: Polynomial):
307
+ key = (element.value, order, family)
308
+ try:
309
+ return RegularQuadrature.CachedFormula._cache[key]
310
+ except KeyError:
311
+ quadrature = RegularQuadrature.CachedFormula(element, order, family)
312
+ RegularQuadrature.CachedFormula._cache[key] = quadrature
313
+ return quadrature
314
+
315
+ _dynamic_attribute_constructors: ClassVar = {
316
+ "point_count": lambda obj: obj._make_point_count(),
317
+ "point_index": lambda obj: obj._make_point_index(),
318
+ "point_coords": lambda obj: obj._make_point_coords(),
319
+ "point_weight": lambda obj: obj._make_point_weight(),
320
+ }
321
+
322
+ def __init__(
323
+ self,
324
+ domain: GeometryDomain,
325
+ order: int,
326
+ family: Polynomial = None,
327
+ ):
328
+ self._formula = RegularQuadrature.CachedFormula.get(domain.reference_element(), order, family)
329
+ self.family = family
330
+ self.order = order
331
+
332
+ super().__init__(domain, self._formula.count)
333
+
334
+ cache.setup_dynamic_attributes(self)
335
+
336
+ @cached_property
337
+ def name(self):
338
+ return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
339
+
340
+ def total_point_count(self):
341
+ return self._formula.count * self.domain.element_count()
342
+
343
+ def max_points_per_element(self):
344
+ return self._formula.count
345
+
346
+ @property
347
+ def points(self):
348
+ return self._formula.points
349
+
350
+ @property
351
+ def weights(self):
352
+ return self._formula.weights
353
+
354
+ def fill_arg(self, arg: "RegularQuadrature.Arg", device):
355
+ self._formula.fill_arg(arg, device)
356
+
357
+ def _make_point_count(self):
358
+ N = self._formula.count
359
+
360
+ @cache.dynamic_func(suffix=self.name)
361
+ def point_count(
362
+ elt_arg: self.domain.ElementArg,
363
+ qp_arg: self.Arg,
364
+ domain_element_index: ElementIndex,
365
+ element_index: ElementIndex,
366
+ ):
367
+ return N
368
+
369
+ return point_count
370
+
371
+ def _make_point_coords(self):
372
+ @cache.dynamic_func(suffix=self.name)
373
+ def point_coords(
374
+ elt_arg: self.domain.ElementArg,
375
+ qp_arg: self.Arg,
376
+ domain_element_index: ElementIndex,
377
+ element_index: ElementIndex,
378
+ qp_index: int,
379
+ ):
380
+ return qp_arg.points[qp_index]
381
+
382
+ return point_coords
383
+
384
+ def _make_point_weight(self):
385
+ @cache.dynamic_func(suffix=self.name)
386
+ def point_weight(
387
+ elt_arg: self.domain.ElementArg,
388
+ qp_arg: self.Arg,
389
+ domain_element_index: ElementIndex,
390
+ element_index: ElementIndex,
391
+ qp_index: int,
392
+ ):
393
+ return qp_arg.weights[qp_index]
394
+
395
+ return point_weight
396
+
397
+ def _make_point_index(self):
398
+ N = self._formula.count
399
+
400
+ @cache.dynamic_func(suffix=self.name)
401
+ def point_index(
402
+ elt_arg: self.domain.ElementArg,
403
+ qp_arg: self.Arg,
404
+ domain_element_index: ElementIndex,
405
+ element_index: ElementIndex,
406
+ qp_index: int,
407
+ ):
408
+ return N * domain_element_index + qp_index
409
+
410
+ return point_index
411
+
412
+
413
+ class NodalQuadrature(Quadrature):
414
+ """Quadrature using space node points as quadrature points
415
+
416
+ Note that in contrast to the `assembly="nodal"` flag for :func:`integrate`, using this quadrature does not imply
417
+ any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
418
+ """
419
+
420
+ _dynamic_attribute_constructors: ClassVar = {
421
+ "Arg": lambda obj: obj._make_arg(),
422
+ "point_count": lambda obj: obj._make_point_count(),
423
+ "point_index": lambda obj: obj._make_point_index(),
424
+ "point_coords": lambda obj: obj._make_point_coords(),
425
+ "point_weight": lambda obj: obj._make_point_weight(),
426
+ "point_evaluation_index": lambda obj: obj._make_point_evaluation_index(),
427
+ }
428
+
429
+ def __init__(
430
+ self,
431
+ domain: Optional[GeometryDomain],
432
+ space: Optional[FunctionSpace],
433
+ ):
434
+ self._space = space
435
+
436
+ super().__init__(domain)
437
+
438
+ cache.setup_dynamic_attributes(self)
439
+
440
+ @cached_property
441
+ def name(self):
442
+ return f"{self.__class__.__name__}_{self._space.name}"
443
+
444
+ def total_point_count(self):
445
+ return self._space.node_count()
446
+
447
+ def max_points_per_element(self):
448
+ return self._space.topology.MAX_NODES_PER_ELEMENT
449
+
450
+ def _make_arg(self):
451
+ @cache.dynamic_struct(suffix=self.name)
452
+ class Arg:
453
+ space_arg: self._space.SpaceArg
454
+ topo_arg: self._space.topology.TopologyArg
455
+
456
+ return Arg
457
+
458
+ def fill_arg(self, arg: "NodalQuadrature.Arg", device):
459
+ self._space.fill_space_arg(arg.space_arg, device)
460
+ self._space.topology.fill_topo_arg(arg.topo_arg, device)
461
+
462
+ def _make_point_count(self):
463
+ @cache.dynamic_func(suffix=self.name)
464
+ def point_count(
465
+ elt_arg: self.domain.ElementArg,
466
+ qp_arg: self.Arg,
467
+ domain_element_index: ElementIndex,
468
+ element_index: ElementIndex,
469
+ ):
470
+ return self._space.topology.element_node_count(elt_arg, qp_arg.topo_arg, element_index)
471
+
472
+ return point_count
473
+
474
+ def _make_point_coords(self):
475
+ @cache.dynamic_func(suffix=self.name)
476
+ def point_coords(
477
+ elt_arg: self.domain.ElementArg,
478
+ qp_arg: self.Arg,
479
+ domain_element_index: ElementIndex,
480
+ element_index: ElementIndex,
481
+ qp_index: int,
482
+ ):
483
+ return self._space.node_coords_in_element(elt_arg, qp_arg.space_arg, element_index, qp_index)
484
+
485
+ return point_coords
486
+
487
+ def _make_point_weight(self):
488
+ @cache.dynamic_func(suffix=self.name)
489
+ def point_weight(
490
+ elt_arg: self.domain.ElementArg,
491
+ qp_arg: self.Arg,
492
+ domain_element_index: ElementIndex,
493
+ element_index: ElementIndex,
494
+ qp_index: int,
495
+ ):
496
+ return self._space.node_quadrature_weight(elt_arg, qp_arg.space_arg, element_index, qp_index)
497
+
498
+ return point_weight
499
+
500
+ def _make_point_index(self):
501
+ @cache.dynamic_func(suffix=self.name)
502
+ def point_index(
503
+ elt_arg: self.domain.ElementArg,
504
+ qp_arg: self.Arg,
505
+ domain_element_index: ElementIndex,
506
+ element_index: ElementIndex,
507
+ qp_index: int,
508
+ ):
509
+ node_index = self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
510
+ return node_index
511
+
512
+ return point_index
513
+
514
+ def evaluation_point_count(self):
515
+ return self.domain.element_count() * self._space.topology.MAX_NODES_PER_ELEMENT
516
+
517
+ def _make_point_evaluation_index(self):
518
+ N = self._space.topology.MAX_NODES_PER_ELEMENT
519
+
520
+ @cache.dynamic_func(suffix=self.name)
521
+ def evaluation_point_index(
522
+ elt_arg: self.domain.ElementArg,
523
+ qp_arg: self.Arg,
524
+ domain_element_index: ElementIndex,
525
+ element_index: ElementIndex,
526
+ qp_index: int,
527
+ ):
528
+ return N * domain_element_index + qp_index
529
+
530
+ return evaluation_point_index
531
+
532
+
533
+ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
534
+ """Quadrature using explicit per-cell points and weights.
535
+
536
+ The number of quadrature points per cell is assumed to be constant and deduced from the shape of the points and weights arrays.
537
+ Quadrature points may be provided for either the whole geometry or just the domain's elements.
538
+
539
+ Args:
540
+ domain: Domain of definition of the quadrature formula
541
+ points: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the coordinates of each quadrature point.
542
+ weights: 2d array of shape ``(domain.element_count(), points_per_cell)`` or ``(domain.geometry_element_count(), points_per_cell)`` containing the weight for each quadrature point.
543
+
544
+ See also: :class:`PicQuadrature`
545
+ """
546
+
547
+ @wp.struct
548
+ class Arg:
549
+ points_per_cell: int
550
+ points: wp.array2d(dtype=Coords)
551
+ weights: wp.array2d(dtype=float)
552
+
553
+ def __init__(
554
+ self,
555
+ domain: GeometryDomain,
556
+ points: "wp.array2d(dtype=Coords)",
557
+ weights: "wp.array2d(dtype=float)",
558
+ ):
559
+ if points.shape != weights.shape:
560
+ raise ValueError("Points and weights arrays must have the same shape")
561
+
562
+ if points.shape[0] == domain.geometry_element_count():
563
+ self.point_index = ExplicitQuadrature._point_index_geo
564
+ self.point_coords = ExplicitQuadrature._point_coords_geo
565
+ self.point_weight = ExplicitQuadrature._point_weight_geo
566
+ elif points.shape[0] == domain.element_count():
567
+ self.point_index = ExplicitQuadrature._point_index_domain
568
+ self.point_coords = ExplicitQuadrature._point_coords_domain
569
+ self.point_weight = ExplicitQuadrature._point_weight_domain
570
+ else:
571
+ raise NotImplementedError(
572
+ "The number of rows of points and weights must match the element count of either the domain or the geometry"
573
+ )
574
+
575
+ self._points_per_cell = points.shape[1]
576
+
577
+ self._whole_geo = points.shape[0] == domain.geometry_element_count()
578
+
579
+ super().__init__(domain, self._points_per_cell)
580
+ self._points = points
581
+ self._weights = weights
582
+
583
+ @cached_property
584
+ def name(self):
585
+ return f"{self.__class__.__name__}_{self._whole_geo}_{self._points_per_cell}"
586
+
587
+ def total_point_count(self):
588
+ return self._weights.size
589
+
590
+ def max_points_per_element(self):
591
+ return self._points_per_cell
592
+
593
+ def fill_arg(self, arg: "ExplicitQuadrature.Arg", device):
594
+ arg.points_per_cell = self._points_per_cell
595
+ arg.points = self._points.to(device)
596
+ arg.weights = self._weights.to(device)
597
+
598
+ @wp.func
599
+ def point_count(
600
+ elt_arg: Any,
601
+ qp_arg: Arg,
602
+ domain_element_index: ElementIndex,
603
+ element_index: ElementIndex,
604
+ ):
605
+ return qp_arg.points.shape[1]
606
+
607
+ @wp.func
608
+ def _point_coords_domain(
609
+ elt_arg: Any,
610
+ qp_arg: Arg,
611
+ domain_element_index: ElementIndex,
612
+ element_index: ElementIndex,
613
+ qp_index: int,
614
+ ):
615
+ return qp_arg.points[domain_element_index, qp_index]
616
+
617
+ @wp.func
618
+ def _point_weight_domain(
619
+ elt_arg: Any,
620
+ qp_arg: Arg,
621
+ domain_element_index: ElementIndex,
622
+ element_index: ElementIndex,
623
+ qp_index: int,
624
+ ):
625
+ return qp_arg.weights[domain_element_index, qp_index]
626
+
627
+ @wp.func
628
+ def _point_index_domain(
629
+ elt_arg: Any,
630
+ qp_arg: Arg,
631
+ domain_element_index: ElementIndex,
632
+ element_index: ElementIndex,
633
+ qp_index: int,
634
+ ):
635
+ return qp_arg.points_per_cell * domain_element_index + qp_index
636
+
637
+ @wp.func
638
+ def _point_coords_geo(
639
+ elt_arg: Any,
640
+ qp_arg: Arg,
641
+ domain_element_index: ElementIndex,
642
+ element_index: ElementIndex,
643
+ qp_index: int,
644
+ ):
645
+ return qp_arg.points[element_index, qp_index]
646
+
647
+ @wp.func
648
+ def _point_weight_geo(
649
+ elt_arg: Any,
650
+ qp_arg: Arg,
651
+ domain_element_index: ElementIndex,
652
+ element_index: ElementIndex,
653
+ qp_index: int,
654
+ ):
655
+ return qp_arg.weights[element_index, qp_index]
656
+
657
+ @wp.func
658
+ def _point_index_geo(
659
+ elt_arg: Any,
660
+ qp_arg: Arg,
661
+ domain_element_index: ElementIndex,
662
+ element_index: ElementIndex,
663
+ qp_index: int,
664
+ ):
665
+ return qp_arg.points_per_cell * element_index + qp_index