warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (269) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.so +0 -0
  57. warp/bin/warp.so +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/fem/field/discrete_field.py +0 -80
  257. warp/fem/space/nodal_function_space.py +0 -233
  258. warp/tests/test_all.py +0 -223
  259. warp/tests/test_array_scan.py +0 -60
  260. warp/tests/test_base.py +0 -208
  261. warp/tests/test_unresolved_func.py +0 -7
  262. warp/tests/test_unresolved_symbol.py +0 -7
  263. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  264. warp_lang-1.0.0b2.dist-info/RECORD +0 -378
  265. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  266. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  267. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  268. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  269. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,369 @@
1
+ import warp as wp
2
+
3
+ from warp.fem.types import ElementIndex, Coords
4
+ from warp.fem.polynomial import Polynomial, is_closed
5
+ from warp.fem.geometry import Quadmesh2D
6
+ from warp.fem import cache
7
+
8
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
9
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
10
+
11
+ from .shape import ShapeFunction, ConstantShapeFunction
12
+ from .shape import (
13
+ SquareBipolynomialShapeFunctions,
14
+ SquareSerendipityShapeFunctions,
15
+ SquareNonConformingPolynomialShapeFunctions,
16
+ )
17
+
18
+
19
+ @wp.struct
20
+ class Quadmesh2DTopologyArg:
21
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
22
+ quad_edge_indices: wp.array2d(dtype=int)
23
+
24
+ vertex_count: int
25
+ edge_count: int
26
+
27
+
28
+ class Quadmesh2DSpaceTopology(SpaceTopology):
29
+ TopologyArg = Quadmesh2DTopologyArg
30
+
31
+ def __init__(self, mesh: Quadmesh2D, shape: ShapeFunction):
32
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
33
+ self._mesh = mesh
34
+ self._shape = shape
35
+
36
+ self._compute_quad_edge_indices()
37
+
38
+ @cache.cached_arg_value
39
+ def topo_arg_value(self, device):
40
+ arg = Quadmesh2DTopologyArg()
41
+ arg.quad_edge_indices = self._quad_edge_indices.to(device)
42
+ arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
43
+
44
+ arg.vertex_count = self._mesh.vertex_count()
45
+ arg.edge_count = self._mesh.side_count()
46
+ return arg
47
+
48
+ def _compute_quad_edge_indices(self):
49
+ self._quad_edge_indices = wp.empty(
50
+ dtype=int, device=self._mesh.quad_vertex_indices.device, shape=(self._mesh.cell_count(), 4)
51
+ )
52
+
53
+ wp.launch(
54
+ kernel=Quadmesh2DSpaceTopology._compute_quad_edge_indices_kernel,
55
+ dim=self._mesh.edge_quad_indices.shape,
56
+ device=self._mesh.quad_vertex_indices.device,
57
+ inputs=[
58
+ self._mesh.edge_quad_indices,
59
+ self._mesh.edge_vertex_indices,
60
+ self._mesh.quad_vertex_indices,
61
+ self._quad_edge_indices,
62
+ ],
63
+ )
64
+
65
+ @wp.func
66
+ def _find_edge_index_in_quad(
67
+ edge_vtx: wp.vec2i,
68
+ quad_vtx: wp.vec4i,
69
+ ):
70
+ for k in range(3):
71
+ if (edge_vtx[0] == quad_vtx[k] and edge_vtx[1] == quad_vtx[k + 1]) or (
72
+ edge_vtx[1] == quad_vtx[k] and edge_vtx[0] == quad_vtx[k + 1]
73
+ ):
74
+ return k
75
+ return 3
76
+
77
+ @wp.kernel
78
+ def _compute_quad_edge_indices_kernel(
79
+ edge_quad_indices: wp.array(dtype=wp.vec2i),
80
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
81
+ quad_vertex_indices: wp.array2d(dtype=int),
82
+ quad_edge_indices: wp.array2d(dtype=int),
83
+ ):
84
+ e = wp.tid()
85
+
86
+ edge_vtx = edge_vertex_indices[e]
87
+ edge_quads = edge_quad_indices[e]
88
+
89
+ q0 = edge_quads[0]
90
+ q0_vtx = wp.vec4i(
91
+ quad_vertex_indices[q0, 0],
92
+ quad_vertex_indices[q0, 1],
93
+ quad_vertex_indices[q0, 2],
94
+ quad_vertex_indices[q0, 3],
95
+ )
96
+ q0_edge = Quadmesh2DSpaceTopology._find_edge_index_in_quad(edge_vtx, q0_vtx)
97
+ quad_edge_indices[q0, q0_edge] = e
98
+
99
+ q1 = edge_quads[1]
100
+ if q1 != q0:
101
+ t1_vtx = wp.vec4i(
102
+ quad_vertex_indices[q1, 0],
103
+ quad_vertex_indices[q1, 1],
104
+ quad_vertex_indices[q1, 2],
105
+ quad_vertex_indices[q1, 3],
106
+ )
107
+ t1_edge = Quadmesh2DSpaceTopology._find_edge_index_in_quad(edge_vtx, t1_vtx)
108
+ quad_edge_indices[q1, t1_edge] = e
109
+
110
+
111
+ class Quadmesh2DDiscontinuousSpaceTopology(
112
+ DiscontinuousSpaceTopologyMixin,
113
+ SpaceTopology,
114
+ ):
115
+ def __init__(self, mesh: Quadmesh2D, shape: ShapeFunction):
116
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
117
+
118
+
119
+ class Quadmesh2DBasisSpace(ShapeBasisSpace):
120
+ def __init__(self, topology: Quadmesh2DSpaceTopology, shape: ShapeFunction):
121
+ super().__init__(topology, shape)
122
+
123
+ self._mesh: Quadmesh2D = topology.geometry
124
+
125
+
126
+ class Quadmesh2DPiecewiseConstantBasis(Quadmesh2DBasisSpace):
127
+ def __init__(self, mesh: Quadmesh2D):
128
+ shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=2)
129
+ topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
130
+ super().__init__(shape=shape, topology=topology)
131
+
132
+ class Trace(TraceBasisSpace):
133
+ @wp.func
134
+ def _node_coords_in_element(
135
+ side_arg: Quadmesh2D.SideArg,
136
+ basis_arg: Quadmesh2DBasisSpace.BasisArg,
137
+ element_index: ElementIndex,
138
+ node_index_in_element: int,
139
+ ):
140
+ return Coords(0.5, 0.0, 0.0)
141
+
142
+ def make_node_coords_in_element(self):
143
+ return self._node_coords_in_element
144
+
145
+ def trace(self):
146
+ return Quadmesh2DPiecewiseConstantBasis.Trace(self)
147
+
148
+
149
+ class Quadmesh2DBipolynomialSpaceTopology(Quadmesh2DSpaceTopology):
150
+ def __init__(self, mesh: Quadmesh2D, shape: SquareBipolynomialShapeFunctions):
151
+ super().__init__(mesh, shape)
152
+
153
+ self.element_node_index = self._make_element_node_index()
154
+
155
+ def node_count(self) -> int:
156
+ ORDER = self._shape.ORDER
157
+ INTERIOR_NODES_PER_SIDE = max(0, ORDER - 1)
158
+ INTERIOR_NODES_PER_CELL = INTERIOR_NODES_PER_SIDE**2
159
+
160
+ return (
161
+ self._mesh.vertex_count()
162
+ + self._mesh.side_count() * INTERIOR_NODES_PER_SIDE
163
+ + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
164
+ )
165
+
166
+ def _make_element_node_index(self):
167
+ ORDER = self._shape.ORDER
168
+ INTERIOR_NODES_PER_SIDE = wp.constant(max(0, ORDER - 1))
169
+ INTERIOR_NODES_PER_CELL = wp.constant(INTERIOR_NODES_PER_SIDE**2)
170
+
171
+ @cache.dynamic_func(suffix=self.name)
172
+ def element_node_index(
173
+ geo_arg: Quadmesh2D.CellArg,
174
+ topo_arg: Quadmesh2DTopologyArg,
175
+ element_index: ElementIndex,
176
+ node_index_in_elt: int,
177
+ ):
178
+ node_i = node_index_in_elt // (ORDER + 1)
179
+ node_j = node_index_in_elt - (ORDER + 1) * node_i
180
+
181
+ # Vertices
182
+ if node_i == 0:
183
+ if node_j == 0:
184
+ return geo_arg.quad_vertex_indices[element_index, 0]
185
+ elif node_j == ORDER:
186
+ return geo_arg.quad_vertex_indices[element_index, 3]
187
+
188
+ # 3-0 edge
189
+ side_index = topo_arg.quad_edge_indices[element_index, 3]
190
+ local_vs = geo_arg.quad_vertex_indices[element_index, 3]
191
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
192
+ index_in_side = wp.select(local_vs == global_vs, ORDER - node_j, node_j) - 1
193
+
194
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
195
+
196
+ elif node_i == ORDER:
197
+ if node_j == 0:
198
+ return geo_arg.quad_vertex_indices[element_index, 1]
199
+ elif node_j == ORDER:
200
+ return geo_arg.quad_vertex_indices[element_index, 2]
201
+
202
+ # 1-2 edge
203
+ side_index = topo_arg.quad_edge_indices[element_index, 1]
204
+ local_vs = geo_arg.quad_vertex_indices[element_index, 1]
205
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
206
+ index_in_side = wp.select(local_vs == global_vs, ORDER - node_j, node_j) - 1
207
+
208
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
209
+
210
+ if node_j == 0:
211
+ # 0-1 edge
212
+ side_index = topo_arg.quad_edge_indices[element_index, 0]
213
+ local_vs = geo_arg.quad_vertex_indices[element_index, 0]
214
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
215
+ index_in_side = wp.select(local_vs == global_vs, node_i, ORDER - node_i) - 1
216
+
217
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
218
+
219
+ elif node_j == ORDER:
220
+ # 2-3 edge
221
+ side_index = topo_arg.quad_edge_indices[element_index, 2]
222
+ local_vs = geo_arg.quad_vertex_indices[element_index, 2]
223
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
224
+ index_in_side = wp.select(local_vs == global_vs, node_i, ORDER - node_i) - 1
225
+
226
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
227
+
228
+ return (
229
+ topo_arg.vertex_count
230
+ + topo_arg.edge_count * INTERIOR_NODES_PER_SIDE
231
+ + element_index * INTERIOR_NODES_PER_CELL
232
+ + (node_i - 1) * INTERIOR_NODES_PER_SIDE
233
+ + node_j
234
+ - 1
235
+ )
236
+
237
+ return element_node_index
238
+
239
+
240
+ class Quadmesh2DBipolynomialBasisSpace(Quadmesh2DBasisSpace):
241
+ def __init__(
242
+ self,
243
+ mesh: Quadmesh2D,
244
+ degree: int,
245
+ family: Polynomial,
246
+ ):
247
+ if family is None:
248
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
249
+
250
+ if not is_closed(family):
251
+ raise ValueError("A closed polynomial family is required to define a continuous function space")
252
+
253
+ shape = SquareBipolynomialShapeFunctions(degree, family=family)
254
+ topology = forward_base_topology(Quadmesh2DBipolynomialSpaceTopology, mesh, shape)
255
+
256
+ super().__init__(topology, shape)
257
+
258
+
259
+ class Quadmesh2DDGBipolynomialBasisSpace(Quadmesh2DBasisSpace):
260
+ def __init__(
261
+ self,
262
+ mesh: Quadmesh2D,
263
+ degree: int,
264
+ family: Polynomial,
265
+ ):
266
+ if family is None:
267
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
268
+
269
+ shape = SquareBipolynomialShapeFunctions(degree, family=family)
270
+ topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
271
+
272
+ super().__init__(topology, shape)
273
+
274
+
275
+ class Quadmesh2DSerendipitySpaceTopology(Quadmesh2DSpaceTopology):
276
+ def __init__(self, grid: Quadmesh2D, shape: SquareSerendipityShapeFunctions):
277
+ super().__init__(grid, shape)
278
+
279
+ self.element_node_index = self._make_element_node_index()
280
+
281
+ def node_count(self) -> int:
282
+ return self.geometry.vertex_count() + (self._shape.ORDER - 1) * self.geometry.side_count()
283
+
284
+ def _make_element_node_index(self):
285
+ ORDER = self._shape.ORDER
286
+
287
+ SHAPE_TO_QUAD_IDX = wp.constant(wp.vec4i([0, 3, 1, 2]))
288
+
289
+ @cache.dynamic_func(suffix=self.name)
290
+ def element_node_index(
291
+ cell_arg: Quadmesh2D.CellArg,
292
+ topo_arg: Quadmesh2DSpaceTopology.TopologyArg,
293
+ element_index: ElementIndex,
294
+ node_index_in_elt: int,
295
+ ):
296
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
297
+
298
+ if node_type == SquareSerendipityShapeFunctions.VERTEX:
299
+ return cell_arg.quad_vertex_indices[element_index, SHAPE_TO_QUAD_IDX[type_index]]
300
+
301
+ side_offset, index_in_side = SquareSerendipityShapeFunctions.side_offset_and_index(type_index)
302
+
303
+ if node_type == SquareSerendipityShapeFunctions.EDGE_X:
304
+ if side_offset == 0:
305
+ side_start = 0
306
+ else:
307
+ side_start = 2
308
+ index_in_side = ORDER - 2 - index_in_side
309
+ else:
310
+ if side_offset == 0:
311
+ side_start = 3
312
+ index_in_side = ORDER - 2 - index_in_side
313
+ else:
314
+ side_start = 1
315
+
316
+ side_index = topo_arg.quad_edge_indices[element_index, side_start]
317
+ local_vs = cell_arg.quad_vertex_indices[element_index, side_start]
318
+ global_vs = topo_arg.edge_vertex_indices[side_index][0]
319
+ if local_vs != global_vs:
320
+ # Flip indexing direction
321
+ index_in_side = ORDER - 2 - index_in_side
322
+
323
+ return topo_arg.vertex_count + (ORDER - 1) * side_index + index_in_side
324
+
325
+ return element_node_index
326
+
327
+
328
+ class Quadmesh2DSerendipityBasisSpace(Quadmesh2DBasisSpace):
329
+ def __init__(
330
+ self,
331
+ mesh: Quadmesh2D,
332
+ degree: int,
333
+ family: Polynomial,
334
+ ):
335
+ if family is None:
336
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
337
+
338
+ shape = SquareSerendipityShapeFunctions(degree, family=family)
339
+ topology = forward_base_topology(Quadmesh2DSerendipitySpaceTopology, mesh, shape=shape)
340
+
341
+ super().__init__(topology=topology, shape=shape)
342
+
343
+
344
+ class Quadmesh2DDGSerendipityBasisSpace(Quadmesh2DBasisSpace):
345
+ def __init__(
346
+ self,
347
+ mesh: Quadmesh2D,
348
+ degree: int,
349
+ family: Polynomial,
350
+ ):
351
+ if family is None:
352
+ family = Polynomial.LOBATTO_GAUSS_LEGENDRE
353
+
354
+ shape = SquareSerendipityShapeFunctions(degree, family=family)
355
+ topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape=shape)
356
+
357
+ super().__init__(topology=topology, shape=shape)
358
+
359
+
360
+ class Quadmesh2DPolynomialBasisSpace(Quadmesh2DBasisSpace):
361
+ def __init__(
362
+ self,
363
+ mesh: Quadmesh2D,
364
+ degree: int,
365
+ ):
366
+ shape = SquareNonConformingPolynomialShapeFunctions(degree)
367
+ topology = Quadmesh2DDiscontinuousSpaceTopology(mesh, shape)
368
+
369
+ super().__init__(topology, shape)
@@ -3,6 +3,7 @@ import warp as wp
3
3
  from warp.fem.domain import GeometryDomain
