warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +882 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/{builtins.py → _src/builtins.py} +1435 -379
  8. warp/_src/codegen.py +4361 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/{fem → _src/fem}/domain.py +42 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +32 -15
  22. warp/{fem → _src/fem}/field/restriction.py +3 -1
  23. warp/{fem → _src/fem}/field/virtual.py +55 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
  28. warp/{fem → _src/fem}/geometry/element.py +34 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +50 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
  34. warp/{fem → _src/fem}/geometry/partition.py +123 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
  38. warp/{fem → _src/fem}/integrate.py +166 -158
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
  46. warp/_src/fem/space/basis_space.py +681 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
  48. warp/{fem → _src/fem}/space/function_space.py +16 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
  53. warp/{fem → _src/fem}/space/partition.py +119 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
  55. warp/{fem → _src/fem}/space/restriction.py +68 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
  63. warp/_src/fem/space/topology.py +461 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
  65. warp/_src/fem/types.py +114 -0
  66. warp/_src/fem/utils.py +488 -0
  67. warp/_src/jax.py +188 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +389 -0
  70. warp/_src/jax_experimental/ffi.py +1286 -0
  71. warp/_src/jax_experimental/xla_ffi.py +658 -0
  72. warp/_src/marching_cubes.py +710 -0
  73. warp/_src/math.py +416 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +165 -0
  76. warp/_src/optim/linear.py +1608 -0
  77. warp/_src/optim/sgd.py +114 -0
  78. warp/_src/paddle.py +408 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +291 -0
  81. warp/_src/render/render_opengl.py +3638 -0
  82. warp/_src/render/render_usd.py +939 -0
  83. warp/_src/render/utils.py +162 -0
  84. warp/_src/sparse.py +2718 -0
  85. warp/_src/tape.py +1208 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +393 -0
  88. warp/_src/types.py +5888 -0
  89. warp/_src/utils.py +1695 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.so +0 -0
  92. warp/bin/warp.so +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +3 -3
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +521 -250
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +18 -17
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +578 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,16 @@
16
16
  from typing import Any, Optional, Tuple, Union
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_kernel
20
- from warp.fem.domain import GeometryDomain
21
- from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, make_free_sample
22
- from warp.fem.utils import compress_node_indices
19
+ from warp._src.fem.cache import TemporaryStore, borrow_temporary, dynamic_kernel
20
+ from warp._src.fem.domain import GeometryDomain
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, make_free_sample
22
+ from warp._src.fem.utils import compress_node_indices
23
+ from warp._src.types import is_array
23
24
 
24
25
  from .quadrature import Quadrature
25
26
 
27
+ _wp_module_name_ = "warp.fem.quadrature.pic_quadrature"
28
+
26
29
 
27
30
  class PicQuadrature(Quadrature):
