warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
1
+ from typing import Optional
2
+
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.types import OUTSIDE, Coords, ElementIndex, Sample, make_free_sample
11
+
12
+ from .element import LinearEdge, Square
13
+ from .geometry import Geometry
14
+
15
+ # from .closest_point import project_on_tet_at_origin
16
+
17
+
18
+ @wp.struct
19
+ class Quadmesh2DCellArg:
20
+ quad_vertex_indices: wp.array2d(dtype=int)
21
+ positions: wp.array(dtype=wp.vec2)
22
+
23
+ # for neighbor cell lookup
24
+ vertex_quad_offsets: wp.array(dtype=int)
25
+ vertex_quad_indices: wp.array(dtype=int)
26
+
27
+
28
+ @wp.struct
29
+ class Quadmesh2DSideArg:
30
+ cell_arg: Quadmesh2DCellArg
31
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
32
+ edge_quad_indices: wp.array(dtype=wp.vec2i)
33
+
34
+
35
+ class Quadmesh2D(Geometry):
36
+ """Two-dimensional quadrilateral mesh geometry"""
37
+
38
+ dimension = 2
39
+
40
+ def __init__(
41
+ self, quad_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
42
+ ):
43
+ """
44
+ Constructs a two-dimensional quadrilateral mesh.
45
+
46
+ Args:
47
+ quad_vertex_indices: warp array of shape (num_tris, 4) containing vertex indices for each quad, in counter-clockwise order
48
+ positions: warp array of shape (num_vertices, 2) containing 2d position for each vertex
49
+ temporary_store: shared pool from which to allocate temporary arrays
50
+ """
51
+
52
+ self.quad_vertex_indices = quad_vertex_indices
53
+ self.positions = positions
54
+
55
+ self._edge_vertex_indices: wp.array = None
56
+ self._edge_quad_indices: wp.array = None
57
+ self._vertex_quad_offsets: wp.array = None
58
+ self._vertex_quad_indices: wp.array = None
59
+ self._build_topology(temporary_store)
60
+
61
+ def cell_count(self):
62
+ return self.quad_vertex_indices.shape[0]
63
+
64
+ def vertex_count(self):
65
+ return self.positions.shape[0]
66
+
67
+ def side_count(self):
68
+ return self._edge_vertex_indices.shape[0]
69
+
70
+ def boundary_side_count(self):
71
+ return self._boundary_edge_indices.shape[0]
72
+
73
+ def reference_cell(self) -> Square:
74
+ return Square()
75
+
76
+ def reference_side(self) -> LinearEdge:
77
+ return LinearEdge()
78
+
79
+ @property
80
+ def edge_quad_indices(self) -> wp.array:
81
+ return self._edge_quad_indices
82
+
83
+ @property
84
+ def edge_vertex_indices(self) -> wp.array:
85
+ return self._edge_vertex_indices
86
+
87
+ CellArg = Quadmesh2DCellArg
88
+ SideArg = Quadmesh2DSideArg
89
+
90
+ @wp.struct
91
+ class SideIndexArg:
92
+ boundary_edge_indices: wp.array(dtype=int)
93
+
94
+ # Geometry device interface
95
+
96
+ @cached_arg_value
97
+ def cell_arg_value(self, device) -> CellArg:
98
+ args = self.CellArg()
99
+
100
+ args.quad_vertex_indices = self.quad_vertex_indices.to(device)
101
+ args.positions = self.positions.to(device)
102
+ args.vertex_quad_offsets = self._vertex_quad_offsets.to(device)
103
+ args.vertex_quad_indices = self._vertex_quad_indices.to(device)
104
+
105
+ return args
106
+
107
+ @wp.func
108
+ def cell_position(args: CellArg, s: Sample):
109
+ quad_idx = args.quad_vertex_indices[s.element_index]
110
+
111
+ w_p = s.element_coords
112
+ w_m = Coords(1.0) - s.element_coords
113
+
114
+ # 0 : m m
115
+ # 1 : p m
116
+ # 2 : p p
117
+ # 3 : m p
118
+
119
+ return (
120
+ w_m[0] * w_m[1] * args.positions[quad_idx[0]]
121
+ + w_p[0] * w_m[1] * args.positions[quad_idx[1]]
122
+ + w_p[0] * w_p[1] * args.positions[quad_idx[2]]
123
+ + w_m[0] * w_p[1] * args.positions[quad_idx[3]]
124
+ )
125
+
126
+ @wp.func
127
+ def cell_deformation_gradient(cell_arg: CellArg, s: Sample):
128
+ """Deformation gradient at `coords`"""
129
+ quad_idx = cell_arg.quad_vertex_indices[s.element_index]
130
+
131
+ w_p = s.element_coords
132
+ w_m = Coords(1.0) - s.element_coords
133
+
134
+ return (
135
+ wp.outer(cell_arg.positions[quad_idx[0]], wp.vec2(-w_m[1], -w_m[0]))
136
+ + wp.outer(cell_arg.positions[quad_idx[1]], wp.vec2(w_m[1], -w_p[0]))
137
+ + wp.outer(cell_arg.positions[quad_idx[2]], wp.vec2(w_p[1], w_p[0]))
138
+ + wp.outer(cell_arg.positions[quad_idx[3]], wp.vec2(-w_p[1], w_m[0]))
139
+ )
140
+
141
+ @wp.func
142
+ def cell_inverse_deformation_gradient(cell_arg: CellArg, s: Sample):
143
+ return wp.inverse(Quadmesh2D.cell_deformation_gradient(cell_arg, s))
144
+
145
+ @wp.func
146
+ def cell_measure(args: CellArg, s: Sample):
147
+ return wp.abs(wp.determinant(Quadmesh2D.cell_deformation_gradient(args, s)))
148
+
149
+ @wp.func
150
+ def cell_normal(args: CellArg, s: Sample):
151
+ return wp.vec2(0.0)
152
+
153
+ @cached_arg_value
154
+ def side_index_arg_value(self, device) -> SideIndexArg:
155
+ args = self.SideIndexArg()
156
+
157
+ args.boundary_edge_indices = self._boundary_edge_indices.to(device)
158
+
159
+ return args
160
+
161
+ @wp.func
162
+ def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
163
+ """Boundary side to side index"""
164
+
165
+ return args.boundary_edge_indices[boundary_side_index]
166
+
167
+ @cached_arg_value
168
+ def side_arg_value(self, device) -> CellArg:
169
+ args = self.SideArg()
170
+
171
+ args.cell_arg = self.cell_arg_value(device)
172
+ args.edge_vertex_indices = self._edge_vertex_indices.to(device)
173
+ args.edge_quad_indices = self._edge_quad_indices.to(device)
174
+
175
+ return args
176
+
177
+ @wp.func
178
+ def side_position(args: SideArg, s: Sample):
179
+ edge_idx = args.edge_vertex_indices[s.element_index]
180
+ return (1.0 - s.element_coords[0]) * args.cell_arg.positions[edge_idx[0]] + s.element_coords[
181
+ 0
182
+ ] * args.cell_arg.positions[edge_idx[1]]
183
+
184
+ @wp.func
185
+ def side_deformation_gradient(args: SideArg, s: Sample):
186
+ edge_idx = args.edge_vertex_indices[s.element_index]
187
+ v0 = args.cell_arg.positions[edge_idx[0]]
188
+ v1 = args.cell_arg.positions[edge_idx[1]]
189
+ return v1 - v0
190
+
191
+ @wp.func
192
+ def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
193
+ cell_index = Quadmesh2D.side_inner_cell_index(args, s.element_index)
194
+ cell_coords = Quadmesh2D.side_inner_cell_coords(args, s.element_index, s.element_coords)
195
+ return Quadmesh2D.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
196
+
197
+ @wp.func
198
+ def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
199
+ cell_index = Quadmesh2D.side_outer_cell_index(args, s.element_index)
200
+ cell_coords = Quadmesh2D.side_outer_cell_coords(args, s.element_index, s.element_coords)
201
+ return Quadmesh2D.cell_inverse_deformation_gradient(args.cell_arg, make_free_sample(cell_index, cell_coords))
202
+
203
+ @wp.func
204
+ def side_measure(args: SideArg, s: Sample):
205
+ edge_idx = args.edge_vertex_indices[s.element_index]
206
+ v0 = args.cell_arg.positions[edge_idx[0]]
207
+ v1 = args.cell_arg.positions[edge_idx[1]]
208
+ return wp.length(v1 - v0)
209
+
210
+ @wp.func
211
+ def side_measure_ratio(args: SideArg, s: Sample):
212
+ inner = Quadmesh2D.side_inner_cell_index(args, s.element_index)
213
+ outer = Quadmesh2D.side_outer_cell_index(args, s.element_index)
214
+ inner_coords = Quadmesh2D.side_inner_cell_coords(args, s.element_index, s.element_coords)
215
+ outer_coords = Quadmesh2D.side_outer_cell_coords(args, s.element_index, s.element_coords)
216
+ return Quadmesh2D.side_measure(args, s) / wp.min(
217
+ Quadmesh2D.cell_measure(args.cell_arg, make_free_sample(inner, inner_coords)),
218
+ Quadmesh2D.cell_measure(args.cell_arg, make_free_sample(outer, outer_coords)),
219
+ )
220
+
221
+ @wp.func
222
+ def side_normal(args: SideArg, s: Sample):
223
+ edge_idx = args.edge_vertex_indices[s.element_index]
224
+ v0 = args.cell_arg.positions[edge_idx[0]]
225
+ v1 = args.cell_arg.positions[edge_idx[1]]
226
+ e = v1 - v0
227
+
228
+ return wp.normalize(wp.vec2(-e[1], e[0]))
229
+
230
+ @wp.func
231
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
232
+ return arg.edge_quad_indices[side_index][0]
233
+
234
+ @wp.func
235
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
236
+ return arg.edge_quad_indices[side_index][1]
237
+
238
+ @wp.func
239
+ def edge_to_quad_coords(args: SideArg, side_index: ElementIndex, quad_index: ElementIndex, side_coords: Coords):
240
+ edge_vidx = args.edge_vertex_indices[side_index]
241
+ quad_vidx = args.cell_arg.quad_vertex_indices[quad_index]
242
+
243
+ vs = edge_vidx[0]
244
+ ve = edge_vidx[1]
245
+
246
+ s = side_coords[0]
247
+
248
+ if vs == quad_vidx[0]:
249
+ return wp.select(ve == quad_vidx[1], Coords(0.0, s, 0.0), Coords(s, 0.0, 0.0))
250
+ elif vs == quad_vidx[1]:
251
+ return wp.select(ve == quad_vidx[2], Coords(1.0 - s, 0.0, 0.0), Coords(1.0, s, 0.0))
252
+ elif vs == quad_vidx[2]:
253
+ return wp.select(ve == quad_vidx[3], Coords(1.0, 1.0 - s, 0.0), Coords(1.0 - s, 1.0, 0.0))
254
+
255
+ return wp.select(ve == quad_vidx[0], Coords(s, 1.0, 0.0), Coords(0.0, 1.0 - s, 0.0))
256
+
257
+ @wp.func
258
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
259
+ inner_cell_index = Quadmesh2D.side_inner_cell_index(args, side_index)
260
+ return Quadmesh2D.edge_to_quad_coords(args, side_index, inner_cell_index, side_coords)
261
+
262
+ @wp.func
263
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
264
+ outer_cell_index = Quadmesh2D.side_outer_cell_index(args, side_index)
265
+ return Quadmesh2D.edge_to_quad_coords(args, side_index, outer_cell_index, side_coords)
266
+
267
+ @wp.func
268
+ def side_from_cell_coords(
269
+ args: SideArg,
270
+ side_index: ElementIndex,
271
+ quad_index: ElementIndex,
272
+ quad_coords: Coords,
273
+ ):
274
+ edge_vidx = args.edge_vertex_indices[side_index]
275
+ quad_vidx = args.cell_arg.quad_vertex_indices[quad_index]
276
+
277
+ vs = edge_vidx[0]
278
+ ve = edge_vidx[1]
279
+
280
+ cx = quad_coords[0]
281
+ cy = quad_coords[1]
282
+
283
+ if vs == quad_vidx[0]:
284
+ oc = wp.select(ve == quad_vidx[1], cx, cy)
285
+ ec = wp.select(ve == quad_vidx[1], cy, cx)
286
+ elif vs == quad_vidx[1]:
287
+ oc = wp.select(ve == quad_vidx[2], cy, 1.0 - cx)
288
+ ec = wp.select(ve == quad_vidx[2], 1.0 - cx, cy)
289
+ elif vs == quad_vidx[2]:
290
+ oc = wp.select(ve == quad_vidx[3], 1.0 - cx, 1.0 - cy)
291
+ ec = wp.select(ve == quad_vidx[3], 1.0 - cy, 1.0 - cx)
292
+ else:
293
+ oc = wp.select(ve == quad_vidx[0], 1.0 - cy, cx)
294
+ ec = wp.select(ve == quad_vidx[0], cx, 1.0 - cy)
295
+ return wp.select(oc == 0.0, Coords(OUTSIDE), Coords(ec, 0.0, 0.0))
296
+
297
+ @wp.func
298
+ def side_to_cell_arg(side_arg: SideArg):
299
+ return side_arg.cell_arg
300
+
301
+ def _build_topology(self, temporary_store: TemporaryStore):
302
+ from warp.fem.utils import compress_node_indices, masked_indices
303
+ from warp.utils import array_scan
304
+
305
+ device = self.quad_vertex_indices.device
306
+
307
+ vertex_quad_offsets, vertex_quad_indices, _, __ = compress_node_indices(
308
+ self.vertex_count(), self.quad_vertex_indices, temporary_store=temporary_store
309
+ )
310
+ self._vertex_quad_offsets = vertex_quad_offsets.detach()
311
+ self._vertex_quad_indices = vertex_quad_indices.detach()
312
+
313
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
314
+ vertex_start_edge_count.array.zero_()
315
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
316
+
317
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(4 * self.cell_count()))
318
+ vertex_edge_quads = borrow_temporary(
319
+ temporary_store, dtype=int, device=device, shape=(4 * self.cell_count(), 2)
320
+ )
321
+
322
+ # Count face edges starting at each vertex
323
+ wp.launch(
324
+ kernel=Quadmesh2D._count_starting_edges_kernel,
325
+ device=device,
326
+ dim=self.cell_count(),
327
+ inputs=[self.quad_vertex_indices, vertex_start_edge_count.array],
328
+ )
329
+
330
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
331
+
332
+ # Count number of unique edges (deduplicate across faces)
333
+ vertex_unique_edge_count = vertex_start_edge_count
334
+ wp.launch(
335
+ kernel=Quadmesh2D._count_unique_starting_edges_kernel,
336
+ device=device,
337
+ dim=self.vertex_count(),
338
+ inputs=[
339
+ self._vertex_quad_offsets,
340
+ self._vertex_quad_indices,
341
+ self.quad_vertex_indices,
342
+ vertex_start_edge_offsets.array,
343
+ vertex_unique_edge_count.array,
344
+ vertex_edge_ends.array,
345
+ vertex_edge_quads.array,
346
+ ],
347
+ )
348
+
349
+ vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
350
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
351
+
352
+ # Get back edge count to host
353
+ if device.is_cuda:
354
+ edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
355
+ # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
356
+ wp.copy(
357
+ dest=edge_count.array, src=vertex_unique_edge_offsets.array, src_offset=self.vertex_count() - 1, count=1
358
+ )
359
+ wp.synchronize_stream(wp.get_stream(device))
360
+ edge_count = int(edge_count.array.numpy()[0])
361
+ else:
362
+ edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
363
+
364
+ self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
365
+ self._edge_quad_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
366
+
367
+ boundary_mask = borrow_temporary(temporary_store=temporary_store, shape=(edge_count,), dtype=int, device=device)
368
+
369
+ # Compress edge data
370
+ wp.launch(
371
+ kernel=Quadmesh2D._compress_edges_kernel,
372
+ device=device,
373
+ dim=self.vertex_count(),
374
+ inputs=[
375
+ vertex_start_edge_offsets.array,
376
+ vertex_unique_edge_offsets.array,
377
+ vertex_unique_edge_count.array,
378
+ vertex_edge_ends.array,
379
+ vertex_edge_quads.array,
380
+ self._edge_vertex_indices,
381
+ self._edge_quad_indices,
382
+ boundary_mask.array,
383
+ ],
384
+ )
385
+
386
+ vertex_start_edge_offsets.release()
387
+ vertex_unique_edge_offsets.release()
388
+ vertex_unique_edge_count.release()
389
+ vertex_edge_ends.release()
390
+ vertex_edge_quads.release()
391
+
392
+ # Flip normals if necessary
393
+ wp.launch(
394
+ kernel=Quadmesh2D._flip_edge_normals,
395
+ device=device,
396
+ dim=self.side_count(),
397
+ inputs=[self._edge_vertex_indices, self._edge_quad_indices, self.quad_vertex_indices, self.positions],
398
+ )
399
+
400
+ boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
401
+ self._boundary_edge_indices = boundary_edge_indices.detach()
402
+
403
+ boundary_mask.release()
404
+
405
+ @wp.kernel
406
+ def _count_starting_edges_kernel(
407
+ quad_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
408
+ ):
409
+ t = wp.tid()
410
+ for k in range(4):
411
+ v0 = quad_vertex_indices[t, k]
412
+ v1 = quad_vertex_indices[t, (k + 1) % 4]
413
+
414
+ if v0 < v1:
415
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
416
+ else:
417
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
418
+
419
+ @wp.func
420
+ def _find(
421
+ needle: int,
422
+ values: wp.array(dtype=int),
423
+ beg: int,
424
+ end: int,
425
+ ):
426
+ for i in range(beg, end):
427
+ if values[i] == needle:
428
+ return i
429
+
430
+ return -1
431
+
432
+ @wp.kernel
433
+ def _count_unique_starting_edges_kernel(
434
+ vertex_quad_offsets: wp.array(dtype=int),
435
+ vertex_quad_indices: wp.array(dtype=int),
436
+ quad_vertex_indices: wp.array2d(dtype=int),
437
+ vertex_start_edge_offsets: wp.array(dtype=int),
438
+ vertex_start_edge_count: wp.array(dtype=int),
439
+ edge_ends: wp.array(dtype=int),
440
+ edge_quads: wp.array2d(dtype=int),
441
+ ):
442
+ v = wp.tid()
443
+
444
+ edge_beg = vertex_start_edge_offsets[v]
445
+
446
+ quad_beg = vertex_quad_offsets[v]
447
+ quad_end = vertex_quad_offsets[v + 1]
448
+
449
+ edge_cur = edge_beg
450
+
451
+ for quad in range(quad_beg, quad_end):
452
+ q = vertex_quad_indices[quad]
453
+
454
+ for k in range(4):
455
+ v0 = quad_vertex_indices[q, k]
456
+ v1 = quad_vertex_indices[q, (k + 1) % 4]
457
+
458
+ if v == wp.min(v0, v1):
459
+ other_v = wp.max(v0, v1)
460
+
461
+ # Check if other_v has been seen
462
+ seen_idx = Quadmesh2D._find(other_v, edge_ends, edge_beg, edge_cur)
463
+
464
+ if seen_idx == -1:
465
+ edge_ends[edge_cur] = other_v
466
+ edge_quads[edge_cur, 0] = q
467
+ edge_quads[edge_cur, 1] = q
468
+ edge_cur += 1
469
+ else:
470
+ edge_quads[seen_idx, 1] = q
471
+
472
+ vertex_start_edge_count[v] = edge_cur - edge_beg
473
+
474
+ @wp.kernel
475
+ def _compress_edges_kernel(
476
+ vertex_start_edge_offsets: wp.array(dtype=int),
477
+ vertex_unique_edge_offsets: wp.array(dtype=int),
478
+ vertex_unique_edge_count: wp.array(dtype=int),
479
+ uncompressed_edge_ends: wp.array(dtype=int),
480
+ uncompressed_edge_quads: wp.array2d(dtype=int),
481
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
482
+ edge_quad_indices: wp.array(dtype=wp.vec2i),
483
+ boundary_mask: wp.array(dtype=int),
484
+ ):
485
+ v = wp.tid()
486
+
487
+ start_beg = vertex_start_edge_offsets[v]
488
+ unique_beg = vertex_unique_edge_offsets[v]
489
+ unique_count = vertex_unique_edge_count[v]
490
+
491
+ for e in range(unique_count):
492
+ src_index = start_beg + e
493
+ edge_index = unique_beg + e
494
+
495
+ edge_vertex_indices[edge_index] = wp.vec2i(v, uncompressed_edge_ends[src_index])
496
+
497
+ q0 = uncompressed_edge_quads[src_index, 0]
498
+ q1 = uncompressed_edge_quads[src_index, 1]
499
+ edge_quad_indices[edge_index] = wp.vec2i(q0, q1)
500
+ if q0 == q1:
501
+ boundary_mask[edge_index] = 1
502
+ else:
503
+ boundary_mask[edge_index] = 0
504
+
505
+ @wp.kernel
506
+ def _flip_edge_normals(
507
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
508
+ edge_quad_indices: wp.array(dtype=wp.vec2i),
509
+ quad_vertex_indices: wp.array2d(dtype=int),
510
+ positions: wp.array(dtype=wp.vec2),
511
+ ):
512
+ e = wp.tid()
513
+
514
+ tri = edge_quad_indices[e][0]
515
+
516
+ quad_vidx = quad_vertex_indices[tri]
517
+ edge_vidx = edge_vertex_indices[e]
518
+
519
+ quad_centroid = (
520
+ positions[quad_vidx[0]] + positions[quad_vidx[1]] + positions[quad_vidx[2]] + positions[quad_vidx[3]]
521
+ ) / 4.0
522
+
523
+ v0 = positions[edge_vidx[0]]
524
+ v1 = positions[edge_vidx[1]]
525
+
526
+ edge_center = 0.5 * (v1 + v0)
527
+ edge_vec = v1 - v0
528
+ edge_normal = wp.vec2(-edge_vec[1], edge_vec[0])
529
+
530
+ # if edge normal points toward first triangle centroid, flip indices
531
+ if wp.dot(quad_centroid - edge_center, edge_normal) > 0.0:
532
+ edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])