warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -1,134 +1,66 @@
1
1
  import warp as wp
2
- import numpy as np
3
2
 
4
-
5
- from warp.fem.types import ElementIndex, Coords, OUTSIDE, vec2i, vec3i
3
+ from warp.fem.types import ElementIndex, Coords
6
4
  from warp.fem.geometry import Trimesh2D
5
+ from warp.fem import cache
7
6
 
8
- from .dof_mapper import DofMapper
9
- from .nodal_function_space import NodalFunctionSpace, NodalFunctionSpaceTrace
7
+ from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
8
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
10
9
 
10
+ from .shape import ShapeFunction, ConstantShapeFunction
11
+ from .shape import Triangle2DPolynomialShapeFunctions, Triangle2DNonConformingPolynomialShapeFunctions
11
12
 
12
- class Trimesh2DFunctionSpace(NodalFunctionSpace):
13
- DIMENSION = wp.constant(2)
14
13
 
15
- @wp.struct
16
- class SpaceArg:
17
- geo_arg: Trimesh2D.SideArg
14
+ @wp.struct
15
+ class Trimesh2DTopologyArg:
16
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
17
+ tri_edge_indices: wp.array2d(dtype=int)
18
18
 
19
- reference_transforms: wp.array(dtype=wp.mat22f)
20
- tri_edge_indices: wp.array2d(dtype=int)
19
+ vertex_count: int
20
+ edge_count: int
21
21
 
22
- vertex_count: int
23
- edge_count: int
24
22
 
25
- def __init__(self, mesh: Trimesh2D, dtype: type = float, dof_mapper: DofMapper = None):
26
- super().__init__(dtype, dof_mapper)
27
- self._mesh = mesh
23
+ class Trimesh2DSpaceTopology(SpaceTopology):
24
+ TopologyArg = Trimesh2DTopologyArg
28
25
 
29
- self._reference_transforms: wp.array = None
30
- self._tri_edge_indices: wp.array = None
26
+ def __init__(self, mesh: Trimesh2D, shape: ShapeFunction):
27
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
28
+ self._mesh = mesh
29
+ self._shape = shape
31
30
 
32
- self._compute_reference_transforms()
33
31
  self._compute_tri_edge_indices()
34
32
 
35
- @property
36
- def geometry(self) -> Trimesh2D:
37
- return self._mesh
38
-
39
- def space_arg_value(self, device):
40
- arg = self.SpaceArg()
41
- arg.geo_arg = self.geometry.side_arg_value(device)
42
- arg.reference_transforms = self._reference_transforms.to(device)
33
+ @cache.cached_arg_value
34
+ def topo_arg_value(self, device):
35
+ arg = Trimesh2DTopologyArg()
43
36
  arg.tri_edge_indices = self._tri_edge_indices.to(device)
37
+ arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
44
38
 
45
39
  arg.vertex_count = self._mesh.vertex_count()
46
40
  arg.edge_count = self._mesh.side_count()
47
41
  return arg
48
42
 
49
- class Trace(NodalFunctionSpaceTrace):
50
- def __init__(self, space: NodalFunctionSpace):
51
- super().__init__(space)
52
- self.ORDER = space.ORDER
53
-
54
- @wp.func
55
- def _inner_cell_index(args: SpaceArg, side_index: ElementIndex):
56
- return Trimesh2D.side_inner_cell_index(args.geo_arg, side_index)
57
-
58
- @wp.func
59
- def _outer_cell_index(args: SpaceArg, side_index: ElementIndex):
60
- return Trimesh2D.side_outer_cell_index(args.geo_arg, side_index)
61
-
62
- @wp.func
63
- def _inner_cell_coords(args: SpaceArg, side_index: ElementIndex, side_coords: Coords):
64
- tri_index = Trimesh2D.side_inner_cell_index(args.geo_arg, side_index)
65
- return Trimesh2D.edge_to_tri_coords(args.geo_arg, side_index, tri_index, side_coords)
66
-
67
- @wp.func
68
- def _outer_cell_coords(args: SpaceArg, side_index: ElementIndex, side_coords: Coords):
69
- tri_index = Trimesh2D.side_outer_cell_index(args.geo_arg, side_index)
70
- return Trimesh2D.edge_to_tri_coords(args.geo_arg, side_index, tri_index, side_coords)
71
-
72
- @wp.func
73
- def _cell_to_side_coords(
74
- args: SpaceArg,
75
- side_index: ElementIndex,
76
- element_index: ElementIndex,
77
- element_coords: Coords,
78
- ):
79
- return Trimesh2D.tri_to_edge_coords(args.geo_arg, side_index, element_index, element_coords)
80
-
81
- def _compute_reference_transforms(self):
82
- self._reference_transforms = wp.empty(
83
- dtype=wp.mat22f, device=self._mesh.positions.device, shape=(self._mesh.cell_count())
84
- )
85
-
86
- wp.launch(
87
- kernel=Trimesh2DFunctionSpace._compute_reference_transforms_kernel,
88
- dim=self._reference_transforms.shape,
89
- device=self._reference_transforms.device,
90
- inputs=[self._mesh.tri_vertex_indices, self._mesh.positions, self._reference_transforms],
91
- )
92
-
93
43
  def _compute_tri_edge_indices(self):
94
44
  self._tri_edge_indices = wp.empty(
95
45
  dtype=int, device=self._mesh.tri_vertex_indices.device, shape=(self._mesh.cell_count(), 3)
96
46
  )
