warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,11 @@
16
16
  from typing import Any, Optional, Tuple, Union
17
17
 
18
18
  import warp as wp
19
- from warp.fem.cache import TemporaryStore, borrow_temporary, cached_arg_value, dynamic_kernel
20
- from warp.fem.domain import GeometryDomain
21
- from warp.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, make_free_sample
22
- from warp.fem.utils import compress_node_indices
19
+ from warp._src.fem.cache import TemporaryStore, borrow_temporary, dynamic_kernel
20
+ from warp._src.fem.domain import GeometryDomain
21
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, OUTSIDE, Coords, ElementIndex, make_free_sample
22
+ from warp._src.fem.utils import compress_node_indices
23
+ from warp._src.types import is_array
23
24
 
24
25
  from .quadrature import Quadrature
25
26
 
@@ -63,7 +64,7 @@ class PicQuadrature(Quadrature):
63
64
 
64
65
  @property
65
66
  def name(self):
66
- return f"{self.__class__.__name__}"
67
+ return self.__class__.__name__
67
68
 
68
69
  @Quadrature.domain.setter
69
70
  def domain(self, domain: GeometryDomain):
@@ -84,15 +85,9 @@ class PicQuadrature(Quadrature):
84
85
  particle_fraction: wp.array(dtype=float)
85
86
  particle_coords: wp.array(dtype=Coords)
86
87
 
87
- @cached_arg_value
88
- def arg_value(self, device) -> Arg:
89
- arg = PicQuadrature.Arg()
90
- self.fill_arg(arg, device)
91
- return arg
92
-
93
88
  def fill_arg(self, args: Arg, device):
94
- args.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
95
- args.cell_particle_indices = self._cell_particle_indices.array.to(device)
89
+ args.cell_particle_offsets = self._cell_particle_offsets.to(device)
90
+ args.cell_particle_indices = self._cell_particle_indices.to(device)
96
91
  args.particle_fraction = self._particle_fraction.to(device)
97
92
  args.particle_coords = self.particle_coords.to(device)
98
93
 
@@ -101,16 +96,16 @@ class PicQuadrature(Quadrature):
101
96
 
102
97
  def active_cell_count(self):
103
98
  """Number of cells containing at least one particle"""
104
- return self._cell_count
99
+ return self._cell_count.numpy()[0]
105
100
 
106
101
  def max_points_per_element(self):
107
102
  if self._max_particles_per_cell is None:
108
- max_ppc = wp.zeros(shape=(1,), dtype=int, device=self._cell_particle_offsets.array.device)
103
+ max_ppc = wp.zeros(shape=(1,), dtype=int, device=self._cell_particle_offsets.device)
109
104
  wp.launch(
110
105
  PicQuadrature._max_particles_per_cell_kernel,
111
- self._cell_particle_offsets.array.shape[0] - 1,
106
+ self._cell_particle_offsets.shape[0] - 1,
112
107
  device=max_ppc.device,
113
- inputs=[self._cell_particle_offsets.array, max_ppc],
108
+ inputs=[self._cell_particle_offsets, max_ppc],
114
109
  )
115
110
  self._max_particles_per_cell = int(max_ppc.numpy()[0])
116
111
  return self._max_particles_per_cell
@@ -157,7 +152,7 @@ class PicQuadrature(Quadrature):
157
152
  kernel=PicQuadrature._fill_mask_kernel,
158
153
  dim=self.domain.geometry_element_count(),
159
154
  device=mask.device,
160
- inputs=[self._cell_particle_offsets.array, mask],
155
+ inputs=[self._cell_particle_offsets, mask],
161
156
  )
162
157
 
163
158
  @wp.kernel
@@ -184,7 +179,7 @@ class PicQuadrature(Quadrature):
184
179
  cell_fraction[p] = 1.0 / float(cell_particle_count)
185
180
 
186
181
  def _bin_particles(self, positions, measures, max_dist: float, temporary_store: TemporaryStore):
187
- if wp.types.is_array(positions):
182
+ if is_array(positions):
188
183
  device = positions.device
189
184
  if not self.domain.supports_lookup(device):
