warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/native/tile_reduce.h CHANGED
@@ -24,6 +24,20 @@
24
24
  namespace wp
25
25
  {
26
26
 
27
+
28
+ template <typename T>
29
+ int argmax_tracker(T champion_value, T current_value, int champion_index, int current_index)
30
+ {
31
+ return current_value > champion_value ? current_index : champion_index;
32
+ }
33
+
34
+ template <typename T>
35
+ int argmin_tracker(T champion_value, T current_value, int champion_index, int current_index)
36
+ {
37
+ return current_value < champion_value ? current_index : champion_index;
38
+ }
39
+
40
+
27
41
  #if defined(__CUDA_ARCH__)
28
42
 
29
43
  template <typename T>
@@ -62,6 +76,44 @@ inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
62
76
  return output;
63
77
  }
64
78
 
79
+ // Vector overload
80
+ template <unsigned Length, typename T>
81
+ inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T> val, int offset, int mask)
82
+ {
83
+ wp::vec_t<Length, T> result;
84
+
85
+ for (unsigned i=0; i < Length; ++i)
86
+ result.data[i] = __shfl_down_sync(mask, val.data[i], offset, WP_TILE_WARP_SIZE);
87
+
88
+ return result;
89
+ }
90
+
91
+ // Quaternion overload
92
+ template <typename T>
93
+ inline CUDA_CALLABLE wp::quat_t<T> warp_shuffle_down(wp::quat_t<T> val, int offset, int mask)
94
+ {
95
+ wp::quat_t<T> result;
96
+
97
+ for (unsigned i=0; i < 4; ++i)
98
+ result.data[i] = __shfl_down_sync(mask, val.data[i], offset, WP_TILE_WARP_SIZE);
99
+
100
+ return result;
101
+ }
102
+
103
+ // Matrix overload
104
+ template <unsigned Rows, unsigned Cols, typename T>
105
+ inline CUDA_CALLABLE wp::mat_t<Rows, Cols, T> warp_shuffle_down(wp::mat_t<Rows, Cols, T> val, int offset, int mask)
106
+ {
107
+ wp::mat_t<Rows, Cols, T> result;
108
+
109
+ for (unsigned i=0; i < Rows; ++i)
110
+ for (unsigned j=0; j < Cols; ++j)
111
+ result.data[i][j] = __shfl_down_sync(mask, val.data[i][j], offset, WP_TILE_WARP_SIZE);
112
+
113
+ return result;
114
+ }
115
+
116
+
65
117
  template <typename T, typename Op>
66
118
  inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
67
119
  {
@@ -89,6 +141,52 @@ inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
89
141
  return sum;
90
142
  }
91
143
 
144
+ template <typename T>
145
+ struct ValueAndIndex
146
+ {
147
+ T value;
148
+ int index;
149
+ };
150
+
151
+ template <typename T, typename Op, typename OpTrack>
152
+ inline CUDA_CALLABLE ValueAndIndex<T> warp_reduce_tracked(T val, int idx, Op f, OpTrack track, unsigned int mask)
153
+ {
154
+ T sum = val;
155
+ int index = idx;
156
+
157
+ if (mask == 0xFFFFFFFF)
158
+ {
159
+ // handle case where entire warp is active
160
+ for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
161
+ {
162
+ auto shfl_val = warp_shuffle_down(sum, offset, mask);
163
+ int shfl_idx = warp_shuffle_down(index, offset, mask);
164
+ index = track(sum, shfl_val, index, shfl_idx);
165
+ sum = f(sum, shfl_val);
166
+ }
167
+ }
168
+ else
169
+ {
170
+ // handle partial warp case
171
+ for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
172
+ {
173
+ T shfl_val = warp_shuffle_down(sum, offset, mask);
174
+ int shfl_index = warp_shuffle_down(index, offset, mask);
175
+ if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
176
+ {
177
+ index = track(sum, shfl_val, index, shfl_index);
178
+ sum = f(sum, shfl_val);
179
+ }
180
+ }
181
+ }
182
+
183
+ ValueAndIndex<T> result;
184
+ result.value = sum;
185
+ result.index = index;
186
+
187
+ return result;
188
+ }
189
+
92
190
  // non-axis version which computes sum
