warp-lang 1.7.2__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/native/tile.h CHANGED
@@ -803,7 +803,7 @@ struct tile_layout_strided_t
803
803
  }
804
804
 
805
805
  // checks whether a strided layout is unique, i.e.: if memory locations are only
806
- // every referred to by one element in the tile, this is a basic test that only
806
+ // ever referred to by one element in the tile, this is a basic test that only
807
807
  // checks for broadcast dimensions, it would be possible to do the full check
808
808
  // using sorted shape/strides in Python and add it as a template parameter to the type
809
809
  static constexpr bool is_unique()
@@ -912,33 +912,27 @@ struct tile_shared_t
912
912
  }
913
913
 
914
914
  // assign from a register tile
915
- template <typename Tile>
916
- inline CUDA_CALLABLE auto& operator=(const Tile& t)
915
+ inline CUDA_CALLABLE auto& operator=(const tile_register_t<Type, tile_layout_register_t<typename Layout::Shape>>& t)
917
916
  {
918
917
  assign(t);
919
918
  return *this;
920
919
  }
921
920
 
922
-
923
- /*
924
921
  // construct from another shared tile, this constructor
925
922
  // is invoked for reshape operations like `wp.tile_transpose()`
926
- template <typename OtherT, typename OtherLayout>
927
- inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout>& rhs)
923
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
924
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& rhs)
928
925
  {
929
- using OtherTile = tile_shared_t<OtherT, OtherLayout>;
930
-
931
926
  // check dimensions are compatible
932
- static_assert(Size == OtherTile::Size, "Expected Size == OtherTile::Size");
927
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
933
928
 
934
929
  // alias tile directly
935
- data = rhs.data;
936
- grad = rhs.grad;
930
+ data.ptr = rhs.data.ptr;
931
+ grad.ptr = rhs.grad.ptr;
937
932
  initialized = rhs.initialized;
938
933
 
939
934
  return *this;
940
935
  }
941
- */
942
936
 
943
937
  // assign from a global tile (load)
944
938
  inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
@@ -989,6 +983,37 @@ struct tile_shared_t
989
983
  WP_TILE_SYNC();
990
984
  }
991
985
 
986
+ // add scalar value onto a single tile element
987
+ inline CUDA_CALLABLE void add_inplace(const typename Layout::Coord& c, const Type& x)
988
+ {
989
+ // since multiple threads may add to the same element
990
+ // we need to accumulate using atomic operations
991
+ wp::atomic_add(&data(c), x);
992
+
993
+ WP_TILE_SYNC();
994
+ }
995
+
996
+ // backward of inplace scalar addition
997
+ inline CUDA_CALLABLE void adj_add_inplace(const typename Layout::Coord& c, Type& adj_x)
998
+ {
999
+ adj_x += grad(c);
1000
+ }
1001
+
1002
+ // subtract scalar value from a single tile element
1003
+ inline CUDA_CALLABLE void sub_inplace(const typename Layout::Coord& c, const Type& x)
1004
+ {
1005
+ // since multiple threads may add to the same element
1006
+ // we need to accumulate using atomic operations
1007
+ wp::atomic_add(&data(c), -x);
1008
+
1009
+ WP_TILE_SYNC();
1010
+ }
1011
+
1012
+ // backward of inplace scalar subtraction
1013
+ inline CUDA_CALLABLE void adj_sub_inplace(const typename Layout::Coord& c, Type& adj_x)
1014
+ {
1015
+ adj_x -= grad(c);
1016
+ }
992
1017
 
993
1018
  // copy register tile to shared
994
1019
  template <typename Tile>
@@ -1472,10 +1497,10 @@ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const t
1472
1497
 
1473
1498
 
1474
1499
  // helpers to allocate shared tiles
1475
- template <typename T, typename Shape, bool RequiresGrad>
1500
+ template <typename T, typename Shape, typename Strides, bool RequiresGrad>
1476
1501
  inline CUDA_CALLABLE auto tile_alloc_empty()
1477
-
1478
- { constexpr int size = Shape::size();
1502
+ {
1503
+ constexpr int size = Shape::size();
1479
1504
  T* data = (T*)tile_alloc_shared(size*sizeof(T));
1480
1505
  T* grad = nullptr;
1481
1506
 
@@ -1503,7 +1528,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1503
1528
  WP_TILE_SYNC();
1504
1529
  }
1505
1530
 
1506
- return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
1531
+ return tile_shared_t<T, tile_layout_strided_t<Shape, Strides>>(data, grad);
1507
1532
  }
1508
1533
 
1509
1534
 
@@ -1532,37 +1557,56 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1532
1557
  using Layout = typename decltype(result)::Layout;
1533
1558
  static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
1534
1559
 
1535
- for (int i=0; i < Length; ++i)
1560
+ for (unsigned i=0; i < Length; ++i)
1536
1561
  result.data[i] = x[i];
