warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -128,6 +128,10 @@ class Example:
128
128
  self.num_substeps = 10
129
129
  self.iterations = 4
130
130
  self.dt = self.frame_dt / self.num_substeps
131
+ # the BVH used by VBDIntegrator will be rebuilt every self.bvh_rebuild_frames
132
+ # When the simulated object deforms significantly, simply refitting the BVH can lead to deterioration of the BVH's
133
+ # quality, in this case we need to completely rebuild the tree to achieve better query efficiency.
134
+ self.bvh_rebuild_frames = 10
131
135
 
132
136
  self.num_frames = num_frames
133
137
  self.sim_time = 0.0
@@ -227,69 +231,62 @@ class Example:
227
231
  self.cuda_graph = None
228
232
  if self.use_cuda_graph:
229
233
  with wp.ScopedCapture() as capture:
230
- for _ in range(self.num_substeps):
231
- wp.launch(
232
- kernel=apply_rotation,
233
- dim=self.rot_point_indices.shape[0],
234
- inputs=[
235
- self.rot_point_indices,
236
- self.rot_axes,
237
- self.roots,
238
- self.roots_to_ps,
239
- self.t,
240
- self.rot_angular_velocity,
241
- self.dt,
242
- self.rot_end_time,
243
- ],
244
- outputs=[
245
- self.state0.particle_q,
246
- self.state1.particle_q,
247
- ],
248
- )
249
-
250
- self.integrator.simulate(self.model, self.state0, self.state1, self.dt, None)
251
- (self.state0, self.state1) = (self.state1, self.state0)
234
+ self.integrate_frame_substeps()
252
235
 
253
236
  self.cuda_graph = capture.graph
254
237
 
255
- def step(self):
238
+ def integrate_frame_substeps(self):
239
+ for _ in range(self.num_substeps):
240
+ wp.launch(
241
+ kernel=apply_rotation,
242
+ dim=self.rot_point_indices.shape[0],
243
+ inputs=[
244
+ self.rot_point_indices,
245
+ self.rot_axes,
246
+ self.roots,
247
+ self.roots_to_ps,
248
+ self.t,
249
+ self.rot_angular_velocity,
250
+ self.dt,
251
+ self.rot_end_time,
252
+ ],
253
+ outputs=[
254
+ self.state0.particle_q,
255
+ self.state1.particle_q,
256
+ ],
257
+ )
258
+
259
+ self.integrator.simulate(self.model, self.state0, self.state1, self.dt, None)
260
+ (self.state0, self.state1) = (self.state1, self.state0)
261
+
262
+ def advance_frame(self):
256
263
  with wp.ScopedTimer("step", print=False, dict=self.profiler):
257
264
  if self.use_cuda_graph:
258
265
  wp.capture_launch(self.cuda_graph)
259
266
  else:
260
- for _ in range(self.num_substeps):
261
- wp.launch(
262
- kernel=apply_rotation,
263
- dim=self.rot_point_indices.shape[0],
264
- inputs=[
265
- self.rot_point_indices,
266
- self.rot_axes,
267
- self.roots,
268
- self.roots_to_ps,
269
- self.t,
270
- self.rot_angular_velocity,
271
- self.dt,
272
- self.rot_end_time,
273
- ],
274
- outputs=[
275
- self.state0.particle_q,
276
- self.state1.particle_q,
277
- ],
278
- )
279
- self.integrator.simulate(self.model, self.state0, self.state1, self.dt)
280
-
281
- (self.state0, self.state1) = (self.state1, self.state0)
267
+ self.integrate_frame_substeps()
282
268
 
283
269
  self.sim_time += self.dt
284
270
 
271
+ def run(self):
272
+ for i in range(self.num_frames):
273
+ self.advance_frame()
274
+ self.render()
275
+ print(f"[{i:4d}/{self.num_frames}]")
276
+
277
+ if i != 0 and not i % self.bvh_rebuild_frames and self.use_cuda_graph:
278
+ self.integrator.rebuild_bvh(self.state0)
279
+ with wp.ScopedCapture() as capture:
280
+ self.integrate_frame_substeps()
281
+ self.cuda_graph = capture.graph
282
+
285
283
  def render(self):
286
284
  if self.renderer is None:
287
285
  return
288
286
 
