warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/tile.h
CHANGED
|
@@ -19,9 +19,39 @@
|
|
|
19
19
|
|
|
20
20
|
#include "builtin.h"
|
|
21
21
|
|
|
22
|
+
#ifdef __clang__
|
|
23
|
+
// disable warnings related to C++17 extensions on CPU JIT builds
|
|
24
|
+
#pragma clang diagnostic push
|
|
25
|
+
#pragma clang diagnostic ignored "-Wc++17-extensions"
|
|
26
|
+
#endif // __clang__
|
|
27
|
+
|
|
28
|
+
// Check if the CUDA toolkit is available
|
|
29
|
+
#if WP_ENABLE_CUDA || defined(__CUDACC_RTC__)
|
|
30
|
+
|
|
31
|
+
// If NVRTC is being used, do not include extra headers (NVRTC has built-in float4)
|
|
32
|
+
#ifdef __CUDACC_RTC__
|
|
33
|
+
// NVRTC: Use built-in float4 (no need for extra definitions)
|
|
34
|
+
#else
|
|
35
|
+
// NVCC: Include vector_types.h to get float4
|
|
36
|
+
#include <cuda_runtime.h>
|
|
37
|
+
#endif
|
|
38
|
+
|
|
39
|
+
#else
|
|
40
|
+
// If CUDA is not available (e.g., macOS build), manually define float4
|
|
41
|
+
struct alignas(16) float4 {
|
|
42
|
+
float x, y, z, w;
|
|
43
|
+
};
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
// only used while building the warp core library
|
|
47
|
+
#ifndef WP_TILE_BLOCK_DIM
|
|
48
|
+
#define WP_TILE_BLOCK_DIM 256
|
|
49
|
+
#endif
|
|
50
|
+
|
|
22
51
|
#if !defined(__CUDA_ARCH__)
|
|
23
52
|
#define WP_TILE_SHARED static
|
|
24
53
|
#define WP_TILE_SYNC void
|
|
54
|
+
|
|
25
55
|
#else
|
|
26
56
|
#define WP_TILE_SHARED __shared__
|
|
27
57
|
#define WP_TILE_SYNC __syncthreads
|
|
@@ -46,6 +76,14 @@
|
|
|
46
76
|
#define WP_USE_ASYNC_PIPELINE 0
|
|
47
77
|
#define WP_USE_REGISTER_GEMM 0
|
|
48
78
|
|
|
79
|
+
#if defined(__CUDACC_RTC__)
|
|
80
|
+
#define WP_TILE_THREAD_IDX threadIdx.x
|
|
81
|
+
#else
|
|
82
|
+
#define WP_TILE_THREAD_IDX 0
|
|
83
|
+
#endif //
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
49
87
|
/* Tile Expressions
|
|
50
88
|
|
|
51
89
|
[ ] Tiles
|
|
@@ -217,14 +255,14 @@ constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
|
|
|
217
255
|
}
|
|
218
256
|
|
|
219
257
|
// helpers to construct a coord from a set of indices
|
|
220
|
-
auto tile_coord(int i)
|
|
258
|
+
inline auto tile_coord(int i)
|
|
221
259
|
{
|
|
222
260
|
auto c = tile_coord_t<1>();
|
|
223
261
|
c.indices[0] = i;
|
|
224
262
|
return c;
|
|
225
263
|
}
|
|
226
264
|
|
|
227
|
-
auto tile_coord(int i, int j)
|
|
265
|
+
inline auto tile_coord(int i, int j)
|
|
228
266
|
{
|
|
229
267
|
auto c = tile_coord_t<2>();
|
|
230
268
|
c.indices[0] = i;
|
|
@@ -232,7 +270,7 @@ auto tile_coord(int i, int j)
|
|
|
232
270
|
return c;
|
|
233
271
|
}
|
|
234
272
|
|
|
235
|
-
auto tile_coord(int i, int j, int k)
|
|
273
|
+
inline auto tile_coord(int i, int j, int k)
|
|
236
274
|
{
|
|
237
275
|
auto c = tile_coord_t<3>();
|
|
238
276
|
c.indices[0] = i;
|
|
@@ -241,7 +279,7 @@ auto tile_coord(int i, int j, int k)
|
|
|
241
279
|
return c;
|
|
242
280
|
}
|
|
243
281
|
|
|
244
|
-
auto tile_coord(int i, int j, int k, int l)
|
|
282
|
+
inline auto tile_coord(int i, int j, int k, int l)
|
|
245
283
|
{
|
|
246
284
|
auto c = tile_coord_t<4>();
|
|
247
285
|
c.indices[0] = i;
|
|
@@ -256,7 +294,7 @@ template <int... V>
|
|
|
256
294
|
struct tile_tuple_t
|
|
257
295
|
{
|
|
258
296
|
static constexpr int N = sizeof...(V);
|
|
259
|
-
static_assert(N > 0);
|
|
297
|
+
static_assert(N > 0, "Expected N > 0");
|
|
260
298
|
|
|
261
299
|
static constexpr int data[N] = { V... };
|
|
262
300
|
|
|
@@ -409,7 +447,7 @@ struct tile_layout_register_t
|
|
|
409
447
|
|
|
410
448
|
static inline CUDA_CALLABLE int linear_from_register(int reg)
|
|
411
449
|
{
|
|
412
|
-
return
|
|
450
|
+
return WP_TILE_THREAD_IDX + reg*WP_TILE_BLOCK_DIM;
|
|
413
451
|
}
|
|
414
452
|
|
|
415
453
|
static inline CUDA_CALLABLE int linear_from_coord(Coord c)
|
|
@@ -509,15 +547,6 @@ struct tile_register_t
|
|
|
509
547
|
return data[reg];
|
|
510
548
|
}
|
|
511
549
|
|
|
512
|
-
// Returns the number of valid registers for this tile
|
|
513
|
-
// i.e.: how many registers map to a valid coordinate.
|
|
514
|
-
// When a tile's size is not aligned to the block dimension
|
|
515
|
-
// some of the trailing registers may lie outside the valid range
|
|
516
|
-
inline CUDA_CALLABLE int valid() const
|
|
517
|
-
{
|
|
518
|
-
return (int)floor(float(Size - threadIdx.x - 1)/WP_TILE_BLOCK_DIM) + 1;
|
|
519
|
-
}
|
|
520
|
-
|
|
521
550
|
inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
|
|
522
551
|
{
|
|
523
552
|
for (int i=0; i < Layout::NumRegs; ++i)
|
|
@@ -544,7 +573,7 @@ struct tile_register_t
|
|
|
544
573
|
// ensure any previously scheduled threads have finished reading from scratch
|
|
545
574
|
WP_TILE_SYNC();
|
|
546
575
|
|
|
547
|
-
if (
|
|
576
|
+
if (WP_TILE_THREAD_IDX == thread)
|
|
548
577
|
{
|
|
549
578
|
scratch = data[reg];
|
|
550
579
|
}
|
|
@@ -565,7 +594,7 @@ struct tile_register_t
|
|
|
565
594
|
const int thread = Layout::thread_from_linear(linear);
|
|
566
595
|
const int reg = Layout::register_from_linear(linear);
|
|
567
596
|
|
|
568
|
-
if (
|
|
597
|
+
if (WP_TILE_THREAD_IDX == thread)
|
|
569
598
|
{
|
|
570
599
|
data[reg] += adj_ret;
|
|
571
600
|
}
|
|
@@ -668,7 +697,7 @@ struct tile_register_t
|
|
|
668
697
|
// users can either specify a template explicitly or
|
|
669
698
|
// pass in another concrete instance
|
|
670
699
|
template<typename Tile>
|
|
671
|
-
auto tile_register_like(Tile* t=
|
|
700
|
+
auto tile_register_like(Tile* t=nullptr)
|
|
672
701
|
{
|
|
673
702
|
using T = typename Tile::Type;
|
|
674
703
|
using L = typename Tile::Layout;
|
|
@@ -694,26 +723,39 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
|
|
|
694
723
|
return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
|
|
695
724
|
}
|
|
696
725
|
|
|
697
|
-
inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
|
|
726
|
+
inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, bool check=false)
|
|
698
727
|
{
|
|
699
728
|
// we maintain a per-thread offset into dynamic
|
|
700
729
|
// shared memory that allows us to keep track of
|
|
701
730
|
// current use across dynamic function calls
|
|
702
|
-
|
|
731
|
+
WP_TILE_SHARED int smem_base[WP_TILE_BLOCK_DIM];
|
|
703
732
|
|
|
704
733
|
if (init)
|
|
705
734
|
{
|
|
706
|
-
smem_base[
|
|
707
|
-
return
|
|
735
|
+
smem_base[WP_TILE_THREAD_IDX] = 0;
|
|
736
|
+
return nullptr;
|
|
737
|
+
}
|
|
738
|
+
else if (check)
|
|
739
|
+
{
|
|
740
|
+
assert(smem_base[WP_TILE_THREAD_IDX] == 0);
|
|
741
|
+
return nullptr;
|
|
708
742
|
}
|
|
709
743
|
else
|
|
710
744
|
{
|
|
711
|
-
const int offset = smem_base[
|
|
745
|
+
const int offset = smem_base[WP_TILE_THREAD_IDX];
|
|
712
746
|
|
|
713
747
|
// one entry per-thread so no need for synchronization
|
|
714
|
-
smem_base[
|
|
748
|
+
smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
|
|
715
749
|
|
|
750
|
+
#ifdef __CUDA_ARCH__
|
|
716
751
|
extern __shared__ char dynamic_smem_base[];
|
|
752
|
+
#else
|
|
753
|
+
// on CPU allocate a fixed 256k block to use for shared allocs
|
|
754
|
+
static const int max_cpu_shared = 256*1024;
|
|
755
|
+
static char dynamic_smem_base[max_cpu_shared];
|
|
756
|
+
|
|
757
|
+
assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
|
|
758
|
+
#endif
|
|
717
759
|
return &(dynamic_smem_base[offset]);
|
|
718
760
|
}
|
|
719
761
|
}
|
|
@@ -847,12 +889,12 @@ struct tile_shared_t
|
|
|
847
889
|
bool initialized;
|
|
848
890
|
|
|
849
891
|
// default initialization (non-initialized)
|
|
850
|
-
inline CUDA_CALLABLE tile_shared_t() : data(
|
|
892
|
+
inline CUDA_CALLABLE tile_shared_t() : data(nullptr), grad(nullptr), initialized(false)
|
|
851
893
|
{
|
|
852
894
|
}
|
|
853
895
|
|
|
854
896
|
// initialize from an existing tile's memory
|
|
855
|
-
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=
|
|
897
|
+
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
|
|
856
898
|
{
|
|
857
899
|
}
|
|
858
900
|
|
|
@@ -878,6 +920,7 @@ struct tile_shared_t
|
|
|
878
920
|
}
|
|
879
921
|
|
|
880
922
|
|
|
923
|
+
/*
|
|
881
924
|
// construct from another shared tile, this constructor
|
|
882
925
|
// is invoked for reshape operations like `wp.tile_transpose()`
|
|
883
926
|
template <typename OtherT, typename OtherLayout>
|
|
@@ -886,7 +929,7 @@ struct tile_shared_t
|
|
|
886
929
|
using OtherTile = tile_shared_t<OtherT, OtherLayout>;
|
|
887
930
|
|
|
888
931
|
// check dimensions are compatible
|
|
889
|
-
static_assert(Size == OtherTile::Size);
|
|
932
|
+
static_assert(Size == OtherTile::Size, "Expected Size == OtherTile::Size");
|
|
890
933
|
|
|
891
934
|
// alias tile directly
|
|
892
935
|
data = rhs.data;
|
|
@@ -895,6 +938,7 @@ struct tile_shared_t
|
|
|
895
938
|
|
|
896
939
|
return *this;
|
|
897
940
|
}
|
|
941
|
+
*/
|
|
898
942
|
|
|
899
943
|
// assign from a global tile (load)
|
|
900
944
|
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
|
|
@@ -912,7 +956,7 @@ struct tile_shared_t
|
|
|
912
956
|
if (initialized)
|
|
913
957
|
WP_TILE_SYNC();
|
|
914
958
|
|
|
915
|
-
for (int i=
|
|
959
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
916
960
|
data(i) = x;
|
|
917
961
|
|
|
918
962
|
initialized = true;
|
|
@@ -923,7 +967,7 @@ struct tile_shared_t
|
|
|
923
967
|
// in-place zero
|
|
924
968
|
inline CUDA_CALLABLE void zero()
|
|
925
969
|
{
|
|
926
|
-
for (int i=
|
|
970
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
927
971
|
data(i) = T(0);
|
|
928
972
|
|
|
929
973
|
WP_TILE_SYNC();
|
|
@@ -973,7 +1017,7 @@ struct tile_shared_t
|
|
|
973
1017
|
// in-place gradient zero
|
|
974
1018
|
inline CUDA_CALLABLE void grad_zero()
|
|
975
1019
|
{
|
|
976
|
-
for (int i=
|
|
1020
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
|
|
977
1021
|
grad(i) = T(0);
|
|
978
1022
|
|
|
979
1023
|
WP_TILE_SYNC();
|
|
@@ -1013,7 +1057,7 @@ struct tile_shared_t
|
|
|
1013
1057
|
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
1014
1058
|
{
|
|
1015
1059
|
WP_PRAGMA_UNROLL
|
|
1016
|
-
for (int i=
|
|
1060
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1017
1061
|
{
|
|
1018
1062
|
auto c = Layout::coord_from_linear(i);
|
|
1019
1063
|
T g = global.load_grad(c);
|
|
@@ -1081,6 +1125,8 @@ struct tile_shared_t
|
|
|
1081
1125
|
template <typename Global>
|
|
1082
1126
|
inline CUDA_CALLABLE void copy_to_global(const Global& dest)
|
|
1083
1127
|
{
|
|
1128
|
+
|
|
1129
|
+
#if defined(__CUDA_ARCH__)
|
|
1084
1130
|
// vectorized loads for specific input/output shapes
|
|
1085
1131
|
if constexpr (Layout::Shape::N == 2)
|
|
1086
1132
|
{
|
|
@@ -1109,7 +1155,7 @@ struct tile_shared_t
|
|
|
1109
1155
|
const int stride_j = 1;
|
|
1110
1156
|
|
|
1111
1157
|
WP_PRAGMA_UNROLL
|
|
1112
|
-
for (int i=
|
|
1158
|
+
for (int i=WP_TILE_THREAD_IDX; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1113
1159
|
{
|
|
1114
1160
|
auto c = SrcLayout::coord_from_linear(i);
|
|
1115
1161
|
|
|
@@ -1120,17 +1166,18 @@ struct tile_shared_t
|
|
|
1120
1166
|
}
|
|
1121
1167
|
}
|
|
1122
1168
|
|
|
1169
|
+
#endif //defined(__CUDA_ARCH__)
|
|
1170
|
+
|
|
1123
1171
|
// scalar bounds checked path
|
|
1124
1172
|
WP_PRAGMA_UNROLL
|
|
1125
|
-
for (int i=
|
|
1173
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1126
1174
|
{
|
|
1127
1175
|
auto c = Layout::coord_from_linear(i);
|
|
1128
1176
|
dest.store(c, data(i));
|
|
1129
1177
|
}
|
|
1130
1178
|
}
|
|
1131
1179
|
|
|
1132
|
-
|
|
1133
|
-
void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
|
|
1180
|
+
inline CUDA_CALLABLE void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
|
|
1134
1181
|
{
|
|
1135
1182
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
1136
1183
|
|
|
@@ -1152,8 +1199,7 @@ struct tile_shared_t
|
|
|
1152
1199
|
#endif
|
|
1153
1200
|
}
|
|
1154
1201
|
|
|
1155
|
-
|
|
1156
|
-
void cp_async_commit_and_wait_all_128()
|
|
1202
|
+
inline CUDA_CALLABLE void cp_async_commit_and_wait_all_128()
|
|
1157
1203
|
{
|
|
1158
1204
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
1159
1205
|
asm volatile(
|
|
@@ -1168,6 +1214,8 @@ struct tile_shared_t
|
|
|
1168
1214
|
if (initialized)
|
|
1169
1215
|
WP_TILE_SYNC();
|
|
1170
1216
|
|
|
1217
|
+
#if defined(__CUDA_ARCH__)
|
|
1218
|
+
|
|
1171
1219
|
// vectorized loads for specific input/output shapes
|
|
1172
1220
|
if constexpr (Layout::Shape::N == 2)
|
|
1173
1221
|
{
|
|
@@ -1196,7 +1244,7 @@ struct tile_shared_t
|
|
|
1196
1244
|
const int stride_j = 1;
|
|
1197
1245
|
|
|
1198
1246
|
WP_PRAGMA_UNROLL
|
|
1199
|
-
for (int i=
|
|
1247
|
+
for (int i=WP_TILE_THREAD_IDX; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1200
1248
|
{
|
|
1201
1249
|
auto c = DestLayout::coord_from_linear(i);
|
|
1202
1250
|
|
|
@@ -1217,9 +1265,11 @@ struct tile_shared_t
|
|
|
1217
1265
|
}
|
|
1218
1266
|
}
|
|
1219
1267
|
|
|
1268
|
+
#endif //defined(__CUDA_ARCH__)
|
|
1269
|
+
|
|
1220
1270
|
// scalar bounds checked path
|
|
1221
1271
|
WP_PRAGMA_UNROLL
|
|
1222
|
-
for (int i=
|
|
1272
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1223
1273
|
{
|
|
1224
1274
|
auto c = Layout::coord_from_linear(i);
|
|
1225
1275
|
data(i) = src.load(c);
|
|
@@ -1332,7 +1382,7 @@ struct tile_shared_t
|
|
|
1332
1382
|
|
|
1333
1383
|
inline CUDA_CALLABLE void print(bool reverse=false) const
|
|
1334
1384
|
{
|
|
1335
|
-
if (
|
|
1385
|
+
if (WP_TILE_THREAD_IDX != 0)
|
|
1336
1386
|
return;
|
|
1337
1387
|
|
|
1338
1388
|
if (reverse)
|
|
@@ -1359,13 +1409,13 @@ void tile_register_t<T, L>::print() const
|
|
|
1359
1409
|
// create a temporary shared tile so that
|
|
1360
1410
|
// we can print it deterministically
|
|
1361
1411
|
WP_TILE_SHARED T smem[L::Size];
|
|
1362
|
-
tile_shared_t<T, tile_layout_strided_t<typename L::Shape
|
|
1412
|
+
tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
|
|
1363
1413
|
|
|
1364
1414
|
scratch.assign(*this);
|
|
1365
1415
|
|
|
1366
1416
|
WP_TILE_SYNC();
|
|
1367
1417
|
|
|
1368
|
-
if (
|
|
1418
|
+
if (WP_TILE_THREAD_IDX == 0)
|
|
1369
1419
|
{
|
|
1370
1420
|
scratch.print_values(scratch.data, 0);
|
|
1371
1421
|
|
|
@@ -1392,7 +1442,7 @@ inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print()
|
|
|
1392
1442
|
template <typename T, typename L, bool O>
|
|
1393
1443
|
inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
|
|
1394
1444
|
{
|
|
1395
|
-
return
|
|
1445
|
+
return L::Shape::dim(0);
|
|
1396
1446
|
}
|
|
1397
1447
|
|
|
1398
1448
|
template <typename T, typename L, bool O, typename AdjTile>
|
|
@@ -1403,7 +1453,7 @@ inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile&
|
|
|
1403
1453
|
template <typename T, typename L>
|
|
1404
1454
|
inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
|
|
1405
1455
|
{
|
|
1406
|
-
return
|
|
1456
|
+
return L::Shape::dim(0);
|
|
1407
1457
|
}
|
|
1408
1458
|
|
|
1409
1459
|
template <typename T, typename L, typename AdjTile>
|
|
@@ -1425,12 +1475,16 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1425
1475
|
|
|
1426
1476
|
{ constexpr int size = Shape::size();
|
|
1427
1477
|
T* data = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1428
|
-
T* grad =
|
|
1478
|
+
T* grad = nullptr;
|
|
1429
1479
|
|
|
1430
1480
|
#if FP_CHECK
|
|
1431
1481
|
|
|
1432
|
-
|
|
1433
|
-
|
|
1482
|
+
// initialize tile to quiet nan
|
|
1483
|
+
uint32_t qnanbits = 0x7FC00000;
|
|
1484
|
+
float qnan = *(float*)(&qnanbits);
|
|
1485
|
+
|
|
1486
|
+
for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1487
|
+
data[i] = T(qnan);
|
|
1434
1488
|
|
|
1435
1489
|
WP_TILE_SYNC();
|
|
1436
1490
|
|
|
@@ -1441,7 +1495,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1441
1495
|
{
|
|
1442
1496
|
grad = (T*)tile_alloc_shared(size*sizeof(T));
|
|
1443
1497
|
|
|
1444
|
-
for (int i=
|
|
1498
|
+
for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
|
|
1445
1499
|
grad[i] = T(0);
|
|
1446
1500
|
|
|
1447
1501
|
WP_TILE_SYNC();
|
|
@@ -1450,30 +1504,6 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
|
|
|
1450
1504
|
return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
|
|
1451
1505
|
}
|
|
1452
1506
|
|
|
1453
|
-
template <typename T, int M, int N, bool RequiresGrad>
|
|
1454
|
-
inline CUDA_CALLABLE auto tile_alloc_zeros()
|
|
1455
|
-
{
|
|
1456
|
-
// compute the total storage required for the tile (may be different from M*N) for broadcast tiles
|
|
1457
|
-
constexpr int Len = M*N;
|
|
1458
|
-
T* data = (T*)tile_alloc_shared(Len*sizeof(T));
|
|
1459
|
-
T* grad = NULL;
|
|
1460
|
-
|
|
1461
|
-
for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
|
|
1462
|
-
data[i] = T(0);
|
|
1463
|
-
|
|
1464
|
-
if (RequiresGrad)
|
|
1465
|
-
{
|
|
1466
|
-
grad = (T*)tile_alloc_shared(Len*sizeof(T));
|
|
1467
|
-
|
|
1468
|
-
for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
|
|
1469
|
-
grad[i] = T(0);
|
|
1470
|
-
}
|
|
1471
|
-
|
|
1472
|
-
WP_TILE_SYNC();
|
|
1473
|
-
|
|
1474
|
-
return tile_shared_t<T, tile_layout_strided_t<tile_shape_t<M, N>>(data, grad);
|
|
1475
|
-
}
|
|
1476
|
-
|
|
1477
1507
|
|
|
1478
1508
|
//-----------------------------------------------------------------------------------------------------
|
|
1479
1509
|
// High level entry points for each op (correspond to one Warp builtin)
|
|
@@ -1485,7 +1515,7 @@ inline CUDA_CALLABLE auto tile(const T& x)
|
|
|
1485
1515
|
tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
|
|
1486
1516
|
|
|
1487
1517
|
using Layout = typename decltype(result)::Layout;
|
|
1488
|
-
static_assert(Layout::NumRegs == 1);
|
|
1518
|
+
static_assert(Layout::NumRegs == 1, "Expected Layout::NumRegs == 1");
|
|
1489
1519
|
|
|
1490
1520
|
result.data[0] = x;
|
|
1491
1521
|
return result;
|
|
@@ -1498,7 +1528,7 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
|
1498
1528
|
tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
|
|
1499
1529
|
|
|
1500
1530
|
using Layout = typename decltype(result)::Layout;
|
|
1501
|
-
static_assert(Layout::NumRegs == Length);
|
|
1531
|
+
static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
|
|
1502
1532
|
|
|
1503
1533
|
for (int i=0; i < Length; ++i)
|
|
1504
1534
|
result.data[i] = x[i];
|
|
@@ -1510,8 +1540,8 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
|
|
|
1510
1540
|
template <typename T, typename AdjTile>
|
|
1511
1541
|
inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
1512
1542
|
{
|
|
1513
|
-
static_assert(AdjTile::Layout::Shape::N == 1);
|
|
1514
|
-
static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM);
|
|
1543
|
+
static_assert(AdjTile::Layout::Shape::N == 1, "Expected AdjTile::Layout::Shape::N == 1");
|
|
1544
|
+
static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM");
|
|
1515
1545
|
|
|
1516
1546
|
auto adj_reg = adj_ret.copy_to_register();
|
|
1517
1547
|
|
|
@@ -1521,9 +1551,9 @@ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
|
|
|
1521
1551
|
template <typename T, unsigned Length, typename AdjTile>
|
|
1522
1552
|
inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
|
|
1523
1553
|
{
|
|
1524
|
-
static_assert(AdjTile::Layout::Shape::N == 2);
|
|
1525
|
-
static_assert(AdjTile::Layout::Shape::dim(0) == Length);
|
|
1526
|
-
static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM);
|
|
1554
|
+
static_assert(AdjTile::Layout::Shape::N == 2, "Expected AdjTile::Layout::Shape::N == 2");
|
|
1555
|
+
static_assert(AdjTile::Layout::Shape::dim(0) == Length, "Expected AdjTile::Layout::Shape::dim(0) == Length");
|
|
1556
|
+
static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM");
|
|
1527
1557
|
|
|
1528
1558
|
auto adj_reg = adj_ret.copy_to_register();
|
|
1529
1559
|
|
|
@@ -1701,7 +1731,7 @@ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, arr
|
|
|
1701
1731
|
if (adj_dest.data)
|
|
1702
1732
|
src.data.grad = adj_dest.data;
|
|
1703
1733
|
|
|
1704
|
-
if (src.data.grad ==
|
|
1734
|
+
if (src.data.grad == nullptr)
|
|
1705
1735
|
return;
|
|
1706
1736
|
|
|
1707
1737
|
adj_t.grad_add(src);
|
|
@@ -1936,7 +1966,6 @@ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, i
|
|
|
1936
1966
|
template<typename Tile, typename AdjTile>
|
|
1937
1967
|
void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
|
|
1938
1968
|
|
|
1939
|
-
#if WP_USE_REGISTER_GEMM
|
|
1940
1969
|
|
|
1941
1970
|
namespace partitioned_gemm
|
|
1942
1971
|
{
|
|
@@ -2042,9 +2071,11 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
|
|
|
2042
2071
|
auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
|
|
2043
2072
|
auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
|
|
2044
2073
|
|
|
2074
|
+
//static_assert(is_same<typename TileA::Type, typename TileB::Type>::value);
|
|
2075
|
+
|
|
2045
2076
|
const int length = partition_size(C_tile);
|
|
2046
2077
|
|
|
2047
|
-
for (int t=
|
|
2078
|
+
for (int t=WP_TILE_THREAD_IDX; t < length; t += WP_TILE_BLOCK_DIM)
|
|
2048
2079
|
{
|
|
2049
2080
|
int i, j;
|
|
2050
2081
|
partition_coord(C_tile, t, i, j);
|
|
@@ -2064,10 +2095,102 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
|
|
|
2064
2095
|
partition_store(C_tile, i, j, sum);
|
|
2065
2096
|
}
|
|
2066
2097
|
}
|
|
2067
|
-
|
|
2068
|
-
} // namespace partition_gemm
|
|
2069
2098
|
|
|
2070
|
-
|
|
2099
|
+
template <typename LayoutA, typename LayoutB, typename LayoutC, typename StorageA, typename StorageB, typename StorageC, typename T>
|
|
2100
|
+
inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T scale)
|
|
2101
|
+
{
|
|
2102
|
+
for (int t=WP_TILE_THREAD_IDX; t < LayoutC::Size; t += WP_TILE_BLOCK_DIM)
|
|
2103
|
+
{
|
|
2104
|
+
auto coord = LayoutC::coord_from_linear(t);
|
|
2105
|
+
|
|
2106
|
+
int i = coord[0];
|
|
2107
|
+
int j = coord[1];
|
|
2108
|
+
|
|
2109
|
+
// accumulator
|
|
2110
|
+
auto sum = C(coord)*scale;
|
|
2111
|
+
|
|
2112
|
+
WP_PRAGMA_UNROLL
|
|
2113
|
+
for (int k=0; k < LayoutA::Shape::dim(1); k++)
|
|
2114
|
+
{
|
|
2115
|
+
const auto a = A(tile_coord(i, k));
|
|
2116
|
+
const auto b = B(tile_coord(k, j));
|
|
2117
|
+
|
|
2118
|
+
sum = muladd<decltype(sum)>(a, b, sum);
|
|
2119
|
+
}
|
|
2120
|
+
|
|
2121
|
+
C(coord) = sum;
|
|
2122
|
+
}
|
|
2123
|
+
}
|
|
2124
|
+
|
|
2125
|
+
template <typename TileA, typename TileL>
|
|
2126
|
+
inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
|
|
2127
|
+
{
|
|
2128
|
+
using T = typename TileA::Type;
|
|
2129
|
+
constexpr int n = TileA::Layout::Shape::dim(1);
|
|
2130
|
+
|
|
2131
|
+
for (int j=0; j < n; ++j)
|
|
2132
|
+
{
|
|
2133
|
+
T s = A.data(tile_coord(j, j));
|
|
2134
|
+
|
|
2135
|
+
for (int k=0; k < j; ++k)
|
|
2136
|
+
{
|
|
2137
|
+
T r = L.data(tile_coord(j, k));
|
|
2138
|
+
s -= r * r;
|
|
2139
|
+
}
|
|
2140
|
+
|
|
2141
|
+
s = wp::sqrt(s);
|
|
2142
|
+
T invS = 1.0 / s;
|
|
2143
|
+
|
|
2144
|
+
L.data(tile_coord(j, j)) = s;
|
|
2145
|
+
|
|
2146
|
+
for (int i=j+1; i < n; ++i)
|
|
2147
|
+
{
|
|
2148
|
+
s = A.data(tile_coord(i, j));
|
|
2149
|
+
|
|
2150
|
+
for (int k=0; k < j; ++k)
|
|
2151
|
+
{
|
|
2152
|
+
s -= L.data(tile_coord(i, k)) * L.data(tile_coord(j, k));
|
|
2153
|
+
}
|
|
2154
|
+
|
|
2155
|
+
L.data(tile_coord(i, j)) = s * invS;
|
|
2156
|
+
}
|
|
2157
|
+
|
|
2158
|
+
// zero out upper triangular portion
|
|
2159
|
+
for (int k=j+1; k < n; ++k)
|
|
2160
|
+
{
|
|
2161
|
+
L.data(tile_coord(j,k)) = T(0.0);
|
|
2162
|
+
}
|
|
2163
|
+
}
|
|
2164
|
+
}
|
|
2165
|
+
|
|
2166
|
+
template <typename TileL, typename TileX, typename TileY>
|
|
2167
|
+
inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
|
|
2168
|
+
{
|
|
2169
|
+
using T = typename TileL::Type;
|
|
2170
|
+
constexpr int n = TileL::Layout::Shape::dim(1);
|
|
2171
|
+
|
|
2172
|
+
for (int i=0; i < n; ++i)
|
|
2173
|
+
{
|
|
2174
|
+
T s = Y.data(tile_coord(i));
|
|
2175
|
+
|
|
2176
|
+
for (int j=0; j < i; ++j)
|
|
2177
|
+
s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
|
|
2178
|
+
|
|
2179
|
+
X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
|
|
2180
|
+
}
|
|
2181
|
+
|
|
2182
|
+
for (int i=n-1; i >= 0; --i)
|
|
2183
|
+
{
|
|
2184
|
+
T s = X.data(tile_coord(i));
|
|
2185
|
+
|
|
2186
|
+
for (int j=i+1; j < n; ++j)
|
|
2187
|
+
s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
|
|
2188
|
+
|
|
2189
|
+
X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
|
|
2190
|
+
}
|
|
2191
|
+
}
|
|
2192
|
+
|
|
2193
|
+
} // namespace partition_gemm
|
|
2071
2194
|
|
|
2072
2195
|
|
|
2073
2196
|
template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
@@ -2077,19 +2200,19 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
|
|
|
2077
2200
|
using ShapeB = typename TileB::Layout::Shape;
|
|
2078
2201
|
using ShapeC = typename TileC::Layout::Shape;
|
|
2079
2202
|
|
|
2080
|
-
static_assert(ShapeA::N == 2);
|
|
2081
|
-
static_assert(ShapeB::N == 2);
|
|
2082
|
-
static_assert(ShapeC::N == 2);
|
|
2203
|
+
static_assert(ShapeA::N == 2, "Expected ShapeA::N == 2");
|
|
2204
|
+
static_assert(ShapeB::N == 2, "Expected ShapeB::N == 2");
|
|
2205
|
+
static_assert(ShapeC::N == 2, "Expected ShapeC::N == 2");
|
|
2083
2206
|
|
|
2084
|
-
static_assert(ShapeA::dim(1) == ShapeB::dim(0));
|
|
2085
|
-
static_assert(ShapeC::dim(0) == ShapeA::dim(0));
|
|
2086
|
-
static_assert(ShapeC::dim(1) == ShapeB::dim(1));
|
|
2207
|
+
static_assert(ShapeA::dim(1) == ShapeB::dim(0), "Expected ShapeA::dim(1) == ShapeB::dim(0)");
|
|
2208
|
+
static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
|
|
2209
|
+
static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
|
|
2087
2210
|
|
|
2088
2211
|
|
|
2089
2212
|
using T = typename TileA::Type;
|
|
2090
2213
|
|
|
2091
|
-
#if
|
|
2092
|
-
partitioned_gemm::
|
|
2214
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2215
|
+
partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
|
|
2093
2216
|
#else
|
|
2094
2217
|
fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
|
|
2095
2218
|
#endif
|
|
@@ -2099,6 +2222,7 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
|
|
|
2099
2222
|
return C;
|
|
2100
2223
|
}
|
|
2101
2224
|
|
|
2225
|
+
|
|
2102
2226
|
// backward for the wp.tile_matmul(a, b, out) syntax
|
|
2103
2227
|
template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
|
|
2104
2228
|
void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
|
|
@@ -2106,8 +2230,17 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2106
2230
|
{
|
|
2107
2231
|
using T = typename TileA::Type;
|
|
2108
2232
|
|
|
2233
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2234
|
+
auto At = tile_transpose(A);
|
|
2235
|
+
auto Bt = tile_transpose(B);
|
|
2236
|
+
|
|
2237
|
+
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
|
|
2238
|
+
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
|
|
2239
|
+
#else
|
|
2109
2240
|
fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
|
|
2110
2241
|
fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
|
|
2242
|
+
#endif
|
|
2243
|
+
|
|
2111
2244
|
WP_TILE_SYNC();
|
|
2112
2245
|
}
|
|
2113
2246
|
|
|
@@ -2118,11 +2251,30 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2118
2251
|
{
|
|
2119
2252
|
using T = typename TileA::Type;
|
|
2120
2253
|
|
|
2254
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2255
|
+
auto At = tile_transpose(A);
|
|
2256
|
+
auto Bt = tile_transpose(B);
|
|
2257
|
+
|
|
2258
|
+
partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
|
|
2259
|
+
partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
|
|
2260
|
+
#else
|
|
2121
2261
|
fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
|
|
2122
2262
|
fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
|
|
2263
|
+
#endif
|
|
2264
|
+
|
|
2123
2265
|
WP_TILE_SYNC();
|
|
2124
2266
|
}
|
|
2125
2267
|
|
|
2268
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2269
|
+
|
|
2270
|
+
#define tile_fft()
|
|
2271
|
+
#define tile_ifft()
|
|
2272
|
+
|
|
2273
|
+
#define adj_tile_fft()
|
|
2274
|
+
#define adj_tile_ifft()
|
|
2275
|
+
|
|
2276
|
+
#else
|
|
2277
|
+
|
|
2126
2278
|
// TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
|
|
2127
2279
|
// and remove the need for __align__(16) dtypes data[...]
|
|
2128
2280
|
#define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
|
|
@@ -2158,12 +2310,21 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
|
|
|
2158
2310
|
tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
|
|
2159
2311
|
} while (0)
|
|
2160
2312
|
|
|
2313
|
+
#endif // !defined(__CUDA_ARCH__)
|
|
2314
|
+
|
|
2161
2315
|
template <typename Fwd, typename TileA, typename TileL>
|
|
2162
2316
|
TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
2163
2317
|
{
|
|
2164
2318
|
// Copy to L
|
|
2165
2319
|
L = A;
|
|
2166
2320
|
|
|
2321
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2322
|
+
|
|
2323
|
+
partitioned_gemm::scalar_cholesky(A, L);
|
|
2324
|
+
|
|
2325
|
+
#else
|
|
2326
|
+
|
|
2327
|
+
|
|
2167
2328
|
// Call cholesky on L
|
|
2168
2329
|
WP_TILE_SYNC();
|
|
2169
2330
|
|
|
@@ -2174,7 +2335,7 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
|
2174
2335
|
// Zero-out the upper triangular part of L
|
|
2175
2336
|
|
|
2176
2337
|
WP_PRAGMA_UNROLL
|
|
2177
|
-
for (int i=
|
|
2338
|
+
for (int i=WP_TILE_THREAD_IDX; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
2178
2339
|
{
|
|
2179
2340
|
auto c = TileL::Layout::coord_from_linear(i);
|
|
2180
2341
|
|
|
@@ -2183,7 +2344,9 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
|
|
|
2183
2344
|
}
|
|
2184
2345
|
|
|
2185
2346
|
WP_TILE_SYNC();
|
|
2186
|
-
|
|
2347
|
+
|
|
2348
|
+
#endif
|
|
2349
|
+
|
|
2187
2350
|
return L;
|
|
2188
2351
|
}
|
|
2189
2352
|
|
|
@@ -2200,6 +2363,12 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
|
2200
2363
|
|
|
2201
2364
|
Y = X;
|
|
2202
2365
|
|
|
2366
|
+
#if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
|
|
2367
|
+
|
|
2368
|
+
partitioned_gemm::scalar_cholesky_solve(L, X, Y);
|
|
2369
|
+
|
|
2370
|
+
#else
|
|
2371
|
+
|
|
2203
2372
|
// Call cholesky solve on L & y
|
|
2204
2373
|
|
|
2205
2374
|
WP_TILE_SYNC();
|
|
@@ -2208,6 +2377,8 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
|
2208
2377
|
|
|
2209
2378
|
WP_TILE_SYNC();
|
|
2210
2379
|
|
|
2380
|
+
#endif
|
|
2381
|
+
|
|
2211
2382
|
return Y;
|
|
2212
2383
|
}
|
|
2213
2384
|
|
|
@@ -2220,7 +2391,7 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
|
|
|
2220
2391
|
template <typename Tile>
|
|
2221
2392
|
inline CUDA_CALLABLE auto tile_transpose(Tile& t)
|
|
2222
2393
|
{
|
|
2223
|
-
static_assert(Tile::Layout::Shape::N == 2);
|
|
2394
|
+
static_assert(Tile::Layout::Shape::N == 2, "Expected Tile::Layout::Shape::N == 2");
|
|
2224
2395
|
|
|
2225
2396
|
// alias incoming tile
|
|
2226
2397
|
constexpr int M = Tile::Layout::Shape::dim(0);
|
|
@@ -2241,13 +2412,34 @@ inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_
|
|
|
2241
2412
|
adj_t.assign(tile_add(a,b));
|
|
2242
2413
|
}
|
|
2243
2414
|
|
|
2415
|
+
template <int N, int StrideN, typename Tile>
|
|
2416
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2417
|
+
{
|
|
2418
|
+
// alias incoming tile with new strides
|
|
2419
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N>, tile_stride_t<StrideN>>, false>(t.data.ptr, t.grad.ptr);
|
|
2420
|
+
}
|
|
2421
|
+
|
|
2244
2422
|
template <int M, int N, int StrideM, int StrideN, typename Tile>
|
|
2245
2423
|
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2246
|
-
{
|
|
2424
|
+
{
|
|
2247
2425
|
// alias incoming tile with new strides
|
|
2248
2426
|
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
|
|
2249
2427
|
}
|
|
2250
2428
|
|
|
2429
|
+
template <int M, int N, int O, int StrideM, int StrideN, int StrideO, typename Tile>
|
|
2430
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2431
|
+
{
|
|
2432
|
+
// alias incoming tile with new strides
|
|
2433
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O>, tile_stride_t<StrideM, StrideN, StrideO>>, false>(t.data.ptr, t.grad.ptr);
|
|
2434
|
+
}
|
|
2435
|
+
|
|
2436
|
+
template <int M, int N, int O, int P, int StrideM, int StrideN, int StrideO, int StrideP, typename Tile>
|
|
2437
|
+
inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
|
|
2438
|
+
{
|
|
2439
|
+
// alias incoming tile with new strides
|
|
2440
|
+
return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O, P>, tile_stride_t<StrideM, StrideN, StrideO, StrideP>>, false>(t.data.ptr, t.grad.ptr);
|
|
2441
|
+
}
|
|
2442
|
+
|
|
2251
2443
|
template <typename Tile, typename AdjTile>
|
|
2252
2444
|
inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
2253
2445
|
{
|
|
@@ -2261,7 +2453,7 @@ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
|
|
|
2261
2453
|
|
|
2262
2454
|
// return new tile with same strides
|
|
2263
2455
|
typename Tile::Type* data_ptr = &t.data(c);
|
|
2264
|
-
typename Tile::Type* grad_ptr =
|
|
2456
|
+
typename Tile::Type* grad_ptr = nullptr;
|
|
2265
2457
|
|
|
2266
2458
|
if (t.grad.ptr)
|
|
2267
2459
|
grad_ptr = &t.grad(c);
|
|
@@ -2306,7 +2498,7 @@ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offs
|
|
|
2306
2498
|
{
|
|
2307
2499
|
using Layout = typename TileB::Layout;
|
|
2308
2500
|
|
|
2309
|
-
for (int t=
|
|
2501
|
+
for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
|
|
2310
2502
|
{
|
|
2311
2503
|
auto c = Layout::coord_from_linear(t);
|
|
2312
2504
|
dest.data(c + offset) = src.data(c);
|
|
@@ -2321,7 +2513,7 @@ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
|
|
|
2321
2513
|
{
|
|
2322
2514
|
using Layout = typename TileB::Layout;
|
|
2323
2515
|
|
|
2324
|
-
for (int t=
|
|
2516
|
+
for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
|
|
2325
2517
|
{
|
|
2326
2518
|
auto c = Layout::coord_from_linear(t);
|
|
2327
2519
|
src.grad(c) += dest.grad(c + offset);
|
|
@@ -2360,14 +2552,14 @@ inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
|
|
|
2360
2552
|
using ShapeB = typename TileB::Layout::Shape;
|
|
2361
2553
|
using ShapeC = typename TileC::Layout::Shape;
|
|
2362
2554
|
|
|
2363
|
-
static_assert(ShapeA::dim(0) == ShapeA::dim(1));
|
|
2364
|
-
static_assert(ShapeB::dim(0) == ShapeA::dim(0));
|
|
2365
|
-
static_assert(ShapeC::dim(0) == ShapeA::dim(0));
|
|
2366
|
-
static_assert(ShapeC::dim(0) == ShapeC::dim(1));
|
|
2555
|
+
static_assert(ShapeA::dim(0) == ShapeA::dim(1), "Expected ShapeA::dim(0) == ShapeA::dim(1)");
|
|
2556
|
+
static_assert(ShapeB::dim(0) == ShapeA::dim(0), "Expected ShapeB::dim(0) == ShapeA::dim(0)");
|
|
2557
|
+
static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
|
|
2558
|
+
static_assert(ShapeC::dim(0) == ShapeC::dim(1), "Expected ShapeC::dim(0) == ShapeC::dim(1)");
|
|
2367
2559
|
|
|
2368
2560
|
c = a;
|
|
2369
2561
|
|
|
2370
|
-
for (int t=
|
|
2562
|
+
for (int t=WP_TILE_THREAD_IDX; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
|
|
2371
2563
|
{
|
|
2372
2564
|
c.data(tile_coord(t, t)) += b.data(tile_coord(t));
|
|
2373
2565
|
}
|
|
@@ -2386,3 +2578,7 @@ inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTil
|
|
|
2386
2578
|
|
|
2387
2579
|
} // namespace wp
|
|
2388
2580
|
|
|
2581
|
+
|
|
2582
|
+
#ifdef __clang__
|
|
2583
|
+
#pragma clang diagnostic pop
|
|
2584
|
+
#endif
|