warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1412 -888
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +91 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.1.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.1.dist-info}/top_level.txt +0 -0
warp/fem/utils.py CHANGED
@@ -1,14 +1,10 @@
1
- from typing import Any, Tuple
1
+ from typing import Any, Tuple, Union
2
2
 
3
3
  import numpy as np
4
4
 
5
5
  import warp as wp
6
- from warp.fem.cache import (
7
- Temporary,
8
- TemporaryStore,
9
- borrow_temporary,
10
- borrow_temporary_like,
11
- )
6
+ import warp.fem.cache as cache
7
+ from warp.fem.types import NULL_NODE_INDEX
12
8
  from warp.utils import array_scan, radix_sort_pairs, runlength_encode
13
9
 
14
10
 
@@ -115,121 +111,331 @@ def skew_part(x: wp.mat33):
115
111
  return wp.vec3(a, b, c)
116
112
 
117
113
 
114
+ @wp.func
115
+ def householder_qr_decomposition(A: Any):
116
+ """
117
+ QR decomposition of a square matrix using Householder reflections
118
+
119
+ Returns Q and R such that Q R = A, Q orthonormal (such that QQ^T = Id), R upper triangular
120
+ """
121
+
122
+ x = type(A[0])()
123
+ Q = wp.identity(n=type(x).length, dtype=A.dtype)
124
+
125
+ zero = x.dtype(0.0)
126
+ two = x.dtype(2.0)
127
+
128
+ for i in range(type(x).length):
129
+ for k in range(type(x).length):
130
+ x[k] = wp.select(k < i, A[k, i], zero)
131
+
132
+ alpha = wp.length(x) * wp.sign(x[i])
133
+ x[i] += alpha
134
+ two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
135
+
136
+ A -= wp.outer(two_over_x_sq * x, x * A)
137
+ Q -= wp.outer(Q * x, two_over_x_sq * x)
138
+
139
+ return Q, A
140
+
141
+
142
+ @wp.func
143
+ def householder_make_hessenberg(A: Any):
144
+ """Transforms a square matrix to Hessenberg form (single lower diagonal) using Householder reflections
145
+
146
+ Returns:
147
+ Q and H such that Q H Q^T = A, Q orthonormal, H under Hessenberg form
148
+ If A is symmetric, H will be tridiagonal
149
+ """
150
+
151
+ x = type(A[0])()
152
+ Q = wp.identity(n=type(x).length, dtype=A.dtype)
153
+
154
+ zero = x.dtype(0.0)
155
+ two = x.dtype(2.0)
156
+
157
+ for i in range(1, type(x).length):
158
+ for k in range(type(x).length):
159
+ x[k] = wp.select(k < i, A[k, i - 1], zero)
160
+
161
+ alpha = wp.length(x) * wp.sign(x[i])
162
+ x[i] += alpha
163
+ two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
164
+
165
+ # apply on both sides
166
+ A -= wp.outer(two_over_x_sq * x, x * A)
167
+ A -= wp.outer(A * x, two_over_x_sq * x)
168
+ Q -= wp.outer(Q * x, two_over_x_sq * x)
169
+
170
+ return Q, A
171
+
172
+
173
+ @wp.func
174
+ def solve_triangular(R: Any, b: Any):
175
+ """Solves for R x = b where R is an upper triangular matrix
176
+
177
+ Returns x
178
+ """
179
+ zero = b.dtype(0)
180
+ x = type(b)(b.dtype(0))
181
+ for i in range(b.length, 0, -1):
182
+ j = i - 1
183
+ r = b[j] - wp.dot(R[j], x)
184
+ x[j] = wp.select(R[j, j] == zero, r / R[j, j], zero)
185
+
186
+ return x
187
+
188
+
189
+ @wp.func
190
+ def inverse_qr(A: Any):
191
+ # Computes a square matrix inverse using QR factorization
192
+
193
+ Q, R = householder_qr_decomposition(A)
194
+
195
+ A_inv = type(A)()
196
+ for i in range(type(A[0]).length):
197
+ A_inv[i] = solve_triangular(R, Q[i]) # ith column of Q^T
198
+
199
+ return wp.transpose(A_inv)
200
+
201
+
202
+ @wp.func
203
+ def symmetric_eigenvalues_qr(A: Any, tol: Any):
204
+ """
205
+ Computes the eigenvalues and eigen vectors of a square symmetric matrix A using the QR algorithm
206
+
207
+ Args:
208
+ A: square symmetric matrix
209
+ tol: Tolerance for the diagonalization residual (squared L2 norm of off-diagonal terms)
210
+
211
+ Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
212
+ """
213
+
214
+ two = A.dtype(2.0)
215
+ zero = A.dtype(0.0)
216
+
217
+ # temp storage for matrix rows
218
+ ri = type(A[0])()
219
+ rn = type(ri)()
220
+
221
+ # tridiagonal storage for R
222
+ R_L = type(ri)()
223
+ R_L = type(ri)(zero)
224
+ R_U = type(ri)(zero)
225
+
226
+ # so that we can use the type length in expression
227
+ # this will prevent unrolling by warp, but should be ok for native code
228
+ m = int(0)
229
+ for _ in range(type(ri).length):
230
+ m += 1
231
+
232
+ # Put A under Hessenberg form (tridiagonal)
233
+ Q, H = householder_make_hessenberg(A)
234
+ Q = wp.transpose(Q) # algorithm below works and transposed Q as rows are easier to index
235
+
236
+ for _ in range(16 * m): # failsafe, usually converges faster than that
237
+ # Initialize R with current H
238
+ R_D = wp.get_diag(H)
239
+ for i in range(1, type(ri).length):
240
+ R_L[i - 1] = H[i, i - 1]
241
+ R_U[i - 1] = H[i - 1, i]
242
+
243
+ # compute QR decomposition, directly transform H and eigenvectors
244
+ for n in range(1, m):
245
+ i = n - 1
246
+
247
+ # compute reflection
248
+ xi = R_D[i]
249
+ xn = R_L[i]
250
+
251
+ xii = xi * xi
252
+ xnn = xn * xn
253
+ alpha = wp.sqrt(xii + xnn) * wp.sign(xi)
254
+
255
+ xi += alpha
256
+ xii = xi * xi
257
+ xin = xi * xn
258
+
259
+ two_over_x_sq = wp.select(alpha == zero, two / (xii + xnn), zero)
260
+ xii *= two_over_x_sq
261
+ xin *= two_over_x_sq
262
+ xnn *= two_over_x_sq
263
+
264
+ # Left-multiply R and Q, multiply H on both sides
265
+ # Note that R should get non-zero coefficients on the second upper diagonal,
266
+ # but those won't get read afterwards, so we can ignore them
267
+
268
+ R_D[n] -= R_U[i] * xin + R_D[n] * xnn
269
+ R_U[n] -= R_U[n] * xnn
270
+
271
+ ri = Q[i]
272
+ rn = Q[n]
273
+ Q[i] -= ri * xii + rn * xin
274
+ Q[n] -= ri * xin + rn * xnn
275
+
276
+ # H is multiplied on both sides, but stays tridiagonal except for moving buldge
277
+ # Note: we could reduce the stencil to for 4 columns qui we do below,
278
+ # but unlikely to be worth it for our small matrix sizes
279
+ ri = H[i]
280
+ rn = H[n]
281
+ H[i] -= ri * xii + rn * xin
282
+ H[n] -= ri * xin + rn * xnn
283
+
284
+ # multiply on right, manually. We just need to consider 4 rows
285
+ if i > 0:
286
+ ci = H[i - 1, i]
287
+ cn = H[i - 1, n]
288
+ H[i - 1, i] -= ci * xii + cn * xin
289
+ H[i - 1, n] -= ci * xin + cn * xnn
290
+
291
+ for k in range(2):
292
+ ci = H[i + k, i]
293
+ cn = H[i + k, n]
294
+ H[i + k, i] -= ci * xii + cn * xin
295
+ H[i + k, n] -= ci * xin + cn * xnn
296
+
297
+ if n + 1 < m:
298
+ ci = H[n + 1, i]
299
+ cn = H[n + 1, n]
300
+ H[n + 1, i] -= ci * xii + cn * xin
301
+ H[n + 1, n] -= ci * xin + cn * xnn
302
+
303
+ # Terminate if the upper diagonal of R is near zero
304
+ if wp.length_sq(R_U) < tol:
305
+ break
306
+
307
+ return wp.get_diag(H), Q
308
+
309
+
118
310
  def compress_node_indices(
119
- node_count: int, node_indices: wp.array(dtype=int), temporary_store: TemporaryStore = None
120
- ) -> Tuple[Temporary, Temporary, int, Temporary]:
311
+ node_count: int,
312
+ node_indices: wp.array(dtype=int),
313
+ return_unique_nodes=False,
314
+ temporary_store: cache.TemporaryStore = None,
315
+ ) -> Union[Tuple[cache.Temporary, cache.Temporary], Tuple[cache.Temporary, cache.Temporary, int, cache.Temporary]]:
121
316
  """
122
317
  Compress an unsorted list of node indices into:
123
318
  - a node_offsets array, giving for each node the start offset of corresponding indices in sorted_array_indices
124
319
  - a sorted_array_indices array, listing the indices in the input array corresponding to each node
320
+
321
+ Plus if `return_unique_nodes` is ``True``,
125
322
  - the number of unique node indices
126
323
  - a unique_node_indices array containing the sorted list of unique node indices (i.e. the list of indices i for which node_offsets[i] < node_offsets[i+1])
324
+
325
+ Node indices equal to NULL_NODE_INDEX will be ignored
127
326
  """
