warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/tests/test_fem.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
1
+ # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
2
  # NVIDIA CORPORATION and its licensors retain all intellectual property
3
3
  # and proprietary rights in and to this software, related documentation
4
4
  # and any modifications thereto. Any use, reproduction, disclosure or
@@ -10,17 +10,13 @@ import unittest
10
10
 
11
11
  import numpy as np
12
12
  import warp as wp
13
- from warp.tests.test_base import *
13
+ from warp.tests.unittest_utils import *
14
14
 
15
15
 
16
16
  from warp.fem import Field, Sample, Domain, Coords
17
- from warp.fem import Grid2D, Grid3D, Trimesh2D, Tetmesh, Quadmesh2D, Hexmesh, Geometry
18
- from warp.fem import make_polynomial_space, SymmetricTensorMapper, SkewSymmetricTensorMapper
19
- from warp.fem import make_test
20
- from warp.fem import Cells, Sides, BoundarySides
21
- from warp.fem import integrate, interpolate
22
17
  from warp.fem import integrand, div, grad, curl, D, normal
23
- from warp.fem import RegularQuadrature, Polynomial
18
+ import warp.fem as fem
19
+
24
20
  from warp.fem.types import make_free_sample
25
21
  from warp.fem.geometry.closest_point import project_on_tri_at_origin, project_on_tet_at_origin
26
22
  from warp.fem.geometry import DeformedGeometry
@@ -39,13 +35,13 @@ def linear_form(s: Sample, u: Field):
39
35
  def test_integrate_gradient(test_case, device):
40
36
  with wp.ScopedDevice(device):
41
37
  # Grid geometry
42
- geo = Grid2D(res=wp.vec2i(5))
38
+ geo = fem.Grid2D(res=wp.vec2i(5))
43
39
 
44
40
  # Domain and function spaces
45
- domain = Cells(geometry=geo)
46
- quadrature = RegularQuadrature(domain=domain, order=3)
41
+ domain = fem.Cells(geometry=geo)
42
+ quadrature = fem.RegularQuadrature(domain=domain, order=3)
47
43
 
48
- scalar_space = make_polynomial_space(geo, degree=3)
44
+ scalar_space = fem.make_polynomial_space(geo, degree=3)
49
45
 
50
46
  u = scalar_space.make_field()
51
47
  u.dof_values = wp.zeros_like(u.dof_values, requires_grad=True)
@@ -56,17 +52,100 @@ def test_integrate_gradient(test_case, device):
56
52
 
57
53
  # forward pass
58
54
  with tape:
59
- integrate(linear_form, quadrature=quadrature, fields={"u": u}, output=result)
55
+ fem.integrate(linear_form, quadrature=quadrature, fields={"u": u}, output=result)
60
56
 
61
57
  tape.backward(result)
62
58
 
63
- test = make_test(space=scalar_space, domain=domain)
64
- rhs = integrate(linear_form, quadrature=quadrature, fields={"u": test})
59
+ test = fem.make_test(space=scalar_space, domain=domain)
60
+ rhs = fem.integrate(linear_form, quadrature=quadrature, fields={"u": test})
65
61
 
66
62
  err = np.linalg.norm(rhs.numpy() - u.dof_values.grad.numpy())
67
63
  test_case.assertLess(err, 1.0e-8)
68
64
 
69
65
 
