warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/native/tile_reduce.h
CHANGED
|
@@ -24,6 +24,20 @@
|
|
|
24
24
|
namespace wp
|
|
25
25
|
{
|
|
26
26
|
|
|
27
|
+
|
|
28
|
+
template <typename T>
|
|
29
|
+
int argmax_tracker(T champion_value, T current_value, int champion_index, int current_index)
|
|
30
|
+
{
|
|
31
|
+
return current_value > champion_value ? current_index : champion_index;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
template <typename T>
|
|
35
|
+
int argmin_tracker(T champion_value, T current_value, int champion_index, int current_index)
|
|
36
|
+
{
|
|
37
|
+
return current_value < champion_value ? current_index : champion_index;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
27
41
|
#if defined(__CUDA_ARCH__)
|
|
28
42
|
|
|
29
43
|
template <typename T>
|
|
@@ -62,6 +76,44 @@ inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
|
|
|
62
76
|
return output;
|
|
63
77
|
}
|
|
64
78
|
|
|
79
|
+
// Vector overload
|
|
80
|
+
template <unsigned Length, typename T>
|
|
81
|
+
inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T> val, int offset, int mask)
|
|
82
|
+
{
|
|
83
|
+
wp::vec_t<Length, T> result;
|
|
84
|
+
|
|
85
|
+
for (unsigned i=0; i < Length; ++i)
|
|
86
|
+
result.data[i] = __shfl_down_sync(mask, val.data[i], offset, WP_TILE_WARP_SIZE);
|
|
87
|
+
|
|
88
|
+
return result;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// Quaternion overload
|
|
92
|
+
template <typename T>
|
|
93
|
+
inline CUDA_CALLABLE wp::quat_t<T> warp_shuffle_down(wp::quat_t<T> val, int offset, int mask)
|
|
94
|
+
{
|
|
95
|
+
wp::quat_t<T> result;
|
|
96
|
+
|
|
97
|
+
for (unsigned i=0; i < 4; ++i)
|
|
98
|
+
result.data[i] = __shfl_down_sync(mask, val.data[i], offset, WP_TILE_WARP_SIZE);
|
|
99
|
+
|
|
100
|
+
return result;
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
// Matrix overload
|
|
104
|
+
template <unsigned Rows, unsigned Cols, typename T>
|
|
105
|
+
inline CUDA_CALLABLE wp::mat_t<Rows, Cols, T> warp_shuffle_down(wp::mat_t<Rows, Cols, T> val, int offset, int mask)
|
|
106
|
+
{
|
|
107
|
+
wp::mat_t<Rows, Cols, T> result;
|
|
108
|
+
|
|
109
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
110
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
111
|
+
result.data[i][j] = __shfl_down_sync(mask, val.data[i][j], offset, WP_TILE_WARP_SIZE);
|
|
112
|
+
|
|
113
|
+
return result;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
|
|
65
117
|
template <typename T, typename Op>
|
|
66
118
|
inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
|
|
67
119
|
{
|
|
@@ -89,6 +141,52 @@ inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
|
|
|
89
141
|
return sum;
|
|
90
142
|
}
|
|
91
143
|
|
|
144
|
+
template <typename T>
|
|
145
|
+
struct ValueAndIndex
|
|
146
|
+
{
|
|
147
|
+
T value;
|
|
148
|
+
int index;
|
|
149
|
+
};
|
|
150
|
+
|
|
151
|
+
template <typename T, typename Op, typename OpTrack>
|
|
152
|
+
inline CUDA_CALLABLE ValueAndIndex<T> warp_reduce_tracked(T val, int idx, Op f, OpTrack track, unsigned int mask)
|
|
153
|
+
{
|
|
154
|
+
T sum = val;
|
|
155
|
+
int index = idx;
|
|
156
|
+
|
|
157
|
+
if (mask == 0xFFFFFFFF)
|
|
158
|
+
{
|
|
159
|
+
// handle case where entire warp is active
|
|
160
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
161
|
+
{
|
|
162
|
+
auto shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
163
|
+
int shfl_idx = warp_shuffle_down(index, offset, mask);
|
|
164
|
+
index = track(sum, shfl_val, index, shfl_idx);
|
|
165
|
+
sum = f(sum, shfl_val);
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
else
|
|
169
|
+
{
|
|
170
|
+
// handle partial warp case
|
|
171
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
172
|
+
{
|
|
173
|
+
T shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
174
|
+
int shfl_index = warp_shuffle_down(index, offset, mask);
|
|
175
|
+
if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
|
|
176
|
+
{
|
|
177
|
+
index = track(sum, shfl_val, index, shfl_index);
|
|
178
|
+
sum = f(sum, shfl_val);
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
ValueAndIndex<T> result;
|
|
184
|
+
result.value = sum;
|
|
185
|
+
result.index = index;
|
|
186
|
+
|
|
187
|
+
return result;
|
|
188
|
+
}
|
|
189
|
+
|
|
92
190
|
// non-axis version which computes sum
|
|
93
191
|
// across the entire tile using the whole block
|
|
94
192
|
template <typename Tile, typename Op>
|
|
@@ -159,6 +257,85 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
159
257
|
return output;
|
|
160
258
|
}
|
|
161
259
|
|
|
260
|
+
|
|
261
|
+
// non-axis version which computes sum
|
|
262
|
+
// across the entire tile using the whole block
|
|
263
|
+
template <typename Tile, typename Op, typename OpTrack>
|
|
264
|
+
auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
|
|
265
|
+
{
|
|
266
|
+
using T = typename Tile::Type;
|
|
267
|
+
|
|
268
|
+
auto input = t.copy_to_register();
|
|
269
|
+
auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
|
|
270
|
+
|
|
271
|
+
const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
|
|
272
|
+
const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
|
|
273
|
+
const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
|
|
274
|
+
|
|
275
|
+
using Layout = typename decltype(input)::Layout;
|
|
276
|
+
|
|
277
|
+
int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
|
|
278
|
+
T thread_sum = input.data[0];
|
|
279
|
+
|
|
280
|
+
// thread reduction
|
|
281
|
+
WP_PRAGMA_UNROLL
|
|
282
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
283
|
+
{
|
|
284
|
+
int linear = Layout::linear_from_register(i);
|
|
285
|
+
if (!Layout::valid(linear))
|
|
286
|
+
break;
|
|
287
|
+
|
|
288
|
+
champion_index = track(thread_sum, input.data[i], champion_index, linear);
|
|
289
|
+
thread_sum = f(thread_sum, input.data[i]);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
// ensure that only threads with at least one valid item participate in the reduction
|
|
293
|
+
unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
|
|
294
|
+
|
|
295
|
+
// warp reduction
|
|
296
|
+
ValueAndIndex<T> warp_sum = warp_reduce_tracked(thread_sum, champion_index, f, track, mask);
|
|
297
|
+
|
|
298
|
+
// fixed size scratch pad for partial results in shared memory
|
|
299
|
+
WP_TILE_SHARED T partials[warp_count];
|
|
300
|
+
WP_TILE_SHARED int partials_idx[warp_count];
|
|
301
|
+
|
|
302
|
+
// count of active warps
|
|
303
|
+
WP_TILE_SHARED int active_warps;
|
|
304
|
+
if (threadIdx.x == 0)
|
|
305
|
+
active_warps = 0;
|
|
306
|
+
|
|
307
|
+
// ensure active_warps is initialized
|
|
308
|
+
WP_TILE_SYNC();
|
|
309
|
+
|
|
310
|
+
if (lane_index == 0)
|
|
311
|
+
{
|
|
312
|
+
partials[warp_index] = warp_sum.value;
|
|
313
|
+
partials_idx[warp_index] = warp_sum.index;
|
|
314
|
+
atomicAdd(&active_warps, 1);
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
// ensure partials are ready
|
|
318
|
+
WP_TILE_SYNC();
|
|
319
|
+
|
|
320
|
+
// reduce across block, todo: use warp_reduce() here
|
|
321
|
+
if (threadIdx.x == 0)
|
|
322
|
+
{
|
|
323
|
+
T block_sum = partials[0];
|
|
324
|
+
int block_champion_index = partials_idx[0];
|
|
325
|
+
|
|
326
|
+
WP_PRAGMA_UNROLL
|
|
327
|
+
for (int i=1; i < active_warps; ++i)
|
|
328
|
+
{
|
|
329
|
+
block_champion_index = track(block_sum, partials[i], block_champion_index, partials_idx[i]);
|
|
330
|
+
block_sum = f(block_sum, partials[i]);
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
output.data[0] = block_champion_index;
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
return output;
|
|
337
|
+
}
|
|
338
|
+
|
|
162
339
|
#else
|
|
163
340
|
|
|
164
341
|
// CPU implementation
|
|
@@ -171,9 +348,9 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
171
348
|
auto input = t.copy_to_register();
|
|
172
349
|
auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
|
|
173
350
|
|
|
174
|
-
|
|
351
|
+
using Layout = typename decltype(input)::Layout;
|
|
175
352
|
|
|
176
|
-
|
|
353
|
+
T sum = input.data[0];
|
|
177
354
|
|
|
178
355
|
WP_PRAGMA_UNROLL
|
|
179
356
|
for (int i=1; i < Layout::NumRegs; ++i)
|
|
@@ -189,6 +366,34 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
189
366
|
return output;
|
|
190
367
|
}
|
|
191
368
|
|
|
369
|
+
template <typename Tile, typename Op, typename OpTrack>
|
|
370
|
+
auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
|
|
371
|
+
{
|
|
372
|
+
using T = typename Tile::Type;
|
|
373
|
+
|
|
374
|
+
auto input = t.copy_to_register();
|
|
375
|
+
auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
|
|
376
|
+
|
|
377
|
+
using Layout = typename decltype(input)::Layout;
|
|
378
|
+
|
|
379
|
+
int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
|
|
380
|
+
T sum = input.data[0];
|
|
381
|
+
|
|
382
|
+
WP_PRAGMA_UNROLL
|
|
383
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
384
|
+
{
|
|
385
|
+
int linear = Layout::linear_from_register(i);
|
|
386
|
+
if (!Layout::valid(linear))
|
|
387
|
+
break;
|
|
388
|
+
|
|
389
|
+
champion_index = track(sum, input.data[i], champion_index, linear);
|
|
390
|
+
sum = f(sum, input.data[i]);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
output.data[0] = champion_index;
|
|
394
|
+
return output;
|
|
395
|
+
}
|
|
396
|
+
|
|
192
397
|
#endif // !defined(__CUDA_ARCH__)
|
|
193
398
|
|
|
194
399
|
inline void adj_tile_reduce_impl()
|
|
@@ -200,6 +405,9 @@ inline void adj_tile_reduce_impl()
|
|
|
200
405
|
#define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
|
|
201
406
|
#define adj_tile_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_reduce_impl()
|
|
202
407
|
|
|
408
|
+
#define tile_arg_reduce(op, opTrack, t) tile_arg_reduce_impl([](auto x, auto y) { return op(x, y);}, [](auto a, auto b, auto c, auto d) { return opTrack(a, b, c, d); }, t)
|
|
409
|
+
#define adj_tile_arg_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_arg_reduce_impl()
|
|
410
|
+
|
|
203
411
|
// convenience methods for specific reductions
|
|
204
412
|
|
|
205
413
|
template <typename Tile>
|
|
@@ -261,4 +469,31 @@ void adj_tile_min(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
|
261
469
|
|
|
262
470
|
|
|
263
471
|
|
|
472
|
+
template <typename Tile>
|
|
473
|
+
auto tile_argmax(Tile& t)
|
|
474
|
+
{
|
|
475
|
+
return tile_arg_reduce(max, argmax_tracker, t);
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
template <typename Tile, typename AdjTile>
|
|
479
|
+
void adj_tile_argmax(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
480
|
+
{
|
|
481
|
+
// todo: not implemented
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
template <typename Tile>
|
|
485
|
+
auto tile_argmin(Tile& t)
|
|
486
|
+
{
|
|
487
|
+
return tile_arg_reduce(min, argmin_tracker, t);
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
template <typename Tile, typename AdjTile>
|
|
491
|
+
void adj_tile_argmin(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
492
|
+
{
|
|
493
|
+
// todo: not implemented
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
|
|
264
499
|
} // namespace wp
|
warp/native/tile_scan.h
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include "tile.h"
|
|
21
|
+
|
|
22
|
+
#if defined(__clang__)
|
|
23
|
+
// disable warnings related to C++17 extensions on CPU JIT builds
|
|
24
|
+
#pragma clang diagnostic push
|
|
25
|
+
#pragma clang diagnostic ignored "-Wc++17-extensions"
|
|
26
|
+
#endif
|
|
27
|
+
|
|
28
|
+
namespace wp
|
|
29
|
+
{
|
|
30
|
+
|
|
31
|
+
#if defined(__CUDA_ARCH__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
template<typename T>
|
|
35
|
+
inline CUDA_CALLABLE T scan_warp_inclusive(int lane, T value)
|
|
36
|
+
{
|
|
37
|
+
//Computes an inclusive cumulative sum
|
|
38
|
+
#pragma unroll
|
|
39
|
+
for (int i = 1; i <= 32; i *= 2)
|
|
40
|
+
{
|
|
41
|
+
auto n = __shfl_up_sync(0xffffffffu, value, i, 32);
|
|
42
|
+
|
|
43
|
+
if (lane >= i)
|
|
44
|
+
value = value + n;
|
|
45
|
+
}
|
|
46
|
+
return value;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
template<typename T>
|
|
51
|
+
inline CUDA_CALLABLE T thread_block_scan_inclusive(int lane, int warp_index, int num_warps, T value)
|
|
52
|
+
{
|
|
53
|
+
WP_TILE_SHARED T sums[1024 / WP_TILE_WARP_SIZE]; // 1024 is the maximum number of threads per block
|
|
54
|
+
|
|
55
|
+
value = scan_warp_inclusive(lane, value);
|
|
56
|
+
|
|
57
|
+
if (lane == 31)
|
|
58
|
+
{
|
|
59
|
+
sums[warp_index] = value;
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
WP_TILE_SYNC();
|
|
63
|
+
|
|
64
|
+
if (warp_index == 0)
|
|
65
|
+
{
|
|
66
|
+
T v = lane < num_warps ? sums[lane] : T(0);
|
|
67
|
+
v = scan_warp_inclusive(lane, v);
|
|
68
|
+
if (lane < num_warps)
|
|
69
|
+
sums[lane] = v;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
WP_TILE_SYNC();
|
|
73
|
+
|
|
74
|
+
if (warp_index > 0)
|
|
75
|
+
{
|
|
76
|
+
value += sums[warp_index - 1];
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return value;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
template<typename T, bool exclusive>
|
|
83
|
+
inline CUDA_CALLABLE void thread_block_scan(T* values, int num_elements)
|
|
84
|
+
{
|
|
85
|
+
const int num_threads_in_block = blockDim.x;
|
|
86
|
+
const int num_iterations = (num_elements + num_threads_in_block - 1) / num_threads_in_block;
|
|
87
|
+
|
|
88
|
+
WP_TILE_SHARED T offset;
|
|
89
|
+
if (threadIdx.x == 0)
|
|
90
|
+
offset = T(0);
|
|
91
|
+
|
|
92
|
+
WP_TILE_SYNC();
|
|
93
|
+
|
|
94
|
+
const int lane = WP_TILE_THREAD_IDX % WP_TILE_WARP_SIZE;
|
|
95
|
+
const int warp_index = WP_TILE_THREAD_IDX / WP_TILE_WARP_SIZE;
|
|
96
|
+
const int num_warps = num_threads_in_block / WP_TILE_WARP_SIZE;
|
|
97
|
+
|
|
98
|
+
for (int i = 0; i < num_iterations; ++i)
|
|
99
|
+
{
|
|
100
|
+
int element_index = WP_TILE_THREAD_IDX + i * num_threads_in_block;
|
|
101
|
+
T orig_value = element_index < num_elements ? values[element_index] : T(0);
|
|
102
|
+
T value = thread_block_scan_inclusive(lane, warp_index, num_warps, orig_value);
|
|
103
|
+
if (element_index < num_elements)
|
|
104
|
+
{
|
|
105
|
+
T new_value = value + offset;
|
|
106
|
+
if constexpr (exclusive)
|
|
107
|
+
new_value -= orig_value;
|
|
108
|
+
values[element_index] = new_value;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
WP_TILE_SYNC();
|
|
112
|
+
|
|
113
|
+
if (threadIdx.x == num_threads_in_block - 1)
|
|
114
|
+
offset += value;
|
|
115
|
+
|
|
116
|
+
WP_TILE_SYNC();
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
template<typename Tile>
|
|
121
|
+
inline CUDA_CALLABLE auto tile_scan_inclusive_impl(Tile& t)
|
|
122
|
+
{
|
|
123
|
+
using T = typename Tile::Type;
|
|
124
|
+
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
125
|
+
|
|
126
|
+
// create a temporary shared tile to hold the input values
|
|
127
|
+
WP_TILE_SHARED T smem[num_elements_to_scan];
|
|
128
|
+
tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
|
|
129
|
+
|
|
130
|
+
// copy input values to scratch space
|
|
131
|
+
scratch.assign(t);
|
|
132
|
+
|
|
133
|
+
T* values = &scratch.data(0);
|
|
134
|
+
thread_block_scan<T, false>(values, num_elements_to_scan);
|
|
135
|
+
|
|
136
|
+
auto result = scratch.copy_to_register();
|
|
137
|
+
|
|
138
|
+
WP_TILE_SYNC();
|
|
139
|
+
|
|
140
|
+
return result;
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
template<typename Tile>
|
|
144
|
+
inline CUDA_CALLABLE auto tile_scan_exclusive_impl(Tile& t)
|
|
145
|
+
{
|
|
146
|
+
using T = typename Tile::Type;
|
|
147
|
+
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
148
|
+
|
|
149
|
+
// create a temporary shared tile to hold the input values
|
|
150
|
+
WP_TILE_SHARED T smem[num_elements_to_scan];
|
|
151
|
+
tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
|
|
152
|
+
|
|
153
|
+
// copy input values to scratch space
|
|
154
|
+
scratch.assign(t);
|
|
155
|
+
|
|
156
|
+
T* values = &scratch.data(0);
|
|
157
|
+
thread_block_scan<T, true>(values, num_elements_to_scan);
|
|
158
|
+
|
|
159
|
+
auto result = scratch.copy_to_register();
|
|
160
|
+
|
|
161
|
+
WP_TILE_SYNC();
|
|
162
|
+
|
|
163
|
+
return result;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
#else
|
|
167
|
+
|
|
168
|
+
template<typename Tile>
|
|
169
|
+
inline auto tile_scan_inclusive_impl(Tile& t)
|
|
170
|
+
{
|
|
171
|
+
using T = typename Tile::Type;
|
|
172
|
+
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
173
|
+
|
|
174
|
+
auto input = t.copy_to_register();
|
|
175
|
+
auto output = tile_register_like<Tile>();
|
|
176
|
+
|
|
177
|
+
using Layout = typename decltype(input)::Layout;
|
|
178
|
+
|
|
179
|
+
T sum = T(0);
|
|
180
|
+
for (int i = 0; i < num_elements_to_scan; ++i)
|
|
181
|
+
{
|
|
182
|
+
sum += input.data[i];
|
|
183
|
+
output.data[i] = sum;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
return output;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
template<typename Tile>
|
|
190
|
+
inline auto tile_scan_exclusive_impl(Tile& t)
|
|
191
|
+
{
|
|
192
|
+
using T = typename Tile::Type;
|
|
193
|
+
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
194
|
+
|
|
195
|
+
auto input = t.copy_to_register();
|
|
196
|
+
auto output = tile_register_like<Tile>();
|
|
197
|
+
|
|
198
|
+
using Layout = typename decltype(input)::Layout;
|
|
199
|
+
|
|
200
|
+
T sum = T(0);
|
|
201
|
+
for (int i = 0; i < num_elements_to_scan; ++i)
|
|
202
|
+
{
|
|
203
|
+
output.data[i] = sum;
|
|
204
|
+
sum += input.data[i];
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
return output;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
#endif // !defined(__CUDA_ARCH__)
|
|
211
|
+
|
|
212
|
+
template <typename Tile>
|
|
213
|
+
auto tile_scan_inclusive(Tile& t)
|
|
214
|
+
{
|
|
215
|
+
return tile_scan_inclusive_impl(t);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
template <typename Tile, typename AdjTile>
|
|
219
|
+
void adj_tile_scan_inclusive(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
220
|
+
{
|
|
221
|
+
// todo: not implemented
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
template <typename Tile>
|
|
225
|
+
auto tile_scan_exclusive(Tile& t)
|
|
226
|
+
{
|
|
227
|
+
return tile_scan_exclusive_impl(t);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
template <typename Tile, typename AdjTile>
|
|
231
|
+
void adj_tile_scan_exclusive(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
232
|
+
{
|
|
233
|
+
// todo: not implemented
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
} // namespace wp
|
|
237
|
+
|
|
238
|
+
#if defined(__clang__)
|
|
239
|
+
#pragma clang diagnostic pop
|
|
240
|
+
#endif
|
warp/native/tuple.h
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
namespace wp
|
|
21
|
+
{
|
|
22
|
+
|
|
23
|
+
template <typename... Types>
|
|
24
|
+
struct tuple_t;
|
|
25
|
+
|
|
26
|
+
template <>
|
|
27
|
+
struct tuple_t<>
|
|
28
|
+
{
|
|
29
|
+
|
|
30
|
+
static constexpr int size() { return 0; }
|
|
31
|
+
|
|
32
|
+
// Base case: empty tuple.
|
|
33
|
+
template <typename Callable>
|
|
34
|
+
void apply(Callable&&) const { }
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
template <typename Head, typename... Tail>
|
|
38
|
+
struct tuple_t<Head, Tail...>
|
|
39
|
+
{
|
|
40
|
+
Head head;
|
|
41
|
+
tuple_t<Tail...> tail;
|
|
42
|
+
|
|
43
|
+
CUDA_CALLABLE inline tuple_t() {}
|
|
44
|
+
CUDA_CALLABLE inline tuple_t(Head h, Tail... t) : head(h), tail(t...) {}
|
|
45
|
+
|
|
46
|
+
static constexpr int size() { return 1 + tuple_t<Tail...>::size(); }
|
|
47
|
+
|
|
48
|
+
// Applies a callable to each element.
|
|
49
|
+
template <typename Callable>
|
|
50
|
+
void apply(Callable&& func) const
|
|
51
|
+
{
|
|
52
|
+
func(head); // Apply the callable to the current element.
|
|
53
|
+
tail.apply(func); // Recursively process the rest of the tuple.
|
|
54
|
+
}
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
// Tuple constructor.
|
|
58
|
+
template <typename... Args>
|
|
59
|
+
CUDA_CALLABLE inline tuple_t<Args...>
|
|
60
|
+
tuple(
|
|
61
|
+
Args... args
|
|
62
|
+
)
|
|
63
|
+
{
|
|
64
|
+
return tuple_t<Args...>(args...);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
// Helper to extract a value from the tuple.
|
|
68
|
+
// Can be replaced with simpler member function version when our CPU compiler
|
|
69
|
+
// backend supports constexpr if statements.
|
|
70
|
+
template <int N, typename Head, typename... Tail>
|
|
71
|
+
struct tuple_get
|
|
72
|
+
{
|
|
73
|
+
static CUDA_CALLABLE inline const auto&
|
|
74
|
+
value(
|
|
75
|
+
const tuple_t<Head, Tail...>& t
|
|
76
|
+
)
|
|
77
|
+
{
|
|
78
|
+
return tuple_get<N - 1, Tail...>::value(t.tail);
|
|
79
|
+
}
|
|
80
|
+
};
|
|
81
|
+
|
|
82
|
+
// Specialization for the base case N == 0. Simply return the head of the tuple.
|
|
83
|
+
template <typename Head, typename... Tail>
|
|
84
|
+
struct tuple_get<0, Head, Tail...>
|
|
85
|
+
{
|
|
86
|
+
static CUDA_CALLABLE inline const auto&
|
|
87
|
+
value(
|
|
88
|
+
const tuple_t<Head, Tail...>& t
|
|
89
|
+
)
|
|
90
|
+
{
|
|
91
|
+
return t.head;
|
|
92
|
+
}
|
|
93
|
+
};
|
|
94
|
+
|
|
95
|
+
template <int Index, typename... Args>
|
|
96
|
+
CUDA_CALLABLE inline auto
|
|
97
|
+
extract(
|
|
98
|
+
const tuple_t<Args...>& t
|
|
99
|
+
)
|
|
100
|
+
{
|
|
101
|
+
return tuple_get<Index, Args...>::value(t);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
template <typename... Args>
|
|
105
|
+
CUDA_CALLABLE inline int
|
|
106
|
+
len(
|
|
107
|
+
const tuple_t<Args...>& t
|
|
108
|
+
)
|
|
109
|
+
{
|
|
110
|
+
return t.size();
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
template <typename... Args>
|
|
114
|
+
CUDA_CALLABLE inline void
|
|
115
|
+
adj_len(
|
|
116
|
+
const tuple_t<Args...>& t,
|
|
117
|
+
tuple_t<Args...>& adj_t,
|
|
118
|
+
int adj_ret
|
|
119
|
+
)
|
|
120
|
+
{
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
template <typename... Args>
|
|
124
|
+
CUDA_CALLABLE inline void
|
|
125
|
+
print(
|
|
126
|
+
const tuple_t<Args...>& t
|
|
127
|
+
)
|
|
128
|
+
{
|
|
129
|
+
t.apply([&](auto a) { print(a); });
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
template <typename... Args>
|
|
133
|
+
CUDA_CALLABLE inline void
|
|
134
|
+
adj_print(
|
|
135
|
+
const tuple_t<Args...>& t,
|
|
136
|
+
tuple_t<Args...>& adj_t
|
|
137
|
+
)
|
|
138
|
+
{
|
|
139
|
+
adj_t.apply([&](auto a) { print(a); });
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
CUDA_CALLABLE inline tuple_t<>
|
|
143
|
+
add(
|
|
144
|
+
const tuple_t<>& a,
|
|
145
|
+
const tuple_t<>& b
|
|
146
|
+
)
|
|
147
|
+
{
|
|
148
|
+
return tuple_t<>();
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
template <typename Head, typename... Tail>
|
|
152
|
+
CUDA_CALLABLE inline tuple_t<Head, Tail...>
|
|
153
|
+
add(
|
|
154
|
+
const tuple_t<Head, Tail...>& a,
|
|
155
|
+
const tuple_t<Head, Tail...>& b
|
|
156
|
+
)
|
|
157
|
+
{
|
|
158
|
+
tuple_t<Head, Tail...> out;
|
|
159
|
+
out.head = add(a.head, b.head);
|
|
160
|
+
out.tail = add(a.tail, b.tail);
|
|
161
|
+
return out;
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
CUDA_CALLABLE inline void
|
|
165
|
+
adj_add(
|
|
166
|
+
const tuple_t<>& a,
|
|
167
|
+
const tuple_t<>& b,
|
|
168
|
+
tuple_t<>& adj_a,
|
|
169
|
+
tuple_t<>& adj_b,
|
|
170
|
+
const tuple_t<>& adj_ret
|
|
171
|
+
)
|
|
172
|
+
{
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
template <typename Head, typename... Tail>
|
|
176
|
+
CUDA_CALLABLE inline void
|
|
177
|
+
adj_add(
|
|
178
|
+
const tuple_t<Head, Tail...>& a,
|
|
179
|
+
const tuple_t<Head, Tail...>& b,
|
|
180
|
+
tuple_t<Head, Tail...>& adj_a,
|
|
181
|
+
tuple_t<Head, Tail...>& adj_b,
|
|
182
|
+
const tuple_t<Head, Tail...>& adj_ret
|
|
183
|
+
)
|
|
184
|
+
{
|
|
185
|
+
adj_add(a.head, b.head, adj_ret.head);
|
|
186
|
+
adj_add(a.tail, b.tail, adj_ret.tail);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
} // namespace wp
|