128
327
 
129
328
  index_count = node_indices.size
329
+ device = node_indices.device
130
330
 
131
- sorted_node_indices_temp = borrow_temporary(
132
- temporary_store, shape=2 * index_count, dtype=int, device=node_indices.device
133
- )
134
- sorted_array_indices_temp = borrow_temporary_like(sorted_node_indices_temp, temporary_store)
331
+ with wp.ScopedDevice(device):
332
+ sorted_node_indices_temp = cache.borrow_temporary(temporary_store, shape=2 * index_count, dtype=int)
333
+ sorted_array_indices_temp = cache.borrow_temporary_like(sorted_node_indices_temp, temporary_store)
135
334
 
136
- sorted_node_indices = sorted_node_indices_temp.array
137
- sorted_array_indices = sorted_array_indices_temp.array
335
+ sorted_node_indices = sorted_node_indices_temp.array
336
+ sorted_array_indices = sorted_array_indices_temp.array
138
337
 
139
- wp.copy(dest=sorted_node_indices, src=node_indices, count=index_count)
338
+ wp.copy(dest=sorted_node_indices, src=node_indices, count=index_count)
140
339
 
141
- indices_per_element = 1 if node_indices.ndim == 1 else node_indices.shape[-1]
142
- wp.launch(
143
- kernel=_iota_kernel,
144
- dim=index_count,
145
- inputs=[sorted_array_indices, indices_per_element],
146
- device=sorted_array_indices.device,
147
- )
340
+ indices_per_element = 1 if node_indices.ndim == 1 else node_indices.shape[-1]
341
+ wp.launch(
342
+ kernel=_iota_kernel,
343
+ dim=index_count,
344
+ inputs=[sorted_array_indices, indices_per_element],
345
+ )
148
346
 
