warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -4,9 +4,10 @@ import warp as wp
4
4
 
5
5
  from warp.fem.types import ElementIndex, Coords, make_free_sample
6
6
  from warp.fem.geometry import Geometry
7
- from warp.fem import cache, utils
7
+ from warp.fem.quadrature import Quadrature
8
+ from warp.fem import cache
8
9
 
9
- from .topology import SpaceTopology
10
+ from .topology import SpaceTopology, DiscontinuousSpaceTopology
10
11
  from .shape import ShapeFunction
11
12
 
12
13
 
@@ -25,19 +26,10 @@ class BasisSpace:
25
26
 
26
27
  pass
27
28
 
28
- def __init__(self, topology: SpaceTopology, shape: ShapeFunction):
29
+ def __init__(self, topology: SpaceTopology):
29
30
  self._topology = topology
30
- self._shape = shape
31
31
 
32
32
  self.NODES_PER_ELEMENT = self._topology.NODES_PER_ELEMENT
33
- self.ORDER = self._shape.ORDER
34
-
35
- if hasattr(shape, "element_node_triangulation"):
36
- self.node_triangulation = self._node_triangulation
37
- if hasattr(shape, "element_node_tets"):
38
- self.node_tets = self._node_tets
39
- if hasattr(shape, "element_node_hexes"):
40
- self.node_hexes = self._node_hexes
41
33
 
42
34
  @property
43
35
  def topology(self) -> SpaceTopology:
@@ -49,19 +41,10 @@ class BasisSpace:
49
41
  """Underlying geometry of the basis space"""
50
42
  return self._topology.geometry
51
43
 
52
- @property
53
- def shape(self) -> ShapeFunction:
54
- """Shape functions used for defining individual element basis"""
55
- return self._shape
56
-
57
44
  def basis_arg_value(self, device) -> "BasisArg":
58
45
  """Value for the argument structure to be passed to device functions"""
59
46
  return BasisSpace.BasisArg()
60
47
 
61
- @property
62
- def name(self):
63
- return f"{self.topology.name}_{self._shape.name}"
64
-
65
48
  # Helpers for generating node positions
66
49
 
67
50
  def node_positions(self, out: Optional[wp.array] = None) -> wp.array:
@@ -117,6 +100,56 @@ class BasisSpace:
117
100
 
118
101
  return node_positions
119
102
 
103
+ def make_node_coords_in_element(self):
104
+ raise NotImplementedError()
105
+
106
+ def make_node_quadrature_weight(self):
107
+ raise NotImplementedError()
108
+
109
+ def make_element_inner_weight(self):
110
+ raise NotImplementedError()
111
+
112
+ def make_element_outer_weight(self):
113
+ return self.make_element_inner_weight()
114
+
115
+ def make_element_inner_weight_gradient(self):
116
+ raise NotImplementedError()
117
+
118
+ def make_element_outer_weight_gradient(self):
119
+ return self.make_element_inner_weight_gradient()
120
+
121
+ def make_trace_node_quadrature_weight(self):
122
+ raise NotImplementedError()
123
+
124
+ def trace(self) -> "TraceBasisSpace":
125
+ return TraceBasisSpace(self)
126
+
127
+
128
+ class ShapeBasisSpace(BasisSpace):
129
+ """Base class for defining shape-function-based basis spaces."""
130
+
131
+ def __init__(self, topology: SpaceTopology, shape: ShapeFunction):
132
+ super().__init__(topology)
133
+ self._shape = shape
134
+
135
+ self.ORDER = self._shape.ORDER
136
+
137
+ if hasattr(shape, "element_node_triangulation"):
138
+ self.node_triangulation = self._node_triangulation
139
+ if hasattr(shape, "element_node_tets"):
140
+ self.node_tets = self._node_tets
141
+ if hasattr(shape, "element_node_hexes"):
142
+ self.node_hexes = self._node_hexes
143
+
144
+ @property
145
+ def shape(self) -> ShapeFunction:
146
+ """Shape functions used for defining individual element basis"""
147
+ return self._shape
148
+
149
+ @property
150
+ def name(self):
151
+ return f"{self.topology.name}_{self._shape.name}"
152
+
120
153
  def make_node_coords_in_element(self):