93
191
  // across the entire tile using the whole block
94
192
  template <typename Tile, typename Op>
@@ -159,6 +257,85 @@ auto tile_reduce_impl(Op f, Tile& t)
159
257
  return output;
160
258
  }
161
259
 
260
+
261
+ // non-axis version which computes sum
262
+ // across the entire tile using the whole block
263
+ template <typename Tile, typename Op, typename OpTrack>
264
+ auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
265
+ {
266
+ using T = typename Tile::Type;
267
+
268
+ auto input = t.copy_to_register();
269
+ auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
270
+
271
+ const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
272
+ const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
273
+ const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
274
+
275
+ using Layout = typename decltype(input)::Layout;
276
+
277
+ int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
278
+ T thread_sum = input.data[0];
279
+
280
+ // thread reduction
281
+ WP_PRAGMA_UNROLL
282
+ for (int i=1; i < Layout::NumRegs; ++i)
283
+ {
284
+ int linear = Layout::linear_from_register(i);
285
+ if (!Layout::valid(linear))
286
+ break;
287
+
288
+ champion_index = track(thread_sum, input.data[i], champion_index, linear);
289
+ thread_sum = f(thread_sum, input.data[i]);
290
+ }
291
+
292
+ // ensure that only threads with at least one valid item participate in the reduction
293
+ unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
294
+
295
+ // warp reduction
296
+ ValueAndIndex<T> warp_sum = warp_reduce_tracked(thread_sum, champion_index, f, track, mask);
297
+
298
+ // fixed size scratch pad for partial results in shared memory
299
+ WP_TILE_SHARED T partials[warp_count];
300
+ WP_TILE_SHARED int partials_idx[warp_count];
301
+
302
+ // count of active warps
303
+ WP_TILE_SHARED int active_warps;
304
+ if (threadIdx.x == 0)
305
+ active_warps = 0;
306
+
307
+ // ensure active_warps is initialized
308
+ WP_TILE_SYNC();
309
+
310
+ if (lane_index == 0)
311
+ {
312
+ partials[warp_index] = warp_sum.value;
313
+ partials_idx[warp_index] = warp_sum.index;
314
+ atomicAdd(&active_warps, 1);
315
+ }
316
+
317
+ // ensure partials are ready
318
+ WP_TILE_SYNC();
319
+
320
+ // reduce across block, todo: use warp_reduce() here
321
+ if (threadIdx.x == 0)
322
+ {
323
+ T block_sum = partials[0];
324
+ int block_champion_index = partials_idx[0];
325
+
326
+ WP_PRAGMA_UNROLL
327
+ for (int i=1; i < active_warps; ++i)
328
+ {
329
+ block_champion_index = track(block_sum, partials[i], block_champion_index, partials_idx[i]);
330
+ block_sum = f(block_sum, partials[i]);
331
+ }
332
+
333
+ output.data[0] = block_champion_index;
334
+ }
335
+
336
+ return output;
337
+ }
338
+
162
339
  #else
163
340
 
164
341
  // CPU implementation
@@ -171,9 +348,9 @@ auto tile_reduce_impl(Op f, Tile& t)
171
348
  auto input = t.copy_to_register();
172
349
  auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
173
350
 
174
- using Layout = typename decltype(input)::Layout;
351
+ using Layout = typename decltype(input)::Layout;
175
352
 
176
- T sum = input.data[0];
353
+ T sum = input.data[0];
177
354
 
178
355
  WP_PRAGMA_UNROLL
179
356
  for (int i=1; i < Layout::NumRegs; ++i)
@@ -189,6 +366,34 @@ auto tile_reduce_impl(Op f, Tile& t)
189
366
  return output;