1537
1562
 
1538
1563
  return result;
1539
1564
  }
1540
1565
 
1541
- // construct a tile from a local SIMT value (one per-thread)
1542
- template <typename T, typename AdjTile>
1543
- inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1566
+ // overload for constructing a tile from a per-thread matrix
1567
+ template <unsigned Rows, unsigned Cols, typename T>
1568
+ inline CUDA_CALLABLE auto tile(const wp::mat_t<Rows, Cols, T>& x)
1544
1569
  {
1545
- static_assert(AdjTile::Layout::Shape::N == 1, "Expected AdjTile::Layout::Shape::N == 1");
1546
- static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM");
1570
+ tile_register_t<T, tile_layout_register_t<tile_shape_t<Rows, Cols, WP_TILE_BLOCK_DIM>>> result;
1547
1571
 
1548
- auto adj_reg = adj_ret.copy_to_register();
1572
+ using Layout = typename decltype(result)::Layout;
1573
+ static_assert(Layout::NumRegs == Rows*Cols, "Expected Layout::NumRegs == Rows*Cols");
1574
+
1575
+ for (unsigned i=0; i < Rows; ++i)
1576
+ for (unsigned j=0; j < Cols; ++j)
1577
+ result.data[i*Cols + j] = x.data[i][j];
1549
1578
 
1550
- adj_x += adj_reg.data[0];
1579
+ return result;
1551
1580
  }
1552
1581
 
1553
- template <typename T, unsigned Length, typename AdjTile>
1554
- inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
1582
+ // it is sufficient to use a single adjoint for all tile overload funcs
1583
+ // it is also necessary, because we don't provide a dispatch_func for adjoint calls
1584
+ // so the compiler will default to choosing based on argument types
1585
+ template <typename T, typename AdjTile>
1586
+ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1555
1587
  {
1556
- static_assert(AdjTile::Layout::Shape::N == 2, "Expected AdjTile::Layout::Shape::N == 2");
1557
- static_assert(AdjTile::Layout::Shape::dim(0) == Length, "Expected AdjTile::Layout::Shape::dim(0) == Length");
1558
- static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM");
1559
-
1588
+ static_assert(AdjTile::Layout::Shape::dim(AdjTile::Layout::Shape::N - 1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(AdjTile::Layout::Shape::N - 1) == WP_TILE_BLOCK_DIM");
1589
+
1560
1590
  auto adj_reg = adj_ret.copy_to_register();
1561
1591
 
1562
- for (int i=0; i < Length; ++i)
1563
- adj_x[i] += adj_reg.data[i];
1592
+ if constexpr (AdjTile::Layout::Shape::N == 1)
1593
+ {
1594
+ adj_x += adj_reg.data[0];
1595
+ }
1596
+ else if constexpr (AdjTile::Layout::Shape::N == 2)
1597
+ {
1598
+ for (unsigned i=0; i < AdjTile::Layout::Shape::dim(0); ++i)
1599
+ adj_x[i] += adj_reg.data[i];
1600
+ }
1601
+ else if constexpr (AdjTile::Layout::Shape::N == 3)
1602
+ {
1603
+ for (unsigned i=0; i < AdjTile::Layout::Shape::dim(0); ++i)
1604
+ for (unsigned j=0; j < AdjTile::Layout::Shape::dim(1); ++j)
1605
+ adj_x.data[i][j] += adj_reg.data[i*AdjTile::Layout::Shape::dim(1) + j];
1606
+ }
1564
1607
  }
1565
1608
 
1609
+
1566
1610
  template <typename Tile>
1567
1611
  inline CUDA_CALLABLE auto untile(Tile& tile)
1568
1612
  {
@@ -1589,6 +1633,19 @@ inline CUDA_CALLABLE auto untile(Tile& tile)
1589
1633
 
1590
1634
  return v;
1591
1635
  }
1636
+
1637
+ // matrix case
1638
+ if constexpr(N == 3)
1639
+ {
1640
+ constexpr int Rows = Tile::Layout::Shape::dim(0);
1641
+ constexpr int Cols = Tile::Layout::Shape::dim(1);
1642
+ wp::mat_t<Rows, Cols, typename Tile::Type> m;
1643
+ for (int i=0; i < Rows; ++i)
1644
+ for (int j=0; j < Cols; ++j)
1645
+ m.data[i][j] = reg.data[i*Cols + j];
1646
+
1647
+ return m;
1648
+ }
1592
1649
  }
1593
1650
 
1594
1651
  template <typename Tile, typename Value>
@@ -1612,6 +1669,16 @@ inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
1612
1669
  adj.data[i] += adj_ret[i];
1613
1670
  }
1614
1671
 
