warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/native/tile.h CHANGED
@@ -19,9 +19,39 @@
19
19
 
20
20
  #include "builtin.h"
21
21
 
22
+ #ifdef __clang__
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif // __clang__
27
+
28
+ // Check if the CUDA toolkit is available
29
+ #if WP_ENABLE_CUDA || defined(__CUDACC_RTC__)
30
+
31
+ // If NVRTC is being used, do not include extra headers (NVRTC has built-in float4)
32
+ #ifdef __CUDACC_RTC__
33
+ // NVRTC: Use built-in float4 (no need for extra definitions)
34
+ #else
35
+ // NVCC: Include vector_types.h to get float4
36
+ #include <cuda_runtime.h>
37
+ #endif
38
+
39
+ #else
40
+ // If CUDA is not available (e.g., macOS build), manually define float4
41
+ struct alignas(16) float4 {
42
+ float x, y, z, w;
43
+ };
44
+ #endif
45
+
46
+ // only used while building the warp core library
47
+ #ifndef WP_TILE_BLOCK_DIM
48
+ #define WP_TILE_BLOCK_DIM 256
49
+ #endif
50
+
22
51
  #if !defined(__CUDA_ARCH__)
23
52
  #define WP_TILE_SHARED static
24
53
  #define WP_TILE_SYNC void
54
+
25
55
  #else
26
56
  #define WP_TILE_SHARED __shared__
27
57
  #define WP_TILE_SYNC __syncthreads
@@ -46,6 +76,14 @@
46
76
  #define WP_USE_ASYNC_PIPELINE 0
47
77
  #define WP_USE_REGISTER_GEMM 0
48
78
 