4
4
  from warp.fem.types import NodeElementIndex
5
5
  from warp.fem.utils import compress_node_indices
6
+ from warp.fem.cache import cached_arg_value, borrow_temporary, borrow_temporary_like, TemporaryStore
6
7
 
7
8
  from .function_space import FunctionSpace
8
9
  from .partition import SpacePartition
@@ -11,97 +12,107 @@ wp.set_module_options({"enable_backward": False})
11
12
 
12
13
 
13
14
  class SpaceRestriction:
14
- """Restriction of a space to a given GeometryDomain"""
15
+ """Restriction of a space partition to a given GeometryDomain"""
15
16
 
16
17
  def __init__(
17
18
  self,
18
- space: FunctionSpace,
19
- domain: GeometryDomain,
20
19
  space_partition: SpacePartition,
20
+ domain: GeometryDomain,
21
21
  device=None,
22
+ temporary_store: TemporaryStore = None,
22
23
  ):
23
- if domain.dimension() == space.DIMENSION - 1:
24
- space = space.trace()
24
+ space_topology = space_partition.space_topology
25
+
26
+ if domain.dimension == space_topology.dimension - 1:
27
+ space_topology = space_topology.trace()
25
28
 
26
- if domain.dimension() != space.DIMENSION:
29
+ if domain.dimension != space_topology.dimension:
27
30
  raise ValueError("Incompatible space and domain dimensions")
