warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/native/tile_reduce.h CHANGED
@@ -24,6 +24,20 @@
24
24
  namespace wp
25
25
  {
26
26
 
27
+
28
+ template <typename T>
29
+ int argmax_tracker(T champion_value, T current_value, int champion_index, int current_index)
30
+ {
31
+ return current_value > champion_value ? current_index : champion_index;
32
+ }
33
+
34
+ template <typename T>
35
+ int argmin_tracker(T champion_value, T current_value, int champion_index, int current_index)
36
+ {
37
+ return current_value < champion_value ? current_index : champion_index;
38
+ }
39
+
40
+
27
41
  #if defined(__CUDA_ARCH__)
28
42
 
29
43
  template <typename T>
@@ -62,6 +76,32 @@ inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
62
76
  return output;
63
77
  }
64
78
 
79
+ // Vector overload
80
+ template <unsigned Length, typename T>
81
+ inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T> val, int offset, int mask)
82
+ {
83
+ wp::vec_t<Length, T> result;
84
+
85
+ for (unsigned i=0; i < Length; ++i)
86
+ result[i] = __shfl_down_sync(mask, val[i], offset, WP_TILE_WARP_SIZE);
87
+
88
+ return result;
89
+ }
90
+
91
+ // Matrix overload
92
+ template <unsigned Rows, unsigned Cols, typename T>
93
+ inline CUDA_CALLABLE wp::mat_t<Rows, Cols, T> warp_shuffle_down(wp::mat_t<Rows, Cols, T> val, int offset, int mask)
94
+ {
95
+ wp::mat_t<Rows, Cols, T> result;
96
+
97
+ for (unsigned i=0; i < Rows; ++i)
98
+ for (unsigned j=0; j < Cols; ++j)
99
+ result.data[i][j] = __shfl_down_sync(mask, val.data[i][j], offset, WP_TILE_WARP_SIZE);
100
+
101
+ return result;
102
+ }
103
+
104
+
65
105
  template <typename T, typename Op>
66
106
  inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
67
107
  {
@@ -89,6 +129,52 @@ inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
89
129
  return sum;
90
130
  }
91
131
 
132
+ template <typename T>
133
+ struct ValueAndIndex
134
+ {
135
+ T value;
136
+ int index;
137
+ };
138
+
139
+ template <typename T, typename Op, typename OpTrack>
140
+ inline CUDA_CALLABLE ValueAndIndex<T> warp_reduce_tracked(T val, int idx, Op f, OpTrack track, unsigned int mask)
141
+ {
142
+ T sum = val;
143
+ int index = idx;
144
+
145
+ if (mask == 0xFFFFFFFF)
146
+ {
147
+ // handle case where entire warp is active
148
+ for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
149
+ {
150
+ auto shfl_val = warp_shuffle_down(sum, offset, mask);
151
+ int shfl_idx = warp_shuffle_down(index, offset, mask);
152
+ index = track(sum, shfl_val, index, shfl_idx);
153
+ sum = f(sum, shfl_val);
154
+ }
155
+ }
156
+ else
157
+ {
158
+ // handle partial warp case
159
+ for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
160
+ {
161
+ T shfl_val = warp_shuffle_down(sum, offset, mask);
162
+ int shfl_index = warp_shuffle_down(index, offset, mask);
163
+ if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
164
+ {
165
+ index = track(sum, shfl_val, index, shfl_index);
166
+ sum = f(sum, shfl_val);
167
+ }
168
+ }
169
+ }
170
+
171
+ ValueAndIndex<T> result;
172
+ result.value = sum;
173
+ result.index = index;
174
+
175
+ return result;
176
+ }
177
+
92
178
  // non-axis version which computes sum
93
179
  // across the entire tile using the whole block
94
180
  template <typename Tile, typename Op>
@@ -120,6 +206,7 @@ auto tile_reduce_impl(Op f, Tile& t)
120
206
 
121
207
  // ensure that only threads with at least one valid item participate in the reduction