66
+ @fem.integrand
67
+ def bilinear_field(s: fem.Sample, domain: fem.Domain):
68
+ x = domain(s)
69
+ return x[0] * x[1]
70
+
71
+
72
+ @fem.integrand
73
+ def grad_field(s: fem.Sample, p: fem.Field):
74
+ return fem.grad(p, s)
75
+
76
+
77
+ def test_interpolate_gradient(test_case, device):
78
+ with wp.ScopedDevice(device):
79
+ # Quad mesh with single element
80
+ # so we can test gradient with respect to vertex positions
81
+ positions = wp.array([[0.0, 0.0], [0.0, 2.0], [2.0, 0.0], [2.0, 2.0]], dtype=wp.vec2, requires_grad=True)
82
+ quads = wp.array([[0, 2, 3, 1]], dtype=int)
83
+ geo = fem.Quadmesh2D(quads, positions)
84
+
85
+ # Quadratic scalar space
86
+ scalar_space = fem.make_polynomial_space(geo, degree=2)
87
+
88
+ # Point-based vector space
89
+ # So we can test gradient with respect to inteprolation point position
90
+ point_coords = wp.array([[[0.5, 0.5, 0.0]]], dtype=fem.Coords, requires_grad=True)
91
+ interpolation_nodes = fem.PointBasisSpace(
92
+ fem.ExplicitQuadrature(domain=fem.Cells(geo), points=point_coords, weights=wp.array([[1.0]], dtype=float))
93
+ )
94
+ vector_space = fem.make_collocated_function_space(interpolation_nodes, dtype=wp.vec2)
95
+
96
+ # Initialize scalar field with known function
97
+ scalar_field = scalar_space.make_field()
98
+ scalar_field.dof_values.requires_grad = True
99
+ fem.interpolate(bilinear_field, dest=scalar_field)
100
+
101
+ # Interpolate gradient at center point
102
+ vector_field = vector_space.make_field()
103
+ vector_field.dof_values.requires_grad = True
104
+ tape = wp.Tape()
105
+ with tape:
106
+ fem.interpolate(grad_field, dest=vector_field, fields={"p": scalar_field})
107
+
108
+ assert_np_equal(vector_field.dof_values.numpy(), np.array([[1.0, 1.0]]))
109
+
110
+ vector_field.dof_values.grad.assign([1.0, 0.0])
111
+ tape.backward()
112
+
113
+ assert_np_equal(scalar_field.dof_values.grad.numpy(), np.array([0.0, 0.0, 0.0, 0.0, 0.0, -0.5, 0.0, 0.5, 0.0]))
114
+ assert_np_equal(
115
+ geo.positions.grad.numpy(),
116
+ np.array(
117
+ [
118
+ [0.25, 0.25],
119
+ [0.25, 0.25],
120
+ [-0.25, -0.25],
121
+ [-0.25, -0.25],
122
+ ]
123
+ ),
124
+ )
125
+ assert_np_equal(point_coords.grad.numpy(), np.array([[[0.0, 2.0, 0.0]]]))
126
+
127
+ tape.zero()
128
+ scalar_field.dof_values.grad.zero_()
129
+ geo.positions.grad.zero_()
130
+ point_coords.grad.zero_()
131
+
132
+ vector_field.dof_values.grad.assign([0.0, 1.0])
133
+ tape.backward()
134
+
135
+ assert_np_equal(scalar_field.dof_values.grad.numpy(), np.array([0.0, 0.0, 0.0, 0.0, -0.5, 0.0, 0.5, 0.0, 0.0]))
136
+ assert_np_equal(
137
+ geo.positions.grad.numpy(),
138
+ np.array(
139
+ [
140
+ [0.25, 0.25],
141
+ [-0.25, -0.25],
142
+ [0.25, 0.25],
143
+ [-0.25, -0.25],
144
+ ]
145
+ ),
146
+ )
147
+ assert_np_equal(point_coords.grad.numpy(), np.array([[[2.0, 0.0, 0.0]]]))
148
+
70
149
  @integrand
71
150
  def vector_divergence_form(s: Sample, u: Field, q: Field):
72
151
  return div(u, s) * q(s)
@@ -87,14 +166,14 @@ def test_vector_divergence_theorem(test_case, device):
87
166
 
88
167
  with wp.ScopedDevice(device):
89
168
  # Grid geometry
90
- geo = Grid2D(res=wp.vec2i(5))
169
+ geo = fem.Grid2D(res=wp.vec2i(5))
91
170
 
92
171
  # Domain and function spaces
93
- interior = Cells(geometry=geo)
94
- boundary = BoundarySides(geometry=geo)
172
+ interior = fem.Cells(geometry=geo)
173
+ boundary = fem.BoundarySides(geometry=geo)
95
174
 
96
- vector_space = make_polynomial_space(geo, degree=2, dtype=wp.vec2)
97
- scalar_space = make_polynomial_space(geo, degree=1, dtype=float)
175
+ vector_space = fem.make_polynomial_space(geo, degree=2, dtype=wp.vec2)
176
+ scalar_space = fem.make_polynomial_space(geo, degree=1, dtype=float)
98
177
 
99
178
  u = vector_space.make_field()
100
179
  u.dof_values = rng.random(size=(u.dof_values.shape[0], 2))
@@ -103,15 +182,15 @@ def test_vector_divergence_theorem(test_case, device):
103
182
  constant_one = scalar_space.make_field()
104
183
  constant_one.dof_values.fill_(1.0)
105
184
 