79
+ #if defined(__CUDACC_RTC__)
80
+ #define WP_TILE_THREAD_IDX threadIdx.x
81
+ #else
82
+ #define WP_TILE_THREAD_IDX 0
83
+ #endif //
84
+
85
+
86
+
49
87
  /* Tile Expressions
50
88
 
51
89
  [ ] Tiles
@@ -217,14 +255,14 @@ constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
217
255
  }
218
256
 
219
257
  // helpers to construct a coord from a set of indices
220
- auto tile_coord(int i)
258
+ inline auto tile_coord(int i)
221
259
  {
222
260
  auto c = tile_coord_t<1>();
223
261
  c.indices[0] = i;
224
262
  return c;
225
263
  }
226
264
 
227
- auto tile_coord(int i, int j)
265
+ inline auto tile_coord(int i, int j)
228
266
  {
229
267
  auto c = tile_coord_t<2>();
230
268
  c.indices[0] = i;
@@ -232,7 +270,7 @@ auto tile_coord(int i, int j)
232
270
  return c;
233
271
  }
234
272
 
235
- auto tile_coord(int i, int j, int k)
273
+ inline auto tile_coord(int i, int j, int k)
236
274
  {
237
275
  auto c = tile_coord_t<3>();
238
276
  c.indices[0] = i;
@@ -241,7 +279,7 @@ auto tile_coord(int i, int j, int k)
241
279
  return c;
242
280
  }
243
281
 
244
- auto tile_coord(int i, int j, int k, int l)
282
+ inline auto tile_coord(int i, int j, int k, int l)
245
283
  {
246
284
  auto c = tile_coord_t<4>();
247
285
  c.indices[0] = i;
@@ -256,7 +294,7 @@ template <int... V>
256
294
  struct tile_tuple_t
257
295
  {
258
296
  static constexpr int N = sizeof...(V);
259
- static_assert(N > 0);
297
+ static_assert(N > 0, "Expected N > 0");
260
298
 
261
299
  static constexpr int data[N] = { V... };
262
300
 
@@ -409,7 +447,7 @@ struct tile_layout_register_t
409
447
 
410
448
  static inline CUDA_CALLABLE int linear_from_register(int reg)
411
449
  {
412
- return threadIdx.x + reg*WP_TILE_BLOCK_DIM;
450
+ return WP_TILE_THREAD_IDX + reg*WP_TILE_BLOCK_DIM;
413
451
  }
414
452
 
415
453
  static inline CUDA_CALLABLE int linear_from_coord(Coord c)
@@ -509,15 +547,6 @@ struct tile_register_t
509
547
  return data[reg];
510
548
  }
511
549
 
512
- // Returns the number of valid registers for this tile
513
- // i.e.: how many registers map to a valid coordinate.
514
- // When a tile's size is not aligned to the block dimension
515
- // some of the trailing registers may lie outside the valid range
516
- inline CUDA_CALLABLE int valid() const
517
- {
518
- return (int)floor(float(Size - threadIdx.x - 1)/WP_TILE_BLOCK_DIM) + 1;
519
- }
520
-
521
550
  inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
522
551
  {
523
552
  for (int i=0; i < Layout::NumRegs; ++i)
@@ -544,7 +573,7 @@ struct tile_register_t
544
573
  // ensure any previously scheduled threads have finished reading from scratch
545
574
  WP_TILE_SYNC();
546
575
 
547
- if (threadIdx.x == thread)
576
+ if (WP_TILE_THREAD_IDX == thread)
548
577
  {
549
578
  scratch = data[reg];
550
579
  }
@@ -565,7 +594,7 @@ struct tile_register_t
565
594
  const int thread = Layout::thread_from_linear(linear);
566
595
  const int reg = Layout::register_from_linear(linear);
567
596
 
568
- if (threadIdx.x == thread)
597
+ if (WP_TILE_THREAD_IDX == thread)
569
598
  {
570
599
  data[reg] += adj_ret;
571
600
  }
@@ -668,7 +697,7 @@ struct tile_register_t
668
697
  // users can either specify a template explicitly or
669
698
  // pass in another concrete instance
670
699
  template<typename Tile>
671
- auto tile_register_like(Tile* t=NULL)
700
+ auto tile_register_like(Tile* t=nullptr)
672
701
  {
673
702
  using T = typename Tile::Type;
674
703
  using L = typename Tile::Layout;
@@ -694,26 +723,39 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
694
723
  return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
695
724
  }
696
725
 
697
- inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
726
+ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, bool check=false)
698
727
  {
699
728
  // we maintain a per-thread offset into dynamic
700
729
  // shared memory that allows us to keep track of
701
730
  // current use across dynamic function calls
702
- __shared__ int smem_base[WP_TILE_BLOCK_DIM];
731
+ WP_TILE_SHARED int smem_base[WP_TILE_BLOCK_DIM];
703
732
 
704
733
  if (init)
705
734
  {
706
- smem_base[threadIdx.x] = 0;
707
- return NULL;
735
+ smem_base[WP_TILE_THREAD_IDX] = 0;
736
+ return nullptr;
737
+ }
738
+ else if (check)
739
+ {
740
+ assert(smem_base[WP_TILE_THREAD_IDX] == 0);
741
+ return nullptr;
708
742
  }
709
743
  else
710
744
  {
711
- const int offset = smem_base[threadIdx.x];
745
+ const int offset = smem_base[WP_TILE_THREAD_IDX];
712
746
 
713
747
  // one entry per-thread so no need for synchronization
714
- smem_base[threadIdx.x] += tile_align(num_bytes);
748
+ smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
715
749
 
750
+ #ifdef __CUDA_ARCH__
716
751
  extern __shared__ char dynamic_smem_base[];
752
+ #else
753
+ // on CPU allocate a fixed 256k block to use for shared allocs
754
+ static const int max_cpu_shared = 256*1024;
755
+ static char dynamic_smem_base[max_cpu_shared];
756
+
757
+ assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
758
+ #endif
717
759
  return &(dynamic_smem_base[offset]);
718
760
  }
719
761
  }
@@ -847,12 +889,12 @@ struct tile_shared_t
847
889
  bool initialized;
848
890
 
849
891
  // default initialization (non-initialized)
850
- inline CUDA_CALLABLE tile_shared_t() : data(NULL), grad(NULL), initialized(false)
892
+ inline CUDA_CALLABLE tile_shared_t() : data(nullptr), grad(nullptr), initialized(false)
851
893
  {
852
894
  }
853
895
 
854
896
  // initialize from an existing tile's memory
855
- inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=NULL, bool initialized=true) : data(data), grad(grad), initialized(initialized)
897
+ inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
856
898
  {
857
899
  }
858
900
 
@@ -878,6 +920,7 @@ struct tile_shared_t
878
920
  }
879
921
 
880
922
 
923
+ /*
881
924
  // construct from another shared tile, this constructor
882
925
  // is invoked for reshape operations like `wp.tile_transpose()`
883
926
  template <typename OtherT, typename OtherLayout>
@@ -886,7 +929,7 @@ struct tile_shared_t
886
929
  using OtherTile = tile_shared_t<OtherT, OtherLayout>;
887
930
 
888
931
  // check dimensions are compatible
889
- static_assert(Size == OtherTile::Size);
932
+ static_assert(Size == OtherTile::Size, "Expected Size == OtherTile::Size");
890
933
 
891
934
  // alias tile directly
892
935
  data = rhs.data;
@@ -895,6 +938,7 @@ struct tile_shared_t
895
938
 
896
939
  return *this;
897
940
  }
941
+ */
898
942
 