190
367
  }
191
368
 
369
+ template <typename Tile, typename Op, typename OpTrack>
370
+ auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
371
+ {
372
+ using T = typename Tile::Type;
373
+
374
+ auto input = t.copy_to_register();
375
+ auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
376
+
377
+ using Layout = typename decltype(input)::Layout;
378
+
379
+ int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
380
+ T sum = input.data[0];
381
+
382
+ WP_PRAGMA_UNROLL
383
+ for (int i=1; i < Layout::NumRegs; ++i)
384
+ {
385
+ int linear = Layout::linear_from_register(i);
386
+ if (!Layout::valid(linear))
387
+ break;
388
+
389
+ champion_index = track(sum, input.data[i], champion_index, linear);
390
+ sum = f(sum, input.data[i]);
391
+ }
392
+
393
+ output.data[0] = champion_index;
394
+ return output;
395
+ }
396
+
192
397
  #endif // !defined(__CUDA_ARCH__)
193
398
 
194
399
  inline void adj_tile_reduce_impl()
@@ -200,6 +405,9 @@ inline void adj_tile_reduce_impl()
200
405
  #define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
201
406
  #define adj_tile_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_reduce_impl()
202
407
 
408
+ #define tile_arg_reduce(op, opTrack, t) tile_arg_reduce_impl([](auto x, auto y) { return op(x, y);}, [](auto a, auto b, auto c, auto d) { return opTrack(a, b, c, d); }, t)
409
+ #define adj_tile_arg_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_arg_reduce_impl()
410
+
203
411
  // convenience methods for specific reductions
204
412
 
205
413
  template <typename Tile>
@@ -261,4 +469,31 @@ void adj_tile_min(Tile& t, Tile& adj_t, AdjTile& adj_ret)
261
469
 
262
470
 
263
471
 
472
+ template <typename Tile>
473
+ auto tile_argmax(Tile& t)
474
+ {
475
+ return tile_arg_reduce(max, argmax_tracker, t);
476
+ }
477
+
478
+ template <typename Tile, typename AdjTile>
479
+ void adj_tile_argmax(Tile& t, Tile& adj_t, AdjTile& adj_ret)
480
+ {
481
+ // todo: not implemented
482
+ }
483
+
484
+ template <typename Tile>
485
+ auto tile_argmin(Tile& t)
486
+ {
487
+ return tile_arg_reduce(min, argmin_tracker, t);
488
+ }
489
+
490
+ template <typename Tile, typename AdjTile>
491
+ void adj_tile_argmin(Tile& t, Tile& adj_t, AdjTile& adj_ret)
492
+ {
493
+ // todo: not implemented
494
+ }
495
+
496
+
497
+
498
+
264
499
  } // namespace wp