106
- interior_quadrature = RegularQuadrature(domain=interior, order=vector_space.degree)
107
- boundary_quadrature = RegularQuadrature(domain=boundary, order=vector_space.degree)
108
- div_int = integrate(
185
+ interior_quadrature = fem.RegularQuadrature(domain=interior, order=vector_space.degree)
186
+ boundary_quadrature = fem.RegularQuadrature(domain=boundary, order=vector_space.degree)
187
+ div_int = fem.integrate(
109
188
  vector_divergence_form,
110
189
  quadrature=interior_quadrature,
111
190
  fields={"u": u, "q": constant_one},
112
191
  kernel_options={"enable_backward": False},
113
192
  )
114
- boundary_int = integrate(
193
+ boundary_int = fem.integrate(
115
194
  vector_boundary_form,
116
195
  quadrature=boundary_quadrature,
117
196
  fields={"u": u.trace(), "q": constant_one.trace()},
@@ -124,21 +203,21 @@ def test_vector_divergence_theorem(test_case, device):
124
203
  q = scalar_space.make_field()
125
204
  q.dof_values = rng.random(size=q.dof_values.shape[0])
126
205
 
127
- interior_quadrature = RegularQuadrature(domain=interior, order=vector_space.degree + scalar_space.degree)
128
- boundary_quadrature = RegularQuadrature(domain=boundary, order=vector_space.degree + scalar_space.degree)
129
- div_int = integrate(
206
+ interior_quadrature = fem.RegularQuadrature(domain=interior, order=vector_space.degree + scalar_space.degree)
207
+ boundary_quadrature = fem.RegularQuadrature(domain=boundary, order=vector_space.degree + scalar_space.degree)
208
+ div_int = fem.integrate(
130
209
  vector_divergence_form,
131
210
  quadrature=interior_quadrature,
132
211
  fields={"u": u, "q": q},
133
212
  kernel_options={"enable_backward": False},
134
213
  )
135
- grad_int = integrate(
214
+ grad_int = fem.integrate(
136
215
  vector_grad_form,
137
216
  quadrature=interior_quadrature,
138
217
  fields={"u": u, "q": q},
139
218
  kernel_options={"enable_backward": False},
140
219
  )
141
- boundary_int = integrate(
220
+ boundary_int = fem.integrate(
142
221
  vector_boundary_form,
143
222
  quadrature=boundary_quadrature,
144
223
  fields={"u": u.trace(), "q": q.trace()},
@@ -168,14 +247,14 @@ def test_tensor_divergence_theorem(test_case, device):
168
247
 
169
248
  with wp.ScopedDevice(device):
170
249
  # Grid geometry
171
- geo = Grid2D(res=wp.vec2i(5))
250
+ geo = fem.Grid2D(res=wp.vec2i(5))
172
251
 
173
252
  # Domain and function spaces
174
- interior = Cells(geometry=geo)
175
- boundary = BoundarySides(geometry=geo)
253
+ interior = fem.Cells(geometry=geo)
254
+ boundary = fem.BoundarySides(geometry=geo)
176
255
 
177
- tensor_space = make_polynomial_space(geo, degree=2, dtype=wp.mat22)
178
- vector_space = make_polynomial_space(geo, degree=1, dtype=wp.vec2)
256
+ tensor_space = fem.make_polynomial_space(geo, degree=2, dtype=wp.mat22)
257
+ vector_space = fem.make_polynomial_space(geo, degree=1, dtype=wp.vec2)
179
258
 
180
259
  tau = tensor_space.make_field()
181
260
  tau.dof_values = rng.random(size=(tau.dof_values.shape[0], 2, 2))
@@ -184,15 +263,15 @@ def test_tensor_divergence_theorem(test_case, device):
184
263
  constant_vec = vector_space.make_field()
185
264
  constant_vec.dof_values.fill_(wp.vec2(0.5, 2.0))
186
265
 
187
- interior_quadrature = RegularQuadrature(domain=interior, order=tensor_space.degree)
188
- boundary_quadrature = RegularQuadrature(domain=boundary, order=tensor_space.degree)
189
- div_int = integrate(
266
+ interior_quadrature = fem.RegularQuadrature(domain=interior, order=tensor_space.degree)
267
+ boundary_quadrature = fem.RegularQuadrature(domain=boundary, order=tensor_space.degree)
268
+ div_int = fem.integrate(
190
269
  tensor_divergence_form,
191
270
  quadrature=interior_quadrature,
192
271
  fields={"tau": tau, "v": constant_vec},
193
272
  kernel_options={"enable_backward": False},
194
273
  )
195
- boundary_int = integrate(
274
+ boundary_int = fem.integrate(
196
275
  tensor_boundary_form,
197
276
  quadrature=boundary_quadrature,
198
277
  fields={"tau": tau.trace(), "v": constant_vec.trace()},
@@ -205,21 +284,21 @@ def test_tensor_divergence_theorem(test_case, device):
205
284
  v = vector_space.make_field()
206
285
  v.dof_values = rng.random(size=(v.dof_values.shape[0], 2))
207
286
 
208
- interior_quadrature = RegularQuadrature(domain=interior, order=tensor_space.degree + vector_space.degree)
209
- boundary_quadrature = RegularQuadrature(domain=boundary, order=tensor_space.degree + vector_space.degree)
210
- div_int = integrate(
287
+ interior_quadrature = fem.RegularQuadrature(domain=interior, order=tensor_space.degree + vector_space.degree)
288
+ boundary_quadrature = fem.RegularQuadrature(domain=boundary, order=tensor_space.degree + vector_space.degree)
289
+ div_int = fem.integrate(
211
290
  tensor_divergence_form,
212
291
  quadrature=interior_quadrature,
213
292
  fields={"tau": tau, "v": v},
214
293
  kernel_options={"enable_backward": False},
215
294
  )
216
- grad_int = integrate(
295
+ grad_int = fem.integrate(
217
296
  tensor_grad_form,
218
297
  quadrature=interior_quadrature,
219
298
  fields={"tau": tau, "v": v},
220
299
  kernel_options={"enable_backward": False},
221
300
  )
222
- boundary_int = integrate(
301
+ boundary_int = fem.integrate(
223
302
  tensor_boundary_form,
224
303
  quadrature=boundary_quadrature,
225
304
  fields={"tau": tau.trace(), "v": v.trace()},
@@ -239,18 +318,18 @@ def test_grad_decomposition(test_case, device):
239
318
 
240
319
  with wp.ScopedDevice(device):
241
320
  # Grid geometry
242
- geo = Grid3D(res=wp.vec3i(5))
321
+ geo = fem.Grid3D(res=wp.vec3i(5))
243
322
 
244
323
  # Domain and function spaces
245
- domain = Cells(geometry=geo)
246
- quadrature = RegularQuadrature(domain=domain, order=4)
324
+ domain = fem.Cells(geometry=geo)
325
+ quadrature = fem.RegularQuadrature(domain=domain, order=4)
247
326
 
248
- vector_space = make_polynomial_space(geo, degree=2, dtype=wp.vec3)
327
+ vector_space = fem.make_polynomial_space(geo, degree=2, dtype=wp.vec3)
249
328
  u = vector_space.make_field()
250
329
 
251
330
  u.dof_values = rng.random(size=(u.dof_values.shape[0], 3))
252
331
 
253
- err = integrate(grad_decomposition, quadrature=quadrature, fields={"u": u, "v": u})
332
+ err = fem.integrate(grad_decomposition, quadrature=quadrature, fields={"u": u, "v": u})
254
333
  test_case.assertLess(err, 1.0e-8)
255
334
 
256
335
 
@@ -300,7 +379,7 @@ def _gen_hexmesh(N):
300
379
  return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
301
380
 
302
381
 
303
- def _launch_test_geometry_kernel(geo: Geometry, device):
382
+ def _launch_test_geometry_kernel(geo: fem.Geometry, device):
304
383
  @dynamic_kernel(suffix=geo.name, kernel_options={"enable_backward": False})
305
384
  def test_geo_cells_kernel(
306
385
  cell_arg: geo.CellArg,
@@ -368,7 +447,7 @@ def _launch_test_geometry_kernel(geo: Geometry, device):
368
447
 
369
448
  cell_measures = wp.zeros(dtype=float, device=device, shape=geo.cell_count())
370
449
 
371
- cell_quadrature = RegularQuadrature(Cells(geo), order=2)
450
+ cell_quadrature = fem.RegularQuadrature(fem.Cells(geo), order=2)
372
451
  cell_qps = wp.array(cell_quadrature.points, dtype=Coords, device=device)
373
452
  cell_qp_weights = wp.array(cell_quadrature.weights, dtype=float, device=device)
374
453
 
@@ -381,7 +460,7 @@ def _launch_test_geometry_kernel(geo: Geometry, device):
381
460
 
382
461
  side_measures = wp.zeros(dtype=float, device=device, shape=geo.side_count())
383
462
 
384
- side_quadrature = RegularQuadrature(Sides(geo), order=2)
463
+ side_quadrature = fem.RegularQuadrature(fem.Sides(geo), order=2)
385
464
  side_qps = wp.array(side_quadrature.points, dtype=Coords, device=device)
386
465
  side_qp_weights = wp.array(side_quadrature.weights, dtype=float, device=device)
387
466
 
@@ -398,7 +477,7 @@ def _launch_test_geometry_kernel(geo: Geometry, device):
398
477
  def test_grid_2d(test_case, device):
399
478
  N = 3
400
479
 
401
- geo = Grid2D(res=wp.vec2i(N))
480
+ geo = fem.Grid2D(res=wp.vec2i(N))
402
481
 
403
482
  test_case.assertEqual(geo.cell_count(), N**2)
404
483
  test_case.assertEqual(geo.vertex_count(), (N + 1) ** 2)
@@ -417,7 +496,7 @@ def test_triangle_mesh(test_case, device):
417
496
  with wp.ScopedDevice(device):
418
497
  positions, tri_vidx = _gen_trimesh(N)
419
498
 
420
- geo = Trimesh2D(tri_vertex_indices=tri_vidx, positions=positions)
499
+ geo = fem.Trimesh2D(tri_vertex_indices=tri_vidx, positions=positions)
421
500
 
422
501
  test_case.assertEqual(geo.cell_count(), 2 * (N) ** 2)
423
502
  test_case.assertEqual(geo.vertex_count(), (N + 1) ** 2)
@@ -436,7 +515,7 @@ def test_quad_mesh(test_case, device):
436
515
  with wp.ScopedDevice(device):
437
516
  positions, quad_vidx = _gen_quadmesh(N)
438
517
 
439
- geo = Quadmesh2D(quad_vertex_indices=quad_vidx, positions=positions)
518
+ geo = fem.Quadmesh2D(quad_vertex_indices=quad_vidx, positions=positions)
440
519
 
441
520
  test_case.assertEqual(geo.cell_count(), N**2)
442
521
  test_case.assertEqual(geo.vertex_count(), (N + 1) ** 2)
@@ -452,7 +531,7 @@ def test_quad_mesh(test_case, device):
452
531
  def test_grid_3d(test_case, device):
453
532
  N = 3
454
533
 
455
- geo = Grid3D(res=wp.vec3i(N))
534
+ geo = fem.Grid3D(res=wp.vec3i(N))
456
535
 
457
536
  test_case.assertEqual(geo.cell_count(), (N) ** 3)
458
537
  test_case.assertEqual(geo.vertex_count(), (N + 1) ** 3)
@@ -472,7 +551,7 @@ def test_tet_mesh(test_case, device):
472
551
  with wp.ScopedDevice(device):
473
552
  positions, tet_vidx = _gen_tetmesh(N)
474
553
 
475
- geo = Tetmesh(tet_vertex_indices=tet_vidx, positions=positions)
554
+ geo = fem.Tetmesh(tet_vertex_indices=tet_vidx, positions=positions)
476
555
 
477
556
  test_case.assertEqual(geo.cell_count(), 5 * (N) ** 3)
478
557
  test_case.assertEqual(geo.vertex_count(), (N + 1) ** 3)
@@ -492,7 +571,7 @@ def test_hex_mesh(test_case, device):
492
571
  with wp.ScopedDevice(device):
493
572
  positions, tet_vidx = _gen_hexmesh(N)
494
573
 
495
- geo = Hexmesh(hex_vertex_indices=tet_vidx, positions=positions)
574
+ geo = fem.Hexmesh(hex_vertex_indices=tet_vidx, positions=positions)
496
575
 
497
576
  test_case.assertEqual(geo.cell_count(), (N) ** 3)
498
577
  test_case.assertEqual(geo.vertex_count(), (N + 1) ** 3)
@@ -518,15 +597,15 @@ def test_deformed_geometry(test_case, device):
518
597
  with wp.ScopedDevice(device):
519
598
  positions, tet_vidx = _gen_tetmesh(N)
520
599
 
521
- geo = Tetmesh(tet_vertex_indices=tet_vidx, positions=positions)
600
+ geo = fem.Tetmesh(tet_vertex_indices=tet_vidx, positions=positions)
522
601
 
523
602
  translation = [1.0, 2.0, 3.0]
524
603
  rotation = [0.0, math.pi / 4.0, 0.0]
525
604
  scale = 2.0
526
605
 
527
- vector_space = make_polynomial_space(geo, dtype=wp.vec3, degree=2)
606
+ vector_space = fem.make_polynomial_space(geo, dtype=wp.vec3, degree=2)
528
607
  pos_field = vector_space.make_field()
529
- interpolate(
608
+ fem.interpolate(
530
609
  _rigid_deformation_field,
531
610
  dest=pos_field,
532
611
  values={"translation": translation, "rotation": rotation, "scale": scale},
@@ -705,9 +784,9 @@ def test_dof_mapper(test_case, device):
705
784
  matrix_types = [wp.mat22, wp.mat33]
706
785
 
707
786
  # Symmetric mapper
708
- for mapping in SymmetricTensorMapper.Mapping:
787
+ for mapping in fem.SymmetricTensorMapper.Mapping:
709
788
  for dtype in matrix_types:
710
- mapper = SymmetricTensorMapper(dtype, mapping=mapping)
789
+ mapper = fem.SymmetricTensorMapper(dtype, mapping=mapping)
711
790
  dof_dtype = mapper.dof_dtype
712
791
 
713
792
  for k in range(dof_dtype._length_):
@@ -727,7 +806,7 @@ def test_dof_mapper(test_case, device):
727
806
 
728
807
  # Skew-symmetric mapper
729
808
  for dtype in matrix_types:
730
- mapper = SkewSymmetricTensorMapper(dtype)
809
+ mapper = fem.SkewSymmetricTensorMapper(dtype)
731
810
  dof_dtype = mapper.dof_dtype
732
811
 
733
812
  if hasattr(dof_dtype, "_length_"):
@@ -879,9 +958,9 @@ def test_square_shape_functions(test_case, device):
879
958
  param_delta = wp.normalize(wp.vec2(wp.randf(state), wp.randf(state))) * epsilon
880
959
  return param_delta, Coords(param_delta[0], param_delta[1], 0.0)
881
960
 
882
- Q_1 = shape.SquareBipolynomialShapeFunctions(degree=1, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
883
- Q_2 = shape.SquareBipolynomialShapeFunctions(degree=2, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
884
- Q_3 = shape.SquareBipolynomialShapeFunctions(degree=3, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
961
+ Q_1 = shape.SquareBipolynomialShapeFunctions(degree=1, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
962
+ Q_2 = shape.SquareBipolynomialShapeFunctions(degree=2, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
963
+ Q_3 = shape.SquareBipolynomialShapeFunctions(degree=3, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
885
964
 
886
965
  test_shape_function_weight(test_case, Q_1, square_coord_sampler, SQUARE_CENTER_COORDS)
887
966
  test_shape_function_weight(test_case, Q_2, square_coord_sampler, SQUARE_CENTER_COORDS)
@@ -893,8 +972,19 @@ def test_square_shape_functions(test_case, device):
893
972
  test_shape_function_gradient(test_case, Q_2, square_coord_sampler, square_coord_delta_sampler)
894
973
  test_shape_function_gradient(test_case, Q_3, square_coord_sampler, square_coord_delta_sampler)
895
974
 
896
- S_2 = shape.SquareSerendipityShapeFunctions(degree=2, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
897
- S_3 = shape.SquareSerendipityShapeFunctions(degree=3, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
975
+ Q_1 = shape.SquareBipolynomialShapeFunctions(degree=1, family=fem.Polynomial.GAUSS_LEGENDRE)
976
+ Q_2 = shape.SquareBipolynomialShapeFunctions(degree=2, family=fem.Polynomial.GAUSS_LEGENDRE)
977
+ Q_3 = shape.SquareBipolynomialShapeFunctions(degree=3, family=fem.Polynomial.GAUSS_LEGENDRE)
978
+
979
+ test_shape_function_weight(test_case, Q_1, square_coord_sampler, SQUARE_CENTER_COORDS)
980
+ test_shape_function_weight(test_case, Q_2, square_coord_sampler, SQUARE_CENTER_COORDS)
981
+ test_shape_function_weight(test_case, Q_3, square_coord_sampler, SQUARE_CENTER_COORDS)
982
+ test_shape_function_gradient(test_case, Q_1, square_coord_sampler, square_coord_delta_sampler)
983
+ test_shape_function_gradient(test_case, Q_2, square_coord_sampler, square_coord_delta_sampler)
984
+ test_shape_function_gradient(test_case, Q_3, square_coord_sampler, square_coord_delta_sampler)
985
+
986
+ S_2 = shape.SquareSerendipityShapeFunctions(degree=2, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
987
+ S_3 = shape.SquareSerendipityShapeFunctions(degree=3, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
898
988
 
899
989
  test_shape_function_weight(test_case, S_2, square_coord_sampler, SQUARE_CENTER_COORDS)
900
990
  test_shape_function_weight(test_case, S_3, square_coord_sampler, SQUARE_CENTER_COORDS)
@@ -930,9 +1020,9 @@ def test_cube_shape_functions(test_case, device):
930
1020
  param_delta = wp.normalize(wp.vec3(wp.randf(state), wp.randf(state), wp.randf(state))) * epsilon
931
1021
  return param_delta, param_delta
932
1022
 
933
- Q_1 = shape.CubeTripolynomialShapeFunctions(degree=1, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
934
- Q_2 = shape.CubeTripolynomialShapeFunctions(degree=2, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
935
- Q_3 = shape.CubeTripolynomialShapeFunctions(degree=3, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
1023
+ Q_1 = shape.CubeTripolynomialShapeFunctions(degree=1, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
1024
+ Q_2 = shape.CubeTripolynomialShapeFunctions(degree=2, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
1025
+ Q_3 = shape.CubeTripolynomialShapeFunctions(degree=3, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
936
1026
 
937
1027
  test_shape_function_weight(test_case, Q_1, cube_coord_sampler, CUBE_CENTER_COORDS)
938
1028
  test_shape_function_weight(test_case, Q_2, cube_coord_sampler, CUBE_CENTER_COORDS)
@@ -944,8 +1034,19 @@ def test_cube_shape_functions(test_case, device):
944
1034
  test_shape_function_gradient(test_case, Q_2, cube_coord_sampler, cube_coord_delta_sampler)
945
1035
  test_shape_function_gradient(test_case, Q_3, cube_coord_sampler, cube_coord_delta_sampler)
946
1036
 
947
- S_2 = shape.CubeSerendipityShapeFunctions(degree=2, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
948
- S_3 = shape.CubeSerendipityShapeFunctions(degree=3, family=Polynomial.LOBATTO_GAUSS_LEGENDRE)
1037
+ Q_1 = shape.CubeTripolynomialShapeFunctions(degree=1, family=fem.Polynomial.GAUSS_LEGENDRE)
1038
+ Q_2 = shape.CubeTripolynomialShapeFunctions(degree=2, family=fem.Polynomial.GAUSS_LEGENDRE)
1039
+ Q_3 = shape.CubeTripolynomialShapeFunctions(degree=3, family=fem.Polynomial.GAUSS_LEGENDRE)
1040
+
1041
+ test_shape_function_weight(test_case, Q_1, cube_coord_sampler, CUBE_CENTER_COORDS)
1042
+ test_shape_function_weight(test_case, Q_2, cube_coord_sampler, CUBE_CENTER_COORDS)
1043
+ test_shape_function_weight(test_case, Q_3, cube_coord_sampler, CUBE_CENTER_COORDS)
1044
+ test_shape_function_gradient(test_case, Q_1, cube_coord_sampler, cube_coord_delta_sampler)
1045
+ test_shape_function_gradient(test_case, Q_2, cube_coord_sampler, cube_coord_delta_sampler)
1046
+ test_shape_function_gradient(test_case, Q_3, cube_coord_sampler, cube_coord_delta_sampler)
1047
+
1048
+ S_2 = shape.CubeSerendipityShapeFunctions(degree=2, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
1049
+ S_3 = shape.CubeSerendipityShapeFunctions(degree=3, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
949
1050
 
950
1051
  test_shape_function_weight(test_case, S_2, cube_coord_sampler, CUBE_CENTER_COORDS)
951
1052
  test_shape_function_weight(test_case, S_3, cube_coord_sampler, CUBE_CENTER_COORDS)
@@ -1054,35 +1155,117 @@ def test_tet_shape_functions(test_case, device):
1054
1155
  wp.synchronize()
1055
1156
 
1056
1157
 
1057
- def register(parent):
1058
- devices = get_test_devices()
1158
+ def test_point_basis(test_case, device):
1159
+ geo = fem.Grid2D(res=wp.vec2i(2))
1160
+
1161
+ domain = fem.Cells(geo)
1162
+
1163
+ quadrature = fem.RegularQuadrature(domain, order=2, family=fem.Polynomial.GAUSS_LEGENDRE)
1164
+ point_basis = fem.PointBasisSpace(quadrature)
1165
+
1166
+ point_space = fem.make_collocated_function_space(point_basis)
1167
+ point_test = fem.make_test(point_space, domain=domain)
1168
+
1169
+ # Sample at particle positions
1170
+ ones = fem.integrate(linear_form, fields={"u": point_test}, nodal=True)
1171
+ test_case.assertAlmostEqual(np.sum(ones.numpy()), 1.0, places=5)
1172
+
1173
+ # Sampling outside of particle positions
1174
+ other_quadrature = fem.RegularQuadrature(domain, order=2, family=fem.Polynomial.LOBATTO_GAUSS_LEGENDRE)
1175
+ zeros = fem.integrate(linear_form, quadrature=other_quadrature, fields={"u": point_test})
1176
+
1177
+ test_case.assertAlmostEqual(np.sum(zeros.numpy()), 0.0, places=5)
1178
+
1179
+
1180
+ @fem.integrand
1181
+ def _bicubic(s: Sample, domain: Domain):
1182
+ x = domain(s)
1183
+ return wp.pow(x[0], 3.0) * wp.pow(x[1], 3.0)
1184
+
1185
+
1186
+ @fem.integrand
1187
+ def _piecewise_constant(s: Sample):
1188
+ return float(s.element_index)
1189
+
1190
+
1191
+ def test_particle_quadratures(test_case, device):
1192
+ geo = fem.Grid2D(res=wp.vec2i(2))
1193
+
1194
+ domain = fem.Cells(geo)
1195
+ points, weights = domain.reference_element().instantiate_quadrature(order=4, family=fem.Polynomial.GAUSS_LEGENDRE)
1196
+ points_per_cell = len(points)
1197
+
1198
+ points = points * domain.element_count()
1199
+ weights = weights * domain.element_count()
1200
+
1201
+ points = wp.array(points, shape=(domain.element_count(), points_per_cell), dtype=Coords, device=device)
1202
+ weights = wp.array(weights, shape=(domain.element_count(), points_per_cell), dtype=float, device=device)
1203
+
1204
+ explicit_quadrature = fem.ExplicitQuadrature(domain, points, weights)
1205
+
1206
+ test_case.assertEqual(explicit_quadrature.points_per_element(), points_per_cell)
1207
+ test_case.assertEqual(explicit_quadrature.total_point_count(), points_per_cell * geo.cell_count())
1208
+
1209
+ val = fem.integrate(_bicubic, quadrature=explicit_quadrature)
1210
+ test_case.assertAlmostEqual(val, 1.0 / 16, places=5)
1211
+
1212
+ element_indices = wp.array([3, 3, 2], dtype=int, device=device)
1213
+ element_coords = wp.array(
1214
+ [
1215
+ [0.25, 0.5, 0.0],
1216
+ [0.5, 0.25, 0.0],
1217
+ [0.5, 0.5, 0.0],
1218
+ ],
1219
+ dtype=Coords,
1220
+ device=device,
1221
+ )
1222
+
1223
+ pic_quadrature = fem.PicQuadrature(domain, positions=(element_indices, element_coords))
1224
+
1225
+ test_case.assertIsNone(pic_quadrature.points_per_element())
1226
+ test_case.assertEqual(pic_quadrature.total_point_count(), 3)
1227
+ test_case.assertEqual(pic_quadrature.active_cell_count(), 2)
1228
+
1229
+ val = fem.integrate(_piecewise_constant, quadrature=pic_quadrature)
1230
+ test_case.assertAlmostEqual(val, 1.25, places=5)
1231
+
1232
+
1233
+ devices = get_test_devices()
1234
+
1235
+
1236
+ class TestFem(unittest.TestCase):
1237
+ pass
1238
+
1239
+
1240
+ add_function_test(TestFem, "test_regular_quadrature", test_regular_quadrature)
1241
+ add_function_test(TestFem, "test_closest_point_queries", test_closest_point_queries)
1242
+ add_function_test(TestFem, "test_grad_decomposition", test_grad_decomposition, devices=devices)
1243
+ add_function_test(TestFem, "test_integrate_gradient", test_integrate_gradient, devices=devices)
1244
+ add_function_test(TestFem, "test_interpolate_gradient", test_interpolate_gradient, devices=devices)
1245
+ add_function_test(TestFem, "test_vector_divergence_theorem", test_vector_divergence_theorem, devices=devices)
1246
+ add_function_test(TestFem, "test_tensor_divergence_theorem", test_tensor_divergence_theorem, devices=devices)
1247
+ add_function_test(TestFem, "test_grid_2d", test_grid_2d, devices=devices)
1248
+ add_function_test(TestFem, "test_triangle_mesh", test_triangle_mesh, devices=devices)
1249
+ add_function_test(TestFem, "test_quad_mesh", test_quad_mesh, devices=devices)
1250
+ add_function_test(TestFem, "test_grid_3d", test_grid_3d, devices=devices)
1251
+ add_function_test(TestFem, "test_tet_mesh", test_tet_mesh, devices=devices)
1252
+ add_function_test(TestFem, "test_hex_mesh", test_hex_mesh, devices=devices)
1253
+ add_function_test(TestFem, "test_deformed_geometry", test_deformed_geometry, devices=devices)
1254
+ add_function_test(TestFem, "test_dof_mapper", test_dof_mapper)
1255
+ add_function_test(TestFem, "test_point_basis", test_point_basis)
1256
+ add_function_test(TestFem, "test_particle_quadratures", test_particle_quadratures)
1257
+
1059
1258
 
1060
- class TestFem(parent):
1061
- pass
1259
+ class TestFemShapeFunctions(unittest.TestCase):
1260
+ pass
1062
1261
 
1063
- add_function_test(TestFem, "test_regular_quadrature", test_regular_quadrature)
1064
- add_function_test(TestFem, "test_closest_point_queries", test_closest_point_queries)
1065
- add_function_test(TestFem, "test_grad_decomposition", test_grad_decomposition, devices=devices)
1066
- add_function_test(TestFem, "test_integrate_gradient", test_integrate_gradient, devices=devices)
1067
- add_function_test(TestFem, "test_vector_divergence_theorem", test_vector_divergence_theorem, devices=devices)
1068
- add_function_test(TestFem, "test_tensor_divergence_theorem", test_tensor_divergence_theorem, devices=devices)
1069
- add_function_test(TestFem, "test_grid_2d", test_grid_2d, devices=devices)
1070
- add_function_test(TestFem, "test_triangle_mesh", test_triangle_mesh, devices=devices)
1071
- add_function_test(TestFem, "test_quad_mesh", test_quad_mesh, devices=devices)
1072
- add_function_test(TestFem, "test_grid_3d", test_grid_3d, devices=devices)
1073
- add_function_test(TestFem, "test_tet_mesh", test_tet_mesh, devices=devices)
1074
- add_function_test(TestFem, "test_hex_mesh", test_hex_mesh, devices=devices)
1075
- add_function_test(TestFem, "test_deformed_geometry", test_deformed_geometry, devices=devices)
1076
- add_function_test(TestFem, "test_dof_mapper", test_dof_mapper)
1077
- add_function_test(TestFem, "test_square_shape_functions", test_square_shape_functions)
1078
- add_function_test(TestFem, "test_cube_shape_functions", test_cube_shape_functions)
1079
- add_function_test(TestFem, "test_tri_shape_functions", test_tri_shape_functions)
1080
- add_function_test(TestFem, "test_tet_shape_functions", test_tet_shape_functions)
1081
1262
 
1082
- return TestFem
1263
+ add_function_test(TestFemShapeFunctions, "test_square_shape_functions", test_square_shape_functions)
1264
+ add_function_test(TestFemShapeFunctions, "test_cube_shape_functions", test_cube_shape_functions)
1265
+ add_function_test(TestFemShapeFunctions, "test_tri_shape_functions", test_tri_shape_functions)
1266
+ add_function_test(TestFemShapeFunctions, "test_tet_shape_functions", test_tet_shape_functions)
1083
1267
 
1084
1268
 
1085
1269
  if __name__ == "__main__":
1086
1270
  wp.build.clear_kernel_cache()
1087
- _ = register(unittest.TestCase)
1088
1271
  unittest.main(verbosity=2)