28
31
  """Particle-based quadrature formula, using a global set of points unevenly spread out over geometry elements.
@@ -63,7 +66,7 @@ class PicQuadrature(Quadrature):
63
66
 
64
67
  @property
65
68
  def name(self):
66
- return f"{self.__class__.__name__}"
69
+ return self.__class__.__name__
67
70
 
68
71
  @Quadrature.domain.setter
69
72
  def domain(self, domain: GeometryDomain):
@@ -84,15 +87,9 @@ class PicQuadrature(Quadrature):
84
87
  particle_fraction: wp.array(dtype=float)
85
88
  particle_coords: wp.array(dtype=Coords)
86
89
 
87
- @cached_arg_value
88
- def arg_value(self, device) -> Arg:
89
- arg = PicQuadrature.Arg()
90
- self.fill_arg(arg, device)
91
- return arg
92
-
93
90
  def fill_arg(self, args: Arg, device):
94
- args.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
95
- args.cell_particle_indices = self._cell_particle_indices.array.to(device)
91
+ args.cell_particle_offsets = self._cell_particle_offsets.to(device)
92
+ args.cell_particle_indices = self._cell_particle_indices.to(device)
96
93
  args.particle_fraction = self._particle_fraction.to(device)
97
94
  args.particle_coords = self.particle_coords.to(device)
98
95
 
@@ -101,16 +98,16 @@ class PicQuadrature(Quadrature):
101
98
 
102
99
  def active_cell_count(self):
103
100
  """Number of cells containing at least one particle"""
104
- return self._cell_count
101
+ return self._cell_count.numpy()[0]
105
102
 
106
103
  def max_points_per_element(self):
107
104
  if self._max_particles_per_cell is None:
108
- max_ppc = wp.zeros(shape=(1,), dtype=int, device=self._cell_particle_offsets.array.device)
105
+ max_ppc = wp.zeros(shape=(1,), dtype=int, device=self._cell_particle_offsets.device)
109
106
  wp.launch(
110
107
  PicQuadrature._max_particles_per_cell_kernel,
111
- self._cell_particle_offsets.array.shape[0] - 1,
108
+ self._cell_particle_offsets.shape[0] - 1,
112
109
  device=max_ppc.device,
113
- inputs=[self._cell_particle_offsets.array, max_ppc],
110
+ inputs=[self._cell_particle_offsets, max_ppc],
114
111
  )
115
112
  self._max_particles_per_cell = int(max_ppc.numpy()[0])
116
113
  return self._max_particles_per_cell
@@ -157,7 +154,7 @@ class PicQuadrature(Quadrature):
157
154
  kernel=PicQuadrature._fill_mask_kernel,
158
155
  dim=self.domain.geometry_element_count(),
159
156
  device=mask.device,
160
- inputs=[self._cell_particle_offsets.array, mask],
157
+ inputs=[self._cell_particle_offsets, mask],
161
158
  )
162
159
 
163
160
  @wp.kernel
@@ -184,7 +181,7 @@ class PicQuadrature(Quadrature):
184
181
  cell_fraction[p] = 1.0 / float(cell_particle_count)
185
182
 
186
183
  def _bin_particles(self, positions, measures, max_dist: float, temporary_store: TemporaryStore):
187
- if wp.types.is_array(positions):
184
+ if is_array(positions):
188
185
  device = positions.device
189
186
  if not self.domain.supports_lookup(device):
190
187
  raise RuntimeError(
@@ -272,7 +269,7 @@ class PicQuadrature(Quadrature):
272
269
  kernel=PicQuadrature._compute_uniform_fraction,
273
270
  inputs=[
274
271
  cell_index,
275
- self._cell_particle_offsets.array,
272
+ self._cell_particle_offsets,
276
273
  self._particle_fraction,
277
274
  ],
278
275
  device=device,
@@ -13,17 +13,20 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from functools import cached_property
16
17
  from typing import Any, ClassVar, Optional
17
18
 
18
19
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.domain import GeometryDomain
21
- from warp.fem.geometry import Element
22
- from warp.fem.space import FunctionSpace
23
- from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
20
+ from warp._src.fem import cache
21
+ from warp._src.fem.domain import GeometryDomain
22
+ from warp._src.fem.geometry import Element
23
+ from warp._src.fem.space.function_space import FunctionSpace
24
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
24
25
 
25
26
  from ..polynomial import Polynomial
26
27
 
28
+ _wp_module_name_ = "warp.fem.quadrature.quadrature"
29
+
27
30
 
28
31
  @wp.struct
29
32
  class QuadraturePointElementIndex:
@@ -48,18 +51,22 @@ class Quadrature:
48
51
  """Domain over which this quadrature is defined"""
49
52
  return self._domain
50
53
 
54
+ @cache.cached_arg_value
51
55
  def arg_value(self, device) -> "Arg":
52
56
  """
53
57
  Value of the argument to be passed to device
54
58
  """
55
- arg = Quadrature.Arg()
59
+ arg = self.Arg()
60
+ self.fill_arg(arg, device)
56
61
  return arg
57
62
 
58
63
  def fill_arg(self, arg: Arg, device):
59
64
  """
60
65
  Fill the argument with the value of the argument to be passed to device
