warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
1
+ from typing import Any
2
+
3
+ import warp as wp
4
+
5
+ from warp.fem import domain, cache
6
+ from warp.fem.types import ElementIndex, Coords
7
+ from warp.fem.space import FunctionSpace
8
+
9
+ from ..polynomial import Polynomial
10
+
11
+
12
+ class Quadrature:
13
+ """Interface class for quadrature rules"""
14
+
15
+ @wp.struct
16
+ class Arg:
17
+ """Structure containing arguments to be passed to device functions"""
18
+
19
+ pass
20
+
21
+ def __init__(self, domain: domain.GeometryDomain):
22
+ self._domain = domain
23
+
24
+ @property
25
+ def domain(self):
26
+ """Domain over which this quadrature is defined"""
27
+ return self._domain
28
+
29
+ def arg_value(self, device) -> "Arg":
30
+ """
31
+ Value of the argument to be passed to device
32
+ """
33
+ arg = RegularQuadrature.Arg()
34
+ return arg
35
+
36
+ def total_point_count(self):
37
+ """Total number of quadrature points over the domain"""
38
+ raise NotImplementedError()
39
+
40
+ def points_per_element(self):
41
+ """Number of points per element if constant, or ``None`` if varying"""
42
+ return None
43
+
44
+ @staticmethod
45
+ def point_count(elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex):
46
+ """Number of quadrature points for a given element"""
47
+ raise NotImplementedError()
48
+
49
+ @staticmethod
50
+ def point_coords(
51
+ elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
52
+ ):
53
+ """Coordinates in element of the element's qp_index'th quadrature point"""
54
+ raise NotImplementedError()
55
+
56
+ @staticmethod
57
+ def point_weight(
58
+ elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
59
+ ):
60
+ """Weight of the element's qp_index'th quadrature point"""
61
+ raise NotImplementedError()
62
+
63
+ @staticmethod
64
+ def point_index(
65
+ elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
66
+ ):
67
+ """Global index of the element's qp_index'th quadrature point"""
68
+ raise NotImplementedError()
69
+
70
+ def __str__(self) -> str:
71
+ return self.name
72
+
73
+
74
+ class RegularQuadrature(Quadrature):
75
+ """Regular quadrature formula, using a constant set of quadrature points per element"""
76
+
77
+ def __init__(
78
+ self,
79
+ domain: domain.GeometryDomain,
80
+ order: int,
81
+ family: Polynomial = None,
82
+ ):
83
+ super().__init__(domain)
84
+
85
+ self.family = family
86
+ self.order = order
87
+
88
+ self._element_quadrature = domain.reference_element().instantiate_quadrature(order, family)
89
+
90
+ self._N = wp.constant(len(self.points))
91
+
92
+ WeightVec = wp.vec(length=self._N, dtype=wp.float32)
93
+ CoordMat = wp.mat(shape=(self._N, 3), dtype=wp.float32)
94
+
95
+ self._POINTS = wp.constant(CoordMat(self.points))
96
+ self._WEIGHTS = wp.constant(WeightVec(self.weights))
97
+
98
+ self.point_count = self._make_point_count()
99
+ self.point_index = self._make_point_index()
100
+ self.point_coords = self._make_point_coords()
101
+ self.point_weight = self._make_point_weight()
102
+
103
+ @property
104
+ def name(self):
105
+ return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
106
+
107
+ def total_point_count(self):
108
+ return len(self.points) * self.domain.geometry_element_count()
109
+
110
+ def points_per_element(self):
111
+ return self._N
112
+
113
+ @property
114
+ def points(self):
115
+ return self._element_quadrature[0]
116
+
117
+ @property
118
+ def weights(self):
119
+ return self._element_quadrature[1]
120
+
121
+ def _make_point_count(self):
122
+ N = self._N
123
+
124
+ @cache.dynamic_func(suffix=self.name)
125
+ def point_count(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex):
126
+ return N
127
+
128
+ return point_count
129
+
130
+ def _make_point_coords(self):
131
+ POINTS = self._POINTS
132
+
133
+ @cache.dynamic_func(suffix=self.name)
134
+ def point_coords(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
135
+ return Coords(POINTS[qp_index, 0], POINTS[qp_index, 1], POINTS[qp_index, 2])
136
+
137
+ return point_coords
138
+
139
+ def _make_point_weight(self):
140
+ WEIGHTS = self._WEIGHTS
141
+
142
+ @cache.dynamic_func(suffix=self.name)
143
+ def point_weight(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
144
+ return WEIGHTS[qp_index]
145
+
146
+ return point_weight
147
+
148
+ def _make_point_index(self):
149
+ N = self._N
150
+
151
+ @cache.dynamic_func(suffix=self.name)
152
+ def point_index(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
153
+ return N * element_index + qp_index
154
+
155
+ return point_index
156
+
157
+
158
+ class NodalQuadrature(Quadrature):
159
+ """Quadrature using space node points as quadrature points
160
+
161
+ Note that in contrast to the `nodal=True` flag for :func:`integrate`, this quadrature odes not make any assumption
162
+ about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
163
+ """
164
+
165
+ def __init__(self, domain: domain.GeometryDomain, space: FunctionSpace):
166
+ super().__init__(domain)
167
+
168
+ self._space = space
169
+
170
+ self.Arg = self._make_arg()
171
+
172
+ self.point_count = self._make_point_count()
173
+ self.point_index = self._make_point_index()
174
+ self.point_coords = self._make_point_coords()
175
+ self.point_weight = self._make_point_weight()
176
+
177
+ @property
178
+ def name(self):
179
+ return f"{self.__class__.__name__}_{self._space.name}"
180
+
181
+ def total_point_count(self):
182
+ return self._space.node_count()
183
+
184
+ def points_per_element(self):
185
+ return self._space.topology.NODES_PER_ELEMENT
186
+
187
+ def _make_arg(self):
188
+ @cache.dynamic_struct(suffix=self.name)
189
+ class Arg:
190
+ space_arg: self._space.SpaceArg
191
+ topo_arg: self._space.topology.TopologyArg
192
+
193
+ return Arg
194
+
195
+ @cache.cached_arg_value
196
+ def arg_value(self, device):
197
+ arg = self.Arg()
198
+ arg.space_arg = self._space.space_arg_value(device)
199
+ arg.topo_arg = self._space.topology.topo_arg_value(device)
200
+ return arg
201
+
202
+ def _make_point_count(self):
203
+ N = self._space.topology.NODES_PER_ELEMENT
204
+
205
+ @cache.dynamic_func(suffix=self.name)
206
+ def point_count(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex):
207
+ return N
208
+
209
+ return point_count
210
+
211
+ def _make_point_coords(self):
212
+ @cache.dynamic_func(suffix=self.name)
213
+ def point_coords(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
214
+ return self._space.node_coords_in_element(elt_arg, qp_arg.space_arg, element_index, qp_index)
215
+
216
+ return point_coords
217
+
218
+ def _make_point_weight(self):
219
+ @cache.dynamic_func(suffix=self.name)
220
+ def point_weight(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
221
+ return self._space.node_quadrature_weight(elt_arg, qp_arg.space_arg, element_index, qp_index)
222
+
223
+ return point_weight
224
+
225
+ def _make_point_index(self):
226
+ @cache.dynamic_func(suffix=self.name)
227
+ def point_index(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
228
+ return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
229
+
230
+ return point_index
231
+
232
+
233
+ class ExplicitQuadrature(Quadrature):
234
+ """Quadrature using explicit per-cell points and weights. The number of quadrature points per cell is assumed
235
+ to be constant and deduced from the shape of the points and weights arrays.
236
+
237
+ Args:
238
+ domain: Domain of definition of the quadrature formula
239
+ points: 2d array of shape ``(domain.geometry_element-count(), points_per_cell)`` containing the coordinates of each quadrature point.
240
+ weights: 2d array of shape ``(domain.geometry_element-count(), points_per_cell)`` containing the weight for each quadrature point.
241
+
242
+ See also: :class:`PicQuadrature`
243
+ """
244
+
245
+ @wp.struct
246
+ class Arg:
247
+ points_per_cell: int
248
+ points: wp.array2d(dtype=Coords)
249
+ weights: wp.array2d(dtype=float)
250
+
251
+ def __init__(self, domain: domain.GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
252
+ super().__init__(domain)
253
+
254
+ if points.shape != weights.shape:
255
+ raise ValueError("Points and weights arrays must have the same shape")
256
+
257
+ self._points_per_cell = points.shape[1]
258
+ self._points = points
259
+ self._weights = weights
260
+
261
+ @property
262
+ def name(self):
263
+ return f"{self.__class__.__name__}"
264
+
265
+ def total_point_count(self):
266
+ return self._weights.size
267
+
268
+ def points_per_element(self):
269
+ return self._points_per_cell
270
+
271
+ @cache.cached_arg_value
272
+ def arg_value(self, device):
273
+ arg = self.Arg()
274
+ arg.points_per_cell = self._points_per_cell
275
+ arg.points = self._points.to(device)
276
+ arg.weights = self._weights.to(device)
277
+
278
+ return arg
279
+
280
+ @wp.func
281
+ def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
282
+ return qp_arg.points_per_cell
283
+
284
+ @wp.func
285
+ def point_coords(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
286
+ return qp_arg.points[element_index, qp_index]
287
+
288
+ @wp.func
289
+ def point_weight(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
290
+ return qp_arg.weights[element_index, qp_index]
291
+
292
+ @wp.func
293
+ def point_index(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
294
+ return qp_arg.points_per_cell * element_index + qp_index
@@ -0,0 +1,292 @@
1
+ from typing import Optional
2
+ from enum import Enum
3
+
4
+ import warp.fem.domain as _domain
5
+ import warp.fem.geometry as _geometry
6
+ import warp.fem.polynomial as _polynomial
7
+
8
+ from .function_space import FunctionSpace
9
+ from .topology import SpaceTopology
10
+ from .basis_space import BasisSpace, PointBasisSpace
11
+ from .collocated_function_space import CollocatedFunctionSpace
12
+
13
+ from .grid_2d_function_space import (
14
+ GridPiecewiseConstantBasis,
15
+ GridBipolynomialBasisSpace,
16
+ GridDGBipolynomialBasisSpace,
17
+ GridSerendipityBasisSpace,
18
+ GridDGSerendipityBasisSpace,
19
+ GridDGPolynomialBasisSpace,
20
+ )
21
+ from .grid_3d_function_space import (
22
+ GridTripolynomialBasisSpace,
23
+ GridDGTripolynomialBasisSpace,
24
+ Grid3DPiecewiseConstantBasis,
25
+ Grid3DSerendipityBasisSpace,
26
+ Grid3DDGSerendipityBasisSpace,
27
+ Grid3DDGPolynomialBasisSpace,
28
+ )
29
+ from .trimesh_2d_function_space import (
30
+ Trimesh2DPiecewiseConstantBasis,
31
+ Trimesh2DPolynomialBasisSpace,
32
+ Trimesh2DDGPolynomialBasisSpace,
33
+ Trimesh2DNonConformingPolynomialBasisSpace,
34
+ )
35
+ from .tetmesh_function_space import (
36
+ TetmeshPiecewiseConstantBasis,
37
+ TetmeshPolynomialBasisSpace,
38
+ TetmeshDGPolynomialBasisSpace,
39
+ TetmeshNonConformingPolynomialBasisSpace,
40
+ )
41
+ from .quadmesh_2d_function_space import (
42
+ Quadmesh2DPiecewiseConstantBasis,
43
+ Quadmesh2DBipolynomialBasisSpace,
44
+ Quadmesh2DDGBipolynomialBasisSpace,
45
+ Quadmesh2DSerendipityBasisSpace,
46
+ Quadmesh2DDGSerendipityBasisSpace,
47
+ Quadmesh2DPolynomialBasisSpace,
48
+ )
49
+ from .hexmesh_function_space import (
50
+ HexmeshPiecewiseConstantBasis,
51
+ HexmeshTripolynomialBasisSpace,
52
+ HexmeshDGTripolynomialBasisSpace,
53
+ HexmeshSerendipityBasisSpace,
54
+ HexmeshDGSerendipityBasisSpace,
55
+ HexmeshPolynomialBasisSpace,
56
+ )
57
+
58
+ from .partition import SpacePartition, make_space_partition
59
+ from .restriction import SpaceRestriction
60
+
61
+
62
+ from .dof_mapper import DofMapper, IdentityMapper, SymmetricTensorMapper, SkewSymmetricTensorMapper
63
+
64
+
65
+ def make_space_restriction(
66
+ space: Optional[FunctionSpace] = None,
67
+ space_partition: Optional[SpacePartition] = None,
68
+ domain: Optional[_domain.GeometryDomain] = None,
69
+ space_topology: Optional[SpaceTopology] = None,
70
+ device=None,
71
+ temporary_store: "Optional[warp.fem.cache.TemporaryStore]" = None,
72
+ ) -> SpaceRestriction:
73
+ """
74
+ Restricts a function space partition to a Domain, i.e. a subset of its elements.
75
+
76
+ One of `space_partition`, `space_topology`, or `space` must be provided (and will be considered in that order).
77
+
78
+ Args:
79
+ space: (deprecated) if neither `space_partition` nor `space_topology` are provided, the space defining the topology to restrict
80
+ space_partition: the subset of nodes from the space topology to consider
81
+ domain: the domain to restrict the space to, defaults to all cells of the space geometry or partition.
82
+ space_topology: the space topology to be restricted, if `space_partition` is ``None``.
83
+ device: device on which to perform and store computations
84
+ temporary_store: shared pool from which to allocate temporary arrays
85
+ """
86
+
87
+ if space_partition is None:
88
+ if space_topology is None:
89
+ assert space is not None
90
+ space_topology = space.topology
91
+
92
+ if domain is None:
93
+ domain = _domain.Cells(geometry=space_topology.geometry)
94
+
95
+ space_partition = make_space_partition(
96
+ space_topology=space_topology, geometry_partition=domain.geometry_partition
97
+ )
98
+ elif domain is None:
99
+ domain = _domain.Cells(geometry=space_partition.geo_partition)
100
+
101
+ return SpaceRestriction(
102
+ space_partition=space_partition, domain=domain, device=device, temporary_store=temporary_store
103
+ )
104
+
105
+
106
+ class ElementBasis(Enum):
107
+ """Choice of basis function to equip individual elements"""
108
+
109
+ LAGRANGE = 0
110
+ """Lagrange basis functions :math:`P_k` for simplices, tensor products :math:`Q_k` for squares and cubes"""
111
+ SERENDIPITY = 1
112
+ """Serendipity elements :math:`S_k`, corresponding to Lagrange nodes with interior points removed (for degree <= 3)"""
113
+ NONCONFORMING_POLYNOMIAL = 2
114
+ """Simplex Lagrange basis functions :math:`P_{kd}` embedded into non conforming reference elements (e.g. squares or cubes). Discontinuous only."""
115
+
116
+
117
+ def make_polynomial_basis_space(
118
+ geo: _geometry.Geometry,
119
+ degree: int = 1,
120
+ element_basis: Optional[ElementBasis] = None,
121
+ discontinuous: bool = False,
122
+ family: Optional[_polynomial.Polynomial] = None,
123
+ ) -> BasisSpace:
124
+ """
125
+ Equips a geometry with a polynomial basis.
126
+
127
+ Args:
128
+ geo: the Geometry on which to build the space
129
+ degree: polynomial degree of the per-element shape functions
130
+ discontinuous: if True, use Discontinuous Galerkin shape functions. Discontinuous is implied if degree is 0, i.e, piecewise-constant shape functions.
131
+ element_basis: type of basis function for the individual elements
132
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
133
+
134
+ Returns:
135
+ the constructed basis space
136
+ """
137
+
138
+ base_geo = geo.base if isinstance(geo, _geometry.DeformedGeometry) else geo
139
+
140
+ if element_basis is None:
141
+ element_basis = ElementBasis.LAGRANGE
142
+
143
+ if isinstance(base_geo, _geometry.Grid2D):
144
+ if degree == 0:
145
+ return GridPiecewiseConstantBasis(geo)
146
+
147
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
148
+ if discontinuous:
149
+ return GridDGSerendipityBasisSpace(geo, degree=degree, family=family)
150
+ else:
151
+ return GridSerendipityBasisSpace(geo, degree=degree, family=family)
152
+
153
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
154
+ return GridDGPolynomialBasisSpace(geo, degree=degree)
155
+
156
+ if discontinuous:
157
+ return GridDGBipolynomialBasisSpace(geo, degree=degree, family=family)
158
+ else:
159
+ return GridBipolynomialBasisSpace(geo, degree=degree, family=family)
160
+
161
+ if isinstance(base_geo, _geometry.Grid3D):
162
+ if degree == 0:
163
+ return Grid3DPiecewiseConstantBasis(geo)
164
+
165
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
166
+ if discontinuous:
167
+ return Grid3DDGSerendipityBasisSpace(geo, degree=degree, family=family)
168
+ else:
169
+ return Grid3DSerendipityBasisSpace(geo, degree=degree, family=family)
170
+
171
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
172
+ return Grid3DDGPolynomialBasisSpace(geo, degree=degree)
173
+
174
+ if discontinuous:
175
+ return GridDGTripolynomialBasisSpace(geo, degree=degree, family=family)
176
+ else:
177
+ return GridTripolynomialBasisSpace(geo, degree=degree, family=family)
178
+
179
+ if isinstance(base_geo, _geometry.Trimesh2D):
180
+ if degree == 0:
181
+ return Trimesh2DPiecewiseConstantBasis(geo)
182
+
183
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
184
+ raise NotImplementedError("Serendipity variant not implemented yet")
185
+
186
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
187
+ return Trimesh2DNonConformingPolynomialBasisSpace(geo, degree=degree)
188
+
189
+ if discontinuous:
190
+ return Trimesh2DDGPolynomialBasisSpace(geo, degree=degree)
191
+ else:
192
+ return Trimesh2DPolynomialBasisSpace(geo, degree=degree)
193
+
194
+ if isinstance(base_geo, _geometry.Tetmesh):
195
+ if degree == 0:
196
+ return TetmeshPiecewiseConstantBasis(geo)
197
+
198
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
199
+ raise NotImplementedError("Serendipity variant not implemented yet")
200
+
201
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
202
+ return TetmeshNonConformingPolynomialBasisSpace(geo, degree=degree)
203
+
204
+ if discontinuous:
205
+ return TetmeshDGPolynomialBasisSpace(geo, degree=degree)
206
+ else:
207
+ return TetmeshPolynomialBasisSpace(geo, degree=degree)
208
+
209
+ if isinstance(base_geo, _geometry.Quadmesh2D):
210
+ if degree == 0:
211
+ return Quadmesh2DPiecewiseConstantBasis(geo)
212
+
213
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
214
+ if discontinuous:
215
+ return Quadmesh2DDGSerendipityBasisSpace(geo, degree=degree, family=family)
216
+ else:
217
+ return Quadmesh2DSerendipityBasisSpace(geo, degree=degree, family=family)
218
+
219
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
220
+ return Quadmesh2DPolynomialBasisSpace(geo, degree=degree)
221
+
222
+ if discontinuous:
223
+ return Quadmesh2DDGBipolynomialBasisSpace(geo, degree=degree, family=family)
224
+ else:
225
+ return Quadmesh2DBipolynomialBasisSpace(geo, degree=degree, family=family)
226
+
227
+ if isinstance(base_geo, _geometry.Hexmesh):
228
+ if degree == 0:
229
+ return HexmeshPiecewiseConstantBasis(geo)
230
+
231
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
232
+ if discontinuous:
233
+ return HexmeshDGSerendipityBasisSpace(geo, degree=degree, family=family)
234
+ else:
235
+ return HexmeshSerendipityBasisSpace(geo, degree=degree, family=family)
236
+
237
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
238
+ return HexmeshPolynomialBasisSpace(geo, degree=degree)
239
+
240
+ if discontinuous:
241
+ return HexmeshDGTripolynomialBasisSpace(geo, degree=degree, family=family)
242
+ else:
243
+ return HexmeshTripolynomialBasisSpace(geo, degree=degree, family=family)
244
+
245
+ raise NotImplementedError()
246
+
247
+
248
+ def make_collocated_function_space(
249
+ basis_space: BasisSpace, dtype: type = float, dof_mapper: Optional[DofMapper] = None
250
+ ) -> CollocatedFunctionSpace:
251
+ """
252
+ Constructs a function space from a basis space and a value type, such that all degrees of freedom of the value type are stored at each of the basis nodes.
253
+
254
+ Args:
255
+ geo: the Geometry on which to build the space
256
+ dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
257
+ dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
258
+
259
+ Returns:
260
+ the constructed function space
261
+ """
262
+ return CollocatedFunctionSpace(basis_space, dtype=dtype, dof_mapper=dof_mapper)
263
+
264
+
265
+ def make_polynomial_space(
266
+ geo: _geometry.Geometry,
267
+ dtype: type = float,
268
+ dof_mapper: Optional[DofMapper] = None,
269
+ degree: int = 1,
270
+ element_basis: Optional[ElementBasis] = None,
271
+ discontinuous: bool = False,
272
+ family: Optional[_polynomial.Polynomial] = None,
273
+ ) -> CollocatedFunctionSpace:
274
+ """
275
+ Equips a geometry with a collocated, polynomial function space.
276
+ Equivalent to successive calls to :func:`make_polynomial_basis_space` and `make_collocated_function_space`.
277
+
278
+ Args:
279
+ geo: the Geometry on which to build the space
280
+ dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
281
+ dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
282
+ degree: polynomial degree of the per-element shape functions
283
+ discontinuous: if True, use Discontinuous Galerkin shape functions. Discontinuous is implied if degree is 0, i.e, piecewise-constant shape functions.
284
+ element_basis: type of basis function for the individual elements
285
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
286
+
287
+ Returns:
288
+ the constructed function space
289
+ """
290
+
291
+ basis_space = make_polynomial_basis_space(geo, degree, element_basis, discontinuous, family)
292
+ return CollocatedFunctionSpace(basis_space, dtype=dtype, dof_mapper=dof_mapper)