122
208
  unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
209
+ bool warp_is_active = mask != 0;
123
210
 
124
211
  // warp reduction
125
212
  T warp_sum = warp_reduce(thread_sum, f, mask);
@@ -135,7 +222,7 @@ auto tile_reduce_impl(Op f, Tile& t)
135
222
  // ensure active_warps is initialized
136
223
  WP_TILE_SYNC();
137
224
 
138
- if (lane_index == 0)
225
+ if (lane_index == 0 && warp_is_active)
139
226
  {
140
227
  partials[warp_index] = warp_sum;
141
228
  atomicAdd(&active_warps, 1);
@@ -159,6 +246,86 @@ auto tile_reduce_impl(Op f, Tile& t)
159
246
  return output;
160
247
  }
161
248
 
249
+
250
+ // non-axis version which computes sum
251
+ // across the entire tile using the whole block
252
+ template <typename Tile, typename Op, typename OpTrack>
253
+ auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
254
+ {
255
+ using T = typename Tile::Type;
256
+
257
+ auto input = t.copy_to_register();
258
+ auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
259
+
260
+ const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
261
+ const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
262
+ const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
263
+
264
+ using Layout = typename decltype(input)::Layout;
265
+
266
+ int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
267
+ T thread_sum = input.data[0];
268
+
269
+ // thread reduction
270
+ WP_PRAGMA_UNROLL
271
+ for (int i=1; i < Layout::NumRegs; ++i)
272
+ {
273
+ int linear = Layout::linear_from_register(i);
274
+ if (!Layout::valid(linear))
275
+ break;
276
+
277
+ champion_index = track(thread_sum, input.data[i], champion_index, linear);
278
+ thread_sum = f(thread_sum, input.data[i]);
279
+ }
280
+
281
+ // ensure that only threads with at least one valid item participate in the reduction
282
+ unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
283
+ bool warp_is_active = mask != 0;
284
+
285
+ // warp reduction
286
+ ValueAndIndex<T> warp_sum = warp_reduce_tracked(thread_sum, champion_index, f, track, mask);
287
+
288
+ // fixed size scratch pad for partial results in shared memory
289
+ WP_TILE_SHARED T partials[warp_count];
290
+ WP_TILE_SHARED int partials_idx[warp_count];
291
+
292
+ // count of active warps
293
+ WP_TILE_SHARED int active_warps;
294
+ if (threadIdx.x == 0)
295
+ active_warps = 0;
296
+
297
+ // ensure active_warps is initialized
298
+ WP_TILE_SYNC();
299
+
300
+ if (lane_index == 0 && warp_is_active)
301
+ {
302
+ partials[warp_index] = warp_sum.value;
303
+ partials_idx[warp_index] = warp_sum.index;
304
+ atomicAdd(&active_warps, 1);
305
+ }
306
+
307
+ // ensure partials are ready
308
+ WP_TILE_SYNC();
309
+
310
+ // reduce across block, todo: use warp_reduce() here
311
+ if (threadIdx.x == 0)
312
+ {
313
+ T block_sum = partials[0];
314
+ int block_champion_index = partials_idx[0];
315
+
316
+ WP_PRAGMA_UNROLL
317
+ for (int i=1; i < active_warps; ++i)
318
+ {
319
+ block_champion_index = track(block_sum, partials[i], block_champion_index, partials_idx[i]);
320
+ block_sum = f(block_sum, partials[i]);
321
+ }
322
+
323
+ output.data[0] = block_champion_index;
324
+ }
325
+
326
+ return output;
327
+ }
328
+
162
329
  #else
163
330
 
164
331
  // CPU implementation
@@ -171,9 +338,9 @@ auto tile_reduce_impl(Op f, Tile& t)
171
338
  auto input = t.copy_to_register();
172
339
  auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
173
340
 
174
- using Layout = typename decltype(input)::Layout;
341
+ using Layout = typename decltype(input)::Layout;
175
342
 
176
- T sum = input.data[0];
343
+ T sum = input.data[0];
177
344
 