121
154
  shape_node_coords_in_element = self._shape.make_node_coords_in_element()
122
155
 
@@ -156,14 +189,10 @@ class BasisSpace:
156
189
  coords: Coords,
157
190
  node_index_in_elt: int,
158
191
  ):
159
- a = self.BasisArg()
160
192
  return shape_element_inner_weight(coords, node_index_in_elt)
161
193
 
162
194
  return element_inner_weight
163
195
 
164
- def make_element_outer_weight(self):
165
- return self.make_element_inner_weight()
166
-
167
196
  def make_element_inner_weight_gradient(self):
168
197
  shape_element_inner_weight_gradient = self._shape.make_element_inner_weight_gradient()
169
198
 
@@ -179,11 +208,22 @@ class BasisSpace:
179
208
 
180
209
  return element_inner_weight_gradient
181
210
 
182
- def make_element_outer_weight_gradient(self):
183
- return self.make_element_inner_weight_gradient()
211
+ def make_trace_node_quadrature_weight(self, trace_basis):
212
+ shape_trace_node_quadrature_weight = self._shape.make_trace_node_quadrature_weight()
184
213
 
185
- def trace(self) -> "TraceBasisSpace":
186
- return TraceBasisSpace(self)
214
+ @cache.dynamic_func(suffix=self.name)
215
+ def trace_node_quadrature_weight(
216
+ geo_side_arg: trace_basis.geometry.SideArg,
217
+ basis_arg: trace_basis.BasisArg,
218
+ element_index: ElementIndex,
219
+ node_index_in_elt: int,
220
+ ):
221
+ neighbour_elem, index_in_neighbour = trace_basis.topology.neighbor_cell_index(
222
+ geo_side_arg, element_index, node_index_in_elt
223
+ )
224
+ return shape_trace_node_quadrature_weight(index_in_neighbour)
225
+
226
+ return trace_node_quadrature_weight
187
227
 
188
228
  def _node_triangulation(self):
189
229
  element_node_indices = self._topology.element_node_indices().numpy()
@@ -211,7 +251,9 @@ class TraceBasisSpace(BasisSpace):
211
251
  """Auto-generated trace space evaluating the cell-defined basis on the geometry sides"""
212
252
 
213
253
  def __init__(self, basis: BasisSpace):
214
- super().__init__(basis.topology.trace(), basis.shape)
254
+ super().__init__(basis.topology.trace())
255
+
256
+ self.ORDER = basis.ORDER
215
257
 
216
258
  self._basis = basis
217
259
  self.BasisArg = self._basis.BasisArg
@@ -247,21 +289,7 @@ class TraceBasisSpace(BasisSpace):
247
289
  return trace_node_coords_in_element
248
290
 
249
291
  def make_node_quadrature_weight(self):
250
- shape_trace_node_quadrature_weight = self._shape.make_trace_node_quadrature_weight()
251
-
252
- @cache.dynamic_func(suffix=self._basis.name)
253
- def trace_node_quadrature_weight(
254
- geo_side_arg: self.geometry.SideArg,
255
- basis_arg: self.BasisArg,
256
- element_index: ElementIndex,
257
- node_index_in_elt: int,
258
- ):
259
- neighbour_elem, index_in_neighbour = self.topology.neighbor_cell_index(
260
- geo_side_arg, element_index, node_index_in_elt
261
- )
262
- return shape_trace_node_quadrature_weight(index_in_neighbour)
263
-
264
- return trace_node_quadrature_weight
292
+ return self._basis.make_trace_node_quadrature_weight(self)
265
293
 
266
294
  def make_element_inner_weight(self):
267
295
  cell_inner_weight = self._basis.make_element_inner_weight()
@@ -337,9 +365,7 @@ class TraceBasisSpace(BasisSpace):
337
365
 
338
366
  cell_coords = self.geometry.side_inner_cell_coords(geo_side_arg, element_index, coords)
339
367
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
340
- return cell_inner_weight_gradient(
341
- geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell
342
- )
368
+ return cell_inner_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell)
343
369
 
344
370
  return trace_element_inner_weight_gradient
345
371
 
@@ -361,11 +387,103 @@ class TraceBasisSpace(BasisSpace):
361
387
 
