warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,577 @@
1
+ from typing import Optional
2
+
3
+ import warp as wp
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
10
+ from warp.fem.types import (
11
+ NULL_ELEMENT_INDEX,
12
+ OUTSIDE,
13
+ Coords,
14
+ ElementIndex,
15
+ Sample,
16
+ make_free_sample,
17
+ )
18
+
19
+ from .closest_point import project_on_tri_at_origin
20
+ from .element import LinearEdge, Triangle
21
+ from .geometry import Geometry
22
+
23
+
24
+ @wp.struct
25
+ class Trimesh2DCellArg:
26
+ tri_vertex_indices: wp.array2d(dtype=int)
27
+ positions: wp.array(dtype=wp.vec2)
28
+
29
+ # for neighbor cell lookup
30
+ vertex_tri_offsets: wp.array(dtype=int)
31
+ vertex_tri_indices: wp.array(dtype=int)
32
+
33
+ deformation_gradients: wp.array(dtype=wp.mat22f)
34
+
35
+
36
+ @wp.struct
37
+ class Trimesh2DSideArg:
38
+ cell_arg: Trimesh2DCellArg
39
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
40
+ edge_tri_indices: wp.array(dtype=wp.vec2i)
41
+
42
+
43
+ class Trimesh2D(Geometry):
44
+ """Two-dimensional triangular mesh geometry"""
45
+
46
+ dimension = 2
47
+
48
+ def __init__(
49
+ self, tri_vertex_indices: wp.array, positions: wp.array, temporary_store: Optional[TemporaryStore] = None
50
+ ):
51
+ """
52
+ Constructs a two-dimensional triangular mesh.
53
+
54
+ Args:
55
+ tri_vertex_indices: warp array of shape (num_tris, 3) containing vertex indices for each tri
56
+ positions: warp array of shape (num_vertices, 2) containing 2d position for each vertex
57
+ temporary_store: shared pool from which to allocate temporary arrays
58
+ """
59
+
60
+ self.tri_vertex_indices = tri_vertex_indices
61
+ self.positions = positions
62
+
63
+ self._edge_vertex_indices: wp.array = None
64
+ self._edge_tri_indices: wp.array = None
65
+ self._vertex_tri_offsets: wp.array = None
66
+ self._vertex_tri_indices: wp.array = None
67
+ self._build_topology(temporary_store)
68
+
69
+ self._deformation_gradients: wp.array = None
70
+ self._compute_deformation_gradients()
71
+
72
+ def cell_count(self):
73
+ return self.tri_vertex_indices.shape[0]
74
+
75
+ def vertex_count(self):
76
+ return self.positions.shape[0]
77
+
78
+ def side_count(self):
79
+ return self._edge_vertex_indices.shape[0]
80
+
81
+ def boundary_side_count(self):
82
+ return self._boundary_edge_indices.shape[0]
83
+
84
+ def reference_cell(self) -> Triangle:
85
+ return Triangle()
86
+
87
+ def reference_side(self) -> LinearEdge:
88
+ return LinearEdge()
89
+
90
+ @property
91
+ def edge_tri_indices(self) -> wp.array:
92
+ return self._edge_tri_indices
93
+
94
+ @property
95
+ def edge_vertex_indices(self) -> wp.array:
96
+ return self._edge_vertex_indices
97
+
98
+ CellArg = Trimesh2DCellArg
99
+ SideArg = Trimesh2DSideArg
100
+
101
+ @wp.struct
102
+ class SideIndexArg:
103
+ boundary_edge_indices: wp.array(dtype=int)
104
+
105
+ # Geometry device interface
106
+
107
+ @cached_arg_value
108
+ def cell_arg_value(self, device) -> CellArg:
109
+ args = self.CellArg()
110
+
111
+ args.tri_vertex_indices = self.tri_vertex_indices.to(device)
112
+ args.positions = self.positions.to(device)
113
+ args.vertex_tri_offsets = self._vertex_tri_offsets.to(device)
114
+ args.vertex_tri_indices = self._vertex_tri_indices.to(device)
115
+ args.deformation_gradients = self._deformation_gradients.to(device)
116
+
117
+ return args
118
+
119
+ @wp.func
120
+ def cell_position(args: CellArg, s: Sample):
121
+ tri_idx = args.tri_vertex_indices[s.element_index]
122
+ return (
123
+ s.element_coords[0] * args.positions[tri_idx[0]]
124
+ + s.element_coords[1] * args.positions[tri_idx[1]]
125
+ + s.element_coords[2] * args.positions[tri_idx[2]]
126
+ )
127
+
128
+ @wp.func
129
+ def cell_deformation_gradient(args: CellArg, s: Sample):
130
+ return args.deformation_gradients[s.element_index]
131
+
132
+ @wp.func
133
+ def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
134
+ return wp.inverse(args.deformation_gradients[s.element_index])
135
+
136
+ @wp.func
137
+ def _project_on_tri(args: CellArg, pos: wp.vec2, tri_index: int):
138
+ p0 = args.positions[args.tri_vertex_indices[tri_index, 0]]
139
+
140
+ q = pos - p0
141
+ e1 = args.positions[args.tri_vertex_indices[tri_index, 1]] - p0
142
+ e2 = args.positions[args.tri_vertex_indices[tri_index, 2]] - p0
143
+
144
+ dist, coords = project_on_tri_at_origin(q, e1, e2)
145
+ return dist, coords
146
+
147
+ @wp.func
148
+ def cell_lookup(args: CellArg, pos: wp.vec2, guess: Sample):
149
+ closest_tri = int(NULL_ELEMENT_INDEX)
150
+ closest_coords = Coords(OUTSIDE)
151
+ closest_dist = float(1.0e8)
152
+
153
+ for v in range(3):
154
+ vtx = args.tri_vertex_indices[guess.element_index, v]
155
+ tri_beg = args.vertex_tri_offsets[vtx]
156
+ tri_end = args.vertex_tri_offsets[vtx + 1]
157
+
158
+ for t in range(tri_beg, tri_end):
159
+ tri = args.vertex_tri_indices[t]
160
+ dist, coords = Trimesh2D._project_on_tri(args, pos, tri)
161
+ if dist <= closest_dist:
162
+ closest_dist = dist
163
+ closest_tri = tri
164
+ closest_coords = coords
165
+
166
+ return make_free_sample(closest_tri, closest_coords)
167
+
168
+ @wp.func
169
+ def cell_measure(args: CellArg, s: Sample):
170
+ return 0.5 * wp.abs(wp.determinant(args.deformation_gradients[s.element_index]))
171
+
172
+ @wp.func
173
+ def cell_normal(args: CellArg, s: Sample):
174
+ return wp.vec2(0.0)
175
+
176
+ @cached_arg_value
177
+ def side_index_arg_value(self, device) -> SideIndexArg:
178
+ args = self.SideIndexArg()
179
+
180
+ args.boundary_edge_indices = self._boundary_edge_indices.to(device)
181
+
182
+ return args
183
+
184
+ @wp.func
185
+ def boundary_side_index(args: SideIndexArg, boundary_side_index: int):
186
+ """Boundary side to side index"""
187
+
188
+ return args.boundary_edge_indices[boundary_side_index]
189
+
190
+ @cached_arg_value
191
+ def side_arg_value(self, device) -> CellArg:
192
+ args = self.SideArg()
193
+
194
+ args.cell_arg = self.cell_arg_value(device)
195
+ args.edge_vertex_indices = self._edge_vertex_indices.to(device)
196
+ args.edge_tri_indices = self._edge_tri_indices.to(device)
197
+
198
+ return args
199
+
200
+ @wp.func
201
+ def side_position(args: SideArg, s: Sample):
202
+ edge_idx = args.edge_vertex_indices[s.element_index]
203
+ return (1.0 - s.element_coords[0]) * args.cell_arg.positions[edge_idx[0]] + s.element_coords[
204
+ 0
205
+ ] * args.cell_arg.positions[edge_idx[1]]
206
+
207
+ @wp.func
208
+ def side_deformation_gradient(args: SideArg, s: Sample):
209
+ edge_idx = args.edge_vertex_indices[s.element_index]
210
+ v0 = args.cell_arg.positions[edge_idx[0]]
211
+ v1 = args.cell_arg.positions[edge_idx[1]]
212
+ return v1 - v0
213
+
214
+ @wp.func
215
+ def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
216
+ cell_index = Trimesh2D.side_inner_cell_index(args, s.element_index)
217
+ return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
218
+
219
+ @wp.func
220
+ def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
221
+ cell_index = Trimesh2D.side_outer_cell_index(args, s.element_index)
222
+ return wp.inverse(args.cell_arg.deformation_gradients[cell_index])
223
+
224
+ @wp.func
225
+ def side_measure(args: SideArg, s: Sample):
226
+ edge_idx = args.edge_vertex_indices[s.element_index]
227
+ v0 = args.cell_arg.positions[edge_idx[0]]
228
+ v1 = args.cell_arg.positions[edge_idx[1]]
229
+ return wp.length(v1 - v0)
230
+
231
+ @wp.func
232
+ def side_measure_ratio(args: SideArg, s: Sample):
233
+ inner = Trimesh2D.side_inner_cell_index(args, s.element_index)
234
+ outer = Trimesh2D.side_outer_cell_index(args, s.element_index)
235
+ return Trimesh2D.side_measure(args, s) / wp.min(
236
+ Trimesh2D.cell_measure(args.cell_arg, make_free_sample(inner, Coords())),
237
+ Trimesh2D.cell_measure(args.cell_arg, make_free_sample(outer, Coords())),
238
+ )
239
+
240
+ @wp.func
241
+ def side_normal(args: SideArg, s: Sample):
242
+ edge_idx = args.edge_vertex_indices[s.element_index]
243
+ v0 = args.cell_arg.positions[edge_idx[0]]
244
+ v1 = args.cell_arg.positions[edge_idx[1]]
245
+ e = v1 - v0
246
+
247
+ return wp.normalize(wp.vec2(-e[1], e[0]))
248
+
249
+ @wp.func
250
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
251
+ return arg.edge_tri_indices[side_index][0]
252
+
253
+ @wp.func
254
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
255
+ return arg.edge_tri_indices[side_index][1]
256
+
257
+ @wp.func
258
+ def edge_to_tri_coords(args: SideArg, side_index: ElementIndex, tri_index: ElementIndex, side_coords: Coords):
259
+ edge_vidx = args.edge_vertex_indices[side_index]
260
+ tri_vidx = args.cell_arg.tri_vertex_indices[tri_index]
261
+
262
+ v0 = tri_vidx[0]
263
+ v1 = tri_vidx[1]
264
+
265
+ cx = float(0.0)
266
+ cy = float(0.0)
267
+ cz = float(0.0)
268
+
269
+ if edge_vidx[0] == v0:
270
+ cx = 1.0 - side_coords[0]
271
+ elif edge_vidx[0] == v1:
272
+ cy = 1.0 - side_coords[0]
273
+ else:
274
+ cz = 1.0 - side_coords[0]
275
+
276
+ if edge_vidx[1] == v0:
277
+ cx = side_coords[0]
278
+ elif edge_vidx[1] == v1:
279
+ cy = side_coords[0]
280
+ else:
281
+ cz = side_coords[0]
282
+
283
+ return Coords(cx, cy, cz)
284
+
285
+ @wp.func
286
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
287
+ inner_cell_index = Trimesh2D.side_inner_cell_index(args, side_index)
288
+ return Trimesh2D.edge_to_tri_coords(args, side_index, inner_cell_index, side_coords)
289
+
290
+ @wp.func
291
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
292
+ outer_cell_index = Trimesh2D.side_outer_cell_index(args, side_index)
293
+ return Trimesh2D.edge_to_tri_coords(args, side_index, outer_cell_index, side_coords)
294
+
295
+ @wp.func
296
+ def side_from_cell_coords(
297
+ args: SideArg,
298
+ side_index: ElementIndex,
299
+ tri_index: ElementIndex,
300
+ tri_coords: Coords,
301
+ ):
302
+ edge_vidx = args.edge_vertex_indices[side_index]
303
+ tri_vidx = args.cell_arg.tri_vertex_indices[tri_index]
304
+
305
+ start = int(2)
306
+ end = int(2)
307
+
308
+ for k in range(2):
309
+ v = tri_vidx[k]
310
+ if edge_vidx[1] == v:
311
+ end = k
312
+ elif edge_vidx[0] == v:
313
+ start = k
314
+
315
+ return wp.select(
316
+ tri_coords[start] + tri_coords[end] > 0.999, Coords(OUTSIDE), Coords(tri_coords[end], 0.0, 0.0)
317
+ )
318
+
319
+ @wp.func
320
+ def side_to_cell_arg(side_arg: SideArg):
321
+ return side_arg.cell_arg
322
+
323
+ def _build_topology(self, temporary_store: TemporaryStore):
324
+ from warp.fem.utils import compress_node_indices, masked_indices
325
+ from warp.utils import array_scan
326
+
327
+ device = self.tri_vertex_indices.device
328
+
329
+ vertex_tri_offsets, vertex_tri_indices, _, __ = compress_node_indices(
330
+ self.vertex_count(), self.tri_vertex_indices, temporary_store=temporary_store
331
+ )
332
+ self._vertex_tri_offsets = vertex_tri_offsets.detach()
333
+ self._vertex_tri_indices = vertex_tri_indices.detach()
334
+
335
+ vertex_start_edge_count = borrow_temporary(temporary_store, dtype=int, device=device, shape=self.vertex_count())
336
+ vertex_start_edge_count.array.zero_()
337
+ vertex_start_edge_offsets = borrow_temporary_like(vertex_start_edge_count, temporary_store=temporary_store)
338
+
339
+ vertex_edge_ends = borrow_temporary(temporary_store, dtype=int, device=device, shape=(3 * self.cell_count()))
340
+ vertex_edge_tris = borrow_temporary(temporary_store, dtype=int, device=device, shape=(3 * self.cell_count(), 2))
341
+
342
+ # Count face edges starting at each vertex
343
+ wp.launch(
344
+ kernel=Trimesh2D._count_starting_edges_kernel,
345
+ device=device,
346
+ dim=self.cell_count(),
347
+ inputs=[self.tri_vertex_indices, vertex_start_edge_count.array],
348
+ )
349
+
350
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_start_edge_offsets.array, inclusive=False)
351
+
352
+ # Count number of unique edges (deduplicate across faces)
353
+ vertex_unique_edge_count = vertex_start_edge_count
354
+ wp.launch(
355
+ kernel=Trimesh2D._count_unique_starting_edges_kernel,
356
+ device=device,
357
+ dim=self.vertex_count(),
358
+ inputs=[
359
+ self._vertex_tri_offsets,
360
+ self._vertex_tri_indices,
361
+ self.tri_vertex_indices,
362
+ vertex_start_edge_offsets.array,
363
+ vertex_unique_edge_count.array,
364
+ vertex_edge_ends.array,
365
+ vertex_edge_tris.array,
366
+ ],
367
+ )
368
+
369
+ vertex_unique_edge_offsets = borrow_temporary_like(vertex_start_edge_offsets, temporary_store=temporary_store)
370
+ array_scan(in_array=vertex_start_edge_count.array, out_array=vertex_unique_edge_offsets.array, inclusive=False)
371
+
372
+ # Get back edge count to host
373
+ if device.is_cuda:
374
+ edge_count = borrow_temporary(temporary_store, shape=(1,), dtype=int, device="cpu", pinned=True)
375
+ # Last vertex will not own any edge, so its count will be zero; just fetching last prefix count is ok
376
+ wp.copy(
377
+ dest=edge_count.array, src=vertex_unique_edge_offsets.array, src_offset=self.vertex_count() - 1, count=1
378
+ )
379
+ wp.synchronize_stream(wp.get_stream(device))
380
+ edge_count = int(edge_count.array.numpy()[0])
381
+ else:
382
+ edge_count = int(vertex_unique_edge_offsets.array.numpy()[self.vertex_count() - 1])
383
+
384
+ self._edge_vertex_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
385
+ self._edge_tri_indices = wp.empty(shape=(edge_count,), dtype=wp.vec2i, device=device)
386
+
387
+ boundary_mask = borrow_temporary(temporary_store=temporary_store, shape=(edge_count,), dtype=int, device=device)
388
+
389
+ # Compress edge data
390
+ wp.launch(
391
+ kernel=Trimesh2D._compress_edges_kernel,
392
+ device=device,
393
+ dim=self.vertex_count(),
394
+ inputs=[
395
+ vertex_start_edge_offsets.array,
396
+ vertex_unique_edge_offsets.array,
397
+ vertex_unique_edge_count.array,
398
+ vertex_edge_ends.array,
399
+ vertex_edge_tris.array,
400
+ self._edge_vertex_indices,
401
+ self._edge_tri_indices,
402
+ boundary_mask.array,
403
+ ],
404
+ )
405
+
406
+ vertex_start_edge_offsets.release()
407
+ vertex_unique_edge_offsets.release()
408
+ vertex_unique_edge_count.release()
409
+ vertex_edge_ends.release()
410
+ vertex_edge_tris.release()
411
+
412
+ # Flip normals if necessary
413
+ wp.launch(
414
+ kernel=Trimesh2D._flip_edge_normals,
415
+ device=device,
416
+ dim=self.side_count(),
417
+ inputs=[self._edge_vertex_indices, self._edge_tri_indices, self.tri_vertex_indices, self.positions],
418
+ )
419
+
420
+ boundary_edge_indices, _ = masked_indices(boundary_mask.array, temporary_store=temporary_store)
421
+ self._boundary_edge_indices = boundary_edge_indices.detach()
422
+
423
+ boundary_mask.release()
424
+
425
+ def _compute_deformation_gradients(self):
426
+ self._deformation_gradients = wp.empty(dtype=wp.mat22f, device=self.positions.device, shape=(self.cell_count()))
427
+
428
+ wp.launch(
429
+ kernel=Trimesh2D._compute_deformation_gradients_kernel,
430
+ dim=self._deformation_gradients.shape,
431
+ device=self._deformation_gradients.device,
432
+ inputs=[self.tri_vertex_indices, self.positions, self._deformation_gradients],
433
+ )
434
+
435
+ @wp.kernel
436
+ def _count_starting_edges_kernel(
437
+ tri_vertex_indices: wp.array2d(dtype=int), vertex_start_edge_count: wp.array(dtype=int)
438
+ ):
439
+ t = wp.tid()
440
+ for k in range(3):
441
+ v0 = tri_vertex_indices[t, k]
442
+ v1 = tri_vertex_indices[t, (k + 1) % 3]
443
+
444
+ if v0 < v1:
445
+ wp.atomic_add(vertex_start_edge_count, v0, 1)
446
+ else:
447
+ wp.atomic_add(vertex_start_edge_count, v1, 1)
448
+
449
+ @wp.func
450
+ def _find(
451
+ needle: int,
452
+ values: wp.array(dtype=int),
453
+ beg: int,
454
+ end: int,
455
+ ):
456
+ for i in range(beg, end):
457
+ if values[i] == needle:
458
+ return i
459
+
460
+ return -1
461
+
462
+ @wp.kernel
463
+ def _count_unique_starting_edges_kernel(
464
+ vertex_tri_offsets: wp.array(dtype=int),
465
+ vertex_tri_indices: wp.array(dtype=int),
466
+ tri_vertex_indices: wp.array2d(dtype=int),
467
+ vertex_start_edge_offsets: wp.array(dtype=int),
468
+ vertex_start_edge_count: wp.array(dtype=int),
469
+ edge_ends: wp.array(dtype=int),
470
+ edge_tris: wp.array2d(dtype=int),
471
+ ):
472
+ v = wp.tid()
473
+
474
+ edge_beg = vertex_start_edge_offsets[v]
475
+
476
+ tri_beg = vertex_tri_offsets[v]
477
+ tri_end = vertex_tri_offsets[v + 1]
478
+
479
+ edge_cur = edge_beg
480
+
481
+ for tri in range(tri_beg, tri_end):
482
+ t = vertex_tri_indices[tri]
483
+
484
+ for k in range(3):
485
+ v0 = tri_vertex_indices[t, k]
486
+ v1 = tri_vertex_indices[t, (k + 1) % 3]
487
+
488
+ if v == wp.min(v0, v1):
489
+ other_v = wp.max(v0, v1)
490
+
491
+ # Check if other_v has been seen
492
+ seen_idx = Trimesh2D._find(other_v, edge_ends, edge_beg, edge_cur)
493
+
494
+ if seen_idx == -1:
495
+ edge_ends[edge_cur] = other_v
496
+ edge_tris[edge_cur, 0] = t
497
+ edge_tris[edge_cur, 1] = t
498
+ edge_cur += 1
499
+ else:
500
+ edge_tris[seen_idx, 1] = t
501
+
502
+ vertex_start_edge_count[v] = edge_cur - edge_beg
503
+
504
+ @wp.kernel
505
+ def _compress_edges_kernel(
506
+ vertex_start_edge_offsets: wp.array(dtype=int),
507
+ vertex_unique_edge_offsets: wp.array(dtype=int),
508
+ vertex_unique_edge_count: wp.array(dtype=int),
509
+ uncompressed_edge_ends: wp.array(dtype=int),
510
+ uncompressed_edge_tris: wp.array2d(dtype=int),
511
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
512
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
513
+ boundary_mask: wp.array(dtype=int),
514
+ ):
515
+ v = wp.tid()
516
+
517
+ start_beg = vertex_start_edge_offsets[v]
518
+ unique_beg = vertex_unique_edge_offsets[v]
519
+ unique_count = vertex_unique_edge_count[v]
520
+
521
+ for e in range(unique_count):
522
+ src_index = start_beg + e
523
+ edge_index = unique_beg + e
524
+
525
+ edge_vertex_indices[edge_index] = wp.vec2i(v, uncompressed_edge_ends[src_index])
526
+
527
+ t0 = uncompressed_edge_tris[src_index, 0]
528
+ t1 = uncompressed_edge_tris[src_index, 1]
529
+ edge_tri_indices[edge_index] = wp.vec2i(t0, t1)
530
+ if t0 == t1:
531
+ boundary_mask[edge_index] = 1
532
+ else:
533
+ boundary_mask[edge_index] = 0
534
+
535
+ @wp.kernel
536
+ def _flip_edge_normals(
537
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
538
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
539
+ tri_vertex_indices: wp.array2d(dtype=int),
540
+ positions: wp.array(dtype=wp.vec2),
541
+ ):
542
+ e = wp.tid()
543
+
544
+ tri = edge_tri_indices[e][0]
545
+
546
+ tri_vidx = tri_vertex_indices[tri]
547
+ edge_vidx = edge_vertex_indices[e]
548
+
549
+ tri_centroid = (positions[tri_vidx[0]] + positions[tri_vidx[1]] + positions[tri_vidx[2]]) / 3.0
550
+
551
+ v0 = positions[edge_vidx[0]]
552
+ v1 = positions[edge_vidx[1]]
553
+
554
+ edge_center = 0.5 * (v1 + v0)
555
+ edge_vec = v1 - v0
556
+ edge_normal = wp.vec2(-edge_vec[1], edge_vec[0])
557
+
558
+ # if edge normal points toward first triangle centroid, flip indices
559
+ if wp.dot(tri_centroid - edge_center, edge_normal) > 0.0:
560
+ edge_vertex_indices[e] = wp.vec2i(edge_vidx[1], edge_vidx[0])
561
+
562
+ @wp.kernel
563
+ def _compute_deformation_gradients_kernel(
564
+ tri_vertex_indices: wp.array2d(dtype=int),
565
+ positions: wp.array(dtype=wp.vec2f),
566
+ transforms: wp.array(dtype=wp.mat22f),
567
+ ):
568
+ t = wp.tid()
569
+
570
+ p0 = positions[tri_vertex_indices[t, 0]]
571
+ p1 = positions[tri_vertex_indices[t, 1]]
572
+ p2 = positions[tri_vertex_indices[t, 2]]
573
+
574
+ e1 = p1 - p0
575
+ e2 = p2 - p0
576
+
577
+ transforms[t] = wp.mat22(e1, e2)