149
- # Sort indices
150
- radix_sort_pairs(sorted_node_indices, sorted_array_indices, count=index_count)
347
+ # Sort indices
348
+ radix_sort_pairs(sorted_node_indices, sorted_array_indices, count=index_count)
151
349
 
152
- # Build prefix sum of number of elements per node
153
- unique_node_indices_temp = borrow_temporary(
154
- temporary_store, shape=index_count, dtype=int, device=node_indices.device
155
- )
156
- node_element_counts_temp = borrow_temporary(
157
- temporary_store, shape=index_count, dtype=int, device=node_indices.device
158
- )
350
+ # Build prefix sum of number of elements per node
351
+ unique_node_indices_temp = cache.borrow_temporary(temporary_store, shape=index_count, dtype=int)
352
+ node_element_counts_temp = cache.borrow_temporary(temporary_store, shape=index_count, dtype=int)
159
353
 
160
- unique_node_indices = unique_node_indices_temp.array
161
- node_element_counts = node_element_counts_temp.array
354
+ unique_node_indices = unique_node_indices_temp.array
355
+ node_element_counts = node_element_counts_temp.array
162
356
 
163
- unique_node_count_dev = borrow_temporary(temporary_store, shape=(1,), dtype=int, device=sorted_node_indices.device)
164
- runlength_encode(
165
- sorted_node_indices,
166
- unique_node_indices,
167
- node_element_counts,
168
- value_count=index_count,
169
- run_count=unique_node_count_dev.array,
170
- )
357
+ unique_node_count_dev = cache.borrow_temporary(temporary_store, shape=(1,), dtype=int)
171
358
 