1672
+ // matrix case
1673
+ if constexpr(N == 3)
1674
+ {
1675
+ constexpr int Rows = Tile::Layout::Shape::dim(0);
1676
+ constexpr int Cols = Tile::Layout::Shape::dim(1);
1677
+ for (int i=0; i < Rows; ++i)
1678
+ for (int j=0; j < Cols; ++j)
1679
+ adj.data[i*Cols + j] += adj_ret.data[i][j];
1680
+ }
1681
+
1615
1682
  adj_tile.assign(adj);
1616
1683
  }
1617
1684
 
@@ -1893,6 +1960,27 @@ inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
1893
1960
  return tile_binary_map(add, a, b);
1894
1961
  }
1895
1962
 
1963
+ // add overloads get called in user function adjoints generated by codegen (adj_tile += adj_ret)
1964
+ template <typename T, typename L>
1965
+ inline CUDA_CALLABLE auto add(tile_register_t<T, L>& a, const tile_register_t<T, L>& b) {
1966
+ return tile_add(a, b);
1967
+ }
1968
+
1969
+ template <typename T, typename L, bool Owner>
1970
+ inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b) {
1971
+ return tile_add(a, b);
1972
+ }
1973
+
1974
+ template <typename T, typename L, bool Owner>
1975
+ inline CUDA_CALLABLE auto add(tile_register_t<T, L>& a, const tile_shared_t<T, L, Owner>& b) {
1976
+ return tile_add(a, b);
1977
+ }
1978
+
1979
+ template <typename T, typename L, bool Owner>
1980
+ inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_register_t<T, L>& b) {
1981
+ return tile_add(a, b);
1982
+ }
1983
+
1896
1984
  template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1897
1985
  inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1898
1986
  {
@@ -1961,6 +2049,126 @@ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
1961
2049
  }
1962
2050
 
1963
2051
 
2052
+ template <typename TileA, typename TileB>
2053
+ inline CUDA_CALLABLE void tile_add_inplace(TileA& a, TileB& b)
2054
+ {
2055
+ using ShapeA = typename TileA::Layout::Shape;
2056
+ using ShapeB = typename TileB::Layout::Shape;
2057
+
2058
+ // verify shapes and sizes are compatible
2059
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace addition");
2060
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace addition");
2061
+
2062
+ auto a_reg = a.copy_to_register();
2063
+ auto b_reg = b.copy_to_register();
2064
+
2065
+ using Layout = typename decltype(b_reg)::Layout;
2066
+
2067
+ WP_PRAGMA_UNROLL
2068
+ for (int i=0; i < Layout::NumRegs; ++i)
2069
+ {
2070
+ const int linear = Layout::linear_from_register(i);
2071
+
2072
+ if(!Layout::valid(linear))
2073
+ break;
2074
+
2075
+ a_reg.data[i] += b_reg.data[i];
2076
+ }
2077
+
2078
+ a.assign(a_reg);
2079
+ }
2080
+
2081
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2082
+ inline CUDA_CALLABLE void adj_tile_add_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b)
2083
+ {
2084
+ using ShapeA = typename TileA::Layout::Shape;
2085
+ using ShapeB = typename TileB::Layout::Shape;
2086
+
2087
+ // verify shapes and sizes are compatible
2088
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace addition");
2089
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace addition");
2090
+
2091
+ // allocate storage for adjoints
2092
+ auto adj_a_reg = adj_a.grad_to_register();
2093
+ auto adj_b_reg = tile_register_like<TileB>();
2094
+
2095
+ using Layout = typename decltype(adj_a_reg)::Layout;
2096
+
2097
+ WP_PRAGMA_UNROLL
2098
+ for (int i=0; i < Layout::NumRegs; ++i)
2099
+ {
2100
+ const int linear = Layout::linear_from_register(i);
2101
+
2102
+ if(!Layout::valid(linear))
2103
+ break;
2104
+
2105
+ adj_b_reg.data[i] += adj_a_reg.data[i];
2106
+ }
2107
+
2108
+ adj_b.grad_add(adj_b_reg);
2109
+ }
2110
+
2111
+ template <typename TileA, typename TileB>
2112
+ inline CUDA_CALLABLE void tile_sub_inplace(TileA& a, TileB& b)
2113
+ {
2114
+ using ShapeA = typename TileA::Layout::Shape;
2115
+ using ShapeB = typename TileB::Layout::Shape;
2116
+
2117
+ // verify shapes and sizes are compatible
2118
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace subtraction");
2119
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace subtraction");
2120
+
2121
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2122
+ auto a_reg = a.copy_to_register();
2123
+ auto b_reg = b.copy_to_register();
2124
+
2125
+ using Layout = typename decltype(a_reg)::Layout;
2126
+
2127
+ WP_PRAGMA_UNROLL
2128
+ for (int i=0; i < Layout::NumRegs; ++i)
2129
+ {
2130
+ const int linear = Layout::linear_from_register(i);
2131
+
2132
+ if(!Layout::valid(linear))
2133
+ break;
2134
+
2135
+ a_reg.data[i] -= b_reg.data[i];
2136
+ }
2137
+
2138
+ a.assign(a_reg);
2139
+ }
2140
+
2141
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2142
+ inline CUDA_CALLABLE void adj_tile_sub_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b)
2143
+ {
2144
+ using ShapeA = typename TileA::Layout::Shape;
2145
+ using ShapeB = typename TileB::Layout::Shape;
2146
+
2147
+ // verify shapes and sizes are compatible
2148
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace subtraction");
2149
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace subtraction");
2150
+
2151
+ // allocate storage for adjoints
2152
+ auto adj_a_reg = adj_a.grad_to_register();
2153
+ auto adj_b_reg = tile_register_like<TileB>();
2154
+
2155
+ using Layout = typename decltype(adj_a_reg)::Layout;
2156
+
2157
+ WP_PRAGMA_UNROLL
2158
+ for (int i=0; i < Layout::NumRegs; ++i)
2159
+ {
2160
+ const int linear = Layout::linear_from_register(i);
2161
+
2162
+ if(!Layout::valid(linear))
2163
+ break;
2164
+
2165
+ adj_b_reg.data[i] -= adj_a_reg.data[i];
2166
+ }
2167
+
2168
+ adj_b.grad_add(adj_b_reg);
2169
+ }
2170
+
2171
+
1964
2172
  template<typename Tile>
