warp-lang 1.2.2__py3-none-win_amd64.whl → 1.3.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1410 -886
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +400 -198
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +66 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +65 -1
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +341 -224
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.0.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.2.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
@@ -1,378 +0,0 @@
1
- from typing import Any, Optional, Tuple, Union
2
-
3
- import warp as wp
4
- import warp.types
5
- from warp.optim.linear import LinearOperator, aslinearoperator, preconditioner
6
- from warp.sparse import BsrMatrix, bsr_get_diag, bsr_mv, bsr_transposed, bsr_zeros
7
-
8
-
9
- def bsr_to_scipy(matrix: BsrMatrix) -> "scipy.sparse.bsr_array": # noqa: F821
10
- try:
11
- from scipy.sparse import bsr_array, csr_array
12
- except ImportError:
13
- # WAR for older scipy
14
- from scipy.sparse import bsr_matrix as bsr_array
15
- from scipy.sparse import csr_matrix as csr_array
16
-
17
- if matrix.block_shape == (1, 1):
18
- return csr_array(
19
- (
20
- matrix.values.numpy().flatten()[: matrix.nnz],
21
- matrix.columns.numpy()[: matrix.nnz],
22
- matrix.offsets.numpy(),
23
- ),
24
- shape=matrix.shape,
25
- )
26
-
27
- return bsr_array(
28
- (
29
- matrix.values.numpy().reshape((matrix.values.shape[0], *matrix.block_shape))[: matrix.nnz],
30
- matrix.columns.numpy()[: matrix.nnz],
31
- matrix.offsets.numpy(),
32
- ),
33
- shape=matrix.shape,
34
- )
35
-
36
-
37
- def scipy_to_bsr(
38
- sp: Union["scipy.sparse.bsr_array", "scipy.sparse.csr_array"], # noqa: F821
39
- device=None,
40
- dtype=None,
41
- ) -> BsrMatrix:
42
- try:
43
- from scipy.sparse import csr_array
44
- except ImportError:
45
- # WAR for older scipy
46
- from scipy.sparse import csr_matrix as csr_array
47
-
48
- if dtype is None:
49
- dtype = warp.types.np_dtype_to_warp_type[sp.dtype]
50
-
51
- sp.sort_indices()
52
-
53
- if isinstance(sp, csr_array):
54
- matrix = bsr_zeros(sp.shape[0], sp.shape[1], dtype, device=device)
55
- else:
56
- block_shape = sp.blocksize
57
- block_type = wp.types.matrix(shape=block_shape, dtype=dtype)
58
- matrix = bsr_zeros(
59
- sp.shape[0] // block_shape[0],
60
- sp.shape[1] // block_shape[1],
61
- block_type,
62
- device=device,
63
- )
64
-
65
- matrix.nnz = sp.nnz
66
- matrix.values = wp.array(sp.data.flatten(), dtype=matrix.values.dtype, device=device)
67
- matrix.columns = wp.array(sp.indices, dtype=matrix.columns.dtype, device=device)
68
- matrix.offsets = wp.array(sp.indptr, dtype=matrix.offsets.dtype, device=device)
69
-
70
- return matrix
71
-
72
-
73
- def get_linear_solver_func(method_name: str):
74
- from warp.optim.linear import bicgstab, cg, cr, gmres
75
-
76
- if method_name == "bicgstab":
77
- return bicgstab
78
- if method_name == "gmres":
79
- return gmres
80
- if method_name == "cr":
81
- return cr
82
- return cg
83
-
84
-
85
- def bsr_cg(
86
- A: BsrMatrix,
87
- x: wp.array,
88
- b: wp.array,
89
- max_iters: int = 0,
90
- tol: float = 0.0001,
91
- check_every=10,
92
- use_diag_precond=True,
93
- mv_routine=None,
94
- quiet=False,
95
- method: str = "cg",
96
- ) -> Tuple[float, int]:
97
- """Solves the linear system A x = b using an iterative solver, optionally with diagonal preconditioning
98
-
99
- Args:
100
- A: system left-hand side
101
- x: result vector and initial guess
102
- b: system right-hand-side
103
- max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
104
- tol: relative tolerance under which to stop the solve
105
- check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
106
- use_diag_precond: Whether to use diagonal preconditioning
107
- mv_routine: Matrix-vector multiplication routine to use for multiplications with ``A``
108
- quiet: if True, do not print iteration residuals
109
- method: Iterative solver method to use, defaults to Conjugate Gradient
110
-
111
- Returns:
112
- Tuple (residual norm, iteration count)
113
-
114
- """
115
-
116
- if mv_routine is None:
117
- M = preconditioner(A, "diag") if use_diag_precond else None
118
- else:
119
- A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
120
- M = None
121
-
122
- func = get_linear_solver_func(method_name=method)
123
-
124
- def print_callback(i, err, tol):
125
- print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
126
-
127
- callback = None if quiet else print_callback
128
-
129
- end_iter, err, atol = func(
130
- A=A,
131
- b=b,
132
- x=x,
133
- maxiter=max_iters,
134
- tol=tol,
135
- check_every=check_every,
136
- M=M,
137
- callback=callback,
138
- )
139
-
140
- if not quiet:
141
- res_str = "OK" if err <= atol else "TRUNCATED"
142
- print(f"{func.__name__}: terminated after {end_iter} iterations with error = \t {err} ({res_str})")
143
-
144
- return err, end_iter
145
-
146
-
147
- class SaddleSystem(LinearOperator):
148
- """Builds a linear operator corresponding to the saddle-point linear system [A B^T; B 0]
149
-
150
- If use_diag_precond` is ``True``, builds the corresponding diagonal preconditioner `[diag(A); diag(B diag(A)^-1 B^T)]`
151
- """
152
-
153
- def __init__(
154
- self,
155
- A: BsrMatrix,
156
- B: BsrMatrix,
157
- Bt: Optional[BsrMatrix] = None,
158
- use_diag_precond: bool = True,
159
- ):
160
- if Bt is None:
161
- Bt = bsr_transposed(B)
162
-
163
- self._A = A
164
- self._B = B
165
- self._Bt = Bt
166
-
167
- self._u_dtype = wp.vec(length=A.block_shape[0], dtype=A.scalar_type)
168
- self._p_dtype = wp.vec(length=B.block_shape[0], dtype=B.scalar_type)
169
- self._p_byte_offset = A.nrow * wp.types.type_size_in_bytes(self._u_dtype)
170
-
171
- saddle_shape = (A.shape[0] + B.shape[0], A.shape[0] + B.shape[0])
172
-
173
- super().__init__(saddle_shape, dtype=A.scalar_type, device=A.device, matvec=self._saddle_mv)
174
-
175
- if use_diag_precond:
176
- self._preconditioner = self._diag_preconditioner()
177
- else:
178
- self._preconditioner = None
179
-
180
- def _diag_preconditioner(self):
181
- A = self._A
182
- B = self._B
183
-
184
- M_u = preconditioner(A, "diag")
185
-
186
- A_diag = bsr_get_diag(A)
187
-
188
- schur_block_shape = (B.block_shape[0], B.block_shape[0])
189
- schur_dtype = wp.mat(shape=schur_block_shape, dtype=B.scalar_type)
190
- schur_inv_diag = wp.empty(dtype=schur_dtype, shape=B.nrow, device=self.device)
191
- wp.launch(
192
- _compute_schur_inverse_diagonal,
193
- dim=B.nrow,
194
- device=A.device,
195
- inputs=[B.offsets, B.columns, B.values, A_diag, schur_inv_diag],
196
- )
197
-
198
- if schur_block_shape == (1, 1):
199
- # Downcast 1x1 mats to scalars
200
- schur_inv_diag = schur_inv_diag.view(dtype=B.scalar_type)
201
-
202
- M_p = aslinearoperator(schur_inv_diag)
203
-
204
- def precond_mv(x, y, z, alpha, beta):
205
- x_u = self.u_slice(x)
206
- x_p = self.p_slice(x)
207
- y_u = self.u_slice(y)
208
- y_p = self.p_slice(y)
209
- z_u = self.u_slice(z)
210
- z_p = self.p_slice(z)
211
-
212
- M_u.matvec(x_u, y_u, z_u, alpha=alpha, beta=beta)
213
- M_p.matvec(x_p, y_p, z_p, alpha=alpha, beta=beta)
214
-
215
- return LinearOperator(
216
- shape=self.shape,
217
- dtype=self.dtype,
218
- device=self.device,
219
- matvec=precond_mv,
220
- )
221
-
222
- @property
223
- def preconditioner(self):
224
- return self._preconditioner
225
-
226
- def u_slice(self, a: wp.array):
227
- return wp.array(
228
- ptr=a.ptr,
229
- dtype=self._u_dtype,
230
- shape=self._A.nrow,
231
- strides=None,
232
- device=a.device,
233
- pinned=a.pinned,
234
- copy=False,
235
- )
236
-
237
- def p_slice(self, a: wp.array):
238
- return wp.array(
239
- ptr=a.ptr + self._p_byte_offset,
240
- dtype=self._p_dtype,
241
- shape=self._B.nrow,
242
- strides=None,
243
- device=a.device,
244
- pinned=a.pinned,
245
- copy=False,
246
- )
247
-
248
- def _saddle_mv(self, x, y, z, alpha, beta):
249
- x_u = self.u_slice(x)
250
- x_p = self.p_slice(x)
251
- z_u = self.u_slice(z)
252
- z_p = self.p_slice(z)
253
-
254
- if y.ptr != z.ptr and beta != 0.0:
255
- wp.copy(src=y, dest=z)
256
-
257
- bsr_mv(self._A, x_u, z_u, alpha=alpha, beta=beta)
258
- bsr_mv(self._Bt, x_p, z_u, alpha=alpha, beta=1.0)
259
- bsr_mv(self._B, x_u, z_p, alpha=alpha, beta=beta)
260
-
261
-
262
- def bsr_solve_saddle(
263
- saddle_system: SaddleSystem,
264
- x_u: wp.array,
265
- x_p: wp.array,
266
- b_u: wp.array,
267
- b_p: wp.array,
268
- max_iters: int = 0,
269
- tol: float = 0.0001,
270
- check_every=10,
271
- quiet=False,
272
- method: str = "cg",
273
- ) -> Tuple[float, int]:
274
- """Solves the saddle-point linear system [A B^T; B 0] (x_u; x_p) = (b_u; b_p) using an iterative solver, optionally with diagonal preconditioning
275
-
276
- Args:
277
- saddle_system: Saddle point system
278
- x_u: primal part of the result vector and initial guess
279
- x_p: Lagrange multiplier part of the result vector and initial guess
280
- b_u: primal left-hand-side
281
- b_p: constraint left-hand-side
282
- max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
283
- tol: relative tolerance under which to stop the solve
284
- check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
285
- quiet: if True, do not print iteration residuals
286
- method: Iterative solver method to use, defaults to BiCGSTAB
287
-
288
- Returns:
289
- Tuple (residual norm, iteration count)
290
-
291
- """
292
- x = wp.empty(dtype=saddle_system.scalar_type, shape=saddle_system.shape[0], device=saddle_system.device)
293
- b = wp.empty_like(x)
294
-
295
- wp.copy(src=x_u, dest=saddle_system.u_slice(x))
296
- wp.copy(src=x_p, dest=saddle_system.p_slice(x))
297
- wp.copy(src=b_u, dest=saddle_system.u_slice(b))
298
- wp.copy(src=b_p, dest=saddle_system.p_slice(b))
299
-
300
- func = get_linear_solver_func(method_name=method)
301
-
302
- def print_callback(i, err, tol):
303
- print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
304
-
305
- callback = None if quiet else print_callback
306
-
307
- end_iter, err, atol = func(
308
- A=saddle_system,
309
- b=b,
310
- x=x,
311
- maxiter=max_iters,
312
- tol=tol,
313
- check_every=check_every,
314
- M=saddle_system.preconditioner,
315
- callback=callback,
316
- )
317
-
318
- if not quiet:
319
- res_str = "OK" if err <= atol else "TRUNCATED"
320
- print(f"{func.__name__}: terminated after {end_iter} iterations with absolute error = \t {err} ({res_str})")
321
-
322
- wp.copy(dest=x_u, src=saddle_system.u_slice(x))
323
- wp.copy(dest=x_p, src=saddle_system.p_slice(x))
324
-
325
- return err, end_iter
326
-
327
-
328
- @wp.kernel
329
- def _compute_schur_inverse_diagonal(
330
- B_offsets: wp.array(dtype=int),
331
- B_indices: wp.array(dtype=int),
332
- B_values: wp.array(dtype=Any),
333
- A_diag: wp.array(dtype=Any),
334
- P_diag: wp.array(dtype=Any),
335
- ):
336
- row = wp.tid()
337
-
338
- zero = P_diag.dtype(P_diag.dtype.dtype(0.0))
339
-
340
- schur = zero
341
-
342
- beg = B_offsets[row]
343
- end = B_offsets[row + 1]
344
-
345
- for b in range(beg, end):
346
- B = B_values[b]
347
- col = B_indices[b]
348
- Ai = wp.inverse(A_diag[col])
349
- S = B * Ai * wp.transpose(B)
350
- schur += S
351
-
352
- schur_diag = wp.get_diag(schur)
353
- id_diag = type(schur_diag)(schur_diag.dtype(1.0))
354
-
355
- inv_diag = wp.select(schur == zero, wp.cw_div(id_diag, schur_diag), id_diag)
356
- P_diag[row] = wp.diag(inv_diag)
357
-
358
-
359
- def invert_diagonal_bsr_mass_matrix(A: BsrMatrix):
360
- """Inverts each block of a block-diagonal mass matrix"""
361
-
362
- scale = A.scalar_type(A.block_shape[0])
363
- values = A.values
364
- if not wp.types.type_is_matrix(values.dtype):
365
- values = values.view(dtype=wp.mat(shape=(1, 1), dtype=A.scalar_type))
366
-
367
- wp.launch(
368
- kernel=_block_diagonal_mass_invert,
369
- dim=A.nrow,
370
- inputs=[values, scale],
371
- device=values.device,
372
- )
373
-
374
-
375
- @wp.kernel
376
- def _block_diagonal_mass_invert(values: wp.array(dtype=Any), scale: Any):
377
- i = wp.tid()
378
- values[i] = scale * values[i] / wp.ddot(values[i], values[i])
@@ -1,133 +0,0 @@
1
- from typing import Optional
2
-
3
- import numpy as np
4
-
5
- import warp as wp
6
- from warp.fem.utils import grid_to_hexes, grid_to_quads, grid_to_tets, grid_to_tris
7
-
8
-
9
- def gen_trimesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
10
- """Constructs a triangular mesh by diving each cell of a dense 2D grid into two triangles
11
-
12
- Args:
13
- res: Resolution of the grid along each dimension
14
- bounds_lo: Position of the lower bound of the axis-aligned grid
15
- bounds_up: Position of the upper bound of the axis-aligned grid
16
-
17
- Returns:
18
- Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
19
- """
20
-
21
- if bounds_lo is None:
22
- bounds_lo = wp.vec2(0.0)
23
-
24
- if bounds_hi is None:
25
- bounds_hi = wp.vec2(1.0)
26
-
27
- Nx = res[0]
28
- Ny = res[1]
29
-
30
- x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
31
- y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
32
-
33
- positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
34
-
35
- vidx = grid_to_tris(Nx, Ny)
36
-
37
- return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
38
-
39
-
40
- def gen_tetmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
41
- """Constructs a tetrahedral mesh by diving each cell of a dense 3D grid into five tetrahedrons
42
-
43
- Args:
44
- res: Resolution of the grid along each dimension
45
- bounds_lo: Position of the lower bound of the axis-aligned grid
46
- bounds_up: Position of the upper bound of the axis-aligned grid
47
-
48
- Returns:
49
- Tuple of ndarrays: (Vertex positions, Tetrahedron vertex indices)
50
- """
51
-
52
- if bounds_lo is None:
53
- bounds_lo = wp.vec3(0.0)
54
-
55
- if bounds_hi is None:
56
- bounds_hi = wp.vec3(1.0)
57
-
58
- Nx = res[0]
59
- Ny = res[1]
60
- Nz = res[2]
61
-
62
- x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
63
- y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
64
- z = np.linspace(bounds_lo[2], bounds_hi[2], Nz + 1)
65
-
66
- positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
67
-
68
- vidx = grid_to_tets(Nx, Ny, Nz)
69
-
70
- return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)
71
-
72
-
73
- def gen_quadmesh(res, bounds_lo: Optional[wp.vec2] = None, bounds_hi: Optional[wp.vec2] = None):
74
- """Constructs a quadrilateral mesh from a dense 2D grid
75
-
76
- Args:
77
- res: Resolution of the grid along each dimension
78
- bounds_lo: Position of the lower bound of the axis-aligned grid
79
- bounds_up: Position of the upper bound of the axis-aligned grid
80
-
81
- Returns:
82
- Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
83
- """
84
- if bounds_lo is None:
85
- bounds_lo = wp.vec2(0.0)
86
-
87
- if bounds_hi is None:
88
- bounds_hi = wp.vec2(1.0)
89
-
90
- Nx = res[0]
91
- Ny = res[1]
92
-
93
- x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
94
- y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
95
-
96
- positions = np.transpose(np.meshgrid(x, y, indexing="ij"), axes=(1, 2, 0)).reshape(-1, 2)
97
-
98
- vidx = grid_to_quads(Nx, Ny)
99
-
100
- return wp.array(positions, dtype=wp.vec2), wp.array(vidx, dtype=int)
101
-
102
-
103
- def gen_hexmesh(res, bounds_lo: Optional[wp.vec3] = None, bounds_hi: Optional[wp.vec3] = None):
104
- """Constructs a quadrilateral mesh from a dense 2D grid
105
-
106
- Args:
107
- res: Resolution of the grid along each dimension
108
- bounds_lo: Position of the lower bound of the axis-aligned grid
109
- bounds_up: Position of the upper bound of the axis-aligned grid
110
-
111
- Returns:
112
- Tuple of ndarrays: (Vertex positions, Triangle vertex indices)
113
- """
114
-
115
- if bounds_lo is None:
116
- bounds_lo = wp.vec3(0.0)
117
-
118
- if bounds_hi is None:
119
- bounds_hi = wp.vec3(1.0)
120
-
121
- Nx = res[0]
122
- Ny = res[1]
123
- Nz = res[2]
124
-
125
- x = np.linspace(bounds_lo[0], bounds_hi[0], Nx + 1)
126
- y = np.linspace(bounds_lo[1], bounds_hi[1], Ny + 1)
127
- z = np.linspace(bounds_lo[1], bounds_hi[1], Nz + 1)
128
-
129
- positions = np.transpose(np.meshgrid(x, y, z, indexing="ij"), axes=(1, 2, 3, 0)).reshape(-1, 3)
130
-
131
- vidx = grid_to_hexes(Nx, Ny, Nz)
132
-
133
- return wp.array(positions, dtype=wp.vec3), wp.array(vidx, dtype=int)