@@ -0,0 +1,240 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "tile.h"
21
+
22
+ #if defined(__clang__)
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif
27
+
28
+ namespace wp
29
+ {
30
+
31
+ #if defined(__CUDA_ARCH__)
32
+
33
+
34
+ template<typename T>
35
+ inline CUDA_CALLABLE T scan_warp_inclusive(int lane, T value)
36
+ {
37
+ //Computes an inclusive cumulative sum
38
+ #pragma unroll
39
+ for (int i = 1; i <= 32; i *= 2)
40
+ {
41
+ auto n = __shfl_up_sync(0xffffffffu, value, i, 32);
42
+
43
+ if (lane >= i)
44
+ value = value + n;
45
+ }
46
+ return value;
47
+ }
48
+
49
+
50
+ template<typename T>
51
+ inline CUDA_CALLABLE T thread_block_scan_inclusive(int lane, int warp_index, int num_warps, T value)
52
+ {
53
+ WP_TILE_SHARED T sums[1024 / WP_TILE_WARP_SIZE]; // 1024 is the maximum number of threads per block
54
+
55
+ value = scan_warp_inclusive(lane, value);
56
+
57
+ if (lane == 31)
58
+ {
59
+ sums[warp_index] = value;
60
+ }
61
+
62
+ WP_TILE_SYNC();
63
+
64
+ if (warp_index == 0)
65
+ {
66
+ T v = lane < num_warps ? sums[lane] : T(0);
67
+ v = scan_warp_inclusive(lane, v);
68
+ if (lane < num_warps)
69
+ sums[lane] = v;
70
+ }
71
+
72
+ WP_TILE_SYNC();
73
+
74
+ if (warp_index > 0)
75
+ {
76
+ value += sums[warp_index - 1];
77
+ }
78
+
79
+ return value;
80
+ }
81
+
82
+ template<typename T, bool exclusive>
83
+ inline CUDA_CALLABLE void thread_block_scan(T* values, int num_elements)
84
+ {
85
+ const int num_threads_in_block = blockDim.x;
86
+ const int num_iterations = (num_elements + num_threads_in_block - 1) / num_threads_in_block;
87
+
88
+ WP_TILE_SHARED T offset;
89
+ if (threadIdx.x == 0)
90
+ offset = T(0);
91
+
92
+ WP_TILE_SYNC();
93
+
94
+ const int lane = WP_TILE_THREAD_IDX % WP_TILE_WARP_SIZE;
95
+ const int warp_index = WP_TILE_THREAD_IDX / WP_TILE_WARP_SIZE;
96
+ const int num_warps = num_threads_in_block / WP_TILE_WARP_SIZE;
97
+
98
+ for (int i = 0; i < num_iterations; ++i)
99
+ {
100
+ int element_index = WP_TILE_THREAD_IDX + i * num_threads_in_block;
101
+ T orig_value = element_index < num_elements ? values[element_index] : T(0);
102
+ T value = thread_block_scan_inclusive(lane, warp_index, num_warps, orig_value);
103
+ if (element_index < num_elements)
104
+ {
105
+ T new_value = value + offset;
106
+ if constexpr (exclusive)
107
+ new_value -= orig_value;
108
+ values[element_index] = new_value;
109
+ }
110
+
111
+ WP_TILE_SYNC();
112
+
113
+ if (threadIdx.x == num_threads_in_block - 1)
114
+ offset += value;
115
+
116
+ WP_TILE_SYNC();
117
+ }
118
+ }
119
+
120
+ template<typename Tile>
121
+ inline CUDA_CALLABLE auto tile_scan_inclusive_impl(Tile& t)
122
+ {
123
+ using T = typename Tile::Type;
124
+ constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
125
+
126
+ // create a temporary shared tile to hold the input values
127
+ WP_TILE_SHARED T smem[num_elements_to_scan];
128
+ tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
129
+
130
+ // copy input values to scratch space
131
+ scratch.assign(t);
132
+
133
+ T* values = &scratch.data(0);
134
+ thread_block_scan<T, false>(values, num_elements_to_scan);
135
+
136
+ auto result = scratch.copy_to_register();
137
+
138
+ WP_TILE_SYNC();
139
+
140
+ return result;
141
+ }
142
+
143
+ template<typename Tile>
144
+ inline CUDA_CALLABLE auto tile_scan_exclusive_impl(Tile& t)
145
+ {
146
+ using T = typename Tile::Type;
147
+ constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
148
+
149
+ // create a temporary shared tile to hold the input values
150
+ WP_TILE_SHARED T smem[num_elements_to_scan];
151
+ tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
152
+
153
+ // copy input values to scratch space
154
+ scratch.assign(t);
155
+
156
+ T* values = &scratch.data(0);
157
+ thread_block_scan<T, true>(values, num_elements_to_scan);
158
+
159
+ auto result = scratch.copy_to_register();
160
+
161
+ WP_TILE_SYNC();
162
+
163
+ return result;
164
+ }
165
+
166
+ #else
167
+
168
+ template<typename Tile>
169
+ inline auto tile_scan_inclusive_impl(Tile& t)
170
+ {
171
+ using T = typename Tile::Type;
172
+ constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
173
+
174
+ auto input = t.copy_to_register();
175
+ auto output = tile_register_like<Tile>();
176
+
177
+ using Layout = typename decltype(input)::Layout;
178
+
179
+ T sum = T(0);
180
+ for (int i = 0; i < num_elements_to_scan; ++i)
181
+ {
182
+ sum += input.data[i];
183
+ output.data[i] = sum;
184
+ }
185
+
186
+ return output;
187
+ }
188
+
189
+ template<typename Tile>
190
+ inline auto tile_scan_exclusive_impl(Tile& t)
191
+ {
192
+ using T = typename Tile::Type;
193
+ constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
194
+
195
+ auto input = t.copy_to_register();
196
+ auto output = tile_register_like<Tile>();
197
+
198
+ using Layout = typename decltype(input)::Layout;
199
+
200
+ T sum = T(0);
201
+ for (int i = 0; i < num_elements_to_scan; ++i)
202
+ {
203
+ output.data[i] = sum;
204
+ sum += input.data[i];
205
+ }
206
+
207
+ return output;
208
+ }
209
+
210
+ #endif // !defined(__CUDA_ARCH__)
211
+
212
+ template <typename Tile>
213
+ auto tile_scan_inclusive(Tile& t)
214
+ {
215
+ return tile_scan_inclusive_impl(t);
216
+ }
217
+
218
+ template <typename Tile, typename AdjTile>
219
+ void adj_tile_scan_inclusive(Tile& t, Tile& adj_t, AdjTile& adj_ret)
220
+ {
221
+ // todo: not implemented
222
+ }
223
+
224
+ template <typename Tile>
225
+ auto tile_scan_exclusive(Tile& t)
226
+ {
227
+ return tile_scan_exclusive_impl(t);
228
+ }
229
+
230
+ template <typename Tile, typename AdjTile>
231
+ void adj_tile_scan_exclusive(Tile& t, Tile& adj_t, AdjTile& adj_ret)
232
+ {
233
+ // todo: not implemented
234
+ }
235
+
236
+ } // namespace wp
237
+
238
+ #if defined(__clang__)
239
+ #pragma clang diagnostic pop
240
+ #endif
warp/native/tuple.h ADDED
@@ -0,0 +1,189 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ namespace wp
21
+ {
22
+
23
+ template <typename... Types>
24
+ struct tuple_t;
25
+
26
+ template <>
27
+ struct tuple_t<>
28
+ {
29
+
30
+ static constexpr int size() { return 0; }
31
+
32
+ // Base case: empty tuple.
33
+ template <typename Callable>
34
+ void apply(Callable&&) const { }
35
+ };
36
+
37
+ template <typename Head, typename... Tail>
38
+ struct tuple_t<Head, Tail...>
39
+ {
40
+ Head head;
41
+ tuple_t<Tail...> tail;
42
+
43
+ CUDA_CALLABLE inline tuple_t() {}
44
+ CUDA_CALLABLE inline tuple_t(Head h, Tail... t) : head(h), tail(t...) {}
45
+
46
+ static constexpr int size() { return 1 + tuple_t<Tail...>::size(); }
47
+
48
+ // Applies a callable to each element.
49
+ template <typename Callable>
50
+ void apply(Callable&& func) const
51
+ {
52
+ func(head); // Apply the callable to the current element.
53
+ tail.apply(func); // Recursively process the rest of the tuple.
54
+ }
55
+ };
56
+
57
+ // Tuple constructor.
58
+ template <typename... Args>
59
+ CUDA_CALLABLE inline tuple_t<Args...>
60
+ tuple(
61
+ Args... args
62
+ )
63
+ {
64
+ return tuple_t<Args...>(args...);
65
+ }
66
+
67
+ // Helper to extract a value from the tuple.
68
+ // Can be replaced with simpler member function version when our CPU compiler
69
+ // backend supports constexpr if statements.
70
+ template <int N, typename Head, typename... Tail>
71
+ struct tuple_get
72
+ {
73
+ static CUDA_CALLABLE inline const auto&
74
+ value(
75
+ const tuple_t<Head, Tail...>& t
76
+ )
77
+ {
78
+ return tuple_get<N - 1, Tail...>::value(t.tail);
79
+ }
80
+ };
81
+
82
+ // Specialization for the base case N == 0. Simply return the head of the tuple.
83
+ template <typename Head, typename... Tail>
84
+ struct tuple_get<0, Head, Tail...>
85
+ {
86
+ static CUDA_CALLABLE inline const auto&
87
+ value(
88
+ const tuple_t<Head, Tail...>& t
89
+ )
90
+ {
91
+ return t.head;
92
+ }
93
+ };
94
+
95
+ template <int Index, typename... Args>
96
+ CUDA_CALLABLE inline auto
97
+ extract(
98
+ const tuple_t<Args...>& t
99
+ )
100
+ {
101
+ return tuple_get<Index, Args...>::value(t);
102
+ }
103
+
104
+ template <typename... Args>
105
+ CUDA_CALLABLE inline int
106
+ len(
107
+ const tuple_t<Args...>& t
108
+ )
109
+ {
110
+ return t.size();
111
+ }
112
+
113
+ template <typename... Args>
114
+ CUDA_CALLABLE inline void
115
+ adj_len(
116
+ const tuple_t<Args...>& t,
117
+ tuple_t<Args...>& adj_t,
118
+ int adj_ret
119
+ )
120
+ {
121
+ }
122
+
123
+ template <typename... Args>
124
+ CUDA_CALLABLE inline void
125
+ print(
126
+ const tuple_t<Args...>& t
127
+ )
128
+ {
129
+ t.apply([&](auto a) { print(a); });
130
+ }
131
+
132
+ template <typename... Args>
133
+ CUDA_CALLABLE inline void
134
+ adj_print(
135
+ const tuple_t<Args...>& t,
136
+ tuple_t<Args...>& adj_t
137
+ )
138
+ {
139
+ adj_t.apply([&](auto a) { print(a); });
140
+ }
141
+
142
+ CUDA_CALLABLE inline tuple_t<>
143
+ add(
144
+ const tuple_t<>& a,
145
+ const tuple_t<>& b
146
+ )
147
+ {
148
+ return tuple_t<>();
149
+ }
150
+
151
+ template <typename Head, typename... Tail>
152
+ CUDA_CALLABLE inline tuple_t<Head, Tail...>
153
+ add(
154
+ const tuple_t<Head, Tail...>& a,
155
+ const tuple_t<Head, Tail...>& b
156
+ )
157
+ {
158
+ tuple_t<Head, Tail...> out;
159
+ out.head = add(a.head, b.head);
160
+ out.tail = add(a.tail, b.tail);
161
+ return out;
162
+ }
163
+
164
+ CUDA_CALLABLE inline void
165
+ adj_add(
166
+ const tuple_t<>& a,
167
+ const tuple_t<>& b,
168
+ tuple_t<>& adj_a,
169
+ tuple_t<>& adj_b,
170
+ const tuple_t<>& adj_ret
171
+ )
172
+ {
173
+ }
174
+
175
+ template <typename Head, typename... Tail>
176
+ CUDA_CALLABLE inline void
177
+ adj_add(
178
+ const tuple_t<Head, Tail...>& a,
179
+ const tuple_t<Head, Tail...>& b,
180
+ tuple_t<Head, Tail...>& adj_a,
181
+ tuple_t<Head, Tail...>& adj_b,
182
+ const tuple_t<Head, Tail...>& adj_ret
183
+ )
184
+ {
185
+ adj_add(a.head, b.head, adj_ret.head);
186
+ adj_add(a.tail, b.tail, adj_ret.tail);
187
+ }
188
+
189
+ } // namespace wp