warp-lang 1.8.1__py3-none-win_amd64.whl → 1.9.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +1904 -114
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +331 -101
- warp/builtins.py +1244 -160
- warp/codegen.py +317 -206
- warp/config.py +1 -1
- warp/context.py +1465 -789
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/examples/interop/example_jax_kernel.py +2 -1
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +264 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +129 -51
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +25 -2
- warp/jax_experimental/ffi.py +22 -1
- warp/jax_experimental/xla_ffi.py +16 -7
- warp/marching_cubes.py +708 -0
- warp/native/array.h +99 -4
- warp/native/builtin.h +86 -9
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +8 -2
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +41 -10
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +1910 -116
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +4 -2
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +331 -14
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +40 -31
- warp/native/sort.h +2 -0
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +13 -13
- warp/native/spatial.h +366 -17
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +471 -82
- warp/native/vec.h +328 -14
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +377 -216
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +99 -18
- warp/render/render_usd.py +1 -0
- warp/sim/graph_coloring.py +2 -2
- warp/sparse.py +558 -175
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/interop/test_jax.py +608 -28
- warp/tests/sim/test_coloring.py +6 -6
- warp/tests/test_array.py +58 -5
- warp/tests/test_codegen.py +4 -3
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +49 -6
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +15 -1
- warp/tests/test_mat.py +1518 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +140 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +71 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +61 -20
- warp/tests/test_vec.py +179 -34
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/tile/test_tile.py +245 -18
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_shared_memory.py +5 -5
- warp/tests/unittest_suites.py +6 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +571 -267
- warp/utils.py +68 -86
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/native/tile.h
CHANGED
|
@@ -230,7 +230,9 @@ struct tile_coord_t
|
|
|
230
230
|
out.indices[i] = indices[i] + c.indices[i];
|
|
231
231
|
}
|
|
232
232
|
return out;
|
|
233
|
-
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
static constexpr int size() { return N; }
|
|
234
236
|
};
|
|
235
237
|
|
|
236
238
|
// This function deduces N = sizeof...(Ints)
|
|
@@ -338,7 +340,8 @@ using tile_stride_t = tile_tuple_t<V...>;
|
|
|
338
340
|
|
|
339
341
|
// represents a tile stored in global memory with dynamic strides
|
|
340
342
|
// used to represent the source and offset for tile loads to register/shared
|
|
341
|
-
|
|
343
|
+
// BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
|
|
344
|
+
template <typename T, typename Shape_, bool BoundsCheck=true>
|
|
342
345
|
struct tile_global_t
|
|
343
346
|
{
|
|
344
347
|
using Type = T;
|
|
@@ -370,25 +373,33 @@ struct tile_global_t
|
|
|
370
373
|
|
|
371
374
|
inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
|
|
372
375
|
{
|
|
373
|
-
|
|
374
|
-
int index = 0;
|
|
375
|
-
|
|
376
|
-
WP_PRAGMA_UNROLL
|
|
377
|
-
for (int i=0; i < Shape::N; ++i)
|
|
376
|
+
if constexpr (BoundsCheck)
|
|
378
377
|
{
|
|
379
|
-
//
|
|
380
|
-
int
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
378
|
+
// element index
|
|
379
|
+
int index = 0;
|
|
380
|
+
|
|
381
|
+
WP_PRAGMA_UNROLL
|
|
382
|
+
for (int i=0; i < Shape::N; ++i)
|
|
383
|
+
{
|
|
384
|
+
// global = offset + coord
|
|
385
|
+
int c = offset[i] + coord[i];
|
|
386
|
+
|
|
387
|
+
// handle out of bounds case
|
|
388
|
+
if (c >= data.shape[i])
|
|
389
|
+
return false;
|
|
390
|
+
else
|
|
391
|
+
index += data.strides[i]*c;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
// array strides are in bytes so we convert to elements
|
|
395
|
+
out = index / sizeof(T);
|
|
396
|
+
return true;
|
|
397
|
+
}
|
|
398
|
+
else
|
|
399
|
+
{
|
|
400
|
+
out = index_from_coord(coord);
|
|
401
|
+
return true;
|
|
402
|
+
}
|
|
392
403
|
}
|
|
393
404
|
|
|
394
405
|
inline CUDA_CALLABLE T load(const Coord& coord) const
|
|
@@ -435,6 +446,7 @@ struct tile_global_t
|
|
|
435
446
|
}
|
|
436
447
|
};
|
|
437
448
|
|
|
449
|
+
|
|
438
450
|
template <typename Shape_>
|
|
439
451
|
struct tile_layout_register_t
|
|
440
452
|
{
|
|
@@ -521,7 +533,8 @@ struct tile_register_t
|
|
|
521
533
|
data[i] = value;
|
|
522
534
|
}
|
|
523
535
|
|
|
524
|
-
|
|
536
|
+
template <bool BoundsCheck>
|
|
537
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
|
|
525
538
|
{
|
|
526
539
|
copy_from_global(t);
|
|
527
540
|
return *this;
|
|
@@ -529,7 +542,7 @@ struct tile_register_t
|
|
|
529
542
|
|
|
530
543
|
// define the += operator which is used during backward pass codegen
|
|
531
544
|
// when returning a register tile from a user defined function
|
|
532
|
-
inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
|
|
545
|
+
inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, Layout>& rhs)
|
|
533
546
|
{
|
|
534
547
|
grad_add(rhs);
|
|
535
548
|
return *this;
|
|
@@ -645,10 +658,9 @@ struct tile_register_t
|
|
|
645
658
|
data[i] += tile.data[i];
|
|
646
659
|
}
|
|
647
660
|
|
|
648
|
-
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
661
|
+
inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
649
662
|
{
|
|
650
|
-
apply([&](int reg, auto c) {data[reg]
|
|
651
|
-
|
|
663
|
+
apply([&](int reg, auto c) {data[reg] += global.load_grad(c);});
|
|
652
664
|
}
|
|
653
665
|
|
|
654
666
|
inline CUDA_CALLABLE auto& grad_to_register()
|
|
@@ -746,6 +758,7 @@ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, boo
|
|
|
746
758
|
|
|
747
759
|
// one entry per-thread so no need for synchronization
|
|
748
760
|
smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
|
|
761
|
+
assert(smem_base[WP_TILE_THREAD_IDX] >= 0);
|
|
749
762
|
|
|
750
763
|
#ifdef __CUDA_ARCH__
|
|
751
764
|
extern __shared__ char dynamic_smem_base[];
|
|
@@ -893,6 +906,28 @@ struct tile_shared_t
|
|
|
893
906
|
{
|
|
894
907
|
}
|
|
895
908
|
|
|
909
|
+
// we delete the copy constructor because in the case the shared tile is owning,
|
|
910
|
+
// this leads to a double deallocation.
|
|
911
|
+
// this also forces one to handle copies explicitly
|
|
912
|
+
inline CUDA_CALLABLE tile_shared_t(const tile_shared_t& other) : data(other.data), grad(other.grad), initialized(other.initialized)
|
|
913
|
+
{
|
|
914
|
+
static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
// move constructor
|
|
918
|
+
inline CUDA_CALLABLE tile_shared_t(tile_shared_t&& other) : data(other.data), grad(other.grad), initialized(other.initialized)
|
|
919
|
+
{
|
|
920
|
+
other.data.ptr = nullptr;
|
|
921
|
+
other.grad.ptr = nullptr;
|
|
922
|
+
}
|
|
923
|
+
|
|
924
|
+
template <typename OtherT, typename OtherLayout, bool OtherOwner>
|
|
925
|
+
inline CUDA_CALLABLE tile_shared_t(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& other) : data(other.data.ptr), grad(other.grad.ptr), initialized(other.initialized)
|
|
926
|
+
{
|
|
927
|
+
static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
|
|
928
|
+
static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
|
|
929
|
+
}
|
|
930
|
+
|
|
896
931
|
// initialize from an existing tile's memory
|
|
897
932
|
inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
|
|
898
933
|
{
|
|
@@ -920,22 +955,52 @@ struct tile_shared_t
|
|
|
920
955
|
|
|
921
956
|
// construct from another shared tile, this constructor
|
|
922
957
|
// is invoked for reshape operations like `wp.tile_transpose()`
|
|
958
|
+
// or `wp::copy()`
|
|
923
959
|
template <typename OtherT, typename OtherLayout, bool OtherOwner>
|
|
924
960
|
inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& rhs)
|
|
925
961
|
{
|
|
926
962
|
// check dimensions are compatible
|
|
927
963
|
static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
|
|
928
964
|
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
965
|
+
|
|
966
|
+
if (Owner)
|
|
967
|
+
{
|
|
968
|
+
// if the tile owns the data we need to copy
|
|
969
|
+
assign(rhs);
|
|
970
|
+
}
|
|
971
|
+
else
|
|
972
|
+
{
|
|
973
|
+
// alias tile directly
|
|
974
|
+
data.ptr = rhs.data.ptr;
|
|
975
|
+
grad.ptr = rhs.grad.ptr;
|
|
976
|
+
initialized = rhs.initialized;
|
|
977
|
+
}
|
|
933
978
|
|
|
934
979
|
return *this;
|
|
935
|
-
}
|
|
980
|
+
}
|
|
981
|
+
|
|
982
|
+
inline CUDA_CALLABLE auto& operator=(const tile_shared_t& rhs)
|
|
983
|
+
{
|
|
984
|
+
if (Owner)
|
|
985
|
+
{
|
|
986
|
+
// if the tile owns the data we need to copy
|
|
987
|
+
assign(rhs);
|
|
988
|
+
}
|
|
989
|
+
else
|
|
990
|
+
{
|
|
991
|
+
// alias tile directly
|
|
992
|
+
data.ptr = rhs.data.ptr;
|
|
993
|
+
grad.ptr = rhs.grad.ptr;
|
|
994
|
+
initialized = rhs.initialized;
|
|
995
|
+
}
|
|
996
|
+
|
|
997
|
+
return *this;
|
|
998
|
+
}
|
|
936
999
|
|
|
937
1000
|
// assign from a global tile (load)
|
|
938
|
-
|
|
1001
|
+
|
|
1002
|
+
template <bool BoundsCheck>
|
|
1003
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
|
|
939
1004
|
{
|
|
940
1005
|
copy_from_global(t);
|
|
941
1006
|
return *this;
|
|
@@ -958,6 +1023,21 @@ struct tile_shared_t
|
|
|
958
1023
|
return *this;
|
|
959
1024
|
}
|
|
960
1025
|
|
|
1026
|
+
// define the += operator which is used during backward pass codegen
|
|
1027
|
+
// when returning a register tile from a user defined function
|
|
1028
|
+
template<typename OtherLayout>
|
|
1029
|
+
inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, OtherLayout>& rhs)
|
|
1030
|
+
{
|
|
1031
|
+
grad_add(rhs);
|
|
1032
|
+
return *this;
|
|
1033
|
+
}
|
|
1034
|
+
|
|
1035
|
+
inline CUDA_CALLABLE auto& operator += (const tile_shared_t<T, Layout>& rhs)
|
|
1036
|
+
{
|
|
1037
|
+
grad_add(rhs);
|
|
1038
|
+
return *this;
|
|
1039
|
+
}
|
|
1040
|
+
|
|
961
1041
|
// in-place zero
|
|
962
1042
|
inline CUDA_CALLABLE void zero()
|
|
963
1043
|
{
|
|
@@ -1039,6 +1119,27 @@ struct tile_shared_t
|
|
|
1039
1119
|
WP_TILE_SYNC();
|
|
1040
1120
|
}
|
|
1041
1121
|
|
|
1122
|
+
// shared tile deep copy
|
|
1123
|
+
template <typename OtherT, typename OtherLayout, bool OtherOwner>
|
|
1124
|
+
inline CUDA_CALLABLE void assign(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& tile)
|
|
1125
|
+
{
|
|
1126
|
+
// check dimensions are compatible
|
|
1127
|
+
static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
|
|
1128
|
+
|
|
1129
|
+
if (initialized)
|
|
1130
|
+
WP_TILE_SYNC();
|
|
1131
|
+
|
|
1132
|
+
WP_PRAGMA_UNROLL
|
|
1133
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1134
|
+
{
|
|
1135
|
+
auto c = Layout::coord_from_linear(i);
|
|
1136
|
+
data(c) = tile.data(c);
|
|
1137
|
+
}
|
|
1138
|
+
|
|
1139
|
+
initialized = true;
|
|
1140
|
+
WP_TILE_SYNC();
|
|
1141
|
+
}
|
|
1142
|
+
|
|
1042
1143
|
// in-place gradient zero
|
|
1043
1144
|
inline CUDA_CALLABLE void grad_zero()
|
|
1044
1145
|
{
|
|
@@ -1078,8 +1179,21 @@ struct tile_shared_t
|
|
|
1078
1179
|
WP_TILE_SYNC();
|
|
1079
1180
|
}
|
|
1080
1181
|
|
|
1182
|
+
// accumulate gradients onto this tile from another shared tile
|
|
1183
|
+
inline CUDA_CALLABLE void grad_add(const tile_shared_t<T, Layout>& tile)
|
|
1184
|
+
{
|
|
1185
|
+
WP_PRAGMA_UNROLL
|
|
1186
|
+
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
1187
|
+
{
|
|
1188
|
+
auto c = Layout::coord_from_linear(i);
|
|
1189
|
+
grad(c) += tile.grad(c);
|
|
1190
|
+
}
|
|
1191
|
+
|
|
1192
|
+
WP_TILE_SYNC();
|
|
1193
|
+
}
|
|
1194
|
+
|
|
1081
1195
|
// accumulate gradient onto this tile from a global array
|
|
1082
|
-
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
1196
|
+
inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
1083
1197
|
{
|
|
1084
1198
|
WP_PRAGMA_UNROLL
|
|
1085
1199
|
for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
|
|
@@ -1103,7 +1217,7 @@ struct tile_shared_t
|
|
|
1103
1217
|
}
|
|
1104
1218
|
|
|
1105
1219
|
WP_TILE_SYNC();
|
|
1106
|
-
}
|
|
1220
|
+
}
|
|
1107
1221
|
|
|
1108
1222
|
// copy shared tile to register
|
|
1109
1223
|
inline CUDA_CALLABLE auto grad_to_register()
|
|
@@ -1172,7 +1286,7 @@ struct tile_shared_t
|
|
|
1172
1286
|
{
|
|
1173
1287
|
// alias of shared tile with 128bit type
|
|
1174
1288
|
using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
|
|
1175
|
-
tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
|
|
1289
|
+
tile_shared_t<float4, SrcLayout, false> src128((float4*)data.ptr);
|
|
1176
1290
|
|
|
1177
1291
|
assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
|
|
1178
1292
|
assert(((uint64_t)(dest128))%sizeof(float4) == 0);
|
|
@@ -1251,7 +1365,7 @@ struct tile_shared_t
|
|
|
1251
1365
|
const int elements = min(Layout::Shape::dim(1), (src.data.shape[lastdim] - src.offset[lastdim]));
|
|
1252
1366
|
const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
|
|
1253
1367
|
const bool aligned_stride = (src.data.strides[0]/sizeof(T))%Layout::Stride::dim(0) == 0;
|
|
1254
|
-
|
|
1368
|
+
|
|
1255
1369
|
float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
|
|
1256
1370
|
const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
|
|
1257
1371
|
|
|
@@ -1262,7 +1376,7 @@ struct tile_shared_t
|
|
|
1262
1376
|
{
|
|
1263
1377
|
// alias of shared tile with 128bit type
|
|
1264
1378
|
using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
|
|
1265
|
-
tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
|
|
1379
|
+
tile_shared_t<float4, DestLayout, false> dest128((float4*)data.ptr);
|
|
1266
1380
|
|
|
1267
1381
|
assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
|
|
1268
1382
|
assert(((uint64_t)(src128))%sizeof(float4) == 0);
|
|
@@ -1463,9 +1577,16 @@ void tile_register_t<T, L>::print() const
|
|
|
1463
1577
|
// print entry points
|
|
1464
1578
|
template <typename T, typename L>
|
|
1465
1579
|
inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
|
|
1580
|
+
|
|
1581
|
+
template <typename T, typename L>
|
|
1582
|
+
inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
|
|
1583
|
+
|
|
1466
1584
|
template <typename T, typename L, bool Owner>
|
|
1467
1585
|
inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
|
|
1468
1586
|
|
|
1587
|
+
template <typename T, typename L, bool Owner>
|
|
1588
|
+
inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
|
|
1589
|
+
|
|
1469
1590
|
template <typename T, typename L, bool O>
|
|
1470
1591
|
inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
|
|
1471
1592
|
{
|
|
@@ -1488,13 +1609,81 @@ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile&
|
|
|
1488
1609
|
{
|
|
1489
1610
|
}
|
|
1490
1611
|
|
|
1612
|
+
// select specialization for shared tiles
|
|
1613
|
+
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1614
|
+
inline CUDA_CALLABLE auto select(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
|
|
1615
|
+
{
|
|
1616
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1617
|
+
return (!!cond) ? b.copy_to_register() : a;
|
|
1618
|
+
}
|
|
1491
1619
|
|
|
1492
|
-
template <typename T, typename
|
|
1493
|
-
inline CUDA_CALLABLE
|
|
1494
|
-
|
|
1495
|
-
|
|
1620
|
+
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1621
|
+
inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
|
|
1622
|
+
{
|
|
1623
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1624
|
+
return (!!cond) ? b : a.copy_to_register();
|
|
1625
|
+
}
|
|
1626
|
+
|
|
1627
|
+
template <typename C, typename T, typename L, bool Owner>
|
|
1628
|
+
inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
|
|
1629
|
+
{
|
|
1630
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1631
|
+
return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
|
|
1632
|
+
}
|
|
1496
1633
|
|
|
1634
|
+
template <typename C, typename T, typename L, bool LOwner, bool ROwner>
|
|
1635
|
+
inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
|
|
1636
|
+
{
|
|
1637
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1638
|
+
return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
|
|
1639
|
+
}
|
|
1497
1640
|
|
|
1641
|
+
// adj_select same as in builtin.h
|
|
1642
|
+
|
|
1643
|
+
// where specialization for register/shared tiles
|
|
1644
|
+
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1645
|
+
inline CUDA_CALLABLE auto where(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
|
|
1646
|
+
{
|
|
1647
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1648
|
+
return (!!cond) ? a : b.copy_to_register();
|
|
1649
|
+
}
|
|
1650
|
+
|
|
1651
|
+
template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
|
|
1652
|
+
inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
|
|
1653
|
+
{
|
|
1654
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1655
|
+
return (!!cond) ? a.copy_to_register() : b;
|
|
1656
|
+
}
|
|
1657
|
+
|
|
1658
|
+
template <typename C, typename T, typename L, bool Owner>
|
|
1659
|
+
inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
|
|
1660
|
+
{
|
|
1661
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1662
|
+
return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
|
|
1663
|
+
}
|
|
1664
|
+
|
|
1665
|
+
template <typename C, typename T, typename L, bool LOwner, bool ROwner>
|
|
1666
|
+
inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
|
|
1667
|
+
{
|
|
1668
|
+
// The double NOT operator !! casts to bool without compiler warnings.
|
|
1669
|
+
return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
|
|
1670
|
+
}
|
|
1671
|
+
|
|
1672
|
+
// adj_where same as in builtin.h
|
|
1673
|
+
|
|
1674
|
+
// copy specialization for shared tiles, the lvalue this gets assigned to is owning, thus, this invokes the copy assign path
|
|
1675
|
+
template <typename T, typename L, bool Owner>
|
|
1676
|
+
inline CUDA_CALLABLE auto copy(const tile_shared_t<T, L, Owner>& t)
|
|
1677
|
+
{
|
|
1678
|
+
return tile_shared_t<T, L, false>(t.data.ptr, t.grad.ptr);
|
|
1679
|
+
}
|
|
1680
|
+
|
|
1681
|
+
template <typename T, typename L, bool Owner>
|
|
1682
|
+
inline CUDA_CALLABLE void adj_copy(const tile_shared_t<T, L, Owner>& src, tile_shared_t<T, L, Owner>& adj_src, tile_shared_t<T, L, Owner>& adj_dest)
|
|
1683
|
+
{
|
|
1684
|
+
adj_src += adj_dest;
|
|
1685
|
+
adj_dest.grad_zero();
|
|
1686
|
+
}
|
|
1498
1687
|
|
|
1499
1688
|
// helpers to allocate shared tiles
|
|
1500
1689
|
template <typename T, typename Shape, typename Strides, bool RequiresGrad>
|
|
@@ -1727,10 +1916,66 @@ inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
|
|
|
1727
1916
|
T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
|
|
1728
1917
|
|
|
1729
1918
|
// entry point for load operations, these just return a reference to a global memory array + coordinate
|
|
1730
|
-
template <unsigned... Shape, typename...
|
|
1731
|
-
inline CUDA_CALLABLE auto tile_load(array_t<T>& src,
|
|
1919
|
+
template <typename T, bool BoundsCheck, unsigned... Shape, typename... Offset>
|
|
1920
|
+
inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Offset... offset)
|
|
1921
|
+
{
|
|
1922
|
+
return tile_global_t<T, tile_shape_t<Shape...>, BoundsCheck>(src, tile_coord(offset...));
|
|
1923
|
+
}
|
|
1924
|
+
|
|
1925
|
+
// used for indexed loads and stores
|
|
1926
|
+
template <typename T, int M, typename Coord>
|
|
1927
|
+
inline CUDA_CALLABLE bool compute_index(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Coord c, int& out)
|
|
1732
1928
|
{
|
|
1733
|
-
|
|
1929
|
+
int index = 0;
|
|
1930
|
+
|
|
1931
|
+
WP_PRAGMA_UNROLL
|
|
1932
|
+
for (int i = 0; i < Coord::size(); ++i)
|
|
1933
|
+
{
|
|
1934
|
+
if (i == axis)
|
|
1935
|
+
{
|
|
1936
|
+
// global = offset_coord + index_mapped_coord
|
|
1937
|
+
int index_along_axis = offset[i] + indices.data(c[i]);
|
|
1938
|
+
|
|
1939
|
+
// handle out of bounds case
|
|
1940
|
+
if (index_along_axis >= src.shape[i])
|
|
1941
|
+
return false;
|
|
1942
|
+
else
|
|
1943
|
+
index += src.strides[i] * index_along_axis;
|
|
1944
|
+
}
|
|
1945
|
+
else
|
|
1946
|
+
{
|
|
1947
|
+
// global = offset_coord + coord
|
|
1948
|
+
int g = offset[i] + c[i];
|
|
1949
|
+
|
|
1950
|
+
// handle out of bounds case
|
|
1951
|
+
if (g >= src.shape[i])
|
|
1952
|
+
return false;
|
|
1953
|
+
else
|
|
1954
|
+
index += src.strides[i] * g;
|
|
1955
|
+
}
|
|
1956
|
+
}
|
|
1957
|
+
|
|
1958
|
+
// array strides are in bytes so we convert to elements
|
|
1959
|
+
out = index / sizeof(T);
|
|
1960
|
+
return true;
|
|
1961
|
+
}
|
|
1962
|
+
|
|
1963
|
+
|
|
1964
|
+
template <unsigned... Shape, int M, typename T, typename... Offset>
|
|
1965
|
+
inline CUDA_CALLABLE auto tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Offset... offset)
|
|
1966
|
+
{
|
|
1967
|
+
auto out = tile_register_t<T, tile_layout_register_t<tile_shape_t<Shape...>>>();
|
|
1968
|
+
auto offset_coord = tile_coord(offset...);
|
|
1969
|
+
|
|
1970
|
+
out.apply([&](int reg, auto c) {
|
|
1971
|
+
int i;
|
|
1972
|
+
if (compute_index(src, indices, axis, offset_coord, c, i))
|
|
1973
|
+
out.data[reg] = src.data[i];
|
|
1974
|
+
else
|
|
1975
|
+
out.data[reg] = T(0);
|
|
1976
|
+
});
|
|
1977
|
+
|
|
1978
|
+
return out;
|
|
1734
1979
|
}
|
|
1735
1980
|
|
|
1736
1981
|
// // entry point for tile store operations
|
|
@@ -1741,38 +1986,90 @@ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
|
|
|
1741
1986
|
// }
|
|
1742
1987
|
|
|
1743
1988
|
// entry point for tile store operations
|
|
1744
|
-
template <typename T, typename Tile>
|
|
1745
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
|
|
1746
|
-
template <typename T, typename Tile>
|
|
1747
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y))); }
|
|
1748
|
-
template <typename T, typename Tile>
|
|
1749
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z))); }
|
|
1750
|
-
template <typename T, typename Tile>
|
|
1751
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w))); }
|
|
1989
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1990
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x))); }
|
|
1991
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1992
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y))); }
|
|
1993
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1994
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z))); }
|
|
1995
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1996
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z, w))); }
|
|
1997
|
+
|
|
1998
|
+
template <typename T, int M, typename Tile, typename Coord>
|
|
1999
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
|
|
2000
|
+
{
|
|
2001
|
+
auto src_reg = src.copy_to_register();
|
|
2002
|
+
|
|
2003
|
+
src_reg.apply([&](int reg, auto c) {
|
|
2004
|
+
int i;
|
|
2005
|
+
if (compute_index(dest, indices, axis, offset, c, i))
|
|
2006
|
+
dest.data[i] = src_reg.data[reg];
|
|
2007
|
+
});
|
|
2008
|
+
}
|
|
2009
|
+
|
|
2010
|
+
// entry point for tile index store operations
|
|
2011
|
+
template <typename T, int M, typename Tile>
|
|
2012
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x), src); }
|
|
2013
|
+
template <typename T, int M, typename Tile>
|
|
2014
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y), src); }
|
|
2015
|
+
template <typename T, int M, typename Tile>
|
|
2016
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
|
|
2017
|
+
template <typename T, int M, typename Tile>
|
|
2018
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
|
|
1752
2019
|
|
|
1753
2020
|
|
|
1754
2021
|
// compiler struggles with these if they are one line
|
|
1755
|
-
template <typename T, typename Tile>
|
|
2022
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1756
2023
|
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) {
|
|
1757
|
-
tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x));
|
|
2024
|
+
tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x));
|
|
1758
2025
|
return src.atomic_add(global);
|
|
1759
2026
|
}
|
|
1760
|
-
template <typename T, typename Tile>
|
|
2027
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1761
2028
|
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) {
|
|
1762
|
-
tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y));
|
|
2029
|
+
tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y));
|
|
1763
2030
|
return src.atomic_add(global);
|
|
1764
2031
|
}
|
|
1765
|
-
template <typename T, typename Tile>
|
|
2032
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1766
2033
|
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) {
|
|
1767
|
-
tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y, z));
|
|
2034
|
+
tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z));
|
|
1768
2035
|
return src.atomic_add(global);
|
|
1769
2036
|
}
|
|
1770
|
-
template <typename T, typename Tile>
|
|
2037
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1771
2038
|
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) {
|
|
1772
|
-
tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y, z, w));
|
|
2039
|
+
tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z, w));
|
|
1773
2040
|
return src.atomic_add(global);
|
|
1774
2041
|
}
|
|
1775
2042
|
|
|
2043
|
+
template <typename T, int M, typename Tile, typename Coord>
|
|
2044
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
|
|
2045
|
+
{
|
|
2046
|
+
auto src_reg = src.copy_to_register();
|
|
2047
|
+
auto ret_reg = tile_register_like<Tile>();
|
|
2048
|
+
|
|
2049
|
+
src_reg.apply([&](int reg, auto c) {
|
|
2050
|
+
int i;
|
|
2051
|
+
if (compute_index(dest, indices, axis, offset, c, i))
|
|
2052
|
+
ret_reg.data[reg] = wp::atomic_add(&dest.data[i], src_reg.data[reg]);
|
|
2053
|
+
else
|
|
2054
|
+
ret_reg.data[reg] = T(0);
|
|
2055
|
+
});
|
|
2056
|
+
|
|
2057
|
+
return ret_reg;
|
|
2058
|
+
}
|
|
2059
|
+
|
|
2060
|
+
// entry point for tile index atomic add operations
|
|
2061
|
+
template <typename T, int M, typename Tile>
|
|
2062
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x), src); }
|
|
2063
|
+
|
|
2064
|
+
template <typename T, int M, typename Tile>
|
|
2065
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y), src); }
|
|
2066
|
+
|
|
2067
|
+
template <typename T, int M, typename Tile>
|
|
2068
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
|
|
2069
|
+
|
|
2070
|
+
template <typename T, int M, typename Tile>
|
|
2071
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
|
|
2072
|
+
|
|
1776
2073
|
|
|
1777
2074
|
//-------------------------------------
|
|
1778
2075
|
// Adjoints
|
|
@@ -1791,7 +2088,6 @@ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
|
|
|
1791
2088
|
adj_ret.atomic_add_grad(dest);
|
|
1792
2089
|
}
|
|
1793
2090
|
|
|
1794
|
-
|
|
1795
2091
|
template <typename T, typename AdjTile>
|
|
1796
2092
|
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
|
|
1797
2093
|
template <typename T, typename AdjTile>
|
|
@@ -1801,7 +2097,44 @@ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, ar
|
|
|
1801
2097
|
template <typename T, typename AdjTile>
|
|
1802
2098
|
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
|
|
1803
2099
|
|
|
2100
|
+
template <typename T, int M, typename AdjTile, typename Coord>
|
|
2101
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset,
|
|
2102
|
+
array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset,
|
|
2103
|
+
AdjTile& adj_ret)
|
|
2104
|
+
{
|
|
2105
|
+
// we allow users to override grad of src
|
|
2106
|
+
if (adj_src.data)
|
|
2107
|
+
src.grad = adj_src.data;
|
|
2108
|
+
|
|
2109
|
+
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
1804
2110
|
|
|
2111
|
+
adj_ret_reg.apply([&](int reg, auto c) {
|
|
2112
|
+
int i;
|
|
2113
|
+
if (compute_index(src, indices, axis, offset, c, i))
|
|
2114
|
+
wp::atomic_add(&src.grad[i], adj_ret_reg.data[reg]);
|
|
2115
|
+
});
|
|
2116
|
+
}
|
|
2117
|
+
|
|
2118
|
+
template <typename T, int M, typename AdjTile>
|
|
2119
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_ret)
|
|
2120
|
+
{
|
|
2121
|
+
adj_tile_load_indexed(src, indices, axis, tile_coord(x), adj_src, adj_indices, adj_axis, tile_coord(0), adj_ret);
|
|
2122
|
+
}
|
|
2123
|
+
template <typename T, int M, typename AdjTile>
|
|
2124
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_ret)
|
|
2125
|
+
{
|
|
2126
|
+
adj_tile_load_indexed(src, indices, axis, tile_coord(x, y), adj_src, adj_indices, adj_axis, tile_coord(0, 0), adj_ret);
|
|
2127
|
+
}
|
|
2128
|
+
template <typename T, int M, typename AdjTile>
|
|
2129
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret)
|
|
2130
|
+
{
|
|
2131
|
+
adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0), adj_ret);
|
|
2132
|
+
}
|
|
2133
|
+
template <typename T, int M, typename AdjTile>
|
|
2134
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret)
|
|
2135
|
+
{
|
|
2136
|
+
adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z, w), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0, 0), adj_ret);
|
|
2137
|
+
}
|
|
1805
2138
|
|
|
1806
2139
|
template <typename T, typename Tile, typename AdjTile, typename Coord>
|
|
1807
2140
|
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
|
|
@@ -1827,7 +2160,33 @@ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z,
|
|
|
1827
2160
|
template <typename T, typename Tile, typename AdjTile>
|
|
1828
2161
|
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
|
|
1829
2162
|
|
|
2163
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename Coord>
|
|
2164
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset, AdjTile& adj_t)
|
|
2165
|
+
{
|
|
2166
|
+
// we allow users to override grad of src
|
|
2167
|
+
if (adj_dest.data)
|
|
2168
|
+
dest.grad = adj_dest.data;
|
|
2169
|
+
|
|
2170
|
+
auto adj_t_reg = tile_register_like<Tile>();
|
|
1830
2171
|
|
|
2172
|
+
adj_t_reg.apply([&](int reg, auto c) {
|
|
2173
|
+
int i;
|
|
2174
|
+
if (compute_index(dest, indices, axis, offset, c, i))
|
|
2175
|
+
adj_t_reg.data[reg] += dest.grad[i];
|
|
2176
|
+
});
|
|
2177
|
+
|
|
2178
|
+
// write adjoints back
|
|
2179
|
+
adj_t.grad_add(adj_t_reg);
|
|
2180
|
+
}
|
|
2181
|
+
|
|
2182
|
+
template <typename T, int M, typename Tile, typename AdjTile>
|
|
2183
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
|
|
2184
|
+
template <typename T, int M, typename Tile, typename AdjTile>
|
|
2185
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
|
|
2186
|
+
template <typename T, int M, typename Tile, typename AdjTile>
|
|
2187
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
|
|
2188
|
+
template <typename T, int M, typename Tile, typename AdjTile>
|
|
2189
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
|
|
1831
2190
|
|
|
1832
2191
|
// adj_tile_atomic_add is an alias for adj_tile_store
|
|
1833
2192
|
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
@@ -1839,13 +2198,28 @@ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, in
|
|
|
1839
2198
|
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1840
2199
|
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
|
|
1841
2200
|
|
|
2201
|
+
// adj_tile_atomic_add_indexed is an alias for adj_tile_store_indexed
|
|
2202
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
|
|
2203
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
|
|
2204
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
|
|
2205
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
|
|
2206
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
|
|
2207
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
|
|
2208
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
|
|
2209
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
|
|
1842
2210
|
|
|
1843
2211
|
// unary map
|
|
1844
|
-
template <typename Tile, typename Fwd>
|
|
1845
|
-
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1846
|
-
Tile &a)
|
|
2212
|
+
template <typename Tile, typename Fwd, typename ReturnTile>
|
|
2213
|
+
inline CUDA_CALLABLE auto tile_map(Fwd op, Tile &a, ReturnTile &r)
|
|
1847
2214
|
{
|
|
1848
|
-
|
|
2215
|
+
// verify shapes and sizes are compatible
|
|
2216
|
+
using ShapeIn = typename Tile::Layout::Shape;
|
|
2217
|
+
using ShapeOut = typename ReturnTile::Layout::Shape;
|
|
2218
|
+
|
|
2219
|
+
static_assert(ShapeIn::N == ShapeOut::N, "Number of tile dimensions must match for unary map");
|
|
2220
|
+
static_assert(ShapeIn::size() == ShapeOut::size(), "Tile sizes must match for unary map");
|
|
2221
|
+
|
|
2222
|
+
auto out = tile_register_like<ReturnTile>();
|
|
1849
2223
|
auto a_reg = a.copy_to_register();
|
|
1850
2224
|
|
|
1851
2225
|
using Layout = typename decltype(out)::Layout;
|
|
@@ -1884,12 +2258,24 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
|
1884
2258
|
}
|
|
1885
2259
|
|
|
1886
2260
|
// binary map
|
|
1887
|
-
template <typename TileA, typename TileB, typename Fwd>
|
|
2261
|
+
template <typename TileA, typename TileB, typename Fwd, typename ReturnTile>
|
|
1888
2262
|
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1889
2263
|
TileA& a,
|
|
1890
|
-
TileB& b
|
|
2264
|
+
TileB& b,
|
|
2265
|
+
ReturnTile& r)
|
|
1891
2266
|
{
|
|
1892
|
-
|
|
2267
|
+
// verify shapes and sizes are compatible
|
|
2268
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2269
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2270
|
+
using ShapeOut = typename ReturnTile::Layout::Shape;
|
|
2271
|
+
|
|
2272
|
+
static_assert(ShapeA::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
|
|
2273
|
+
static_assert(ShapeB::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
|
|
2274
|
+
|
|
2275
|
+
static_assert(ShapeA::size() == ShapeOut::size(), "Tile sizes must match for binary map");
|
|
2276
|
+
static_assert(ShapeB::size() == ShapeOut::size(), "Tile sizes must match for binary map");
|
|
2277
|
+
|
|
2278
|
+
auto out = tile_register_like<ReturnTile>();
|
|
1893
2279
|
|
|
1894
2280
|
auto a_reg = a.copy_to_register();
|
|
1895
2281
|
auto b_reg = b.copy_to_register();
|
|
@@ -1905,7 +2291,6 @@ inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
|
1905
2291
|
return out;
|
|
1906
2292
|
}
|
|
1907
2293
|
|
|
1908
|
-
|
|
1909
2294
|
template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
|
|
1910
2295
|
inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
1911
2296
|
TileA &a,
|
|
@@ -1936,28 +2321,32 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
|
1936
2321
|
adj_b.grad_add(adj_b_reg);
|
|
1937
2322
|
}
|
|
1938
2323
|
|
|
1939
|
-
// wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
|
|
2324
|
+
// We wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
|
|
1940
2325
|
// this is important because many of the builtin operators don't follow particular conventions on references for
|
|
1941
2326
|
// the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
|
|
1942
|
-
|
|
1943
|
-
|
|
2327
|
+
// The r argument is a dummy return tile argument, because we can't template on the return tile type in a macro definition.
|
|
2328
|
+
// So if we want users to be able to define functions that return a tile type that is different from the input type,
|
|
2329
|
+
// we must pass an extra dummy return tile argument that is used define the return type of tile_map.
|
|
2330
|
+
|
|
2331
|
+
#define tile_unary_map(op, a, r) tile_map([](auto x) { return op(x);}, a, r)
|
|
2332
|
+
#define adj_tile_unary_map(op, a, r, adj_op, adj_a, adj_r, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
|
|
1944
2333
|
|
|
1945
|
-
#define tile_binary_map(op, a, b) tile_map([](auto x, auto y) { return op(x, y);}, a, b)
|
|
1946
|
-
#define adj_tile_binary_map(op, a, b, adj_op, adj_a, adj_b, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
|
|
2334
|
+
#define tile_binary_map(op, a, b, r) tile_map([](auto x, auto y) { return op(x, y);}, a, b, r)
|
|
2335
|
+
#define adj_tile_binary_map(op, a, b, r, adj_op, adj_a, adj_b, adj_r, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
|
|
1947
2336
|
|
|
1948
2337
|
// -tile (unary neg)
|
|
1949
2338
|
template <typename Tile>
|
|
1950
|
-
inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a); }
|
|
2339
|
+
inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a, a); }
|
|
1951
2340
|
|
|
1952
2341
|
template <typename Tile, typename AdjTile>
|
|
1953
|
-
inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, wp::adj_neg, adj_a, adj_ret); }
|
|
2342
|
+
inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, a, wp::adj_neg, adj_a, adj_a, adj_ret); }
|
|
1954
2343
|
|
|
1955
2344
|
|
|
1956
2345
|
// tile + tile
|
|
1957
2346
|
template <typename TileA, typename TileB>
|
|
1958
2347
|
inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
|
|
1959
2348
|
{
|
|
1960
|
-
return tile_binary_map(add, a, b);
|
|
2349
|
+
return tile_binary_map(add, a, b, a);
|
|
1961
2350
|
}
|
|
1962
2351
|
|
|
1963
2352
|
// add overloads get called in user function adjoints generated by codegen (adj_tile += adj_ret)
|
|
@@ -1984,20 +2373,20 @@ inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_register
|
|
|
1984
2373
|
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1985
2374
|
inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1986
2375
|
{
|
|
1987
|
-
adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
|
|
2376
|
+
adj_tile_binary_map(add, a, b, a, adj_add, adj_a, adj_b, adj_a, adj_c);
|
|
1988
2377
|
}
|
|
1989
2378
|
|
|
1990
2379
|
// tile - tile
|
|
1991
2380
|
template <typename TileA, typename TileB>
|
|
1992
2381
|
inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
|
|
1993
2382
|
{
|
|
1994
|
-
return tile_binary_map(sub, a, b);
|
|
2383
|
+
return tile_binary_map(sub, a, b, a);
|
|
1995
2384
|
}
|
|
1996
2385
|
|
|
1997
2386
|
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1998
2387
|
inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1999
2388
|
{
|
|
2000
|
-
adj_tile_binary_map(sub, a, b, adj_sub, adj_a, adj_b, adj_c);
|
|
2389
|
+
adj_tile_binary_map(sub, a, b, a, adj_sub, adj_a, adj_b, adj_a, adj_c);
|
|
2001
2390
|
}
|
|
2002
2391
|
|
|
2003
2392
|
|
|
@@ -2008,7 +2397,7 @@ inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
|
|
|
2008
2397
|
// promote scalar to a constant tile
|
|
2009
2398
|
auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
|
|
2010
2399
|
|
|
2011
|
-
return tile_binary_map(mul, a, s_tile);
|
|
2400
|
+
return tile_binary_map(mul, a, s_tile, a);
|
|
2012
2401
|
}
|
|
2013
2402
|
|
|
2014
2403
|
template <typename Tile, typename AdjTile>
|
|
@@ -2024,7 +2413,7 @@ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
|
|
|
2024
2413
|
// initialize to constant
|
|
2025
2414
|
s_tile = s;
|
|
2026
2415
|
|
|
2027
|
-
adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
|
|
2416
|
+
adj_tile_binary_map(mul, a, s_tile, a, adj_mul, adj_a, adj_s_tile, adj_a, adj_c);
|
|
2028
2417
|
|
|
2029
2418
|
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2030
2419
|
{
|
|
@@ -2834,7 +3223,7 @@ template <typename Tile, typename AdjTile>
|
|
|
2834
3223
|
inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
|
|
2835
3224
|
{
|
|
2836
3225
|
auto a = tile_transpose(adj_ret);
|
|
2837
|
-
auto b = adj_t;
|
|
3226
|
+
auto& b = adj_t;
|
|
2838
3227
|
|
|
2839
3228
|
adj_t.assign(tile_add(a,b));
|
|
2840
3229
|
}
|