warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,9 @@ from typing import Any
2
2
 
3
3
  import warp as wp
4
4
 
5
- from warp.fem import domain
5
+ from warp.fem import domain, cache
6
6
  from warp.fem.types import ElementIndex, Coords
7
+ from warp.fem.space import FunctionSpace
7
8
 
8
9
  from ..polynomial import Polynomial
9
10
 
@@ -11,8 +12,11 @@ from ..polynomial import Polynomial
11
12
  class Quadrature:
12
13
  """Interface class for quadrature rules"""
13
14
 
14
- Arg: wp.codegen.Struct
15
- """Structure containing arguments to be passed to device functions"""
15
+ @wp.struct
16
+ class Arg:
17
+ """Structure containing arguments to be passed to device functions"""
18
+
19
+ pass
16
20
 
17
21
  def __init__(self, domain: domain.GeometryDomain):
18
22
  self._domain = domain
@@ -22,27 +26,46 @@ class Quadrature:
22
26
  """Domain over which this quadrature is defined"""
23
27
  return self._domain
24
28
 
25
- def eval_arg_value(self, device) -> wp.codegen.StructInstance:
29
+ def arg_value(self, device) -> "Arg":
26
30
  """
27
31
  Value of the argument to be passed to device
28
32
  """
29
- pass
33
+ arg = RegularQuadrature.Arg()
34
+ return arg
30
35
 
31
36
  def total_point_count(self):
32
37
  """Total number of quadrature points over the domain"""
33
- pass
38
+ raise NotImplementedError()
39
+
40
+ def points_per_element(self):
41
+ """Number of points per element if constant, or ``None`` if varying"""
42
+ return None
34
43
 
35
- def point_count(arg: Any, element_index: ElementIndex):
44
+ @staticmethod
45
+ def point_count(elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex):
36
46
  """Number of quadrature points for a given element"""
37
- pass
47
+ raise NotImplementedError()
38
48
 
39
- def point_coords(arg: Any, element_index: ElementIndex, qp_index: int):
40
- """Coordinates in element of the qp_index'th quadrature point"""
41
- pass
49
+ @staticmethod
50
+ def point_coords(
51
+ elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
52
+ ):
53
+ """Coordinates in element of the element's qp_index'th quadrature point"""
54
+ raise NotImplementedError()
42
55
 
43
- def point_weight(arg: Any, element_index: ElementIndex, qp_index: int):
44
- """Weight of the qp_index'th quadrature point"""
45
- pass
56
+ @staticmethod
57
+ def point_weight(
58
+ elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
59
+ ):
60
+ """Weight of the element's qp_index'th quadrature point"""
61
+ raise NotImplementedError()
62
+
63
+ @staticmethod
64
+ def point_index(
65
+ elt_arg: "domain.GeometryDomain.ElementArg", qp_arg: Arg, element_index: ElementIndex, qp_index: int
66
+ ):
67
+ """Global index of the element's qp_index'th quadrature point"""
68
+ raise NotImplementedError()
46
69
 
47
70
  def __str__(self) -> str:
48
71
  return self.name
@@ -64,31 +87,29 @@ class RegularQuadrature(Quadrature):
64
87
 
65
88
  self._element_quadrature = domain.reference_element().instantiate_quadrature(order, family)
66
89
 
67
- N = wp.constant(len(self.points))
90
+ self._N = wp.constant(len(self.points))
68
91
 
69
- WeightVec = wp.vec(length=N, dtype=wp.float32)
70
- CoordMat = wp.mat(shape=(N, 3), dtype=wp.float32)
92
+ WeightVec = wp.vec(length=self._N, dtype=wp.float32)
93
+ CoordMat = wp.mat(shape=(self._N, 3), dtype=wp.float32)
71
94
 
72
- POINTS = wp.constant(CoordMat(self.points))
73
- WEIGHTS = wp.constant(WeightVec(self.weights))
95
+ self._POINTS = wp.constant(CoordMat(self.points))
96
+ self._WEIGHTS = wp.constant(WeightVec(self.weights))
74
97
 
75
- self.point_count = self._make_point_count(N)
76
- self.point_index = self._make_point_index(N)
77
- self.point_coords = self._make_point_coords(POINTS, self.name)
78
- self.point_weight = self._make_point_weight(WEIGHTS, self.name)
98
+ self.point_count = self._make_point_count()
99
+ self.point_index = self._make_point_index()
100
+ self.point_coords = self._make_point_coords()
101
+ self.point_weight = self._make_point_weight()
79
102
 