97
47
 
98
48
  wp.launch(
99
- kernel=Trimesh2DFunctionSpace._compute_tri_edge_indices_kernel,
100
- dim=self._mesh._edge_tri_indices.shape,
49
+ kernel=Trimesh2DSpaceTopology._compute_tri_edge_indices_kernel,
50
+ dim=self._mesh.edge_tri_indices.shape,
101
51
  device=self._mesh.tri_vertex_indices.device,
102
52
  inputs=[
103
- self._mesh._edge_tri_indices,
104
- self._mesh._edge_vertex_indices,
53
+ self._mesh.edge_tri_indices,
54
+ self._mesh.edge_vertex_indices,
105
55
  self._mesh.tri_vertex_indices,
106
56
  self._tri_edge_indices,
107
57
  ],
108
58
  )
109
59
 
110
- @wp.kernel
111
- def _compute_reference_transforms_kernel(
112
- tri_vertex_indices: wp.array2d(dtype=int),
113
- positions: wp.array(dtype=wp.vec2f),
114
- transforms: wp.array(dtype=wp.mat22f),
115
- ):
116
- t = wp.tid()
117
-
118
- p0 = positions[tri_vertex_indices[t, 0]]
119
- p1 = positions[tri_vertex_indices[t, 1]]
120
- p2 = positions[tri_vertex_indices[t, 2]]
121
-
122
- e1 = p1 - p0
123
- e2 = p2 - p0
124
-
125
- mat = wp.mat22(e1, e2)
126
- transforms[t] = wp.transpose(wp.inverse(mat))
127
-
128
60
  @wp.func
129
61
  def _find_edge_index_in_tri(
130
- edge_vtx: vec2i,
131
- tri_vtx: vec3i,
62
+ edge_vtx: wp.vec2i,
63
+ tri_vtx: wp.vec3i,
132
64
  ):
133
65
  for k in range(2):