172
- # Transfer unique node count to host
173
- if node_indices.device.is_cuda:
174
- unique_node_count_host = borrow_temporary(temporary_store, shape=(1,), dtype=int, pinned=True, device="cpu")
175
- wp.copy(src=unique_node_count_dev.array, dest=unique_node_count_host.array, count=1)
176
- wp.synchronize_stream(wp.get_stream(node_indices.device))
177
- unique_node_count_dev.release()
359
+ runlength_encode(
360
+ sorted_node_indices,
361
+ unique_node_indices,
362
+ node_element_counts,
363
+ value_count=index_count,
364
+ run_count=unique_node_count_dev.array,
365
+ )
366
+
367
+ # Scatter seen run counts to global array of element count per node
368
+ node_offsets_temp = cache.borrow_temporary(temporary_store, shape=(node_count + 1), dtype=int)
369
+ node_offsets = node_offsets_temp.array
370
+
371
+ node_offsets.zero_()
372
+ wp.launch(
373
+ kernel=_scatter_node_counts,
374
+ dim=node_count + 1, # +1 to accommodate possible NULL node,
375
+ inputs=[node_element_counts, unique_node_indices, node_offsets, unique_node_count_dev.array],
376
+ )
377
+
378
+ if device.is_cuda and return_unique_nodes:
379
+ unique_node_count_host = cache.borrow_temporary(
380
+ temporary_store, shape=(1,), dtype=int, pinned=True, device="cpu"
381
+ )
382
+ wp.copy(src=unique_node_count_dev.array, dest=unique_node_count_host.array, count=1)
383
+ copy_event = cache.capture_event(device)
384
+
385
+ # Prefix sum of number of elements per node
386
+ array_scan(node_offsets, node_offsets, inclusive=True)
387
+
388
+ sorted_node_indices_temp.release()
389
+ node_element_counts_temp.release()
390
+
391
+ if not return_unique_nodes:
392
+ unique_node_count_dev.release()
393
+ return node_offsets_temp, sorted_array_indices_temp
394
+
395
+ if device.is_cuda:
396
+ cache.synchronize_event(copy_event)
397
+ unique_node_count_dev.release()
398
+ else:
399
+ unique_node_count_host = unique_node_count_dev
178
400
  unique_node_count = int(unique_node_count_host.array.numpy()[0])
179
401
  unique_node_count_host.release()
180
- else:
181
- unique_node_count = int(unique_node_count_dev.array.numpy()[0])
182
- unique_node_count_dev.release()
402
+ return node_offsets_temp, sorted_array_indices_temp, unique_node_count, unique_node_indices_temp
183
403
 
184
- # Scatter seen run counts to global array of element count per node
185
- node_offsets_temp = borrow_temporary(
186
- temporary_store, shape=(node_count + 1), device=node_element_counts.device, dtype=int
187
- )
188
- node_offsets = node_offsets_temp.array
189
404
 
190
- node_offsets.zero_()
191
- wp.launch(
192
- kernel=_scatter_node_counts,
193
- dim=unique_node_count,
194
- inputs=[node_element_counts, unique_node_indices, node_offsets],
195
- device=node_offsets.device,
196
- )
405
+ def host_read_at_index(array: wp.array, index: int = -1, temporary_store: cache.TemporaryStore = None) -> int:
406
+ """Returns the value of the array element at the given index on host"""
197
407
 
198
- # Prefix sum of number of elements per node
199
- array_scan(node_offsets, node_offsets, inclusive=True)
408
+ if index < 0:
409
+ index += array.shape[0]
200
410
 
