warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/native/tile_reduce.h
CHANGED
|
@@ -24,6 +24,20 @@
|
|
|
24
24
|
namespace wp
|
|
25
25
|
{
|
|
26
26
|
|
|
27
|
+
|
|
28
|
+
template <typename T>
|
|
29
|
+
int argmax_tracker(T champion_value, T current_value, int champion_index, int current_index)
|
|
30
|
+
{
|
|
31
|
+
return current_value > champion_value ? current_index : champion_index;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
template <typename T>
|
|
35
|
+
int argmin_tracker(T champion_value, T current_value, int champion_index, int current_index)
|
|
36
|
+
{
|
|
37
|
+
return current_value < champion_value ? current_index : champion_index;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
27
41
|
#if defined(__CUDA_ARCH__)
|
|
28
42
|
|
|
29
43
|
template <typename T>
|
|
@@ -62,6 +76,32 @@ inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
|
|
|
62
76
|
return output;
|
|
63
77
|
}
|
|
64
78
|
|
|
79
|
+
// Vector overload
|
|
80
|
+
template <unsigned Length, typename T>
|
|
81
|
+
inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T> val, int offset, int mask)
|
|
82
|
+
{
|
|
83
|
+
wp::vec_t<Length, T> result;
|
|
84
|
+
|
|
85
|
+
for (unsigned i=0; i < Length; ++i)
|
|
86
|
+
result[i] = __shfl_down_sync(mask, val[i], offset, WP_TILE_WARP_SIZE);
|
|
87
|
+
|
|
88
|
+
return result;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// Matrix overload
|
|
92
|
+
template <unsigned Rows, unsigned Cols, typename T>
|
|
93
|
+
inline CUDA_CALLABLE wp::mat_t<Rows, Cols, T> warp_shuffle_down(wp::mat_t<Rows, Cols, T> val, int offset, int mask)
|
|
94
|
+
{
|
|
95
|
+
wp::mat_t<Rows, Cols, T> result;
|
|
96
|
+
|
|
97
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
98
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
99
|
+
result.data[i][j] = __shfl_down_sync(mask, val.data[i][j], offset, WP_TILE_WARP_SIZE);
|
|
100
|
+
|
|
101
|
+
return result;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
|
|
65
105
|
template <typename T, typename Op>
|
|
66
106
|
inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
|
|
67
107
|
{
|
|
@@ -89,6 +129,52 @@ inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
|
|
|
89
129
|
return sum;
|
|
90
130
|
}
|
|
91
131
|
|
|
132
|
+
template <typename T>
|
|
133
|
+
struct ValueAndIndex
|
|
134
|
+
{
|
|
135
|
+
T value;
|
|
136
|
+
int index;
|
|
137
|
+
};
|
|
138
|
+
|
|
139
|
+
template <typename T, typename Op, typename OpTrack>
|
|
140
|
+
inline CUDA_CALLABLE ValueAndIndex<T> warp_reduce_tracked(T val, int idx, Op f, OpTrack track, unsigned int mask)
|
|
141
|
+
{
|
|
142
|
+
T sum = val;
|
|
143
|
+
int index = idx;
|
|
144
|
+
|
|
145
|
+
if (mask == 0xFFFFFFFF)
|
|
146
|
+
{
|
|
147
|
+
// handle case where entire warp is active
|
|
148
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
149
|
+
{
|
|
150
|
+
auto shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
151
|
+
int shfl_idx = warp_shuffle_down(index, offset, mask);
|
|
152
|
+
index = track(sum, shfl_val, index, shfl_idx);
|
|
153
|
+
sum = f(sum, shfl_val);
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
else
|
|
157
|
+
{
|
|
158
|
+
// handle partial warp case
|
|
159
|
+
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
160
|
+
{
|
|
161
|
+
T shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
162
|
+
int shfl_index = warp_shuffle_down(index, offset, mask);
|
|
163
|
+
if ((mask & (1 << ((threadIdx.x + offset)%WP_TILE_WARP_SIZE))) != 0)
|
|
164
|
+
{
|
|
165
|
+
index = track(sum, shfl_val, index, shfl_index);
|
|
166
|
+
sum = f(sum, shfl_val);
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
ValueAndIndex<T> result;
|
|
172
|
+
result.value = sum;
|
|
173
|
+
result.index = index;
|
|
174
|
+
|
|
175
|
+
return result;
|
|
176
|
+
}
|
|
177
|
+
|
|
92
178
|
// non-axis version which computes sum
|
|
93
179
|
// across the entire tile using the whole block
|
|
94
180
|
template <typename Tile, typename Op>
|
|
@@ -120,6 +206,7 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
120
206
|
|
|
121
207
|
// ensure that only threads with at least one valid item participate in the reduction
|
|
122
208
|
unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
|
|
209
|
+
bool warp_is_active = mask != 0;
|
|
123
210
|
|
|
124
211
|
// warp reduction
|
|
125
212
|
T warp_sum = warp_reduce(thread_sum, f, mask);
|
|
@@ -135,7 +222,7 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
135
222
|
// ensure active_warps is initialized
|
|
136
223
|
WP_TILE_SYNC();
|
|
137
224
|
|
|
138
|
-
if (lane_index == 0)
|
|
225
|
+
if (lane_index == 0 && warp_is_active)
|
|
139
226
|
{
|
|
140
227
|
partials[warp_index] = warp_sum;
|
|
141
228
|
atomicAdd(&active_warps, 1);
|
|
@@ -159,6 +246,86 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
159
246
|
return output;
|
|
160
247
|
}
|
|
161
248
|
|
|
249
|
+
|
|
250
|
+
// non-axis version which computes sum
|
|
251
|
+
// across the entire tile using the whole block
|
|
252
|
+
template <typename Tile, typename Op, typename OpTrack>
|
|
253
|
+
auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
|
|
254
|
+
{
|
|
255
|
+
using T = typename Tile::Type;
|
|
256
|
+
|
|
257
|
+
auto input = t.copy_to_register();
|
|
258
|
+
auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
|
|
259
|
+
|
|
260
|
+
const int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
|
|
261
|
+
const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
|
|
262
|
+
const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
|
|
263
|
+
|
|
264
|
+
using Layout = typename decltype(input)::Layout;
|
|
265
|
+
|
|
266
|
+
int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
|
|
267
|
+
T thread_sum = input.data[0];
|
|
268
|
+
|
|
269
|
+
// thread reduction
|
|
270
|
+
WP_PRAGMA_UNROLL
|
|
271
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
272
|
+
{
|
|
273
|
+
int linear = Layout::linear_from_register(i);
|
|
274
|
+
if (!Layout::valid(linear))
|
|
275
|
+
break;
|
|
276
|
+
|
|
277
|
+
champion_index = track(thread_sum, input.data[i], champion_index, linear);
|
|
278
|
+
thread_sum = f(thread_sum, input.data[i]);
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
// ensure that only threads with at least one valid item participate in the reduction
|
|
282
|
+
unsigned int mask = __ballot_sync(__activemask(), Layout::valid(Layout::linear_from_register(0)));
|
|
283
|
+
bool warp_is_active = mask != 0;
|
|
284
|
+
|
|
285
|
+
// warp reduction
|
|
286
|
+
ValueAndIndex<T> warp_sum = warp_reduce_tracked(thread_sum, champion_index, f, track, mask);
|
|
287
|
+
|
|
288
|
+
// fixed size scratch pad for partial results in shared memory
|
|
289
|
+
WP_TILE_SHARED T partials[warp_count];
|
|
290
|
+
WP_TILE_SHARED int partials_idx[warp_count];
|
|
291
|
+
|
|
292
|
+
// count of active warps
|
|
293
|
+
WP_TILE_SHARED int active_warps;
|
|
294
|
+
if (threadIdx.x == 0)
|
|
295
|
+
active_warps = 0;
|
|
296
|
+
|
|
297
|
+
// ensure active_warps is initialized
|
|
298
|
+
WP_TILE_SYNC();
|
|
299
|
+
|
|
300
|
+
if (lane_index == 0 && warp_is_active)
|
|
301
|
+
{
|
|
302
|
+
partials[warp_index] = warp_sum.value;
|
|
303
|
+
partials_idx[warp_index] = warp_sum.index;
|
|
304
|
+
atomicAdd(&active_warps, 1);
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
// ensure partials are ready
|
|
308
|
+
WP_TILE_SYNC();
|
|
309
|
+
|
|
310
|
+
// reduce across block, todo: use warp_reduce() here
|
|
311
|
+
if (threadIdx.x == 0)
|
|
312
|
+
{
|
|
313
|
+
T block_sum = partials[0];
|
|
314
|
+
int block_champion_index = partials_idx[0];
|
|
315
|
+
|
|
316
|
+
WP_PRAGMA_UNROLL
|
|
317
|
+
for (int i=1; i < active_warps; ++i)
|
|
318
|
+
{
|
|
319
|
+
block_champion_index = track(block_sum, partials[i], block_champion_index, partials_idx[i]);
|
|
320
|
+
block_sum = f(block_sum, partials[i]);
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
output.data[0] = block_champion_index;
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
return output;
|
|
327
|
+
}
|
|
328
|
+
|
|
162
329
|
#else
|
|
163
330
|
|
|
164
331
|
// CPU implementation
|
|
@@ -171,9 +338,9 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
171
338
|
auto input = t.copy_to_register();
|
|
172
339
|
auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
|
|
173
340
|
|
|
174
|
-
|
|
341
|
+
using Layout = typename decltype(input)::Layout;
|
|
175
342
|
|
|
176
|
-
|
|
343
|
+
T sum = input.data[0];
|
|
177
344
|
|
|
178
345
|
WP_PRAGMA_UNROLL
|
|
179
346
|
for (int i=1; i < Layout::NumRegs; ++i)
|
|
@@ -189,6 +356,34 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
189
356
|
return output;
|
|
190
357
|
}
|
|
191
358
|
|
|
359
|
+
template <typename Tile, typename Op, typename OpTrack>
|
|
360
|
+
auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
|
|
361
|
+
{
|
|
362
|
+
using T = typename Tile::Type;
|
|
363
|
+
|
|
364
|
+
auto input = t.copy_to_register();
|
|
365
|
+
auto output = tile_register_t<int, tile_layout_register_t<tile_shape_t<1>>>();
|
|
366
|
+
|
|
367
|
+
using Layout = typename decltype(input)::Layout;
|
|
368
|
+
|
|
369
|
+
int champion_index = Layout::NumRegs > 0 ? Layout::linear_from_register(0) : -1;
|
|
370
|
+
T sum = input.data[0];
|
|
371
|
+
|
|
372
|
+
WP_PRAGMA_UNROLL
|
|
373
|
+
for (int i=1; i < Layout::NumRegs; ++i)
|
|
374
|
+
{
|
|
375
|
+
int linear = Layout::linear_from_register(i);
|
|
376
|
+
if (!Layout::valid(linear))
|
|
377
|
+
break;
|
|
378
|
+
|
|
379
|
+
champion_index = track(sum, input.data[i], champion_index, linear);
|
|
380
|
+
sum = f(sum, input.data[i]);
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
output.data[0] = champion_index;
|
|
384
|
+
return output;
|
|
385
|
+
}
|
|
386
|
+
|
|
192
387
|
#endif // !defined(__CUDA_ARCH__)
|
|
193
388
|
|
|
194
389
|
inline void adj_tile_reduce_impl()
|
|
@@ -200,6 +395,9 @@ inline void adj_tile_reduce_impl()
|
|
|
200
395
|
#define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
|
|
201
396
|
#define adj_tile_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_reduce_impl()
|
|
202
397
|
|
|
398
|
+
#define tile_arg_reduce(op, opTrack, t) tile_arg_reduce_impl([](auto x, auto y) { return op(x, y);}, [](auto a, auto b, auto c, auto d) { return opTrack(a, b, c, d); }, t)
|
|
399
|
+
#define adj_tile_arg_reduce(op, a, adj_op, adj_a, adj_ret) adj_tile_arg_reduce_impl()
|
|
400
|
+
|
|
203
401
|
// convenience methods for specific reductions
|
|
204
402
|
|
|
205
403
|
template <typename Tile>
|
|
@@ -214,25 +412,26 @@ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
|
214
412
|
{
|
|
215
413
|
using T = typename Tile::Type;
|
|
216
414
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
for (int i=0; i < Tile::Layout::Size; ++i)
|
|
220
|
-
{
|
|
221
|
-
adj_t(i) += adj_ret.data[0];
|
|
415
|
+
auto adj_reg = adj_ret.grad_to_register();
|
|
222
416
|
|
|
223
|
-
|
|
417
|
+
#if !defined(__CUDA_ARCH__)
|
|
418
|
+
T scratch = adj_reg.data[0];
|
|
224
419
|
#else
|
|
225
420
|
// broadcast incoming adjoint to block
|
|
226
421
|
WP_TILE_SHARED T scratch;
|
|
227
422
|
if (WP_TILE_THREAD_IDX == 0)
|
|
228
|
-
scratch =
|
|
423
|
+
scratch = adj_reg.data[0];
|
|
229
424
|
|
|
230
425
|
WP_TILE_SYNC();
|
|
426
|
+
#endif
|
|
231
427
|
|
|
232
|
-
|
|
233
|
-
|
|
428
|
+
auto adj_ret_reg = tile_register_like<Tile>();
|
|
429
|
+
using Layout = typename decltype(adj_ret_reg)::Layout;
|
|
430
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
431
|
+
{
|
|
432
|
+
adj_ret_reg.data[i] += scratch;
|
|
433
|
+
}
|
|
234
434
|
adj_t.grad_add(adj_ret_reg);
|
|
235
|
-
#endif
|
|
236
435
|
}
|
|
237
436
|
|
|
238
437
|
template <typename Tile>
|
|
@@ -261,4 +460,31 @@ void adj_tile_min(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
|
261
460
|
|
|
262
461
|
|
|
263
462
|
|
|
463
|
+
template <typename Tile>
|
|
464
|
+
auto tile_argmax(Tile& t)
|
|
465
|
+
{
|
|
466
|
+
return tile_arg_reduce(max, argmax_tracker, t);
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
template <typename Tile, typename AdjTile>
|
|
470
|
+
void adj_tile_argmax(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
471
|
+
{
|
|
472
|
+
// todo: not implemented
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
template <typename Tile>
|
|
476
|
+
auto tile_argmin(Tile& t)
|
|
477
|
+
{
|
|
478
|
+
return tile_arg_reduce(min, argmin_tracker, t);
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
template <typename Tile, typename AdjTile>
|
|
482
|
+
void adj_tile_argmin(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
483
|
+
{
|
|
484
|
+
// todo: not implemented
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
|
|
264
490
|
} // namespace wp
|
warp/native/tile_scan.h
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
*
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
#pragma once
|
|
19
|
+
|
|
20
|
+
#include "tile.h"
|
|
21
|
+
|
|
22
|
+
#if defined(__clang__)
|
|
23
|
+
// disable warnings related to C++17 extensions on CPU JIT builds
|
|
24
|
+
#pragma clang diagnostic push
|
|
25
|
+
#pragma clang diagnostic ignored "-Wc++17-extensions"
|
|
26
|
+
#endif
|
|
27
|
+
|
|
28
|
+
namespace wp
|
|
29
|
+
{
|
|
30
|
+
|
|
31
|
+
#if defined(__CUDA_ARCH__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
template<typename T>
|
|
35
|
+
inline CUDA_CALLABLE T scan_warp_inclusive(int lane, T value)
|
|
36
|
+
{
|
|
37
|
+
//Computes an inclusive cumulative sum
|
|
38
|
+
#pragma unroll
|
|
39
|
+
for (int i = 1; i <= 32; i *= 2)
|
|
40
|
+
{
|
|
41
|
+
auto n = __shfl_up_sync(0xffffffffu, value, i, 32);
|
|
42
|
+
|
|
43
|
+
if (lane >= i)
|
|
44
|
+
value = value + n;
|
|
45
|
+
}
|
|
46
|
+
return value;
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
template<typename T>
|
|
51
|
+
inline CUDA_CALLABLE T thread_block_scan_inclusive(int lane, int warp_index, int num_warps, T value)
|
|
52
|
+
{
|
|
53
|
+
WP_TILE_SHARED T sums[1024 / WP_TILE_WARP_SIZE]; // 1024 is the maximum number of threads per block
|
|
54
|
+
|
|
55
|
+
value = scan_warp_inclusive(lane, value);
|
|
56
|
+
|
|
57
|
+
if (lane == 31)
|
|
58
|
+
{
|
|
59
|
+
sums[warp_index] = value;
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
WP_TILE_SYNC();
|
|
63
|
+
|
|
64
|
+
if (warp_index == 0)
|
|
65
|
+
{
|
|
66
|
+
T v = lane < num_warps ? sums[lane] : T(0);
|
|
67
|
+
v = scan_warp_inclusive(lane, v);
|
|
68
|
+
if (lane < num_warps)
|
|
69
|
+
sums[lane] = v;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
WP_TILE_SYNC();
|
|
73
|
+
|
|
74
|
+
if (warp_index > 0)
|
|
75
|
+
{
|
|
76
|
+
value += sums[warp_index - 1];
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
return value;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
template<typename T, bool exclusive>
|
|
83
|
+
inline CUDA_CALLABLE void thread_block_scan(T* values, int num_elements)
|
|
84
|
+
{
|
|
85
|
+
const int num_threads_in_block = blockDim.x;
|
|
86
|
+
const int num_iterations = (num_elements + num_threads_in_block - 1) / num_threads_in_block;
|
|
87
|
+
|
|
88
|
+
WP_TILE_SHARED T offset;
|
|
89
|
+
if (threadIdx.x == 0)
|
|
90
|
+
offset = T(0);
|
|
91
|
+
|
|
92
|
+
WP_TILE_SYNC();
|
|
93
|
+
|
|
94
|
+
const int lane = WP_TILE_THREAD_IDX % WP_TILE_WARP_SIZE;
|
|
95
|
+
const int warp_index = WP_TILE_THREAD_IDX / WP_TILE_WARP_SIZE;
|
|
96
|
+
const int num_warps = num_threads_in_block / WP_TILE_WARP_SIZE;
|
|
97
|
+
|
|
98
|
+
for (int i = 0; i < num_iterations; ++i)
|
|
99
|
+
{
|
|
100
|
+
int element_index = WP_TILE_THREAD_IDX + i * num_threads_in_block;
|
|
101
|
+
T orig_value = element_index < num_elements ? values[element_index] : T(0);
|
|
102
|
+
T value = thread_block_scan_inclusive(lane, warp_index, num_warps, orig_value);
|
|
103
|
+
if (element_index < num_elements)
|
|
104
|
+
{
|
|
105
|
+
T new_value = value + offset;
|
|
106
|
+
if constexpr (exclusive)
|
|
107
|
+
new_value -= orig_value;
|
|
108
|
+
values[element_index] = new_value;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
WP_TILE_SYNC();
|
|
112
|
+
|
|
113
|
+
if (threadIdx.x == num_threads_in_block - 1)
|
|
114
|
+
offset += value;
|
|
115
|
+
|
|
116
|
+
WP_TILE_SYNC();
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
template<typename Tile>
|
|
121
|
+
inline CUDA_CALLABLE auto tile_scan_inclusive_impl(Tile& t)
|
|
122
|
+
{
|
|
123
|
+
using T = typename Tile::Type;
|
|
124
|
+
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
125
|
+
|
|
126
|
+
// create a temporary shared tile to hold the input values
|
|
127
|
+
WP_TILE_SHARED T smem[num_elements_to_scan];
|
|
128
|
+
tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
|
|
129
|
+
|
|
130
|
+
// copy input values to scratch space
|
|
131
|
+
scratch.assign(t);
|
|
132
|
+
|
|
133
|
+
T* values = &scratch.data(0);
|
|
134
|
+
thread_block_scan<T, false>(values, num_elements_to_scan);
|
|
135
|
+
|
|
136
|
+
auto result = scratch.copy_to_register();
|
|
137
|
+
|
|
138
|
+
WP_TILE_SYNC();
|
|
139
|
+
|
|
140
|
+
return result;
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
template<typename Tile>
|
|
144
|
+
inline CUDA_CALLABLE auto tile_scan_exclusive_impl(Tile& t)
|
|
145
|
+
{
|
|
146
|
+
using T = typename Tile::Type;
|
|
147
|
+
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
148
|
+
|
|
149
|
+
// create a temporary shared tile to hold the input values
|
|
150
|
+
WP_TILE_SHARED T smem[num_elements_to_scan];
|
|
151
|
+
tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
|
|
152
|
+
|
|
153
|
+
// copy input values to scratch space
|
|
154
|
+
scratch.assign(t);
|
|
155
|
+
|
|
156
|
+
T* values = &scratch.data(0);
|
|
157
|
+
thread_block_scan<T, true>(values, num_elements_to_scan);
|
|
158
|
+
|
|
159
|
+
auto result = scratch.copy_to_register();
|
|
160
|
+
|
|
161
|
+
WP_TILE_SYNC();
|
|
162
|
+
|
|
163
|
+
return result;
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
#else
|
|
167
|
+
|
|
168
|
+
template<typename Tile>
|
|
169
|
+
inline auto tile_scan_inclusive_impl(Tile& t)
|
|
170
|
+
{
|
|
171
|
+
using T = typename Tile::Type;
|
|
172
|
+
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
173
|
+
|
|
174
|
+
auto input = t.copy_to_register();
|
|
175
|
+
auto output = tile_register_like<Tile>();
|
|
176
|
+
|
|
177
|
+
using Layout = typename decltype(input)::Layout;
|
|
178
|
+
|
|
179
|
+
T sum = T(0);
|
|
180
|
+
for (int i = 0; i < num_elements_to_scan; ++i)
|
|
181
|
+
{
|
|
182
|
+
sum += input.data[i];
|
|
183
|
+
output.data[i] = sum;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
return output;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
template<typename Tile>
|
|
190
|
+
inline auto tile_scan_exclusive_impl(Tile& t)
|
|
191
|
+
{
|
|
192
|
+
using T = typename Tile::Type;
|
|
193
|
+
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
194
|
+
|
|
195
|
+
auto input = t.copy_to_register();
|
|
196
|
+
auto output = tile_register_like<Tile>();
|
|
197
|
+
|
|
198
|
+
using Layout = typename decltype(input)::Layout;
|
|
199
|
+
|
|
200
|
+
T sum = T(0);
|
|
201
|
+
for (int i = 0; i < num_elements_to_scan; ++i)
|
|
202
|
+
{
|
|
203
|
+
output.data[i] = sum;
|
|
204
|
+
sum += input.data[i];
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
return output;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
#endif // !defined(__CUDA_ARCH__)
|
|
211
|
+
|
|
212
|
+
template <typename Tile>
|
|
213
|
+
auto tile_scan_inclusive(Tile& t)
|
|
214
|
+
{
|
|
215
|
+
return tile_scan_inclusive_impl(t);
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
template <typename Tile, typename AdjTile>
|
|
219
|
+
void adj_tile_scan_inclusive(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
220
|
+
{
|
|
221
|
+
// todo: not implemented
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
template <typename Tile>
|
|
225
|
+
auto tile_scan_exclusive(Tile& t)
|
|
226
|
+
{
|
|
227
|
+
return tile_scan_exclusive_impl(t);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
template <typename Tile, typename AdjTile>
|
|
231
|
+
void adj_tile_scan_exclusive(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
232
|
+
{
|
|
233
|
+
// todo: not implemented
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
} // namespace wp
|
|
237
|
+
|
|
238
|
+
#if defined(__clang__)
|
|
239
|
+
#pragma clang diagnostic pop
|
|
240
|
+
#endif
|