warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2220 -313
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1497 -226
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +313 -201
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +529 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/tile_radix_sort.h
CHANGED
|
@@ -921,7 +921,7 @@ template <typename K, typename V, typename KeyToUint>
|
|
|
921
921
|
void radix_sort_pairs_cpu_core(K* keys, K* aux_keys, V* values, V* aux_values, int n)
|
|
922
922
|
{
|
|
923
923
|
KeyToUint converter;
|
|
924
|
-
|
|
924
|
+
unsigned int tables[2][1 << 16];
|
|
925
925
|
memset(tables, 0, sizeof(tables));
|
|
926
926
|
|
|
927
927
|
// build histograms
|
warp/native/tile_reduce.h
CHANGED
|
@@ -19,6 +19,12 @@
|
|
|
19
19
|
|
|
20
20
|
#include "tile.h"
|
|
21
21
|
|
|
22
|
+
#ifdef __clang__
|
|
23
|
+
// disable warnings related to C++17 extensions on CPU JIT builds
|
|
24
|
+
#pragma clang diagnostic push
|
|
25
|
+
#pragma clang diagnostic ignored "-Wc++17-extensions"
|
|
26
|
+
#endif // __clang__
|
|
27
|
+
|
|
22
28
|
#define WP_TILE_WARP_SIZE 32
|
|
23
29
|
|
|
24
30
|
namespace wp
|
|
@@ -76,7 +82,7 @@ inline CUDA_CALLABLE T warp_shuffle_down(T val, int offset, int mask)
|
|
|
76
82
|
return output;
|
|
77
83
|
}
|
|
78
84
|
|
|
79
|
-
//
|
|
85
|
+
// vector overload
|
|
80
86
|
template <unsigned Length, typename T>
|
|
81
87
|
inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T> val, int offset, int mask)
|
|
82
88
|
{
|
|
@@ -88,7 +94,7 @@ inline CUDA_CALLABLE wp::vec_t<Length, T> warp_shuffle_down(wp::vec_t<Length, T>
|
|
|
88
94
|
return result;
|
|
89
95
|
}
|
|
90
96
|
|
|
91
|
-
//
|
|
97
|
+
// matrix overload
|
|
92
98
|
template <unsigned Rows, unsigned Cols, typename T>
|
|
93
99
|
inline CUDA_CALLABLE wp::mat_t<Rows, Cols, T> warp_shuffle_down(wp::mat_t<Rows, Cols, T> val, int offset, int mask)
|
|
94
100
|
{
|
|
@@ -117,7 +123,7 @@ inline CUDA_CALLABLE T warp_reduce(T val, Op f, unsigned int mask)
|
|
|
117
123
|
}
|
|
118
124
|
else
|
|
119
125
|
{
|
|
120
|
-
// handle partial warp case
|
|
126
|
+
// handle partial warp case - works for contiguous masks
|
|
121
127
|
for (int offset=WP_TILE_WARP_SIZE/2; offset > 0; offset /= 2)
|
|
122
128
|
{
|
|
123
129
|
T shfl_val = warp_shuffle_down(sum, offset, mask);
|
|
@@ -175,6 +181,51 @@ inline CUDA_CALLABLE ValueAndIndex<T> warp_reduce_tracked(T val, int idx, Op f,
|
|
|
175
181
|
return result;
|
|
176
182
|
}
|
|
177
183
|
|
|
184
|
+
// combines per-thread reduction results across warps and the entire block
|
|
185
|
+
// assumes each thread has already reduced its local data to thread_sum
|
|
186
|
+
// returns the block-wide reduced value (only valid in thread 0)
|
|
187
|
+
template <typename T, typename Op>
|
|
188
|
+
inline CUDA_CALLABLE T block_combine_thread_results(T thread_sum, bool thread_has_data, Op f,
|
|
189
|
+
T* partials, int& active_warps)
|
|
190
|
+
{
|
|
191
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
192
|
+
const int warp_index = threadIdx.x / WP_TILE_WARP_SIZE;
|
|
193
|
+
const int lane_index = threadIdx.x % WP_TILE_WARP_SIZE;
|
|
194
|
+
|
|
195
|
+
// determine which threads have data
|
|
196
|
+
unsigned int mask = __ballot_sync(0xFFFFFFFF, thread_has_data);
|
|
197
|
+
bool warp_is_active = mask != 0;
|
|
198
|
+
|
|
199
|
+
// warp reduction
|
|
200
|
+
T warp_sum;
|
|
201
|
+
if (thread_has_data)
|
|
202
|
+
warp_sum = warp_reduce(thread_sum, f, mask);
|
|
203
|
+
|
|
204
|
+
// lane 0 of each active warp writes to shared memory and increments counter
|
|
205
|
+
if (lane_index == 0 && warp_is_active)
|
|
206
|
+
{
|
|
207
|
+
partials[warp_index] = warp_sum;
|
|
208
|
+
atomicAdd(&active_warps, 1);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
// sync to ensure all warps have written their partials
|
|
212
|
+
WP_TILE_SYNC();
|
|
213
|
+
|
|
214
|
+
// thread 0 performs final reduction across active warps
|
|
215
|
+
T block_sum;
|
|
216
|
+
if (threadIdx.x == 0)
|
|
217
|
+
{
|
|
218
|
+
block_sum = partials[0];
|
|
219
|
+
|
|
220
|
+
for (int w = 1; w < active_warps; ++w)
|
|
221
|
+
{
|
|
222
|
+
block_sum = f(block_sum, partials[w]);
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
return block_sum;
|
|
227
|
+
}
|
|
228
|
+
|
|
178
229
|
// non-axis version which computes sum
|
|
179
230
|
// across the entire tile using the whole block
|
|
180
231
|
template <typename Tile, typename Op>
|
|
@@ -185,15 +236,14 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
185
236
|
auto input = t.copy_to_register();
|
|
186
237
|
auto output = tile_register_t<T, tile_layout_register_t<tile_shape_t<1>>>();
|
|
187
238
|
|
|
188
|
-
|
|
189
|
-
const int warp_index = threadIdx.x/WP_TILE_WARP_SIZE;
|
|
190
|
-
const int lane_index = threadIdx.x%WP_TILE_WARP_SIZE;
|
|
191
|
-
|
|
192
|
-
T thread_sum = input.data[0];
|
|
239
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
193
240
|
|
|
194
241
|
using Layout = typename decltype(input)::Layout;
|
|
195
242
|
|
|
196
|
-
// thread
|
|
243
|
+
// step 1: each thread reduces its own registers locally
|
|
244
|
+
T thread_sum = input.data[0];
|
|
245
|
+
bool thread_has_data = Layout::valid(Layout::linear_from_register(0));
|
|
246
|
+
|
|
197
247
|
WP_PRAGMA_UNROLL
|
|
198
248
|
for (int i=1; i < Layout::NumRegs; ++i)
|
|
199
249
|
{
|
|
@@ -204,48 +254,190 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
204
254
|
thread_sum = f(thread_sum, input.data[i]);
|
|
205
255
|
}
|
|
206
256
|
|
|
207
|
-
//
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
// warp reduction
|
|
212
|
-
T warp_sum = warp_reduce(thread_sum, f, mask);
|
|
213
|
-
|
|
214
|
-
// fixed size scratch pad for partial results in shared memory
|
|
215
|
-
WP_TILE_SHARED T partials[warp_count];
|
|
216
|
-
|
|
217
|
-
// count of active warps
|
|
218
|
-
WP_TILE_SHARED int active_warps;
|
|
257
|
+
// shared memory for cross-warp reduction
|
|
258
|
+
__shared__ T partials[warp_count];
|
|
259
|
+
__shared__ int active_warps;
|
|
260
|
+
|
|
219
261
|
if (threadIdx.x == 0)
|
|
220
262
|
active_warps = 0;
|
|
221
263
|
|
|
222
|
-
// ensure active_warps is initialized
|
|
223
264
|
WP_TILE_SYNC();
|
|
224
265
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
partials[warp_index] = warp_sum;
|
|
228
|
-
atomicAdd(&active_warps, 1);
|
|
229
|
-
}
|
|
266
|
+
// step 2-3: combine thread results across warps and block
|
|
267
|
+
T block_sum = block_combine_thread_results(thread_sum, thread_has_data, f, partials, active_warps);
|
|
230
268
|
|
|
231
|
-
// ensure partials are ready
|
|
232
|
-
WP_TILE_SYNC();
|
|
233
|
-
|
|
234
|
-
// reduce across block, todo: use warp_reduce() here
|
|
235
269
|
if (threadIdx.x == 0)
|
|
236
|
-
{
|
|
237
|
-
T block_sum = partials[0];
|
|
238
|
-
|
|
239
|
-
WP_PRAGMA_UNROLL
|
|
240
|
-
for (int i=1; i < active_warps; ++i)
|
|
241
|
-
block_sum = f(block_sum, partials[i]);
|
|
242
|
-
|
|
243
270
|
output.data[0] = block_sum;
|
|
244
|
-
}
|
|
245
271
|
|
|
246
272
|
return output;
|
|
247
273
|
}
|
|
248
274
|
|
|
275
|
+
template <int Axis, typename Op, typename Tile>
|
|
276
|
+
auto tile_reduce_axis_impl(Op f, Tile& t)
|
|
277
|
+
{
|
|
278
|
+
using T = typename Tile::Type;
|
|
279
|
+
using InputShape = typename Tile::Layout::Shape;
|
|
280
|
+
using OutputShape = typename tile_shape_remove_dim<Axis, InputShape>::type;
|
|
281
|
+
|
|
282
|
+
constexpr int reduce_dim_size = InputShape::dim(Axis);
|
|
283
|
+
constexpr int output_size = OutputShape::size();
|
|
284
|
+
|
|
285
|
+
// special case: 1D input delegates to block-wide tile_reduce_impl for optimal performance
|
|
286
|
+
if constexpr (InputShape::N == 1)
|
|
287
|
+
{
|
|
288
|
+
return tile_reduce_impl(f, t);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
// shared memory buffer for the output (used by all tiers)
|
|
292
|
+
__shared__ T output_buffer[output_size];
|
|
293
|
+
|
|
294
|
+
// create output layout for coordinate conversion (used by all tiers)
|
|
295
|
+
using OutputLayout = tile_layout_strided_t<OutputShape>;
|
|
296
|
+
|
|
297
|
+
if constexpr (reduce_dim_size <= 32)
|
|
298
|
+
{
|
|
299
|
+
// Tier 1: Single thread per output element (optimal for small reductions)
|
|
300
|
+
|
|
301
|
+
// each thread processes output elements, performing reduction along the axis
|
|
302
|
+
for (int out_idx = WP_TILE_THREAD_IDX; out_idx < output_size; out_idx += WP_TILE_BLOCK_DIM)
|
|
303
|
+
{
|
|
304
|
+
// convert output linear index to output coordinates
|
|
305
|
+
auto out_coord = OutputLayout::coord_from_linear(out_idx);
|
|
306
|
+
|
|
307
|
+
// initialize accumulator with first element along the reduction axis
|
|
308
|
+
T accumulator = t.data(tile_coord_insert_axis<Axis>(out_coord, 0));
|
|
309
|
+
|
|
310
|
+
// reduce across the axis
|
|
311
|
+
for (int i = 1; i < reduce_dim_size; ++i)
|
|
312
|
+
{
|
|
313
|
+
accumulator = f(accumulator, t.data(tile_coord_insert_axis<Axis>(out_coord, i)));
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
// store to output buffer
|
|
317
|
+
output_buffer[out_idx] = accumulator;
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
// sync before reading output
|
|
321
|
+
WP_TILE_SYNC();
|
|
322
|
+
}
|
|
323
|
+
else if constexpr (reduce_dim_size <= 256)
|
|
324
|
+
{
|
|
325
|
+
// Tier 2: Warp-based reduction (one warp per output element)
|
|
326
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
327
|
+
const int warp_index = threadIdx.x / WP_TILE_WARP_SIZE;
|
|
328
|
+
const int lane_index = threadIdx.x % WP_TILE_WARP_SIZE;
|
|
329
|
+
|
|
330
|
+
constexpr int chunks_per_slice = (reduce_dim_size + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
331
|
+
|
|
332
|
+
// shared memory: one accumulator per warp
|
|
333
|
+
__shared__ T warp_partials[warp_count];
|
|
334
|
+
|
|
335
|
+
// each warp processes output slices
|
|
336
|
+
for (int out_idx = warp_index; out_idx < output_size; out_idx += warp_count)
|
|
337
|
+
{
|
|
338
|
+
auto out_coord = OutputLayout::coord_from_linear(out_idx);
|
|
339
|
+
|
|
340
|
+
// process the reduction axis in chunks of 32
|
|
341
|
+
for (int chunk = 0; chunk < chunks_per_slice; ++chunk)
|
|
342
|
+
{
|
|
343
|
+
int axis_idx = chunk * WP_TILE_WARP_SIZE + lane_index;
|
|
344
|
+
bool valid = axis_idx < reduce_dim_size;
|
|
345
|
+
|
|
346
|
+
T val;
|
|
347
|
+
if (valid)
|
|
348
|
+
{
|
|
349
|
+
auto in_coord = tile_coord_insert_axis<Axis>(out_coord, axis_idx);
|
|
350
|
+
val = t.data(in_coord);
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
// warp reduce this chunk (only valid lanes participate)
|
|
354
|
+
unsigned int mask = __ballot_sync(0xFFFFFFFF, valid);
|
|
355
|
+
T chunk_result = warp_reduce(val, f, mask);
|
|
356
|
+
|
|
357
|
+
// lane 0 accumulates the chunk result
|
|
358
|
+
if (lane_index == 0)
|
|
359
|
+
{
|
|
360
|
+
if (chunk == 0)
|
|
361
|
+
warp_partials[warp_index] = chunk_result;
|
|
362
|
+
else
|
|
363
|
+
warp_partials[warp_index] = f(warp_partials[warp_index], chunk_result);
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
// lane 0 writes final result for this output element
|
|
368
|
+
if (lane_index == 0)
|
|
369
|
+
output_buffer[out_idx] = warp_partials[warp_index];
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
// sync before reading output
|
|
373
|
+
WP_TILE_SYNC();
|
|
374
|
+
}
|
|
375
|
+
else
|
|
376
|
+
{
|
|
377
|
+
// Tier 3: Block-level reduction (entire block collaborates on each output element)
|
|
378
|
+
constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1) / WP_TILE_WARP_SIZE;
|
|
379
|
+
|
|
380
|
+
// shared memory for cross-warp reduction
|
|
381
|
+
__shared__ T partials[warp_count];
|
|
382
|
+
__shared__ int active_warps;
|
|
383
|
+
|
|
384
|
+
// process each output element sequentially with full block cooperation
|
|
385
|
+
for (int out_idx = 0; out_idx < output_size; ++out_idx)
|
|
386
|
+
{
|
|
387
|
+
auto out_coord = OutputLayout::coord_from_linear(out_idx);
|
|
388
|
+
|
|
389
|
+
// step 1: each thread reduces its strided subset of the slice locally
|
|
390
|
+
bool thread_has_data = threadIdx.x < reduce_dim_size;
|
|
391
|
+
T thread_sum;
|
|
392
|
+
|
|
393
|
+
if (thread_has_data)
|
|
394
|
+
{
|
|
395
|
+
// initialize with first element
|
|
396
|
+
auto in_coord = tile_coord_insert_axis<Axis>(out_coord, threadIdx.x);
|
|
397
|
+
thread_sum = t.data(in_coord);
|
|
398
|
+
|
|
399
|
+
// reduce remaining elements with stride
|
|
400
|
+
for (int i = threadIdx.x + WP_TILE_BLOCK_DIM; i < reduce_dim_size; i += WP_TILE_BLOCK_DIM)
|
|
401
|
+
{
|
|
402
|
+
auto in_coord = tile_coord_insert_axis<Axis>(out_coord, i);
|
|
403
|
+
T val = t.data(in_coord);
|
|
404
|
+
thread_sum = f(thread_sum, val);
|
|
405
|
+
}
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
// initialize active warp counter
|
|
409
|
+
if (threadIdx.x == 0)
|
|
410
|
+
active_warps = 0;
|
|
411
|
+
|
|
412
|
+
WP_TILE_SYNC();
|
|
413
|
+
|
|
414
|
+
// step 2-3: combine thread results across warps and block
|
|
415
|
+
T block_sum = block_combine_thread_results(thread_sum, thread_has_data, f, partials, active_warps);
|
|
416
|
+
|
|
417
|
+
if (threadIdx.x == 0)
|
|
418
|
+
output_buffer[out_idx] = block_sum;
|
|
419
|
+
|
|
420
|
+
// sync before next output element
|
|
421
|
+
WP_TILE_SYNC();
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
// copy from shared memory buffer to register tile (common to all tiers)
|
|
426
|
+
auto output = tile_register_t<T, tile_layout_register_t<OutputShape>>();
|
|
427
|
+
using OutputRegLayout = typename decltype(output)::Layout;
|
|
428
|
+
|
|
429
|
+
WP_PRAGMA_UNROLL
|
|
430
|
+
for (int i = 0; i < OutputRegLayout::NumRegs; ++i)
|
|
431
|
+
{
|
|
432
|
+
int linear = OutputRegLayout::linear_from_register(i);
|
|
433
|
+
if (!OutputRegLayout::valid(linear))
|
|
434
|
+
break;
|
|
435
|
+
|
|
436
|
+
output.data[i] = output_buffer[linear];
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
return output;
|
|
440
|
+
}
|
|
249
441
|
|
|
250
442
|
// non-axis version which computes sum
|
|
251
443
|
// across the entire tile using the whole block
|
|
@@ -286,11 +478,11 @@ auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
|
|
|
286
478
|
ValueAndIndex<T> warp_sum = warp_reduce_tracked(thread_sum, champion_index, f, track, mask);
|
|
287
479
|
|
|
288
480
|
// fixed size scratch pad for partial results in shared memory
|
|
289
|
-
|
|
290
|
-
|
|
481
|
+
__shared__ T partials[warp_count];
|
|
482
|
+
__shared__ int partials_idx[warp_count];
|
|
291
483
|
|
|
292
484
|
// count of active warps
|
|
293
|
-
|
|
485
|
+
__shared__ int active_warps;
|
|
294
486
|
if (threadIdx.x == 0)
|
|
295
487
|
active_warps = 0;
|
|
296
488
|
|
|
@@ -356,6 +548,65 @@ auto tile_reduce_impl(Op f, Tile& t)
|
|
|
356
548
|
return output;
|
|
357
549
|
}
|
|
358
550
|
|
|
551
|
+
template <int Axis, typename Op, typename Tile>
|
|
552
|
+
auto tile_reduce_axis_impl(Op f, Tile& t)
|
|
553
|
+
{
|
|
554
|
+
using T = typename Tile::Type;
|
|
555
|
+
using InputShape = typename Tile::Layout::Shape;
|
|
556
|
+
using OutputShape = typename tile_shape_remove_dim<Axis, InputShape>::type;
|
|
557
|
+
|
|
558
|
+
constexpr int reduce_dim_size = InputShape::dim(Axis);
|
|
559
|
+
|
|
560
|
+
// CPU version - work directly with register tiles, no thread coordination needed
|
|
561
|
+
auto input = t.copy_to_register();
|
|
562
|
+
auto output = tile_register_t<T, tile_layout_register_t<OutputShape>>();
|
|
563
|
+
using OutputLayout = typename decltype(output)::Layout;
|
|
564
|
+
|
|
565
|
+
// iterate through each output element and reduce along the axis
|
|
566
|
+
constexpr int output_size = OutputShape::size();
|
|
567
|
+
for (int out_idx = 0; out_idx < output_size; ++out_idx)
|
|
568
|
+
{
|
|
569
|
+
T accumulator;
|
|
570
|
+
|
|
571
|
+
// special case for 1D input (reduces to single value)
|
|
572
|
+
if constexpr (InputShape::N == 1)
|
|
573
|
+
{
|
|
574
|
+
accumulator = input.data[0];
|
|
575
|
+
for (int i = 1; i < reduce_dim_size; ++i)
|
|
576
|
+
{
|
|
577
|
+
// input is in registers, linear access
|
|
578
|
+
accumulator = f(accumulator, input.data[i]);
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
else
|
|
582
|
+
{
|
|
583
|
+
// multi-dimensional case
|
|
584
|
+
auto out_coord = OutputLayout::coord_from_linear(out_idx);
|
|
585
|
+
|
|
586
|
+
// get input coordinates by inserting axis values
|
|
587
|
+
auto coord_0 = tile_coord_insert_axis<Axis>(out_coord, 0);
|
|
588
|
+
int input_linear_0 = tile_layout_register_t<InputShape>::linear_from_coord(coord_0);
|
|
589
|
+
int input_reg_0 = tile_layout_register_t<InputShape>::register_from_linear(input_linear_0);
|
|
590
|
+
accumulator = input.data[input_reg_0];
|
|
591
|
+
|
|
592
|
+
// reduce across the axis
|
|
593
|
+
for (int i = 1; i < reduce_dim_size; ++i)
|
|
594
|
+
{
|
|
595
|
+
auto coord_i = tile_coord_insert_axis<Axis>(out_coord, i);
|
|
596
|
+
int input_linear_i = tile_layout_register_t<InputShape>::linear_from_coord(coord_i);
|
|
597
|
+
int input_reg_i = tile_layout_register_t<InputShape>::register_from_linear(input_linear_i);
|
|
598
|
+
accumulator = f(accumulator, input.data[input_reg_i]);
|
|
599
|
+
}
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
// store to output register
|
|
603
|
+
int output_reg = OutputLayout::register_from_linear(out_idx);
|
|
604
|
+
output.data[output_reg] = accumulator;
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
return output;
|
|
608
|
+
}
|
|
609
|
+
|
|
359
610
|
template <typename Tile, typename Op, typename OpTrack>
|
|
360
611
|
auto tile_arg_reduce_impl(Op f, OpTrack track, Tile& t)
|
|
361
612
|
{
|
|
@@ -391,15 +642,25 @@ inline void adj_tile_reduce_impl()
|
|
|
391
642
|
// todo: general purpose reduction gradients not implemented
|
|
392
643
|
}
|
|
393
644
|
|
|
645
|
+
inline void adj_tile_reduce_axis_impl()
|
|
646
|
+
{
|
|
647
|
+
// todo: axis-specific reduction gradients not implemented
|
|
648
|
+
}
|
|
649
|
+
|
|
394
650
|
// entry point for Python code-gen, wraps op in a lambda to perform overload resolution
|
|
395
651
|
#define tile_reduce(op, t) tile_reduce_impl([](auto x, auto y) { return op(x, y);}, t)
|
|
396
|
-
#define adj_tile_reduce(op,
|
|
652
|
+
#define adj_tile_reduce(op, t, adj_op, adj_t, adj_ret) adj_tile_reduce_impl()
|
|
397
653
|
|
|
398
654
|
#define tile_arg_reduce(op, opTrack, t) tile_arg_reduce_impl([](auto x, auto y) { return op(x, y);}, [](auto a, auto b, auto c, auto d) { return opTrack(a, b, c, d); }, t)
|
|
399
|
-
#define adj_tile_arg_reduce(op,
|
|
655
|
+
#define adj_tile_arg_reduce(op, t, adj_op, adj_t, adj_ret) adj_tile_arg_reduce_impl()
|
|
656
|
+
|
|
657
|
+
// axis-specific reduction entry points
|
|
658
|
+
#define tile_reduce_axis(op, t, axis) tile_reduce_axis_impl<axis>([](auto x, auto y) { return op(x, y);}, t)
|
|
659
|
+
#define adj_tile_reduce_axis(op, t, axis, adj_op, adj_t, adj_axis, adj_ret) adj_tile_reduce_axis_impl()
|
|
400
660
|
|
|
401
661
|
// convenience methods for specific reductions
|
|
402
662
|
|
|
663
|
+
// whole-tile sum
|
|
403
664
|
template <typename Tile>
|
|
404
665
|
auto tile_sum(Tile& t)
|
|
405
666
|
{
|
|
@@ -418,7 +679,7 @@ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
|
418
679
|
T scratch = adj_reg.data[0];
|
|
419
680
|
#else
|
|
420
681
|
// broadcast incoming adjoint to block
|
|
421
|
-
|
|
682
|
+
__shared__ T scratch;
|
|
422
683
|
if (WP_TILE_THREAD_IDX == 0)
|
|
423
684
|
scratch = adj_reg.data[0];
|
|
424
685
|
|
|
@@ -434,6 +695,90 @@ void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
|
434
695
|
adj_t.grad_add(adj_ret_reg);
|
|
435
696
|
}
|
|
436
697
|
|
|
698
|
+
// axis-specific sum
|
|
699
|
+
template <int Axis, typename Tile>
|
|
700
|
+
auto tile_sum(Tile& t)
|
|
701
|
+
{
|
|
702
|
+
return tile_reduce_axis_impl<Axis>([](auto x, auto y) { return add(x, y); }, t);
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
// special case adjoint for axis-specific summation
|
|
706
|
+
template<int Axis, typename Tile, typename AdjTile>
|
|
707
|
+
void adj_tile_sum(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
708
|
+
{
|
|
709
|
+
using InputShape = typename Tile::Layout::Shape;
|
|
710
|
+
|
|
711
|
+
if constexpr (InputShape::N == 1)
|
|
712
|
+
{
|
|
713
|
+
// 1D -> scalar case: broadcast scalar to 1D
|
|
714
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), 0>(adj_ret);
|
|
715
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
716
|
+
}
|
|
717
|
+
else if constexpr (InputShape::N == 2)
|
|
718
|
+
{
|
|
719
|
+
if constexpr (Axis == 0)
|
|
720
|
+
{
|
|
721
|
+
// broadcast from (D1,) to (D0, D1) with strides (0, 1)
|
|
722
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), 0, 1>(adj_ret);
|
|
723
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
724
|
+
}
|
|
725
|
+
else // Axis == 1
|
|
726
|
+
{
|
|
727
|
+
// broadcast from (D0,) to (D0, D1) with strides (1, 0)
|
|
728
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), 1, 0>(adj_ret);
|
|
729
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
730
|
+
}
|
|
731
|
+
}
|
|
732
|
+
else if constexpr (InputShape::N == 3)
|
|
733
|
+
{
|
|
734
|
+
if constexpr (Axis == 0)
|
|
735
|
+
{
|
|
736
|
+
// broadcast from (D1, D2) to (D0, D1, D2) with strides (0, D2, 1)
|
|
737
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), 0, InputShape::dim(2), 1>(adj_ret);
|
|
738
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
739
|
+
}
|
|
740
|
+
else if constexpr (Axis == 1)
|
|
741
|
+
{
|
|
742
|
+
// broadcast from (D0, D2) to (D0, D1, D2) with strides (D2, 0, 1)
|
|
743
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(2), 0, 1>(adj_ret);
|
|
744
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
745
|
+
}
|
|
746
|
+
else // Axis == 2
|
|
747
|
+
{
|
|
748
|
+
// broadcast from (D0, D1) to (D0, D1, D2) with strides (D1, 1, 0)
|
|
749
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(1), 1, 0>(adj_ret);
|
|
750
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
751
|
+
}
|
|
752
|
+
}
|
|
753
|
+
else if constexpr (InputShape::N == 4)
|
|
754
|
+
{
|
|
755
|
+
if constexpr (Axis == 0)
|
|
756
|
+
{
|
|
757
|
+
// broadcast from (D1, D2, D3) to (D0, D1, D2, D3) with strides (0, D2*D3, D3, 1)
|
|
758
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), 0, InputShape::dim(2)*InputShape::dim(3), InputShape::dim(3), 1>(adj_ret);
|
|
759
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
760
|
+
}
|
|
761
|
+
else if constexpr (Axis == 1)
|
|
762
|
+
{
|
|
763
|
+
// broadcast from (D0, D2, D3) to (D0, D1, D2, D3) with strides (D2*D3, 0, D3, 1)
|
|
764
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(2)*InputShape::dim(3), 0, InputShape::dim(3), 1>(adj_ret);
|
|
765
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
766
|
+
}
|
|
767
|
+
else if constexpr (Axis == 2)
|
|
768
|
+
{
|
|
769
|
+
// broadcast from (D0, D1, D3) to (D0, D1, D2, D3) with strides (D1*D3, D3, 0, 1)
|
|
770
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(1)*InputShape::dim(3), InputShape::dim(3), 0, 1>(adj_ret);
|
|
771
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
772
|
+
}
|
|
773
|
+
else // Axis == 3
|
|
774
|
+
{
|
|
775
|
+
// broadcast from (D0, D1, D2) to (D0, D1, D2, D3) with strides (D1*D2, D2, 1, 0)
|
|
776
|
+
auto broadcasted = tile_broadcast<InputShape::dim(0), InputShape::dim(1), InputShape::dim(2), InputShape::dim(3), InputShape::dim(1)*InputShape::dim(2), InputShape::dim(2), 1, 0>(adj_ret);
|
|
777
|
+
tile_add_inplace(adj_t, broadcasted);
|
|
778
|
+
}
|
|
779
|
+
}
|
|
780
|
+
}
|
|
781
|
+
|
|
437
782
|
template <typename Tile>
|
|
438
783
|
auto tile_max(Tile& t)
|
|
439
784
|
{
|
|
@@ -485,6 +830,9 @@ void adj_tile_argmin(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
|
485
830
|
}
|
|
486
831
|
|
|
487
832
|
|
|
833
|
+
} // namespace wp
|
|
488
834
|
|
|
489
835
|
|
|
490
|
-
|
|
836
|
+
#ifdef __clang__
|
|
837
|
+
#pragma clang diagnostic pop
|
|
838
|
+
#endif
|
warp/native/tile_scan.h
CHANGED
|
@@ -50,7 +50,7 @@ inline CUDA_CALLABLE T scan_warp_inclusive(int lane, T value)
|
|
|
50
50
|
template<typename T>
|
|
51
51
|
inline CUDA_CALLABLE T thread_block_scan_inclusive(int lane, int warp_index, int num_warps, T value)
|
|
52
52
|
{
|
|
53
|
-
|
|
53
|
+
__shared__ T sums[1024 / WP_TILE_WARP_SIZE]; // 1024 is the maximum number of threads per block
|
|
54
54
|
|
|
55
55
|
value = scan_warp_inclusive(lane, value);
|
|
56
56
|
|
|
@@ -85,7 +85,7 @@ inline CUDA_CALLABLE void thread_block_scan(T* values, int num_elements)
|
|
|
85
85
|
const int num_threads_in_block = blockDim.x;
|
|
86
86
|
const int num_iterations = (num_elements + num_threads_in_block - 1) / num_threads_in_block;
|
|
87
87
|
|
|
88
|
-
|
|
88
|
+
__shared__ T offset;
|
|
89
89
|
if (threadIdx.x == 0)
|
|
90
90
|
offset = T(0);
|
|
91
91
|
|
|
@@ -124,7 +124,7 @@ inline CUDA_CALLABLE auto tile_scan_inclusive_impl(Tile& t)
|
|
|
124
124
|
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
125
125
|
|
|
126
126
|
// create a temporary shared tile to hold the input values
|
|
127
|
-
|
|
127
|
+
__shared__ T smem[num_elements_to_scan];
|
|
128
128
|
tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
|
|
129
129
|
|
|
130
130
|
// copy input values to scratch space
|
|
@@ -147,7 +147,7 @@ inline CUDA_CALLABLE auto tile_scan_exclusive_impl(Tile& t)
|
|
|
147
147
|
constexpr int num_elements_to_scan = Tile::Layout::Shape::size();
|
|
148
148
|
|
|
149
149
|
// create a temporary shared tile to hold the input values
|
|
150
|
-
|
|
150
|
+
__shared__ T smem[num_elements_to_scan];
|
|
151
151
|
tile_shared_t<T, tile_layout_strided_t<typename Tile::Layout::Shape>, false> scratch(smem, nullptr);
|
|
152
152
|
|
|
153
153
|
// copy input values to scratch space
|