362
388
  cell_coords = self.geometry.side_outer_cell_coords(geo_side_arg, element_index, coords)
363
389
  geo_cell_arg = self.geometry.side_to_cell_arg(geo_side_arg)
364
- return cell_outer_weight_gradient(
365
- geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell
366
- )
390
+ return cell_outer_weight_gradient(geo_cell_arg, basis_arg, cell_index, cell_coords, index_in_cell)
367
391
 
368
392
  return trace_element_outer_weight_gradient
369
393
 
370
394
  def __eq__(self, other: "TraceBasisSpace") -> bool:
371
395
  return self._topo == other._topo
396
+
397
+
398
+ class PointBasisSpace(BasisSpace):
399
+ """An unstructured :class:`BasisSpace` that is non-zero at a finite set of points only.
400
+
401
+ The node locations and nodal quadrature weights are defined by a :class:`Quadrature` formula.
402
+ """
403
+
404
+ def __init__(self, quadrature: Quadrature):
405
+ self._quadrature = quadrature
406
+
407
+ if quadrature.points_per_element() is None:
408
+ raise NotImplementedError("Varying number of points per element is not supported yet")
409
+
410
+ topology = DiscontinuousSpaceTopology(
411
+ geometry=quadrature.domain.geometry, nodes_per_element=quadrature.points_per_element()
412
+ )
413
+ super().__init__(topology)
414
+
415
+ self.BasisArg = quadrature.Arg
416
+ self.basis_arg_value = quadrature.arg_value
417
+ self.ORDER = 0
418
+
419
+ self.make_element_outer_weight = self.make_element_inner_weight
420
+ self.make_element_outer_weight_gradient = self.make_element_outer_weight_gradient
421
+
422
+ @property
423
+ def name(self):
424
+ return f"{self._quadrature.name}_Point"
425
+
426
+ def make_node_coords_in_element(self):
427
+ @cache.dynamic_func(suffix=self.name)
428
+ def node_coords_in_element(
429
+ elt_arg: self._quadrature.domain.ElementArg,
430
+ basis_arg: self.BasisArg,
431
+ element_index: ElementIndex,
432
+ node_index_in_elt: int,
433
+ ):
434
+ return self._quadrature.point_coords(elt_arg, basis_arg, element_index, node_index_in_elt)
435
+
436
+ return node_coords_in_element
437
+
438
+ def make_node_quadrature_weight(self):
439
+ @cache.dynamic_func(suffix=self.name)
440
+ def node_quadrature_weight(
441
+ elt_arg: self._quadrature.domain.ElementArg,
442
+ basis_arg: self.BasisArg,
443
+ element_index: ElementIndex,
444
+ node_index_in_elt: int,
445
+ ):
446
+ return self._quadrature.point_weight(elt_arg, basis_arg, element_index, node_index_in_elt)
447
+
448
+ return node_quadrature_weight
449
+
450
+ def make_element_inner_weight(self):
451
+ @cache.dynamic_func(suffix=self.name)
452
+ def element_inner_weight(
453
+ elt_arg: self._quadrature.domain.ElementArg,
454
+ basis_arg: self.BasisArg,
455
+ element_index: ElementIndex,
456
+ coords: Coords,
457
+ node_index_in_elt: int,
458
+ ):
459
+ qp_coord = self._quadrature.point_coords(elt_arg, basis_arg, element_index, node_index_in_elt)
460
+ return wp.select(wp.length_sq(coords - qp_coord) < 0.001, 0.0, 1.0)
461
+
462
+ return element_inner_weight
463
+
464
+ def make_element_inner_weight_gradient(self):
465
+ gradient_vec = cache.cached_vec_type(length=self.geometry.dimension, dtype=float)
466
+
467
+ @cache.dynamic_func(suffix=self.name)
468
+ def element_inner_weight_gradient(
469
+ elt_arg: self._quadrature.domain.ElementArg,
470
+ basis_arg: self.BasisArg,
471
+ element_index: ElementIndex,
472
+ coords: Coords,
473
+ node_index_in_elt: int,
474
+ ):
475
+ return gradient_vec(0.0)
476
+
477
+ return element_inner_weight_gradient
478
+
479
+ def make_trace_node_quadrature_weight(self, trace_basis):
480
+ @cache.dynamic_func(suffix=self.name)
481
+ def trace_node_quadrature_weight(
482
+ elt_arg: trace_basis.geometry.SideArg,
483
+ basis_arg: trace_basis.BasisArg,
484
+ element_index: ElementIndex,
485
+ node_index_in_elt: int,
486
+ ):
487
+ return 0.0
488
+
489
+ return trace_node_quadrature_weight
@@ -7,7 +7,7 @@ from warp.fem.geometry import Grid2D
7
7
  from warp.fem import cache
