warp-lang 1.2.2__py3-none-win_amd64.whl → 1.3.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1410 -886
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +400 -198
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +66 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +65 -1
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +341 -224
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.0.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.2.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,936 @@
1
+ from typing import Any, Dict, Optional, Tuple
2
+
3
+ import numpy as np
4
+
5
+ import warp as wp
6
+ import warp.fem as fem
7
+ from warp.optim.linear import LinearOperator, aslinearoperator, preconditioner
8
+ from warp.sparse import BsrMatrix, bsr_get_diag, bsr_mv, bsr_transposed
9
+
10
+ __all__ = [
11
+ "gen_hexmesh",
12
+ "gen_quadmesh",
13
+ "gen_tetmesh",
14
+ "gen_trimesh",
15
+ "bsr_cg",
16
+ "bsr_solve_saddle",
17
+ "SaddleSystem",
18
+ "invert_diagonal_bsr_matrix",
19
+ "Plot",
20
+ ]
21
+
22
+
23
+ #
24
+ # Mesh utilities
25
+ #
26
+
27
+
28
+ def gen_trimesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
29
+ """Constructs a triangular mesh by diving each cell of a dense 2D grid into two triangles
30
+
31
+ Args:
32
+ res: Resolution of the grid along each dimension
33
+ bounds_lo: Position of the lower bound of the axis-aligned grid
34
+ bounds_up: Position of the upper bound of the axis-aligned grid
35
+
36
+ Returns:
37
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
38
+ """
39
+
40
+ if bounds_lo is None:
41
+ bounds_lo = wp.vec2(0.0)
42
+
43
+ if bounds_hi is None:
44
+ bounds_hi = wp.vec2(1.0)
45
+
46
+ Nx = res[0]
47
+ Ny = res[1]
48
+
49
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
50
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
51
+
52
+ positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
53
+
54
+ vidx = fem.utils.grid_to_tris(Nx, Ny)
55
+
56
+ return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
57
+
58
+
59
+ def gen_tetmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
60
+ """Constructs a tetrahedral mesh by diving each cell of a dense 3D grid into five tetrahedrons
61
+
62
+ Args:
63
+ res: Resolution of the grid along each dimension
64
+ bounds_lo: Position of the lower bound of the axis-aligned grid
65
+ bounds_up: Position of the upper bound of the axis-aligned grid
66
+
67
+ Returns:
68
+ Tuple of ndarrays: (Vertex positions, Tetrahedron vertex indices)
69
+ """
70
+
71
+ if bounds_lo is None:
72
+ bounds_lo = wp.vec3(0.0)
73
+
74
+ if bounds_hi is None:
75
+ bounds_hi = wp.vec3(1.0)
76
+
77
+ Nx = res[0]
78
+ Ny = res[1]
79
+ Nz = res[2]
80
+
81
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
82
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
83
+ z = np.linspace(bounds_lo[2], bounds_hi[2], Nz + 1)
84
+
85
+ positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
86
+
87
+ vidx = fem.utils.grid_to_tets(Nx, Ny, Nz)
88
+
89
+ return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
90
+
91
+
92
+ def gen_quadmesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
93
+ """Constructs a quadrilateral mesh from a dense 2D grid
94
+
95
+ Args:
96
+ res: Resolution of the grid along each dimension
97
+ bounds_lo: Position of the lower bound of the axis-aligned grid
98
+ bounds_up: Position of the upper bound of the axis-aligned grid
99
+
100
+ Returns:
101
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
102
+ """
103
+ if bounds_lo is None:
104
+ bounds_lo = wp.vec2(0.0)
105
+
106
+ if bounds_hi is None:
107
+ bounds_hi = wp.vec2(1.0)
108
+
109
+ Nx = res[0]
110
+ Ny = res[1]
111
+
112
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
113
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
114
+
115
+ positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
116
+
117
+ vidx = fem.utils.grid_to_quads(Nx, Ny)
118
+
119
+ return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
120
+
121
+
122
+ def gen_hexmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
123
+ """Constructs a quadrilateral mesh from a dense 2D grid
124
+
125
+ Args:
126
+ res: Resolution of the grid along each dimension
127
+ bounds_lo: Position of the lower bound of the axis-aligned grid
128
+ bounds_up: Position of the upper bound of the axis-aligned grid
129
+
130
+ Returns:
131
+ Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
132
+ """
133
+
134
+ if bounds_lo is None:
135
+ bounds_lo = wp.vec3(0.0)
136
+
137
+ if bounds_hi is None:
138
+ bounds_hi = wp.vec3(1.0)
139
+
140
+ Nx = res[0]
141
+ Ny = res[1]
142
+ Nz = res[2]
143
+
144
+ x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
145
+ y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
146
+ z = np.linspace(bounds_lo[1], bounds_hi[1], Nz + 1)
147
+
148
+ positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
149
+
150
+ vidx = fem.utils.grid_to_hexes(Nx, Ny, Nz)
151
+
152
+ return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
153
+
154
+
155
+ def gen_volume(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None, device=None) -> wp.Volume:
156
+ """Constructs a wp.Volume from a dense 3D grid
157
+
158
+ Args:
159
+ res: Resolution of the grid along each dimension
160
+ bounds_lo: Position of the lower bound of the axis-aligned grid
161
+ bounds_up: Position of the upper bound of the axis-aligned grid
162
+ device: Cuda device on which to allocate the grid
163
+ """
164
+
165
+ if bounds_lo is None:
166
+ bounds_lo = wp.vec3(0.0)
167
+
168
+ if bounds_hi is None:
169
+ bounds_hi = wp.vec3(1.0)
170
+
171
+ extents = bounds_hi - bounds_lo
172
+ voxel_size = wp.cw_div(extents, wp.vec3(res))
173
+
174
+ x = np.arange(res[0], dtype=int)
175
+ y = np.arange(res[1], dtype=int)
176
+ z = np.arange(res[2], dtype=int)
177
+
178
+ ijk = np.transpose(np.meshgrid(x, y, z), axes=(1, 2, 3, 0)).reshape(-1, 3)
179
+ ijk = wp.array(ijk, dtype=wp.vec3i, device=device)
180
+ return wp.Volume.allocate_by_voxels(ijk, voxel_size=voxel_size, translation=0.5 * voxel_size, device=device)
181
+
182
+
183
+ #
184
+ # Bsr matrix utilities
185
+ #
186
+
187
+
188
+ def _get_linear_solver_func(method_name: str):
189
+ from warp.optim.linear import bicgstab, cg, cr, gmres
190
+
191
+ if method_name == "bicgstab":
192
+ return bicgstab
193
+ if method_name == "gmres":
194
+ return gmres
195
+ if method_name == "cr":
196
+ return cr
197
+ return cg
198
+
199
+
200
+ def bsr_cg(
201
+ A: BsrMatrix,
202
+ x: wp.array,
203
+ b: wp.array,
204
+ max_iters: int = 0,
205
+ tol: float = 0.0001,
206
+ check_every=10,
207
+ use_diag_precond=True,
208
+ mv_routine=None,
209
+ quiet=False,
210
+ method: str = "cg",
211
+ ) -> Tuple[float, int]:
212
+ """Solves the linear system A x = b using an iterative solver, optionally with diagonal preconditioning
213
+
214
+ Args:
215
+ A: system left-hand side
216
+ x: result vector and initial guess
217
+ b: system right-hand-side
218
+ max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
219
+ tol: relative tolerance under which to stop the solve
220
+ check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
221
+ use_diag_precond: Whether to use diagonal preconditioning
222
+ mv_routine: Matrix-vector multiplication routine to use for multiplications with ``A``
223
+ quiet: if True, do not print iteration residuals
224
+ method: Iterative solver method to use, defaults to Conjugate Gradient
225
+
226
+ Returns:
227
+ Tuple (residual norm, iteration count)
228
+
229
+ """
230
+
231
+ if mv_routine is None:
232
+ M = preconditioner(A, "diag") if use_diag_precond else None
233
+ else:
234
+ A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
235
+ M = None
236
+
237
+ func = _get_linear_solver_func(method_name=method)
238
+
239
+ def print_callback(i, err, tol):
240
+ print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
241
+
242
+ callback = None if quiet else print_callback
243
+
244
+ end_iter, err, atol = func(
245
+ A=A,
246
+ b=b,
247
+ x=x,
248
+ maxiter=max_iters,
249
+ tol=tol,
250
+ check_every=check_every,
251
+ M=M,
252
+ callback=callback,
253
+ )
254
+
255
+ if not quiet:
256
+ res_str = "OK" if err <= atol else "TRUNCATED"
257
+ print(f"{func.__name__}: terminated after {end_iter} iterations with error = \t {err} ({res_str})")
258
+
259
+ return err, end_iter
260
+
261
+
262
+ class SaddleSystem(LinearOperator):
263
+ """Builds a linear operator corresponding to the saddle-point linear system [A B^T; B 0]
264
+
265
+ If use_diag_precond` is ``True``, builds the corresponding diagonal preconditioner `[diag(A); diag(B diag(A)^-1 B^T)]`
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ A: BsrMatrix,
271
+ B: BsrMatrix,
272
+ Bt: Optional[BsrMatrix] = None,
273
+ use_diag_precond: bool = True,
274
+ ):
275
+ if Bt is None:
276
+ Bt = bsr_transposed(B)
277
+
278
+ self._A = A
279
+ self._B = B
280
+ self._Bt = Bt
281
+
282
+ self._u_dtype = wp.vec(length=A.block_shape[0], dtype=A.scalar_type)
283
+ self._p_dtype = wp.vec(length=B.block_shape[0], dtype=B.scalar_type)
284
+ self._p_byte_offset = A.nrow * wp.types.type_size_in_bytes(self._u_dtype)
285
+
286
+ saddle_shape = (A.shape[0] + B.shape[0], A.shape[0] + B.shape[0])
287
+
288
+ super().__init__(saddle_shape, dtype=A.scalar_type, device=A.device, matvec=self._saddle_mv)
289
+
290
+ if use_diag_precond:
291
+ self._preconditioner = self._diag_preconditioner()
292
+ else:
293
+ self._preconditioner = None
294
+
295
+ def _diag_preconditioner(self):
296
+ A = self._A
297
+ B = self._B
298
+
299
+ M_u = preconditioner(A, "diag")
300
+
301
+ A_diag = bsr_get_diag(A)
302
+
303
+ schur_block_shape = (B.block_shape[0], B.block_shape[0])
304
+ schur_dtype = wp.mat(shape=schur_block_shape, dtype=B.scalar_type)
305
+ schur_inv_diag = wp.empty(dtype=schur_dtype, shape=B.nrow, device=self.device)
306
+ wp.launch(
307
+ _compute_schur_inverse_diagonal,
308
+ dim=B.nrow,
309
+ device=A.device,
310
+ inputs=[B.offsets, B.columns, B.values, A_diag, schur_inv_diag],
311
+ )
312
+
313
+ if schur_block_shape == (1, 1):
314
+ # Downcast 1x1 mats to scalars
315
+ schur_inv_diag = schur_inv_diag.view(dtype=B.scalar_type)
316
+
317
+ M_p = aslinearoperator(schur_inv_diag)
318
+
319
+ def precond_mv(x, y, z, alpha, beta):
320
+ x_u = self.u_slice(x)
321
+ x_p = self.p_slice(x)
322
+ y_u = self.u_slice(y)
323
+ y_p = self.p_slice(y)
324
+ z_u = self.u_slice(z)
325
+ z_p = self.p_slice(z)
326
+
327
+ M_u.matvec(x_u, y_u, z_u, alpha=alpha, beta=beta)
328
+ M_p.matvec(x_p, y_p, z_p, alpha=alpha, beta=beta)
329
+
330
+ return LinearOperator(
331
+ shape=self.shape,
332
+ dtype=self.dtype,
333
+ device=self.device,
334
+ matvec=precond_mv,
335
+ )
336
+
337
+ @property
338
+ def preconditioner(self):
339
+ return self._preconditioner
340
+
341
+ def u_slice(self, a: wp.array):
342
+ return wp.array(
343
+ ptr=a.ptr,
344
+ dtype=self._u_dtype,
345
+ shape=self._A.nrow,
346
+ strides=None,
347
+ device=a.device,
348
+ pinned=a.pinned,
349
+ copy=False,
350
+ )
351
+
352
+ def p_slice(self, a: wp.array):
353
+ return wp.array(
354
+ ptr=a.ptr + self._p_byte_offset,
355
+ dtype=self._p_dtype,
356
+ shape=self._B.nrow,
357
+ strides=None,
358
+ device=a.device,
359
+ pinned=a.pinned,
360
+ copy=False,
361
+ )
362
+
363
+ def _saddle_mv(self, x, y, z, alpha, beta):
364
+ x_u = self.u_slice(x)
365
+ x_p = self.p_slice(x)
366
+ z_u = self.u_slice(z)
367
+ z_p = self.p_slice(z)
368
+
369
+ if y.ptr != z.ptr and beta != 0.0:
370
+ wp.copy(src=y, dest=z)
371
+
372
+ bsr_mv(self._A, x_u, z_u, alpha=alpha, beta=beta)
373
+ bsr_mv(self._Bt, x_p, z_u, alpha=alpha, beta=1.0)
374
+ bsr_mv(self._B, x_u, z_p, alpha=alpha, beta=beta)
375
+
376
+
377
+ def bsr_solve_saddle(
378
+ saddle_system: SaddleSystem,
379
+ x_u: wp.array,
380
+ x_p: wp.array,
381
+ b_u: wp.array,
382
+ b_p: wp.array,
383
+ max_iters: int = 0,
384
+ tol: float = 0.0001,
385
+ check_every=10,
386
+ quiet=False,
387
+ method: str = "cg",
388
+ ) -> Tuple[float, int]:
389
+ """Solves the saddle-point linear system [A B^T; B 0] (x_u; x_p) = (b_u; b_p) using an iterative solver, optionally with diagonal preconditioning
390
+
391
+ Args:
392
+ saddle_system: Saddle point system
393
+ x_u: primal part of the result vector and initial guess
394
+ x_p: Lagrange multiplier part of the result vector and initial guess
395
+ b_u: primal left-hand-side
396
+ b_p: constraint left-hand-side
397
+ max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
398
+ tol: relative tolerance under which to stop the solve
399
+ check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
400
+ quiet: if True, do not print iteration residuals
401
+ method: Iterative solver method to use, defaults to BiCGSTAB
402
+
403
+ Returns:
404
+ Tuple (residual norm, iteration count)
405
+
406
+ """
407
+ x = wp.empty(dtype=saddle_system.scalar_type, shape=saddle_system.shape[0], device=saddle_system.device)
408
+ b = wp.empty_like(x)
409
+
410
+ wp.copy(src=x_u, dest=saddle_system.u_slice(x))
411
+ wp.copy(src=x_p, dest=saddle_system.p_slice(x))
412
+ wp.copy(src=b_u, dest=saddle_system.u_slice(b))
413
+ wp.copy(src=b_p, dest=saddle_system.p_slice(b))
414
+
415
+ func = _get_linear_solver_func(method_name=method)
416
+
417
+ def print_callback(i, err, tol):
418
+ print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
419
+
420
+ callback = None if quiet else print_callback
421
+
422
+ end_iter, err, atol = func(
423
+ A=saddle_system,
424
+ b=b,
425
+ x=x,
426
+ maxiter=max_iters,
427
+ tol=tol,
428
+ check_every=check_every,
429
+ M=saddle_system.preconditioner,
430
+ callback=callback,
431
+ )
432
+
433
+ if not quiet:
434
+ res_str = "OK" if err <= atol else "TRUNCATED"
435
+ print(f"{func.__name__}: terminated after {end_iter} iterations with absolute error = \t {err} ({res_str})")
436
+
437
+ wp.copy(dest=x_u, src=saddle_system.u_slice(x))
438
+ wp.copy(dest=x_p, src=saddle_system.p_slice(x))
439
+
440
+ return err, end_iter
441
+
442
+
443
+ @wp.kernel
444
+ def _compute_schur_inverse_diagonal(
445
+ B_offsets: wp.array(dtype=int),
446
+ B_indices: wp.array(dtype=int),
447
+ B_values: wp.array(dtype=Any),
448
+ A_diag: wp.array(dtype=Any),
449
+ P_diag: wp.array(dtype=Any),
450
+ ):
451
+ row = wp.tid()
452
+
453
+ zero = P_diag.dtype(P_diag.dtype.dtype(0.0))
454
+
455
+ schur = zero
456
+
457
+ beg = B_offsets[row]
458
+ end = B_offsets[row + 1]
459
+
460
+ for b in range(beg, end):
461
+ B = B_values[b]
462
+ col = B_indices[b]
463
+ Ai = wp.inverse(A_diag[col])
464
+ S = B * Ai * wp.transpose(B)
465
+ schur += S
466
+
467
+ P_diag[row] = fem.utils.inverse_qr(schur)
468
+
469
+
470
+ def invert_diagonal_bsr_matrix(A: BsrMatrix):
471
+ """Inverts each block of a block-diagonal mass matrix"""
472
+
473
+ values = A.values
474
+ if not wp.types.type_is_matrix(values.dtype):
475
+ values = values.view(dtype=wp.mat(shape=(1, 1), dtype=A.scalar_type))
476
+
477
+ wp.launch(
478
+ kernel=_block_diagonal_invert,
479
+ dim=A.nrow,
480
+ inputs=[values],
481
+ device=values.device,
482
+ )
483
+
484
+
485
+ @wp.kernel
486
+ def _block_diagonal_invert(values: wp.array(dtype=Any)):
487
+ i = wp.tid()
488
+ values[i] = fem.utils.inverse_qr(values[i])
489
+
490
+
491
+ #
492
+ # Plot utilities
493
+ #
494
+
495
+
496
+ class Plot:
497
+ def __init__(self, stage=None, default_point_radius=0.01):
498
+ self.default_point_radius = default_point_radius
499
+
500
+ self._fields = {}
501
+
502
+ self._usd_renderer = None
503
+ if stage is not None:
504
+ try:
505
+ from warp.render import UsdRenderer
506
+
507
+ self._usd_renderer = UsdRenderer(stage)
508
+ except Exception as err:
509
+ print(f"Could not initialize UsdRenderer for stage '{stage}': {err}.")
510
+
511
+ def begin_frame(self, time):
512
+ if self._usd_renderer is not None:
513
+ self._usd_renderer.begin_frame(time=time)
514
+
515
+ def end_frame(self):
516
+ if self._usd_renderer is not None:
517
+ self._usd_renderer.end_frame()
518
+
519
+ def add_field(self, name: str, field: fem.DiscreteField):
520
+ if self._usd_renderer is not None:
521
+ self._render_to_usd(field)
522
+
523
+ if name not in self._fields:
524
+ field_clone = field.space.make_field(space_partition=field.space_partition)
525
+ self._fields[name] = (field_clone, [])
526
+
527
+ self._fields[name][1].append(field.dof_values.numpy())
528
+
529
+ def _render_to_usd(self, name: str, field: fem.DiscreteField):
530
+ points = field.space.node_positions().numpy()
531
+ values = field.dof_values.numpy()
532
+
533
+ if values.ndim == 2:
534
+ if values.shape[1] == field.space.dimension:
535
+ # use values as displacement
536
+ points += values
537
+ else:
538
+ # use magnitude
539
+ values = np.linalg.norm(values, axis=1)
540
+
541
+ if field.space.dimension == 2:
542
+ z = values if values.ndim == 1 else np.zeros((points.shape[0], 1))
543
+ points = np.hstack((points, z))
544
+
545
+ if hasattr(field.space, "node_triangulation"):
546
+ indices = field.space.node_triangulation()
547
+ self._usd_renderer.render_mesh(name, points=points, indices=indices)
548
+ else:
549
+ self._usd_renderer.render_points(name, points=points, radius=self.default_point_radius)
550
+ elif values.ndim == 1:
551
+ self._usd_renderer.render_points(name, points, radius=values)
552
+ else:
553
+ self._usd_renderer.render_points(name, points, radius=self.default_point_radius)
554
+
555
+ def plot(self, options: Dict[str, Any] = None, backend: str = "auto"):
556
+ if options is None:
557
+ options = {}
558
+
559
+ if backend == "pyvista":
560
+ return self._plot_pyvista(options)
561
+ if backend == "matplotlib":
562
+ return self._plot_matplotlib(options)
563
+
564
+ # try both
565
+ try:
566
+ return self._plot_pyvista(options)
567
+ except ModuleNotFoundError:
568
+ try:
569
+ return self._plot_matplotlib(options)
570
+ except ModuleNotFoundError:
571
+ wp.utils.warn("pyvista or matplotlib must be installed to visualize solution results")
572
+
573
+ def _plot_pyvista(self, options: Dict[str, Any]):
574
+ import pyvista
575
+
576
+ grids = {}
577
+ scales = {}
578
+ markers = {}
579
+
580
+ animate = False
581
+
582
+ for name, (field, values) in self._fields.items():
583
+ cells, types = field.space.vtk_cells()
584
+ node_pos = field.space.node_positions().numpy()
585
+
586
+ args = options.get(name, {})
587
+
588
+ grid_scale = np.max(np.max(node_pos, axis=0) - np.min(node_pos, axis=0))
589
+ value_range = self._get_field_value_range(values, args)
590
+ scales[name] = (grid_scale, value_range)
591
+
592
+ if node_pos.shape[1] == 2:
593
+ node_pos = np.hstack((node_pos, np.zeros((node_pos.shape[0], 1))))
594
+
595
+ grid = pyvista.UnstructuredGrid(cells, types, node_pos)
596
+ grids[name] = grid
597
+
598
+ if len(values) > 1:
599
+ animate = True
600
+
601
+ def set_frame_data(frame):
602
+ for name, (field, values) in self._fields.items():
603
+ if frame > 0 and len(values) == 1:
604
+ continue
605
+
606
+ v = values[frame % len(values)]
607
+ grid = grids[name]
608
+ grid_scale, value_range = scales[name]
609
+ field_args = options.get(name, {})
610
+
611
+ marker = None
612
+
613
+ if field.space.dimension == 2 and v.ndim == 2 and v.shape[1] == 2:
614
+ grid.point_data[name] = np.hstack((v, np.zeros((v.shape[0], 1))))
615
+ else:
616
+ grid.point_data[name] = v
617
+
618
+ if v.ndim == 2:
619
+ grid.point_data[name + "_mag"] = np.linalg.norm(v, axis=1)
620
+
621
+ if "arrows" in field_args:
622
+ glyph_scale = field_args["arrows"].get("glyph_scale", 1.0)
623
+ glyph_scale *= grid_scale / max(1.0e-8, value_range[1] - value_range[0])
624
+ marker = grid.glyph(scale=name, orient=name, factor=glyph_scale)
625
+ elif "contours" in field_args:
626
+ levels = field_args["contours"].get("levels", 10)
627
+ if type(levels) == int:
628
+ levels = np.linspace(*value_range, levels)
629
+ marker = grid.contour(isosurfaces=levels, scalars=name + "_mag" if v.ndim == 2 else name)
630
+ elif field.space.dimension == 2:
631
+ z_scale = grid_scale / max(1.0e-8, value_range[1] - value_range[0])
632
+
633
+ if "streamlines" in field_args:
634
+ center = np.mean(grid.points, axis=0)
635
+ density = field_args["streamlines"].get("density", 1.0)
636
+ cell_size = 1.0 / np.sqrt(field.space.geometry.cell_count())
637
+
638
+ separating_distance = 0.5 / (30.0 * density * cell_size)
639
+ # Try with various sep distance until we get at least one line
640
+ while separating_distance * cell_size < 1.0:
641
+ lines = grid.streamlines_evenly_spaced_2D(
642
+ vectors=name,
643
+ start_position=center,
644
+ separating_distance=separating_distance,
645
+ separating_distance_ratio=0.5,
646
+ step_length=0.25,
647
+ compute_vorticity=False,
648
+ )
649
+ if lines.n_lines > 0:
650
+ break
651
+ separating_distance *= 1.25
652
+ marker = lines.tube(radius=0.0025 * grid_scale / density)
653
+ elif "arrows" in field_args:
654
+ glyph_scale = field_args["arrows"].get("glyph_scale", 1.0)
655
+ glyph_scale *= grid_scale / max(1.0e-8, value_range[1] - value_range[0])
656
+ marker = grid.glyph(scale=name, orient=name, factor=glyph_scale)
657
+ elif "displacement" in field_args:
658
+ grid.points[:, 0:2] = field.space.node_positions().numpy() + v
659
+ else:
660
+ # Extrude surface
661
+ z = v if v.ndim == 1 else grid.point_data[name + "_mag"]
662
+ grid.points[:, 2] = z * z_scale
663
+
664
+ elif field.space.dimension == 3:
665
+ if "streamlines" in field_args:
666
+ center = np.mean(grid.points, axis=0)
667
+ density = field_args["streamlines"].get("density", 1.0)
668
+ cell_size = 1.0 / np.sqrt(field.space.geometry.cell_count())
669
+ lines = grid.streamlines(vectors=name, n_points=int(100 * density))
670
+ marker = lines.tube(radius=0.0025 * grid_scale / density)
671
+ elif "displacement" in field_args:
672
+ grid.points = field.space.node_positions().numpy() + v
673
+
674
+ if frame == 0:
675
+ if v.ndim == 1:
676
+ grid.set_active_scalars(name)
677
+ else:
678
+ grid.set_active_vectors(name)
679
+ grid.set_active_scalars(name + "_mag")
680
+ markers[name] = marker
681
+ elif marker:
682
+ markers[name].copy_from(marker)
683
+
684
+ set_frame_data(0)
685
+
686
+ subplot_rows = options.get("rows", 1)
687
+ subplot_shape = (subplot_rows, (len(grids) + subplot_rows - 1) // subplot_rows)
688
+
689
+ plotter = pyvista.Plotter(shape=subplot_shape)
690
+ plotter.link_views()
691
+ plotter.add_camera_orientation_widget()
692
+ for index, (name, grid) in enumerate(grids.items()):
693
+ plotter.subplot(index // subplot_shape[1], index % subplot_shape[1])
694
+ grid_scale, value_range = scales[name]
695
+ field = self._fields[name][0]
696
+ marker = markers[name]
697
+ if marker:
698
+ if field.space.dimension == 2:
699
+ plotter.add_mesh(marker, show_scalar_bar=False)
700
+ plotter.add_mesh(grid, opacity=0.25, clim=value_range)
701
+ plotter.view_xy()
702
+ else:
703
+ plotter.add_mesh(marker, opacity=0.25)
704
+ elif field.space.dimension == 3:
705
+ plotter.add_mesh_clip_plane(grid, show_edges=True, clim=value_range)
706
+ else:
707
+ plotter.add_mesh(grid, show_edges=True, clim=value_range)
708
+ plotter.show(interactive_update=animate)
709
+
710
+ frame = 0
711
+ while animate and not plotter.iren.interactor.GetDone():
712
+ frame += 1
713
+ set_frame_data(frame)
714
+ plotter.update()
715
+
716
+ def _plot_matplotlib(self, options: Dict[str, Any]):
717
+ import matplotlib.animation as animation
718
+ import matplotlib.pyplot as plt
719
+ from matplotlib import cm
720
+
721
+ def make_animation(fig, ax, cax, values, draw_func):
722
+ def animate(i):
723
+ cs = draw_func(ax, values[i])
724
+
725
+ cax.cla()
726
+ fig.colorbar(cs, cax)
727
+
728
+ return cs
729
+
730
+ return animation.FuncAnimation(
731
+ ax.figure,
732
+ animate,
733
+ interval=30,
734
+ blit=False,
735
+ frames=len(values),
736
+ )
737
+
738
+ def make_draw_func(field, args, plot_func, plot_opts):
739
+ def draw_fn(axes, values):
740
+ axes.clear()
741
+
742
+ field.dof_values = values
743
+ cs = plot_func(field, axes=axes, **plot_opts)
744
+
745
+ if "xlim" in args:
746
+ axes.set_xlim(*args["xlim"])
747
+ if "ylim" in args:
748
+ axes.set_ylim(*args["ylim"])
749
+
750
+ return cs
751
+
752
+ return draw_fn
753
+
754
+ anims = []
755
+
756
+ field_count = len(self._fields)
757
+ subplot_rows = options.get("rows", 1)
758
+ subplot_shape = (subplot_rows, (field_count + subplot_rows - 1) // subplot_rows)
759
+
760
+ for index, (name, (field, values)) in enumerate(self._fields.items()):
761
+ args = options.get(name, {})
762
+ v = values[0]
763
+
764
+ plot_fn = None
765
+ plot_3d = False
766
+ plot_opts = {"cmap": cm.viridis}
767
+
768
+ plot_opts["clim"] = self._get_field_value_range(values, args)
769
+
770
+ if field.space.dimension == 2:
771
+ if "contours" in args:
772
+ plot_opts["levels"] = args["contours"].get("levels", None)
773
+ plot_fn = _plot_contours
774
+ elif v.ndim == 2 and v.shape[1] == 2:
775
+ if "displacement" in args:
776
+ plot_fn = _plot_displaced_tri_mesh
777
+ elif "streamlines" in args:
778
+ plot_opts["density"] = args["streamlines"].get("density", 1.0)
779
+ plot_fn = _plot_streamlines
780
+ elif "arrows" in args:
781
+ plot_opts["glyph_scale"] = args["arrows"].get("glyph_scale", 1.0)
782
+ plot_fn = _plot_quivers
783
+
784
+ if plot_fn is None:
785
+ plot_fn = _plot_surface
786
+ plot_3d = True
787
+
788
+ elif field.space.dimension == 3:
789
+ if "arrows" in args or "streamlines" in args:
790
+ plot_opts["glyph_scale"] = args.get("arrows", {}).get("glyph_scale", 1.0)
791
+ plot_fn = _plot_quivers_3d
792
+ else:
793
+ plot_fn = _plot_3d_scatter
794
+ plot_3d = True
795
+
796
+ subplot_kw = {"projection": "3d"} if plot_3d else {}
797
+ axes = plt.subplot(*subplot_shape, index + 1, **subplot_kw)
798
+
799
+ if not plot_3d:
800
+ axes.set_aspect("equal")
801
+
802
+ draw_fn = make_draw_func(field, args, plot_func=plot_fn, plot_opts=plot_opts)
803
+ cs = draw_fn(axes, values[0])
804
+
805
+ fig = plt.gcf()
806
+ cax = fig.colorbar(cs).ax
807
+
808
+ if len(values) > 1:
809
+ anims.append(make_animation(fig, axes, cax, values, draw_func=draw_fn))
810
+
811
+ plt.show()
812
+
813
+ @staticmethod
814
+ def _get_field_value_range(values, field_options: Dict[str, Any]):
815
+ value_range = field_options.get("clim", None)
816
+ if value_range is None:
817
+ value_range = (
818
+ min((np.min(_value_or_magnitude(v)) for v in values)),
819
+ max((np.max(_value_or_magnitude(v)) for v in values)),
820
+ )
821
+
822
+ return value_range
823
+
824
+
825
+ def _value_or_magnitude(values: np.ndarray):
826
+ if values.ndim == 1:
827
+ return values
828
+ return np.linalg.norm(values, axis=-1)
829
+
830
+
831
+ def _field_triangulation(field):
832
+ from matplotlib.tri import Triangulation
833
+
834
+ node_positions = field.space.node_positions().numpy()
835
+ return Triangulation(x=node_positions[:, 0], y=node_positions[:, 1], triangles=field.space.node_triangulation())
836
+
837
+
838
+ def _plot_surface(field, axes, **kwargs):
839
+ Z = _value_or_magnitude(field.dof_values.numpy())
840
+
841
+ if "clim" in kwargs:
842
+ axes.set_zlim(*kwargs["clim"])
843
+
844
+ if hasattr(field.space, "node_grid"):
845
+ X, Y = field.space.node_grid()
846
+ Z = Z.reshape(X.shape)
847
+ return axes.plot_surface(X, Y, Z, linewidth=0.1, antialiased=False, **kwargs)
848
+
849
+ if hasattr(field.space, "node_triangulation"):
850
+ triangulation = _field_triangulation(field)
851
+ return axes.plot_trisurf(triangulation, Z, linewidth=0.1, antialiased=False, **kwargs)
852
+
853
+ # scatter
854
+ X, Y = field.space.node_positions().numpy().T
855
+ return axes.scatter(X, Y, Z, c=Z, **kwargs)
856
+
857
+
858
+ def _plot_displaced_tri_mesh(field, axes, **kwargs):
859
+ triangulation = _field_triangulation(field)
860
+
861
+ displacement = field.dof_values.numpy()
862
+ triangulation.x += displacement[:, 0]
863
+ triangulation.y += displacement[:, 1]
864
+
865
+ Z = _value_or_magnitude(displacement)
866
+
867
+ # Plot the surface.
868
+ cs = axes.tripcolor(triangulation, Z, **kwargs)
869
+ axes.triplot(triangulation, lw=0.1)
870
+
871
+ return cs
872
+
873
+
874
+ def _plot_quivers(field, axes, clim=None, glyph_scale=1.0, **kwargs):
875
+ X, Y = field.space.node_positions().numpy().T
876
+
877
+ vel = field.dof_values.numpy()
878
+ u = vel[:, 0].reshape(X.shape)
879
+ v = vel[:, 1].reshape(X.shape)
880
+
881
+ return axes.quiver(X, Y, u, v, _value_or_magnitude(vel), scale=1.0 / glyph_scale, **kwargs)
882
+
883
+
884
+ def _plot_quivers_3d(field, axes, clim=None, cmap=None, glyph_scale=1.0, **kwargs):
885
+ X, Y, Z = field.space.node_positions().numpy().T
886
+
887
+ vel = field.dof_values.numpy()
888
+
889
+ colors = cmap((_value_or_magnitude(vel) - clim[0]) / (clim[1] - clim[0]))
890
+
891
+ u = vel[:, 0].reshape(X.shape) / (clim[1] - clim[0])
892
+ v = vel[:, 1].reshape(X.shape) / (clim[1] - clim[0])
893
+ w = vel[:, 2].reshape(X.shape) / (clim[1] - clim[0])
894
+
895
+ return axes.quiver(X, Y, Z, u, v, w, colors=colors, length=glyph_scale, clim=clim, cmap=cmap, **kwargs)
896
+
897
+
898
+ def _plot_streamlines(field, axes, clim=None, **kwargs):
899
+ import matplotlib.tri as tr
900
+
901
+ triangulation = _field_triangulation(field)
902
+
903
+ vel = field.dof_values.numpy()
904
+
905
+ itp_vx = tr.CubicTriInterpolator(triangulation, vel[:, 0])
906
+ itp_vy = tr.CubicTriInterpolator(triangulation, vel[:, 1])
907
+
908
+ X, Y = np.meshgrid(
909
+ np.linspace(np.min(triangulation.x), np.max(triangulation.x), 100),
910
+ np.linspace(np.min(triangulation.y), np.max(triangulation.y), 100),
911
+ )
912
+
913
+ u = itp_vx(X, Y)
914
+ v = itp_vy(X, Y)
915
+ C = np.sqrt(u * u + v * v)
916
+
917
+ plot = axes.streamplot(X, Y, u, v, color=C, **kwargs)
918
+ return plot.lines
919
+
920
+
921
+ def _plot_contours(field, axes, clim=None, **kwargs):
922
+ triangulation = _field_triangulation(field)
923
+
924
+ Z = _value_or_magnitude(field.dof_values.numpy())
925
+
926
+ tc = axes.tricontourf(triangulation, Z, **kwargs)
927
+ axes.tricontour(triangulation, Z, **kwargs)
928
+ return tc
929
+
930
+
931
+ def _plot_3d_scatter(field, axes, **kwargs):
932
+ X, Y, Z = field.space.node_positions().numpy().T
933
+
934
+ f = _value_or_magnitude(field.dof_values.numpy()).reshape(X.shape)
935
+
936
+ return axes.scatter(X, Y, Z, c=f, **kwargs)