1965
2173
  typename Tile::Type tile_extract(Tile& t, int i) { return t.extract(tile_coord(i)); }
1966
2174
  template<typename Tile>
@@ -1970,7 +2178,6 @@ typename Tile::Type tile_extract(Tile& t, int i, int j, int k) { return t.extrac
1970
2178
  template<typename Tile>
1971
2179
  typename Tile::Type tile_extract(Tile& t, int i, int j, int k, int l) { return t.extract(tile_coord(i,j,k,l)); }
1972
2180
 
1973
-
1974
2181
  template<typename Tile, typename AdjTile>
1975
2182
  void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i), adj_ret); }
1976
2183
  template<typename Tile, typename AdjTile>
@@ -1981,6 +2188,42 @@ template<typename Tile, typename AdjTile>
1981
2188
  void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
1982
2189
 
1983
2190
 
2191
+ template<typename Tile>
2192
+ void tile_add_inplace(Tile& t, int i, typename Tile::Type value) { t.add_inplace(tile_coord(i), value); }
2193
+ template<typename Tile>
2194
+ void tile_add_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.add_inplace(tile_coord(i,j), value); }
2195
+ template<typename Tile>
2196
+ void tile_add_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.add_inplace(tile_coord(i,j,k), value); }
2197
+ template<typename Tile>
2198
+ void tile_add_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.add_inplace(tile_coord(i,j,k,l), value); }
2199
+
2200
+ template<typename Tile>
2201
+ void tile_sub_inplace(Tile& t, int i, typename Tile::Type value) { t.sub_inplace(tile_coord(i), value); }
2202
+ template<typename Tile>
2203
+ void tile_sub_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j), value); }
2204
+ template<typename Tile>
2205
+ void tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k), value); }
2206
+ template<typename Tile>
2207
+ void tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k,l), value); }
2208
+
2209
+ template<typename Tile, typename AdjTile>
2210
+ void adj_tile_add_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i), adj_value); }
2211
+ template<typename Tile, typename AdjTile>
2212
+ void adj_tile_add_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j), adj_value); }
2213
+ template<typename Tile, typename AdjTile>
2214
+ void adj_tile_add_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j, k), adj_value); }
2215
+ template<typename Tile, typename AdjTile>
2216
+ void adj_tile_add_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j, k, l), adj_value); }
2217
+
2218
+ template<typename Tile, typename AdjTile>
2219
+ void adj_tile_sub_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i), adj_value); }
2220
+ template<typename Tile, typename AdjTile>
2221
+ void adj_tile_sub_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j), adj_value); }
2222
+ template<typename Tile, typename AdjTile>
2223
+ void adj_tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k), adj_value); }
2224
+ template<typename Tile, typename AdjTile>
2225
+ void adj_tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k, l), adj_value); }
2226
+
1984
2227
  namespace partitioned_gemm
1985
2228
  {
1986
2229
 
@@ -2177,33 +2420,98 @@ inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
2177
2420
  }
2178
2421
  }
2179
2422
 
2423
+ // Writes into X
2180
2424
  template <typename TileL, typename TileX, typename TileY>
