warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/native/tile.h
CHANGED
|
@@ -803,7 +803,7 @@ struct tile_layout_strided_t
|
|
|
803
803
|
}
|
|
804
804
|
|
|
805
805
|
// checks whether a strided layout is unique, i.e.: if memory locations are only
|
|
806
|
-
//
|
|
806
|
+
// ever referred to by one element in the tile, this is a basic test that only
|
|
807
807
|
// checks for broadcast dimensions, it would be possible to do the full check
|
|
808
808
|
// using sorted shape/strides in Python and add it as a template parameter to the type
|
|
809
809
|
static constexpr bool is_unique()
|
|
@@ -912,33 +912,27 @@ struct tile_shared_t
|
|
|
912
912
|
}
|
|
913
913
|
|
|
914
914
|
// assign from a register tile
|
|
915
|
-
|
|
916
|
-
inline CUDA_CALLABLE auto& operator=(const Tile& t)
|
|
915
|
+
inline CUDA_CALLABLE auto& operator=(const tile_register_t<Type, tile_layout_register_t<typename Layout::Shape>>& t)
|
|
917
916
|
{
|
|
918
917
|
assign(t);
|
|
919
918
|
return *this;
|
|
920
919
|
}
|
|
921
920
|
|
|
922
|
-
|
|
923
|
-
/*
|
|
924
921
|
// construct from another shared tile, this constructor
|
|
925
922
|
// is invoked for reshape operations like `wp.tile_transpose()`
|
|
926
|
-
template <typename OtherT, typename OtherLayout>
|
|
927
|
-
inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout>& rhs)
|
|
923
|
+
template <typename OtherT, typename OtherLayout, bool OtherOwner>
|
|
924
|
+
inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& rhs)
|
|
928
925
|
{
|
|
929
|
-
using OtherTile = tile_shared_t<OtherT, OtherLayout>;
|
|
930
|
-
|
|
931
926
|
// check dimensions are compatible
|
|
932
|
-
static_assert(Size ==
|
|
927
|
+
static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
|
|
933
928
|
|
|
934
929
|
// alias tile directly
|
|
935
|
-
data = rhs.data;
|
|
936
|
-
grad = rhs.grad;
|
|
930
|
+
data.ptr = rhs.data.ptr;
|
|
931
|
+
grad.ptr = rhs.grad.ptr;
|
|
937
932
|
initialized = rhs.initialized;
|
|
938
933
|
|
|
939
934
|
return *this;
|
|
940
935
|
}
|
|
941
|
-
*/
|
|
942
936
|
|
|
943
937
|
// assign from a global tile (load)
|
|
944
938
|
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
|
|
@@ -989,6 +983,37 @@ struct tile_shared_t
|
|
|
989
983
|
WP_TILE_SYNC();
|
|
990
984
|
}
|
|
991
985
|
|
|
986
|
+
// add scalar value onto a single tile element
|
|
987
|
+
inline CUDA_CALLABLE void add_inplace(const typename Layout::Coord& c, const Type& x)
|
|
988
|
+
{
|
|
989
|
+
// since multiple threads may add to the same element
|
|
990
|
+
// we need to accumulate using atomic operations
|
|
991
|
+
wp::atomic_add(&data(c), x);
|
|
992
|
+
|
|
993
|
+
WP_TILE_SYNC();
|
|
994
|
+
}
|
|
995
|
+
|
|
996
|
+
// backward of inplace scalar addition
|
|
997
|
+
inline CUDA_CALLABLE void adj_add_inplace(const typename Layout::Coord& c, Type& adj_x)
|
|
998
|
+
{
|
|
999
|
+
adj_x += grad(c);
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
// subtract scalar value from a single tile element
|
|
1003
|
+
inline CUDA_CALLABLE void sub_inplace(const typename Layout::Coord& c, const Type& x)
|
|
1004
|
+
{
|
|
1005
|
+
// since multiple threads may add to the same element
|
|
1006
|
+
// we need to accumulate using atomic operations
|
|
1007
|
+
wp::atomic_add(&data(c), -x);
|
|
1008
|
+
|
|
1009
|
+
WP_TILE_SYNC();
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
// backward of inplace scalar subtraction
|
|
1013
|
+
inline CUDA_CALLABLE void adj_sub_inplace(const typename Layout::Coord& c, Type& adj_x)
|
|
1014
|
+
{
|
|
1015
|
+
adj_x -= grad(c);
|
|
1016
|
+
}
|
|
992
1017
|
|
|
993
1018
|
// copy register tile to shared
|
|
994
1019
|
template <typename Tile>
|
|
@@ -1472,10 +1497,10 @@ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const t
|
|
|
1472
1497
|
|
|
1473
1498
|
|
|
1474
1499
|
// helpers to allocate shared tiles
|
|
1475
|
-
template <typename T, typename Shape, bool RequiresGrad>
|
|
1500
|
+
template <typename T, typename Shape, typename Strides, bool RequiresGrad>
|
|
1476
1501
|
inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
1477
|
-
|
|
1478
|
-
|
|
1502
|
+
{
|
|
1503
|
+
constexpr int size = Shape::size();
|
|
1479
1504
|
T* data = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1480
1505
|
T* grad = nullptr;
|
|
1481
1506
|
|
|
@@ -1503,7 +1528,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1503
1528
|
WP_TILE_SYNC();
|
|
1504
1529
|
}
|
|
1505
1530
|
|
|
1506
|
-
return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
|
|
1531
|
+
return tile_shared_t<T, tile_layout_strided_t<Shape, Strides>>(data, grad);
|
|
1507
1532
|
}
|
|
1508
1533
|
|
|
1509
1534
|
|
|
@@ -1532,37 +1557,56 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
|
1532
1557
|
using Layout = typename decltype(result)::Layout;
|
|
1533
1558
|
static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
|
|
1534
1559
|
|
|
1535
|
-
for (
|
|
1560
|
+
for (unsigned i=0; i < Length; ++i)
|
|
1536
1561
|
result.data[i] = x[i];
|
|
1537
1562
|
|
|
1538
1563
|
return result;
|
|
1539
1564
|
}
|
|
1540
1565
|
|
|
1541
|
-
//
|
|
1542
|
-
template <
|
|
1543
|
-
inline CUDA_CALLABLE
|
|
1566
|
+
// overload for constructing a tile from a per-thread matrix
|
|
1567
|
+
template <unsigned Rows, unsigned Cols, typename T>
|
|
1568
|
+
inline CUDA_CALLABLE auto tile(const wp::mat_t<Rows, Cols, T>& x)
|
|
1544
1569
|
{
|
|
1545
|
-
|
|
1546
|
-
static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM");
|
|
1570
|
+
tile_register_t<T, tile_layout_register_t<tile_shape_t<Rows, Cols, WP_TILE_BLOCK_DIM>>> result;
|
|
1547
1571
|
|
|
1548
|
-
|
|
1572
|
+
using Layout = typename decltype(result)::Layout;
|
|
1573
|
+
static_assert(Layout::NumRegs == Rows*Cols, "Expected Layout::NumRegs == Rows*Cols");
|
|
1574
|
+
|
|
1575
|
+
for (unsigned i=0; i < Rows; ++i)
|
|
1576
|
+
for (unsigned j=0; j < Cols; ++j)
|
|
1577
|
+
result.data[i*Cols + j] = x.data[i][j];
|
|
1549
1578
|
|
|
1550
|
-
|
|
1579
|
+
return result;
|
|
1551
1580
|
}
|
|
1552
1581
|
|
|
1553
|
-
|
|
1554
|
-
|
|
1582
|
+
// it is sufficient to use a single adjoint for all tile overload funcs
|
|
1583
|
+
// it is also necessary, because we don't provide a dispatch_func for adjoint calls
|
|
1584
|
+
// so the compiler will default to choosing based on argument types
|
|
1585
|
+
template <typename T, typename AdjTile>
|
|
1586
|
+
inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
1555
1587
|
{
|
|
1556
|
-
static_assert(AdjTile::Layout::Shape::N ==
|
|
1557
|
-
|
|
1558
|
-
static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM");
|
|
1559
|
-
|
|
1588
|
+
static_assert(AdjTile::Layout::Shape::dim(AdjTile::Layout::Shape::N - 1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(AdjTile::Layout::Shape::N - 1) == WP_TILE_BLOCK_DIM");
|
|
1589
|
+
|
|
1560
1590
|
auto adj_reg = adj_ret.copy_to_register();
|
|
1561
1591
|
|
|
1562
|
-
|
|
1563
|
-
|
|
1592
|
+
if constexpr (AdjTile::Layout::Shape::N == 1)
|
|
1593
|
+
{
|
|
1594
|
+
adj_x += adj_reg.data[0];
|
|
1595
|
+
}
|
|
1596
|
+
else if constexpr (AdjTile::Layout::Shape::N == 2)
|
|
1597
|
+
{
|
|
1598
|
+
for (unsigned i=0; i < AdjTile::Layout::Shape::dim(0); ++i)
|
|
1599
|
+
adj_x[i] += adj_reg.data[i];
|
|
1600
|
+
}
|
|
1601
|
+
else if constexpr (AdjTile::Layout::Shape::N == 3)
|
|
1602
|
+
{
|
|
1603
|
+
for (unsigned i=0; i < AdjTile::Layout::Shape::dim(0); ++i)
|
|
1604
|
+
for (unsigned j=0; j < AdjTile::Layout::Shape::dim(1); ++j)
|
|
1605
|
+
adj_x.data[i][j] += adj_reg.data[i*AdjTile::Layout::Shape::dim(1) + j];
|
|
1606
|
+
}
|
|
1564
1607
|
}
|
|
1565
1608
|
|
|
1609
|
+
|
|
1566
1610
|
template <typename Tile>
|
|
1567
1611
|
inline CUDA_CALLABLE auto untile(Tile& tile)
|
|
1568
1612
|
{
|
|
@@ -1589,6 +1633,19 @@ inline CUDA_CALLABLE auto untile(Tile& tile)
|
|
|
1589
1633
|
|
|
1590
1634
|
return v;
|
|
1591
1635
|
}
|
|
1636
|
+
|
|
1637
|
+
// matrix case
|
|
1638
|
+
if constexpr(N == 3)
|
|
1639
|
+
{
|
|
1640
|
+
constexpr int Rows = Tile::Layout::Shape::dim(0);
|
|
1641
|
+
constexpr int Cols = Tile::Layout::Shape::dim(1);
|
|
1642
|
+
wp::mat_t<Rows, Cols, typename Tile::Type> m;
|
|
1643
|
+
for (int i=0; i < Rows; ++i)
|
|
1644
|
+
for (int j=0; j < Cols; ++j)
|
|
1645
|
+
m.data[i][j] = reg.data[i*Cols + j];
|
|
1646
|
+
|
|
1647
|
+
return m;
|
|
1648
|
+
}
|
|
1592
1649
|
}
|
|
1593
1650
|
|
|
1594
1651
|
template <typename Tile, typename Value>
|
|
@@ -1612,6 +1669,16 @@ inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
|
|
|
1612
1669
|
adj.data[i] += adj_ret[i];
|
|
1613
1670
|
}
|
|
1614
1671
|
|
|
1672
|
+
// matrix case
|
|
1673
|
+
if constexpr(N == 3)
|
|
1674
|
+
{
|
|
1675
|
+
constexpr int Rows = Tile::Layout::Shape::dim(0);
|
|
1676
|
+
constexpr int Cols = Tile::Layout::Shape::dim(1);
|
|
1677
|
+
for (int i=0; i < Rows; ++i)
|
|
1678
|
+
for (int j=0; j < Cols; ++j)
|
|
1679
|
+
adj.data[i*Cols + j] += adj_ret.data[i][j];
|
|
1680
|
+
}
|
|
1681
|
+
|
|
1615
1682
|
adj_tile.assign(adj);
|
|
1616
1683
|
}
|
|
1617
1684
|
|
|
@@ -1893,6 +1960,27 @@ inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
|
|
|
1893
1960
|
return tile_binary_map(add, a, b);
|
|
1894
1961
|
}
|
|
1895
1962
|
|
|
1963
|
+
// add overloads get called in user function adjoints generated by codegen (adj_tile += adj_ret)
|
|
1964
|
+
template <typename T, typename L>
|
|
1965
|
+
inline CUDA_CALLABLE auto add(tile_register_t<T, L>& a, const tile_register_t<T, L>& b) {
|
|
1966
|
+
return tile_add(a, b);
|
|
1967
|
+
}
|
|
1968
|
+
|
|
1969
|
+
template <typename T, typename L, bool Owner>
|
|
1970
|
+
inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b) {
|
|
1971
|
+
return tile_add(a, b);
|
|
1972
|
+
}
|
|
1973
|
+
|
|
1974
|
+
template <typename T, typename L, bool Owner>
|
|
1975
|
+
inline CUDA_CALLABLE auto add(tile_register_t<T, L>& a, const tile_shared_t<T, L, Owner>& b) {
|
|
1976
|
+
return tile_add(a, b);
|
|
1977
|
+
}
|
|
1978
|
+
|
|
1979
|
+
template <typename T, typename L, bool Owner>
|
|
1980
|
+
inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_register_t<T, L>& b) {
|
|
1981
|
+
return tile_add(a, b);
|
|
1982
|
+
}
|
|
1983
|
+
|
|
1896
1984
|
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1897
1985
|
inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1898
1986
|
{
|
|
@@ -1961,6 +2049,126 @@ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
|
|
|
1961
2049
|
}
|
|
1962
2050
|
|
|
1963
2051
|
|
|
2052
|
+
template <typename TileA, typename TileB>
|
|
2053
|
+
inline CUDA_CALLABLE void tile_add_inplace(TileA& a, TileB& b)
|
|
2054
|
+
{
|
|
2055
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2056
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2057
|
+
|
|
2058
|
+
// verify shapes and sizes are compatible
|
|
2059
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace addition");
|
|
2060
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace addition");
|
|
2061
|
+
|
|
2062
|
+
auto a_reg = a.copy_to_register();
|
|
2063
|
+
auto b_reg = b.copy_to_register();
|
|
2064
|
+
|
|
2065
|
+
using Layout = typename decltype(b_reg)::Layout;
|
|
2066
|
+
|
|
2067
|
+
WP_PRAGMA_UNROLL
|
|
2068
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2069
|
+
{
|
|
2070
|
+
const int linear = Layout::linear_from_register(i);
|
|
2071
|
+
|
|
2072
|
+
if(!Layout::valid(linear))
|
|
2073
|
+
break;
|
|
2074
|
+
|
|
2075
|
+
a_reg.data[i] += b_reg.data[i];
|
|
2076
|
+
}
|
|
2077
|
+
|
|
2078
|
+
a.assign(a_reg);
|
|
2079
|
+
}
|
|
2080
|
+
|
|
2081
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2082
|
+
inline CUDA_CALLABLE void adj_tile_add_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b)
|
|
2083
|
+
{
|
|
2084
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2085
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2086
|
+
|
|
2087
|
+
// verify shapes and sizes are compatible
|
|
2088
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace addition");
|
|
2089
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace addition");
|
|
2090
|
+
|
|
2091
|
+
// allocate storage for adjoints
|
|
2092
|
+
auto adj_a_reg = adj_a.grad_to_register();
|
|
2093
|
+
auto adj_b_reg = tile_register_like<TileB>();
|
|
2094
|
+
|
|
2095
|
+
using Layout = typename decltype(adj_a_reg)::Layout;
|
|
2096
|
+
|
|
2097
|
+
WP_PRAGMA_UNROLL
|
|
2098
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2099
|
+
{
|
|
2100
|
+
const int linear = Layout::linear_from_register(i);
|
|
2101
|
+
|
|
2102
|
+
if(!Layout::valid(linear))
|
|
2103
|
+
break;
|
|
2104
|
+
|
|
2105
|
+
adj_b_reg.data[i] += adj_a_reg.data[i];
|
|
2106
|
+
}
|
|
2107
|
+
|
|
2108
|
+
adj_b.grad_add(adj_b_reg);
|
|
2109
|
+
}
|
|
2110
|
+
|
|
2111
|
+
template <typename TileA, typename TileB>
|
|
2112
|
+
inline CUDA_CALLABLE void tile_sub_inplace(TileA& a, TileB& b)
|
|
2113
|
+
{
|
|
2114
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2115
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2116
|
+
|
|
2117
|
+
// verify shapes and sizes are compatible
|
|
2118
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace subtraction");
|
|
2119
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace subtraction");
|
|
2120
|
+
|
|
2121
|
+
// work with register tiles for inplace operations, regardless of the storage type of the input tiles
|
|
2122
|
+
auto a_reg = a.copy_to_register();
|
|
2123
|
+
auto b_reg = b.copy_to_register();
|
|
2124
|
+
|
|
2125
|
+
using Layout = typename decltype(a_reg)::Layout;
|
|
2126
|
+
|
|
2127
|
+
WP_PRAGMA_UNROLL
|
|
2128
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2129
|
+
{
|
|
2130
|
+
const int linear = Layout::linear_from_register(i);
|
|
2131
|
+
|
|
2132
|
+
if(!Layout::valid(linear))
|
|
2133
|
+
break;
|
|
2134
|
+
|
|
2135
|
+
a_reg.data[i] -= b_reg.data[i];
|
|
2136
|
+
}
|
|
2137
|
+
|
|
2138
|
+
a.assign(a_reg);
|
|
2139
|
+
}
|
|
2140
|
+
|
|
2141
|
+
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
|
|
2142
|
+
inline CUDA_CALLABLE void adj_tile_sub_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b)
|
|
2143
|
+
{
|
|
2144
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2145
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2146
|
+
|
|
2147
|
+
// verify shapes and sizes are compatible
|
|
2148
|
+
static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace subtraction");
|
|
2149
|
+
static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace subtraction");
|
|
2150
|
+
|
|
2151
|
+
// allocate storage for adjoints
|
|
2152
|
+
auto adj_a_reg = adj_a.grad_to_register();
|
|
2153
|
+
auto adj_b_reg = tile_register_like<TileB>();
|
|
2154
|
+
|
|
2155
|
+
using Layout = typename decltype(adj_a_reg)::Layout;
|
|
2156
|
+
|
|
2157
|
+
WP_PRAGMA_UNROLL
|
|
2158
|
+
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2159
|
+
{
|
|
2160
|
+
const int linear = Layout::linear_from_register(i);
|
|
2161
|
+
|
|
2162
|
+
if(!Layout::valid(linear))
|
|
2163
|
+
break;
|
|
2164
|
+
|
|
2165
|
+
adj_b_reg.data[i] -= adj_a_reg.data[i];
|
|
2166
|
+
}
|
|
2167
|
+
|
|
2168
|
+
adj_b.grad_add(adj_b_reg);
|
|
2169
|
+
}
|
|
2170
|
+
|
|
2171
|
+
|
|
1964
2172
|
template<typename Tile>
|
|
1965
2173
|
typename Tile::Type tile_extract(Tile& t, int i) { return t.extract(tile_coord(i)); }
|
|
1966
2174
|
template<typename Tile>
|
|
@@ -1970,7 +2178,6 @@ typename Tile::Type tile_extract(Tile& t, int i, int j, int k) { return t.extrac
|
|
|
1970
2178
|
template<typename Tile>
|
|
1971
2179
|
typename Tile::Type tile_extract(Tile& t, int i, int j, int k, int l) { return t.extract(tile_coord(i,j,k,l)); }
|
|
1972
2180
|
|
|
1973
|
-
|
|
1974
2181
|
template<typename Tile, typename AdjTile>
|
|
1975
2182
|
void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i), adj_ret); }
|
|
1976
2183
|
template<typename Tile, typename AdjTile>
|
|
@@ -1981,6 +2188,42 @@ template<typename Tile, typename AdjTile>
|
|
|
1981
2188
|
void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
|
|
1982
2189
|
|
|
1983
2190
|
|
|
2191
|
+
template<typename Tile>
|
|
2192
|
+
void tile_add_inplace(Tile& t, int i, typename Tile::Type value) { t.add_inplace(tile_coord(i), value); }
|
|
2193
|
+
template<typename Tile>
|
|
2194
|
+
void tile_add_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.add_inplace(tile_coord(i,j), value); }
|
|
2195
|
+
template<typename Tile>
|
|
2196
|
+
void tile_add_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.add_inplace(tile_coord(i,j,k), value); }
|
|
2197
|
+
template<typename Tile>
|
|
2198
|
+
void tile_add_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.add_inplace(tile_coord(i,j,k,l), value); }
|
|
2199
|
+
|
|
2200
|
+
template<typename Tile>
|
|
2201
|
+
void tile_sub_inplace(Tile& t, int i, typename Tile::Type value) { t.sub_inplace(tile_coord(i), value); }
|
|
2202
|
+
template<typename Tile>
|
|
2203
|
+
void tile_sub_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j), value); }
|
|
2204
|
+
template<typename Tile>
|
|
2205
|
+
void tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k), value); }
|
|
2206
|
+
template<typename Tile>
|
|
2207
|
+
void tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k,l), value); }
|
|
2208
|
+
|
|
2209
|
+
template<typename Tile, typename AdjTile>
|
|
2210
|
+
void adj_tile_add_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i), adj_value); }
|
|
2211
|
+
template<typename Tile, typename AdjTile>
|
|
2212
|
+
void adj_tile_add_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j), adj_value); }
|
|
2213
|
+
template<typename Tile, typename AdjTile>
|
|
2214
|
+
void adj_tile_add_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j, k), adj_value); }
|
|
2215
|
+
template<typename Tile, typename AdjTile>
|
|
2216
|
+
void adj_tile_add_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j, k, l), adj_value); }
|
|
2217
|
+
|
|
2218
|
+
template<typename Tile, typename AdjTile>
|
|
2219
|
+
void adj_tile_sub_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i), adj_value); }
|
|
2220
|
+
template<typename Tile, typename AdjTile>
|
|
2221
|
+
void adj_tile_sub_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j), adj_value); }
|
|
2222
|
+
template<typename Tile, typename AdjTile>
|
|
2223
|
+
void adj_tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k), adj_value); }
|
|
2224
|
+
template<typename Tile, typename AdjTile>
|
|
2225
|
+
void adj_tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k, l), adj_value); }
|
|
2226
|
+
|
|
1984
2227
|
namespace partitioned_gemm
|
|
1985
2228
|
{
|
|
1986
2229
|
|
|
@@ -2177,33 +2420,98 @@ inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
|
|
|
2177
2420
|
}
|
|
2178
2421
|
}
|
|
2179
2422
|
|
|
2423
|
+
// Writes into X
|
|
2180
2424
|
template <typename TileL, typename TileX, typename TileY>
|
|
2181
|
-
inline CUDA_CALLABLE void
|
|
2425
|
+
inline CUDA_CALLABLE void scalar_cholesky_forward_substitution(TileL& L, TileX& X, TileY& Y)
|
|
2182
2426
|
{
|
|
2183
|
-
using T = typename TileL::Type;
|
|
2184
|
-
constexpr int n = TileL::Layout::Shape::dim(1);
|
|
2427
|
+
using T = typename TileL::Type;
|
|
2185
2428
|
|
|
2186
|
-
|
|
2429
|
+
if constexpr (TileY::Layout::Shape::N == 1)
|
|
2187
2430
|
{
|
|
2188
|
-
|
|
2431
|
+
constexpr int n = TileL::Layout::Shape::dim(1);
|
|
2432
|
+
|
|
2433
|
+
for (int i=0; i < n; ++i)
|
|
2434
|
+
{
|
|
2435
|
+
T s = Y.data(tile_coord(i));
|
|
2189
2436
|
|
|
2190
|
-
|
|
2191
|
-
|
|
2437
|
+
for (int j=0; j < i; ++j)
|
|
2438
|
+
s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
|
|
2192
2439
|
|
|
2193
|
-
|
|
2440
|
+
T diag = L.data(tile_coord(i, i));
|
|
2441
|
+
X.data(tile_coord(i)) = (diag != T(0.0f)) ? s / diag : s;
|
|
2442
|
+
}
|
|
2194
2443
|
}
|
|
2444
|
+
else if constexpr (TileY::Layout::Shape::N == 2)
|
|
2445
|
+
{
|
|
2446
|
+
constexpr int n = TileL::Layout::Shape::dim(1);
|
|
2447
|
+
constexpr int m = TileY::Layout::Shape::dim(1);
|
|
2448
|
+
|
|
2449
|
+
for (int k=0; k < m; ++k)
|
|
2450
|
+
{
|
|
2451
|
+
for (int i=0; i < n; ++i)
|
|
2452
|
+
{
|
|
2453
|
+
T s = Y.data(tile_coord(i,k));
|
|
2454
|
+
|
|
2455
|
+
for (int j=0; j < i; ++j)
|
|
2456
|
+
s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j,k));
|
|
2457
|
+
|
|
2458
|
+
T diag = L.data(tile_coord(i, i));
|
|
2459
|
+
X.data(tile_coord(i,k)) = (diag != T(0.0f)) ? s / diag : s;
|
|
2460
|
+
}
|
|
2461
|
+
}
|
|
2462
|
+
}
|
|
2463
|
+
}
|
|
2464
|
+
|
|
2465
|
+
// Reads and writes X
|
|
2466
|
+
template <typename TileL, typename TileX>
|
|
2467
|
+
inline CUDA_CALLABLE void scalar_cholesky_back_substitution(TileL& L, TileX& X)
|
|
2468
|
+
{
|
|
2469
|
+
using T = typename TileL::Type;
|
|
2470
|
+
|
|
2471
|
+
if constexpr (TileX::Layout::Shape::N == 1)
|
|
2472
|
+
{
|
|
2473
|
+
constexpr int n = TileL::Layout::Shape::dim(1);
|
|
2474
|
+
|
|
2475
|
+
for (int i=n-1; i >= 0; --i)
|
|
2476
|
+
{
|
|
2477
|
+
T s = X.data(tile_coord(i));
|
|
2195
2478
|
|
|
2196
|
-
|
|
2479
|
+
for (int j=i+1; j < n; ++j)
|
|
2480
|
+
s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
|
|
2481
|
+
|
|
2482
|
+
T diag = L.data(tile_coord(i, i));
|
|
2483
|
+
X.data(tile_coord(i)) = (diag != T(0.0f)) ? s / diag : s;
|
|
2484
|
+
}
|
|
2485
|
+
}
|
|
2486
|
+
else if constexpr (TileX::Layout::Shape::N == 2)
|
|
2197
2487
|
{
|
|
2198
|
-
|
|
2488
|
+
constexpr int n = TileL::Layout::Shape::dim(1);
|
|
2489
|
+
constexpr int m = TileX::Layout::Shape::dim(1);
|
|
2199
2490
|
|
|
2200
|
-
for (int
|
|
2201
|
-
|
|
2491
|
+
for (int k=0; k < m; ++k)
|
|
2492
|
+
{
|
|
2493
|
+
for (int i=n-1; i >= 0; --i)
|
|
2494
|
+
{
|
|
2495
|
+
T s = X.data(tile_coord(i,k));
|
|
2496
|
+
|
|
2497
|
+
for (int j=i+1; j < n; ++j)
|
|
2498
|
+
s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j,k));
|
|
2202
2499
|
|
|
2203
|
-
|
|
2500
|
+
T diag = L.data(tile_coord(i, i));
|
|
2501
|
+
X.data(tile_coord(i,k)) = (diag != T(0.0f)) ? s / diag : s;
|
|
2502
|
+
}
|
|
2503
|
+
}
|
|
2204
2504
|
}
|
|
2205
2505
|
}
|
|
2206
2506
|
|
|
2507
|
+
template <typename TileL, typename TileX, typename TileY>
|
|
2508
|
+
inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
|
|
2509
|
+
{
|
|
2510
|
+
scalar_cholesky_forward_substitution(L, X, Y);
|
|
2511
|
+
scalar_cholesky_back_substitution(L, X);
|
|
2512
|
+
}
|
|
2513
|
+
|
|
2514
|
+
|
|
2207
2515
|
} // namespace partition_gemm
|
|
2208
2516
|
|
|
2209
2517
|
|
|
@@ -2223,12 +2531,14 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
|
|
|
2223
2531
|
static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
|
|
2224
2532
|
|
|
2225
2533
|
|
|
2226
|
-
using T = typename
|
|
2534
|
+
using T = typename TileC::Type;
|
|
2227
2535
|
|
|
2228
2536
|
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2229
2537
|
partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
|
|
2230
2538
|
#else
|
|
2231
|
-
|
|
2539
|
+
T alpha = T(1.0);
|
|
2540
|
+
T beta = T(Add);
|
|
2541
|
+
fun_forward(&alpha, A.data.ptr, B.data.ptr, &beta, C.data.ptr);
|
|
2232
2542
|
#endif
|
|
2233
2543
|
|
|
2234
2544
|
WP_TILE_SYNC();
|
|
@@ -2242,17 +2552,22 @@ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename T
|
|
|
2242
2552
|
void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
|
|
2243
2553
|
Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C)
|
|
2244
2554
|
{
|
|
2245
|
-
using
|
|
2555
|
+
using T_A = typename TileA::Type;
|
|
2556
|
+
using T_B = typename TileB::Type;
|
|
2246
2557
|
|
|
2247
2558
|
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2248
2559
|
auto At = tile_transpose(A);
|
|
2249
2560
|
auto Bt = tile_transpose(B);
|
|
2250
2561
|
|
|
2251
|
-
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad,
|
|
2252
|
-
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad,
|
|
2562
|
+
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T_A(1.0));
|
|
2563
|
+
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T_B(1.0));
|
|
2253
2564
|
#else
|
|
2254
|
-
|
|
2255
|
-
|
|
2565
|
+
T_A alpha_A = T_A(1.0);
|
|
2566
|
+
T_A beta_A = T_A(1.0);
|
|
2567
|
+
fun_backward_A(&alpha_A, adj_C.grad.ptr, B.data.ptr, &beta_A, adj_A.grad.ptr);
|
|
2568
|
+
T_B alpha_B = T_B(1.0);
|
|
2569
|
+
T_B beta_B = T_B(1.0);
|
|
2570
|
+
fun_backward_B(&alpha_B, A.data.ptr, adj_C.grad.ptr, &beta_B, adj_B.grad.ptr);
|
|
2256
2571
|
#endif
|
|
2257
2572
|
|
|
2258
2573
|
WP_TILE_SYNC();
|
|
@@ -2263,7 +2578,7 @@ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename T
|
|
|
2263
2578
|
void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
|
|
2264
2579
|
Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C, TileC& adj_ret)
|
|
2265
2580
|
{
|
|
2266
|
-
using T = typename
|
|
2581
|
+
using T = typename TileC::Type;
|
|
2267
2582
|
|
|
2268
2583
|
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2269
2584
|
auto At = tile_transpose(A);
|
|
@@ -2272,8 +2587,10 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2272
2587
|
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
|
|
2273
2588
|
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
|
|
2274
2589
|
#else
|
|
2275
|
-
|
|
2276
|
-
|
|
2590
|
+
T alpha = T(1.0);
|
|
2591
|
+
T beta = T(1.0);
|
|
2592
|
+
fun_backward_A(&alpha, adj_C.grad.ptr, B.data.ptr, &beta, adj_A.grad.ptr);
|
|
2593
|
+
fun_backward_B(&alpha, A.data.ptr, adj_C.grad.ptr, &beta, adj_B.grad.ptr);
|
|
2277
2594
|
#endif
|
|
2278
2595
|
|
|
2279
2596
|
WP_TILE_SYNC();
|
|
@@ -2293,13 +2610,13 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2293
2610
|
// and remove the need for __align__(16) dtypes data[...]
|
|
2294
2611
|
#define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
|
|
2295
2612
|
do { \
|
|
2296
|
-
void function_name(dtype*,
|
|
2613
|
+
void function_name(dtype*, char*); \
|
|
2297
2614
|
char* buffer = (char*)wp::tile_alloc_shared(shared_memory_size); \
|
|
2298
2615
|
__align__(16) dtype data[ept]; \
|
|
2299
2616
|
for(int b = 0; b < (int)batch_size; b++) { \
|
|
2300
2617
|
dtype* inout = Xinout.data + (int)b * (int)ept; \
|
|
2301
2618
|
memcpy(data, inout, sizeof(dtype) * ept); \
|
|
2302
|
-
function_name(data,
|
|
2619
|
+
function_name(data, buffer); \
|
|
2303
2620
|
memcpy(inout, data, sizeof(dtype) * ept); \
|
|
2304
2621
|
WP_TILE_SYNC(); \
|
|
2305
2622
|
} \
|
|
@@ -2328,7 +2645,15 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2328
2645
|
|
|
2329
2646
|
template <typename Fwd, typename TileA, typename TileL>
|
|
2330
2647
|
TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
2331
|
-
{
|
|
2648
|
+
{
|
|
2649
|
+
static_assert(TileA::Layout::Shape::N == 2, "Expected TileA::Layout::Shape::N == 2");
|
|
2650
|
+
static_assert(TileL::Layout::Shape::N == 2, "Expected TileL::Layout::Shape::N == 2");
|
|
2651
|
+
|
|
2652
|
+
static_assert(TileA::Layout::Shape::dim(0) == TileA::Layout::Shape::dim(1), "Expected TileA to be square");
|
|
2653
|
+
static_assert(TileL::Layout::Shape::dim(0) == TileL::Layout::Shape::dim(1), "Expected TileL to be square");
|
|
2654
|
+
static_assert(TileA::Layout::Shape::dim(0) == TileL::Layout::Shape::dim(0), "Expected A and L to have the same number of rows");
|
|
2655
|
+
static_assert(TileA::Layout::Shape::dim(1) == TileL::Layout::Shape::dim(1), "Expected A and L to have the same number of columns");
|
|
2656
|
+
|
|
2332
2657
|
// Copy to L
|
|
2333
2658
|
L = A;
|
|
2334
2659
|
|
|
@@ -2338,14 +2663,27 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
|
2338
2663
|
|
|
2339
2664
|
#else
|
|
2340
2665
|
|
|
2666
|
+
// TODO: for batched Cholesky, need one info per batch
|
|
2667
|
+
WP_TILE_SHARED int info[1];
|
|
2668
|
+
|
|
2669
|
+
if (WP_TILE_THREAD_IDX == 0) {
|
|
2670
|
+
info[0] = 0;
|
|
2671
|
+
}
|
|
2341
2672
|
|
|
2342
2673
|
// Call cholesky on L
|
|
2343
2674
|
WP_TILE_SYNC();
|
|
2344
2675
|
|
|
2345
|
-
fun_forward(L.data.ptr,
|
|
2676
|
+
fun_forward(L.data.ptr, info);
|
|
2346
2677
|
|
|
2347
2678
|
WP_TILE_SYNC();
|
|
2348
2679
|
|
|
2680
|
+
// TODO: for batched Cholesky, check all batches
|
|
2681
|
+
#if defined(_DEBUG)
|
|
2682
|
+
if (WP_TILE_THREAD_IDX == 0 && info[0] != 0) {
|
|
2683
|
+
printf("Non-zero status in Cholesky factorization, got %d\n", info[0]);
|
|
2684
|
+
}
|
|
2685
|
+
#endif
|
|
2686
|
+
|
|
2349
2687
|
// Zero-out the upper triangular part of L
|
|
2350
2688
|
|
|
2351
2689
|
WP_PRAGMA_UNROLL
|
|
@@ -2371,11 +2709,11 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
|
2371
2709
|
} while (0)
|
|
2372
2710
|
|
|
2373
2711
|
template <typename Fwd, typename TileL, typename TileX, typename TileY>
|
|
2374
|
-
TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX&
|
|
2712
|
+
TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& Y, TileY& X)
|
|
2375
2713
|
{
|
|
2376
|
-
// Copy
|
|
2714
|
+
// Copy y to x
|
|
2377
2715
|
|
|
2378
|
-
|
|
2716
|
+
X = Y;
|
|
2379
2717
|
|
|
2380
2718
|
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2381
2719
|
|
|
@@ -2383,24 +2721,99 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
|
2383
2721
|
|
|
2384
2722
|
#else
|
|
2385
2723
|
|
|
2386
|
-
// Call cholesky solve on L &
|
|
2724
|
+
// Call cholesky solve on L & x
|
|
2725
|
+
|
|
2726
|
+
WP_TILE_SYNC();
|
|
2727
|
+
|
|
2728
|
+
fun_forward(L.data.ptr, X.data.ptr); \
|
|
2729
|
+
|
|
2730
|
+
WP_TILE_SYNC();
|
|
2731
|
+
|
|
2732
|
+
#endif
|
|
2733
|
+
|
|
2734
|
+
return X;
|
|
2735
|
+
}
|
|
2736
|
+
|
|
2737
|
+
#define adj_tile_cholesky_solve(function_name, L, Y, X, \
|
|
2738
|
+
adj_function_name, adj_L, adj_Y, adj_X, adj_ret) \
|
|
2739
|
+
do { \
|
|
2740
|
+
assert(false); \
|
|
2741
|
+
} while (0)
|
|
2742
|
+
|
|
2743
|
+
|
|
2744
|
+
|
|
2745
|
+
|
|
2746
|
+
|
|
2747
|
+
|
|
2748
|
+
template <typename Fwd, typename TileL, typename TileY, typename TileZ>
|
|
2749
|
+
TileZ& tile_lower_solve(Fwd fun_forward, TileL& L, TileY& y, TileZ& z)
|
|
2750
|
+
{
|
|
2751
|
+
// Copy y to z
|
|
2752
|
+
z = y;
|
|
2753
|
+
|
|
2754
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2755
|
+
|
|
2756
|
+
partitioned_gemm::scalar_cholesky_forward_substitution(L, z, y);
|
|
2757
|
+
|
|
2758
|
+
#else
|
|
2759
|
+
|
|
2760
|
+
// Call cholesky solve on L & z
|
|
2387
2761
|
|
|
2388
2762
|
WP_TILE_SYNC();
|
|
2389
2763
|
|
|
2390
|
-
fun_forward(L.data.ptr,
|
|
2764
|
+
fun_forward(L.data.ptr, z.data.ptr);
|
|
2391
2765
|
|
|
2392
2766
|
WP_TILE_SYNC();
|
|
2393
2767
|
|
|
2394
2768
|
#endif
|
|
2395
2769
|
|
|
2396
|
-
return
|
|
2770
|
+
return z;
|
|
2397
2771
|
}
|
|
2398
2772
|
|
|
2399
|
-
#define
|
|
2400
|
-
|
|
2773
|
+
#define adj_tile_lower_solve(function_name, L, y, z, \
|
|
2774
|
+
adj_function_name, adj_L, adj_y, adj_z, adj_ret) \
|
|
2401
2775
|
do { \
|
|
2402
2776
|
assert(false); \
|
|
2403
2777
|
} while (0)
|
|
2778
|
+
|
|
2779
|
+
|
|
2780
|
+
|
|
2781
|
+
template <typename Fwd, typename TileU, typename TileZ, typename TileX>
|
|
2782
|
+
TileX& tile_upper_solve(Fwd fun_forward, TileU& U, TileZ& z, TileX& x)
|
|
2783
|
+
{
|
|
2784
|
+
// Copy z to x
|
|
2785
|
+
x = z;
|
|
2786
|
+
|
|
2787
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2788
|
+
|
|
2789
|
+
auto L = tile_transpose(U);
|
|
2790
|
+
partitioned_gemm::scalar_cholesky_back_substitution(L, x);
|
|
2791
|
+
|
|
2792
|
+
#else
|
|
2793
|
+
|
|
2794
|
+
// Call cholesky solve on U & x
|
|
2795
|
+
|
|
2796
|
+
WP_TILE_SYNC();
|
|
2797
|
+
|
|
2798
|
+
fun_forward(U.data.ptr, x.data.ptr);
|
|
2799
|
+
|
|
2800
|
+
WP_TILE_SYNC();
|
|
2801
|
+
|
|
2802
|
+
#endif
|
|
2803
|
+
|
|
2804
|
+
return x;
|
|
2805
|
+
}
|
|
2806
|
+
|
|
2807
|
+
#define adj_tile_upper_solve(function_name, U, z, x, \
|
|
2808
|
+
adj_function_name, adj_U, adj_z, adj_x, adj_ret) \
|
|
2809
|
+
do { \
|
|
2810
|
+
assert(false); \
|
|
2811
|
+
} while (0)
|
|
2812
|
+
|
|
2813
|
+
|
|
2814
|
+
|
|
2815
|
+
|
|
2816
|
+
|
|
2404
2817
|
|
|
2405
2818
|
template <typename Tile>
|
|
2406
2819
|
inline CUDA_CALLABLE auto tile_transpose(Tile& t)
|
|
@@ -2457,10 +2870,11 @@ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
|
2457
2870
|
template <typename Tile, typename AdjTile>
|
|
2458
2871
|
inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
2459
2872
|
{
|
|
2460
|
-
// nop, since memory is aliased grads already accumulated
|
|
2873
|
+
// nop, since memory is aliased, grads already accumulated
|
|
2461
2874
|
}
|
|
2462
2875
|
|
|
2463
|
-
|
|
2876
|
+
|
|
2877
|
+
template <typename ReturnTile, typename Tile, typename... Indices>
|
|
2464
2878
|
inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
|
|
2465
2879
|
{
|
|
2466
2880
|
auto c = tile_coord(indices...);
|
|
@@ -2472,7 +2886,104 @@ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
|
|
|
2472
2886
|
if (t.grad.ptr)
|
|
2473
2887
|
grad_ptr = &t.grad(c);
|
|
2474
2888
|
|
|
2475
|
-
return
|
|
2889
|
+
return ReturnTile(data_ptr, grad_ptr);
|
|
2890
|
+
}
|
|
2891
|
+
|
|
2892
|
+
|
|
2893
|
+
template <typename ReturnTile, typename Tile>
|
|
2894
|
+
inline CUDA_CALLABLE auto tile_squeeze(Tile& t)
|
|
2895
|
+
{
|
|
2896
|
+
// ReturnTile layout is set in builtins.py
|
|
2897
|
+
typename Tile::Type* data_ptr = t.data.ptr;
|
|
2898
|
+
typename Tile::Type* grad_ptr = nullptr;
|
|
2899
|
+
|
|
2900
|
+
if (t.grad.ptr)
|
|
2901
|
+
grad_ptr = t.grad.ptr;
|
|
2902
|
+
|
|
2903
|
+
return ReturnTile(data_ptr, grad_ptr);
|
|
2904
|
+
}
|
|
2905
|
+
|
|
2906
|
+
template <typename Tile, typename AdjTile, typename AdjReturnTile>
|
|
2907
|
+
inline CUDA_CALLABLE void adj_tile_squeeze(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
|
|
2908
|
+
{
|
|
2909
|
+
// nop, since memory is aliased, grads already accumulated
|
|
2910
|
+
}
|
|
2911
|
+
|
|
2912
|
+
|
|
2913
|
+
template <typename ReturnTile, typename Tile>
|
|
2914
|
+
inline CUDA_CALLABLE auto tile_reshape(Tile& t)
|
|
2915
|
+
{
|
|
2916
|
+
// ReturnTile layout is set in builtins.py
|
|
2917
|
+
typename Tile::Type* data_ptr = t.data.ptr;
|
|
2918
|
+
typename Tile::Type* grad_ptr = nullptr;
|
|
2919
|
+
|
|
2920
|
+
if (t.grad.ptr)
|
|
2921
|
+
grad_ptr = t.grad.ptr;
|
|
2922
|
+
|
|
2923
|
+
return ReturnTile(data_ptr, grad_ptr);
|
|
2924
|
+
}
|
|
2925
|
+
|
|
2926
|
+
template <typename Tile, typename AdjTile, typename AdjReturnTile>
|
|
2927
|
+
inline CUDA_CALLABLE void adj_tile_reshape(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
|
|
2928
|
+
{
|
|
2929
|
+
// nop, since memory is aliased, grads already accumulated
|
|
2930
|
+
}
|
|
2931
|
+
|
|
2932
|
+
|
|
2933
|
+
template <typename ReturnTile, typename Tile>
|
|
2934
|
+
inline CUDA_CALLABLE auto tile_astype(Tile& t)
|
|
2935
|
+
{
|
|
2936
|
+
// verify shapes and sizes are compatible
|
|
2937
|
+
using ShapeIn = typename Tile::Layout::Shape;
|
|
2938
|
+
using ShapeOut = typename ReturnTile::Layout::Shape;
|
|
2939
|
+
|
|
2940
|
+
static_assert(ShapeIn::N == ShapeOut::N, "Tile shapes must match for data type casting");
|
|
2941
|
+
static_assert(ShapeIn::size() == ShapeOut::size(), "Tile sizes must match for data type casting");
|
|
2942
|
+
|
|
2943
|
+
// work with register tiles for type casting
|
|
2944
|
+
auto t_reg = t.copy_to_register();
|
|
2945
|
+
auto result = tile_register_like<ReturnTile>();
|
|
2946
|
+
|
|
2947
|
+
using Layout = typename decltype(result)::Layout;
|
|
2948
|
+
|
|
2949
|
+
WP_PRAGMA_UNROLL
|
|
2950
|
+
for (int i = 0; i < Layout::NumRegs; ++i)
|
|
2951
|
+
{
|
|
2952
|
+
const int linear = Layout::linear_from_register(i);
|
|
2953
|
+
|
|
2954
|
+
if(!Layout::valid(linear))
|
|
2955
|
+
break;
|
|
2956
|
+
|
|
2957
|
+
result.data[i] = static_cast<typename ReturnTile::Type>(t_reg.data[i]);
|
|
2958
|
+
}
|
|
2959
|
+
|
|
2960
|
+
return result;
|
|
2961
|
+
}
|
|
2962
|
+
|
|
2963
|
+
template <typename Tile, typename AdjTile, typename AdjReturnTile>
|
|
2964
|
+
inline CUDA_CALLABLE void adj_tile_astype(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
|
|
2965
|
+
{
|
|
2966
|
+
// gradients only flow between float conversions
|
|
2967
|
+
if constexpr((is_same<typename AdjTile::Type, wp::float16>::value ||
|
|
2968
|
+
is_same<typename AdjTile::Type, wp::float32>::value ||
|
|
2969
|
+
is_same<typename AdjTile::Type, wp::float64>::value) &&
|
|
2970
|
+
(is_same<typename AdjReturnTile::Type, wp::float16>::value ||
|
|
2971
|
+
is_same<typename AdjReturnTile::Type, wp::float32>::value ||
|
|
2972
|
+
is_same<typename AdjReturnTile::Type, wp::float64>::value))
|
|
2973
|
+
{
|
|
2974
|
+
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
2975
|
+
auto adj_t_reg = tile_register_like<AdjTile>();
|
|
2976
|
+
|
|
2977
|
+
using Layout = typename decltype(adj_t_reg)::Layout;
|
|
2978
|
+
|
|
2979
|
+
WP_PRAGMA_UNROLL
|
|
2980
|
+
for (int i = 0; i < Layout::NumRegs; ++i)
|
|
2981
|
+
{
|
|
2982
|
+
adj_t_reg.data[i] += static_cast<typename AdjTile::Type>(adj_ret_reg.data[i]);
|
|
2983
|
+
}
|
|
2984
|
+
|
|
2985
|
+
adj_t.grad_add(adj_t_reg);
|
|
2986
|
+
}
|
|
2476
2987
|
}
|
|
2477
2988
|
|
|
2478
2989
|
|
|
@@ -2504,21 +3015,41 @@ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const
|
|
|
2504
3015
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
2505
3016
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, const Scalar& src, AdjTileA& adj_dest, int adj_i, Scalar& adj_src)
|
|
2506
3017
|
{
|
|
3018
|
+
if (dest.grad.ptr == nullptr)
|
|
3019
|
+
{
|
|
3020
|
+
return;
|
|
3021
|
+
}
|
|
3022
|
+
|
|
2507
3023
|
adj_src += dest.grad(tile_coord(i));
|
|
2508
3024
|
}
|
|
2509
3025
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
2510
3026
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, Scalar& adj_src)
|
|
2511
3027
|
{
|
|
3028
|
+
if (dest.grad.ptr == nullptr)
|
|
3029
|
+
{
|
|
3030
|
+
return;
|
|
3031
|
+
}
|
|
3032
|
+
|
|
2512
3033
|
adj_src += dest.grad(tile_coord(i, j));
|
|
2513
3034
|
}
|
|
2514
3035
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
2515
3036
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, Scalar& adj_src)
|
|
2516
3037
|
{
|
|
3038
|
+
if (dest.grad.ptr == nullptr)
|
|
3039
|
+
{
|
|
3040
|
+
return;
|
|
3041
|
+
}
|
|
3042
|
+
|
|
2517
3043
|
adj_src += dest.grad(tile_coord(i, j, k));
|
|
2518
3044
|
}
|
|
2519
3045
|
template <typename TileA, typename AdjTileA, typename Scalar>
|
|
2520
3046
|
inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, Scalar& adj_src)
|
|
2521
3047
|
{
|
|
3048
|
+
if (dest.grad.ptr == nullptr)
|
|
3049
|
+
{
|
|
3050
|
+
return;
|
|
3051
|
+
}
|
|
3052
|
+
|
|
2522
3053
|
adj_src += dest.grad(tile_coord(i, j, k, l));
|
|
2523
3054
|
}
|
|
2524
3055
|
|
|
@@ -2601,7 +3132,6 @@ inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
|
|
|
2601
3132
|
template <typename TileA, typename TileB, typename TileC, typename AdjTileA, typename AdjTileB, typename AdjTileC>
|
|
2602
3133
|
inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c, AdjTileC& adj_ret)
|
|
2603
3134
|
{
|
|
2604
|
-
assert(false);
|
|
2605
3135
|
}
|
|
2606
3136
|
|
|
2607
3137
|
|