warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/fem/polynomial.py ADDED
@@ -0,0 +1,229 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from enum import Enum
18
+
19
+ import numpy as np
20
+
21
+
22
+ class Polynomial(Enum):
23
+ """Polynomial family defining interpolation nodes over an interval"""
24
+
25
+ GAUSS_LEGENDRE = "GL"
26
+ """Gauss--Legendre 1D polynomial family (does not include endpoints)"""
27
+
28
+ LOBATTO_GAUSS_LEGENDRE = "LGL"
29
+ """Lobatto--Gauss--Legendre 1D polynomial family (includes endpoints)"""
30
+
31
+ EQUISPACED_CLOSED = "closed"
32
+ """Closed 1D polynomial family with uniformly distributed nodes (includes endpoints)"""
33
+
34
+ EQUISPACED_OPEN = "open"
35
+ """Open 1D polynomial family with uniformly distributed nodes (does not include endpoints)"""
36
+
37
+ def __str__(self):
38
+ return self.value
39
+
40
+
41
+ def is_closed(family: Polynomial):
42
+ """Whether the polynomial roots include interval endpoints"""
43
+ return family == Polynomial.LOBATTO_GAUSS_LEGENDRE or family == Polynomial.EQUISPACED_CLOSED
44
+
45
+
46
+ def _gauss_legendre_quadrature_1d(n: int):
47
+ if n == 1:
48
+ coords = [0.0]
49
+ weights = [2.0]
50
+ elif n == 2:
51
+ coords = [-math.sqrt(1.0 / 3), math.sqrt(1.0 / 3)]
52
+ weights = [1.0, 1.0]
53
+ elif n == 3:
54
+ coords = [0.0, -math.sqrt(3.0 / 5.0), math.sqrt(3.0 / 5.0)]
55
+ weights = [8.0 / 9.0, 5.0 / 9.0, 5.0 / 9.0]
56
+ elif n == 4:
57
+ c_a = math.sqrt(3.0 / 7.0 - 2.0 / 7.0 * math.sqrt(6.0 / 5.0))
58
+ c_b = math.sqrt(3.0 / 7.0 + 2.0 / 7.0 * math.sqrt(6.0 / 5.0))
59
+ w_a = (18.0 + math.sqrt(30.0)) / 36.0
60
+ w_b = (18.0 - math.sqrt(30.0)) / 36.0
61
+ coords = [c_a, -c_a, c_b, -c_b]
62
+ weights = [w_a, w_a, w_b, w_b]
63
+ elif n == 5:
64
+ c_a = 1.0 / 3.0 * math.sqrt(5.0 - 2.0 * math.sqrt(10.0 / 7.0))
65
+ c_b = 1.0 / 3.0 * math.sqrt(5.0 + 2.0 * math.sqrt(10.0 / 7.0))
66
+ w_a = (322.0 + 13.0 * math.sqrt(70.0)) / 900.0
67
+ w_b = (322.0 - 13.0 * math.sqrt(70.0)) / 900.0
68
+ coords = [0.0, c_a, -c_a, c_b, -c_b]
69
+ weights = [128.0 / 225.0, w_a, w_a, w_b, w_b]
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ # Shift from [-1, 1] to [0, 1]
74
+ weights = 0.5 * np.array(weights)
75
+ coords = 0.5 * np.array(coords) + 0.5
76
+
77
+ return coords, weights
78
+
79
+
80
+ def _lobatto_gauss_legendre_quadrature_1d(n: int):
81
+ if n == 2:
82
+ coords = [-1.0, 1.0]
83
+ weights = [1.0, 1.0]
84
+ elif n == 3:
85
+ coords = [-1.0, 0.0, 1.0]
86
+ weights = [1.0 / 3.0, 4.0 / 3.0, 1.0 / 3.0]
87
+ elif n == 4:
88
+ coords = [-1.0, -1.0 / math.sqrt(5.0), 1.0 / math.sqrt(5.0), 1.0]
89
+ weights = [1.0 / 6.0, 5.0 / 6.0, 5.0 / 6.0, 1.0 / 6.0]
90
+ elif n == 5:
91
+ coords = [-1.0, -math.sqrt(3.0 / 7.0), 0.0, math.sqrt(3.0 / 7.0), 1.0]
92
+ weights = [1.0 / 10.0, 49.0 / 90.0, 32.0 / 45.0, 49.0 / 90.0, 1.0 / 10.0]
93
+ else:
94
+ raise NotImplementedError
95
+
96
+ # Shift from [-1, 1] to [0, 1]
97
+ weights = 0.5 * np.array(weights)
98
+ coords = 0.5 * np.array(coords) + 0.5
99
+
100
+ return coords, weights
101
+
102
+
103
+ def _uniform_open_quadrature_1d(n: int):
104
+ step = 1.0 / (n + 1)
105
+ coords = np.linspace(step, 1.0 - step, n)
106
+ weights = np.full(n, 1.0 / (n + 1))
107
+
108
+ # Boundaries have 3/2 the weight
109
+ weights[0] = 1.5 / (n + 1)
110
+ weights[-1] = 1.5 / (n + 1)
111
+
112
+ return coords, weights
113
+
114
+
115
+ def _uniform_closed_quadrature_1d(n: int):
116
+ coords = np.linspace(0.0, 1.0, n)
117
+ weights = np.full(n, 1.0 / (n - 1))
118
+
119
+ # Boundaries have half the weight
120
+ weights[0] = 0.5 / (n - 1)
121
+ weights[-1] = 0.5 / (n - 1)
122
+
123
+ return coords, weights
124
+
125
+
126
+ def _open_newton_cotes_quadrature_1d(n: int):
127
+ step = 1.0 / (n + 1)
128
+ coords = np.linspace(step, 1.0 - step, n)
129
+
130
+ # Weisstein, Eric W. "Newton-Cotes Formulas." From MathWorld--A Wolfram Web Resource.
131
+ # https://mathworld.wolfram.com/Newton-CotesFormulas.html
132
+
133
+ if n == 1:
134
+ weights = np.array([1.0])
135
+ elif n == 2:
136
+ weights = np.array([0.5, 0.5])
137
+ elif n == 3:
138
+ weights = np.array([2.0, -1.0, 2.0]) / 3.0
139
+ elif n == 4:
140
+ weights = np.array([11.0, 1.0, 1.0, 11.0]) / 24.0
141
+ elif n == 5:
142
+ weights = np.array([11.0, -14.0, 26.0, -14.0, 11.0]) / 20.0
143
+ elif n == 6:
144
+ weights = np.array([611.0, -453.0, 562.0, 562.0, -453.0, 611.0]) / 1440.0
145
+ elif n == 7:
146
+ weights = np.array([460.0, -954.0, 2196.0, -2459.0, 2196.0, -954.0, 460.0]) / 945.0
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ return coords, weights
151
+
152
+
153
+ def _closed_newton_cotes_quadrature_1d(n: int):
154
+ coords = np.linspace(0.0, 1.0, n)
155
+
156
+ # OEIS: A093735, A093736
157
+
158
+ if n == 2:
159
+ weights = np.array([1.0, 1.0]) / 2.0
160
+ elif n == 3:
161
+ weights = np.array([1.0, 4.0, 1.0]) / 3.0
162
+ elif n == 4:
163
+ weights = np.array([3.0, 9.0, 9.0, 3.0]) / 8.0
164
+ elif n == 5:
165
+ weights = np.array([14.0, 64.0, 24.0, 64.0, 14.0]) / 45.0
166
+ elif n == 6:
167
+ weights = np.array([95.0 / 288.0, 125.0 / 96.0, 125.0 / 144.0, 125.0 / 144.0, 125.0 / 96.0, 95.0 / 288.0])
168
+ elif n == 7:
169
+ weights = np.array([41, 54, 27, 68, 27, 54, 41], dtype=float) / np.array(
170
+ [140, 35, 140, 35, 140, 35, 140], dtype=float
171
+ )
172
+ elif n == 8:
173
+ weights = np.array(
174
+ [
175
+ 5257,
176
+ 25039,
177
+ 343,
178
+ 20923,
179
+ 20923,
180
+ 343,
181
+ 25039,
182
+ 5257,
183
+ ]
184
+ ) / np.array(
185
+ [
186
+ 17280,
187
+ 17280,
188
+ 640,
189
+ 17280,
190
+ 17280,
191
+ 640,
192
+ 17280,
193
+ 17280,
194
+ ],
195
+ dtype=float,
196
+ )
197
+ else:
198
+ raise NotImplementedError
199
+
200
+ # Normalize with interval length
201
+ weights = weights / (n - 1)
202
+
203
+ return coords, weights
204
+
205
+
206
+ def quadrature_1d(point_count: int, family: Polynomial):
207
+ """Return quadrature points and weights for the given family and point count"""
208
+
209
+ if family == Polynomial.GAUSS_LEGENDRE:
210
+ return _gauss_legendre_quadrature_1d(point_count)
211
+ if family == Polynomial.LOBATTO_GAUSS_LEGENDRE:
212
+ return _lobatto_gauss_legendre_quadrature_1d(point_count)
213
+ if family == Polynomial.EQUISPACED_CLOSED:
214
+ return _closed_newton_cotes_quadrature_1d(point_count)
215
+ if family == Polynomial.EQUISPACED_OPEN:
216
+ return _open_newton_cotes_quadrature_1d(point_count)
217
+
218
+ raise NotImplementedError
219
+
220
+
221
+ def lagrange_scales(coords: np.array):
222
+ """Return the scaling factors for Lagrange polynomials with roots at coords"""
223
+ lagrange_scale = np.empty_like(coords)
224
+ for i in range(len(coords)):
225
+ deltas = coords[i] - coords
226
+ deltas[i] = 1.0
227
+ lagrange_scale[i] = 1.0 / np.prod(deltas)
228
+
229
+ return lagrange_scale
@@ -0,0 +1,17 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .pic_quadrature import PicQuadrature
17
+ from .quadrature import ExplicitQuadrature, NodalQuadrature, Quadrature, RegularQuadrature
@@ -0,0 +1,299 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional, Tuple, Union
17
+
18
+ import warp as wp
19
+ from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_kernel
20
+ from warp.fem.domain import GeometryDomain
21
+ from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, make_free_sample
22
+ from warp.fem.utils import compress_node_indices
23
+
24
+ from .quadrature import Quadrature
25
+
26
+
27
+ class PicQuadrature(Quadrature):
28
+ """Particle-based quadrature formula, using a global set of points unevenly spread out over geometry elements.
29
+
30
+ Useful for Particle-In-Cell and derived methods.
31
+
32
+ Args:
33
+ domain: Underlying domain for the quadrature
34
+ positions: Either an array containing the world positions of all particles, or a tuple of arrays containing
35
+ the cell indices and coordinates for each particle. Note that the former requires the underlying geometry to
36
+ define a global :meth:`Geometry.cell_lookup` method; currently this is only available for :class:`Grid2D` and :class:`Grid3D`.
37
+ measures: Array containing the measure (area/volume) of each particle, used to defined the integration weights.
38
+ If ``None``, defaults to the cell measure divided by the number of particles in the cell.
39
+ requires_grad: Whether gradients should be allocated for the computed quantities
40
+ temporary_store: shared pool from which to allocate temporary arrays
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ domain: GeometryDomain,
46
+ positions: Union[
47
+ "wp.array(dtype=wp.vecXd)",
48
+ Tuple[
49
+ "wp.array(dtype=ElementIndex)",
50
+ "wp.array(dtype=Coords)",
51
+ ],
52
+ ],
53
+ measures: Optional["wp.array(dtype=float)"] = None,
54
+ requires_grad: bool = False,
55
+ temporary_store: TemporaryStore = None,
56
+ ):
57
+ super().__init__(domain)
58
+
59
+ self._requires_grad = requires_grad
60
+ self._bin_particles(positions, measures, temporary_store)
61
+ self._max_particles_per_cell: int = None
62
+
63
+ @property
64
+ def name(self):
65
+ return f"{self.__class__.__name__}"
66
+
67
+ @Quadrature.domain.setter
68
+ def domain(self, domain: GeometryDomain):
69
+ # Allow changing the quadrature domain as long as underlying geometry and element kind are the same
70
+ if self.domain is not None and (
71
+ domain.element_kind != self.domain.element_kind or domain.geometry.base != self.domain.geometry.base
72
+ ):
73
+ raise RuntimeError(
74
+ "The new domain must use the same base geometry and kind of elements as the current one."
75
+ )
76
+
77
+ self._domain = domain
78
+
79
+ @wp.struct
80
+ class Arg:
81
+ cell_particle_offsets: wp.array(dtype=int)
82
+ cell_particle_indices: wp.array(dtype=int)
83
+ particle_fraction: wp.array(dtype=float)
84
+ particle_coords: wp.array(dtype=Coords)
85
+
86
+ @cached_arg_value
87
+ def arg_value(self, device) -> Arg:
88
+ arg = PicQuadrature.Arg()
89
+ arg.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
90
+ arg.cell_particle_indices = self._cell_particle_indices.array.to(device)
91
+ arg.particle_fraction = self._particle_fraction.to(device)
92
+ arg.particle_coords = self.particle_coords.to(device)
93
+ return arg
94
+
95
+ def total_point_count(self):
96
+ return self.particle_coords.shape[0]
97
+
98
+ def active_cell_count(self):
99
+ """Number of cells containing at least one particle"""
100
+ return self._cell_count
101
+
102
+ def max_points_per_element(self):
103
+ if self._max_particles_per_cell is None:
104
+ max_ppc = wp.zeros(shape=(1,), dtype=int, device=self._cell_particle_offsets.array.device)
105
+ wp.launch(
106
+ PicQuadrature._max_particles_per_cell_kernel,
107
+ self._cell_particle_offsets.array.shape[0] - 1,
108
+ device=max_ppc.device,
109
+ inputs=[self._cell_particle_offsets.array, max_ppc],
110
+ )
111
+ self._max_particles_per_cell = int(max_ppc.numpy()[0])
112
+ return self._max_particles_per_cell
113
+
114
+ @wp.func
115
+ def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
116
+ return qp_arg.cell_particle_offsets[element_index + 1] - qp_arg.cell_particle_offsets[element_index]
117
+
118
+ @wp.func
119
+ def point_coords(
120
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
121
+ ):
122
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
123
+ return qp_arg.particle_coords[particle_index]
124
+
125
+ @wp.func
126
+ def point_weight(
127
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
128
+ ):
129
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
130
+ return qp_arg.particle_fraction[particle_index]
131
+
132
+ @wp.func
133
+ def point_index(
134
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
135
+ ):
136
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
137
+ return particle_index
138
+
139
+ @wp.func
140
+ def point_evaluation_index(
141
+ elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, index: int
142
+ ):
143
+ return qp_arg.cell_particle_offsets[element_index] + index
144
+
145
+ def fill_element_mask(self, mask: "wp.array(dtype=int)"):
146
+ """Fills a mask array such that all non-empty elements are set to 1, all empty elements to zero.
147
+
148
+ Args:
149
+ mask: Int warp array with size at least equal to `self.domain.geometry_element_count()`
150
+ """
151
+
152
+ wp.launch(
153
+ kernel=PicQuadrature._fill_mask_kernel,
154
+ dim=self.domain.geometry_element_count(),
155
+ device=mask.device,
156
+ inputs=[self._cell_particle_offsets.array, mask],
157
+ )
158
+
159
+ @wp.kernel
160
+ def _fill_mask_kernel(
161
+ element_particle_offsets: wp.array(dtype=int),
162
+ element_mask: wp.array(dtype=int),
163
+ ):
164
+ i = wp.tid()
165
+ element_mask[i] = wp.where(element_particle_offsets[i] == element_particle_offsets[i + 1], 0, 1)
166
+
167
+ @wp.kernel
168
+ def _compute_uniform_fraction(
169
+ cell_index: wp.array(dtype=ElementIndex),
170
+ cell_particle_offsets: wp.array(dtype=int),
171
+ cell_fraction: wp.array(dtype=float),
172
+ ):
173
+ p = wp.tid()
174
+
175
+ cell = cell_index[p]
176
+ if cell == NULL_ELEMENT_INDEX:
177
+ cell_fraction[p] = 0.0
178
+ else:
179
+ cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
180
+ cell_fraction[p] = 1.0 / float(cell_particle_count)
181
+
182
+ def _bin_particles(self, positions, measures, temporary_store: TemporaryStore):
183
+ if wp.types.is_array(positions):
184
+ # Initialize from positions
185
+ @dynamic_kernel(suffix=f"{self.domain.name}")
186
+ def bin_particles(
187
+ cell_arg_value: self.domain.ElementArg,
188
+ positions: wp.array(dtype=positions.dtype),
189
+ cell_index: wp.array(dtype=ElementIndex),
190
+ cell_coords: wp.array(dtype=Coords),
191
+ ):
192
+ p = wp.tid()
193
+ sample = self.domain.element_lookup(cell_arg_value, positions[p])
194
+
195
+ cell_index[p] = sample.element_index
196
+ cell_coords[p] = sample.element_coords
197
+
198
+ device = positions.device
199
+
200
+ self._cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
201
+ self.cell_indices = self._cell_index_temp.array
202
+
203
+ self._particle_coords_temp = borrow_temporary(
204
+ temporary_store, shape=positions.shape, dtype=Coords, device=device, requires_grad=self._requires_grad
205
+ )
206
+ self.particle_coords = self._particle_coords_temp.array
207
+
208
+ wp.launch(
209
+ dim=positions.shape[0],
210
+ kernel=bin_particles,
211
+ inputs=[
212
+ self.domain.element_arg_value(device),
213
+ positions,
214
+ self.cell_indices,
215
+ self.particle_coords,
216
+ ],
217
+ device=device,
218
+ )
219
+
220
+ else:
221
+ self.cell_indices, self.particle_coords = positions
222
+ if self.cell_indices.shape != self.particle_coords.shape:
223
+ raise ValueError("Cell index and coordinates arrays must have the same shape")
224
+
225
+ self._cell_index_temp = None
226
+ self._particle_coords_temp = None
227
+
228
+ self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
229
+ self.domain.geometry_element_count(),
230
+ self.cell_indices,
231
+ return_unique_nodes=True,
232
+ temporary_store=temporary_store,
233
+ )
234
+
235
+ self._compute_fraction(self.cell_indices, measures, temporary_store)
236
+
237
+ def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStore):
238
+ device = cell_index.device
239
+
240
+ self._particle_fraction_temp = borrow_temporary(
241
+ temporary_store, shape=cell_index.shape, dtype=float, device=device, requires_grad=self._requires_grad
242
+ )
243
+ self._particle_fraction = self._particle_fraction_temp.array
244
+
245
+ if measures is None:
246
+ # Split fraction uniformly over all particles in cell
247
+
248
+ wp.launch(
249
+ dim=cell_index.shape,
250
+ kernel=PicQuadrature._compute_uniform_fraction,
251
+ inputs=[
252
+ cell_index,
253
+ self._cell_particle_offsets.array,
254
+ self._particle_fraction,
255
+ ],
256
+ device=device,
257
+ )
258
+
259
+ else:
260
+ # Fraction from particle measure
261
+
262
+ if measures.shape != cell_index.shape:
263
+ raise ValueError("Measures should be an 1d array or length equal to particle count")
264
+
265
+ @dynamic_kernel(suffix=f"{self.domain.name}")
266
+ def compute_fraction(
267
+ cell_arg_value: self.domain.ElementArg,
268
+ measures: wp.array(dtype=float),
269
+ cell_index: wp.array(dtype=ElementIndex),
270
+ cell_coords: wp.array(dtype=Coords),
271
+ cell_fraction: wp.array(dtype=float),
272
+ ):
273
+ p = wp.tid()
274
+
275
+ cell = cell_index[p]
276
+ if cell == NULL_ELEMENT_INDEX:
277
+ cell_fraction[p] = 0.0
278
+ else:
279
+ sample = make_free_sample(cell_index[p], cell_coords[p])
280
+ cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
281
+
282
+ wp.launch(
283
+ dim=measures.shape[0],
284
+ kernel=compute_fraction,
285
+ inputs=[
286
+ self.domain.element_arg_value(device),
287
+ measures,
288
+ cell_index,
289
+ self.particle_coords,
290
+ self._particle_fraction,
291
+ ],
292
+ device=device,
293
+ )
294
+
295
+ @wp.kernel
296
+ def _max_particles_per_cell_kernel(offsets: wp.array(dtype=int), max_count: wp.array(dtype=int)):
297
+ cell = wp.tid()
298
+ particle_count = offsets[cell + 1] - offsets[cell]
299
+ wp.atomic_max(max_count, 0, particle_count)