190
185
  raise RuntimeError(
@@ -272,7 +267,7 @@ class PicQuadrature(Quadrature):
272
267
  kernel=PicQuadrature._compute_uniform_fraction,
273
268
  inputs=[
274
269
  cell_index,
275
- self._cell_particle_offsets.array,
270
+ self._cell_particle_offsets,
276
271
  self._particle_fraction,
277
272
  ],
278
273
  device=device,
@@ -13,14 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from functools import cached_property
16
17
  from typing import Any, ClassVar, Optional
17
18
 
18
19
  import warp as wp
19
- from warp.fem import cache
20
- from warp.fem.domain import GeometryDomain
21
- from warp.fem.geometry import Element
22
- from warp.fem.space import FunctionSpace
23
- from warp.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
20
+ from warp._src.fem import cache
21
+ from warp._src.fem.domain import GeometryDomain
22
+ from warp._src.fem.geometry import Element
23
+ from warp._src.fem.space.function_space import FunctionSpace
24
+ from warp._src.fem.types import NULL_ELEMENT_INDEX, Coords, ElementIndex, QuadraturePointIndex
24
25
 
25
26
  from ..polynomial import Polynomial
26
27
 
@@ -48,18 +49,22 @@ class Quadrature:
48
49
  """Domain over which this quadrature is defined"""
49
50
  return self._domain
50
51
 
52
+ @cache.cached_arg_value
51
53
  def arg_value(self, device) -> "Arg":
52
54
  """
53
55
  Value of the argument to be passed to device
54
56
  """
55
- arg = Quadrature.Arg()
57
+ arg = self.Arg()
58
+ self.fill_arg(arg, device)
56
59
  return arg
57
60
 
58
61
  def fill_arg(self, arg: Arg, device):
59
62
  """
60
63
  Fill the argument with the value of the argument to be passed to device
61
64
  """
62
- pass
65
+ if self.arg_value is __class__.arg_value:
66
+ raise NotImplementedError()
67
+ arg.assign(self.arg_value(device))
63
68
 
64
69
  def total_point_count(self):
65
70
  """Number of unique quadrature points that can be indexed by this rule.
@@ -160,6 +165,8 @@ class Quadrature:
160
165
  ):
161
166
  domain_element_index = wp.tid()
162
167
  element_index = self.domain.element_index(domain_index_arg, domain_element_index)
168
+ if element_index == NULL_ELEMENT_INDEX:
169
+ return
163
170
 
164
171
  qp_point_count = self.point_count(domain_arg, qp_arg, domain_element_index, element_index)
165
172
  for k in range(qp_point_count):
@@ -217,7 +224,9 @@ class _QuadratureWithRegularEvaluationPoints(Quadrature):
217
224
  cache.setup_dynamic_attributes(self, cls=__class__)
218
225
 
219
226
  ElementIndexArg = Quadrature.Arg
220
- element_index_arg_value = Quadrature.arg_value
227
+
228
+ def element_index_arg_value(self, device):
229
+ return Quadrature.Arg()
221
230
 
222
231
  def evaluation_point_count(self):
223
232
  return self.domain.element_count() * self._EVALUATION_POINTS_PER_ELEMENT
@@ -267,22 +276,33 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
267
276
  _cache: ClassVar = {}
268
277
 
269
278
  def __init__(self, element: Element, order: int, family: Polynomial):
270
- self.points, self.weights = element.instantiate_quadrature(order, family)
279
+ self.points, self.weights = element.prototype.instantiate_quadrature(order, family)
271
280
  self.count = wp.constant(len(self.points))
272
281
 
273
282
  @cache.cached_arg_value
274
283
  def arg_value(self, device):
275
284
  arg = RegularQuadrature.Arg()
276
- self.fill_arg(arg, device)
277
- return arg
278
285
 
279
- def fill_arg(self, arg: "RegularQuadrature.Arg", device):
286
+ # pause graph capture while we copy from host
287
+ # we want the cached result to be available outside of the graph
288
+ if device.is_capturing:
289
+ graph = wp.context.capture_pause()
290
+ else:
291
+ graph = None
292
+
280
293
  arg.points = wp.array(self.points, device=device, dtype=Coords)
281
294
  arg.weights = wp.array(self.weights, device=device, dtype=float)
282
295
 
296
+ if graph is not None:
297
+ wp.context.capture_resume(graph)
298
+ return arg
299
+
300
+ def fill_arg(self, arg: "RegularQuadrature.Arg", device):
301
+ arg.assign(self.arg_value(device))
302
+
283
303
  @staticmethod
284
304
  def get(element: Element, order: int, family: Polynomial):
285
- key = (element.__class__.__name__, order, family)
305
+ key = (element.value, order, family)
286
306
  try:
287
307
  return RegularQuadrature.CachedFormula._cache[key]
288
308
  except KeyError:
@@ -311,7 +331,7 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
311
331
 
312
332
  cache.setup_dynamic_attributes(self)
313
333
 
314
- @property
334
+ @cached_property
315
335
  def name(self):
316
336
  return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
317
337
 
@@ -329,9 +349,6 @@ class RegularQuadrature(_QuadratureWithRegularEvaluationPoints):
329
349
  def weights(self):
330
350
  return self._formula.weights
331
351
 
332
- def arg_value(self, device):
333
- return self._formula.arg_value(device)
334
-
335
352
  def fill_arg(self, arg: "RegularQuadrature.Arg", device):
336
353
  self._formula.fill_arg(arg, device)
337
354
 
@@ -398,20 +415,27 @@ class NodalQuadrature(Quadrature):
398
415
  any assumption about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
399
416
  """
400
417
 
401
- def __init__(self, domain: Optional[GeometryDomain], space: FunctionSpace):
418
+ _dynamic_attribute_constructors: ClassVar = {
419
+ "Arg": lambda obj: obj._make_arg(),
420
+ "point_count": lambda obj: obj._make_point_count(),
421
+ "point_index": lambda obj: obj._make_point_index(),
422
+ "point_coords": lambda obj: obj._make_point_coords(),
423
+ "point_weight": lambda obj: obj._make_point_weight(),
424
+ "point_evaluation_index": lambda obj: obj._make_point_evaluation_index(),
425
+ }
426
+
427
+ def __init__(
428
+ self,
429
+ domain: Optional[GeometryDomain],
430
+ space: Optional[FunctionSpace],
431
+ ):
402
432
  self._space = space
403
433
 
404
434
  super().__init__(domain)
405
435
 
406
- self.Arg = self._make_arg()
407
-
408
- self.point_count = self._make_point_count()
409
- self.point_index = self._make_point_index()
410
- self.point_coords = self._make_point_coords()
411
- self.point_weight = self._make_point_weight()
412
- self.point_evaluation_index = self._make_point_evaluation_index()
436
+ cache.setup_dynamic_attributes(self)
413
437
 
414
- @property
438
+ @cached_property
415
439
  def name(self):
416
440
  return f"{self.__class__.__name__}_{self._space.name}"
417
441
 
@@ -429,12 +453,6 @@ class NodalQuadrature(Quadrature):
429
453
 
430
454
  return Arg
431
455
 
432
- @cache.cached_arg_value
433
- def arg_value(self, device):
434
- arg = self.Arg()
435
- self.fill_arg(arg, device)
436
- return arg
437
-
438
456
  def fill_arg(self, arg: "NodalQuadrature.Arg", device):
439
457
  self._space.fill_space_arg(arg.space_arg, device)
440
458
  self._space.topology.fill_topo_arg(arg.topo_arg, device)
@@ -486,7 +504,8 @@ class NodalQuadrature(Quadrature):
486
504
  element_index: ElementIndex,
487
505
  qp_index: int,
488
506
  ):