2181
- inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
2425
+ inline CUDA_CALLABLE void scalar_cholesky_forward_substitution(TileL& L, TileX& X, TileY& Y)
2182
2426
  {
2183
- using T = typename TileL::Type;
2184
- constexpr int n = TileL::Layout::Shape::dim(1);
2427
+ using T = typename TileL::Type;
2428
+
2429
+ if constexpr (TileY::Layout::Shape::N == 1)
2430
+ {
2431
+ constexpr int n = TileL::Layout::Shape::dim(1);
2432
+
2433
+ for (int i=0; i < n; ++i)
2434
+ {
2435
+ T s = Y.data(tile_coord(i));
2436
+
2437
+ for (int j=0; j < i; ++j)
2438
+ s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
2185
2439
 
2186
- for (int i=0; i < n; ++i)
2440
+ T diag = L.data(tile_coord(i, i));
2441
+ X.data(tile_coord(i)) = (diag != T(0.0f)) ? s / diag : s;
2442
+ }
2443
+ }
2444
+ else if constexpr (TileY::Layout::Shape::N == 2)
2187
2445
  {
2188
- T s = Y.data(tile_coord(i));
2446
+ constexpr int n = TileL::Layout::Shape::dim(1);
2447
+ constexpr int m = TileY::Layout::Shape::dim(1);
2448
+
2449
+ for (int k=0; k < m; ++k)
2450
+ {
2451
+ for (int i=0; i < n; ++i)
2452
+ {
2453
+ T s = Y.data(tile_coord(i,k));
2189
2454
 
2190
- for (int j=0; j < i; ++j)
2191
- s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
2455
+ for (int j=0; j < i; ++j)
2456
+ s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j,k));
2192
2457
 
2193
- X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2458
+ T diag = L.data(tile_coord(i, i));
2459
+ X.data(tile_coord(i,k)) = (diag != T(0.0f)) ? s / diag : s;
2460
+ }
2461
+ }
2194
2462
  }
2463
+ }
2464
+
2465
+ // Reads and writes X
2466
+ template <typename TileL, typename TileX>
2467
+ inline CUDA_CALLABLE void scalar_cholesky_back_substitution(TileL& L, TileX& X)
2468
+ {
2469
+ using T = typename TileL::Type;
2470
+
2471
+ if constexpr (TileX::Layout::Shape::N == 1)
2472
+ {
2473
+ constexpr int n = TileL::Layout::Shape::dim(1);
2195
2474
 
2196
- for (int i=n-1; i >= 0; --i)
2475
+ for (int i=n-1; i >= 0; --i)
2476
+ {
2477
+ T s = X.data(tile_coord(i));
2478
+
2479
+ for (int j=i+1; j < n; ++j)
2480
+ s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
2481
+
2482
+ T diag = L.data(tile_coord(i, i));
2483
+ X.data(tile_coord(i)) = (diag != T(0.0f)) ? s / diag : s;
2484
+ }
2485
+ }
2486
+ else if constexpr (TileX::Layout::Shape::N == 2)
2197
2487
  {
2198
- T s = X.data(tile_coord(i));
2488
+ constexpr int n = TileL::Layout::Shape::dim(1);
2489
+ constexpr int m = TileX::Layout::Shape::dim(1);
2199
2490
 
2200
- for (int j=i+1; j < n; ++j)
2201
- s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
2491
+ for (int k=0; k < m; ++k)
2492
+ {
2493
+ for (int i=n-1; i >= 0; --i)
2494
+ {
2495
+ T s = X.data(tile_coord(i,k));
2496
+
2497
+ for (int j=i+1; j < n; ++j)
2498
+ s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j,k));
2202
2499
 
2203
- X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2500
+ T diag = L.data(tile_coord(i, i));
2501
+ X.data(tile_coord(i,k)) = (diag != T(0.0f)) ? s / diag : s;
2502
+ }
2503
+ }
2204
2504
  }
2205
2505
  }
2206
2506
 
2507
+ template <typename TileL, typename TileX, typename TileY>
2508
+ inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
2509
+ {
2510
+ scalar_cholesky_forward_substitution(L, X, Y);
2511
+ scalar_cholesky_back_substitution(L, X);
2512
+ }
2513
+
2514
+
2207
2515
  } // namespace partition_gemm
2208
2516
 
2209
2517
 
@@ -2223,12 +2531,14 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
2223
2531
  static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
2224
2532
 
2225
2533
 
2226
- using T = typename TileA::Type;
2534
+ using T = typename TileC::Type;
2227
2535
 
2228
2536
  #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2229
2537
  partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
2230
2538
  #else
2231
- fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
2539
+ T alpha = T(1.0);
2540
+ T beta = T(Add);
2541
+ fun_forward(&alpha, A.data.ptr, B.data.ptr, &beta, C.data.ptr);
2232
2542
  #endif
2233
2543
 
2234
2544
  WP_TILE_SYNC();
@@ -2242,17 +2552,22 @@ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename T
2242
2552
  void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
2243
2553
  Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C)
2244
2554
  {
2245
- using T = typename TileA::Type;
2555
+ using T_A = typename TileA::Type;
2556
+ using T_B = typename TileB::Type;
2246
2557
 
2247
2558
  #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2248
2559
  auto At = tile_transpose(A);
2249
2560
  auto Bt = tile_transpose(B);
2250
2561
 
2251
- partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
2252
- partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
2562
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T_A(1.0));
2563
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T_B(1.0));
2253
2564
  #else