201
- sorted_node_indices_temp.release()
202
- node_element_counts_temp.release()
411
+ if array.device.is_cuda:
412
+ temp = cache.borrow_temporary(temporary_store, shape=1, dtype=int, pinned=True, device="cpu")
413
+ wp.copy(dest=temp.array, src=array, src_offset=index, count=1)
414
+ wp.synchronize_stream(wp.get_stream(array.device))
415
+ return temp.array.numpy()[0]
203
416
 
204
- return node_offsets_temp, sorted_array_indices_temp, unique_node_count, unique_node_indices_temp
417
+ return array.numpy()[index]
205
418
 
206
419
 
207
420
  def masked_indices(
208
- mask: wp.array, missing_index=-1, temporary_store: TemporaryStore = None
209
- ) -> Tuple[Temporary, Temporary]:
421
+ mask: wp.array, missing_index=-1, temporary_store: cache.TemporaryStore = None
422
+ ) -> Tuple[cache.Temporary, cache.Temporary]:
210
423
  """
211
424
  From an array of boolean masks (must be either 0 or 1), returns:
212
425
  - The list of indices for which the mask is 1
213
426
  - A map associating to each element of the input mask array its local index if non-zero, or missing_index if zero.
214
427
  """
215
428
 
216
- offsets_temp = borrow_temporary_like(mask, temporary_store)
429
+ offsets_temp = cache.borrow_temporary_like(mask, temporary_store)
217
430
  offsets = offsets_temp.array
218
431
 
219
432
  wp.utils.array_scan(mask, offsets, inclusive=True)
220
433
 
221
434
  # Get back total counts on host
222
- if offsets.device.is_cuda:
223
- masked_count_temp = borrow_temporary(temporary_store, shape=1, dtype=int, pinned=True, device="cpu")
224
- wp.copy(dest=masked_count_temp.array, src=offsets, src_offset=offsets.shape[0] - 1, count=1)
225
- wp.synchronize_stream(wp.get_stream(offsets.device))
226
- masked_count = int(masked_count_temp.array.numpy()[0])
227
- masked_count_temp.release()
228
- else:
229
- masked_count = int(offsets.numpy()[-1])
435
+ masked_count = int(host_read_at_index(offsets, temporary_store=temporary_store))
230
436
 
231
437
  # Convert counts to indices
232
- indices_temp = borrow_temporary(temporary_store, shape=masked_count, device=mask.device, dtype=int)
438
+ indices_temp = cache.borrow_temporary(temporary_store, shape=masked_count, device=mask.device, dtype=int)
233
439
 