28
31
 
29
- self.space = space
30
32
  self.space_partition = space_partition
33
+ self.space_topology = space_topology
31
34
  self.domain = domain
32
35
 
33
- self._compute_node_element_indices(device=device)
36
+ self._compute_node_element_indices(device=device, temporary_store=temporary_store)
34
37
 
35
- def _compute_node_element_indices(self, device):
38
+ def _compute_node_element_indices(self, device, temporary_store: TemporaryStore):
36
39
  from warp.fem import cache
37
40
 
38
- NODES_PER_ELEMENT = self.space.NODES_PER_ELEMENT
41
+ NODES_PER_ELEMENT = self.space_topology.NODES_PER_ELEMENT
39
42
 
40
- def fill_element_node_indices_fn(
43
+ @cache.dynamic_kernel(suffix=f"{self.domain.name}_{self.space_topology.name}_{self.space_partition.name}")
44
+ def fill_element_node_indices(
45
+ element_arg: self.domain.ElementArg,
41
46
  domain_index_arg: self.domain.ElementIndexArg,
42
- space_arg: self.space.SpaceArg,
47
+ topo_arg: self.space_topology.TopologyArg,
43
48
  partition_arg: self.space_partition.PartitionArg,
44
49
  element_node_indices: wp.array2d(dtype=int),
45
50
  ):
46
51
  domain_element_index = wp.tid()
47
52
  element_index = self.domain.element_index(domain_index_arg, domain_element_index)
48
53
  for n in range(NODES_PER_ELEMENT):
49
- space_nidx = self.space.element_node_index(space_arg, element_index, n)
54
+ space_nidx = self.space_topology.element_node_index(element_arg, topo_arg, element_index, n)
50
55
  partition_nidx = self.space_partition.partition_node_index(partition_arg, space_nidx)
51
56
  element_node_indices[domain_element_index, n] = partition_nidx
52
57
 
53
- fill_element_node_indices = cache.get_kernel(
54
- fill_element_node_indices_fn,
55
- suffix=f"{self.domain.name}_{self.space.name}_{self.space_partition.name}",
56
- )
57
-
58
- element_node_indices = wp.empty(
58
+ element_node_indices = borrow_temporary(
59
+ temporary_store,
59
60
  shape=(self.domain.element_count(), NODES_PER_ELEMENT),
60
61
  dtype=int,
61
62
  device=device,
62
63
  )
63
64
  wp.launch(
64
- dim=element_node_indices.shape[0],
65
+ dim=element_node_indices.array.shape[0],
65
66
  kernel=fill_element_node_indices,
66
67
  inputs=[
68
+ self.domain.element_arg_value(device),
67
69
  self.domain.element_index_arg_value(device),
68
- self.space.space_arg_value(device),
70
+ self.space_topology.topo_arg_value(device),
69
71
  self.space_partition.partition_arg_value(device),
70
- element_node_indices,
72
+ element_node_indices.array,
71
73
  ],
72
74
  device=device,
73
75
  )
74
76
 
75
77
  # Build compressed map from node to element indices
76
- flattened_node_indices = element_node_indices.reshape(element_node_indices.size)
78
+ flattened_node_indices = element_node_indices.array.flatten()
77
79
  (
78
80
  self._dof_partition_element_offsets,
79
81
  node_array_indices,
80
82
  self._node_count,
81
83
  self._dof_partition_indices,
82
- ) = compress_node_indices(self.space_partition.node_count(), flattened_node_indices)
84
+ ) = compress_node_indices(
85
+ self.space_partition.node_count(), flattened_node_indices, temporary_store=temporary_store
86
+ )
83
87
 
84
88
  # Extract element index and index in element
85
- self._dof_element_indices = wp.empty_like(flattened_node_indices)
86
- self._dof_indices_in_element = wp.empty_like(flattened_node_indices)
89
+ self._dof_element_indices = borrow_temporary_like(flattened_node_indices, temporary_store)
90
+ self._dof_indices_in_element = borrow_temporary_like(flattened_node_indices, temporary_store)
87
91
  wp.launch(
88
92
  kernel=SpaceRestriction._split_vertex_element_index,
89
93
  dim=flattened_node_indices.shape,
90
- inputs=[NODES_PER_ELEMENT, node_array_indices, self._dof_element_indices, self._dof_indices_in_element],
94
+ inputs=[
95
+ NODES_PER_ELEMENT,
96
+ node_array_indices.array,
97
+ self._dof_element_indices.array,
98
+ self._dof_indices_in_element.array,
99
+ ],
91
100
  device=flattened_node_indices.device,
92
101
  )
93
102
 
103
+ node_array_indices.release()
104
+
94
105
  def node_count(self):
95
106
  return self._node_count
96
107
 
97
108
  def partition_element_offsets(self):
98
- return self._dof_partition_element_offsets
109
+ return self._dof_partition_element_offsets.array
99
110
 
100
111
  def node_partition_indices(self):
101
- return self._dof_partition_indices
112
+ return self._dof_partition_indices.array
102
113
 
103
114
  def total_node_element_count(self):
104
- return self._dof_element_indices.size
115
+ return self._dof_element_indices.array.size
105
116
 
106
117
  @wp.struct
107
118
  class NodeArg:
@@ -110,12 +121,13 @@ class SpaceRestriction:
110
121
  dof_partition_indices: wp.array(dtype=int)
111
122
  dof_indices_in_element: wp.array(dtype=int)
112
123
 
124
+ @cached_arg_value
113
125
  def node_arg(self, device):
114
126
  arg = SpaceRestriction.NodeArg()
115
- arg.dof_element_offsets = self._dof_partition_element_offsets.to(device)
116
- arg.dof_element_indices = self._dof_element_indices.to(device)
117
- arg.dof_partition_indices = self._dof_partition_indices.to(device)
118
- arg.dof_indices_in_element = self._dof_indices_in_element.to(device)
127
+ arg.dof_element_offsets = self._dof_partition_element_offsets.array.to(device)
128
+ arg.dof_element_indices = self._dof_element_indices.array.to(device)
129
+ arg.dof_partition_indices = self._dof_partition_indices.array.to(device)
130
+ arg.dof_indices_in_element = self._dof_indices_in_element.array.to(device)
119
131
  return arg
120
132
 
121
133
  @wp.func
@@ -0,0 +1,15 @@
1
+ from .shape_function import ShapeFunction, ConstantShapeFunction
2
+
3
+ from .triangle_shape_function import Triangle2DPolynomialShapeFunctions, Triangle2DNonConformingPolynomialShapeFunctions
4
+ from .tet_shape_function import TetrahedronPolynomialShapeFunctions, TetrahedronNonConformingPolynomialShapeFunctions
5
+
6
+ from .square_shape_function import (
7
+ SquareBipolynomialShapeFunctions,
8
+ SquareSerendipityShapeFunctions,
9
+ SquareNonConformingPolynomialShapeFunctions,
10
+ )
11
+ from .cube_shape_function import (
12
+ CubeSerendipityShapeFunctions,
13
+ CubeTripolynomialShapeFunctions,
14
+ CubeNonConformingPolynomialShapeFunctions,
15
+ )