2254
- fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
2255
- fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
2565
+ T_A alpha_A = T_A(1.0);
2566
+ T_A beta_A = T_A(1.0);
2567
+ fun_backward_A(&alpha_A, adj_C.grad.ptr, B.data.ptr, &beta_A, adj_A.grad.ptr);
2568
+ T_B alpha_B = T_B(1.0);
2569
+ T_B beta_B = T_B(1.0);
2570
+ fun_backward_B(&alpha_B, A.data.ptr, adj_C.grad.ptr, &beta_B, adj_B.grad.ptr);
2256
2571
  #endif
2257
2572
 
2258
2573
  WP_TILE_SYNC();
@@ -2263,7 +2578,7 @@ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename T
2263
2578
  void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
2264
2579
  Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C, TileC& adj_ret)
2265
2580
  {
2266
- using T = typename TileA::Type;
2581
+ using T = typename TileC::Type;
2267
2582
 
2268
2583
  #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2269
2584
  auto At = tile_transpose(A);
@@ -2272,8 +2587,10 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2272
2587
  partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
2273
2588
  partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
2274
2589
  #else
2275
- fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
2276
- fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
2590
+ T alpha = T(1.0);
2591
+ T beta = T(1.0);
2592
+ fun_backward_A(&alpha, adj_C.grad.ptr, B.data.ptr, &beta, adj_A.grad.ptr);
2593
+ fun_backward_B(&alpha, A.data.ptr, adj_C.grad.ptr, &beta, adj_B.grad.ptr);
2277
2594
  #endif
2278
2595
 
2279
2596
  WP_TILE_SYNC();
@@ -2293,13 +2610,13 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2293
2610
  // and remove the need for __align__(16) dtypes data[...]
2294
2611
  #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
2295
2612
  do { \
2296
- void function_name(dtype*, dtype*); \
2613
+ void function_name(dtype*, char*); \
2297
2614
  char* buffer = (char*)wp::tile_alloc_shared(shared_memory_size); \
2298
2615
  __align__(16) dtype data[ept]; \
2299
2616
  for(int b = 0; b < (int)batch_size; b++) { \
2300
2617
  dtype* inout = Xinout.data + (int)b * (int)ept; \
2301
2618
  memcpy(data, inout, sizeof(dtype) * ept); \
2302
- function_name(data, (dtype*)buffer); \
2619
+ function_name(data, buffer); \
2303
2620
  memcpy(inout, data, sizeof(dtype) * ept); \
2304
2621
  WP_TILE_SYNC(); \
2305
2622
  } \
@@ -2328,7 +2645,15 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2328
2645
 
2329
2646
  template <typename Fwd, typename TileA, typename TileL>
2330
2647
  TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2331
- {
2648
+ {
2649
+ static_assert(TileA::Layout::Shape::N == 2, "Expected TileA::Layout::Shape::N == 2");
2650
+ static_assert(TileL::Layout::Shape::N == 2, "Expected TileL::Layout::Shape::N == 2");
2651
+
2652
+ static_assert(TileA::Layout::Shape::dim(0) == TileA::Layout::Shape::dim(1), "Expected TileA to be square");
2653
+ static_assert(TileL::Layout::Shape::dim(0) == TileL::Layout::Shape::dim(1), "Expected TileL to be square");
2654
+ static_assert(TileA::Layout::Shape::dim(0) == TileL::Layout::Shape::dim(0), "Expected A and L to have the same number of rows");
2655
+ static_assert(TileA::Layout::Shape::dim(1) == TileL::Layout::Shape::dim(1), "Expected A and L to have the same number of columns");
2656
+
2332
2657
  // Copy to L
2333
2658
  L = A;
2334
2659
 
@@ -2338,14 +2663,27 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2338
2663
 
2339
2664
  #else
2340
2665
 
2666
+ // TODO: for batched Cholesky, need one info per batch
2667
+ WP_TILE_SHARED int info[1];
2668
+
2669
+ if (WP_TILE_THREAD_IDX == 0) {
2670
+ info[0] = 0;
2671
+ }
2341
2672
 
2342
2673
  // Call cholesky on L
2343
2674
  WP_TILE_SYNC();
2344
2675
 
2345
- fun_forward(L.data.ptr, TileL::Layout::Shape::dim(0));
2676
+ fun_forward(L.data.ptr, info);
2346
2677
 
2347
2678
  WP_TILE_SYNC();
2348
2679
 
2680
+ // TODO: for batched Cholesky, check all batches
2681
+ #if defined(_DEBUG)
2682
+ if (WP_TILE_THREAD_IDX == 0 && info[0] != 0) {
2683
+ printf("Non-zero status in Cholesky factorization, got %d\n", info[0]);
2684
+ }
2685
+ #endif
2686
+
2349
2687
  // Zero-out the upper triangular part of L
2350
2688
 
2351
2689
  WP_PRAGMA_UNROLL
@@ -2371,11 +2709,11 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2371
2709
  } while (0)
2372
2710
 
2373
2711
  template <typename Fwd, typename TileL, typename TileX, typename TileY>
2374
- TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2712
+ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& Y, TileY& X)
2375
2713
  {
2376
- // Copy x to y
2714
+ // Copy y to x
2377
2715
 
2378
- Y = X;
2716
+ X = Y;
2379
2717
 
2380
2718
  #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2381
2719
 
@@ -2383,25 +2721,100 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2383
2721
 
2384
2722
  #else
2385
2723
 
2386
- // Call cholesky solve on L & y
2724
+ // Call cholesky solve on L & x
2387
2725
 
2388
2726
  WP_TILE_SYNC();
2389
2727
 
2390
- fun_forward(L.data.ptr, Y.data.ptr); \
2728
+ fun_forward(L.data.ptr, X.data.ptr); \
2391
2729
 
2392
2730
  WP_TILE_SYNC();
2393
2731
 
2394
2732
  #endif
2395
2733
 
2396
- return Y;
2734
+ return X;
2397
2735
  }
2398
2736
 
2399
- #define adj_tile_cholesky_solve(function_name, L, X, Y, \
2400
- adj_function_name, adj_L, adj_X, adj_Y, adj_ret) \
2737
+ #define adj_tile_cholesky_solve(function_name, L, Y, X, \
2738
+ adj_function_name, adj_L, adj_Y, adj_X, adj_ret) \
2401
2739
  do { \
2402
2740
  assert(false); \
2403
2741
  } while (0)
