warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/fem/operator.py ADDED
@@ -0,0 +1,191 @@
1
+ import inspect
2
+ from typing import Callable, Any
3
+
4
+ import warp as wp
5
+
6
+ from warp.fem.types import Domain, Field, Sample
7
+ from warp.fem import utils
8
+
9
+
10
+ class Integrand:
11
+ """An integrand is a device function containing arbitrary expressions over Field and Domain variables.
12
+ It will get transformed to a proper warp.Function by resolving concrete Field types at call time.
13
+ """
14
+
15
+ def __init__(self, func: Callable):
16
+ self.func = func
17
+ self.name = wp.codegen.make_full_qualified_name(self.func)
18
+ self.module = wp.get_module(self.func.__module__)
19
+ self.argspec = inspect.getfullargspec(self.func)
20
+
21
+
22
+ class Operator:
23
+ """
24
+ Operators provide syntaxic sugar over Field and Domain evaluation functions and arguments
25
+ """
26
+
27
+ def __init__(self, func: Callable, resolver: Callable):
28
+ self.func = func
29
+ self.resolver = resolver
30
+
31
+
32
+ def integrand(func: Callable):
33
+ """Decorator for functions to be integrated (or interpolated) using warp.fem"""
34
+ itg = Integrand(func)
35
+ itg.__doc__ = func.__doc__
36
+ return itg
37
+
38
+
39
+ def operator(resolver: Callable):
40
+ """Decorator for functions operating on Field-like or Domain-like data inside warp.fem integrands"""
41
+
42
+ def wrap_operator(func: Callable):
43
+ op = Operator(func, resolver)
44
+ op.__doc__ = func.__doc__
45
+ return op
46
+
47
+ return wrap_operator
48
+
49
+
50
+ # Domain operators
51
+
52
+
53
+ @operator(resolver=lambda dmn: dmn.element_position)
54
+ def position(domain: Domain, s: Sample):
55
+ """Evaluates the world position of the sample point `s`"""
56
+ pass
57
+
58
+
59
+ @operator(resolver=lambda dmn: dmn.eval_normal)
60
+ def normal(domain: Domain, s: Sample):
61
+ """Evaluates the element normal at the sample point `s`. Null for interior points."""
62
+ pass
63
+
64
+
65
+ @operator(resolver=lambda dmn: dmn.element_deformation_gradient)
66
+ def deformation_gradient(domain: Domain, s: Sample):
67
+ """Evaluates the gradient of the domain position with respect to the element reference space at the sample point `s`"""
68
+ pass
69
+
70
+
71
+ @operator(resolver=lambda dmn: dmn.element_lookup)
72
+ def lookup(domain: Domain, x: Any) -> Sample:
73
+ """Looks-up the sample point corresponding to a world position `x`, projecting to the closest point on the domain.
74
+
75
+ Arg:
76
+ x: world position of the point to look-up in the geometry
77
+ guess: (optional) :class:`Sample` initial guess, may help perform the query
78
+
79
+ Notes:
80
+ Currently this operator is only fully supported for :class:`Grid2D` and :class:`Grid3D` geometries.
81
+ For :class:`TriangleMesh2D` and :class:`Tetmesh` geometries, the operator requires providing `guess`.
82
+ """
83
+ pass
84
+
85
+
86
+ @operator(resolver=lambda dmn: dmn.element_measure)
87
+ def measure(domain: Domain, s: Sample) -> float:
88
+ """Returns the measure (volume, area, or length) determinant of an element at a sample point `s`"""
89
+ pass
90
+
91
+
92
+ @operator(resolver=lambda dmn: dmn.element_measure_ratio)
93
+ def measure_ratio(domain: Domain, s: Sample) -> float:
94
+ """Returns the maximum ratio between the measure of this element and that of higher-dimensional neighbours."""
95
+ pass
96
+
97
+
98
+ # Field operators
99
+ # On a side, inner and outer are such that normal goes from inner to outer
100
+
101
+
102
+ @operator(resolver=lambda f: f.eval_inner)
103
+ def inner(f: Field, s: Sample):
104
+ """Evaluates the field at a sample point `s`. On oriented sides, uses the inner element"""
105
+ pass
106
+
107
+
108
+ @operator(resolver=lambda f: f.eval_grad_inner)
109
+ def grad(f: Field, s: Sample):
110
+ """Evaluates the field gradient at a sample point `s`. On oriented sides, uses the inner element"""
111
+ pass
112
+
113
+
114
+ @operator(resolver=lambda f: f.eval_div_inner)
115
+ def div(f: Field, s: Sample):
116
+ """Evaluates the field divergence at a sample point `s`. On oriented sides, uses the inner element"""
117
+ pass
118
+
119
+
120
+ @operator(resolver=lambda f: f.eval_outer)
121
+ def outer(f: Field, s: Sample):
122
+ """Evaluates the field at a sample point `s`. On oriented sides, uses the outer element. On interior points and on domain boundaries, this is equivalent to :func:`inner`."""
123
+ pass
124
+
125
+
126
+ @operator(resolver=lambda f: f.eval_grad_outer)
127
+ def grad_outer(f: Field, s: Sample):
128
+ """Evaluates the field gradient at a sample point `s`. On oriented sides, uses the outer element. On interior points and on domain boundaries, this is equivalent to :func:`grad`."""
129
+ pass
130
+
131
+
132
+ @operator(resolver=lambda f: f.eval_grad_outer)
133
+ def div_outer(f: Field, s: Sample):
134
+ """Evaluates the field divergence at a sample point `s`. On oriented sides, uses the outer element. On interior points and on domain boundaries, this is equivalent to :func:`div`."""
135
+ pass
136
+
137
+
138
+ @operator(resolver=lambda f: f.eval_degree)
139
+ def degree(f: Field):
140
+ """Polynomial degree of a field"""
141
+ pass
142
+
143
+
144
+ @operator(resolver=lambda f: f.at_node)
145
+ def at_node(f: Field, s: Sample):
146
+ """For a Test or Trial field, returns a copy of the Sample `s` moved to the coordinates of the node being evaluated"""
147
+ pass
148
+
149
+
150
+ # Common derived operators, for convenience
151
+
152
+
153
+ @integrand
154
+ def D(f: Field, s: Sample):
155
+ """Symmetric part of the (inner) gradient of the field at `s`"""
156
+ return utils.symmetric_part(grad(f, s))
157
+
158
+
159
+ @integrand
160
+ def curl(f: Field, s: Sample):
161
+ """Skew part of the (inner) gradient of the field at `s`, as a vector such that ``wp.cross(curl(u), v) = skew(grad(u)) v``"""
162
+ return utils.skew_part(grad(f, s))
163
+
164
+
165
+ @integrand
166
+ def jump(f: Field, s: Sample):
167
+ """Jump between inner and outer element values on an interior side. Zero for interior points or domain boundaries"""
168
+ return inner(f, s) - outer(f, s)
169
+
170
+
171
+ @integrand
172
+ def average(f: Field, s: Sample):
173
+ """Average between inner and outer element values"""
174
+ return 0.5 * (inner(f, s) + outer(f, s))
175
+
176
+
177
+ @integrand
178
+ def grad_jump(f: Field, s: Sample):
179
+ """Jump between inner and outer element gradients on an interior side. Zero for interior points or domain boundaries"""
180
+ return grad(f, s) - grad_outer(f, s)
181
+
182
+
183
+ @integrand
184
+ def grad_average(f: Field, s: Sample):
185
+ """Average between inner and outer element gradients"""
186
+ return 0.5 * (grad(f, s) + grad_outer(f, s))
187
+
188
+
189
+ # Set default call operators for argument types, so that field(s) = inner(field, s) and domain(s) = position(domain, s)
190
+ Field.call_operator = inner
191
+ Domain.call_operator = position
warp/fem/polynomial.py ADDED
@@ -0,0 +1,213 @@
1
+ import math
2
+ from enum import Enum
3
+
4
+ import numpy as np
5
+
6
+
7
+ class Polynomial(Enum):
8
+ """Polynomial family defining interpolation nodes over an interval"""
9
+
10
+ GAUSS_LEGENDRE = 0
11
+ """Gauss--Legendre 1D polynomial family (does not include endpoints)"""
12
+
13
+ LOBATTO_GAUSS_LEGENDRE = 1
14
+ """Lobatto--Gauss--Legendre 1D polynomial family (includes endpoints)"""
15
+
16
+ EQUISPACED_CLOSED = 2
17
+ """Closed 1D polynomial family with uniformly distributed nodes (includes endpoints)"""
18
+
19
+ EQUISPACED_OPEN = 3
20
+ """Open 1D polynomial family with uniformly distributed nodes (does not include endpoints)"""
21
+
22
+ def __str__(self):
23
+ return self.name
24
+
25
+ def is_closed(family: Polynomial):
26
+ """Whether the polynomial roots include interval endpoints"""
27
+ return family == Polynomial.LOBATTO_GAUSS_LEGENDRE or family == Polynomial.EQUISPACED_CLOSED
28
+
29
+
30
+ def _gauss_legendre_quadrature_1d(n: int):
31
+ if n == 1:
32
+ coords = [0.0]
33
+ weights = [2.0]
34
+ elif n == 2:
35
+ coords = [-math.sqrt(1.0 / 3), math.sqrt(1.0 / 3)]
36
+ weights = [1.0, 1.0]
37
+ elif n == 3:
38
+ coords = [0.0, -math.sqrt(3.0 / 5.0), math.sqrt(3.0 / 5.0)]
39
+ weights = [8.0 / 9.0, 5.0 / 9.0, 5.0 / 9.0]
40
+ elif n == 4:
41
+ c_a = math.sqrt(3.0 / 7.0 - 2.0 / 7.0 * math.sqrt(6.0 / 5.0))
42
+ c_b = math.sqrt(3.0 / 7.0 + 2.0 / 7.0 * math.sqrt(6.0 / 5.0))
43
+ w_a = (18.0 + math.sqrt(30.0)) / 36.0
44
+ w_b = (18.0 - math.sqrt(30.0)) / 36.0
45
+ coords = [c_a, -c_a, c_b, -c_b]
46
+ weights = [w_a, w_a, w_b, w_b]
47
+ elif n == 5:
48
+ c_a = 1.0 / 3.0 * math.sqrt(5.0 - 2.0 * math.sqrt(10.0 / 7.0))
49
+ c_b = 1.0 / 3.0 * math.sqrt(5.0 + 2.0 * math.sqrt(10.0 / 7.0))
50
+ w_a = (322.0 + 13.0 * math.sqrt(70.0)) / 900.0
51
+ w_b = (322.0 - 13.0 * math.sqrt(70.0)) / 900.0
52
+ coords = [0.0, c_a, -c_a, c_b, -c_b]
53
+ weights = [128.0 / 225.0, w_a, w_a, w_b, w_b]
54
+ else:
55
+ raise NotImplementedError
56
+
57
+ # Shift from [-1, 1] to [0, 1]
58
+ weights = 0.5 * np.array(weights)
59
+ coords = 0.5 * np.array(coords) + 0.5
60
+
61
+ return coords, weights
62
+
63
+
64
+ def _lobatto_gauss_legendre_quadrature_1d(n: int):
65
+ if n == 2:
66
+ coords = [-1.0, 1.0]
67
+ weights = [1.0, 1.0]
68
+ elif n == 3:
69
+ coords = [-1.0, 0.0, 1.0]
70
+ weights = [1.0 / 3.0, 4.0 / 3.0, 1.0 / 3.0]
71
+ elif n == 4:
72
+ coords = [-1.0, -1.0 / math.sqrt(5.0), 1.0 / math.sqrt(5.0), 1.0]
73
+ weights = [1.0 / 6.0, 5.0 / 6.0, 5.0 / 6.0, 1.0 / 6.0]
74
+ elif n == 5:
75
+ coords = [-1.0, -math.sqrt(3.0 / 7.0), 0.0, math.sqrt(3.0 / 7.0), 1.0]
76
+ weights = [1.0 / 10.0, 49.0 / 90.0, 32.0 / 45.0, 49.0 / 90.0, 1.0 / 10.0]
77
+ else:
78
+ raise NotImplementedError
79
+
80
+ # Shift from [-1, 1] to [0, 1]
81
+ weights = 0.5 * np.array(weights)
82
+ coords = 0.5 * np.array(coords) + 0.5
83
+
84
+ return coords, weights
85
+
86
+
87
+ def _uniform_open_quadrature_1d(n: int):
88
+ step = 1.0 / (n + 1)
89
+ coords = np.linspace(step, 1.0 - step, n)
90
+ weights = np.full(n, 1.0 / (n + 1))
91
+
92
+ # Boundaries have 3/2 the weight
93
+ weights[0] = 1.5 / (n + 1)
94
+ weights[-1] = 1.5 / (n + 1)
95
+
96
+ return coords, weights
97
+
98
+
99
+ def _uniform_closed_quadrature_1d(n: int):
100
+ coords = np.linspace(0.0, 1.0, n)
101
+ weights = np.full(n, 1.0 / (n - 1))
102
+
103
+ # Boundaries have half the weight
104
+ weights[0] = 0.5 / (n - 1)
105
+ weights[-1] = 0.5 / (n - 1)
106
+
107
+ return coords, weights
108
+
109
+
110
+ def _open_newton_cotes_quadrature_1d(n: int):
111
+ step = 1.0 / (n + 1)
112
+ coords = np.linspace(step, 1.0 - step, n)
113
+
114
+ # Weisstein, Eric W. "Newton-Cotes Formulas." From MathWorld--A Wolfram Web Resource.
115
+ # https://mathworld.wolfram.com/Newton-CotesFormulas.html
116
+
117
+ if n == 1:
118
+ weights = np.array([1.0])
119
+ elif n == 2:
120
+ weights = np.array([0.5, 0.5])
121
+ elif n == 3:
122
+ weights = np.array([2.0, -1.0, 2.0]) / 3.0
123
+ elif n == 4:
124
+ weights = np.array([11.0, 1.0, 1.0, 11.0]) / 24.0
125
+ elif n == 5:
126
+ weights = np.array([11.0, -14.0, 26.0, -14.0, 11.0]) / 20.0
127
+ elif n == 6:
128
+ weights = np.array([611.0, -453.0, 562.0, 562.0, -453.0, 611.0]) / 1440.0
129
+ elif n == 7:
130
+ weights = np.array([460.0, -954.0, 2196.0, -2459.0, 2196.0, -954.0, 460.0]) / 945.0
131
+ else:
132
+ raise NotImplementedError
133
+
134
+ return coords, weights
135
+
136
+
137
+ def _closed_newton_cotes_quadrature_1d(n: int):
138
+ coords = np.linspace(0.0, 1.0, n)
139
+
140
+ # OEIS: A093735, A093736
141
+
142
+ if n == 2:
143
+ weights = np.array([1.0, 1.0]) / 2.0
144
+ elif n == 3:
145
+ weights = np.array([1.0, 4.0, 1.0]) / 3.0
146
+ elif n == 4:
147
+ weights = np.array([3.0, 9.0, 9.0, 3.0]) / 8.0
148
+ elif n == 5:
149
+ weights = np.array([14.0, 64.0, 24.0, 64.0, 14.0]) / 45.0
150
+ elif n == 6:
151
+ weights = np.array([95.0 / 288.0, 125.0 / 96.0, 125.0 / 144.0, 125.0 / 144.0, 125.0 / 96.0, 95.0 / 288.0])
152
+ elif n == 7:
153
+ weights = np.array([41, 54, 27, 68, 27, 54, 41], dtype=float) / np.array(
154
+ [140, 35, 140, 35, 140, 35, 140], dtype=float
155
+ )
156
+ elif n == 8:
157
+ weights = np.array(
158
+ [
159
+ 5257,
160
+ 25039,
161
+ 343,
162
+ 20923,
163
+ 20923,
164
+ 343,
165
+ 25039,
166
+ 5257,
167
+ ]
168
+ ) / np.array(
169
+ [
170
+ 17280,
171
+ 17280,
172
+ 640,
173
+ 17280,
174
+ 17280,
175
+ 640,
176
+ 17280,
177
+ 17280,
178
+ ],
179
+ dtype=float,
180
+ )
181
+ else:
182
+ raise NotImplementedError
183
+
184
+ # Normalize with interval length
185
+ weights = weights / (n - 1)
186
+
187
+ return coords, weights
188
+
189
+
190
+ def quadrature_1d(point_count: int, family: Polynomial):
191
+ """Return quadrature points and weights for the given family and point count"""
192
+
193
+ if family == Polynomial.GAUSS_LEGENDRE:
194
+ return _gauss_legendre_quadrature_1d(point_count)
195
+ if family == Polynomial.LOBATTO_GAUSS_LEGENDRE:
196
+ return _lobatto_gauss_legendre_quadrature_1d(point_count)
197
+ if family == Polynomial.EQUISPACED_CLOSED:
198
+ return _closed_newton_cotes_quadrature_1d(point_count)
199
+ if family == Polynomial.EQUISPACED_OPEN:
200
+ return _open_newton_cotes_quadrature_1d(point_count)
201
+
202
+ raise NotImplementedError
203
+
204
+
205
+ def lagrange_scales(coords: np.array):
206
+ """Return the scaling factors for Lagrange polynomials with roots at coords"""
207
+ lagrange_scale = np.empty_like(coords)
208
+ for i in range(len(coords)):
209
+ deltas = coords[i] - coords
210
+ deltas[i] = 1.0
211
+ lagrange_scale[i] = 1.0 / np.prod(deltas)
212
+
213
+ return lagrange_scale
@@ -0,0 +1,2 @@
1
+ from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, ExplicitQuadrature
2
+ from .pic_quadrature import PicQuadrature
@@ -0,0 +1,245 @@
1
+ from typing import Union, Tuple, Any, Optional
2
+
3
+ import warp as wp
4
+
5
+ from warp.fem.domain import GeometryDomain
6
+ from warp.fem.types import ElementIndex, Coords, make_free_sample
7
+ from warp.fem.utils import compress_node_indices
8
+ from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary, dynamic_kernel
9
+
10
+ from .quadrature import Quadrature
11
+
12
+
13
+ wp.set_module_options({"enable_backward": False})
14
+
15
+
16
+ class PicQuadrature(Quadrature):
17
+ """Particle-based quadrature formula, using a global set of points unevenly spread out over geometry elements.
18
+
19
+ Useful for Particle-In-Cell and derived methods.
20
+
21
+ Args:
22
+ domain: Underlying domain for the quadrature
23
+ positions: Either an array containing the world positions of all particles, or a tuple of arrays containing
24
+ the cell indices and coordinates for each particle. Note that the former requires the underlying geometry to
25
+ define a global :meth:`Geometry.cell_lookup` method; currently this is only available for :class:`Grid2D` and :class:`Grid3D`.
26
+ measures: Array containing the measure (area/volume) of each particle, used to defined the integration weights.
27
+ If ``None``, defaults to the cell measure divided by the number of particles in the cell.
28
+ temporary_store: shared pool from which to allocate temporary arrays
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ domain: GeometryDomain,
34
+ positions: Union[
35
+ "wp.array(dtype=wp.vecXd)",
36
+ Tuple[
37
+ "wp.array(dtype=ElementIndex)",
38
+ "wp.array(dtype=Coords)",
39
+ ],
40
+ ],
41
+ measures: Optional["wp.array(dtype=float)"] = None,
42
+ temporary_store: TemporaryStore = None,
43
+ ):
44
+ super().__init__(domain)
45
+
46
+ self._bin_particles(positions, measures, temporary_store)
47
+
48
+ @property
49
+ def name(self):
50
+ return f"{self.__class__.__name__}"
51
+
52
+ @Quadrature.domain.setter
53
+ def domain(self, domain: GeometryDomain):
54
+ # Allow changing the quadrature domain as long as underlying geometry and element kind are the same
55
+ if self.domain is not None and (
56
+ domain.geometry != self.domain.geometry or domain.element_kind != self.domain.element_kind
57
+ ):
58
+ raise RuntimeError(
59
+ "Cannot change the domain to that of a different Geometry and/or using different element kinds."
60
+ )
61
+
62
+ self._domain = domain
63
+
64
+ @wp.struct
65
+ class Arg:
66
+ cell_particle_offsets: wp.array(dtype=int)
67
+ cell_particle_indices: wp.array(dtype=int)
68
+ particle_fraction: wp.array(dtype=float)
69
+ particle_coords: wp.array(dtype=Coords)
70
+
71
+ @cached_arg_value
72
+ def arg_value(self, device) -> Arg:
73
+ arg = PicQuadrature.Arg()
74
+ arg.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
75
+ arg.cell_particle_indices = self._cell_particle_indices.array.to(device)
76
+ arg.particle_fraction = self._particle_fraction.to(device)
77
+ arg.particle_coords = self._particle_coords.to(device)
78
+ return arg
79
+
80
+ def total_point_count(self):
81
+ return self._particle_coords.shape[0]
82
+
83
+ def active_cell_count(self):
84
+ """Number of cells containing at least one particle"""
85
+ return self._cell_count
86
+
87
+ @wp.func
88
+ def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
89
+ return qp_arg.cell_particle_offsets[element_index + 1] - qp_arg.cell_particle_offsets[element_index]
90
+
91
+ @wp.func
92
+ def point_coords(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
93
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
94
+ return qp_arg.particle_coords[particle_index]
95
+
96
+ @wp.func
97
+ def point_weight(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
98
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
99
+ return qp_arg.particle_fraction[particle_index]
100
+
101
+ @wp.func
102
+ def point_index(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
103
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
104
+ return particle_index
105
+
106
+ def fill_element_mask(self, mask: "wp.array(dtype=int)"):
107
+ """Fills a mask array such that all non-empty elements are set to 1, all empty elements to zero.
108
+
109
+ Args:
110
+ mask: Int warp array with size at least equal to `self.domain.geometry_element_count()`
111
+ """
112
+
113
+ wp.launch(
114
+ kernel=PicQuadrature._fill_mask_kernel,
115
+ dim=self.domain.geometry_element_count(),
116
+ device=mask.device,
117
+ inputs=[self._cell_particle_offsets.array, mask],
118
+ )
119
+
120
+ @wp.kernel
121
+ def _fill_mask_kernel(
122
+ element_particle_offsets: wp.array(dtype=int),
123
+ element_mask: wp.array(dtype=int),
124
+ ):
125
+ i = wp.tid()
126
+ element_mask[i] = wp.select(element_particle_offsets[i] == element_particle_offsets[i + 1], 1, 0)
127
+
128
+ @wp.kernel
129
+ def _compute_uniform_fraction(
130
+ cell_index: wp.array(dtype=ElementIndex),
131
+ cell_particle_offsets: wp.array(dtype=int),
132
+ cell_fraction: wp.array(dtype=float),
133
+ ):
134
+ p = wp.tid()
135
+
136
+ cell = cell_index[p]
137
+ cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
138
+
139
+ cell_fraction[p] = 1.0 / float(cell_particle_count)
140
+
141
+ def _bin_particles(self, positions, measures, temporary_store: TemporaryStore):
142
+ if wp.types.is_array(positions):
143
+ # Initialize from positions
144
+ @dynamic_kernel(suffix=f"{self.domain.name}")
145
+ def bin_particles(
146
+ cell_arg_value: self.domain.ElementArg,
147
+ positions: wp.array(dtype=positions.dtype),
148
+ cell_index: wp.array(dtype=ElementIndex),
149
+ cell_coords: wp.array(dtype=Coords),
150
+ ):
151
+ p = wp.tid()
152
+ sample = self.domain.element_lookup(cell_arg_value, positions[p])
153
+
154
+ cell_index[p] = sample.element_index
155
+ cell_coords[p] = sample.element_coords
156
+
157
+ device = positions.device
158
+
159
+ cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
160
+ cell_index = cell_index_temp.array
161
+
162
+ self._particle_coords_temp = borrow_temporary(
163
+ temporary_store, shape=positions.shape, dtype=Coords, device=device
164
+ )
165
+ self._particle_coords = self._particle_coords_temp.array
166
+
167
+ wp.launch(
168
+ dim=positions.shape[0],
169
+ kernel=bin_particles,
170
+ inputs=[
171
+ self.domain.element_arg_value(device),
172
+ positions,
173
+ cell_index,
174
+ self._particle_coords,
175
+ ],
176
+ device=device,
177
+ )
178
+
179
+ else:
180
+ cell_index, self._particle_coords = positions
181
+ if cell_index.shape != self._particle_coords.shape:
182
+ raise ValueError("Cell index and coordinates arrays must have the same shape")
183
+
184
+ cell_index_temp = None
185
+ self._particle_coords_temp = None
186
+
187
+ self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
188
+ self.domain.geometry_element_count(), cell_index
189
+ )
190
+
191
+ self._compute_fraction(cell_index, measures, temporary_store)
192
+
193
+ def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStore):
194
+ device = cell_index.device
195
+
196
+ self._particle_fraction_temp = borrow_temporary(
197
+ temporary_store, shape=cell_index.shape, dtype=float, device=device
198
+ )
199
+ self._particle_fraction = self._particle_fraction_temp.array
200
+
201
+ if measures is None:
202
+ # Split fraction uniformly over all particles in cell
203
+
204
+ wp.launch(
205
+ dim=cell_index.shape,
206
+ kernel=PicQuadrature._compute_uniform_fraction,
207
+ inputs=[
208
+ cell_index,
209
+ self._cell_particle_offsets.array,
210
+ self._particle_fraction,
211
+ ],
212
+ device=device,
213
+ )
214
+
215
+ else:
216
+ # Fraction from particle measure
217
+
218
+ if measures.shape != cell_index.shape:
219
+ raise ValueError("Measures should be an 1d array or length equal to particle count")
220
+
221
+ @dynamic_kernel(suffix=f"{self.domain.name}")
222
+ def compute_fraction(
223
+ cell_arg_value: self.domain.ElementArg,
224
+ measures: wp.array(dtype=float),
225
+ cell_index: wp.array(dtype=ElementIndex),
226
+ cell_coords: wp.array(dtype=Coords),
227
+ cell_fraction: wp.array(dtype=float),
228
+ ):
229
+ p = wp.tid()
230
+ sample = make_free_sample(cell_index[p], cell_coords[p])
231
+
232
+ cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
233
+
234
+ wp.launch(
235
+ dim=measures.shape[0],
236
+ kernel=compute_fraction,
237
+ inputs=[
238
+ self.domain.element_arg_value(device),
239
+ measures,
240
+ cell_index,
241
+ self._particle_coords,
242
+ self._particle_fraction,
243
+ ],
244
+ device=device,
245
+ )