61
66
  """
62
- pass
67
+ if self.arg_value is __class__.arg_value:
68
+ raise NotImplementedError()
69
+ arg.assign(self.arg_value(device))
63
70
 
64
71
  def total_point_count(self):
65
72
  """Number of unique quadrature points that can be indexed by this rule.
@@ -160,6 +167,8 @@ class Quadrature:
160
167
  ):
161
168
  domain_element_index = wp.tid()
162
169
  element_index = self.domain.element_index(domain_index_arg, domain_element_index)
170
+ if element_index == NULL_ELEMENT_INDEX:
171
+ return
163
172
 
164
173
  qp_point_count = self.point_count(domain_arg, qp_arg, domain_element_index, element_index)
165
174
  for k in range(qp_point_count):
@@ -217,7 +226,9 @@ class _QuadratureWithRegularEvaluationPoints(Quadrature):
217
226
  cache.setup_dynamic_attributes(self, cls=__class__)
218
227
 
219
228
  ElementIndexArg = Quadrature.Arg
220
- element_index_arg_value = Quadrature.arg_value
229
+
230
+ def element_index_arg_value(self, device):
231
+ return Quadrature.Arg()
221
232
 
222
233
  def evaluation_point_count(self):
223
234
  return self.domain.element_count() * self._EVALUATION_POINTS_PER_ELEMENT
@@ -267,22 +278,33 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
267
278
  _cache: ClassVar = {}
268
279
 
269
280
  def __init__(self, element: Element, order: int, family: Polynomial):
270
- self.points, self.weights = element.instantiate_quadrature(order, family)
281
+ self.points, self.weights = element.prototype.instantiate_quadrature(order, family)
271
282
  self.count = wp.constant(len(self.points))
272
283
 
273
284
  @cache.cached_arg_value
274
285
  def arg_value(self, device):
275
286
  arg = RegularQuadrature.Arg()
276
- self.fill_arg(arg, device)
277
- return arg
278
287
 
279
- def fill_arg(self, arg: "RegularQuadrature.Arg", device):
288
+ # pause graph capture while we copy from host
289
+ # we want the cached result to be available outside of the graph
290
+ if device.is_capturing:
291
+ graph = wp.context.capture_pause()
292
+ else:
293
+ graph = None
294
+
280
295
  arg.points = wp.array(self.points, device=device, dtype=Coords)
281
296
  arg.weights = wp.array(self.weights, device=device, dtype=float)
282
297
 
298
+ if graph is not None:
299
+ wp.context.capture_resume(graph)
300
+ return arg
301
+
302
+ def fill_arg(self, arg: "RegularQuadrature.Arg", device):
303
+ arg.assign(self.arg_value(device))
304
+
283
305
  @staticmethod
284
306
  def get(element: Element, order: int, family: Polynomial):
285
- key = (element.__class__.__name__, order, family)
307
+ key = (element.value, order, family)
286
308
  try:
287
309
  return RegularQuadrature.CachedFormula._cache[key]
288
310
  except KeyError:
@@ -311,7 +333,7 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
311
333
 
312
334
  cache.setup_dynamic_attributes(self)
313
335
 
314
- @property
336
+ @cached_property
315
337
  def name(self):
316
338
  return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
317
339
 
@@ -329,9 +351,6 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
329
351
  def weights(self):
330
352
  return self._formula.weights
331
353
 
332
- def arg_value(self, device):
333
- return self._formula.arg_value(device)
334
-
335
354
  def fill_arg(self, arg: "RegularQuadrature.Arg", device):
336
355
  self._formula.fill_arg(arg, device)
337
356
 
@@ -398,20 +417,27 @@ class NodalQuadrature(Quadrature):
398
417
  any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