2404
2742
 
2743
+
2744
+
2745
+
2746
+
2747
+
2748
+ template <typename Fwd, typename TileL, typename TileY, typename TileZ>
2749
+ TileZ& tile_lower_solve(Fwd fun_forward, TileL& L, TileY& y, TileZ& z)
2750
+ {
2751
+ // Copy y to z
2752
+ z = y;
2753
+
2754
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2755
+
2756
+ partitioned_gemm::scalar_cholesky_forward_substitution(L, z, y);
2757
+
2758
+ #else
2759
+
2760
+ // Call cholesky solve on L & z
2761
+
2762
+ WP_TILE_SYNC();
2763
+
2764
+ fun_forward(L.data.ptr, z.data.ptr);
2765
+
2766
+ WP_TILE_SYNC();
2767
+
2768
+ #endif
2769
+
2770
+ return z;
2771
+ }
2772
+
2773
+ #define adj_tile_lower_solve(function_name, L, y, z, \
2774
+ adj_function_name, adj_L, adj_y, adj_z, adj_ret) \
2775
+ do { \
2776
+ assert(false); \
2777
+ } while (0)
2778
+
2779
+
2780
+
2781
+ template <typename Fwd, typename TileU, typename TileZ, typename TileX>
2782
+ TileX& tile_upper_solve(Fwd fun_forward, TileU& U, TileZ& z, TileX& x)
2783
+ {
2784
+ // Copy z to x
2785
+ x = z;
2786
+
2787
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2788
+
2789
+ auto L = tile_transpose(U);
2790
+ partitioned_gemm::scalar_cholesky_back_substitution(L, x);
2791
+
2792
+ #else
2793
+
2794
+ // Call cholesky solve on U & x
2795
+
2796
+ WP_TILE_SYNC();
2797
+
2798
+ fun_forward(U.data.ptr, x.data.ptr);
2799
+
2800
+ WP_TILE_SYNC();
2801
+
2802
+ #endif
2803
+
2804
+ return x;
2805
+ }
2806
+
2807
+ #define adj_tile_upper_solve(function_name, U, z, x, \
2808
+ adj_function_name, adj_U, adj_z, adj_x, adj_ret) \
2809
+ do { \
2810
+ assert(false); \
2811
+ } while (0)
2812
+
2813
+
2814
+
2815
+
2816
+
2817
+
2405
2818
  template <typename Tile>
2406
2819
  inline CUDA_CALLABLE auto tile_transpose(Tile& t)