178
345
  WP_PRAGMA_UNROLL
179
346
  for (int i=1; i < Layout::NumRegs; ++i)
@@ -189,6 +356,34 @@ auto tile_reduce_impl(Op f, Tile& t)
189
356
  return output;
190
357
  }
191
358
 
359
+ template <typename Tile, typename Op, typename OpTrack>
360
+ auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
361
+ {
362
+ using T = typename Tile::Type;
363
+
364
+ auto input = t.copy_to_register();
365
+ auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
366
+
367
+ using Layout = typename decltype(input)::Layout;
368
+
369
+ int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
370
+ T sum = input.data[0];
371
+
372
+ WP_PRAGMA_UNROLL
373
+ for (int i=1; i < Layout::NumRegs; ++i)
374
+ {
375
+ int linear = Layout::linear_from_register(i);
376
+ if (!Layout::valid(linear))
377
+ break;
378
+
379
+ champion_index = track(sum, input.data[i], champion_index, linear);
380
+ sum = f(sum, input.data[i]);
381
+ }
382
+
383
+ output.data[0] = champion_index;
384
+ return output;
385
+ }
386
+
192
387
  #endif // !defined(__CUDA_ARCH__)
193
388
 
194
389
  inline void adj_tile_reduce_impl()
@@ -200,6 +395,9 @@ inline void adj_tile_reduce_impl()
200
395
  #define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
201
396
  #define adj_tile_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_reduce_impl()
202
397
 
398
+ #define tile_arg_reduce(op, opTrack, t) tile_arg_reduce_impl([](auto x, auto y) { return op(x, y);}, [](auto a, auto b, auto c, auto d) { return opTrack(a, b, c, d); }, t)
399
+ #define adj_tile_arg_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_arg_reduce_impl()
400
+
203
401
  // convenience methods for specific reductions
204
402
 
205
403
  template <typename Tile>
@@ -214,25 +412,26 @@ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
214
412
  {
215
413
  using T = typename Tile::Type;
216
414
 
217
- #if !defined(__CUDA_ARCH__)
218
-
219
- for (int i=0; i < Tile::Layout::Size; ++i)
220
- {
221
- adj_t(i) += adj_ret.data[0];
415
+ auto adj_reg = adj_ret.grad_to_register();
222
416
 
223
- }
417
+ #if !defined(__CUDA_ARCH__)
418
+ T scratch = adj_reg.data[0];
224
419
  #else
225
420
  // broadcast incoming adjoint to block
226
421
  WP_TILE_SHARED T scratch;
227
422
  if (WP_TILE_THREAD_IDX == 0)
228
- scratch = adj_ret.data[0];
423
+ scratch = adj_reg.data[0];
229
424
 
230
425
  WP_TILE_SYNC();
426
+ #endif
231
427
 
232
- // broadcast scalar across input dimensions (note zero strides)
233
- auto adj_ret_reg = tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape, tile_stride_t<0, 0>>, false>(&scratch, nullptr).copy_to_register();
428
+ auto adj_ret_reg = tile_register_like<Tile>();
429
+ using Layout = typename decltype(adj_ret_reg)::Layout;
430
+ for (int i=0; i < Layout::NumRegs; ++i)
431
+ {
432
+ adj_ret_reg.data[i] += scratch;
433
+ }
234
434
  adj_t.grad_add(adj_ret_reg);
235
- #endif
236
435
  }
237
436
 
238
437
  template <typename Tile>
@@ -261,4 +460,31 @@ void adj_tile_min(Tile& t, Tile& adj_t, AdjTile& adj_ret)
261
460
 
262
461
 
263
462
 