399
418
  """
400
419
 
401
- def __init__(self, domain: Optional[GeometryDomain], space: FunctionSpace):
420
+ _dynamic_attribute_constructors: ClassVar = {
421
+ "Arg": lambda obj: obj._make_arg(),
422
+ "point_count": lambda obj: obj._make_point_count(),
423
+ "point_index": lambda obj: obj._make_point_index(),
424
+ "point_coords": lambda obj: obj._make_point_coords(),
425
+ "point_weight": lambda obj: obj._make_point_weight(),
426
+ "point_evaluation_index": lambda obj: obj._make_point_evaluation_index(),
427
+ }
428
+
429
+ def __init__(
430
+ self,
431
+ domain: Optional[GeometryDomain],
432
+ space: Optional[FunctionSpace],
433
+ ):
402
434
  self._space = space
403
435
 
404
436
  super().__init__(domain)
405
437
 
406
- self.Arg = self._make_arg()
407
-
408
- self.point_count = self._make_point_count()
409
- self.point_index = self._make_point_index()
410
- self.point_coords = self._make_point_coords()
411
- self.point_weight = self._make_point_weight()
412
- self.point_evaluation_index = self._make_point_evaluation_index()
438
+ cache.setup_dynamic_attributes(self)
413
439
 
414
- @property
440
+ @cached_property
415
441
  def name(self):
416
442
  return f"{self.__class__.__name__}_{self._space.name}"
417
443
 
@@ -429,12 +455,6 @@ class NodalQuadrature(Quadrature):
429
455
 
430
456
  return Arg
431
457
 
432
- @cache.cached_arg_value
433
- def arg_value(self, device):
434
- arg = self.Arg()
435
- self.fill_arg(arg, device)
436
- return arg
437
-
438
458
  def fill_arg(self, arg: "NodalQuadrature.Arg", device):
439
459
  self._space.fill_space_arg(arg.space_arg, device)
440
460
  self._space.topology.fill_topo_arg(arg.topo_arg, device)
@@ -486,7 +506,8 @@ class NodalQuadrature(Quadrature):
486
506
  element_index: ElementIndex,
487
507
  qp_index: int,
488
508
  ):
489
- return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
509
+ node_index = self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
510
+ return node_index
490
511
 
491
512
  return point_index
492
513
 
@@ -529,7 +550,12 @@ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
529
550
  points: wp.array2d(dtype=Coords)
530
551
  weights: wp.array2d(dtype=float)
531
552
 
532
- def __init__(self, domain: GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
553
+ def __init__(
554
+ self,
555
+ domain: GeometryDomain,
556
+ points: "wp.array2d(dtype=Coords)",
557
+ weights: "wp.array2d(dtype=float)",
558
+ ):
533
559
  if points.shape != weights.shape:
534
560
  raise ValueError("Points and weights arrays must have the same shape")
535
561
 
@@ -554,7 +580,7 @@ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
554
580
  self._points = points
555
581
  self._weights = weights
556
582
 
557
- @property
583
+ @cached_property
558
584
  def name(self):
559
585
  return f"{self.__class__.__name__}_{self._whole_geo}_{self._points_per_cell}"
560
586
 
@@ -564,52 +590,76 @@ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
564
590
  def max_points_per_element(self):
565
591
  return self._points_per_cell
566
592
 
567
- def arg_value(self, device):
568
- arg = self.Arg()
569
- self.fill_arg(arg, device)
570
- return arg
571
-
572
593
  def fill_arg(self, arg: "ExplicitQuadrature.Arg", device):
573
594
  arg.points_per_cell = self._points_per_cell
574
595
  arg.points = self._points.to(device)
575
596
  arg.weights = self._weights.to(device)
576
597
 
577
598
  @wp.func
578
- def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
599
+ def point_count(
600
+ elt_arg: Any,
601
+ qp_arg: Arg,
602
+ domain_element_index: ElementIndex,
603
+ element_index: ElementIndex,
604
+ ):
579
605
  return qp_arg.points.shape[1]
580
606
 
581
607
  @wp.func
582
608
  def _point_coords_domain(
583
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
609
+ elt_arg: Any,
610
+ qp_arg: Arg,
611
+ domain_element_index: ElementIndex,
612
+ element_index: ElementIndex,
613
+ qp_index: int,
584
614
  ):
585
615
  return qp_arg.points[domain_element_index, qp_index]
586
616
 
587
617
  @wp.func
588
618
  def _point_weight_domain(
589
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
619
+ elt_arg: Any,
620
+ qp_arg: Arg,
621
+ domain_element_index: ElementIndex,
622
+ element_index: ElementIndex,
623
+ qp_index: int,
590
624
  ):
591
625
  return qp_arg.weights[domain_element_index, qp_index]
592
626
 
593
627
  @wp.func
594
628
  def _point_index_domain(
595
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
629
+ elt_arg: Any,
630
+ qp_arg: Arg,
631
+ domain_element_index: ElementIndex,
632
+ element_index: ElementIndex,
633
+ qp_index: int,
596
634
  ):
597
635
  return qp_arg.points_per_cell * domain_element_index + qp_index
598
636
 
599
637
  @wp.func
600
638
  def _point_coords_geo(
601
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
639
+ elt_arg: Any,
640
+ qp_arg: Arg,
641
+ domain_element_index: ElementIndex,
642
+ element_index: ElementIndex,
643
+ qp_index: int,
602
644
  ):
603
645
  return qp_arg.points[element_index, qp_index]
604
646
 
605
647
  @wp.func
606
648
  def _point_weight_geo(
607
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
649
+ elt_arg: Any,
650
+ qp_arg: Arg,
651
+ domain_element_index: ElementIndex,
652
+ element_index: ElementIndex,
653
+ qp_index: int,
608
654
  ):
609
655
  return qp_arg.weights[element_index, qp_index]
610
656
 
611
657
  @wp.func
612
658
  def _point_index_geo(
613
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
659
+ elt_arg: Any,
660
+ qp_arg: Arg,
661
+ domain_element_index: ElementIndex,
662
+ element_index: ElementIndex,
663
+ qp_index: int,
614
664
  ):
615
665
  return qp_arg.points_per_cell * element_index + qp_index
@@ -0,0 +1,248 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # isort: skip_file
17
+
18
+ from enum import Enum
19
+ from typing import Optional
20
+
21
+ import warp._src.fem.domain as _domain
22
+ import warp._src.fem.geometry as _geometry
23
+ import warp._src.fem.polynomial as _polynomial
24
+
25
+ from .function_space import FunctionSpace
26
+ from .basis_function_space import CollocatedFunctionSpace, ContravariantFunctionSpace, CovariantFunctionSpace
27
+ from .topology import SpaceTopology, RegularDiscontinuousSpaceTopology
28
+ from .basis_space import BasisSpace, ShapeBasisSpace
29
+ from .shape import ElementBasis, make_element_shape_function, ShapeFunction
30
+
31
+ from .grid_2d_function_space import make_grid_2d_space_topology
32
+
33
+ from .grid_3d_function_space import make_grid_3d_space_topology
34
+
35
+ from .trimesh_function_space import make_trimesh_space_topology
36
+
37
+ from .tetmesh_function_space import make_tetmesh_space_topology
38
+
39
+ from .quadmesh_function_space import make_quadmesh_space_topology
40
+
41
+ from .hexmesh_function_space import make_hexmesh_space_topology
42
+
43
+ from .nanogrid_function_space import make_nanogrid_space_topology
44
+
45
+
46
+ from .partition import SpacePartition, make_space_partition
47
+ from .restriction import SpaceRestriction
48
+
49
+
50
+ from .dof_mapper import DofMapper, IdentityMapper, SymmetricTensorMapper, SkewSymmetricTensorMapper
51
+
52
+
53
+ def make_space_restriction(
54
+ space: Optional[FunctionSpace] = None,
55
+ space_partition: Optional[SpacePartition] = None,
56
+ domain: Optional[_domain.GeometryDomain] = None,
57
+ space_topology: Optional[SpaceTopology] = None,
58
+ device=None,
59
+ temporary_store: "Optional[warp.fem.TemporaryStore]" = None, # noqa: F821
60
+ ) -> SpaceRestriction:
61
+ """
62
+ Restricts a function space partition to a Domain, i.e. a subset of its elements.
63
+
64
+ One of `space_partition`, `space_topology`, or `space` must be provided (and will be considered in that order).
65
+
66
+ Args:
67
+ space: (deprecated) if neither `space_partition` nor `space_topology` are provided, the space defining the topology to restrict
68
+ space_partition: the subset of nodes from the space topology to consider
69
+ domain: the domain to restrict the space to, defaults to all cells of the space geometry or partition.
70
+ space_topology: the space topology to be restricted, if `space_partition` is ``None``.
71
+ device: device on which to perform and store computations
72
+ temporary_store: shared pool from which to allocate temporary arrays
73
+ """
74
+
75
+ if space_partition is None:
76
+ if space_topology is None:
77
+ assert space is not None
78
+ space_topology = space.topology
79
+
80
+ if domain is None:
81
+ domain = _domain.Cells(geometry=space_topology.geometry)
82
+
83
+ space_partition = make_space_partition(
84
+ space_topology=space_topology, geometry_partition=domain.geometry_partition
85
+ )
86
+ elif domain is None:
87
+ domain = _domain.Cells(geometry=space_partition.geo_partition)
88
+
89
+ return SpaceRestriction(
90
+ space_partition=space_partition, domain=domain, device=device, temporary_store=temporary_store
91
+ )
92
+
93
+
94
+ def make_polynomial_basis_space(
95
+ geo: _geometry.Geometry,
96
+ degree: int = 1,
97
+ element_basis: Optional[ElementBasis] = None,
98
+ discontinuous: bool = False,
99
+ family: Optional[_polynomial.Polynomial] = None,
100
+ ) -> BasisSpace:
101
+ """
102
+ Equips a geometry with a polynomial basis.
103
+
104
+ Args:
105
+ geo: the Geometry on which to build the space
106
+ degree: polynomial degree of the per-element shape functions
107
+ discontinuous: if True, use Discontinuous Galerkin shape functions. Discontinuous is implied if degree is 0, i.e, piecewise-constant shape functions.
108
+ element_basis: type of basis function for the individual elements
109
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
110
+
111
+ Returns:
112
+ the constructed basis space
113
+ """
114
+
115
+ shape = make_element_shape_function(geo.reference_cell(), degree=degree, element_basis=element_basis, family=family)
116
+
117
+ discontinuous = discontinuous or element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL
118
+ topology = make_element_based_space_topology(geo, shape, discontinuous)
119
+
120
+ return ShapeBasisSpace(topology, shape)
121
+
122
+
123
+ def make_element_based_space_topology(
124
+ geo: _geometry.Geometry,
125
+ shape: ShapeFunction,
126
+ discontinuous: bool = False,
127
+ ) -> SpaceTopology:
128
+ """
129
+ Makes a space topology from a geometry and an element-based shape function.
130
+
131
+ Args:
132
+ geo: The geometry to make the topology for
133
+ shape: The shape function to make the topology for
134
+ discontinuous: Whether to make a discontinuous topology
135
+
136
+ Returns:
137
+ The element-based space topology
138
+
139
+ Raises:
140
+ NotImplementedError: If the geometry type is not supported
141
+ ValueError: If the shape function is not supported for the given geometry`
142
+ """
143
+
144
+ topology = None
145
+ base_geo = geo.base
146
+
147
+ if discontinuous or shape.ORDER == 0:
148
+ topology = RegularDiscontinuousSpaceTopology(geo, shape.NODES_PER_ELEMENT)
149
+ elif isinstance(base_geo, _geometry.Grid2D):
150
+ topology = make_grid_2d_space_topology(geo, shape)
151
+ elif isinstance(base_geo, _geometry.Grid3D):
152
+ topology = make_grid_3d_space_topology(geo, shape)
153
+ elif isinstance(base_geo, _geometry.Trimesh):
154
+ topology = make_trimesh_space_topology(geo, shape)
155
+ elif isinstance(base_geo, _geometry.Tetmesh):
156
+ topology = make_tetmesh_space_topology(geo, shape)
157
+ elif isinstance(base_geo, _geometry.Quadmesh):
158
+ topology = make_quadmesh_space_topology(geo, shape)
159
+ elif isinstance(base_geo, _geometry.Hexmesh):
160
+ topology = make_hexmesh_space_topology(geo, shape)
161
+ elif isinstance(base_geo, _geometry.Nanogrid) or isinstance(base_geo, _geometry.AdaptiveNanogrid):
162
+ topology = make_nanogrid_space_topology(geo, shape)
163
+
164
+ if topology is None:
165
+ raise NotImplementedError(f"Unsupported geometry type {geo.name}")
166
+
167
+ return topology
168
+
169
+
170
+ def make_collocated_function_space(
171
+ basis_space: BasisSpace, dtype: type = float, dof_mapper: Optional[DofMapper] = None
172
+ ) -> CollocatedFunctionSpace:
173
+ """
174
+ Constructs a function space from a scalar-valued basis space and a value type, such that all degrees of freedom of the value type are stored at each of the basis nodes.
175
+
176
+ Args:
177
+ geo: the Geometry on which to build the space
178
+ dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
179
+ dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
180
+
181
+ Returns:
182
+ the constructed function space
183
+ """
184
+
185
+ if basis_space.value != ShapeFunction.Value.Scalar:
186
+ raise ValueError("Collocated function spaces may only be constructed from scalar-valued basis")
187
+
188
+ return CollocatedFunctionSpace(basis_space, dtype=dtype, dof_mapper=dof_mapper)
189
+
190
+
191
+ def make_covariant_function_space(
192
+ basis_space: BasisSpace,
193
+ ) -> CovariantFunctionSpace:
194
+ """
195
+ Constructs a covariant function space from a vector-valued basis space
196
+ """
197
+
198
+ if basis_space.value != ShapeFunction.Value.CovariantVector:
199
+ raise ValueError("Covariant function spaces may only be constructed from covariant vector-valued basis")
200
+ return CovariantFunctionSpace(basis_space)
201
+
202
+
203
+ def make_contravariant_function_space(
204
+ basis_space: BasisSpace,
205
+ ) -> ContravariantFunctionSpace:
206
+ """
207
+ Constructs a contravariant function space from a vector-valued basis space
208
+ """
209
+
210
+ if basis_space.value != ShapeFunction.Value.ContravariantVector:
211
+ raise ValueError("Contravariant function spaces may only be constructed from contravariant vector-valued basis")
212
+ return ContravariantFunctionSpace(basis_space)
213
+
214
+
215
+ def make_polynomial_space(
216
+ geo: _geometry.Geometry,
217
+ dtype: type = float,
218
+ dof_mapper: Optional[DofMapper] = None,
219
+ degree: int = 1,
220
+ element_basis: Optional[ElementBasis] = None,
221
+ discontinuous: bool = False,
222
+ family: Optional[_polynomial.Polynomial] = None,
223
+ ) -> CollocatedFunctionSpace:
224
+ """
225
+ Equips a geometry with a collocated, polynomial function space.
226
+ Equivalent to successive calls to :func:`make_polynomial_basis_space` then `make_collocated_function_space`, `make_covariant_function_space` or `make_contravariant_function_space`.
227
+
228
+ Args:
229
+ geo: the Geometry on which to build the space
230
+ dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
231
+ dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
232
+ degree: polynomial degree of the per-element shape functions
233
+ discontinuous: if True, use Discontinuous Galerkin shape functions. Discontinuous is implied if degree is 0, i.e, piecewise-constant shape functions.
234
+ element_basis: type of basis function for the individual elements
235
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
236
+
237
+ Returns:
238
+ the constructed function space
239
+ """
240
+
241
+ basis_space = make_polynomial_basis_space(geo, degree, element_basis, discontinuous, family)
242
+
243
+ if basis_space.value == ShapeFunction.Value.CovariantVector:
244
+ return make_covariant_function_space(basis_space)
245
+ if basis_space.value == ShapeFunction.Value.ContravariantVector:
246
+ return make_contravariant_function_space(basis_space)
247
+
248
+ return make_collocated_function_space(basis_space, dtype=dtype, dof_mapper=dof_mapper)