489
- return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
507
+ node_index = self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
508
+ return node_index
490
509
 
491
510
  return point_index
492
511
 
@@ -529,7 +548,12 @@ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
529
548
  points: wp.array2d(dtype=Coords)
530
549
  weights: wp.array2d(dtype=float)
531
550
 
532
- def __init__(self, domain: GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
551
+ def __init__(
552
+ self,
553
+ domain: GeometryDomain,
554
+ points: "wp.array2d(dtype=Coords)",
555
+ weights: "wp.array2d(dtype=float)",
556
+ ):
533
557
  if points.shape != weights.shape:
534
558
  raise ValueError("Points and weights arrays must have the same shape")
535
559
 
@@ -554,7 +578,7 @@ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
554
578
  self._points = points
555
579
  self._weights = weights
556
580
 
557
- @property
581
+ @cached_property
558
582
  def name(self):
559
583
  return f"{self.__class__.__name__}_{self._whole_geo}_{self._points_per_cell}"
560
584
 
@@ -564,52 +588,76 @@ class ExplicitQuadrature(_QuadratureWithRegularEvaluationPoints):
564
588
  def max_points_per_element(self):
565
589
  return self._points_per_cell
566
590
 
567
- def arg_value(self, device):
568
- arg = self.Arg()
569
- self.fill_arg(arg, device)
570
- return arg
571
-
572
591
  def fill_arg(self, arg: "ExplicitQuadrature.Arg", device):
573
592
  arg.points_per_cell = self._points_per_cell
574
593
  arg.points = self._points.to(device)
575
594
  arg.weights = self._weights.to(device)
576
595
 
577
596
  @wp.func
578
- def point_count(elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex):
597
+ def point_count(
598
+ elt_arg: Any,
599
+ qp_arg: Arg,
600
+ domain_element_index: ElementIndex,
601
+ element_index: ElementIndex,
602
+ ):
579
603
  return qp_arg.points.shape[1]
580
604
 
581
605
  @wp.func
582
606
  def _point_coords_domain(
583
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
607
+ elt_arg: Any,
608
+ qp_arg: Arg,
609
+ domain_element_index: ElementIndex,
610
+ element_index: ElementIndex,
611
+ qp_index: int,
584
612
  ):
585
613
  return qp_arg.points[domain_element_index, qp_index]
586
614
 
587
615
  @wp.func
588
616
  def _point_weight_domain(
589
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
617
+ elt_arg: Any,
618
+ qp_arg: Arg,
619
+ domain_element_index: ElementIndex,
620
+ element_index: ElementIndex,
621
+ qp_index: int,
590
622
  ):
591
623
  return qp_arg.weights[domain_element_index, qp_index]
592
624
 
593
625
  @wp.func
594
626
  def _point_index_domain(
595
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
627
+ elt_arg: Any,
628
+ qp_arg: Arg,
629
+ domain_element_index: ElementIndex,
630
+ element_index: ElementIndex,
631
+ qp_index: int,
596
632
  ):
597
633
  return qp_arg.points_per_cell * domain_element_index + qp_index
598
634
 
599
635
  @wp.func
600
636
  def _point_coords_geo(
601
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
637
+ elt_arg: Any,
638
+ qp_arg: Arg,
639
+ domain_element_index: ElementIndex,
640
+ element_index: ElementIndex,
641
+ qp_index: int,
602
642
  ):
603
643
  return qp_arg.points[element_index, qp_index]
604
644
 
605
645
  @wp.func
606
646
  def _point_weight_geo(
607
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
647
+ elt_arg: Any,
648
+ qp_arg: Arg,
649
+ domain_element_index: ElementIndex,
650
+ element_index: ElementIndex,
651
+ qp_index: int,
608
652
  ):
609
653
  return qp_arg.weights[element_index, qp_index]
610
654
 
611
655
  @wp.func
612
656
  def _point_index_geo(
613
- elt_arg: Any, qp_arg: Arg, domain_element_index: ElementIndex, element_index: ElementIndex, qp_index: int
657
+ elt_arg: Any,
658
+ qp_arg: Arg,
659
+ domain_element_index: ElementIndex,
660
+ element_index: ElementIndex,
661
+ qp_index: int,
614
662
  ):
615
663
  return qp_arg.points_per_cell * element_index + qp_index
@@ -0,0 +1,248 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # isort: skip_file
17
+
18
+ from enum import Enum
19
+ from typing import Optional
20
+
21
+ import warp._src.fem.domain as _domain
22
+ import warp._src.fem.geometry as _geometry
23
+ import warp._src.fem.polynomial as _polynomial
24
+
25
+ from .function_space import FunctionSpace
26
+ from .basis_function_space import CollocatedFunctionSpace, ContravariantFunctionSpace, CovariantFunctionSpace
27
+ from .topology import SpaceTopology, RegularDiscontinuousSpaceTopology
28
+ from .basis_space import BasisSpace, ShapeBasisSpace
29
+ from .shape import ElementBasis, make_element_shape_function, ShapeFunction
30
+
31
+ from .grid_2d_function_space import make_grid_2d_space_topology
32
+
33
+ from .grid_3d_function_space import make_grid_3d_space_topology
34
+
35
+ from .trimesh_function_space import make_trimesh_space_topology
36
+
37
+ from .tetmesh_function_space import make_tetmesh_space_topology
38
+
39
+ from .quadmesh_function_space import make_quadmesh_space_topology
40
+
41
+ from .hexmesh_function_space import make_hexmesh_space_topology
42
+
43
+ from .nanogrid_function_space import make_nanogrid_space_topology
44
+
45
+
46
+ from .partition import SpacePartition, make_space_partition
47
+ from .restriction import SpaceRestriction
48
+
49
+
50
+ from .dof_mapper import DofMapper, IdentityMapper, SymmetricTensorMapper, SkewSymmetricTensorMapper
51
+
52
+
53
+ def make_space_restriction(
54
+ space: Optional[FunctionSpace] = None,
55
+ space_partition: Optional[SpacePartition] = None,
56
+ domain: Optional[_domain.GeometryDomain] = None,
57
+ space_topology: Optional[SpaceTopology] = None,
58
+ device=None,
59
+ temporary_store: "Optional[warp.fem.TemporaryStore]" = None, # noqa: F821
60
+ ) -> SpaceRestriction:
61
+ """
62
+ Restricts a function space partition to a Domain, i.e. a subset of its elements.
63
+
64
+ One of `space_partition`, `space_topology`, or `space` must be provided (and will be considered in that order).
65
+
66
+ Args:
67
+ space: (deprecated) if neither `space_partition` nor `space_topology` are provided, the space defining the topology to restrict
68
+ space_partition: the subset of nodes from the space topology to consider
69
+ domain: the domain to restrict the space to, defaults to all cells of the space geometry or partition.
70
+ space_topology: the space topology to be restricted, if `space_partition` is ``None``.
71
+ device: device on which to perform and store computations
72
+ temporary_store: shared pool from which to allocate temporary arrays
73
+ """
74
+
75
+ if space_partition is None:
76
+ if space_topology is None:
77
+ assert space is not None
78
+ space_topology = space.topology
79
+
80
+ if domain is None:
81
+ domain = _domain.Cells(geometry=space_topology.geometry)
82
+
83
+ space_partition = make_space_partition(
84
+ space_topology=space_topology, geometry_partition=domain.geometry_partition
85
+ )
86
+ elif domain is None:
87
+ domain = _domain.Cells(geometry=space_partition.geo_partition)
88
+
89
+ return SpaceRestriction(
90
+ space_partition=space_partition, domain=domain, device=device, temporary_store=temporary_store
91
+ )
92
+
93
+
94
+ def make_polynomial_basis_space(
95
+ geo: _geometry.Geometry,
96
+ degree: int = 1,
97
+ element_basis: Optional[ElementBasis] = None,
98
+ discontinuous: bool = False,
99
+ family: Optional[_polynomial.Polynomial] = None,
100
+ ) -> BasisSpace:
101
+ """
102
+ Equips a geometry with a polynomial basis.
103
+
104
+ Args:
105
+ geo: the Geometry on which to build the space
106
+ degree: polynomial degree of the per-element shape functions
107
+ discontinuous: if True, use Discontinuous Galerkin shape functions. Discontinuous is implied if degree is 0, i.e, piecewise-constant shape functions.
108
+ element_basis: type of basis function for the individual elements
109
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
110
+
111
+ Returns:
112
+ the constructed basis space
113
+ """
114
+
115
+ shape = make_element_shape_function(geo.reference_cell(), degree=degree, element_basis=element_basis, family=family)
116
+
117
+ discontinuous = discontinuous or element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL
118
+ topology = make_element_based_space_topology(geo, shape, discontinuous)
119
+
120
+ return ShapeBasisSpace(topology, shape)
121
+
122
+
123
+ def make_element_based_space_topology(
124
+ geo: _geometry.Geometry,
125
+ shape: ShapeFunction,
126
+ discontinuous: bool = False,
127
+ ) -> SpaceTopology:
128
+ """
129
+ Makes a space topology from a geometry and an element-based shape function.
130
+
131
+ Args:
132
+ geo: The geometry to make the topology for
133
+ shape: The shape function to make the topology for
134
+ discontinuous: Whether to make a discontinuous topology
135
+
136
+ Returns:
137
+ The element-based space topology
138
+
139
+ Raises:
140
+ NotImplementedError: If the geometry type is not supported
141
+ ValueError: If the shape function is not supported for the given geometry`
142
+ """
143
+
144
+ topology = None
145
+ base_geo = geo.base
146
+
147
+ if discontinuous or shape.ORDER == 0:
148
+ topology = RegularDiscontinuousSpaceTopology(geo, shape.NODES_PER_ELEMENT)
149
+ elif isinstance(base_geo, _geometry.Grid2D):
150
+ topology = make_grid_2d_space_topology(geo, shape)
151
+ elif isinstance(base_geo, _geometry.Grid3D):
152
+ topology = make_grid_3d_space_topology(geo, shape)
153
+ elif isinstance(base_geo, _geometry.Trimesh):
154
+ topology = make_trimesh_space_topology(geo, shape)
155
+ elif isinstance(base_geo, _geometry.Tetmesh):
156
+ topology = make_tetmesh_space_topology(geo, shape)
157
+ elif isinstance(base_geo, _geometry.Quadmesh):
158
+ topology = make_quadmesh_space_topology(geo, shape)
159
+ elif isinstance(base_geo, _geometry.Hexmesh):
160
+ topology = make_hexmesh_space_topology(geo, shape)
161
+ elif isinstance(base_geo, _geometry.Nanogrid) or isinstance(base_geo, _geometry.AdaptiveNanogrid):
162
+ topology = make_nanogrid_space_topology(geo, shape)
163
+
164
+ if topology is None:
165
+ raise NotImplementedError(f"Unsupported geometry type {geo.name}")
166
+
167
+ return topology
168
+
169
+
170
+ def make_collocated_function_space(
171
+ basis_space: BasisSpace, dtype: type = float, dof_mapper: Optional[DofMapper] = None
172
+ ) -> CollocatedFunctionSpace:
173
+ """
174
+ Constructs a function space from a scalar-valued basis space and a value type, such that all degrees of freedom of the value type are stored at each of the basis nodes.
175
+
176
+ Args:
177
+ geo: the Geometry on which to build the space
178
+ dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
179
+ dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
180
+
181
+ Returns:
182
+ the constructed function space
183
+ """
184
+
185
+ if basis_space.value != ShapeFunction.Value.Scalar:
186
+ raise ValueError("Collocated function spaces may only be constructed from scalar-valued basis")
187
+
188
+ return CollocatedFunctionSpace(basis_space, dtype=dtype, dof_mapper=dof_mapper)
189
+
190
+
191
+ def make_covariant_function_space(
192
+ basis_space: BasisSpace,
193
+ ) -> CovariantFunctionSpace:
194
+ """
195
+ Constructs a covariant function space from a vector-valued basis space
196
+ """
197
+
198
+ if basis_space.value != ShapeFunction.Value.CovariantVector:
199
+ raise ValueError("Covariant function spaces may only be constructed from covariant vector-valued basis")
200
+ return CovariantFunctionSpace(basis_space)
201
+
202
+
203
+ def make_contravariant_function_space(
204
+ basis_space: BasisSpace,
205
+ ) -> ContravariantFunctionSpace:
206
+ """
207
+ Constructs a contravariant function space from a vector-valued basis space
208
+ """
209
+
210
+ if basis_space.value != ShapeFunction.Value.ContravariantVector:
211
+ raise ValueError("Contravariant function spaces may only be constructed from contravariant vector-valued basis")
212
+ return ContravariantFunctionSpace(basis_space)
213
+
214
+
215
+ def make_polynomial_space(
216
+ geo: _geometry.Geometry,
217
+ dtype: type = float,
218
+ dof_mapper: Optional[DofMapper] = None,
219
+ degree: int = 1,
220
+ element_basis: Optional[ElementBasis] = None,
221
+ discontinuous: bool = False,
222
+ family: Optional[_polynomial.Polynomial] = None,
223
+ ) -> CollocatedFunctionSpace:
224
+ """
225
+ Equips a geometry with a collocated, polynomial function space.
226
+ Equivalent to successive calls to :func:`make_polynomial_basis_space` then `make_collocated_function_space`, `make_covariant_function_space` or `make_contravariant_function_space`.
227
+
228
+ Args:
229
+ geo: the Geometry on which to build the space
230
+ dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
231
+ dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
232
+ degree: polynomial degree of the per-element shape functions
233
+ discontinuous: if True, use Discontinuous Galerkin shape functions. Discontinuous is implied if degree is 0, i.e, piecewise-constant shape functions.
234
+ element_basis: type of basis function for the individual elements
235
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
236
+
237
+ Returns:
238
+ the constructed function space
239
+ """
240
+
241
+ basis_space = make_polynomial_basis_space(geo, degree, element_basis, discontinuous, family)
242
+
243
+ if basis_space.value == ShapeFunction.Value.CovariantVector:
244
+ return make_covariant_function_space(basis_space)
245
+ if basis_space.value == ShapeFunction.Value.ContravariantVector:
246
+ return make_contravariant_function_space(basis_space)
247
+
248
+ return make_collocated_function_space(basis_space, dtype=dtype, dof_mapper=dof_mapper)