463
+ template <typename Tile>
464
+ auto tile_argmax(Tile& t)
465
+ {
466
+ return tile_arg_reduce(max, argmax_tracker, t);
467
+ }
468
+
469
+ template <typename Tile, typename AdjTile>
470
+ void adj_tile_argmax(Tile& t, Tile& adj_t, AdjTile& adj_ret)
471
+ {
472
+ // todo: not implemented
473
+ }
474
+
475
+ template <typename Tile>
476
+ auto tile_argmin(Tile& t)
477
+ {
478
+ return tile_arg_reduce(min, argmin_tracker, t);
479
+ }
480
+
481
+ template <typename Tile, typename AdjTile>
482
+ void adj_tile_argmin(Tile& t, Tile& adj_t, AdjTile& adj_ret)
483
+ {
484
+ // todo: not implemented
485
+ }
486
+
487
+
488
+
489
+
264
490
  } // namespace wp
@@ -0,0 +1,240 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "tile.h"
21
+
22
+ #if defined(__clang__)
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif
27
+
28
+ namespace wp
29
+ {
30
+
31
+ #if defined(__CUDA_ARCH__)
32
+
33
+
34
+ template<typename T>
35
+ inline CUDA_CALLABLE T scan_warp_inclusive(int lane, T value)
36
+ {
37
+ //Computes an inclusive cumulative sum
38
+ #pragma unroll
39
+ for (int i = 1; i <= 32; i *= 2)
40
+ {
41
+ auto n = __shfl_up_sync(0xffffffffu, value, i, 32);
42
+
43
+ if (lane >= i)
44
+ value = value + n;
45
+ }
46
+ return value;
47
+ }
48
+
49
+
50
+ template<typename T>
51
+ inline CUDA_CALLABLE T thread_block_scan_inclusive(int lane, int warp_index, int num_warps, T value)
52
+ {
53
+ WP_TILE_SHARED T sums[1024 / WP_TILE_WARP_SIZE]; // 1024 is the maximum number of threads per block
54
+
55
+ value = scan_warp_inclusive(lane, value);
56
+
57
+ if (lane == 31)
58
+ {
59
+ sums[warp_index] = value;
60
+ }
61
+
62
+ WP_TILE_SYNC();
63
+
64
+ if (warp_index == 0)
65
+ {
66
+ T v = lane < num_warps ? sums[lane] : T(0);
67
+ v = scan_warp_inclusive(lane, v);
68
+ if (lane < num_warps)
69
+ sums[lane] = v;
70
+ }
71
+
72
+ WP_TILE_SYNC();
73
+
74
+ if (warp_index > 0)
75
+ {
76
+ value += sums[warp_index - 1];
77
+ }
78
+
79
+ return value;
80
+ }
81
+
82
+ template<typename T, bool exclusive>
83
+ inline CUDA_CALLABLE void thread_block_scan(T* values, int num_elements)
84
+ {
85
+ const int num_threads_in_block = blockDim.x;
86
+ const int num_iterations = (num_elements + num_threads_in_block - 1) / num_threads_in_block;
87
+
88
+ WP_TILE_SHARED T offset;
89
+ if (threadIdx.x == 0)
90
+ offset = T(0);
91
+
92
+ WP_TILE_SYNC();
93
+
94
+ const int lane = WP_TILE_THREAD_IDX % WP_TILE_WARP_SIZE;
95
+ const int warp_index = WP_TILE_THREAD_IDX / WP_TILE_WARP_SIZE;
96
+ const int num_warps = num_threads_in_block / WP_TILE_WARP_SIZE;
97
+
98
+ for (int i = 0; i < num_iterations; ++i)
99
+ {
100
+ int element_index = WP_TILE_THREAD_IDX + i * num_threads_in_block;
101
+ T orig_value = element_index < num_elements ? values[element_index] : T(0);
102
+ T value = thread_block_scan_inclusive(lane, warp_index, num_warps, orig_value);
103
+ if (element_index < num_elements)
104
+ {
105
+ T new_value = value + offset;
106
+ if constexpr (exclusive)
107
+ new_value -= orig_value;
108
+ values[element_index] = new_value;
109
+ }
110
+
111
+ WP_TILE_SYNC();
112
+
113
+ if (threadIdx.x == num_threads_in_block - 1)
114
+ offset += value;
115
+
116
+ WP_TILE_SYNC();
117
+ }
118
+ }
119
+
120
+ template<typename Tile>
121
+ inline CUDA_CALLABLE auto tile_scan_inclusive_impl(Tile& t)
122
+ {
123
+ using T = typename Tile::Type;
124
+ constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
125
+
126
+ // create a temporary shared tile to hold the input values
127
+ WP_TILE_SHARED T smem[num_elements_to_scan];
128
+ tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
129
+
130
+ // copy input values to scratch space
131
+ scratch.assign(t);
132
+
133
+ T* values = &scratch.data(0);
134
+ thread_block_scan<T, false>(values, num_elements_to_scan);
135
+
136
+ auto result = scratch.copy_to_register();
137
+
138
+ WP_TILE_SYNC();
139
+
140
+ return result;
141
+ }
142
+
143
+ template<typename Tile>
144
+ inline CUDA_CALLABLE auto tile_scan_exclusive_impl(Tile& t)
145
+ {
146
+ using T = typename Tile::Type;
147
+ constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
148
+
149
+ // create a temporary shared tile to hold the input values
150
+ WP_TILE_SHARED T smem[num_elements_to_scan];
151
+ tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
152
+
153
+ // copy input values to scratch space
154
+ scratch.assign(t);
155
+
156
+ T* values = &scratch.data(0);
157
+ thread_block_scan<T, true>(values, num_elements_to_scan);
158
+
159
+ auto result = scratch.copy_to_register();
160
+
161
+ WP_TILE_SYNC();
162
+
163
+ return result;
164
+ }
165
+
166
+ #else
167
+
168
+ template<typename Tile>
169
+ inline auto tile_scan_inclusive_impl(Tile& t)
170
+ {
171
+ using T = typename Tile::Type;
172
+ constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
173
+
174
+ auto input = t.copy_to_register();
175
+ auto output = tile_register_like<Tile>();
176
+
177
+ using Layout = typename decltype(input)::Layout;
178
+
179
+ T sum = T(0);
180
+ for (int i = 0; i < num_elements_to_scan; ++i)
181
+ {
182
+ sum += input.data[i];
183
+ output.data[i] = sum;
184
+ }
185
+
186
+ return output;
187
+ }
188
+
189
+ template<typename Tile>
190
+ inline auto tile_scan_exclusive_impl(Tile& t)
191
+ {
192
+ using T = typename Tile::Type;
193
+ constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
194
+
195
+ auto input = t.copy_to_register();
196
+ auto output = tile_register_like<Tile>();
197
+
198
+ using Layout = typename decltype(input)::Layout;
199
+
200
+ T sum = T(0);
201
+ for (int i = 0; i < num_elements_to_scan; ++i)
202
+ {
203
+ output.data[i] = sum;
204
+ sum += input.data[i];
205
+ }
206
+
207
+ return output;
208
+ }
209
+
210
+ #endif // !defined(__CUDA_ARCH__)
211
+
212
+ template <typename Tile>
213
+ auto tile_scan_inclusive(Tile& t)
214
+ {
215
+ return tile_scan_inclusive_impl(t);
216
+ }
217
+
218
+ template <typename Tile, typename AdjTile>
219
+ void adj_tile_scan_inclusive(Tile& t, Tile& adj_t, AdjTile& adj_ret)
220
+ {
221
+ // todo: not implemented
222
+ }
223
+
224
+ template <typename Tile>
225
+ auto tile_scan_exclusive(Tile& t)
226
+ {
227
+ return tile_scan_exclusive_impl(t);
228
+ }
229
+
230
+ template <typename Tile, typename AdjTile>
231
+ void adj_tile_scan_exclusive(Tile& t, Tile& adj_t, AdjTile& adj_ret)
232
+ {
233
+ // todo: not implemented
234
+ }
235
+
236
+ } // namespace wp
237
+
238
+ #if defined(__clang__)
239
+ #pragma clang diagnostic pop
240
+ #endif