8
8
 
9
9
  from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
10
- from .basis_space import BasisSpace, TraceBasisSpace
10
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
11
11
 
12
12
  from .shape import ShapeFunction, ConstantShapeFunction
13
13
  from .shape import (
@@ -44,7 +44,7 @@ class Grid2DDiscontinuousSpaceTopology(
44
44
  pass
45
45
 
46
46
 
47
- class Grid2DBasisSpace(BasisSpace):
47
+ class Grid2DBasisSpace(ShapeBasisSpace):
48
48
  def __init__(self, topology: Grid2DSpaceTopology, shape: ShapeFunction):
49
49
  super().__init__(topology, shape)
50
50
 
@@ -7,7 +7,7 @@ from warp.fem.geometry import Grid3D
7
7
  from warp.fem import cache
8
8
 
9
9
  from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
10
- from .basis_space import BasisSpace, TraceBasisSpace
10
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
11
11
 
12
12
  from .shape import ShapeFunction, ConstantShapeFunction
13
13
  from .shape.cube_shape_function import (
@@ -45,7 +45,7 @@ class Grid3DDiscontinuousSpaceTopology(
45
45
  pass
46
46
 
47
47
 
48
- class Grid3DBasisSpace(BasisSpace):
48
+ class Grid3DBasisSpace(ShapeBasisSpace):
49
49
  def __init__(self, topology: Grid3DSpaceTopology, shape: ShapeFunction):
50
50
  super().__init__(topology, shape)
51
51
 
@@ -6,7 +6,7 @@ from warp.fem.geometry import Hexmesh
6
6
  from warp.fem import cache
7
7
 
8
8
  from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
9
- from .basis_space import BasisSpace, TraceBasisSpace
9
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
10
10
 
11
11
  from .shape import ShapeFunction, ConstantShapeFunction
12
12
  from .shape import (
@@ -121,7 +121,7 @@ class HexmeshDiscontinuousSpaceTopology(
121
121
  super().__init__(mesh, shape.NODES_PER_ELEMENT)
122
122
 
123
123
 
124
- class HexmeshBasisSpace(BasisSpace):
124
+ class HexmeshBasisSpace(ShapeBasisSpace):
125
125
  def __init__(self, topology: HexmeshSpaceTopology, shape: ShapeFunction):
126
126
  super().__init__(topology, shape)
127
127
 
@@ -1,15 +1,18 @@
1
1
  from typing import Any, Optional, Union
2
2
 
3
3
  import warp as wp
4
-
4
+ from warp.fem.cache import (
5
+ TemporaryStore,
6
+ borrow_temporary,
7
+ borrow_temporary_like,
8
+ cached_arg_value,
9
+ )
5
10
  from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
6
- from warp.fem.utils import compress_node_indices, _iota_kernel
7
11
  from warp.fem.types import NULL_NODE_INDEX
8
- from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary, borrow_temporary_like
12
+ from warp.fem.utils import _iota_kernel, compress_node_indices
9
13
 
10
- from .topology import SpaceTopology
11
14
  from .function_space import FunctionSpace
12
-
15
+ from .topology import SpaceTopology
13
16
 
14
17
  wp.set_module_options({"enable_backward": False})
15
18
 
@@ -272,7 +275,7 @@ class NodePartition(SpacePartition):
272
275
 
273
276
  if device.is_cuda:
274
277
  # TODO switch to synchronize_event once available
275
- wp.synchronize_stream(wp.get_stream())
278
+ wp.synchronize_stream(wp.get_stream(device))
276
279
 
277
280
  category_offsets.release()
278
281
 
@@ -6,7 +6,7 @@ from warp.fem.geometry import Quadmesh2D
6
6
  from warp.fem import cache
7
7
 
8
8
  from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
9
- from .basis_space import BasisSpace, TraceBasisSpace
9
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
10
10
 
11
11
  from .shape import ShapeFunction, ConstantShapeFunction
12
12
  from .shape import (
@@ -116,7 +116,7 @@ class Quadmesh2DDiscontinuousSpaceTopology(
116
116
  super().__init__(mesh, shape.NODES_PER_ELEMENT)
117
117
 
118
118
 
119
- class Quadmesh2DBasisSpace(BasisSpace):
119
+ class Quadmesh2DBasisSpace(ShapeBasisSpace):
120
120
  def __init__(self, topology: Quadmesh2DSpaceTopology, shape: ShapeFunction):
121
121
  super().__init__(topology, shape)
122
122
 
@@ -41,6 +41,7 @@ class CubeTripolynomialShapeFunctions:
41
41
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
42
42
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
43
43
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
44
+ self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
44
45
 
45
46
  self._node_ijk = self._make_node_ijk()
46
47
  self.node_type_and_type_index = self._make_node_type_and_type_index()
@@ -57,21 +58,21 @@ class CubeTripolynomialShapeFunctions:
57
58
  return wp.vec3(float(x), float(y), float(z))
58
59
 
59
60
  def _make_node_ijk(self):
60
- ORDER = self.ORDER
61
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
61
62
 
62
63
  def node_ijk(
63
64
  node_index_in_elt: int,
64
65
  ):
65
- node_i = node_index_in_elt // ((ORDER + 1) * (ORDER + 1))
66
- node_jk = node_index_in_elt - (ORDER + 1) * (ORDER + 1) * node_i
67
- node_j = node_jk // (ORDER + 1)
68
- node_k = node_jk - (ORDER + 1) * node_j
66
+ node_i = node_index_in_elt // (ORDER_PLUS_ONE * ORDER_PLUS_ONE)
67
+ node_jk = node_index_in_elt - ORDER_PLUS_ONE * ORDER_PLUS_ONE * node_i
68
+ node_j = node_jk // ORDER_PLUS_ONE
69
+ node_k = node_jk - ORDER_PLUS_ONE * node_j
69
70
  return node_i, node_j, node_k
70
71
 
71
72
  return cache.get_func(node_ijk, self.name)
72
73
 
73
74
  def _make_node_type_and_type_index(self):
74
- ORDER = wp.constant(self.ORDER)
75
+ ORDER = self.ORDER
75
76
 
76
77
  @cache.dynamic_func(suffix=self.name)
77
78
  def node_type_and_type_index(
@@ -190,13 +191,21 @@ class CubeTripolynomialShapeFunctions:
190
191
  ):
191
192
  return 0.25
192
193
 
194
+ def trace_node_quadrature_weight_open(
195
+ node_index_in_elt: int,
196
+ ):
197
+ return 0.0
198
+
199
+ if not is_closed(self.family):
200
+ return cache.get_func(trace_node_quadrature_weight_open, self.name)
201
+
193
202
  if ORDER == 1:
194
203
  return cache.get_func(trace_node_quadrature_weight_linear, self.name)
195
204
 
196
205
  return cache.get_func(trace_node_quadrature_weight, self.name)
197
206
 
198
207
  def make_element_inner_weight(self):
199
- ORDER = self.ORDER
208
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
200
209
  LOBATTO_COORDS = self.LOBATTO_COORDS
201
210
  LAGRANGE_SCALE = self.LAGRANGE_SCALE
202
211
 
@@ -207,7 +216,7 @@ class CubeTripolynomialShapeFunctions:
207
216
  node_i, node_j, node_k = self._node_ijk(node_index_in_elt)
208
217
 
209
218
  w = float(1.0)
210
- for k in range(ORDER + 1):
219
+ for k in range(ORDER_PLUS_ONE):
211
220
  if k != node_i:
212
221
  w *= coords[0] - LOBATTO_COORDS[k]
213
222
  if k != node_j:
@@ -230,13 +239,13 @@ class CubeTripolynomialShapeFunctions:
230
239
  wz = (1.0 - coords[2]) * (1.0 - v[2]) + v[2] * coords[2]
231
240
  return wx * wy * wz
232
241
 
233
- if ORDER == 1:
242
+ if self.ORDER == 1 and is_closed(self.family):
234
243
  return cache.get_func(element_inner_weight_linear, self.name)
235
244
 
236
245
  return cache.get_func(element_inner_weight, self.name)
237
246
 
238
247
  def make_element_inner_weight_gradient(self):
239
- ORDER = self.ORDER
248
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
240
249
  LOBATTO_COORDS = self.LOBATTO_COORDS
241
250
  LAGRANGE_SCALE = self.LAGRANGE_SCALE
242
251
 
@@ -249,7 +258,7 @@ class CubeTripolynomialShapeFunctions:
249
258
  prefix_xy = float(1.0)
250
259
  prefix_yz = float(1.0)
251
260
  prefix_zx = float(1.0)
252
- for k in range(ORDER + 1):
261
+ for k in range(ORDER_PLUS_ONE):
253
262
  if k != node_i:
254
263
  prefix_yz *= coords[0] - LOBATTO_COORDS[k]
255
264
  if k != node_j:
@@ -265,7 +274,7 @@ class CubeTripolynomialShapeFunctions:
265
274
  grad_y = float(0.0)
266
275
  grad_z = float(0.0)
267
276
 
268
- for k in range(ORDER + 1):
277
+ for k in range(ORDER_PLUS_ONE):
269
278
  if k != node_i:
270
279
  delta_x = coords[0] - LOBATTO_COORDS[k]
271
280
  grad_x = grad_x * delta_x + prefix_x
@@ -308,7 +317,7 @@ class CubeTripolynomialShapeFunctions:
308
317
 
309
318
  return wp.vec3(dx * wy * wz, dy * wz * wx, dz * wx * wy)
310
319
 
311
- if ORDER == 1:
320
+ if self.ORDER == 1 and is_closed(self.family):
312
321
  return cache.get_func(element_inner_weight_gradient_linear, self.name)
313
322
 
314
323
  return cache.get_func(element_inner_weight_gradient, self.name)
@@ -356,6 +365,7 @@ class CubeSerendipityShapeFunctions:
356
365
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
357
366
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
358
367
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
368
+ self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
359
369
 
360
370
  self.node_type_and_type_index = self._get_node_type_and_type_index()
361
371
  self._node_lobatto_indices = self._get_node_lobatto_indices()
@@ -466,6 +476,7 @@ class CubeSerendipityShapeFunctions:
466
476
 
467
477
  def make_element_inner_weight(self):
468
478
  ORDER = self.ORDER
479
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
469
480
 
470
481
  LOBATTO_COORDS = self.LOBATTO_COORDS
471
482
  LAGRANGE_SCALE = self.LAGRANGE_SCALE
@@ -511,7 +522,7 @@ class CubeSerendipityShapeFunctions:
511
522
  w *= wp.select(node_all[1] == 0, local_coords[1], 1.0 - local_coords[1])
512
523
  w *= wp.select(node_all[2] == 0, local_coords[2], 1.0 - local_coords[2])
513
524
 
514
- for k in range(ORDER + 1):
525
+ for k in range(ORDER_PLUS_ONE):
515
526
  if k != node_all[0]:
516
527
  w *= local_coords[0] - LOBATTO_COORDS[k]
517
528
  w *= LAGRANGE_SCALE[node_all[0]]
@@ -522,6 +533,7 @@ class CubeSerendipityShapeFunctions:
522
533
 
523
534
  def make_element_inner_weight_gradient(self):
524
535
  ORDER = self.ORDER
536
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
525
537
  LOBATTO_COORDS = self.LOBATTO_COORDS
526
538
  LAGRANGE_SCALE = self.LAGRANGE_SCALE
527
539
 
@@ -585,7 +597,7 @@ class CubeSerendipityShapeFunctions:
585
597
  w_alt = LAGRANGE_SCALE[node_all[0]]
586
598
  g_alt = float(0.0)
587
599
  prefix_alt = LAGRANGE_SCALE[node_all[0]]
588
- for k in range(ORDER + 1):
600
+ for k in range(ORDER_PLUS_ONE):
589
601
  if k != node_all[0]:
590
602
  delta_alt = local_coords[0] - LOBATTO_COORDS[k]
591
603
  w_alt *= delta_alt