2407
2820
  {
@@ -2457,10 +2870,11 @@ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2457
2870
  template <typename Tile, typename AdjTile>
2458
2871
  inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
2459
2872
  {
2460
- // nop, since memory is aliased grads already accumulated
2873
+ // nop, since memory is aliased, grads already accumulated
2461
2874
  }
2462
2875
 
2463
- template <typename ReturnType, typename Tile, typename... Indices>
2876
+
2877
+ template <typename ReturnTile, typename Tile, typename... Indices>
2464
2878
  inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
2465
2879
  {
2466
2880
  auto c = tile_coord(indices...);
@@ -2472,7 +2886,104 @@ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
2472
2886
  if (t.grad.ptr)
2473
2887
  grad_ptr = &t.grad(c);
2474
2888
 
2475
- return ReturnType(data_ptr, grad_ptr);
2889
+ return ReturnTile(data_ptr, grad_ptr);
2890
+ }
2891
+
2892
+
2893
+ template <typename ReturnTile, typename Tile>
2894
+ inline CUDA_CALLABLE auto tile_squeeze(Tile& t)
2895
+ {
2896
+ // ReturnTile layout is set in builtins.py
2897
+ typename Tile::Type* data_ptr = t.data.ptr;
2898
+ typename Tile::Type* grad_ptr = nullptr;
2899
+
2900
+ if (t.grad.ptr)
2901
+ grad_ptr = t.grad.ptr;
2902
+
2903
+ return ReturnTile(data_ptr, grad_ptr);
2904
+ }
2905
+
2906
+ template <typename Tile, typename AdjTile, typename AdjReturnTile>
2907
+ inline CUDA_CALLABLE void adj_tile_squeeze(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
2908
+ {
2909
+ // nop, since memory is aliased, grads already accumulated
2910
+ }
2911
+
2912
+
2913
+ template <typename ReturnTile, typename Tile>
2914
+ inline CUDA_CALLABLE auto tile_reshape(Tile& t)
2915
+ {
2916
+ // ReturnTile layout is set in builtins.py
2917
+ typename Tile::Type* data_ptr = t.data.ptr;
2918
+ typename Tile::Type* grad_ptr = nullptr;
2919
+
2920
+ if (t.grad.ptr)
2921
+ grad_ptr = t.grad.ptr;
2922
+
2923
+ return ReturnTile(data_ptr, grad_ptr);
2924
+ }
2925
+
2926
+ template <typename Tile, typename AdjTile, typename AdjReturnTile>
2927
+ inline CUDA_CALLABLE void adj_tile_reshape(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
2928
+ {
2929
+ // nop, since memory is aliased, grads already accumulated
2930
+ }
2931
+
2932
+
2933
+ template <typename ReturnTile, typename Tile>
2934
+ inline CUDA_CALLABLE auto tile_astype(Tile& t)
2935
+ {
2936
+ // verify shapes and sizes are compatible
2937
+ using ShapeIn = typename Tile::Layout::Shape;
2938
+ using ShapeOut = typename ReturnTile::Layout::Shape;
2939
+
2940
+ static_assert(ShapeIn::N == ShapeOut::N, "Tile shapes must match for data type casting");
2941
+ static_assert(ShapeIn::size() == ShapeOut::size(), "Tile sizes must match for data type casting");
2942
+
2943
+ // work with register tiles for type casting
2944
+ auto t_reg = t.copy_to_register();
2945
+ auto result = tile_register_like<ReturnTile>();
2946
+
2947
+ using Layout = typename decltype(result)::Layout;
2948
+
2949
+ WP_PRAGMA_UNROLL
2950
+ for (int i = 0; i < Layout::NumRegs; ++i)
2951
+ {
2952
+ const int linear = Layout::linear_from_register(i);
2953
+
2954
+ if(!Layout::valid(linear))
2955
+ break;
2956
+
2957
+ result.data[i] = static_cast<typename ReturnTile::Type>(t_reg.data[i]);
2958
+ }
2959
+
2960
+ return result;
2961
+ }
2962
+
2963
+ template <typename Tile, typename AdjTile, typename AdjReturnTile>
2964
+ inline CUDA_CALLABLE void adj_tile_astype(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
2965
+ {
2966
+ // gradients only flow between float conversions
2967
+ if constexpr((is_same<typename AdjTile::Type, wp::float16>::value ||
2968
+ is_same<typename AdjTile::Type, wp::float32>::value ||
2969
+ is_same<typename AdjTile::Type, wp::float64>::value) &&
2970
+ (is_same<typename AdjReturnTile::Type, wp::float16>::value ||
2971
+ is_same<typename AdjReturnTile::Type, wp::float32>::value ||
2972
+ is_same<typename AdjReturnTile::Type, wp::float64>::value))
2973
+ {
2974
+ auto adj_ret_reg = adj_ret.grad_to_register();
2975
+ auto adj_t_reg = tile_register_like<AdjTile>();
2976
+
2977
+ using Layout = typename decltype(adj_t_reg)::Layout;
2978
+
2979
+ WP_PRAGMA_UNROLL
2980
+ for (int i = 0; i < Layout::NumRegs; ++i)
2981
+ {
2982
+ adj_t_reg.data[i] += static_cast<typename AdjTile::Type>(adj_ret_reg.data[i]);
2983
+ }
2984
+
2985
+ adj_t.grad_add(adj_t_reg);
2986
+ }
2476
2987
  }
2477
2988
 
2478
2989