warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,605 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ wp.init() # For wp.context.runtime.core.wp_is_mathdx_enabled()
24
+
25
+ TILE_M = wp.constant(8)
26
+ TILE_N = wp.constant(4)
27
+ TILE_K = wp.constant(8)
28
+
29
+ # num threads per-tile
30
+ TILE_DIM = 32
31
+
32
+
33
+ @wp.kernel()
34
+ def tile_math_cholesky(
35
+ gA: wp.array2d(dtype=wp.float64),
36
+ gD: wp.array1d(dtype=wp.float64),
37
+ gL: wp.array2d(dtype=wp.float64),
38
+ gy: wp.array1d(dtype=wp.float64),
39
+ gx: wp.array1d(dtype=wp.float64),
40
+ ):
41
+ i, j = wp.tid()
42
+ # Load A, D & y
43
+ a = wp.tile_load(gA, shape=(TILE_M, TILE_M), storage="shared")
44
+ d = wp.tile_load(gD, shape=TILE_M, storage="shared")
45
+ y = wp.tile_load(gy, shape=TILE_M, storage="shared")
46
+ # Ensure tile_diag_add() and tile_cholesky_solve() work with transposed matrices
47
+ a_t = wp.tile_transpose(a)
48
+ # Compute L st LL^T = A^T + diag(D)
49
+ b = wp.tile_diag_add(a_t, d)
50
+ l = wp.tile_cholesky(b)
51
+ # Solve for y in LL^T x = y
52
+ x = wp.tile_cholesky_solve(l, y)
53
+ # Store L & y
54
+ wp.tile_store(gL, l)
55
+ wp.tile_store(gx, x)
56
+
57
+
58
+ def test_tile_cholesky_cholesky(test, device):
59
+ A_h = np.ones((TILE_M, TILE_M), dtype=np.float64)
60
+ D_h = 8.0 * np.ones(TILE_M, dtype=np.float64)
61
+ L_h = np.zeros_like(A_h)
62
+ Y_h = np.arange(TILE_M, dtype=np.float64)
63
+ X_h = np.zeros_like(Y_h)
64
+
65
+ A_np = A_h.T + np.diag(D_h)
66
+ L_np = np.linalg.cholesky(A_np)
67
+ X_np = np.linalg.solve(A_np, Y_h)
68
+
69
+ A_wp = wp.array(A_h, requires_grad=True, dtype=wp.float64, device=device)
70
+ D_wp = wp.array(D_h, requires_grad=True, dtype=wp.float64, device=device)
71
+ L_wp = wp.array(L_h, requires_grad=True, dtype=wp.float64, device=device)
72
+ Y_wp = wp.array(Y_h, requires_grad=True, dtype=wp.float64, device=device)
73
+ X_wp = wp.array(X_h, requires_grad=True, dtype=wp.float64, device=device)
74
+
75
+ wp.launch_tiled(
76
+ tile_math_cholesky, dim=[1, 1], inputs=[A_wp, D_wp, L_wp, Y_wp, X_wp], block_dim=TILE_DIM, device=device
77
+ )
78
+ wp.synchronize_device(device)
79
+
80
+ np.testing.assert_allclose(X_wp.numpy(), X_np)
81
+ np.testing.assert_allclose(L_wp.numpy(), L_np)
82
+
83
+ # TODO: implement and test backward pass
84
+
85
+
86
+ @wp.kernel()
87
+ def tile_math_cholesky_multiple_rhs(
88
+ gA: wp.array2d(dtype=wp.float64),
89
+ gD: wp.array1d(dtype=wp.float64),
90
+ gL: wp.array2d(dtype=wp.float64),
91
+ gy: wp.array2d(dtype=wp.float64),
92
+ gx: wp.array2d(dtype=wp.float64),
93
+ gz: wp.array2d(dtype=wp.float64),
94
+ ):
95
+ i, j = wp.tid()
96
+ # Load A, D & y
97
+ a = wp.tile_load(gA, shape=(TILE_M, TILE_M), storage="shared")
98
+ d = wp.tile_load(gD, shape=TILE_M, storage="shared")
99
+ y = wp.tile_load(gy, shape=(TILE_M, TILE_M), storage="shared")
100
+ # Ensure tile_diag_add() and tile_cholesky_solve() work with transposed matrices
101
+ a_t = wp.tile_transpose(a)
102
+ # Compute L st LL^T = A.T + diag(D)
103
+ b = wp.tile_diag_add(a_t, d)
104
+ l = wp.tile_cholesky(b)
105
+ # Solve for y in LL^T x = y.T
106
+ y_t = wp.tile_transpose(y)
107
+ x = wp.tile_cholesky_solve(l, y_t)
108
+ # Ensure matmul receives correct layout information
109
+ z = wp.tile_matmul(x, x)
110
+ # Store L & y
111
+ wp.tile_store(gL, l)
112
+ wp.tile_store(gx, x)
113
+ wp.tile_store(gz, z)
114
+
115
+
116
+ def test_tile_cholesky_cholesky_multiple_rhs(test, device):
117
+ A_h = np.ones((TILE_M, TILE_M), dtype=np.float64)
118
+ D_h = 8.0 * np.ones(TILE_M, dtype=np.float64)
119
+ L_h = np.zeros_like(A_h)
120
+ Y_h = np.arange(TILE_M * TILE_M, dtype=np.float64).reshape((TILE_M, TILE_M))
121
+ X_h = np.zeros_like(Y_h)
122
+ Z_h = np.zeros_like(Y_h)
123
+
124
+ A_np = A_h.T + np.diag(D_h)
125
+ L_np = np.linalg.cholesky(A_np)
126
+ X_np = np.linalg.solve(A_np, Y_h.T)
127
+ Z_np = X_np @ X_np
128
+
129
+ A_wp = wp.array(A_h, requires_grad=True, dtype=wp.float64, device=device)
130
+ D_wp = wp.array(D_h, requires_grad=True, dtype=wp.float64, device=device)
131
+ L_wp = wp.array(L_h, requires_grad=True, dtype=wp.float64, device=device)
132
+ Y_wp = wp.array(Y_h, requires_grad=True, dtype=wp.float64, device=device)
133
+ X_wp = wp.array(X_h, requires_grad=True, dtype=wp.float64, device=device)
134
+ Z_wp = wp.array(Z_h, requires_grad=True, dtype=wp.float64, device=device)
135
+
136
+ wp.launch_tiled(
137
+ tile_math_cholesky_multiple_rhs,
138
+ dim=[1, 1],
139
+ inputs=[A_wp, D_wp, L_wp, Y_wp, X_wp, Z_wp],
140
+ block_dim=TILE_DIM,
141
+ device=device,
142
+ )
143
+ wp.synchronize_device(device)
144
+
145
+ np.testing.assert_allclose(L_wp.numpy(), L_np)
146
+ np.testing.assert_allclose(X_wp.numpy(), X_np)
147
+ np.testing.assert_allclose(Z_wp.numpy(), Z_np)
148
+
149
+ # TODO: implement and test backward pass
150
+
151
+
152
+ @wp.kernel
153
+ def tile_math_forward_substitution(
154
+ gL: wp.array2d(dtype=wp.float64), gx: wp.array1d(dtype=wp.float64), gz: wp.array1d(dtype=wp.float64)
155
+ ):
156
+ i, j = wp.tid()
157
+ # Load L & x
158
+ L = wp.tile_load(gL, shape=(TILE_M, TILE_M), storage="shared")
159
+ x = wp.tile_load(gx, shape=TILE_M, storage="shared")
160
+ # Solve for z in Lz = x
161
+ # Transpose because we loaded an upper triangular matrix
162
+ z = wp.tile_lower_solve(wp.tile_transpose(L), x)
163
+ # Store z
164
+ wp.tile_store(gz, z)
165
+
166
+
167
+ def test_tile_cholesky_forward_substitution(test, device):
168
+ # Create test data
169
+ rng = np.random.default_rng(42)
170
+ L_h = np.triu(rng.random((TILE_M, TILE_M))) # Upper triangular matrix
171
+ x_h = rng.random(TILE_M)
172
+ z_h = np.zeros_like(x_h)
173
+
174
+ # Compute reference solution using numpy
175
+ z_np = np.linalg.solve(L_h.T, x_h)
176
+
177
+ # Create Warp arrays
178
+ L_wp = wp.array(L_h, requires_grad=True, dtype=wp.float64, device=device)
179
+ x_wp = wp.array(x_h, requires_grad=True, dtype=wp.float64, device=device)
180
+ z_wp = wp.array(z_h, requires_grad=True, dtype=wp.float64, device=device)
181
+
182
+ # Run kernel
183
+ wp.launch_tiled(
184
+ tile_math_forward_substitution, dim=[1, 1], inputs=[L_wp, x_wp, z_wp], block_dim=TILE_DIM, device=device
185
+ )
186
+ wp.synchronize_device(device)
187
+
188
+ # Verify results
189
+ np.testing.assert_allclose(z_wp.numpy(), z_np)
190
+
191
+ # TODO: implement and test backward pass
192
+
193
+
194
+ @wp.kernel
195
+ def tile_math_back_substitution(
196
+ gL: wp.array2d(dtype=wp.float64), gx: wp.array1d(dtype=wp.float64), gz: wp.array1d(dtype=wp.float64)
197
+ ):
198
+ i, j = wp.tid()
199
+ # Load L & x
200
+ L = wp.tile_load(gL, shape=(TILE_M, TILE_M), storage="shared")
201
+ x = wp.tile_load(gx, shape=TILE_M, storage="shared")
202
+ # Solve for z in L^T z = x
203
+ # Transpose because we loaded a lower triangular matrix
204
+ z = wp.tile_upper_solve(wp.tile_transpose(L), x)
205
+ # Store z
206
+ wp.tile_store(gz, z)
207
+
208
+
209
+ def test_tile_cholesky_back_substitution(test, device):
210
+ # Create test data
211
+ rng = np.random.default_rng(42)
212
+ L_h = np.tril(rng.random((TILE_M, TILE_M))) # Lower triangular matrix
213
+ x_h = rng.random(TILE_M)
214
+ z_h = np.zeros_like(x_h)
215
+
216
+ # Compute reference solution using numpy
217
+ z_np = np.linalg.solve(L_h.T, x_h)
218
+
219
+ # Create Warp arrays
220
+ L_wp = wp.array(L_h, requires_grad=True, dtype=wp.float64, device=device)
221
+ x_wp = wp.array(x_h, requires_grad=True, dtype=wp.float64, device=device)
222
+ z_wp = wp.array(z_h, requires_grad=True, dtype=wp.float64, device=device)
223
+
224
+ # Run kernel
225
+ wp.launch_tiled(
226
+ tile_math_back_substitution, dim=[1, 1], inputs=[L_wp, x_wp, z_wp], block_dim=TILE_DIM, device=device
227
+ )
228
+ wp.synchronize_device(device)
229
+
230
+ # Verify results
231
+ np.testing.assert_allclose(z_wp.numpy(), z_np)
232
+
233
+ # TODO: implement and test backward pass
234
+
235
+
236
+ @wp.kernel
237
+ def tile_math_forward_substitution_multiple_rhs(
238
+ gL: wp.array2d(dtype=wp.float64),
239
+ gx: wp.array2d(dtype=wp.float64),
240
+ gz: wp.array2d(dtype=wp.float64),
241
+ gc: wp.array2d(dtype=wp.float64),
242
+ ):
243
+ i, j = wp.tid()
244
+ # Load L & x
245
+ L = wp.tile_load(gL, shape=(TILE_M, TILE_M), storage="shared")
246
+ x = wp.tile_load(gx, shape=(TILE_M, TILE_M), storage="shared")
247
+ # Solve for z in Lz = x.T
248
+ x_t = wp.tile_transpose(x)
249
+ z = wp.tile_lower_solve(L, x_t)
250
+ # Ensure matmul receives correct layout information
251
+ c = wp.tile_matmul(z, z)
252
+ # Store z and c
253
+ wp.tile_store(gz, z)
254
+ wp.tile_store(gc, c)
255
+
256
+
257
+ def test_tile_cholesky_forward_substitution_multiple_rhs(test, device):
258
+ # Create test data
259
+ rng = np.random.default_rng(42)
260
+ L_h = np.tril(rng.random((TILE_M, TILE_M))) # Lower triangular matrix
261
+ x_h = rng.random((TILE_M, TILE_M)) # Multiple right-hand sides
262
+ z_h = np.zeros_like(x_h)
263
+ c_h = np.zeros_like(x_h)
264
+
265
+ # Compute reference solution using numpy
266
+ z_np = np.linalg.solve(L_h, x_h.T)
267
+ c_np = z_np @ z_np
268
+
269
+ # Create Warp arrays
270
+ L_wp = wp.array(L_h, requires_grad=True, dtype=wp.float64, device=device)
271
+ x_wp = wp.array(x_h, requires_grad=True, dtype=wp.float64, device=device)
272
+ z_wp = wp.array(z_h, requires_grad=True, dtype=wp.float64, device=device)
273
+ c_wp = wp.array(c_h, requires_grad=True, dtype=wp.float64, device=device)
274
+
275
+ # Run kernel
276
+ wp.launch_tiled(
277
+ tile_math_forward_substitution_multiple_rhs,
278
+ dim=[1, 1],
279
+ inputs=[L_wp, x_wp, z_wp, c_wp],
280
+ block_dim=TILE_DIM,
281
+ device=device,
282
+ )
283
+ wp.synchronize_device(device)
284
+
285
+ # Verify results
286
+ test.assertTrue(np.allclose(z_wp.numpy(), z_np))
287
+ test.assertTrue(np.allclose(c_wp.numpy(), c_np))
288
+
289
+ # TODO: implement and test backward pass
290
+
291
+
292
+ @wp.kernel
293
+ def tile_math_back_substitution_multiple_rhs(
294
+ gL: wp.array2d(dtype=wp.float64),
295
+ gx: wp.array2d(dtype=wp.float64),
296
+ gz: wp.array2d(dtype=wp.float64),
297
+ gc: wp.array2d(dtype=wp.float64),
298
+ ):
299
+ i, j = wp.tid()
300
+ # Load L & x
301
+ L = wp.tile_load(gL, shape=(TILE_M, TILE_M), storage="shared")
302
+ x = wp.tile_load(gx, shape=(TILE_M, TILE_M), storage="shared")
303
+ # Solve for z in L^T z = x.T
304
+ x_t = wp.tile_transpose(x)
305
+ z = wp.tile_upper_solve(wp.tile_transpose(L), x_t)
306
+ # Ensure matmul receives correct layout information
307
+ c = wp.tile_matmul(z, z)
308
+ # Store z and c
309
+ wp.tile_store(gz, z)
310
+ wp.tile_store(gc, c)
311
+
312
+
313
+ def test_tile_cholesky_back_substitution_multiple_rhs(test, device):
314
+ # Create test data
315
+ rng = np.random.default_rng(42)
316
+ L_h = np.tril(rng.random((TILE_M, TILE_M))) # Lower triangular matrix
317
+ x_h = rng.random((TILE_M, TILE_M)) # Multiple right-hand sides
318
+ z_h = np.zeros_like(x_h)
319
+ c_h = np.zeros_like(x_h)
320
+
321
+ # Compute reference solution using numpy
322
+ z_np = np.linalg.solve(L_h.T, x_h.T)
323
+ c_np = z_np @ z_np
324
+
325
+ # Create Warp arrays
326
+ L_wp = wp.array(L_h, requires_grad=True, dtype=wp.float64, device=device)
327
+ x_wp = wp.array(x_h, requires_grad=True, dtype=wp.float64, device=device)
328
+ z_wp = wp.array(z_h, requires_grad=True, dtype=wp.float64, device=device)
329
+ c_wp = wp.array(c_h, requires_grad=True, dtype=wp.float64, device=device)
330
+
331
+ # Run kernel
332
+ wp.launch_tiled(
333
+ tile_math_back_substitution_multiple_rhs,
334
+ dim=[1, 1],
335
+ inputs=[L_wp, x_wp, z_wp, c_wp],
336
+ block_dim=TILE_DIM,
337
+ device=device,
338
+ )
339
+ wp.synchronize_device(device)
340
+
341
+ # Verify results
342
+ test.assertTrue(np.allclose(z_wp.numpy(), z_np))
343
+ test.assertTrue(np.allclose(c_wp.numpy(), c_np))
344
+
345
+ # TODO: implement and test backward pass
346
+
347
+
348
+ # tests a complex composition of most libmathdx calls
349
+ def test_tile_cholesky_block_cholesky(test, device):
350
+ BLOCK_SIZE = wp.constant(TILE_M // 2)
351
+
352
+ @wp.kernel(module="unique")
353
+ def block_cholesky_kernel(
354
+ A: wp.array2d(dtype=float),
355
+ L: wp.array2d(dtype=float),
356
+ ):
357
+ """
358
+ Computes the Cholesky factorization of a symmetric positive definite matrix A in blocks.
359
+ It returns a lower-triangular matrix L such that A = L L^T.
360
+ """
361
+
362
+ # Process the matrix in blocks along its leading dimension.
363
+ for k in range(0, TILE_M, BLOCK_SIZE):
364
+ end = k + BLOCK_SIZE
365
+
366
+ # Load current diagonal block A[k:end, k:end]
367
+ # and update with contributions from previously computed blocks.
368
+ A_kk_tile = wp.tile_load(A, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(k, k), storage="shared")
369
+
370
+ for j in range(0, k, BLOCK_SIZE):
371
+ L_block = wp.tile_load(L, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(k, j))
372
+ L_block_T = wp.tile_transpose(L_block)
373
+ L_L_T_block = wp.tile_matmul(L_block, L_block_T)
374
+ A_kk_tile -= L_L_T_block
375
+
376
+ # Compute the Cholesky factorization for the block
377
+ # print(A_kk_tile)
378
+ L_kk_tile = wp.tile_cholesky(A_kk_tile)
379
+ wp.tile_store(L, L_kk_tile, offset=(k, k))
380
+
381
+ # Process the blocks below the current block
382
+ for i in range(end, TILE_M, BLOCK_SIZE):
383
+ A_ik_tile = wp.tile_load(A, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(i, k), storage="shared")
384
+
385
+ for j in range(0, k, BLOCK_SIZE):
386
+ L_tile = wp.tile_load(L, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(i, j))
387
+ L_2_tile = wp.tile_load(L, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(k, j))
388
+ L_T_tile = wp.tile_transpose(L_2_tile)
389
+ L_L_T_tile = wp.tile_matmul(L_tile, L_T_tile)
390
+ A_ik_tile -= L_L_T_tile
391
+
392
+ A_ik_T_tile = wp.tile_transpose(A_ik_tile)
393
+ sol_T_tile = wp.tile_lower_solve(L_kk_tile, A_ik_T_tile)
394
+ sol_tile = wp.tile_transpose(sol_T_tile)
395
+
396
+ wp.tile_store(L, sol_tile, offset=(i, k))
397
+
398
+ @wp.kernel(module="unique")
399
+ def block_cholesky_solve_kernel(
400
+ L: wp.array2d(dtype=float),
401
+ b: wp.array2d(dtype=float),
402
+ scratch: wp.array2d(dtype=float),
403
+ x: wp.array2d(dtype=float),
404
+ ):
405
+ """
406
+ Solves A x = b given the Cholesky factor L (A = L L^T) using
407
+ blocked forward and backward substitution.
408
+ """
409
+
410
+ # Forward substitution: solve L y = b
411
+ for i in range(0, TILE_M, BLOCK_SIZE):
412
+ i_end = i + BLOCK_SIZE
413
+ rhs_tile = wp.tile_load(b, shape=(BLOCK_SIZE, 1), offset=(i, 0))
414
+ for j in range(0, i, BLOCK_SIZE):
415
+ L_block = wp.tile_load(L, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(i, j))
416
+ y_block = wp.tile_load(scratch, shape=(BLOCK_SIZE, 1), offset=(j, 0))
417
+ Ly_block = wp.tile_matmul(L_block, y_block)
418
+ rhs_tile -= Ly_block
419
+ L_tile = wp.tile_load(L, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(i, i))
420
+ y_tile = wp.tile_lower_solve(L_tile, rhs_tile)
421
+ wp.tile_store(scratch, y_tile, offset=(i, 0))
422
+
423
+ # Backward substitution: solve L^T x = y
424
+ for i in range(TILE_M - BLOCK_SIZE, -1, -BLOCK_SIZE):
425
+ i_start = i
426
+ i_end = i_start + BLOCK_SIZE
427
+ rhs_tile = wp.tile_load(scratch, shape=(BLOCK_SIZE, 1), offset=(i_start, 0))
428
+ for j in range(i_end, TILE_M, BLOCK_SIZE):
429
+ L_tile = wp.tile_load(L, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(j, i_start))
430
+ L_T_tile = wp.tile_transpose(L_tile)
431
+ x_tile = wp.tile_load(x, shape=(BLOCK_SIZE, 1), offset=(j, 0))
432
+ L_T_x_tile = wp.tile_matmul(L_T_tile, x_tile)
433
+ rhs_tile -= L_T_x_tile
434
+ L_tile = wp.tile_load(L, shape=(BLOCK_SIZE, BLOCK_SIZE), offset=(i_start, i_start))
435
+ x_tile = wp.tile_upper_solve(wp.tile_transpose(L_tile), rhs_tile)
436
+ wp.tile_store(x, x_tile, offset=(i_start, 0))
437
+
438
+ # check block cholesky decomposition
439
+
440
+ rng = np.random.default_rng(42)
441
+
442
+ M = np.array(rng.random((TILE_M, TILE_M)), dtype=float)
443
+
444
+ A_np = M.T @ M + np.eye(TILE_M, TILE_M)
445
+ L_np = np.linalg.cholesky(A_np)
446
+
447
+ A_wp = wp.array(A_np, dtype=float, device=device)
448
+ L_wp = wp.zeros_like(A_wp)
449
+
450
+ wp.launch_tiled(block_cholesky_kernel, dim=1, inputs=[A_wp], outputs=[L_wp], block_dim=TILE_DIM, device=device)
451
+
452
+ # check block cholesky solve
453
+
454
+ assert_np_equal(L_wp.numpy(), L_np, tol=1e-6)
455
+
456
+ b_np = np.array(rng.random((TILE_M, 1)), dtype=float)
457
+ b_wp = wp.array(b_np, dtype=float, device=device)
458
+
459
+ scratch = wp.zeros_like(b_wp)
460
+
461
+ x_np = np.linalg.solve(L_np.T, np.linalg.solve(L_np, b_np))
462
+ x_wp = wp.zeros_like(b_wp)
463
+
464
+ wp.launch_tiled(
465
+ block_cholesky_solve_kernel,
466
+ dim=1,
467
+ inputs=[L_wp, b_wp, scratch],
468
+ outputs=[x_wp],
469
+ block_dim=TILE_DIM,
470
+ device=device,
471
+ )
472
+
473
+ assert_np_equal(x_wp.numpy(), x_np, tol=1e-6)
474
+
475
+
476
+ @wp.kernel
477
+ def test_tile_lower_solve(L: wp.array2d(dtype=float), y: wp.array(dtype=float), x: wp.array(dtype=float)):
478
+ L_tile = wp.tile_load(L, shape=(TILE_M, TILE_M))
479
+ y_tile = wp.tile_load(x, shape=(TILE_M,))
480
+ sol = wp.tile_lower_solve(L_tile, y_tile)
481
+ wp.tile_store(x, sol)
482
+
483
+
484
+ @wp.kernel
485
+ def test_tile_upper_solve(L: wp.array2d(dtype=float), y: wp.array(dtype=float), x: wp.array(dtype=float)):
486
+ L_tile = wp.tile_load(L, shape=(TILE_M, TILE_M))
487
+ y_tile = wp.tile_load(x, shape=(TILE_M,))
488
+ sol = wp.tile_upper_solve(L_tile, y_tile)
489
+ wp.tile_store(x, sol)
490
+
491
+
492
+ def test_tile_cholesky_singular_matrices(test, device):
493
+ if not wp.context.runtime.core.wp_is_mathdx_enabled():
494
+ test.skipTest("MathDx is not enabled")
495
+
496
+ rng = np.random.default_rng(42)
497
+ L_np = np.tril(rng.random((TILE_M, TILE_M))) # Lower triangular matrix
498
+ L_np[-1, -1] = 0.0 # Make it singular
499
+ y_np = rng.random(TILE_M)
500
+
501
+ L_wp = wp.array2d(L_np, dtype=float, device=device)
502
+ y_wp = wp.array(y_np, dtype=float, device=device)
503
+ x_wp = wp.zeros_like(y_wp)
504
+
505
+ wp.launch_tiled(
506
+ test_tile_lower_solve, dim=1, inputs=[L_wp, y_wp], outputs=[x_wp], block_dim=TILE_DIM, device=device
507
+ )
508
+
509
+ test.assertTrue(np.isnan(x_wp.numpy()).any())
510
+
511
+ L_np = np.triu(rng.random((TILE_M, TILE_M))) # Upper triangular matrix
512
+ L_np[-1, -1] = 0.0 # Make it singular
513
+
514
+ L_wp = wp.array2d(L_np, dtype=float, device=device)
515
+ y_wp = wp.array(y_np, dtype=float, device=device)
516
+ x_wp = wp.zeros_like(y_wp)
517
+
518
+ wp.launch_tiled(
519
+ test_tile_upper_solve, dim=1, inputs=[L_wp, y_wp], outputs=[x_wp], block_dim=TILE_DIM, device=device
520
+ )
521
+
522
+ test.assertTrue(np.isnan(x_wp.numpy()).any())
523
+
524
+
525
+ all_devices = get_test_devices()
526
+ cuda_devices = get_cuda_test_devices()
527
+
528
+
529
+ @unittest.skipUnless(
530
+ not wp.context.runtime.core.wp_is_mathdx_enabled()
531
+ or (wp.context.runtime.core.wp_is_mathdx_enabled() and wp.context.runtime.core.wp_cuda_toolkit_version() >= 12060),
532
+ "MathDx is not enabled or is enabled but CUDA toolkit version is less than 12.6",
533
+ )
534
+ class TestTileCholesky(unittest.TestCase):
535
+ pass
536
+
537
+
538
+ add_function_test(
539
+ TestTileCholesky,
540
+ "test_tile_cholesky_cholesky",
541
+ test_tile_cholesky_cholesky,
542
+ devices=all_devices,
543
+ check_output=False,
544
+ )
545
+ add_function_test(
546
+ TestTileCholesky,
547
+ "test_tile_cholesky_cholesky_multiple_rhs",
548
+ test_tile_cholesky_cholesky_multiple_rhs,
549
+ devices=all_devices,
550
+ check_output=False,
551
+ )
552
+
553
+
554
+ add_function_test(
555
+ TestTileCholesky,
556
+ "test_tile_cholesky_forward_substitution",
557
+ test_tile_cholesky_forward_substitution,
558
+ devices=cuda_devices,
559
+ check_output=False,
560
+ )
561
+
562
+ add_function_test(
563
+ TestTileCholesky,
564
+ "test_tile_cholesky_back_substitution",
565
+ test_tile_cholesky_back_substitution,
566
+ devices=cuda_devices,
567
+ check_output=False,
568
+ )
569
+
570
+ add_function_test(
571
+ TestTileCholesky,
572
+ "test_tile_cholesky_forward_substitution_multiple_rhs",
573
+ test_tile_cholesky_forward_substitution_multiple_rhs,
574
+ devices=cuda_devices,
575
+ check_output=False,
576
+ )
577
+
578
+ add_function_test(
579
+ TestTileCholesky,
580
+ "test_tile_cholesky_back_substitution_multiple_rhs",
581
+ test_tile_cholesky_back_substitution_multiple_rhs,
582
+ devices=cuda_devices,
583
+ check_output=False,
584
+ )
585
+
586
+ add_function_test(
587
+ TestTileCholesky,
588
+ "test_tile_cholesky_block_cholesky",
589
+ test_tile_cholesky_block_cholesky,
590
+ devices=cuda_devices,
591
+ check_output=False,
592
+ )
593
+
594
+ add_function_test(
595
+ TestTileCholesky,
596
+ "test_tile_cholesky_singular_matrices",
597
+ test_tile_cholesky_singular_matrices,
598
+ devices=cuda_devices,
599
+ check_output=False,
600
+ )
601
+
602
+
603
+ if __name__ == "__main__":
604
+ wp.clear_kernel_cache()
605
+ unittest.main(verbosity=2, failfast=True)