289
- with wp.ScopedTimer("render", print=False):
290
- self.renderer.begin_frame(self.sim_time)
291
- self.renderer.render(self.state0)
292
- self.renderer.end_frame()
287
+ self.renderer.begin_frame(self.sim_time)
288
+ self.renderer.render(self.state0)
289
+ self.renderer.end_frame()
293
290
 
294
291
 
295
292
  if __name__ == "__main__":
@@ -310,13 +307,10 @@ if __name__ == "__main__":
310
307
  with wp.ScopedDevice(args.device):
311
308
  example = Example(stage_path=args.stage_path, num_frames=args.num_frames)
312
309
 
313
- for i in range(example.num_frames):
314
- example.step()
315
- example.render()
316
- print(f"[{i:4d}/{example.num_frames}]")
310
+ example.run()
317
311
 
318
312
  frame_times = example.profiler["step"]
319
- print("\nAverage frame sim time: {:.2f} ms".format(sum(frame_times) / len(frame_times)))
313
+ print(f"\nAverage frame sim time: {sum(frame_times) / len(frame_times):.2f} ms")
320
314
 
321
315
  if example.renderer:
322
316
  example.renderer.save()
@@ -0,0 +1,502 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Tile Block Cholesky
18
+ #
19
+ # Shows how to write a kernel computing a blocked Cholesky factorization
20
+ # of a symmetric positive definite matrix using Warp Tile APIs.
21
+ #
22
+ ###########################################################################
23
+
24
+ from functools import lru_cache
25
+
26
+ import numpy as np
27
+
28
+ import warp as wp
29
+
30
+ wp.set_module_options({"enable_backward": False})
31
+
32
+
33
+ @lru_cache(maxsize=None)
34
+ def create_blocked_cholesky_kernel(block_size: int):
35
+ @wp.kernel
36
+ def blocked_cholesky_kernel(
37
+ A: wp.array(dtype=float, ndim=2),
38
+ L: wp.array(dtype=float, ndim=2),
39
+ active_matrix_size_arr: wp.array(dtype=int, ndim=1),
40
+ ):
41
+ """
42
+ Computes the Cholesky factorization of a symmetric positive definite matrix A in blocks.
43
+ It returns a lower-triangular matrix L such that A = L L^T.
44
+
45
+ A is assumed to support block reading.
46
+ """
47
+ tid, tid_block = wp.tid()
48
+ num_threads_per_block = wp.block_dim()
49
+
50
+ active_matrix_size = active_matrix_size_arr[0]
51
+
52
+ # Round up active_matrix_size to next multiple of block_size
53
+ n = ((active_matrix_size + block_size - 1) // block_size) * block_size
54
+
55
+ # Process the matrix in blocks along its leading dimension.
56
+ for k in range(0, n, block_size):
57
+ end = k + block_size
58
+
59
+ # Load current diagonal block A[k:end, k:end]
60
+ # and update with contributions from previously computed blocks.
61
+ A_kk_tile = wp.tile_load(A, shape=(block_size, block_size), offset=(k, k), storage="shared")
62
+ # The following if pads the matrix if it is not divisible by block_size
63
+ if k + block_size > active_matrix_size:
64
+ num_tile_elements = block_size * block_size
65
+ num_iterations = (num_tile_elements + num_threads_per_block - 1) // num_threads_per_block
66
+
67
+ for i in range(num_iterations):
68
+ linear_index = tid_block + i * num_threads_per_block
69
+ linear_index = linear_index % num_tile_elements
70
+ row = linear_index // block_size
71
+ col = linear_index % block_size
72
+ value = A_kk_tile[row, col]
73
+ if k + row >= active_matrix_size or k + col >= active_matrix_size:
74
+ value = wp.where(row == col, float(1), float(0))
75
+ A_kk_tile[row, col] = value
76
+
77
+ if k > 0:
78
+ for j in range(0, k, block_size):
79
+ L_block = wp.tile_load(L, shape=(block_size, block_size), offset=(k, j))
80
+ L_block_T = wp.tile_transpose(L_block)
81
+ L_L_T_block = wp.tile_matmul(L_block, L_block_T)
82
+ A_kk_tile -= L_L_T_block
83
+
84
+ # Compute the Cholesky factorization for the block
85
+ L_kk_tile = wp.tile_cholesky(A_kk_tile)
86
+ wp.tile_store(L, L_kk_tile, offset=(k, k))
87
+
88
+ # Process the blocks below the current block
89
+ for i in range(end, n, block_size):
90
+ A_ik_tile = wp.tile_load(A, shape=(block_size, block_size), offset=(i, k), storage="shared")
91
+ # The following if pads the matrix if it is not divisible by block_size
92
+ if i + block_size > active_matrix_size or k + block_size > active_matrix_size:
93
+ num_tile_elements = block_size * block_size
94
+ num_iterations = (num_tile_elements + num_threads_per_block - 1) // num_threads_per_block
95
+
96
+ for ii in range(num_iterations):
97
+ linear_index = tid_block + ii * num_threads_per_block
98
+ linear_index = linear_index % num_tile_elements
99
+ row = linear_index // block_size
100
+ col = linear_index % block_size
101
+ value = A_ik_tile[row, col]
102
+ if i + row >= active_matrix_size or k + col >= active_matrix_size:
103
+ value = wp.where(i + row == k + col, float(1), float(0))
104
+ A_ik_tile[row, col] = value
105
+
106
+ if k > 0:
107
+ for j in range(0, k, block_size):
108
+ L_tile = wp.tile_load(L, shape=(block_size, block_size), offset=(i, j))
109
+ L_2_tile = wp.tile_load(L, shape=(block_size, block_size), offset=(k, j))
110
+ L_T_tile = wp.tile_transpose(L_2_tile)
111
+ L_L_T_tile = wp.tile_matmul(L_tile, L_T_tile)
112
+ A_ik_tile -= L_L_T_tile
113
+
114
+ t = wp.tile_transpose(A_ik_tile)
115
+ tmp = wp.tile_lower_solve(L_kk_tile, t)
116
+ sol_tile = wp.tile_transpose(tmp)
117
+
118
+ wp.tile_store(L, sol_tile, offset=(i, k))
119
+
120
+ return blocked_cholesky_kernel
121
+
122
+
123
+ @lru_cache(maxsize=None)
124
+ def create_blocked_cholesky_solve_kernel(block_size: int):
125
+ @wp.kernel
126
+ def blocked_cholesky_solve_kernel(
127
+ L: wp.array(dtype=float, ndim=2),
128
+ b: wp.array(dtype=float, ndim=2),
129
+ x: wp.array(dtype=float, ndim=2),
130
+ y: wp.array(dtype=float, ndim=2),
131
+ active_matrix_size_arr: wp.array(dtype=int, ndim=1),
132
+ ):
133
+ """
134
+ Solves A x = b given the Cholesky factor L (A = L L^T) using
135
+ blocked forward and backward substitution.
136
+
137
+ b can be a vector or 2-D array with multiple right-hand sides.
138
+ """
139
+
140
+ active_matrix_size = active_matrix_size_arr[0]
141
+
142
+ # Round up active_matrix_size to next multiple of block_size
143
+ n = ((active_matrix_size + block_size - 1) // block_size) * block_size
144
+
145
+ # Forward substitution: solve L y = b
146
+ for i in range(0, n, block_size):
147
+ i_end = i + block_size
148
+ rhs_tile = wp.tile_load(b, shape=(block_size, 1), offset=(i, 0))
149
+ if i > 0:
150
+ for j in range(0, i, block_size):
151
+ L_block = wp.tile_load(L, shape=(block_size, block_size), offset=(i, j))
152
+ y_block = wp.tile_load(y, shape=(block_size, 1), offset=(j, 0))
153
+ Ly_block = wp.tile_matmul(L_block, y_block)
154
+ rhs_tile -= Ly_block
155
+ L_tile = wp.tile_load(L, shape=(block_size, block_size), offset=(i, i))
156
+ y_tile = wp.tile_lower_solve(L_tile, rhs_tile)
157
+ wp.tile_store(y, y_tile, offset=(i, 0))
158
+
159
+ # Backward substitution: solve L^T x = y
160
+ for i in range(n - block_size, -1, -block_size):
161
+ i_start = i
162
+ i_end = i_start + block_size
163
+ rhs_tile = wp.tile_load(y, shape=(block_size, 1), offset=(i_start, 0))
164
+ if i_end < n:
165
+ for j in range(i_end, n, block_size):
166
+ L_tile = wp.tile_load(L, shape=(block_size, block_size), offset=(j, i_start))
167
+ L_T_tile = wp.tile_transpose(L_tile)
168
+ x_tile = wp.tile_load(x, shape=(block_size, 1), offset=(j, 0))
169
+ L_T_x_tile = wp.tile_matmul(L_T_tile, x_tile)
170
+ rhs_tile -= L_T_x_tile
171
+ L_tile = wp.tile_load(L, shape=(block_size, block_size), offset=(i_start, i_start))
172
+ x_tile = wp.tile_upper_solve(wp.tile_transpose(L_tile), rhs_tile)
173
+ wp.tile_store(x, x_tile, offset=(i_start, 0))
174
+
175
+ return blocked_cholesky_solve_kernel
176
+
177
+
178
+ # TODO: Add batching support to solve multiple equation systems at once (one per thread block)
179
+ class BlockCholeskySolver:
180
+ """
181
+ A class for solving linear systems using the Cholesky factorization.
182
+ """
183
+
184
+ def __init__(self, max_num_equations: int, block_size=16, device="cuda"):
185
+ # Round up max_num_equations to next multiple of block_size
186
+ max_num_equations = ((max_num_equations + block_size - 1) // block_size) * block_size
187
+
188
+ self.max_num_equations = max_num_equations
189
+ self.device = device
190
+
191
+ self.num_threads_per_block_factorize = 128
192
+ self.num_threads_per_block_solve = 64
193
+ self.active_matrix_size_int = -1
194
+
195
+ self.block_size = block_size
196
+ self.cholesky_kernel = create_blocked_cholesky_kernel(block_size)
197
+ self.solve_kernel = create_blocked_cholesky_solve_kernel(block_size)
198
+
199
+ # Allocate workspace arrays for factorization and solve
200
+ self.L = wp.zeros(shape=(self.max_num_equations, self.max_num_equations), dtype=float, device=self.device)
201
+ self.y = wp.zeros(shape=(self.max_num_equations, 1), dtype=float, device=self.device) # temp memory
202
+ self.active_matrix_size = wp.zeros(
203
+ shape=(1,), dtype=int, device=self.device
204
+ ) # array to hold active matrix size
205
+ self.active_matrix_size_external = None
206
+
207
+ def factorize(self, A: wp.array(dtype=float, ndim=2), num_active_equations: int):
208
+ """
209
+ Computes the Cholesky factorization of a symmetric positive definite matrix A in blocks.
210
+ It returns a lower-triangular matrix L such that A = L L^T.
211
+ """
212
+
213
+ assert num_active_equations <= self.max_num_equations, (
214
+ f"Number of active equations ({num_active_equations}) exceeds maximum allowed ({self.max_num_equations})"
215
+ )
216
+
217
+ padded_n = ((num_active_equations + self.block_size - 1) // self.block_size) * self.block_size
218
+
219
+ # Verify input dimensions
220
+ assert A.shape[0] == A.shape[1], "Matrix A must be square"
221
+ assert A.shape[0] >= padded_n, f"Matrix A must be at least {padded_n}x{padded_n} to accommodate padding"
222
+
223
+ self.active_matrix_size.zero_()
224
+ wp.copy(self.active_matrix_size, wp.array([num_active_equations], dtype=int, device=self.device))
225
+
226
+ self.factorize_dynamic(A, self.active_matrix_size)
227
+
228
+ self.active_matrix_size_external = None
229
+ self.active_matrix_size_int = num_active_equations
230
+
231
+ def factorize_dynamic(self, A: wp.array(dtype=float, ndim=2), num_active_equations: wp.array(dtype=int, ndim=1)):
232
+ """
233
+ Computes the Cholesky factorization of a symmetric positive definite matrix A in blocks.
234
+ It returns a lower-triangular matrix L such that A = L L^T.
235
+ """
236
+
237
+ self.active_matrix_size_external = num_active_equations
238
+ self.active_matrix_size_int = -1
239
+
240
+ wp.launch_tiled(
241
+ self.cholesky_kernel,
242
+ dim=1,
243
+ inputs=[A, self.L, num_active_equations],
244
+ block_dim=self.num_threads_per_block_factorize,
245
+ device=self.device,
246
+ )
247
+
248
+ def solve(self, rhs: wp.array(dtype=float, ndim=2), result: wp.array(dtype=float, ndim=2)):
249
+ """
250
+ Solves A x = b given the Cholesky factor L (A = L L^T) using
251
+ blocked forward and backward substitution.
252
+
253
+ b can be a vector or 2-D array with multiple right-hand sides.
254
+ """
255
+
256
+ # Do safety checks but they can only be done if the matrix size is known on the host
257
+ if self.active_matrix_size_int > 0:
258
+ n = self.active_matrix_size_int
259
+ padded_n = ((n + self.block_size - 1) // self.block_size) * self.block_size
260
+
261
+ # Verify input dimensions
262
+ assert rhs.shape[1] == 1, "Matrix b must be a column vector"
263
+ assert rhs.shape[0] >= padded_n, f"Matrix b must be at least {padded_n}x1 to accommodate padding"
264
+
265
+ assert result.shape[1] == 1, "Matrix result must be a column vector"
266
+ assert result.shape[0] >= padded_n, f"Matrix result must be at least {padded_n}x1 to accommodate padding"
267
+
268
+ if self.active_matrix_size_external is not None:
269
+ matrix_size = self.active_matrix_size_external
270
+ else:
271
+ matrix_size = self.active_matrix_size
272
+
273
+ # Then solve the system using blocked_cholesky_solve kernel
274
+ wp.launch_tiled(
275
+ self.solve_kernel,
276
+ dim=1,
277
+ inputs=[self.L, rhs, result, self.y, matrix_size],
278
+ block_dim=self.num_threads_per_block_solve,
279
+ device=self.device,
280
+ )
281
+
282
+
283
+ class CholeskySolverNumPy:
284
+ """
285
+ A class for solving linear systems using the Cholesky factorization.
286
+ """
287
+
288
+ def __init__(self, max_num_equations: int):
289
+ self.max_num_equations = max_num_equations
290
+ self.num_active_equations = 0
291
+
292
+ # Allocate workspace arrays for factorization and solve
293
+ self.L = np.zeros((self.max_num_equations, self.max_num_equations))
294
+ self.y = np.zeros((self.max_num_equations, 1)) # temp memory
295
+
296
+ def factorize(self, A: np.ndarray, num_active_equations: int):
297
+ """
298
+ Computes the Cholesky factorization of a symmetric positive definite matrix A.
299
+ It returns a lower-triangular matrix L such that A = L L^T.
300
+ """
301
+ assert num_active_equations <= self.max_num_equations, (
302
+ f"Number of active equations ({num_active_equations}) exceeds maximum allowed ({self.max_num_equations})"
303
+ )
304
+
305
+ self.num_active_equations = num_active_equations
306
+
307
+ # Verify input dimensions
308
+ assert A.shape[0] == A.shape[1], "Matrix A must be square"
309
+ assert A.shape[0] >= num_active_equations, (
310
+ f"Matrix A must be at least {num_active_equations}x{num_active_equations}"
311
+ )
312
+
313
+ # Compute Cholesky factorization
314
+ self.L[:num_active_equations, :num_active_equations] = np.linalg.cholesky(
315
+ A[:num_active_equations, :num_active_equations]
316
+ )
317
+
318
+ def solve(self, rhs: np.ndarray, result: np.ndarray):
319
+ """
320
+ Solves A x = b given the Cholesky factor L (A = L L^T) using
321
+ forward and backward substitution.
322
+
323
+ b can be a vector or 2-D array with multiple right-hand sides.
324
+ """
325
+ assert self.num_active_equations <= self.max_num_equations, (
326
+ f"Number of active equations ({self.num_active_equations}) exceeds maximum allowed ({self.max_num_equations})"
327
+ )
328
+
329
+ n = self.num_active_equations
330
+
331
+ # Verify input dimensions
332
+ assert rhs.shape[1] == 1, "Matrix b must be a column vector"
333
+ assert rhs.shape[0] >= n, f"Matrix b must be at least {n}x1"
334
+
335
+ assert result.shape[1] == 1, "Matrix result must be a column vector"
336
+ assert result.shape[0] >= n, f"Matrix result must be at least {n}x1"
337
+
338
+ # Forward substitution: L y = b
339
+ self.y[:n] = np.linalg.solve(self.L[:n, :n], rhs[:n])
340
+
341
+ # Backward substitution: L^T x = y
342
+ result[:n] = np.linalg.solve(self.L[:n, :n].T, self.y[:n])
343
+
344
+
345
+ def test_cholesky_solver(n, warp_solver: BlockCholeskySolver, device: str = "cuda"):
346
+ # Create a symmetric positive definite matrix
347
+ rng = np.random.default_rng(0)
348
+ A_full = rng.standard_normal((n, n))
349
+ A_full = A_full @ A_full.T + n * np.eye(n) # ensure SPD
350
+ block_size = warp_solver.block_size
351
+
352
+ # Pad matrix to make it divisible by block_size
353
+ padded_n = ((n + block_size - 1) // block_size) * block_size
354
+ padding = padded_n - n
355
+
356
+ if padding > 0:
357
+ # Pad the original matrix with random values while maintaining SPD
358
+ A_padded = rng.standard_normal((padded_n, padded_n))
359
+ A_padded[:n, :n] = A_full
360
+ padding_block = rng.standard_normal((padding, padding))
361
+ padding_block = padding_block @ padding_block.T + padding * np.eye(padding)
362
+ A_padded[n:, n:] = padding_block
363
+ A_padded[n:, :n] = rng.standard_normal((padding, n))
364
+ A_padded[:n, n:] = A_padded[n:, :n].T # Maintain symmetry
365
+ else:
366
+ A_padded = A_full
367
+
368
+ # Create random RHS vector and pad
369
+ b = rng.standard_normal(n)
370
+ b_padded = rng.standard_normal(padded_n)
371
+ b_padded[:n] = b
372
+
373
+ print("\nSolving with NumPy:")
374
+ # NumPy reference solution
375
+ x = np.linalg.solve(A_full, b)
376
+ L_full = np.linalg.cholesky(A_full)
377
+
378
+ # Verify NumPy solution
379
+ err = np.linalg.norm(A_full - L_full @ L_full.T)
380
+ res_norm = np.linalg.norm(b - A_full @ x)
381
+ print(f"Cholesky factorization error: {err:.3e}")
382
+ print(f"Solution residual norm: {res_norm:.3e}")
383
+
384
+ print("\nSolving with Warp kernels:")
385
+ # Initialize Warp arrays
386
+ A_wp = wp.array(A_padded, dtype=wp.float32, device=device)
387
+ b_wp = wp.array(b_padded, dtype=wp.float32, device=device).reshape((padded_n, 1))
388
+ x_wp = wp.zeros_like(b_wp)
389
+
390
+ # Create and use the Cholesky solver
391
+ warp_solver.factorize(A_wp, n)
392
+ warp_solver.solve(b_wp, x_wp)
393
+ wp.synchronize()
394
+
395
+ # Get result back to CPU and verify
396
+ x_warp = x_wp.numpy()[:n].squeeze()
397
+ L_warp = warp_solver.L.numpy()
398
+
399
+ # Verify Warp solution
400
+ err_warp = np.linalg.norm(A_full - L_warp[:n, :n] @ L_warp[:n, :n].T)
401
+ res_norm_warp = np.linalg.norm(b - A_full @ x_warp)
402
+ diff_norm = np.linalg.norm(x - x_warp)
403
+
404
+ print(f"Warp Cholesky factorization error: {err_warp:.3e}")
405
+ print(f"Warp solution residual norm: {res_norm_warp:.3e}")
406
+ print(f"Difference between CPU and GPU solutions: {diff_norm:.3e}")
407
+
408
+
409
+ @wp.kernel
410
+ def assign_int_kernel(arr: wp.array(dtype=int, ndim=1), value: int):
411
+ """Assigns an integer value into the first element of an array"""
412
+ arr[0] = value
413
+
414
+
415
+ def test_cholesky_solver_graph_capture():
416
+ wp.clear_kernel_cache()
417
+
418
+ max_equations = 1000
419
+
420
+ # Create random SPD matrix A and random RHS b
421
+ rng = np.random.default_rng(42)
422
+ A_np = rng.standard_normal((max_equations, max_equations))
423
+ A_np = A_np @ A_np.T + np.eye(max_equations) * max_equations # Make SPD
424
+ b_np = rng.standard_normal((max_equations, 1))
425
+
426
+ device = "cuda"
427
+
428
+ with wp.ScopedDevice(device):
429
+ warp_solver = BlockCholeskySolver(max_equations, block_size=32)
430
+
431
+ # Create Warp arrays
432
+ # Round up dimensions to next multiple of block size
433
+ block_size = warp_solver.block_size
434
+ padded_n = ((max_equations + block_size - 1) // block_size) * block_size
435
+
436
+ # Create padded arrays initialized with zeros
437
+ A_padded = np.zeros((padded_n, padded_n), dtype=np.float32)
438
+ b_padded = np.zeros((padded_n, 1), dtype=np.float32)
439
+
440
+ # Copy original data into padded arrays
441
+ A_padded[:max_equations, :max_equations] = A_np
442
+ b_padded[:max_equations, :] = b_np
443
+
444
+ # Create Warp arrays from padded numpy arrays
445
+ A_wp = wp.array(A_padded, dtype=wp.float32, ndim=2)
446
+ b_wp = wp.array(b_padded, dtype=wp.float32, ndim=2)
447
+
448
+ # Create result array
449
+ x_wp = wp.zeros_like(b_wp)
450
+ # Create array for equation system size
451
+ n_wp = wp.array([1], dtype=wp.int32)
452
+
453
+ # Create a stream for graph capture
454
+ stream = wp.Stream(device)
455
+
456
+ with wp.ScopedStream(stream):
457
+ # Begin graph capture
458
+ wp.capture_begin()
459
+ try:
460
+ # Loop through different system sizes
461
+ for n in range(1, max_equations + 1):
462
+ # Update system size
463
+ wp.launch(assign_int_kernel, dim=1, inputs=[n_wp, n])
464
+
465
+ # Factorize A
466
+ warp_solver.factorize_dynamic(A_wp, n_wp)
467
+
468
+ # Solve system
469
+ warp_solver.solve(b_wp, x_wp)
470
+
471
+ finally:
472
+ # End graph capture
473
+ graph = wp.capture_end()
474
+
475
+ # Run the captured graph
476
+ with wp.ScopedTimer("Launch graph", cuda_filter=wp.TIMING_GRAPH):
477
+ wp.capture_launch(graph, stream=stream)
478
+
479
+ wp.synchronize()
480
+ print("Finished!")
481
+
482
+
483
+ if __name__ == "__main__":
484
+ wp.clear_kernel_cache()
485
+
486
+ test_graph_capture = False
487
+
488
+ if test_graph_capture:
489
+ test_cholesky_solver_graph_capture()
490
+
491
+ else:
492
+ device = "cpu"
493
+
494
+ # Test equation sys sizes
495
+ sizes = [32, 70, 128, 192, 257, 320, 401, 1000]
496
+
497
+ # Initialize solver once with max size
498
+ warp_solver = BlockCholeskySolver(max(sizes), block_size=16, device=device)
499
+
500
+ for n in sizes:
501
+ print(f"\nTesting system size n = {n}")
502
+ test_cholesky_solver(n, warp_solver, device)
@@ -82,6 +82,7 @@ if __name__ == "__main__":
82
82
  print("A\\n (Warp):\n", Y_wp.numpy())
83
83
  print("A\\x (Numpy):\n", Y_np)
84
84
 
85
- assert np.allclose(Y_wp.numpy(), Y_np) and np.allclose(L_wp.numpy(), L_np)
85
+ np.testing.assert_allclose(Y_wp.numpy(), Y_np)
86
+ np.testing.assert_allclose(L_wp.numpy(), L_np)
86
87
 
87
88
  print("Example Tile Cholesky passed")
@@ -63,4 +63,4 @@ if __name__ == "__main__":
63
63
  wp.launch_tiled(conv_tiled, dim=[1, 1], inputs=[x_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
64
64
 
65
65
  # Since filter is 1/N, conv_tiled is a ~no-op
66
- assert np.allclose(x_h, y_wp.numpy())
66
+ np.testing.assert_allclose(x_h, y_wp.numpy())
@@ -88,7 +88,7 @@ if __name__ == "__main__":
88
88
  f_np = cplx(f_h)
89
89
  y_test = cplx(y_wp.numpy())
90
90
  y_ref = np.fft.ifft(f_np * np.fft.fft(x_np))
91
- assert np.allclose(y_ref, y_test)
91
+ np.testing.assert_allclose(y_ref, y_test)
92
92
 
93
93
  try:
94
94
  import matplotlib.pyplot as plt
@@ -80,6 +80,6 @@ if __name__ == "__main__":
80
80
  block_dim=TILE_THREADS,
81
81
  )
82
82
 
83
- assert np.allclose(C_wp.numpy(), A @ B, atol=1.0e-4)
83
+ np.testing.assert_allclose(C_wp.numpy(), A @ B, atol=1.0e-4)
84
84
 
85
85
  print("Example matrix multiplication passed")
@@ -29,6 +29,8 @@
29
29
  #
30
30
  ###########################################################################
31
31
 
32
+ # ruff: noqa: RUF003
33
+
32
34
  import math
33
35
  import os
34
36