899
943
  // assign from a global tile (load)
900
944
  inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
@@ -912,7 +956,7 @@ struct tile_shared_t
912
956
  if (initialized)
913
957
  WP_TILE_SYNC();
914
958
 
915
- for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
959
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
916
960
  data(i) = x;
917
961
 
918
962
  initialized = true;
@@ -923,7 +967,7 @@ struct tile_shared_t
923
967
  // in-place zero
924
968
  inline CUDA_CALLABLE void zero()
925
969
  {
926
- for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
970
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
927
971
  data(i) = T(0);
928
972
 
929
973
  WP_TILE_SYNC();
@@ -973,7 +1017,7 @@ struct tile_shared_t
973
1017
  // in-place gradient zero
974
1018
  inline CUDA_CALLABLE void grad_zero()
975
1019
  {
976
- for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
1020
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
977
1021
  grad(i) = T(0);
978
1022
 
979
1023
  WP_TILE_SYNC();
@@ -1013,7 +1057,7 @@ struct tile_shared_t
1013
1057
  CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1014
1058
  {
1015
1059
  WP_PRAGMA_UNROLL
1016
- for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1060
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1017
1061
  {
1018
1062
  auto c = Layout::coord_from_linear(i);
1019
1063
  T g = global.load_grad(c);
@@ -1081,23 +1125,25 @@ struct tile_shared_t
1081
1125
  template <typename Global>
1082
1126
  inline CUDA_CALLABLE void copy_to_global(const Global& dest)
1083
1127
  {
1128
+
1129
+ #if defined(__CUDA_ARCH__)
1084
1130
  // vectorized loads for specific input/output shapes
1085
1131
  if constexpr (Layout::Shape::N == 2)
1086
1132
  {
1087
1133
  constexpr int lastdim = Layout::Shape::N-1;
1088
1134
  constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1;
1089
1135
  const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T);
1090
- const int elements = (dest.data.shape[lastdim] - dest.offset[lastdim]);
1136
+ const int elements = min(Layout::Shape::dim(1), (dest.data.shape[lastdim] - dest.offset[lastdim]));
1091
1137
  const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1092
1138
 
1093
1139
  float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
1094
1140
  const bool aligned_dst = (uint64_t)(dest128)%sizeof(float4) == 0;
1095
1141
 
1096
- if (contiguous_dest && contiguous_src && aligned_size && aligned_dst)
1097
- {
1098
- constexpr int M = Layout::Shape::dim(0);
1099
- constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1142
+ constexpr int M = Layout::Shape::dim(0);
1143
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1100
1144
 
1145
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_dst && N)
1146
+ {
1101
1147
  // alias of shared tile with 128bit type
1102
1148
  using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1103
1149
  tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
@@ -1109,7 +1155,7 @@ struct tile_shared_t
1109
1155
  const int stride_j = 1;
1110
1156
 
1111
1157
  WP_PRAGMA_UNROLL
1112
- for (int i=threadIdx.x; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
1158
+ for (int i=WP_TILE_THREAD_IDX; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
1113
1159
  {
1114
1160
  auto c = SrcLayout::coord_from_linear(i);
1115
1161
 
@@ -1120,17 +1166,18 @@ struct tile_shared_t
1120
1166
  }
1121
1167
  }
1122
1168
 
1169
+ #endif //defined(__CUDA_ARCH__)
1170
+
1123
1171
  // scalar bounds checked path
1124
1172
  WP_PRAGMA_UNROLL
1125
- for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1173
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1126
1174
  {
1127
1175
  auto c = Layout::coord_from_linear(i);
1128
1176
  dest.store(c, data(i));
1129
1177
  }
1130
1178
  }
1131
1179
 
1132
- __device__ __forceinline__
1133
- void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
1180
+ inline CUDA_CALLABLE void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
1134
1181
  {
1135
1182
  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1136
1183
 
@@ -1152,8 +1199,7 @@ struct tile_shared_t
1152
1199
  #endif
1153
1200
  }
1154
1201
 
1155
- __device__ __forceinline__
1156
- void cp_async_commit_and_wait_all_128()
1202
+ inline CUDA_CALLABLE void cp_async_commit_and_wait_all_128()
1157
1203
  {
1158
1204
  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1159
1205
  asm volatile(
@@ -1168,23 +1214,25 @@ struct tile_shared_t
1168
1214
  if (initialized)
1169
1215
  WP_TILE_SYNC();
1170
1216
 
1217
+ #if defined(__CUDA_ARCH__)
1218
+
1171
1219
  // vectorized loads for specific input/output shapes
1172
1220
  if constexpr (Layout::Shape::N == 2)
1173
1221
  {
1174
1222
  constexpr int lastdim = Layout::Shape::N-1;
1175
1223
  constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1;
1176
1224
  const bool contiguous_src = src.data.strides[lastdim] == sizeof(T);
1177
- const int elements = (src.data.shape[lastdim] - src.offset[lastdim]);
1225
+ const int elements = min(Layout::Shape::dim(1), (src.data.shape[lastdim] - src.offset[lastdim]));
1178
1226
  const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1179
-
1227
+
1180
1228
  float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1181
1229
  const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
1182
1230
 
1183
- if (contiguous_dest && contiguous_src && aligned_size && aligned_src)
1184
- {
1185
- constexpr int M = Layout::Shape::dim(0);
1186
- constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1231
+ constexpr int M = Layout::Shape::dim(0);
1232
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1187
1233
 
1234
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_src && N)
1235
+ {
1188
1236
  // alias of shared tile with 128bit type
1189
1237
  using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1190
1238
  tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
@@ -1196,7 +1244,7 @@ struct tile_shared_t
1196
1244
  const int stride_j = 1;
1197
1245
 
1198
1246
  WP_PRAGMA_UNROLL
1199
- for (int i=threadIdx.x; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
1247
+ for (int i=WP_TILE_THREAD_IDX; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
1200
1248
  {
1201
1249
  auto c = DestLayout::coord_from_linear(i);
1202
1250
 
@@ -1217,9 +1265,11 @@ struct tile_shared_t
1217
1265
  }
1218
1266
  }
1219
1267
 
1268
+ #endif //defined(__CUDA_ARCH__)
1269
+
1220
1270
  // scalar bounds checked path
1221
1271
  WP_PRAGMA_UNROLL
1222
- for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1272
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1223
1273
  {
1224
1274
  auto c = Layout::coord_from_linear(i);
1225
1275
  data(i) = src.load(c);
@@ -1332,7 +1382,7 @@ struct tile_shared_t
1332
1382
 
1333
1383
  inline CUDA_CALLABLE void print(bool reverse=false) const
1334
1384
  {
1335
- if (threadIdx.x != 0)
1385
+ if (WP_TILE_THREAD_IDX != 0)
1336
1386
  return;
1337
1387
 
1338
1388
  if (reverse)
@@ -1359,13 +1409,13 @@ void tile_register_t<T, L>::print() const
1359
1409
  // create a temporary shared tile so that
1360
1410
  // we can print it deterministically
1361
1411
  WP_TILE_SHARED T smem[L::Size];
1362
- tile_shared_t<T, tile_layout_strided_t<typename L::Shape>> scratch(smem, NULL);
1412
+ tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
1363
1413
 
1364
1414
  scratch.assign(*this);
1365
1415
 
1366
1416
  WP_TILE_SYNC();
1367
1417
 
1368
- if (threadIdx.x == 0)
1418
+ if (WP_TILE_THREAD_IDX == 0)
1369
1419
  {
1370
1420
  scratch.print_values(scratch.data, 0);
1371
1421
 
@@ -1392,7 +1442,7 @@ inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print()
1392
1442
  template <typename T, typename L, bool O>
1393
1443
  inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
1394
1444
  {
1395
- return Tile::Layout::Shape::dim(0);
1445
+ return L::Shape::dim(0);
1396
1446
  }
1397
1447
 
1398
1448
  template <typename T, typename L, bool O, typename AdjTile>
@@ -1403,7 +1453,7 @@ inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile&
1403
1453
  template <typename T, typename L>
1404
1454
  inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
1405
1455
  {
1406
- return Tile::Layout::Shape::dim(0);
1456
+ return L::Shape::dim(0);
1407
1457
  }
1408
1458
 
1409
1459
  template <typename T, typename L, typename AdjTile>
@@ -1425,12 +1475,16 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1425
1475
 
1426
1476
  { constexpr int size = Shape::size();
1427
1477
  T* data = (T*)tile_alloc_shared(size*sizeof(T));
1428
- T* grad = NULL;
1478
+ T* grad = nullptr;
1429
1479
 
1430
1480
  #if FP_CHECK
1431
1481
 
1432
- for (int i=threadIdx.x; i < size; i+= WP_TILE_BLOCK_DIM)
1433
- data[i] = T(nanf(""));
1482
+ // initialize tile to quiet nan
1483
+ uint32_t qnanbits = 0x7FC00000;
1484
+ float qnan = *(float*)(&qnanbits);
1485
+
1486
+ for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1487
+ data[i] = T(qnan);
1434
1488
 
1435
1489
  WP_TILE_SYNC();
1436
1490
 
@@ -1441,7 +1495,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1441
1495
  {
1442
1496
  grad = (T*)tile_alloc_shared(size*sizeof(T));
1443
1497
 
1444
- for (int i=threadIdx.x; i < size; i+= WP_TILE_BLOCK_DIM)
1498
+ for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1445
1499
  grad[i] = T(0);
1446
1500
 
1447
1501
  WP_TILE_SYNC();
@@ -1450,30 +1504,6 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1450
1504
  return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
1451
1505
  }
1452
1506
 
1453
- template <typename T, int M, int N, bool RequiresGrad>
1454
- inline CUDA_CALLABLE auto tile_alloc_zeros()
1455
- {
1456
- // compute the total storage required for the tile (may be different from M*N) for broadcast tiles
1457
- constexpr int Len = M*N;
1458
- T* data = (T*)tile_alloc_shared(Len*sizeof(T));
1459
- T* grad = NULL;
1460
-
1461
- for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1462
- data[i] = T(0);
1463
-
1464
- if (RequiresGrad)
1465
- {
1466
- grad = (T*)tile_alloc_shared(Len*sizeof(T));
1467
-
1468
- for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1469
- grad[i] = T(0);
1470
- }
1471
-
1472
- WP_TILE_SYNC();
1473
-
1474
- return tile_shared_t<T, tile_layout_strided_t<tile_shape_t<M, N>>(data, grad);
1475
- }
1476
-
1477
1507
 
1478
1508
  //-----------------------------------------------------------------------------------------------------
1479
1509
  // High level entry points for each op (correspond to one Warp builtin)
@@ -1485,7 +1515,7 @@ inline CUDA_CALLABLE auto tile(const T& x)
1485
1515
  tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
1486
1516
 
1487
1517
  using Layout = typename decltype(result)::Layout;
1488
- static_assert(Layout::NumRegs == 1);
1518
+ static_assert(Layout::NumRegs == 1, "Expected Layout::NumRegs == 1");
1489
1519
 
1490
1520
  result.data[0] = x;
1491
1521
  return result;
@@ -1498,7 +1528,7 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1498
1528
  tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
1499
1529
 
1500
1530
  using Layout = typename decltype(result)::Layout;
1501
- static_assert(Layout::NumRegs == Length);
1531
+ static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
1502
1532
 
1503
1533
  for (int i=0; i < Length; ++i)
1504
1534
  result.data[i] = x[i];
@@ -1510,8 +1540,8 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1510
1540
  template <typename T, typename AdjTile>
1511
1541
  inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1512
1542
  {
1513
- static_assert(AdjTile::Layout::Shape::N == 1);
1514
- static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM);
1543
+ static_assert(AdjTile::Layout::Shape::N == 1, "Expected AdjTile::Layout::Shape::N == 1");
1544
+ static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM");
1515
1545
 
1516
1546
  auto adj_reg = adj_ret.copy_to_register();
1517
1547
 
@@ -1521,9 +1551,9 @@ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1521
1551
  template <typename T, unsigned Length, typename AdjTile>
1522
1552
  inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
1523
1553
  {
1524
- static_assert(AdjTile::Layout::Shape::N == 2);
1525
- static_assert(AdjTile::Layout::Shape::dim(0) == Length);
1526
- static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM);
1554
+ static_assert(AdjTile::Layout::Shape::N == 2, "Expected AdjTile::Layout::Shape::N == 2");
1555
+ static_assert(AdjTile::Layout::Shape::dim(0) == Length, "Expected AdjTile::Layout::Shape::dim(0) == Length");
1556
+ static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM");
1527
1557
 
1528
1558
  auto adj_reg = adj_ret.copy_to_register();
1529
1559
 
@@ -1701,7 +1731,7 @@ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, arr
1701
1731
  if (adj_dest.data)
1702
1732
  src.data.grad = adj_dest.data;
1703
1733
 
1704
- if (src.data.grad == NULL)
1734
+ if (src.data.grad == nullptr)
1705
1735
  return;
1706
1736
 
1707
1737
  adj_t.grad_add(src);
@@ -1936,7 +1966,6 @@ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, i
1936
1966
  template<typename Tile, typename AdjTile>
1937
1967
  void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
1938
1968
 
1939
- #if WP_USE_REGISTER_GEMM
1940
1969
 
1941
1970
  namespace partitioned_gemm
1942
1971
  {
@@ -2042,9 +2071,11 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
2042
2071
  auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
2043
2072
  auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
2044
2073
 
2074
+ //static_assert(is_same<typename TileA::Type, typename TileB::Type>::value);
2075
+
2045
2076
  const int length = partition_size(C_tile);
2046
2077
 
2047
- for (int t=threadIdx.x; t < length; t += blockDim.x)
2078
+ for (int t=WP_TILE_THREAD_IDX; t < length; t += WP_TILE_BLOCK_DIM)
2048
2079
  {
2049
2080
  int i, j;
2050
2081
  partition_coord(C_tile, t, i, j);
@@ -2064,10 +2095,102 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
2064
2095
  partition_store(C_tile, i, j, sum);
2065
2096
  }
2066
2097
  }
2067
-
2068
- } // namespace partition_gemm
2069
2098
 
2070
- #endif // WP_USE_REGISTER_GEMM
2099
+ template <typename LayoutA, typename LayoutB, typename LayoutC, typename StorageA, typename StorageB, typename StorageC, typename T>
2100
+ inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T scale)
2101
+ {
2102
+ for (int t=WP_TILE_THREAD_IDX; t < LayoutC::Size; t += WP_TILE_BLOCK_DIM)
2103
+ {
2104
+ auto coord = LayoutC::coord_from_linear(t);
2105
+
2106
+ int i = coord[0];
2107
+ int j = coord[1];
2108
+
2109
+ // accumulator
2110
+ auto sum = C(coord)*scale;
2111
+
2112
+ WP_PRAGMA_UNROLL
2113
+ for (int k=0; k < LayoutA::Shape::dim(1); k++)
2114
+ {
2115
+ const auto a = A(tile_coord(i, k));
2116
+ const auto b = B(tile_coord(k, j));
2117
+
2118
+ sum = muladd<decltype(sum)>(a, b, sum);
2119
+ }
2120
+
2121
+ C(coord) = sum;
2122
+ }
2123
+ }
2124
+
2125
+ template <typename TileA, typename TileL>
2126
+ inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
2127
+ {
2128
+ using T = typename TileA::Type;
2129
+ constexpr int n = TileA::Layout::Shape::dim(1);
2130
+
2131
+ for (int j=0; j < n; ++j)
2132
+ {
2133
+ T s = A.data(tile_coord(j, j));
2134
+
2135
+ for (int k=0; k < j; ++k)
2136
+ {
2137
+ T r = L.data(tile_coord(j, k));
2138
+ s -= r * r;
2139
+ }
2140
+
2141
+ s = wp::sqrt(s);
2142
+ T invS = 1.0 / s;
2143
+
2144
+ L.data(tile_coord(j, j)) = s;
2145
+
2146
+ for (int i=j+1; i < n; ++i)
2147
+ {
2148
+ s = A.data(tile_coord(i, j));
2149
+
2150
+ for (int k=0; k < j; ++k)
2151
+ {
2152
+ s -= L.data(tile_coord(i, k)) * L.data(tile_coord(j, k));
2153
+ }
2154
+
2155
+ L.data(tile_coord(i, j)) = s * invS;
2156
+ }
2157
+
2158
+ // zero out upper triangular portion
2159
+ for (int k=j+1; k < n; ++k)
2160
+ {
2161
+ L.data(tile_coord(j,k)) = T(0.0);
2162
+ }
2163
+ }
2164
+ }
2165
+
2166
+ template <typename TileL, typename TileX, typename TileY>
2167
+ inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
2168
+ {
2169
+ using T = typename TileL::Type;
2170
+ constexpr int n = TileL::Layout::Shape::dim(1);
2171
+
2172
+ for (int i=0; i < n; ++i)
2173
+ {
2174
+ T s = Y.data(tile_coord(i));
2175
+
2176
+ for (int j=0; j < i; ++j)
2177
+ s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
2178
+
2179
+ X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2180
+ }
2181
+
2182
+ for (int i=n-1; i >= 0; --i)
2183
+ {
2184
+ T s = X.data(tile_coord(i));
2185
+
2186
+ for (int j=i+1; j < n; ++j)
2187
+ s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
2188
+
2189
+ X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2190
+ }
2191
+ }
2192
+
2193
+ } // namespace partition_gemm
2071
2194
 
2072
2195
 
2073
2196
  template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
@@ -2077,19 +2200,19 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
2077
2200
  using ShapeB = typename TileB::Layout::Shape;
2078
2201
  using ShapeC = typename TileC::Layout::Shape;
2079
2202
 
2080
- static_assert(ShapeA::N == 2);
2081
- static_assert(ShapeB::N == 2);
2082
- static_assert(ShapeC::N == 2);
2203
+ static_assert(ShapeA::N == 2, "Expected ShapeA::N == 2");
2204
+ static_assert(ShapeB::N == 2, "Expected ShapeB::N == 2");
2205
+ static_assert(ShapeC::N == 2, "Expected ShapeC::N == 2");
2083
2206
 
2084
- static_assert(ShapeA::dim(1) == ShapeB::dim(0));
2085
- static_assert(ShapeC::dim(0) == ShapeA::dim(0));
2086
- static_assert(ShapeC::dim(1) == ShapeB::dim(1));
2207
+ static_assert(ShapeA::dim(1) == ShapeB::dim(0), "Expected ShapeA::dim(1) == ShapeB::dim(0)");
2208
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
2209
+ static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
2087
2210
 
2088
2211
 
2089
2212
  using T = typename TileA::Type;
2090
2213
 
2091
- #if WP_USE_REGISTER_GEMM
2092
- partitioned_gemm::matmul(A, B, C);
2214
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2215
+ partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
2093
2216
  #else
2094
2217
  fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
2095
2218
  #endif
@@ -2099,6 +2222,7 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
2099
2222
  return C;
2100
2223
  }
2101
2224
 
2225
+
2102
2226
  // backward for the wp.tile_matmul(a, b, out) syntax
2103
2227
  template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
2104
2228
  void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
@@ -2106,8 +2230,17 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2106
2230
  {
2107
2231
  using T = typename TileA::Type;
2108
2232
 
2233
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2234
+ auto At = tile_transpose(A);
2235
+ auto Bt = tile_transpose(B);
2236
+
2237
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
2238
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
2239
+ #else
2109
2240
  fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
2110
2241
  fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
2242
+ #endif
2243
+
2111
2244
  WP_TILE_SYNC();
2112
2245
  }
2113
2246
 
@@ -2118,11 +2251,30 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2118
2251
  {
2119
2252
  using T = typename TileA::Type;
2120
2253
 
2254
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2255
+ auto At = tile_transpose(A);
2256
+ auto Bt = tile_transpose(B);
2257
+
2258
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
2259
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
2260
+ #else
2121
2261
  fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
2122
2262
  fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
2263
+ #endif
2264
+
2123
2265
  WP_TILE_SYNC();
2124
2266
  }
2125
2267
 
2268
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2269
+
2270
+ #define tile_fft()
2271
+ #define tile_ifft()
2272
+
2273
+ #define adj_tile_fft()
2274
+ #define adj_tile_ifft()
2275
+
2276
+ #else
2277
+
2126
2278
  // TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
2127
2279
  // and remove the need for __align__(16) dtypes data[...]
2128
2280
  #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
@@ -2158,12 +2310,21 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2158
2310
  tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
2159
2311
  } while (0)
2160
2312
 
2313
+ #endif // !defined(__CUDA_ARCH__)
2314
+
2161
2315
  template <typename Fwd, typename TileA, typename TileL>
2162
2316
  TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2163
2317
  {
2164
2318
  // Copy to L
2165
2319
  L = A;
2166
2320
 
2321
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2322
+
2323
+ partitioned_gemm::scalar_cholesky(A, L);
2324
+
2325
+ #else
2326
+
2327
+
2167
2328
  // Call cholesky on L
2168
2329
  WP_TILE_SYNC();
2169
2330
 
@@ -2174,7 +2335,7 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2174
2335
  // Zero-out the upper triangular part of L
2175
2336
 
2176
2337
  WP_PRAGMA_UNROLL
2177
- for (int i=threadIdx.x; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
2338
+ for (int i=WP_TILE_THREAD_IDX; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
2178
2339
  {
2179
2340
  auto c = TileL::Layout::coord_from_linear(i);
2180
2341
 
@@ -2183,7 +2344,9 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2183
2344
  }
2184
2345
 
2185
2346
  WP_TILE_SYNC();
2186
-
2347
+
2348
+ #endif
2349
+
2187
2350
  return L;
2188
2351
  }
2189
2352
 
@@ -2200,6 +2363,12 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2200
2363
 
2201
2364
  Y = X;
2202
2365
 
2366
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2367
+
2368
+ partitioned_gemm::scalar_cholesky_solve(L, X, Y);
2369
+
2370
+ #else
2371
+
2203
2372
  // Call cholesky solve on L & y
2204
2373
 
2205
2374
  WP_TILE_SYNC();
@@ -2208,6 +2377,8 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2208
2377
 
2209
2378
  WP_TILE_SYNC();
2210
2379
 
2380
+ #endif
2381
+
2211
2382
  return Y;
2212
2383
  }
2213
2384
 
@@ -2220,7 +2391,7 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2220
2391
  template <typename Tile>
2221
2392
  inline CUDA_CALLABLE auto tile_transpose(Tile& t)
2222
2393
  {
2223
- static_assert(Tile::Layout::Shape::N == 2);
2394
+ static_assert(Tile::Layout::Shape::N == 2, "Expected Tile::Layout::Shape::N == 2");
2224
2395
 
2225
2396
  // alias incoming tile
2226
2397
  constexpr int M = Tile::Layout::Shape::dim(0);
@@ -2241,13 +2412,34 @@ inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_
2241
2412
  adj_t.assign(tile_add(a,b));
2242
2413
  }
2243
2414
 
2415
+ template <int N, int StrideN, typename Tile>
2416
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2417
+ {
2418
+ // alias incoming tile with new strides
2419
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N>, tile_stride_t<StrideN>>, false>(t.data.ptr, t.grad.ptr);
2420
+ }
2421
+
2244
2422
  template <int M, int N, int StrideM, int StrideN, typename Tile>
2245
2423
  inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2246
- {
2424
+ {
2247
2425
  // alias incoming tile with new strides
2248
2426
  return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
2249
2427
  }
2250
2428
 
2429
+ template <int M, int N, int O, int StrideM, int StrideN, int StrideO, typename Tile>
2430
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2431
+ {
2432
+ // alias incoming tile with new strides
2433
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O>, tile_stride_t<StrideM, StrideN, StrideO>>, false>(t.data.ptr, t.grad.ptr);
2434
+ }
2435
+
2436
+ template <int M, int N, int O, int P, int StrideM, int StrideN, int StrideO, int StrideP, typename Tile>
2437
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2438
+ {
2439
+ // alias incoming tile with new strides
2440
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O, P>, tile_stride_t<StrideM, StrideN, StrideO, StrideP>>, false>(t.data.ptr, t.grad.ptr);
2441
+ }
2442
+
2251
2443
  template <typename Tile, typename AdjTile>
2252
2444
  inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
2253
2445
  {
@@ -2261,7 +2453,7 @@ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
2261
2453
 
2262
2454
  // return new tile with same strides
2263
2455
  typename Tile::Type* data_ptr = &t.data(c);
2264
- typename Tile::Type* grad_ptr = NULL;
2456
+ typename Tile::Type* grad_ptr = nullptr;
2265
2457
 
2266
2458
  if (t.grad.ptr)
2267
2459
  grad_ptr = &t.grad(c);
@@ -2306,7 +2498,7 @@ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offs
2306
2498
  {
2307
2499
  using Layout = typename TileB::Layout;
2308
2500
 
2309
- for (int t=threadIdx.x; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2501
+ for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2310
2502
  {
2311
2503
  auto c = Layout::coord_from_linear(t);
2312
2504
  dest.data(c + offset) = src.data(c);
@@ -2321,7 +2513,7 @@ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
2321
2513
  {
2322
2514
  using Layout = typename TileB::Layout;
2323
2515
 
2324
- for (int t=threadIdx.x; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2516
+ for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2325
2517
  {
2326
2518
  auto c = Layout::coord_from_linear(t);
2327
2519
  src.grad(c) += dest.grad(c + offset);
@@ -2360,14 +2552,14 @@ inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
2360
2552
  using ShapeB = typename TileB::Layout::Shape;
2361
2553
  using ShapeC = typename TileC::Layout::Shape;
2362
2554
 
2363
- static_assert(ShapeA::dim(0) == ShapeA::dim(1));
2364
- static_assert(ShapeB::dim(0) == ShapeA::dim(0));
2365
- static_assert(ShapeC::dim(0) == ShapeA::dim(0));
2366
- static_assert(ShapeC::dim(0) == ShapeC::dim(1));
2555
+ static_assert(ShapeA::dim(0) == ShapeA::dim(1), "Expected ShapeA::dim(0) == ShapeA::dim(1)");
2556
+ static_assert(ShapeB::dim(0) == ShapeA::dim(0), "Expected ShapeB::dim(0) == ShapeA::dim(0)");
2557
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
2558
+ static_assert(ShapeC::dim(0) == ShapeC::dim(1), "Expected ShapeC::dim(0) == ShapeC::dim(1)");
2367
2559
 
2368
2560
  c = a;
2369
2561
 
2370
- for (int t=threadIdx.x; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
2562
+ for (int t=WP_TILE_THREAD_IDX; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
2371
2563
  {
2372
2564
  c.data(tile_coord(t, t)) += b.data(tile_coord(t));
2373
2565
  }
@@ -2386,3 +2578,7 @@ inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTil
2386
2578
 
2387
2579
  } // namespace wp
2388
2580
 
2581
+
2582
+ #ifdef __clang__
2583
+ #pragma clang diagnostic pop
2584
+ #endif