warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,380 @@
1
+ from typing import Union, Any, Tuple, Optional
2
+
3
+ import warp as wp
4
+ import warp.types
5
+
6
+ from warp.sparse import BsrMatrix, bsr_zeros, bsr_transposed, bsr_mv, bsr_get_diag
7
+ from warp.optim.linear import preconditioner, LinearOperator, aslinearoperator
8
+
9
+
10
+ def bsr_to_scipy(matrix: BsrMatrix) -> "scipy.sparse.bsr_array":
11
+ try:
12
+ from scipy.sparse import csr_array, bsr_array
13
+ except ImportError:
14
+ # WAR for older scipy
15
+ from scipy.sparse import csr_matrix as csr_array, bsr_matrix as bsr_array
16
+
17
+ if matrix.block_shape == (1, 1):
18
+ return csr_array(
19
+ (
20
+ matrix.values.numpy().flatten()[: matrix.nnz],
21
+ matrix.columns.numpy()[: matrix.nnz],
22
+ matrix.offsets.numpy(),
23
+ ),
24
+ shape=matrix.shape,
25
+ )
26
+
27
+ return bsr_array(
28
+ (
29
+ matrix.values.numpy().reshape((matrix.values.shape[0], *matrix.block_shape))[: matrix.nnz],
30
+ matrix.columns.numpy()[: matrix.nnz],
31
+ matrix.offsets.numpy(),
32
+ ),
33
+ shape=matrix.shape,
34
+ )
35
+
36
+
37
+ def scipy_to_bsr(
38
+ sp: Union["scipy.sparse.bsr_array", "scipy.sparse.csr_array"],
39
+ device=None,
40
+ dtype=None,
41
+ ) -> BsrMatrix:
42
+ try:
43
+ from scipy.sparse import csr_array
44
+ except ImportError:
45
+ # WAR for older scipy
46
+ from scipy.sparse import csr_matrix as csr_array
47
+
48
+ if dtype is None:
49
+ dtype = warp.types.np_dtype_to_warp_type[sp.dtype]
50
+
51
+ sp.sort_indices()
52
+
53
+ if isinstance(sp, csr_array):
54
+ matrix = bsr_zeros(sp.shape[0], sp.shape[1], dtype, device=device)
55
+ else:
56
+ block_shape = sp.blocksize
57
+ block_type = wp.types.matrix(shape=block_shape, dtype=dtype)
58
+ matrix = bsr_zeros(
59
+ sp.shape[0] // block_shape[0],
60
+ sp.shape[1] // block_shape[1],
61
+ block_type,
62
+ device=device,
63
+ )
64
+
65
+ matrix.nnz = sp.nnz
66
+ matrix.values = wp.array(sp.data.flatten(), dtype=matrix.values.dtype, device=device)
67
+ matrix.columns = wp.array(sp.indices, dtype=matrix.columns.dtype, device=device)
68
+ matrix.offsets = wp.array(sp.indptr, dtype=matrix.offsets.dtype, device=device)
69
+
70
+ return matrix
71
+
72
+
73
+ def get_linear_solver_func(method_name: str):
74
+ from warp.optim.linear import cg, bicgstab, gmres
75
+
76
+ if method_name == "bicgstab":
77
+ return bicgstab
78
+ if method_name == "gmres":
79
+ return gmres
80
+ return cg
81
+
82
+
83
+ def bsr_cg(
84
+ A: BsrMatrix,
85
+ x: wp.array,
86
+ b: wp.array,
87
+ max_iters: int = 0,
88
+ tol: float = 0.0001,
89
+ check_every=10,
90
+ use_diag_precond=True,
91
+ mv_routine=None,
92
+ quiet=False,
93
+ method: str = "cg",
94
+ ) -> Tuple[float, int]:
95
+ """Solves the linear system A x = b using an iterative solver, optionally with diagonal preconditioning
96
+
97
+ Args:
98
+ A: system left-hand side
99
+ x: result vector and initial guess
100
+ b: system right-hand-side
101
+ max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
102
+ tol: relative tolerance under which to stop the solve
103
+ check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
104
+ use_diag_precond: Whether to use diagonal preconditioning
105
+ mv_routine: Matrix-vector multiplication routine to use for multiplications with ``A``
106
+ quiet: if True, do not print iteration residuals
107
+ method: Iterative solver method to use, defaults to Conjugate Gradient
108
+
109
+ Returns:
110
+ Tuple (residual norm, iteration count)
111
+
112
+ """
113
+
114
+ from warp.optim.linear import preconditioner, LinearOperator
115
+
116
+ if mv_routine is None:
117
+ M = preconditioner(A, "diag") if use_diag_precond else None
118
+ else:
119
+ A = LinearOperator(A.shape, A.dtype, A.device, matvec=mv_routine)
120
+ M = None
121
+
122
+ func = get_linear_solver_func(method_name=method)
123
+
124
+ def print_callback(i, err, tol):
125
+ print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
126
+
127
+ callback = None if quiet else print_callback
128
+
129
+ end_iter, err, atol = func(
130
+ A=A,
131
+ b=b,
132
+ x=x,
133
+ maxiter=max_iters,
134
+ tol=tol,
135
+ check_every=check_every,
136
+ M=M,
137
+ callback=callback,
138
+ )
139
+
140
+ if not quiet:
141
+ res_str = "OK" if err <= atol else "TRUNCATED"
142
+ print(f"{func.__name__}: terminated after {end_iter} iterations with error = \t {err} ({res_str})")
143
+
144
+ return err, end_iter
145
+
146
+
147
+ class SaddleSystem(LinearOperator):
148
+ """Builds a linear operator corresponding to the saddle-point linear system [A B^T; B 0]
149
+
150
+ If use_diag_precond` is ``True``, builds the corresponding diagonal preconditioner `[diag(A); diag(B diag(A)^-1 B^T)]`
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ A: BsrMatrix,
156
+ B: BsrMatrix,
157
+ Bt: Optional[BsrMatrix] = None,
158
+ use_diag_precond: bool = True,
159
+ ):
160
+ from warp.optim.linear import preconditioner, LinearOperator, aslinearoperator
161
+
162
+ if Bt is None:
163
+ Bt = bsr_transposed(B)
164
+
165
+ self._A = A
166
+ self._B = B
167
+ self._Bt = Bt
168
+
169
+ self._u_dtype = wp.vec(length=A.block_shape[0], dtype=A.scalar_type)
170
+ self._p_dtype = wp.vec(length=B.block_shape[0], dtype=B.scalar_type)
171
+ self._p_byte_offset = A.nrow * wp.types.type_size_in_bytes(self._u_dtype)
172
+
173
+ saddle_shape = (A.shape[0] + B.shape[0], A.shape[0] + B.shape[0])
174
+
175
+ super().__init__(saddle_shape, dtype=A.scalar_type, device=A.device, matvec=self._saddle_mv)
176
+
177
+ if use_diag_precond:
178
+ self._preconditioner = self._diag_preconditioner()
179
+ else:
180
+ self._preconditioner = None
181
+
182
+ def _diag_preconditioner(self):
183
+ A = self._A
184
+ B = self._B
185
+
186
+ M_u = preconditioner(A, "diag")
187
+
188
+ A_diag = bsr_get_diag(A)
189
+
190
+ schur_block_shape = (B.block_shape[0], B.block_shape[0])
191
+ schur_dtype = wp.mat(shape=schur_block_shape, dtype=B.scalar_type)
192
+ schur_inv_diag = wp.empty(dtype=schur_dtype, shape=B.nrow, device=self.device)
193
+ wp.launch(
194
+ _compute_schur_inverse_diagonal,
195
+ dim=B.nrow,
196
+ device=A.device,
197
+ inputs=[B.offsets, B.columns, B.values, A_diag, schur_inv_diag],
198
+ )
199
+
200
+ if schur_block_shape == (1, 1):
201
+ # Downcast 1x1 mats to scalars
202
+ schur_inv_diag = schur_inv_diag.view(dtype=B.scalar_type)
203
+
204
+ M_p = aslinearoperator(schur_inv_diag)
205
+
206
+ def precond_mv(x, y, z, alpha, beta):
207
+ x_u = self.u_slice(x)
208
+ x_p = self.p_slice(x)
209
+ y_u = self.u_slice(y)
210
+ y_p = self.p_slice(y)
211
+ z_u = self.u_slice(z)
212
+ z_p = self.p_slice(z)
213
+
214
+ M_u.matvec(x_u, y_u, z_u, alpha=alpha, beta=beta)
215
+ M_p.matvec(x_p, y_p, z_p, alpha=alpha, beta=beta)
216
+
217
+ return LinearOperator(
218
+ shape=self.shape,
219
+ dtype=self.dtype,
220
+ device=self.device,
221
+ matvec=precond_mv,
222
+ )
223
+
224
+ @property
225
+ def preconditioner(self):
226
+ return self._preconditioner
227
+
228
+ def u_slice(self, a: wp.array):
229
+ return wp.array(
230
+ ptr=a.ptr,
231
+ dtype=self._u_dtype,
232
+ shape=self._A.nrow,
233
+ strides=None,
234
+ device=a.device,
235
+ pinned=a.pinned,
236
+ copy=False,
237
+ )
238
+
239
+ def p_slice(self, a: wp.array):
240
+ return wp.array(
241
+ ptr=a.ptr + self._p_byte_offset,
242
+ dtype=self._p_dtype,
243
+ shape=self._B.nrow,
244
+ strides=None,
245
+ device=a.device,
246
+ pinned=a.pinned,
247
+ copy=False,
248
+ )
249
+
250
+ def _saddle_mv(self, x, y, z, alpha, beta):
251
+ x_u = self.u_slice(x)
252
+ x_p = self.p_slice(x)
253
+ z_u = self.u_slice(z)
254
+ z_p = self.p_slice(z)
255
+
256
+ if y.ptr != z.ptr and beta != 0.0:
257
+ wp.copy(src=y, dest=z)
258
+
259
+ bsr_mv(self._A, x_u, z_u, alpha=alpha, beta=beta)
260
+ bsr_mv(self._Bt, x_p, z_u, alpha=alpha, beta=1.0)
261
+ bsr_mv(self._B, x_u, z_p, alpha=alpha, beta=beta)
262
+
263
+
264
+ def bsr_solve_saddle(
265
+ saddle_system: SaddleSystem,
266
+ x_u: wp.array,
267
+ x_p: wp.array,
268
+ b_u: wp.array,
269
+ b_p: wp.array,
270
+ max_iters: int = 0,
271
+ tol: float = 0.0001,
272
+ check_every=10,
273
+ quiet=False,
274
+ method: str = "cg",
275
+ ) -> Tuple[float, int]:
276
+ """Solves the saddle-point linear system [A B^T; B 0] (x_u; x_p) = (b_u; b_p) using an iterative solver, optionally with diagonal preconditioning
277
+
278
+ Args:
279
+ saddle_system: Saddle point system
280
+ x_u: primal part of the result vector and initial guess
281
+ x_p: Lagrange multiplier part of the result vector and initial guess
282
+ b_u: primal left-hand-side
283
+ b_p: constraint left-hand-side
284
+ max_iters: maximum number of iterations to perform before aborting. If set to zero, equal to the system size.
285
+ tol: relative tolerance under which to stop the solve
286
+ check_every: number of iterations every which to evaluate the current residual norm to compare against tolerance
287
+ quiet: if True, do not print iteration residuals
288
+ method: Iterative solver method to use, defaults to BiCGSTAB
289
+
290
+ Returns:
291
+ Tuple (residual norm, iteration count)
292
+
293
+ """
294
+ x = wp.empty(dtype=saddle_system.scalar_type, shape=saddle_system.shape[0], device=saddle_system.device)
295
+ b = wp.empty_like(x)
296
+
297
+ wp.copy(src=x_u, dest=saddle_system.u_slice(x))
298
+ wp.copy(src=x_p, dest=saddle_system.p_slice(x))
299
+ wp.copy(src=b_u, dest=saddle_system.u_slice(b))
300
+ wp.copy(src=b_p, dest=saddle_system.p_slice(b))
301
+
302
+ func = get_linear_solver_func(method_name=method)
303
+
304
+ def print_callback(i, err, tol):
305
+ print(f"{func.__name__}: at iteration {i} error = \t {err} \t tol: {tol}")
306
+
307
+ callback = None if quiet else print_callback
308
+
309
+ end_iter, err, atol = func(
310
+ A=saddle_system,
311
+ b=b,
312
+ x=x,
313
+ maxiter=max_iters,
314
+ tol=tol,
315
+ check_every=check_every,
316
+ M=saddle_system.preconditioner,
317
+ callback=callback,
318
+ )
319
+
320
+ if not quiet:
321
+ res_str = "OK" if err <= atol else "TRUNCATED"
322
+ print(f"{func.__name__}: terminated after {end_iter} iterations with absolute error = \t {err} ({res_str})")
323
+
324
+ wp.copy(dest=x_u, src=saddle_system.u_slice(x))
325
+ wp.copy(dest=x_p, src=saddle_system.p_slice(x))
326
+
327
+ return err, end_iter
328
+
329
+
330
+ @wp.kernel
331
+ def _compute_schur_inverse_diagonal(
332
+ B_offsets: wp.array(dtype=int),
333
+ B_indices: wp.array(dtype=int),
334
+ B_values: wp.array(dtype=Any),
335
+ A_diag: wp.array(dtype=Any),
336
+ P_diag: wp.array(dtype=Any),
337
+ ):
338
+ row = wp.tid()
339
+
340
+ zero = P_diag.dtype(P_diag.dtype.dtype(0.0))
341
+
342
+ schur = zero
343
+
344
+ beg = B_offsets[row]
345
+ end = B_offsets[row + 1]
346
+
347
+ for b in range(beg, end):
348
+ B = B_values[b]
349
+ col = B_indices[b]
350
+ Ai = wp.inverse(A_diag[col])
351
+ S = B * Ai * wp.transpose(B)
352
+ schur += S
353
+
354
+ schur_diag = wp.get_diag(schur)
355
+ id_diag = type(schur_diag)(schur_diag.dtype(1.0))
356
+
357
+ inv_diag = wp.select(schur == zero, wp.cw_div(id_diag, schur_diag), id_diag)
358
+ P_diag[row] = wp.diag(inv_diag)
359
+
360
+
361
+ def invert_diagonal_bsr_mass_matrix(A: BsrMatrix):
362
+ """Inverts each block of a block-diagonal mass matrix"""
363
+
364
+ scale = A.scalar_type(A.block_shape[0])
365
+ values = A.values
366
+ if not wp.types.type_is_matrix(values.dtype):
367
+ values = values.view(dtype=wp.mat(shape=(1, 1), dtype=A.scalar_type))
368
+
369
+ wp.launch(
370
+ kernel=_block_diagonal_mass_invert,
371
+ dim=A.nrow,
372
+ inputs=[values, scale],
373
+ device=values.device,
374
+ )
375
+
376
+
377
+ @wp.kernel
378
+ def _block_diagonal_mass_invert(values: wp.array(dtype=Any), scale: Any):
379
+ i = wp.tid()
380
+ values[i] = scale * values[i] / wp.ddot(values[i], values[i])