234
440
  wp.launch(
235
441
  kernel=_masked_indices_kernel,
@@ -262,10 +468,22 @@ def _iota_kernel(indices: wp.array(dtype=int), divisor: int):
262
468
 
263
469
  @wp.kernel
264
470
  def _scatter_node_counts(
265
- unique_counts: wp.array(dtype=int), unique_node_indices: wp.array(dtype=int), node_counts: wp.array(dtype=int)
471
+ unique_counts: wp.array(dtype=int),
472
+ unique_node_indices: wp.array(dtype=int),
473
+ node_counts: wp.array(dtype=int),
474
+ unique_node_count: wp.array(dtype=int),
266
475
  ):
267
476
  i = wp.tid()
268
- node_counts[1 + unique_node_indices[i]] = unique_counts[i]
477
+
478
+ if i >= unique_node_count[0]:
479
+ return
480
+
481
+ node_index = unique_node_indices[i]
482
+ if node_index == NULL_NODE_INDEX:
483
+ wp.atomic_sub(unique_node_count, 0, 1)
484
+ return
485
+
486
+ node_counts[1 + node_index] = unique_counts[i]
269
487
 
270
488
 
271
489
  @wp.kernel
@@ -467,7 +685,7 @@ def grid_to_hexes(Nx: int, Ny: int, Nz: int):
467
685
  Nz: Resolution of the grid along `z` dimension
468
686
 
469
687
  Returns:
470
- Array of shape (Nx * Ny * Nz, 8) containing vertex indices for each hexaedron
688
+ Array of shape (Nx * Ny * Nz, 8) containing vertex indices for each hexahedron
471
689
  """
472
690
 
473
691
  hex_vtx = np.array(
warp/native/array.h CHANGED
@@ -207,6 +207,22 @@ struct array_t
207
207
  strides[3] = sizeof(T);
208
208
  }
209
209
 
210
+ CUDA_CALLABLE array_t(uint64 data, int size, uint64 grad=0)
211
+ : array_t((T*)(data), size, (T*)(grad))
212
+ {}
213
+
214
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, uint64 grad=0)
215
+ : array_t((T*)(data), dim0, dim1, (T*)(grad))
216
+ {}
217
+
218
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, int dim2, uint64 grad=0)
219
+ : array_t((T*)(data), dim0, dim1, dim2, (T*)(grad))
220
+ {}
221
+
222
+ CUDA_CALLABLE array_t(uint64 data, int dim0, int dim1, int dim2, int dim3, uint64 grad=0)
223
+ : array_t((T*)(data), dim0, dim1, dim2, dim3, (T*)(grad))
224
+ {}
225
+
210
226
  CUDA_CALLABLE inline bool empty() const { return !data; }
211
227
 
212
228
  T* data;
warp/native/builtin.h CHANGED
@@ -1145,21 +1145,6 @@ struct launch_bounds_t
1145
1145
  size_t size; // total number of threads
1146
1146
  };
1147
1147
 
1148
- #ifndef __CUDACC__
1149
- static size_t s_threadIdx;
1150
- #endif
1151
-
1152
- inline CUDA_CALLABLE size_t grid_index()
1153
- {
1154
- #ifdef __CUDACC__
1155
- // Need to cast at least one of the variables being multiplied so that type promotion happens before the multiplication
1156
- size_t grid_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
1157
- return grid_index;
1158
- #else
1159
- return s_threadIdx;
1160
- #endif
1161
- }
1162
-
1163
1148
  inline CUDA_CALLABLE int tid(size_t index)
1164
1149
  {
1165
1150
  // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
warp/native/cuda_util.cpp CHANGED
@@ -24,14 +24,11 @@
24
24
  #include <stack>
25
25
 
26
26
  // the minimum CUDA version required from the driver
27
- #define WP_CUDA_DRIVER_VERSION 11030
27
+ #define WP_CUDA_DRIVER_VERSION 11040
28
28
 
29
29
  // the minimum CUDA Toolkit version required to build Warp
30
30
  #define WP_CUDA_TOOLKIT_VERSION 11050
31
31
 
32
- #define WP_CUDA_VERSION_MAJOR(version) (version / 1000)
33
- #define WP_CUDA_VERSION_MINOR(version) ((version % 1000) / 10)
34
-
35
32
  // check if the CUDA Toolkit is too old
36
33
  #if CUDA_VERSION < WP_CUDA_TOOLKIT_VERSION
37
34
  #error Building Warp requires CUDA Toolkit version 11.5 or higher
@@ -108,6 +105,17 @@ bool ContextGuard::always_restore = false;
108
105
 
109
106
  CudaTimingState* g_cuda_timing_state = NULL;
110
107
 
108
+
109
+ static inline int get_major(int version)
110
+ {
111
+ return version / 1000;
112
+ }
113
+
114
+ static inline int get_minor(int version)
115
+ {
116
+ return (version % 1000) / 10;
117
+ }
118
+
111
119
  static bool get_driver_entry_point(const char* name, void** pfn)
112
120
  {
113
121
  if (!pfn_cuGetProcAddress || !name || !pfn)
@@ -163,8 +171,8 @@ bool init_cuda_driver()
163
171
  if (driver_version < WP_CUDA_DRIVER_VERSION)
164
172
  {
165
173
  fprintf(stderr, "Warp CUDA error: Warp requires CUDA driver %d.%d or higher, but the current driver only supports CUDA %d.%d\n",
166
- WP_CUDA_VERSION_MAJOR(WP_CUDA_DRIVER_VERSION), WP_CUDA_VERSION_MINOR(WP_CUDA_DRIVER_VERSION),
167
- WP_CUDA_VERSION_MAJOR(driver_version), WP_CUDA_VERSION_MINOR(driver_version));
174
+ get_major(WP_CUDA_DRIVER_VERSION), get_minor(WP_CUDA_DRIVER_VERSION),
175
+ get_major(driver_version), get_minor(driver_version));
168
176
  return false;
169
177
  }
170
178
  }