warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,369 @@
1
+ import warp as wp
2
+
3
+ from warp.fem.types import ElementIndex, Coords
4
+ from warp.fem.polynomial import Polynomial, is_closed
5
+ from warp.fem.geometry import Quadmesh2D
6
+ from warp.fem import cache
7
+
8
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
9
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
10
+
11
+ from .shape import ShapeFunction, ConstantShapeFunction
12
+ from .shape import (
13
+ SquareBipolynomialShapeFunctions,
14
+ SquareSerendipityShapeFunctions,
15
+ SquareNonConformingPolynomialShapeFunctions,
16
+ )
17
+
18
+
19
+ @wp.struct
20
+ class Quadmesh2DTopologyArg:
21
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
22
+ quad_edge_indices: wp.array2d(dtype=int)
23
+
24
+ vertex_count: int
25
+ edge_count: int
26
+
27
+
28
+ class Quadmesh2DSpaceTopology(SpaceTopology):
29
+ TopologyArg = Quadmesh2DTopologyArg
30
+
31
+ def __init__(self, mesh: Quadmesh2D, shape: ShapeFunction):
32
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
33
+ self._mesh = mesh
34
+ self._shape = shape
35
+
36
+ self._compute_quad_edge_indices()
37
+
38
+ @cache.cached_arg_value
39
+ def topo_arg_value(self, device):
40
+ arg = Quadmesh2DTopologyArg()
41
+ arg.quad_edge_indices = self._quad_edge_indices.to(device)
42
+ arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
43
+
44
+ arg.vertex_count = self._mesh.vertex_count()
45
+ arg.edge_count = self._mesh.side_count()
46
+ return arg
47
+
48
+ def _compute_quad_edge_indices(self):
49
+ self._quad_edge_indices = wp.empty(
50
+ dtype=int, device=self._mesh.quad_vertex_indices.device, shape=(self._mesh.cell_count(), 4)
51
+ )
52
+
53
+ wp.launch(
54
+ kernel=Quadmesh2DSpaceTopology._compute_quad_edge_indices_kernel,
55
+ dim=self._mesh.edge_quad_indices.shape,
56
+ device=self._mesh.quad_vertex_indices.device,
57
+ inputs=[
58
+ self._mesh.edge_quad_indices,
59
+ self._mesh.edge_vertex_indices,
60
+ self._mesh.quad_vertex_indices,
61
+ self._quad_edge_indices,
62
+ ],
63
+ )
64
+
65
+ @wp.func
66
+ def _find_edge_index_in_quad(
67
+ edge_vtx: wp.vec2i,
68
+ quad_vtx: wp.vec4i,
69
+ ):
70
+ for k in range(3):
71
+ if (edge_vtx[0] == quad_vtx[k] and edge_vtx[1] == quad_vtx[k + 1]) or (
72
+ edge_vtx[1] == quad_vtx[k] and edge_vtx[0] == quad_vtx[k + 1]
73
+ ):
74
+ return k
75
+ return 3
76
+
77
+ @wp.kernel
78
+ def _compute_quad_edge_indices_kernel(
79
+ edge_quad_indices: wp.array(dtype=wp.vec2i),
80
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
81
+ quad_vertex_indices: wp.array2d(dtype=int),
82
+ quad_edge_indices: wp.array2d(dtype=int),
83
+ ):
84
+ e = wp.tid()
85
+
86
+ edge_vtx = edge_vertex_indices[e]
87
+ edge_quads = edge_quad_indices[e]
88
+
89
+ q0 = edge_quads[0]
90
+ q0_vtx = wp.vec4i(
91
+ quad_vertex_indices[q0, 0],
92
+ quad_vertex_indices[q0, 1],
93
+ quad_vertex_indices[q0, 2],
94
+ quad_vertex_indices[q0, 3],
95
+ )
96
+ q0_edge = Quadmesh2DSpaceTopology._find_edge_index_in_quad(edge_vtx, q0_vtx)
97
+ quad_edge_indices[q0, q0_edge] = e
98
+
99
+ q1 = edge_quads[1]
100
+ if q1 != q0:
101
+ t1_vtx = wp.vec4i(
102
+ quad_vertex_indices[q1, 0],
103
+ quad_vertex_indices[q1, 1],
104
+ quad_vertex_indices[q1, 2],
105
+ quad_vertex_indices[q1, 3],
106
+ )
107
+ t1_edge = Quadmesh2DSpaceTopology._find_edge_index_in_quad(edge_vtx, t1_vtx)
108
+ quad_edge_indices[q1, t1_edge] = e
109
+
110
+
111
+ class Quadmesh2DDiscontinuousSpaceTopology(
112
+ DiscontinuousSpaceTopologyMixin,
113
+ SpaceTopology,
114
+ ):
115
+ def __init__(self, mesh: Quadmesh2D, shape: ShapeFunction):
116
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
117
+
118
+
119
+ class Quadmesh2DBasisSpace(ShapeBasisSpace):
120
+ def __init__(self, topology: Quadmesh2DSpaceTopology, shape: ShapeFunction):
121
+ super().__init__(topology, shape)
122
+
123
+ self._mesh: Quadmesh2D = topology.geometry
124
+
125
+
126
+ class Quadmesh2DPiecewiseConstantBasis(Quadmesh2DBasisSpace):
127
+ def __init__(self, mesh: Quadmesh2D):
128
+ shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=2)
129
+ topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
130
+ super().__init__(shape=shape, topology=topology)
131
+
132
+ class Trace(TraceBasisSpace):
133
+ @wp.func
134
+ def _node_coords_in_element(
135
+ side_arg: Quadmesh2D.SideArg,
136
+ basis_arg: Quadmesh2DBasisSpace.BasisArg,
137
+ element_index: ElementIndex,
138
+ node_index_in_element: int,
139
+ ):
140
+ return Coords(0.5, 0.0, 0.0)
141
+
142
+ def make_node_coords_in_element(self):
143
+ return self._node_coords_in_element
144
+
145
+ def trace(self):
146
+ return Quadmesh2DPiecewiseConstantBasis.Trace(self)
147
+
148
+
149
+ class Quadmesh2DBipolynomialSpaceTopology(Quadmesh2DSpaceTopology):
150
+ def __init__(self, mesh: Quadmesh2D, shape: SquareBipolynomialShapeFunctions):
151
+ super().__init__(mesh, shape)
152
+
153
+ self.element_node_index = self._make_element_node_index()
154
+
155
+ def node_count(self) -> int:
156
+ ORDER = self._shape.ORDER
157
+ INTERIOR_NODES_PER_SIDE = max(0, ORDER - 1)
158
+ INTERIOR_NODES_PER_CELL = INTERIOR_NODES_PER_SIDE**2
159
+
160
+ return (
161
+ self._mesh.vertex_count()
162
+ + self._mesh.side_count() * INTERIOR_NODES_PER_SIDE
163
+ + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
164
+ )
165
+
166
+ def _make_element_node_index(self):
167
+ ORDER = self._shape.ORDER
168
+ INTERIOR_NODES_PER_SIDE = wp.constant(max(0, ORDER - 1))
169
+ INTERIOR_NODES_PER_CELL = wp.constant(INTERIOR_NODES_PER_SIDE**2)
170
+
171
+ @cache.dynamic_func(suffix=self.name)
172
+ def element_node_index(
173
+ geo_arg: Quadmesh2D.CellArg,
174
+ topo_arg: Quadmesh2DTopologyArg,
175
+ element_index: ElementIndex,
176
+ node_index_in_elt: int,
177
+ ):
178
+ node_i = node_index_in_elt // (ORDER + 1)
179
+ node_j = node_index_in_elt - (ORDER + 1) * node_i
180
+
181
+ # Vertices
182
+ if node_i == 0:
183
+ if node_j == 0:
184
+ return geo_arg.quad_vertex_indices[element_index, 0]
185
+ elif node_j == ORDER:
186
+ return geo_arg.quad_vertex_indices[element_index, 3]
187
+
188
+ # 3-0 edge
189
+ side_index = topo_arg.quad_edge_indices[element_index, 3]
190
+ local_vs = geo_arg.quad_vertex_indices[element_index, 3]
191
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
192
+ index_in_side = wp.select(local_vs == global_vs, ORDER - node_j, node_j) - 1
193
+
194
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
195
+
196
+ elif node_i == ORDER:
197
+ if node_j == 0:
198
+ return geo_arg.quad_vertex_indices[element_index, 1]
199
+ elif node_j == ORDER:
200
+ return geo_arg.quad_vertex_indices[element_index, 2]
201
+
202
+ # 1-2 edge
203
+ side_index = topo_arg.quad_edge_indices[element_index, 1]
204
+ local_vs = geo_arg.quad_vertex_indices[element_index, 1]
205
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
206
+ index_in_side = wp.select(local_vs == global_vs, ORDER - node_j, node_j) - 1
207
+
208
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
209
+
210
+ if node_j == 0:
211
+ # 0-1 edge
212
+ side_index = topo_arg.quad_edge_indices[element_index, 0]
213
+ local_vs = geo_arg.quad_vertex_indices[element_index, 0]
214
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
215
+ index_in_side = wp.select(local_vs == global_vs, node_i, ORDER - node_i) - 1
216
+
217
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
218
+
219
+ elif node_j == ORDER:
220
+ # 2-3 edge
221
+ side_index = topo_arg.quad_edge_indices[element_index, 2]
222
+ local_vs = geo_arg.quad_vertex_indices[element_index, 2]
223
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
224
+ index_in_side = wp.select(local_vs == global_vs, node_i, ORDER - node_i) - 1
225
+
226
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
227
+
228
+ return (
229
+ topo_arg.vertex_count
230
+ + topo_arg.edge_count * INTERIOR_NODES_PER_SIDE
231
+ + element_index * INTERIOR_NODES_PER_CELL
232
+ + (node_i - 1) * INTERIOR_NODES_PER_SIDE
233
+ + node_j
234
+ - 1
235
+ )
236
+
237
+ return element_node_index
238
+
239
+
240
+ class Quadmesh2DBipolynomialBasisSpace(Quadmesh2DBasisSpace):
241
+ def __init__(
242
+ self,
243
+ mesh: Quadmesh2D,
244
+ degree: int,
245
+ family: Polynomial,
246
+ ):
247
+ if family is None:
248
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
249
+
250
+ if not is_closed(family):
251
+ raise ValueError("A closed polynomial family is required to define a continuous function space")
252
+
253
+ shape = SquareBipolynomialShapeFunctions(degree, family=family)
254
+ topology = forward_base_topology(Quadmesh2DBipolynomialSpaceTopology, mesh, shape)
255
+
256
+ super().__init__(topology, shape)
257
+
258
+
259
+ class Quadmesh2DDGBipolynomialBasisSpace(Quadmesh2DBasisSpace):
260
+ def __init__(
261
+ self,
262
+ mesh: Quadmesh2D,
263
+ degree: int,
264
+ family: Polynomial,
265
+ ):
266
+ if family is None:
267
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
268
+
269
+ shape = SquareBipolynomialShapeFunctions(degree, family=family)
270
+ topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
271
+
272
+ super().__init__(topology, shape)
273
+
274
+
275
+ class Quadmesh2DSerendipitySpaceTopology(Quadmesh2DSpaceTopology):
276
+ def __init__(self, grid: Quadmesh2D, shape: SquareSerendipityShapeFunctions):
277
+ super().__init__(grid, shape)
278
+
279
+ self.element_node_index = self._make_element_node_index()
280
+
281
+ def node_count(self) -> int:
282
+ return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.side_count()
283
+
284
+ def _make_element_node_index(self):
285
+ ORDER = self._shape.ORDER
286
+
287
+ SHAPE_TO_QUAD_IDX = wp.constant(wp.vec4i([0, 3, 1, 2]))
288
+
289
+ @cache.dynamic_func(suffix=self.name)
290
+ def element_node_index(
291
+ cell_arg: Quadmesh2D.CellArg,
292
+ topo_arg: Quadmesh2DSpaceTopology.TopologyArg,
293
+ element_index: ElementIndex,
294
+ node_index_in_elt: int,
295
+ ):
296
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
297
+
298
+ if node_type == SquareSerendipityShapeFunctions.VERTEX:
299
+ return cell_arg.quad_vertex_indices[element_index, SHAPE_TO_QUAD_IDX[type_index]]
300
+
301
+ side_offset, index_in_side = SquareSerendipityShapeFunctions.side_offset_and_index(type_index)
302
+
303
+ if node_type == SquareSerendipityShapeFunctions.EDGE_X:
304
+ if side_offset == 0:
305
+ side_start = 0
306
+ else:
307
+ side_start = 2
308
+ index_in_side = ORDER - 2 - index_in_side
309
+ else:
310
+ if side_offset == 0:
311
+ side_start = 3
312
+ index_in_side = ORDER - 2 - index_in_side
313
+ else:
314
+ side_start = 1
315
+
316
+ side_index = topo_arg.quad_edge_indices[element_index, side_start]
317
+ local_vs = cell_arg.quad_vertex_indices[element_index, side_start]
318
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
319
+ if local_vs != global_vs:
320
+ # Flip indexing direction
321
+ index_in_side = ORDER - 2 - index_in_side
322
+
323
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
324
+
325
+ return element_node_index
326
+
327
+
328
+ class Quadmesh2DSerendipityBasisSpace(Quadmesh2DBasisSpace):
329
+ def __init__(
330
+ self,
331
+ mesh: Quadmesh2D,
332
+ degree: int,
333
+ family: Polynomial,
334
+ ):
335
+ if family is None:
336
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
337
+
338
+ shape = SquareSerendipityShapeFunctions(degree, family=family)
339
+ topology = forward_base_topology(Quadmesh2DSerendipitySpaceTopology, mesh, shape=shape)
340
+
341
+ super().__init__(topology=topology, shape=shape)
342
+
343
+
344
+ class Quadmesh2DDGSerendipityBasisSpace(Quadmesh2DBasisSpace):
345
+ def __init__(
346
+ self,
347
+ mesh: Quadmesh2D,
348
+ degree: int,
349
+ family: Polynomial,
350
+ ):
351
+ if family is None:
352
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
353
+
354
+ shape = SquareSerendipityShapeFunctions(degree, family=family)
355
+ topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape=shape)
356
+
357
+ super().__init__(topology=topology, shape=shape)
358
+
359
+
360
+ class Quadmesh2DPolynomialBasisSpace(Quadmesh2DBasisSpace):
361
+ def __init__(
362
+ self,
363
+ mesh: Quadmesh2D,
364
+ degree: int,
365
+ ):
366
+ shape = SquareNonConformingPolynomialShapeFunctions(degree)
367
+ topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
368
+
369
+ super().__init__(topology, shape)
@@ -0,0 +1,160 @@
1
+ import warp as wp
2
+
3
+ from warp.fem.domain import GeometryDomain
4
+ from warp.fem.types import NodeElementIndex
5
+ from warp.fem.utils import compress_node_indices
6
+ from warp.fem.cache import cached_arg_value, borrow_temporary, borrow_temporary_like, TemporaryStore
7
+
8
+ from .function_space import FunctionSpace
9
+ from .partition import SpacePartition
10
+
11
+ wp.set_module_options({"enable_backward": False})
12
+
13
+
14
+ class SpaceRestriction:
15
+ """Restriction of a space partition to a given GeometryDomain"""
16
+
17
+ def __init__(
18
+ self,
19
+ space_partition: SpacePartition,
20
+ domain: GeometryDomain,
21
+ device=None,
22
+ temporary_store: TemporaryStore = None,
23
+ ):
24
+ space_topology = space_partition.space_topology
25
+
26
+ if domain.dimension == space_topology.dimension - 1:
27
+ space_topology = space_topology.trace()
28
+
29
+ if domain.dimension != space_topology.dimension:
30
+ raise ValueError("Incompatible space and domain dimensions")
31
+
32
+ self.space_partition = space_partition
33
+ self.space_topology = space_topology
34
+ self.domain = domain
35
+
36
+ self._compute_node_element_indices(device=device, temporary_store=temporary_store)
37
+
38
+ def _compute_node_element_indices(self, device, temporary_store: TemporaryStore):
39
+ from warp.fem import cache
40
+
41
+ NODES_PER_ELEMENT = self.space_topology.NODES_PER_ELEMENT
42
+
43
+ @cache.dynamic_kernel(suffix=f"{self.domain.name}_{self.space_topology.name}_{self.space_partition.name}")
44
+ def fill_element_node_indices(
45
+ element_arg: self.domain.ElementArg,
46
+ domain_index_arg: self.domain.ElementIndexArg,
47
+ topo_arg: self.space_topology.TopologyArg,
48
+ partition_arg: self.space_partition.PartitionArg,
49
+ element_node_indices: wp.array2d(dtype=int),
50
+ ):
51
+ domain_element_index = wp.tid()
52
+ element_index = self.domain.element_index(domain_index_arg, domain_element_index)
53
+ for n in range(NODES_PER_ELEMENT):
54
+ space_nidx = self.space_topology.element_node_index(element_arg, topo_arg, element_index, n)
55
+ partition_nidx = self.space_partition.partition_node_index(partition_arg, space_nidx)
56
+ element_node_indices[domain_element_index, n] = partition_nidx
57
+
58
+ element_node_indices = borrow_temporary(
59
+ temporary_store,
60
+ shape=(self.domain.element_count(), NODES_PER_ELEMENT),
61
+ dtype=int,
62
+ device=device,
63
+ )
64
+ wp.launch(
65
+ dim=element_node_indices.array.shape[0],
66
+ kernel=fill_element_node_indices,
67
+ inputs=[
68
+ self.domain.element_arg_value(device),
69
+ self.domain.element_index_arg_value(device),
70
+ self.space_topology.topo_arg_value(device),
71
+ self.space_partition.partition_arg_value(device),
72
+ element_node_indices.array,
73
+ ],
74
+ device=device,
75
+ )
76
+
77
+ # Build compressed map from node to element indices
78
+ flattened_node_indices = element_node_indices.array.flatten()
79
+ (
80
+ self._dof_partition_element_offsets,
81
+ node_array_indices,
82
+ self._node_count,
83
+ self._dof_partition_indices,
84
+ ) = compress_node_indices(
85
+ self.space_partition.node_count(), flattened_node_indices, temporary_store=temporary_store
86
+ )
87
+
88
+ # Extract element index and index in element
89
+ self._dof_element_indices = borrow_temporary_like(flattened_node_indices, temporary_store)
90
+ self._dof_indices_in_element = borrow_temporary_like(flattened_node_indices, temporary_store)
91
+ wp.launch(
92
+ kernel=SpaceRestriction._split_vertex_element_index,
93
+ dim=flattened_node_indices.shape,
94
+ inputs=[
95
+ NODES_PER_ELEMENT,
96
+ node_array_indices.array,
97
+ self._dof_element_indices.array,
98
+ self._dof_indices_in_element.array,
99
+ ],
100
+ device=flattened_node_indices.device,
101
+ )
102
+
103
+ node_array_indices.release()
104
+
105
+ def node_count(self):
106
+ return self._node_count
107
+
108
+ def partition_element_offsets(self):
109
+ return self._dof_partition_element_offsets.array
110
+
111
+ def node_partition_indices(self):
112
+ return self._dof_partition_indices.array
113
+
114
+ def total_node_element_count(self):
115
+ return self._dof_element_indices.array.size
116
+
117
+ @wp.struct
118
+ class NodeArg:
119
+ dof_element_offsets: wp.array(dtype=int)
120
+ dof_element_indices: wp.array(dtype=int)
121
+ dof_partition_indices: wp.array(dtype=int)
122
+ dof_indices_in_element: wp.array(dtype=int)
123
+
124
+ @cached_arg_value
125
+ def node_arg(self, device):
126
+ arg = SpaceRestriction.NodeArg()
127
+ arg.dof_element_offsets = self._dof_partition_element_offsets.array.to(device)
128
+ arg.dof_element_indices = self._dof_element_indices.array.to(device)
129
+ arg.dof_partition_indices = self._dof_partition_indices.array.to(device)
130
+ arg.dof_indices_in_element = self._dof_indices_in_element.array.to(device)
131
+ return arg
132
+
133
+ @wp.func
134
+ def node_partition_index(args: NodeArg, node_index: int):
135
+ return args.dof_partition_indices[node_index]
136
+
137
+ @wp.func
138
+ def node_element_count(args: NodeArg, node_index: int):
139
+ partition_node_index = SpaceRestriction.node_partition_index(args, node_index)
140
+ return args.dof_element_offsets[partition_node_index + 1] - args.dof_element_offsets[partition_node_index]
141
+
142
+ @wp.func
143
+ def node_element_index(args: NodeArg, node_index: int, element_index: int):
144
+ partition_node_index = SpaceRestriction.node_partition_index(args, node_index)
145
+ offset = args.dof_element_offsets[partition_node_index] + element_index
146
+ domain_element_index = args.dof_element_indices[offset]
147
+ index_in_element = args.dof_indices_in_element[offset]
148
+ return NodeElementIndex(domain_element_index, index_in_element)
149
+
150
+ @wp.kernel
151
+ def _split_vertex_element_index(
152
+ vertex_per_element: int,
153
+ sorted_indices: wp.array(dtype=int),
154
+ vertex_element_index: wp.array(dtype=int),
155
+ vertex_index_in_element: wp.array(dtype=int),
156
+ ):
157
+ idx = sorted_indices[wp.tid()]
158
+ element_index = idx // vertex_per_element
159
+ vertex_element_index[wp.tid()] = element_index
160
+ vertex_index_in_element[wp.tid()] = idx - vertex_per_element * element_index
@@ -0,0 +1,15 @@
1
+ from .shape_function import ShapeFunction, ConstantShapeFunction
2
+
3
+ from .triangle_shape_function import Triangle2DPolynomialShapeFunctions, Triangle2DNonConformingPolynomialShapeFunctions
4
+ from .tet_shape_function import TetrahedronPolynomialShapeFunctions, TetrahedronNonConformingPolynomialShapeFunctions
5
+
6
+ from .square_shape_function import (
7
+ SquareBipolynomialShapeFunctions,
8
+ SquareSerendipityShapeFunctions,
9
+ SquareNonConformingPolynomialShapeFunctions,
10
+ )
11
+ from .cube_shape_function import (
12
+ CubeSerendipityShapeFunctions,
13
+ CubeTripolynomialShapeFunctions,
14
+ CubeNonConformingPolynomialShapeFunctions,
15
+ )