134
66
  if (edge_vtx[0] == tri_vtx[k] and edge_vtx[1] == tri_vtx[k + 1]) or (
@@ -139,8 +71,8 @@ class Trimesh2DFunctionSpace(NodalFunctionSpace):
139
71
 
140
72
  @wp.kernel
141
73
  def _compute_tri_edge_indices_kernel(
142
- edge_tri_indices: wp.array(dtype=vec2i),
143
- edge_vertex_indices: wp.array(dtype=vec2i),
74
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
75
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
144
76
  tri_vertex_indices: wp.array2d(dtype=int),
145
77
  tri_edge_indices: wp.array2d(dtype=int),
146
78
  ):
@@ -150,710 +82,140 @@ class Trimesh2DFunctionSpace(NodalFunctionSpace):
150
82
  edge_tris = edge_tri_indices[e]
151
83
 
152
84
  t0 = edge_tris[0]
153
- t0_vtx = vec3i(tri_vertex_indices[t0, 0], tri_vertex_indices[t0, 1], tri_vertex_indices[t0, 2])
154
- t0_edge = Trimesh2DFunctionSpace._find_edge_index_in_tri(edge_vtx, t0_vtx)
85
+ t0_vtx = wp.vec3i(tri_vertex_indices[t0, 0], tri_vertex_indices[t0, 1], tri_vertex_indices[t0, 2])
86
+ t0_edge = Trimesh2DSpaceTopology._find_edge_index_in_tri(edge_vtx, t0_vtx)
155
87
  tri_edge_indices[t0, t0_edge] = e
156
88
 
157
89
  t1 = edge_tris[1]
158
90
  if t1 != t0:
159
- t1_vtx = vec3i(tri_vertex_indices[t1, 0], tri_vertex_indices[t1, 1], tri_vertex_indices[t1, 2])
160
- t1_edge = Trimesh2DFunctionSpace._find_edge_index_in_tri(edge_vtx, t1_vtx)
91
+ t1_vtx = wp.vec3i(tri_vertex_indices[t1, 0], tri_vertex_indices[t1, 1], tri_vertex_indices[t1, 2])
92
+ t1_edge = Trimesh2DSpaceTopology._find_edge_index_in_tri(edge_vtx, t1_vtx)
161
93
  tri_edge_indices[t1, t1_edge] = e
162
94
 
163
95
 
164
- class Trimesh2DPiecewiseConstantSpace(Trimesh2DFunctionSpace):
165
- ORDER = wp.constant(0)
166
- NODES_PER_ELEMENT = wp.constant(1)
167
-
168
- def __init__(self, grid: Trimesh2D, dtype: type = float, dof_mapper: DofMapper = None):
169
- super().__init__(grid, dtype, dof_mapper)
170
-
171
- self.element_outer_weight = self.element_inner_weight
172
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
173
-
174
- def node_count(self) -> int:
175
- return self._mesh.cell_count()
176
-
177
- def node_positions(self):
178
- vtx_pos = self._mesh.positions.numpy()
179
- tri_vtx = self._mesh.tri_vertex_indices.numpy()
180
-
181
- tri_pos = vtx_pos[tri_vtx]
182
- centers = tri_pos.sum(axis=1) / 3.0
183
-
184
- return centers[:,0], centers[:,1]
185
-
186
- @wp.func
187
- def element_node_index(
188
- args: Trimesh2DFunctionSpace.SpaceArg,
189
- element_index: ElementIndex,
190
- node_index_in_elt: int,
191
- ):
192
- return element_index
193
-
194
- @wp.func
195
- def node_coords_in_element(
196
- args: Trimesh2DFunctionSpace.SpaceArg,
197
- element_index: ElementIndex,
198
- node_index_in_elt: int,
199
- ):
200
- if node_index_in_elt == 0:
201
- return Coords(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0)
202
-
203
- return Coords(OUTSIDE)
204
-
205
- @wp.func
206
- def node_quadrature_weight(
207
- args: Trimesh2DFunctionSpace.SpaceArg,
208
- element_index: ElementIndex,
209
- node_index_in_elt: int,
210
- ):
211
- return 1.0
212
-
213
- @wp.func
214
- def element_inner_weight(
215
- args: Trimesh2DFunctionSpace.SpaceArg,
216
- element_index: ElementIndex,
217
- coords: Coords,
218
- node_index_in_elt: int,
219
- ):
220
- if node_index_in_elt == 0:
221
- return 1.0
222
- return 0.0
223
-
224
- @wp.func
225
- def element_inner_weight_gradient(
226
- args: Trimesh2DFunctionSpace.SpaceArg,
227
- element_index: ElementIndex,
228
- coords: Coords,
229
- node_index_in_elt: int,
230
- ):
231
- return wp.vec2(0.0)
96
+ class Trimesh2DDiscontinuousSpaceTopology(
97
+ DiscontinuousSpaceTopologyMixin,
98
+ SpaceTopology,
99
+ ):
100
+ def __init__(self, mesh: Trimesh2D, shape: ShapeFunction):
101
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
232
102
 
233
- class Trace(Trimesh2DFunctionSpace.Trace):
234
- NODES_PER_ELEMENT = wp.constant(2)
235
- ORDER = wp.constant(0)
236
103
 
237
- def __init__(self, space: "Trimesh2DPiecewiseConstantSpace"):
238
- super().__init__(space)
104
+ class Trimesh2DBasisSpace(ShapeBasisSpace):
105
+ def __init__(self, topology: Trimesh2DSpaceTopology, shape: ShapeFunction):
106
+ super().__init__(topology, shape)
239
107
 
240
- self.element_node_index = self._make_element_node_index(space)
108
+ self._mesh: Trimesh2D = topology.geometry
241
109
 
242
- self.element_inner_weight = self._make_element_inner_weight(space)
243
- self.element_inner_weight_gradient = self._make_element_inner_weight_gradient(space)
244
110
 
245
- self.element_outer_weight = self._make_element_outer_weight(space)
246
- self.element_outer_weight_gradient = self._make_element_outer_weight_gradient(space)
111
+ class Trimesh2DPiecewiseConstantBasis(Trimesh2DBasisSpace):
112
+ def __init__(self, mesh: Trimesh2D):
113
+ shape = ConstantShapeFunction(mesh.reference_cell(), space_dimension=2)
114
+ topology = Trimesh2DDiscontinuousSpaceTopology(mesh, shape)
115
+ super().__init__(shape=shape, topology=topology)
247
116
 
117
+ class Trace(TraceBasisSpace):
248
118
  @wp.func
249
- def node_coords_in_element(
250
- args: Trimesh2DFunctionSpace.SpaceArg,
119
+ def _node_coords_in_element(
120
+ side_arg: Trimesh2D.SideArg,
121
+ basis_arg: Trimesh2DBasisSpace.BasisArg,
251
122
  element_index: ElementIndex,
252
123
  node_index_in_element: int,
253
124
  ):
254
- if node_index_in_element == 0:
255
- return Coords(0.5, 0.0, 0.0)
256
- elif node_index_in_element == 1:
257
- return Coords(0.5, 0.0, 0.0)
258
-
259
- return Coords(OUTSIDE)
125
+ return Coords(0.5, 0.0, 0.0)
260
126
 
261
- @wp.func
262
- def node_quadrature_weight(
263
- args: Trimesh2DFunctionSpace.SpaceArg,
264
- element_index: ElementIndex,
265
- node_index_in_elt: int,
266
- ):
267
- return 1.0
127
+ def make_node_coords_in_element(self):
128
+ return self._node_coords_in_element
268
129
 
269
130
  def trace(self):
270
- return Trimesh2DPiecewiseConstantSpace.Trace(self)
271
-
272
-
273
- def _triangle_node_index(tx: int, ty: int, degree: int):
274
- VERTEX_NODE_COUNT = 3
275
- SIDE_INTERIOR_NODE_COUNT = degree - 1
276
-
277
- # Index in similar order to e.g. VTK
278
- # First vertices, then edge (counterclokwise) then interior points (recursively)
279
-
280
- if tx == 0:
281
- if ty == 0:
282
- return 0
283
- elif ty == degree:
284
- return 2
285
- else:
286
- edge_index = 2
287
- return VERTEX_NODE_COUNT + SIDE_INTERIOR_NODE_COUNT * edge_index + (SIDE_INTERIOR_NODE_COUNT - ty)
288
- elif ty == 0:
289
- if tx == degree:
290
- return 1
291
- else:
292
- edge_index = 0
293
- return VERTEX_NODE_COUNT + SIDE_INTERIOR_NODE_COUNT * edge_index + tx - 1
294
- elif tx + ty == degree:
295
- edge_index = 1
296
- return VERTEX_NODE_COUNT + SIDE_INTERIOR_NODE_COUNT * edge_index + ty - 1
297
-
298
- vertex_edge_node_count = 3 * degree
299
- return vertex_edge_node_count + _triangle_node_index(tx - 1, ty - 1, degree - 2)
300
-
301
-
302
- class Trimesh2DPolynomialShapeFunctions:
303
- INVALID = wp.constant(-1)
304
- VERTEX = wp.constant(0)
305
- EDGE = wp.constant(1)
306
- INTERIOR = wp.constant(2)
307
-
308
- def __init__(self, degree: int):
309
- self.ORDER = wp.constant(degree)
131
+ return Trimesh2DPiecewiseConstantBasis.Trace(self)
310
132
 
311
- self.NODES_PER_ELEMENT = wp.constant((degree + 1) * (degree + 2) // 2)
312
- self.NODES_PER_SIDE = wp.constant(degree + 1)
313
-
314
- triangle_coords = np.empty((self.NODES_PER_ELEMENT, 2), dtype=int)
315
-
316
- for tx in range(degree + 1):
317
- for ty in range(degree + 1 - tx):
318
- index = _triangle_node_index(tx, ty, degree)
319
- triangle_coords[index] = [tx, ty]
320
-
321
- CoordTypeVec = wp.mat(dtype=int, shape=(self.NODES_PER_ELEMENT, 2))
322
- self.NODE_TRIANGLE_COORDS = wp.constant(CoordTypeVec(triangle_coords))
323
-
324
- self.node_type_and_type_index = self._get_node_type_and_type_index()
325
- self._node_triangle_coordinates = self._get_node_triangle_coordinates()
326
-
327
- @property
328
- def name(self) -> str:
329
- return f"{self.ORDER}"
330
-
331
- def _get_node_triangle_coordinates(self):
332
- NODE_TRIANGLE_COORDS = self.NODE_TRIANGLE_COORDS
333
-
334
- def node_triangle_coordinates(
335
- node_index_in_elt: int,
336
- ):
337
- return vec2i(NODE_TRIANGLE_COORDS[node_index_in_elt, 0], NODE_TRIANGLE_COORDS[node_index_in_elt, 1])
338
-
339
- from warp.fem import cache
340
-
341
- return cache.get_func(node_triangle_coordinates, self.name)
342
-
343
- def _get_node_type_and_type_index(self):
344
- ORDER = self.ORDER
345
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
346
-
347
- def node_type_and_index(
348
- node_index_in_elt: int,
349
- ):
350
- if node_index_in_elt < 0 or node_index_in_elt >= NODES_PER_ELEMENT:
351
- return Trimesh2DPolynomialShapeFunctions.INVALID, Trimesh2DPolynomialShapeFunctions.INVALID
352
-
353
- if node_index_in_elt < 3:
354
- return Trimesh2DPolynomialShapeFunctions.VERTEX, node_index_in_elt
355
-
356
- if node_index_in_elt < 3 * ORDER:
357
- return Trimesh2DPolynomialShapeFunctions.EDGE, (node_index_in_elt - 3)
358
-
359
- return Trimesh2DPolynomialShapeFunctions.INTERIOR, (node_index_in_elt - 3 * ORDER)
360
-
361
- from warp.fem import cache
362
-
363
- return cache.get_func(node_type_and_index, self.name)
364
-
365
- def make_node_coords_in_element(self):
366
- ORDER = self.ORDER
367
-
368
- def node_coords_in_element(
369
- args: Trimesh2DFunctionSpace.SpaceArg,
370
- element_index: ElementIndex,
371
- node_index_in_elt: int,
372
- ):
373
- tri_coords = self._node_triangle_coordinates(node_index_in_elt)
374
- cx = float(tri_coords[0]) / float(ORDER)
375
- cy = float(tri_coords[1]) / float(ORDER)
376
- return Coords(1.0 - cx - cy, cx, cy)
377
-
378
- from warp.fem import cache
379
-
380
- return cache.get_func(node_coords_in_element, self.name)
381
-
382
- def make_node_quadrature_weight(self):
383
- ORDER = self.ORDER
384
-
385
- def node_uniform_quadrature_weight(
386
- args: Trimesh2DFunctionSpace.SpaceArg,
387
- element_index: ElementIndex,
388
- node_index_in_elt: int,
389
- ):
390
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
391
-
392
- base_weight = 1.0 / float(3 * ORDER * ORDER)
393
- if node_type == Trimesh2DPolynomialShapeFunctions.VERTEX:
394
- return base_weight
395
- if node_type == Trimesh2DPolynomialShapeFunctions.EDGE:
396
- return 2.0 * base_weight
397
- return 4.0 * base_weight
398
-
399
- def node_linear_quadrature_weight(
400
- args: Trimesh2DFunctionSpace.SpaceArg,
401
- element_index: ElementIndex,
402
- node_index_in_elt: int,
403
- ):
404
- return 1.0 / 3.0
405
-
406
- from warp.fem import cache
407
-
408
- if ORDER == 1:
409
- return cache.get_func(node_linear_quadrature_weight, self.name)
410
- return cache.get_func(node_uniform_quadrature_weight, self.name)
411
-
412
- def make_trace_node_quadrature_weight(self):
413
- ORDER = self.ORDER
414
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
415
-
416
- def trace_uniform_node_quadrature_weight(
417
- args: Trimesh2DFunctionSpace.SpaceArg,
418
- element_index: ElementIndex,
419
- node_index_in_elt: int,
420
- ):
421
- if node_index_in_elt >= NODES_PER_ELEMENT:
422
- node_index_in_cell = node_index_in_elt - NODES_PER_ELEMENT
423
- else:
424
- node_index_in_cell = node_index_in_elt
425
-
426
- # We're either on a side interior or at a vertex
427
- node_type, type_index = self.node_type_and_type_index(node_index_in_cell)
428
-
429
- base_weight = 1.0 / float(ORDER)
430
- return wp.select(node_type == Trimesh2DPolynomialShapeFunctions.VERTEX, base_weight, 0.5 * base_weight)
431
-
432
- def trace_linear_node_quadrature_weight(
433
- args: Trimesh2DFunctionSpace.SpaceArg,
434
- element_index: ElementIndex,
435
- node_index_in_elt: int,
436
- ):
437
- return 0.5
438
-
439
- from warp.fem import cache
440
-
441
- if ORDER == 1:
442
- return cache.get_func(trace_linear_node_quadrature_weight, self.name)
443
-
444
- return cache.get_func(trace_uniform_node_quadrature_weight, self.name)
445
-
446
- def make_element_inner_weight(self):
447
- ORDER = self.ORDER
448
-
449
- def element_inner_weight_linear(
450
- args: Trimesh2DFunctionSpace.SpaceArg,
451
- element_index: ElementIndex,
452
- coords: Coords,
453
- node_index_in_elt: int,
454
- ):
455
- if node_index_in_elt < 0 or node_index_in_elt >= 3:
456
- return 0.0
457
-
458
- return coords[node_index_in_elt]
459
-
460
- def element_inner_weight_quadratic(
461
- args: Trimesh2DFunctionSpace.SpaceArg,
462
- element_index: ElementIndex,
463
- coords: Coords,
464
- node_index_in_elt: int,
465
- ):
466
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
467
133
 
468
- if node_type == Trimesh2DPolynomialShapeFunctions.VERTEX:
469
- # Vertex
470
- return coords[type_index] * (2.0 * coords[type_index] - 1.0)
134
+ class Trimesh2DPolynomialSpaceTopology(Trimesh2DSpaceTopology):
135
+ def __init__(self, mesh: Trimesh2D, shape: Triangle2DPolynomialShapeFunctions):
136
+ super().__init__(mesh, shape)
471
137
 
472
- elif node_type == Trimesh2DPolynomialShapeFunctions.EDGE:
473
- # Edge
474
- c1 = type_index
475
- c2 = (type_index + 1) % 3
476
- return 4.0 * coords[c1] * coords[c2]
477
-
478
- return 0.0
479
-
480
- def element_inner_weight_cubic(
481
- args: Trimesh2DFunctionSpace.SpaceArg,
482
- element_index: ElementIndex,
483
- coords: Coords,
484
- node_index_in_elt: int,
485
- ):
486
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
487
-
488
- if node_type == Trimesh2DPolynomialShapeFunctions.VERTEX:
489
- # Vertex
490
- return 0.5 * coords[type_index] * (3.0 * coords[type_index] - 1.0) * (3.0 * coords[type_index] - 2.0)
491
-
492
- elif node_type == Trimesh2DPolynomialShapeFunctions.EDGE:
493
- # Edge
494
- edge = type_index // 2
495
- k = type_index - 2 * edge
496
- c1 = (edge + k) % 3
497
- c2 = (edge + 1 - k) % 3
498
-
499
- return 4.5 * coords[c1] * coords[c2] * (3.0 * coords[c1] - 1.0)
500
-
501
- elif node_type == Trimesh2DPolynomialShapeFunctions.INTERIOR:
502
- # Interior
503
- return 27.0 * coords[0] * coords[1] * coords[2]
504
-
505
- return 0.0
506
-
507
- from warp.fem import cache
508
-
509
- if ORDER == 1:
510
- return cache.get_func(element_inner_weight_linear, self.name)
511
- elif ORDER == 2:
512
- return cache.get_func(element_inner_weight_quadratic, self.name)
513
- elif ORDER == 3:
514
- return cache.get_func(element_inner_weight_cubic, self.name)
515
-
516
- return None
517
-
518
- def make_element_inner_weight_gradient(self):
519
- ORDER = self.ORDER
520
-
521
- def element_inner_weight_gradient_linear(
522
- args: Trimesh2DFunctionSpace.SpaceArg,
523
- element_index: ElementIndex,
524
- coords: Coords,
525
- node_index_in_elt: int,
526
- ):
527
- if node_index_in_elt < 0 or node_index_in_elt >= 3:
528
- return wp.vec2(0.0)
529
-
530
- dw_dc = wp.vec3(0.0)
531
- dw_dc[node_index_in_elt] = 1.0
532
-
533
- dw_du = wp.vec2(dw_dc[1] - dw_dc[0], dw_dc[2] - dw_dc[0])
534
- return args.reference_transforms[element_index] * dw_du
535
-
536
- def element_inner_weight_gradient_quadratic(
537
- args: Trimesh2DFunctionSpace.SpaceArg,
538
- element_index: ElementIndex,
539
- coords: Coords,
540
- node_index_in_elt: int,
541
- ):
542
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
543
-
544
- dw_dc = wp.vec3(0.0)
545
-
546
- if node_type == Trimesh2DPolynomialShapeFunctions.VERTEX:
547
- # Vertex
548
- dw_dc[type_index] = 4.0 * coords[type_index] - 1.0
549
-
550
- elif node_type == Trimesh2DPolynomialShapeFunctions.EDGE:
551
- # Edge
552
- c1 = type_index
553
- c2 = (type_index + 1) % 3
554
- dw_dc[c1] = 4.0 * coords[c2]
555
- dw_dc[c2] = 4.0 * coords[c1]
556
-
557
- dw_du = wp.vec2(dw_dc[1] - dw_dc[0], dw_dc[2] - dw_dc[0])
558
- return args.reference_transforms[element_index] * dw_du
559
-
560
- def element_inner_weight_gradient_cubic(
561
- args: Trimesh2DFunctionSpace.SpaceArg,
562
- element_index: ElementIndex,
563
- coords: Coords,
564
- node_index_in_elt: int,
565
- ):
566
- node_type, type_index = self.node_type_and_type_index(node_index_in_elt)
567
-
568
- dw_dc = wp.vec3(0.0)
569
-
570
- if node_type == Trimesh2DPolynomialShapeFunctions.VERTEX:
571
- # Vertex
572
- dw_dc[type_index] = (
573
- 0.5 * 27.0 * coords[type_index] * coords[type_index] - 9.0 * coords[type_index] + 1.0
574
- )
575
-
576
- elif node_type == Trimesh2DPolynomialShapeFunctions.EDGE:
577
- # Edge
578
- edge = type_index // 2
579
- k = type_index - 2 * edge
580
- c1 = (edge + k) % 3
581
- c2 = (edge + 1 - k) % 3
582
-
583
- dw_dc[c1] = 4.5 * coords[c2] * (6.0 * coords[c1] - 1.0)
584
- dw_dc[c2] = 4.5 * coords[c1] * (3.0 * coords[c1] - 1.0)
585
-
586
- elif node_type == Trimesh2DPolynomialShapeFunctions.INTERIOR:
587
- # Interior
588
- dw_dc = wp.vec3(
589
- 27.0 * coords[1] * coords[2], 27.0 * coords[2] * coords[0], 27.0 * coords[0] * coords[1]
590
- )
591
-
592
- dw_du = wp.vec2(dw_dc[1] - dw_dc[0], dw_dc[2] - dw_dc[0])
593
- return args.reference_transforms[element_index] * dw_du
594
-
595
- from warp.fem import cache
596
-
597
- if ORDER == 1:
598
- return cache.get_func(element_inner_weight_gradient_linear, self.name)
599
- elif ORDER == 2:
600
- return cache.get_func(element_inner_weight_gradient_quadratic, self.name)
601
- elif ORDER == 3:
602
- return cache.get_func(element_inner_weight_gradient_cubic, self.name)
603
-
604
- return None
605
-
606
- @staticmethod
607
- def node_positions(space):
608
- if space.ORDER == 1:
609
- return np.transpose(space._mesh.positions.numpy())
610
-
611
- NODES_PER_ELEMENT = space.NODES_PER_ELEMENT
612
-
613
- def fill_node_positions_fn(
614
- space_arg: space.SpaceArg,
615
- node_positions: wp.array(dtype=wp.vec2),
616
- ):
617
- element_index = wp.tid()
618
- tri_idx = space_arg.geo_arg.tri_vertex_indices[element_index]
619
- p0 = space_arg.geo_arg.positions[tri_idx[0]]
620
- p1 = space_arg.geo_arg.positions[tri_idx[1]]
621
- p2 = space_arg.geo_arg.positions[tri_idx[2]]
622
-
623
- for n in range(NODES_PER_ELEMENT):
624
- node_index = space.element_node_index(space_arg, element_index, n)
625
- coords = space.node_coords_in_element(space_arg, element_index, n)
626
-
627
- pos = coords[0] * p0 + coords[1] * p1 + coords[2] * p2
628
-
629
- node_positions[node_index] = pos
630
-
631
- from warp.fem import cache
632
-
633
- fill_node_positions = cache.get_kernel(
634
- fill_node_positions_fn,
635
- suffix=space.name,
636
- )
637
-
638
- device = space._mesh.tri_vertex_indices.device
639
- node_positions = wp.empty(
640
- shape=space.node_count(),
641
- dtype=wp.vec2,
642
- device=device,
643
- )
644
- wp.launch(
645
- dim=space._mesh.cell_count(),
646
- kernel=fill_node_positions,
647
- inputs=[
648
- space.space_arg_value(device),
649
- node_positions,
650
- ],
651
- device=device,
652
- )
653
-
654
- return np.transpose(node_positions.numpy())
655
-
656
- @staticmethod
657
- def node_triangulation(space):
658
- if space.ORDER == 1:
659
- return space._mesh.tri_vertex_indices.numpy()
660
-
661
- NODES_PER_ELEMENT = space.NODES_PER_ELEMENT
662
-
663
- def fill_element_node_indices_fn(
664
- space_arg: space.SpaceArg,
665
- element_node_indices: wp.array2d(dtype=int),
666
- ):
667
- element_index = wp.tid()
668
- for n in range(NODES_PER_ELEMENT):
669
- element_node_indices[element_index, n] = space.element_node_index(space_arg, element_index, n)
670
-
671
- from warp.fem import cache
138
+ self.element_node_index = self._make_element_node_index()
672
139
 
673
- fill_element_node_indices = cache.get_kernel(
674
- fill_element_node_indices_fn,
675
- suffix=space.name,
676
- )
140
+ def node_count(self) -> int:
141
+ INTERIOR_NODES_PER_SIDE = max(0, self._shape.ORDER - 1)
142
+ INTERIOR_NODES_PER_CELL = max(0, self._shape.ORDER - 2) * max(0, self._shape.ORDER - 1) // 2
677
143
 
678
- device = space._mesh.tri_vertex_indices.device
679
- element_node_indices = wp.empty(
680
- shape=(space._mesh.cell_count(), NODES_PER_ELEMENT),
681
- dtype=int,
682
- device=device,
683
- )
684
- wp.launch(
685
- dim=element_node_indices.shape[0],
686
- kernel=fill_element_node_indices,
687
- inputs=[
688
- space.space_arg_value(device),
689
- element_node_indices,
690
- ],
691
- device=device,
144
+ return (
145
+ self._mesh.vertex_count()
146
+ + self._mesh.side_count() * INTERIOR_NODES_PER_SIDE
147
+ + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
692
148
  )
693
149
 
694
- element_node_indices = element_node_indices.numpy()
695
- if space.ORDER == 2:
696
- element_triangles = [[0, 3, 5], [3, 1, 4], [2, 5, 4], [3, 4, 5]]
697
- elif space.ORDER == 3:
698
- element_triangles = [
699
- [0, 3, 8],
700
- [3, 4, 9],
701
- [4, 1, 5],
702
- [8, 3, 9],
703
- [4, 5, 9],
704
- [8, 9, 7],
705
- [9, 5, 6],
706
- [6, 7, 9],
707
- [7, 6, 2],
708
- ]
709
-
710
- tri_indices = element_node_indices[:, element_triangles].reshape(-1, 3)
711
- return tri_indices
712
-
713
-
714
- class Trimesh2DPolynomialSpace(Trimesh2DFunctionSpace):
715
-
716
- def __init__(self, grid: Trimesh2D, degree: int, dtype: type = float, dof_mapper: DofMapper = None):
717
- super().__init__(grid, dtype, dof_mapper)
718
-
719
- self._shape = Trimesh2DPolynomialShapeFunctions(degree)
720
-
721
- self.ORDER = self._shape.ORDER
722
- self.NODES_PER_ELEMENT = self._shape.NODES_PER_ELEMENT
723
-
724
- self.element_node_index = self._make_element_node_index()
725
- self.node_coords_in_element = self._shape.make_node_coords_in_element()
726
- self.node_quadrature_weight = self._shape.make_node_quadrature_weight()
727
- self.element_inner_weight = self._shape.make_element_inner_weight()
728
- self.element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
729
-
730
- self.element_outer_weight = self.element_inner_weight
731
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
732
-
733
150
  def _make_element_node_index(self):
734
- INTERIOR_NODES_PER_SIDE = wp.constant(max(0, self.ORDER - 1))
735
- INTERIOR_NODES_PER_CELL = wp.constant(max(0, self.ORDER - 2) * max(0, self.ORDER - 1) // 2)
151
+ INTERIOR_NODES_PER_SIDE = wp.constant(max(0, self._shape.ORDER - 1))
152
+ INTERIOR_NODES_PER_CELL = wp.constant(max(0, self._shape.ORDER - 2) * max(0, self._shape.ORDER - 1) // 2)
736
153
 
154
+ @cache.dynamic_func(suffix=self.name)
737
155
  def element_node_index(
738
- args: Trimesh2DFunctionSpace.SpaceArg,
156
+ geo_arg: Trimesh2D.CellArg,
157
+ topo_arg: Trimesh2DTopologyArg,
739
158
  element_index: ElementIndex,
740
159
  node_index_in_elt: int,
741
160
  ):
742
161
  node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
743
162
 
744
- if node_type == Trimesh2DPolynomialShapeFunctions.VERTEX:
745
- return args.geo_arg.tri_vertex_indices[element_index][type_index]
163
+ if node_type == Triangle2DPolynomialShapeFunctions.VERTEX:
164
+ return geo_arg.tri_vertex_indices[element_index][type_index]
746
165
 
747
- global_offset = args.vertex_count
166
+ global_offset = topo_arg.vertex_count
748
167
 
749
- if node_type == Trimesh2DPolynomialShapeFunctions.EDGE:
168
+ if node_type == Triangle2DPolynomialShapeFunctions.EDGE:
750
169
  edge = type_index // INTERIOR_NODES_PER_SIDE
751
170
  edge_node = type_index - INTERIOR_NODES_PER_SIDE * edge
752
171
 
753
- global_edge_index = args.tri_edge_indices[element_index][edge]
172
+ global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
754
173
 
755
174
  if (
756
- args.geo_arg.edge_vertex_indices[global_edge_index][0]
757
- != args.geo_arg.tri_vertex_indices[element_index][edge]
175
+ topo_arg.edge_vertex_indices[global_edge_index][0]
176
+ != geo_arg.tri_vertex_indices[element_index][edge]
758
177
  ):
759
178
  edge_node = INTERIOR_NODES_PER_SIDE - 1 - edge_node
760
179
 
761
180
  return global_offset + INTERIOR_NODES_PER_SIDE * global_edge_index + edge_node
762
181
 
763
- global_offset += INTERIOR_NODES_PER_SIDE * args.edge_count
182
+ global_offset += INTERIOR_NODES_PER_SIDE * topo_arg.edge_count
764
183
  return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index
765
184
 
766
- from warp.fem import cache
767
-
768
- return cache.get_func(element_node_index, self.name)
769
-
770
- def node_count(self) -> int:
771
- INTERIOR_NODES_PER_SIDE = wp.constant(max(0, self.ORDER - 1))
772
- INTERIOR_NODES_PER_CELL = wp.constant(max(0, self.ORDER - 2) * max(0, self.ORDER - 1) // 2)
773
-
774
- return (
775
- self._mesh.vertex_count()
776
- + self._mesh.side_count() * INTERIOR_NODES_PER_SIDE
777
- + self._mesh.cell_count() * INTERIOR_NODES_PER_CELL
778
- )
779
-
780
- def node_positions(self):
781
- return Trimesh2DPolynomialShapeFunctions.node_positions(self)
782
-
783
- def node_triangulation(self):
784
- return Trimesh2DPolynomialShapeFunctions.node_triangulation(self)
185
+ return element_node_index
785
186
 
786
- class Trace(Trimesh2DFunctionSpace.Trace):
787
- NODES_PER_ELEMENT = wp.constant(2)
788
- ORDER = wp.constant(0)
789
187
 
790
- def __init__(self, space: "Trimesh2DPolynomialSpace"):
791
- super().__init__(space)
792
-
793
- self.element_node_index = self._make_element_node_index(space)
794
- self.node_coords_in_element = self._make_node_coords_in_element(space)
795
- self.node_quadrature_weight = space._shape.make_trace_node_quadrature_weight()
796
-
797
- self.element_inner_weight = self._make_element_inner_weight(space)
798
- self.element_inner_weight_gradient = self._make_element_inner_weight_gradient(space)
799
-
800
- self.element_outer_weight = self._make_element_outer_weight(space)
801
- self.element_outer_weight_gradient = self._make_element_outer_weight_gradient(space)
802
-
803
- def trace(self):
804
- return Trimesh2DPolynomialSpace.Trace(self)
805
-
806
-
807
- class Trimesh2DDGPolynomialSpace(Trimesh2DFunctionSpace):
188
+ class Trimesh2DPolynomialBasisSpace(Trimesh2DBasisSpace):
808
189
  def __init__(
809
190
  self,
810
191
  mesh: Trimesh2D,
811
192
  degree: int,
812
- dtype: type = float,
813
- dof_mapper: DofMapper = None,
814
193
  ):
815
- super().__init__(mesh, dtype, dof_mapper)
816
-
817
- self._shape = Trimesh2DPolynomialShapeFunctions(degree)
194
+ shape = Triangle2DPolynomialShapeFunctions(degree)
195
+ topology = forward_base_topology(Trimesh2DPolynomialSpaceTopology, mesh, shape)
818
196
 
819
- self.ORDER = self._shape.ORDER
820
- self.NODES_PER_ELEMENT = self._shape.NODES_PER_ELEMENT
197
+ super().__init__(topology, shape)
821
198
 
822
- self.element_node_index = self._make_element_node_index()
823
- self.node_coords_in_element = self._shape.make_node_coords_in_element()
824
- self.node_quadrature_weight = self._shape.make_node_quadrature_weight()
825
- self.element_inner_weight = self._shape.make_element_inner_weight()
826
- self.element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
827
-
828
- self.element_outer_weight = self.element_inner_weight
829
- self.element_outer_weight_gradient = self.element_inner_weight_gradient
830
-
831
- def node_count(self) -> int:
832
- return self._mesh.cell_count() * self.NODES_PER_ELEMENT
833
-
834
- def node_positions(self):
835
- return Trimesh2DPolynomialShapeFunctions.node_positions(self)
836
-
837
- def node_triangulation(self):
838
- return Trimesh2DPolynomialShapeFunctions.node_triangulation(self)
839
-
840
- def _make_element_node_index(self):
841
- NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
842
199
 
843
- def element_node_index(
844
- args: Trimesh2DFunctionSpace.SpaceArg,
845
- element_index: ElementIndex,
846
- node_index_in_elt: int,
847
- ):
848
- return element_index * NODES_PER_ELEMENT + node_index_in_elt
200
+ class Trimesh2DDGPolynomialBasisSpace(Trimesh2DBasisSpace):
201
+ def __init__(
202
+ self,
203
+ mesh: Trimesh2D,
204
+ degree: int,
205
+ ):
206
+ shape = Triangle2DPolynomialShapeFunctions(degree)
207
+ topology = Trimesh2DDiscontinuousSpaceTopology(mesh, shape)
849
208
 
850
- from warp.fem import cache
209
+ super().__init__(topology, shape)
851
210
 
852
- return cache.get_func(element_node_index, f"{self.name}_{self.ORDER}")
853
211
 
854
- class Trace(Trimesh2DPolynomialSpace.Trace):
855
- def __init__(self, space: "Trimesh2DDGPolynomialSpace"):
856
- super().__init__(space)
212
+ class Trimesh2DNonConformingPolynomialBasisSpace(Trimesh2DBasisSpace):
213
+ def __init__(
214
+ self,
215
+ mesh: Trimesh2D,
216
+ degree: int,
217
+ ):
218
+ shape = Triangle2DNonConformingPolynomialShapeFunctions(degree)
219
+ topology = Trimesh2DDiscontinuousSpaceTopology(mesh, shape)
857
220
 
858
- def trace(self):
859
- return Trimesh2DDGPolynomialSpace.Trace(self)
221
+ super().__init__(topology, shape)