80
103
  @property
81
104
  def name(self):
82
- return (
83
- f"{self.__class__.__name__}_{self.domain.reference_element().__class__.__name__}_{self.family}_{self.order}"
84
- )
85
-
86
- def __str__(self) -> str:
87
- return self.name
105
+ return f"{self.__class__.__name__}_{self.domain.name}_{self.family}_{self.order}"
88
106
 
89
107
  def total_point_count(self):
90
108
  return len(self.points) * self.domain.geometry_element_count()
91
109
 
110
+ def points_per_element(self):
111
+ return self._N
112
+
92
113
  @property
93
114
  def points(self):
94
115
  return self._element_quadrature[0]
@@ -97,46 +118,177 @@ class RegularQuadrature(Quadrature):
97
118
  def weights(self):
98
119
  return self._element_quadrature[1]
99
120
 
100
- @wp.struct
101
- class Arg:
102
- pass
121
+ def _make_point_count(self):
122
+ N = self._N
103
123
 
104
- def arg_value(self, device) -> Arg:
105
- arg = RegularQuadrature.Arg()
124
+ @cache.dynamic_func(suffix=self.name)
125
+ def point_count(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex):
126
+ return N
127
+
128
+ return point_count
129
+
130
+ def _make_point_coords(self):
131
+ POINTS = self._POINTS
132
+
133
+ @cache.dynamic_func(suffix=self.name)
134
+ def point_coords(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
135
+ return Coords(POINTS[qp_index, 0], POINTS[qp_index, 1], POINTS[qp_index, 2])
136
+
137
+ return point_coords
138
+
139
+ def _make_point_weight(self):
140
+ WEIGHTS = self._WEIGHTS
141
+
142
+ @cache.dynamic_func(suffix=self.name)
143
+ def point_weight(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
144
+ return WEIGHTS[qp_index]
145
+
146
+ return point_weight
147
+
148
+ def _make_point_index(self):
149
+ N = self._N
150
+
151
+ @cache.dynamic_func(suffix=self.name)
152
+ def point_index(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
153
+ return N * element_index + qp_index
154
+
155
+ return point_index
156
+
157
+
158
+ class NodalQuadrature(Quadrature):
159
+ """Quadrature using space node points as quadrature points
160
+
161
+ Note that in contrast to the `nodal=True` flag for :func:`integrate`, this quadrature odes not make any assumption
162
+ about orthogonality of shape functions, and is thus safe to use for arbitrary integrands.
163
+ """
164
+
165
+ def __init__(self, domain: domain.GeometryDomain, space: FunctionSpace):
166
+ super().__init__(domain)
167
+
168
+ self._space = space
169
+
170
+ self.Arg = self._make_arg()
171
+
172
+ self.point_count = self._make_point_count()
173
+ self.point_index = self._make_point_index()
174
+ self.point_coords = self._make_point_coords()
175
+ self.point_weight = self._make_point_weight()
176
+
177
+ @property
178
+ def name(self):
179
+ return f"{self.__class__.__name__}_{self._space.name}"
180
+
181
+ def total_point_count(self):
182
+ return self._space.node_count()
183
+
184
+ def points_per_element(self):
185
+ return self._space.topology.NODES_PER_ELEMENT
186
+
187
+ def _make_arg(self):
188
+ @cache.dynamic_struct(suffix=self.name)
189
+ class Arg:
190
+ space_arg: self._space.SpaceArg
191
+ topo_arg: self._space.topology.TopologyArg
192
+
193
+ return Arg
194
+
195
+ @cache.cached_arg_value
196
+ def arg_value(self, device):
197
+ arg = self.Arg()
198
+ arg.space_arg = self._space.space_arg_value(device)
199
+ arg.topo_arg = self._space.topology.topo_arg_value(device)
106
200
  return arg
107
201
 
108
- @staticmethod
109
- def _make_point_count(N):
110
- def point_count(arg: RegularQuadrature.Arg, element_index: ElementIndex):
202
+ def _make_point_count(self):
203
+ N = self._space.topology.NODES_PER_ELEMENT
204
+
205
+ @cache.dynamic_func(suffix=self.name)
206
+ def point_count(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex):
111
207
  return N
112
208
 
113
- from warp.fem.cache import get_func
209
+ return point_count
114
210
 
115
- return get_func(point_count, str(N))
211
+ def _make_point_coords(self):
212
+ @cache.dynamic_func(suffix=self.name)
213
+ def point_coords(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
214
+ return self._space.node_coords_in_element(elt_arg, qp_arg.space_arg, element_index, qp_index)
116
215
 
117
- @staticmethod
118
- def _make_point_coords(POINTS, name):
119
- def point_coords(arg: RegularQuadrature.Arg, element_index: ElementIndex, index: int):
120
- return Coords(POINTS[index, 0], POINTS[index, 1], POINTS[index, 2])
216
+ return point_coords
121
217
 
122
- from warp.fem.cache import get_func
218
+ def _make_point_weight(self):
219
+ @cache.dynamic_func(suffix=self.name)
220
+ def point_weight(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
221
+ return self._space.node_quadrature_weight(elt_arg, qp_arg.space_arg, element_index, qp_index)
123
222
 
124
- return get_func(point_coords, name)
223
+ return point_weight
125
224
 
126
- @staticmethod
127
- def _make_point_weight(WEIGHTS, name):
128
- def point_weight(arg: RegularQuadrature.Arg, element_index: ElementIndex, index: int):
129
- return WEIGHTS[index]
225
+ def _make_point_index(self):
226
+ @cache.dynamic_func(suffix=self.name)
227
+ def point_index(elt_arg: self.domain.ElementArg, qp_arg: self.Arg, element_index: ElementIndex, qp_index: int):
228
+ return self._space.topology.element_node_index(elt_arg, qp_arg.topo_arg, element_index, qp_index)
130
229
 
131
- from warp.fem.cache import get_func
230
+ return point_index
132
231
 
133
- return get_func(point_weight, name)
134
232
 
135
- @staticmethod
136
- def _make_point_index(N):
137
- def point_index(arg: RegularQuadrature.Arg, element_index: ElementIndex, index: int):
138
- return N * element_index + index
233
+ class ExplicitQuadrature(Quadrature):
234
+ """Quadrature using explicit per-cell points and weights. The number of quadrature points per cell is assumed
235
+ to be constant and deduced from the shape of the points and weights arrays.
236
+
237
+ Args:
238
+ domain: Domain of definition of the quadrature formula
239
+ points: 2d array of shape ``(domain.geometry_element-count(), points_per_cell)`` containing the coordinates of each quadrature point.
240
+ weights: 2d array of shape ``(domain.geometry_element-count(), points_per_cell)`` containing the weight for each quadrature point.
241
+
242
+ See also: :class:`PicQuadrature`
243
+ """
244
+
245
+ @wp.struct
246
+ class Arg:
247
+ points_per_cell: int
248
+ points: wp.array2d(dtype=Coords)
249
+ weights: wp.array2d(dtype=float)
250
+
251
+ def __init__(self, domain: domain.GeometryDomain, points: "wp.array2d(dtype=Coords)", weights: "wp.array2d(dtype=float)"):
252
+ super().__init__(domain)
253
+
254
+ if points.shape != weights.shape:
255
+ raise ValueError("Points and weights arrays must have the same shape")
256
+
257
+ self._points_per_cell = points.shape[1]
258
+ self._points = points
259
+ self._weights = weights
260
+
261
+ @property
262
+ def name(self):
263
+ return f"{self.__class__.__name__}"
264
+
265
+ def total_point_count(self):
266
+ return self._weights.size
267
+
268
+ def points_per_element(self):
269
+ return self._points_per_cell
270
+
271
+ @cache.cached_arg_value
272
+ def arg_value(self, device):
273
+ arg = self.Arg()
274
+ arg.points_per_cell = self._points_per_cell
275
+ arg.points = self._points.to(device)
276
+ arg.weights = self._weights.to(device)
277
+
278
+ return arg
279
+
280
+ @wp.func
281
+ def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
282
+ return qp_arg.points_per_cell
283
+
284
+ @wp.func
285
+ def point_coords(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
286
+ return qp_arg.points[element_index, qp_index]
139
287
 
140
- from warp.fem.cache import get_func
288
+ @wp.func
289
+ def point_weight(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
290
+ return qp_arg.weights[element_index, qp_index]
141
291
 
142
- return get_func(point_index, str(N))
292
+ @wp.func
293
+ def point_index(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, qp_index: int):
294
+ return qp_arg.points_per_cell * element_index + qp_index
@@ -1,115 +1,292 @@
1
1
  from typing import Optional
2
+ from enum import Enum
2
3
 
3
- import warp.fem.domain
4
- import warp.fem.geometry
5
- import warp.fem.polynomial
4
+ import warp.fem.domain as _domain
5
+ import warp.fem.geometry as _geometry
6
+ import warp.fem.polynomial as _polynomial
6
7
 
7
8
  from .function_space import FunctionSpace
8
- from .nodal_function_space import NodalFunctionSpace
9
+ from .topology import SpaceTopology
10
+ from .basis_space import BasisSpace, PointBasisSpace
11
+ from .collocated_function_space import CollocatedFunctionSpace
9
12
 
10
13
  from .grid_2d_function_space import (
11
- GridPiecewiseConstantSpace,
12
- GridBipolynomialSpace,
13
- GridDGBipolynomialSpace,
14
+ GridPiecewiseConstantBasis,
15
+ GridBipolynomialBasisSpace,
16
+ GridDGBipolynomialBasisSpace,
17
+ GridSerendipityBasisSpace,
18
+ GridDGSerendipityBasisSpace,
19
+ GridDGPolynomialBasisSpace,
14
20
  )
15
21
  from .grid_3d_function_space import (
16
- GridTripolynomialSpace,
17
- GridDGTripolynomialSpace,
18
- Grid3DPiecewiseConstantSpace,
22
+ GridTripolynomialBasisSpace,
23
+ GridDGTripolynomialBasisSpace,
24
+ Grid3DPiecewiseConstantBasis,
25
+ Grid3DSerendipityBasisSpace,
26
+ Grid3DDGSerendipityBasisSpace,
27
+ Grid3DDGPolynomialBasisSpace,
19
28
  )
20
29
  from .trimesh_2d_function_space import (
21
- Trimesh2DPiecewiseConstantSpace,
22
- Trimesh2DPolynomialSpace,
23
- Trimesh2DDGPolynomialSpace,
30
+ Trimesh2DPiecewiseConstantBasis,
31
+ Trimesh2DPolynomialBasisSpace,
32
+ Trimesh2DDGPolynomialBasisSpace,
33
+ Trimesh2DNonConformingPolynomialBasisSpace,
34
+ )
35
+ from .tetmesh_function_space import (
36
+ TetmeshPiecewiseConstantBasis,
37
+ TetmeshPolynomialBasisSpace,
38
+ TetmeshDGPolynomialBasisSpace,
39
+ TetmeshNonConformingPolynomialBasisSpace,
40
+ )
41
+ from .quadmesh_2d_function_space import (
42
+ Quadmesh2DPiecewiseConstantBasis,
43
+ Quadmesh2DBipolynomialBasisSpace,
44
+ Quadmesh2DDGBipolynomialBasisSpace,
45
+ Quadmesh2DSerendipityBasisSpace,
46
+ Quadmesh2DDGSerendipityBasisSpace,
47
+ Quadmesh2DPolynomialBasisSpace,
48
+ )
49
+ from .hexmesh_function_space import (
50
+ HexmeshPiecewiseConstantBasis,
51
+ HexmeshTripolynomialBasisSpace,
52
+ HexmeshDGTripolynomialBasisSpace,
53
+ HexmeshSerendipityBasisSpace,
54
+ HexmeshDGSerendipityBasisSpace,
55
+ HexmeshPolynomialBasisSpace,
24
56
  )
25
- from .tetmesh_function_space import TetmeshPiecewiseConstantSpace, TetmeshPolynomialSpace, TetmeshDGPolynomialSpace
26
57
 
27
58
  from .partition import SpacePartition, make_space_partition
28
59
  from .restriction import SpaceRestriction
29
60
 
30
61
 
31
- from .dof_mapper import DofMapper, IdentityMapper, SymmetricTensorMapper
62
+ from .dof_mapper import DofMapper, IdentityMapper, SymmetricTensorMapper, SkewSymmetricTensorMapper
32
63
 
33
64
 
34
65
  def make_space_restriction(
35
- space: FunctionSpace,
66
+ space: Optional[FunctionSpace] = None,
36
67
  space_partition: Optional[SpacePartition] = None,
37
- domain: Optional[warp.fem.domain.GeometryDomain] = None,
68
+ domain: Optional[_domain.GeometryDomain] = None,
69
+ space_topology: Optional[SpaceTopology] = None,
38
70
  device=None,
71
+ temporary_store: "Optional[warp.fem.cache.TemporaryStore]" = None,
39
72
  ) -> SpaceRestriction:
40
73
  """
41
- Restricts a function space to a Domain, i.e. a subset of its elements.
74
+ Restricts a function space partition to a Domain, i.e. a subset of its elements.
75
+
76
+ One of `space_partition`, `space_topology`, or `space` must be provided (and will be considered in that order).
42
77
 
43
78
  Args:
44
- space: the space to be restricted
45
- space_partition: if provided, the subset of nodes from ``space`` to consider
79
+ space: (deprecated) if neither `space_partition` nor `space_topology` are provided, the space defining the topology to restrict
80
+ space_partition: the subset of nodes from the space topology to consider
46
81
  domain: the domain to restrict the space to, defaults to all cells of the space geometry or partition.
82
+ space_topology: the space topology to be restricted, if `space_partition` is ``None``.
83
+ device: device on which to perform and store computations
84
+ temporary_store: shared pool from which to allocate temporary arrays
47
85
  """
48
- if domain is None:
49
- if space_partition is None:
50
- domain = warp.fem.domain.Cells(geometry=space.geometry)
51
- else:
52
- domain = warp.fem.domain.Cells(geometry=space_partition.geo_partition)
53
- return SpaceRestriction(space=space, space_partition=space_partition, domain=domain, device=device)
54
86
 
87
+ if space_partition is None:
88
+ if space_topology is None:
89
+ assert space is not None
90
+ space_topology = space.topology
55
91
 
56
- def make_polynomial_space(
57
- geo: warp.fem.geometry.Geometry,
58
- dtype: type = float,
59
- dof_mapper: Optional[DofMapper] = None,
92
+ if domain is None:
93
+ domain = _domain.Cells(geometry=space_topology.geometry)
94
+
95
+ space_partition = make_space_partition(
96
+ space_topology=space_topology, geometry_partition=domain.geometry_partition
97
+ )
98
+ elif domain is None:
99
+ domain = _domain.Cells(geometry=space_partition.geo_partition)
100
+
101
+ return SpaceRestriction(
102
+ space_partition=space_partition, domain=domain, device=device, temporary_store=temporary_store
103
+ )
104
+
105
+
106
+ class ElementBasis(Enum):
107
+ """Choice of basis function to equip individual elements"""
108
+
109
+ LAGRANGE = 0
110
+ """Lagrange basis functions :math:`P_k` for simplices, tensor products :math:`Q_k` for squares and cubes"""
111
+ SERENDIPITY = 1
112
+ """Serendipity elements :math:`S_k`, corresponding to Lagrange nodes with interior points removed (for degree <= 3)"""
113
+ NONCONFORMING_POLYNOMIAL = 2
114
+ """Simplex Lagrange basis functions :math:`P_{kd}` embedded into non conforming reference elements (e.g. squares or cubes). Discontinuous only."""
115
+
116
+
117
+ def make_polynomial_basis_space(
118
+ geo: _geometry.Geometry,
60
119
  degree: int = 1,
120
+ element_basis: Optional[ElementBasis] = None,
61
121
  discontinuous: bool = False,
62
- family: Optional[warp.fem.polynomial.Polynomial] = None,
63
- ) -> FunctionSpace:
122
+ family: Optional[_polynomial.Polynomial] = None,
123
+ ) -> BasisSpace:
64
124
  """
65
- Equip elements of a geometry with a Lagrange polynomial function space
125
+ Equips a geometry with a polynomial basis.
66
126
 
67
127
  Args:
68
128
  geo: the Geometry on which to build the space
69
- dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
70
- dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
71
129
  degree: polynomial degree of the per-element shape functions
72
130
  discontinuous: if True, use Discontinuous Galerkin shape functions. Discontinuous is implied if degree is 0, i.e, piecewise-constant shape functions.
131
+ element_basis: type of basis function for the individual elements
73
132
  family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
74
133
 
75
134
  Returns:
76
- the constructed function space
135
+ the constructed basis space
77
136
  """
78
137
 
79
- if isinstance(geo, warp.fem.geometry.Grid2D):
138
+ base_geo = geo.base if isinstance(geo, _geometry.DeformedGeometry) else geo
139
+
140
+ if element_basis is None:
141
+ element_basis = ElementBasis.LAGRANGE
142
+
143
+ if isinstance(base_geo, _geometry.Grid2D):
80
144
  if degree == 0:
81
- return GridPiecewiseConstantSpace(geo, dtype=dtype, dof_mapper=dof_mapper)
145
+ return GridPiecewiseConstantBasis(geo)
146
+
147
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
148
+ if discontinuous:
149
+ return GridDGSerendipityBasisSpace(geo, degree=degree, family=family)
150
+ else:
151
+ return GridSerendipityBasisSpace(geo, degree=degree, family=family)
152
+
153
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
154
+ return GridDGPolynomialBasisSpace(geo, degree=degree)
82
155
 
83
156
  if discontinuous:
84
- return GridDGBipolynomialSpace(geo, dtype=dtype, dof_mapper=dof_mapper, degree=degree, family=family)
157
+ return GridDGBipolynomialBasisSpace(geo, degree=degree, family=family)
85
158
  else:
86
- return GridBipolynomialSpace(geo, dtype=dtype, dof_mapper=dof_mapper, degree=degree, family=family)
159
+ return GridBipolynomialBasisSpace(geo, degree=degree, family=family)
87
160
 
88
- if isinstance(geo, warp.fem.geometry.Grid3D):
161
+ if isinstance(base_geo, _geometry.Grid3D):
89
162
  if degree == 0:
90
- return Grid3DPiecewiseConstantSpace(geo, dtype=dtype, dof_mapper=dof_mapper)
163
+ return Grid3DPiecewiseConstantBasis(geo)
164
+
165
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
166
+ if discontinuous:
167
+ return Grid3DDGSerendipityBasisSpace(geo, degree=degree, family=family)
168
+ else:
169
+ return Grid3DSerendipityBasisSpace(geo, degree=degree, family=family)
170
+
171
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
172
+ return Grid3DDGPolynomialBasisSpace(geo, degree=degree)
91
173
 
92
174
  if discontinuous:
93
- return GridDGTripolynomialSpace(geo, dtype=dtype, dof_mapper=dof_mapper, degree=degree, family=family)
175
+ return GridDGTripolynomialBasisSpace(geo, degree=degree, family=family)
94
176
  else:
95
- return GridTripolynomialSpace(geo, dtype=dtype, dof_mapper=dof_mapper, degree=degree, family=family)
177
+ return GridTripolynomialBasisSpace(geo, degree=degree, family=family)
96
178
 
97
- if isinstance(geo, warp.fem.geometry.Trimesh2D):
179
+ if isinstance(base_geo, _geometry.Trimesh2D):
98
180
  if degree == 0:
99
- return Trimesh2DPiecewiseConstantSpace(geo, dtype=dtype, dof_mapper=dof_mapper)
181
+ return Trimesh2DPiecewiseConstantBasis(geo)
182
+
183
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
184
+ raise NotImplementedError("Serendipity variant not implemented yet")
185
+
186
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
187
+ return Trimesh2DNonConformingPolynomialBasisSpace(geo, degree=degree)
100
188
 
101
189
  if discontinuous:
102
- return Trimesh2DDGPolynomialSpace(geo, dtype=dtype, dof_mapper=dof_mapper, degree=degree)
190
+ return Trimesh2DDGPolynomialBasisSpace(geo, degree=degree)
103
191
  else:
104
- return Trimesh2DPolynomialSpace(geo, dtype=dtype, dof_mapper=dof_mapper, degree=degree)
192
+ return Trimesh2DPolynomialBasisSpace(geo, degree=degree)
105
193
 
106
- if isinstance(geo, warp.fem.geometry.Tetmesh):
194
+ if isinstance(base_geo, _geometry.Tetmesh):
107
195
  if degree == 0:
108
- return TetmeshPiecewiseConstantSpace(geo, dtype=dtype, dof_mapper=dof_mapper)
196
+ return TetmeshPiecewiseConstantBasis(geo)
197
+
198
+ if element_basis == ElementBasis.SERENDIPITY and degree > 2:
199
+ raise NotImplementedError("Serendipity variant not implemented yet")
200
+
201
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
202
+ return TetmeshNonConformingPolynomialBasisSpace(geo, degree=degree)
109
203
 
110
204
  if discontinuous:
111
- return TetmeshDGPolynomialSpace(geo, dtype=dtype, dof_mapper=dof_mapper, degree=degree)
205
+ return TetmeshDGPolynomialBasisSpace(geo, degree=degree)
112
206
  else:
113
- return TetmeshPolynomialSpace(geo, dtype=dtype, dof_mapper=dof_mapper, degree=degree)
207
+ return TetmeshPolynomialBasisSpace(geo, degree=degree)
208
+
209
+ if isinstance(base_geo, _geometry.Quadmesh2D):
210
+ if degree == 0:
211
+ return Quadmesh2DPiecewiseConstantBasis(geo)
212
+
213
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
214
+ if discontinuous:
215
+ return Quadmesh2DDGSerendipityBasisSpace(geo, degree=degree, family=family)
216
+ else:
217
+ return Quadmesh2DSerendipityBasisSpace(geo, degree=degree, family=family)
218
+
219
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
220
+ return Quadmesh2DPolynomialBasisSpace(geo, degree=degree)
221
+
222
+ if discontinuous:
223
+ return Quadmesh2DDGBipolynomialBasisSpace(geo, degree=degree, family=family)
224
+ else:
225
+ return Quadmesh2DBipolynomialBasisSpace(geo, degree=degree, family=family)
226
+
227
+ if isinstance(base_geo, _geometry.Hexmesh):
228
+ if degree == 0:
229
+ return HexmeshPiecewiseConstantBasis(geo)
230
+
231
+ if element_basis == ElementBasis.SERENDIPITY and degree > 1:
232
+ if discontinuous:
233
+ return HexmeshDGSerendipityBasisSpace(geo, degree=degree, family=family)
234
+ else:
235
+ return HexmeshSerendipityBasisSpace(geo, degree=degree, family=family)
236
+
237
+ if element_basis == ElementBasis.NONCONFORMING_POLYNOMIAL:
238
+ return HexmeshPolynomialBasisSpace(geo, degree=degree)
239
+
240
+ if discontinuous:
241
+ return HexmeshDGTripolynomialBasisSpace(geo, degree=degree, family=family)
242
+ else:
243
+ return HexmeshTripolynomialBasisSpace(geo, degree=degree, family=family)
244
+
245
+ raise NotImplementedError()
246
+
247
+
248
+ def make_collocated_function_space(
249
+ basis_space: BasisSpace, dtype: type = float, dof_mapper: Optional[DofMapper] = None
250
+ ) -> CollocatedFunctionSpace:
251
+ """
252
+ Constructs a function space from a basis space and a value type, such that all degrees of freedom of the value type are stored at each of the basis nodes.
253
+
254
+ Args:
255
+ geo: the Geometry on which to build the space
256
+ dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
257
+ dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
258
+
259
+ Returns:
260
+ the constructed function space
261
+ """
262
+ return CollocatedFunctionSpace(basis_space, dtype=dtype, dof_mapper=dof_mapper)
263
+
264
+
265
+ def make_polynomial_space(
266
+ geo: _geometry.Geometry,
267
+ dtype: type = float,
268
+ dof_mapper: Optional[DofMapper] = None,
269
+ degree: int = 1,
270
+ element_basis: Optional[ElementBasis] = None,
271
+ discontinuous: bool = False,
272
+ family: Optional[_polynomial.Polynomial] = None,
273
+ ) -> CollocatedFunctionSpace:
274
+ """
275
+ Equips a geometry with a collocated, polynomial function space.
276
+ Equivalent to successive calls to :func:`make_polynomial_basis_space` and `make_collocated_function_space`.
277
+
278
+ Args:
279
+ geo: the Geometry on which to build the space
280
+ dtype: value type the function space. If ``dof_mapper`` is provided, the value type from the DofMapper will be used instead.
281
+ dof_mapper: mapping from node degrees of freedom to function values, defaults to Identity. Useful for reduced coordinates, e.g. :py:class:`SymmetricTensorMapper` maps 2x2 (resp 3x3) symmetric tensors to 3 (resp 6) degrees of freedom.
282
+ degree: polynomial degree of the per-element shape functions
283
+ discontinuous: if True, use Discontinuous Galerkin shape functions. Discontinuous is implied if degree is 0, i.e, piecewise-constant shape functions.
284
+ element_basis: type of basis function for the individual elements
285
+ family: Polynomial family used to generate the shape function basis. If not provided, a reasonable basis is chosen.
286
+
287
+ Returns:
288
+ the constructed function space
289
+ """
114
290
 
115
- raise NotImplementedError
291
+ basis_space = make_polynomial_basis_space(geo, degree, element_basis, discontinuous, family)
292
+ return CollocatedFunctionSpace(basis_space, dtype=dtype, dof_mapper=dof_mapper)