warp-lang 1.8.1__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +282 -103
- warp/__init__.pyi +482 -110
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +93 -30
- warp/build_dll.py +47 -67
- warp/builtins.py +955 -137
- warp/codegen.py +312 -206
- warp/config.py +1 -1
- warp/context.py +1249 -784
- warp/examples/core/example_marching_cubes.py +1 -0
- warp/examples/core/example_render_opengl.py +100 -3
- warp/examples/fem/example_apic_fluid.py +98 -52
- warp/examples/fem/example_convection_diffusion_dg.py +25 -4
- warp/examples/fem/example_diffusion_mgpu.py +8 -3
- warp/examples/fem/utils.py +68 -22
- warp/fabric.py +1 -1
- warp/fem/cache.py +27 -19
- warp/fem/domain.py +2 -2
- warp/fem/field/nodal_field.py +2 -2
- warp/fem/field/virtual.py +264 -166
- warp/fem/geometry/geometry.py +5 -5
- warp/fem/integrate.py +129 -51
- warp/fem/space/restriction.py +4 -0
- warp/fem/space/shape/tet_shape_function.py +3 -10
- warp/jax_experimental/custom_call.py +1 -1
- warp/jax_experimental/ffi.py +2 -1
- warp/marching_cubes.py +708 -0
- warp/native/array.h +99 -4
- warp/native/builtin.h +82 -5
- warp/native/bvh.cpp +64 -28
- warp/native/bvh.cu +58 -58
- warp/native/bvh.h +2 -2
- warp/native/clang/clang.cpp +7 -7
- warp/native/coloring.cpp +8 -2
- warp/native/crt.cpp +2 -2
- warp/native/crt.h +3 -5
- warp/native/cuda_util.cpp +41 -10
- warp/native/cuda_util.h +10 -4
- warp/native/exports.h +1842 -1908
- warp/native/fabric.h +2 -1
- warp/native/hashgrid.cpp +37 -37
- warp/native/hashgrid.cu +2 -2
- warp/native/initializer_array.h +1 -1
- warp/native/intersect.h +2 -2
- warp/native/mat.h +1910 -116
- warp/native/mathdx.cpp +43 -43
- warp/native/mesh.cpp +24 -24
- warp/native/mesh.cu +26 -26
- warp/native/mesh.h +4 -2
- warp/native/nanovdb/GridHandle.h +179 -12
- warp/native/nanovdb/HostBuffer.h +8 -7
- warp/native/nanovdb/NanoVDB.h +517 -895
- warp/native/nanovdb/NodeManager.h +323 -0
- warp/native/nanovdb/PNanoVDB.h +2 -2
- warp/native/quat.h +331 -14
- warp/native/range.h +7 -1
- warp/native/reduce.cpp +10 -10
- warp/native/reduce.cu +13 -14
- warp/native/runlength_encode.cpp +2 -2
- warp/native/runlength_encode.cu +5 -5
- warp/native/scan.cpp +3 -3
- warp/native/scan.cu +4 -4
- warp/native/sort.cpp +10 -10
- warp/native/sort.cu +22 -22
- warp/native/sparse.cpp +8 -8
- warp/native/sparse.cu +13 -13
- warp/native/spatial.h +366 -17
- warp/native/temp_buffer.h +2 -2
- warp/native/tile.h +283 -69
- warp/native/vec.h +381 -14
- warp/native/volume.cpp +54 -54
- warp/native/volume.cu +1 -1
- warp/native/volume.h +2 -1
- warp/native/volume_builder.cu +30 -37
- warp/native/warp.cpp +150 -149
- warp/native/warp.cu +323 -192
- warp/native/warp.h +227 -226
- warp/optim/linear.py +736 -271
- warp/render/imgui_manager.py +289 -0
- warp/render/render_opengl.py +85 -6
- warp/sim/graph_coloring.py +2 -2
- warp/sparse.py +558 -175
- warp/tests/aux_test_module_aot.py +7 -0
- warp/tests/cuda/test_async.py +3 -3
- warp/tests/cuda/test_conditional_captures.py +101 -0
- warp/tests/geometry/test_marching_cubes.py +233 -12
- warp/tests/sim/test_coloring.py +6 -6
- warp/tests/test_array.py +56 -5
- warp/tests/test_codegen.py +3 -2
- warp/tests/test_context.py +8 -15
- warp/tests/test_enum.py +136 -0
- warp/tests/test_examples.py +2 -2
- warp/tests/test_fem.py +45 -2
- warp/tests/test_fixedarray.py +229 -0
- warp/tests/test_func.py +18 -15
- warp/tests/test_future_annotations.py +7 -5
- warp/tests/test_linear_solvers.py +30 -0
- warp/tests/test_map.py +1 -1
- warp/tests/test_mat.py +1518 -378
- warp/tests/test_mat_assign_copy.py +178 -0
- warp/tests/test_mat_constructors.py +574 -0
- warp/tests/test_module_aot.py +287 -0
- warp/tests/test_print.py +69 -0
- warp/tests/test_quat.py +140 -34
- warp/tests/test_quat_assign_copy.py +145 -0
- warp/tests/test_reload.py +2 -1
- warp/tests/test_sparse.py +71 -0
- warp/tests/test_spatial.py +140 -34
- warp/tests/test_spatial_assign_copy.py +160 -0
- warp/tests/test_struct.py +43 -3
- warp/tests/test_types.py +0 -20
- warp/tests/test_vec.py +179 -34
- warp/tests/test_vec_assign_copy.py +143 -0
- warp/tests/tile/test_tile.py +184 -18
- warp/tests/tile/test_tile_cholesky.py +605 -0
- warp/tests/tile/test_tile_load.py +169 -0
- warp/tests/tile/test_tile_mathdx.py +2 -558
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +1 -1
- warp/tests/tile/test_tile_shared_memory.py +5 -5
- warp/tests/unittest_suites.py +6 -0
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +108 -9
- warp/types.py +554 -264
- warp/utils.py +68 -86
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
- warp/native/marching.cpp +0 -19
- warp/native/marching.cu +0 -514
- warp/native/marching.h +0 -19
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/native/tile.h
CHANGED
|
@@ -230,7 +230,9 @@ struct tile_coord_t
|
|
|
230
230
|
out.indices[i] = indices[i] + c.indices[i];
|
|
231
231
|
}
|
|
232
232
|
return out;
|
|
233
|
-
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
static constexpr int size() { return N; }
|
|
234
236
|
};
|
|
235
237
|
|
|
236
238
|
// This function deduces N = sizeof...(Ints)
|
|
@@ -338,7 +340,8 @@ using tile_stride_t = tile_tuple_t<V...>;
|
|
|
338
340
|
|
|
339
341
|
// represents a tile stored in global memory with dynamic strides
|
|
340
342
|
// used to represent the source and offset for tile loads to register/shared
|
|
341
|
-
|
|
343
|
+
// BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
|
|
344
|
+
template <typename T, typename Shape_, bool BoundsCheck=true>
|
|
342
345
|
struct tile_global_t
|
|
343
346
|
{
|
|
344
347
|
using Type = T;
|
|
@@ -370,25 +373,33 @@ struct tile_global_t
|
|
|
370
373
|
|
|
371
374
|
inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
|
|
372
375
|
{
|
|
373
|
-
|
|
374
|
-
int index = 0;
|
|
375
|
-
|
|
376
|
-
WP_PRAGMA_UNROLL
|
|
377
|
-
for (int i=0; i < Shape::N; ++i)
|
|
376
|
+
if constexpr (BoundsCheck)
|
|
378
377
|
{
|
|
379
|
-
//
|
|
380
|
-
int
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
378
|
+
// element index
|
|
379
|
+
int index = 0;
|
|
380
|
+
|
|
381
|
+
WP_PRAGMA_UNROLL
|
|
382
|
+
for (int i=0; i < Shape::N; ++i)
|
|
383
|
+
{
|
|
384
|
+
// global = offset + coord
|
|
385
|
+
int c = offset[i] + coord[i];
|
|
386
|
+
|
|
387
|
+
// handle out of bounds case
|
|
388
|
+
if (c >= data.shape[i])
|
|
389
|
+
return false;
|
|
390
|
+
else
|
|
391
|
+
index += data.strides[i]*c;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
// array strides are in bytes so we convert to elements
|
|
395
|
+
out = index / sizeof(T);
|
|
396
|
+
return true;
|
|
397
|
+
}
|
|
398
|
+
else
|
|
399
|
+
{
|
|
400
|
+
out = index_from_coord(coord);
|
|
401
|
+
return true;
|
|
402
|
+
}
|
|
392
403
|
}
|
|
393
404
|
|
|
394
405
|
inline CUDA_CALLABLE T load(const Coord& coord) const
|
|
@@ -435,6 +446,7 @@ struct tile_global_t
|
|
|
435
446
|
}
|
|
436
447
|
};
|
|
437
448
|
|
|
449
|
+
|
|
438
450
|
template <typename Shape_>
|
|
439
451
|
struct tile_layout_register_t
|
|
440
452
|
{
|
|
@@ -521,7 +533,8 @@ struct tile_register_t
|
|
|
521
533
|
data[i] = value;
|
|
522
534
|
}
|
|
523
535
|
|
|
524
|
-
|
|
536
|
+
template <bool BoundsCheck>
|
|
537
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
|
|
525
538
|
{
|
|
526
539
|
copy_from_global(t);
|
|
527
540
|
return *this;
|
|
@@ -647,8 +660,7 @@ struct tile_register_t
|
|
|
647
660
|
|
|
648
661
|
CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
|
|
649
662
|
{
|
|
650
|
-
apply([&](int reg, auto c) {data[reg]
|
|
651
|
-
|
|
663
|
+
apply([&](int reg, auto c) {data[reg] += global.load_grad(c);});
|
|
652
664
|
}
|
|
653
665
|
|
|
654
666
|
inline CUDA_CALLABLE auto& grad_to_register()
|
|
@@ -935,7 +947,9 @@ struct tile_shared_t
|
|
|
935
947
|
}
|
|
936
948
|
|
|
937
949
|
// assign from a global tile (load)
|
|
938
|
-
|
|
950
|
+
|
|
951
|
+
template <bool BoundsCheck>
|
|
952
|
+
inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
|
|
939
953
|
{
|
|
940
954
|
copy_from_global(t);
|
|
941
955
|
return *this;
|
|
@@ -1103,7 +1117,7 @@ struct tile_shared_t
|
|
|
1103
1117
|
}
|
|
1104
1118
|
|
|
1105
1119
|
WP_TILE_SYNC();
|
|
1106
|
-
}
|
|
1120
|
+
}
|
|
1107
1121
|
|
|
1108
1122
|
// copy shared tile to register
|
|
1109
1123
|
inline CUDA_CALLABLE auto grad_to_register()
|
|
@@ -1172,7 +1186,7 @@ struct tile_shared_t
|
|
|
1172
1186
|
{
|
|
1173
1187
|
// alias of shared tile with 128bit type
|
|
1174
1188
|
using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
|
|
1175
|
-
tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
|
|
1189
|
+
tile_shared_t<float4, SrcLayout, false> src128((float4*)data.ptr);
|
|
1176
1190
|
|
|
1177
1191
|
assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
|
|
1178
1192
|
assert(((uint64_t)(dest128))%sizeof(float4) == 0);
|
|
@@ -1251,7 +1265,7 @@ struct tile_shared_t
|
|
|
1251
1265
|
const int elements = min(Layout::Shape::dim(1), (src.data.shape[lastdim] - src.offset[lastdim]));
|
|
1252
1266
|
const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
|
|
1253
1267
|
const bool aligned_stride = (src.data.strides[0]/sizeof(T))%Layout::Stride::dim(0) == 0;
|
|
1254
|
-
|
|
1268
|
+
|
|
1255
1269
|
float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
|
|
1256
1270
|
const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
|
|
1257
1271
|
|
|
@@ -1262,7 +1276,7 @@ struct tile_shared_t
|
|
|
1262
1276
|
{
|
|
1263
1277
|
// alias of shared tile with 128bit type
|
|
1264
1278
|
using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
|
|
1265
|
-
tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
|
|
1279
|
+
tile_shared_t<float4, DestLayout, false> dest128((float4*)data.ptr);
|
|
1266
1280
|
|
|
1267
1281
|
assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
|
|
1268
1282
|
assert(((uint64_t)(src128))%sizeof(float4) == 0);
|
|
@@ -1727,10 +1741,66 @@ inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
|
|
|
1727
1741
|
T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
|
|
1728
1742
|
|
|
1729
1743
|
// entry point for load operations, these just return a reference to a global memory array + coordinate
|
|
1730
|
-
template <unsigned... Shape, typename...
|
|
1731
|
-
inline CUDA_CALLABLE auto tile_load(array_t<T>& src,
|
|
1744
|
+
template <typename T, bool BoundsCheck, unsigned... Shape, typename... Offset>
|
|
1745
|
+
inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Offset... offset)
|
|
1746
|
+
{
|
|
1747
|
+
return tile_global_t<T, tile_shape_t<Shape...>, BoundsCheck>(src, tile_coord(offset...));
|
|
1748
|
+
}
|
|
1749
|
+
|
|
1750
|
+
// used for indexed loads and stores
|
|
1751
|
+
template <typename T, int M, typename Coord>
|
|
1752
|
+
inline CUDA_CALLABLE bool compute_index(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Coord c, int& out)
|
|
1753
|
+
{
|
|
1754
|
+
int index = 0;
|
|
1755
|
+
|
|
1756
|
+
WP_PRAGMA_UNROLL
|
|
1757
|
+
for (int i = 0; i < Coord::size(); ++i)
|
|
1758
|
+
{
|
|
1759
|
+
if (i == axis)
|
|
1760
|
+
{
|
|
1761
|
+
// global = offset_coord + index_mapped_coord
|
|
1762
|
+
int index_along_axis = offset[i] + indices.data(c[i]);
|
|
1763
|
+
|
|
1764
|
+
// handle out of bounds case
|
|
1765
|
+
if (index_along_axis >= src.shape[i])
|
|
1766
|
+
return false;
|
|
1767
|
+
else
|
|
1768
|
+
index += src.strides[i] * index_along_axis;
|
|
1769
|
+
}
|
|
1770
|
+
else
|
|
1771
|
+
{
|
|
1772
|
+
// global = offset_coord + coord
|
|
1773
|
+
int g = offset[i] + c[i];
|
|
1774
|
+
|
|
1775
|
+
// handle out of bounds case
|
|
1776
|
+
if (g >= src.shape[i])
|
|
1777
|
+
return false;
|
|
1778
|
+
else
|
|
1779
|
+
index += src.strides[i] * g;
|
|
1780
|
+
}
|
|
1781
|
+
}
|
|
1782
|
+
|
|
1783
|
+
// array strides are in bytes so we convert to elements
|
|
1784
|
+
out = index / sizeof(T);
|
|
1785
|
+
return true;
|
|
1786
|
+
}
|
|
1787
|
+
|
|
1788
|
+
|
|
1789
|
+
template <unsigned... Shape, int M, typename T, typename... Offset>
|
|
1790
|
+
inline CUDA_CALLABLE auto tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Offset... offset)
|
|
1732
1791
|
{
|
|
1733
|
-
|
|
1792
|
+
auto out = tile_register_t<T, tile_layout_register_t<tile_shape_t<Shape...>>>();
|
|
1793
|
+
auto offset_coord = tile_coord(offset...);
|
|
1794
|
+
|
|
1795
|
+
out.apply([&](int reg, auto c) {
|
|
1796
|
+
int i;
|
|
1797
|
+
if (compute_index(src, indices, axis, offset_coord, c, i))
|
|
1798
|
+
out.data[reg] = src.data[i];
|
|
1799
|
+
else
|
|
1800
|
+
out.data[reg] = T(0);
|
|
1801
|
+
});
|
|
1802
|
+
|
|
1803
|
+
return out;
|
|
1734
1804
|
}
|
|
1735
1805
|
|
|
1736
1806
|
// // entry point for tile store operations
|
|
@@ -1741,38 +1811,90 @@ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
|
|
|
1741
1811
|
// }
|
|
1742
1812
|
|
|
1743
1813
|
// entry point for tile store operations
|
|
1744
|
-
template <typename T, typename Tile>
|
|
1745
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
|
|
1746
|
-
template <typename T, typename Tile>
|
|
1747
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y))); }
|
|
1748
|
-
template <typename T, typename Tile>
|
|
1749
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z))); }
|
|
1750
|
-
template <typename T, typename Tile>
|
|
1751
|
-
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w))); }
|
|
1814
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1815
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x))); }
|
|
1816
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1817
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y))); }
|
|
1818
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1819
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z))); }
|
|
1820
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1821
|
+
inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z, w))); }
|
|
1822
|
+
|
|
1823
|
+
template <typename T, int M, typename Tile, typename Coord>
|
|
1824
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
|
|
1825
|
+
{
|
|
1826
|
+
auto src_reg = src.copy_to_register();
|
|
1827
|
+
|
|
1828
|
+
src_reg.apply([&](int reg, auto c) {
|
|
1829
|
+
int i;
|
|
1830
|
+
if (compute_index(dest, indices, axis, offset, c, i))
|
|
1831
|
+
dest.data[i] = src_reg.data[reg];
|
|
1832
|
+
});
|
|
1833
|
+
}
|
|
1834
|
+
|
|
1835
|
+
// entry point for tile index store operations
|
|
1836
|
+
template <typename T, int M, typename Tile>
|
|
1837
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x), src); }
|
|
1838
|
+
template <typename T, int M, typename Tile>
|
|
1839
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y), src); }
|
|
1840
|
+
template <typename T, int M, typename Tile>
|
|
1841
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
|
|
1842
|
+
template <typename T, int M, typename Tile>
|
|
1843
|
+
inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
|
|
1752
1844
|
|
|
1753
1845
|
|
|
1754
1846
|
// compiler struggles with these if they are one line
|
|
1755
|
-
template <typename T, typename Tile>
|
|
1847
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1756
1848
|
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) {
|
|
1757
|
-
tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x));
|
|
1849
|
+
tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x));
|
|
1758
1850
|
return src.atomic_add(global);
|
|
1759
1851
|
}
|
|
1760
|
-
template <typename T, typename Tile>
|
|
1852
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1761
1853
|
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) {
|
|
1762
|
-
tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y));
|
|
1854
|
+
tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y));
|
|
1763
1855
|
return src.atomic_add(global);
|
|
1764
1856
|
}
|
|
1765
|
-
template <typename T, typename Tile>
|
|
1857
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1766
1858
|
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) {
|
|
1767
|
-
tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y, z));
|
|
1859
|
+
tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z));
|
|
1768
1860
|
return src.atomic_add(global);
|
|
1769
1861
|
}
|
|
1770
|
-
template <typename T, typename Tile>
|
|
1862
|
+
template <typename T, bool BoundsCheck, typename Tile>
|
|
1771
1863
|
inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) {
|
|
1772
|
-
tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y, z, w));
|
|
1864
|
+
tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z, w));
|
|
1773
1865
|
return src.atomic_add(global);
|
|
1774
1866
|
}
|
|
1775
1867
|
|
|
1868
|
+
template <typename T, int M, typename Tile, typename Coord>
|
|
1869
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
|
|
1870
|
+
{
|
|
1871
|
+
auto src_reg = src.copy_to_register();
|
|
1872
|
+
auto ret_reg = tile_register_like<Tile>();
|
|
1873
|
+
|
|
1874
|
+
src_reg.apply([&](int reg, auto c) {
|
|
1875
|
+
int i;
|
|
1876
|
+
if (compute_index(dest, indices, axis, offset, c, i))
|
|
1877
|
+
ret_reg.data[reg] = wp::atomic_add(&dest.data[i], src_reg.data[reg]);
|
|
1878
|
+
else
|
|
1879
|
+
ret_reg.data[reg] = T(0);
|
|
1880
|
+
});
|
|
1881
|
+
|
|
1882
|
+
return ret_reg;
|
|
1883
|
+
}
|
|
1884
|
+
|
|
1885
|
+
// entry point for tile index atomic add operations
|
|
1886
|
+
template <typename T, int M, typename Tile>
|
|
1887
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x), src); }
|
|
1888
|
+
|
|
1889
|
+
template <typename T, int M, typename Tile>
|
|
1890
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y), src); }
|
|
1891
|
+
|
|
1892
|
+
template <typename T, int M, typename Tile>
|
|
1893
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
|
|
1894
|
+
|
|
1895
|
+
template <typename T, int M, typename Tile>
|
|
1896
|
+
inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
|
|
1897
|
+
|
|
1776
1898
|
|
|
1777
1899
|
//-------------------------------------
|
|
1778
1900
|
// Adjoints
|
|
@@ -1791,7 +1913,6 @@ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
|
|
|
1791
1913
|
adj_ret.atomic_add_grad(dest);
|
|
1792
1914
|
}
|
|
1793
1915
|
|
|
1794
|
-
|
|
1795
1916
|
template <typename T, typename AdjTile>
|
|
1796
1917
|
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
|
|
1797
1918
|
template <typename T, typename AdjTile>
|
|
@@ -1801,7 +1922,44 @@ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, ar
|
|
|
1801
1922
|
template <typename T, typename AdjTile>
|
|
1802
1923
|
inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
|
|
1803
1924
|
|
|
1925
|
+
template <typename T, int M, typename AdjTile, typename Coord>
|
|
1926
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset,
|
|
1927
|
+
array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset,
|
|
1928
|
+
AdjTile& adj_ret)
|
|
1929
|
+
{
|
|
1930
|
+
// we allow users to override grad of src
|
|
1931
|
+
if (adj_src.data)
|
|
1932
|
+
src.grad = adj_src.data;
|
|
1804
1933
|
|
|
1934
|
+
auto adj_ret_reg = adj_ret.grad_to_register();
|
|
1935
|
+
|
|
1936
|
+
adj_ret_reg.apply([&](int reg, auto c) {
|
|
1937
|
+
int i;
|
|
1938
|
+
if (compute_index(src, indices, axis, offset, c, i))
|
|
1939
|
+
wp::atomic_add(&src.grad[i], adj_ret_reg.data[reg]);
|
|
1940
|
+
});
|
|
1941
|
+
}
|
|
1942
|
+
|
|
1943
|
+
template <typename T, int M, typename AdjTile>
|
|
1944
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_ret)
|
|
1945
|
+
{
|
|
1946
|
+
adj_tile_load_indexed(src, indices, axis, tile_coord(x), adj_src, adj_indices, adj_axis, tile_coord(0), adj_ret);
|
|
1947
|
+
}
|
|
1948
|
+
template <typename T, int M, typename AdjTile>
|
|
1949
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_ret)
|
|
1950
|
+
{
|
|
1951
|
+
adj_tile_load_indexed(src, indices, axis, tile_coord(x, y), adj_src, adj_indices, adj_axis, tile_coord(0, 0), adj_ret);
|
|
1952
|
+
}
|
|
1953
|
+
template <typename T, int M, typename AdjTile>
|
|
1954
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret)
|
|
1955
|
+
{
|
|
1956
|
+
adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0), adj_ret);
|
|
1957
|
+
}
|
|
1958
|
+
template <typename T, int M, typename AdjTile>
|
|
1959
|
+
inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret)
|
|
1960
|
+
{
|
|
1961
|
+
adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z, w), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0, 0), adj_ret);
|
|
1962
|
+
}
|
|
1805
1963
|
|
|
1806
1964
|
template <typename T, typename Tile, typename AdjTile, typename Coord>
|
|
1807
1965
|
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
|
|
@@ -1827,7 +1985,33 @@ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z,
|
|
|
1827
1985
|
template <typename T, typename Tile, typename AdjTile>
|
|
1828
1986
|
inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
|
|
1829
1987
|
|
|
1988
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename Coord>
|
|
1989
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset, AdjTile& adj_t)
|
|
1990
|
+
{
|
|
1991
|
+
// we allow users to override grad of src
|
|
1992
|
+
if (adj_dest.data)
|
|
1993
|
+
dest.grad = adj_dest.data;
|
|
1994
|
+
|
|
1995
|
+
auto adj_t_reg = tile_register_like<Tile>();
|
|
1996
|
+
|
|
1997
|
+
adj_t_reg.apply([&](int reg, auto c) {
|
|
1998
|
+
int i;
|
|
1999
|
+
if (compute_index(dest, indices, axis, offset, c, i))
|
|
2000
|
+
adj_t_reg.data[reg] += dest.grad[i];
|
|
2001
|
+
});
|
|
2002
|
+
|
|
2003
|
+
// write adjoints back
|
|
2004
|
+
adj_t.grad_add(adj_t_reg);
|
|
2005
|
+
}
|
|
1830
2006
|
|
|
2007
|
+
template <typename T, int M, typename Tile, typename AdjTile>
|
|
2008
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
|
|
2009
|
+
template <typename T, int M, typename Tile, typename AdjTile>
|
|
2010
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
|
|
2011
|
+
template <typename T, int M, typename Tile, typename AdjTile>
|
|
2012
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
|
|
2013
|
+
template <typename T, int M, typename Tile, typename AdjTile>
|
|
2014
|
+
inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
|
|
1831
2015
|
|
|
1832
2016
|
// adj_tile_atomic_add is an alias for adj_tile_store
|
|
1833
2017
|
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
@@ -1839,13 +2023,28 @@ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, in
|
|
|
1839
2023
|
template <typename T, typename Tile, typename AdjTile, typename AdjRet>
|
|
1840
2024
|
inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
|
|
1841
2025
|
|
|
2026
|
+
// adj_tile_atomic_add_indexed is an alias for adj_tile_store_indexed
|
|
2027
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
|
|
2028
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
|
|
2029
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
|
|
2030
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
|
|
2031
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
|
|
2032
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
|
|
2033
|
+
template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
|
|
2034
|
+
inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
|
|
1842
2035
|
|
|
1843
2036
|
// unary map
|
|
1844
|
-
template <typename Tile, typename Fwd>
|
|
1845
|
-
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1846
|
-
Tile &a)
|
|
2037
|
+
template <typename Tile, typename Fwd, typename ReturnTile>
|
|
2038
|
+
inline CUDA_CALLABLE auto tile_map(Fwd op, Tile &a, ReturnTile &r)
|
|
1847
2039
|
{
|
|
1848
|
-
|
|
2040
|
+
// verify shapes and sizes are compatible
|
|
2041
|
+
using ShapeIn = typename Tile::Layout::Shape;
|
|
2042
|
+
using ShapeOut = typename ReturnTile::Layout::Shape;
|
|
2043
|
+
|
|
2044
|
+
static_assert(ShapeIn::N == ShapeOut::N, "Number of tile dimensions must match for unary map");
|
|
2045
|
+
static_assert(ShapeIn::size() == ShapeOut::size(), "Tile sizes must match for unary map");
|
|
2046
|
+
|
|
2047
|
+
auto out = tile_register_like<ReturnTile>();
|
|
1849
2048
|
auto a_reg = a.copy_to_register();
|
|
1850
2049
|
|
|
1851
2050
|
using Layout = typename decltype(out)::Layout;
|
|
@@ -1884,12 +2083,24 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
|
1884
2083
|
}
|
|
1885
2084
|
|
|
1886
2085
|
// binary map
|
|
1887
|
-
template <typename TileA, typename TileB, typename Fwd>
|
|
2086
|
+
template <typename TileA, typename TileB, typename Fwd, typename ReturnTile>
|
|
1888
2087
|
inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
1889
2088
|
TileA& a,
|
|
1890
|
-
TileB& b
|
|
2089
|
+
TileB& b,
|
|
2090
|
+
ReturnTile& r)
|
|
1891
2091
|
{
|
|
1892
|
-
|
|
2092
|
+
// verify shapes and sizes are compatible
|
|
2093
|
+
using ShapeA = typename TileA::Layout::Shape;
|
|
2094
|
+
using ShapeB = typename TileB::Layout::Shape;
|
|
2095
|
+
using ShapeOut = typename ReturnTile::Layout::Shape;
|
|
2096
|
+
|
|
2097
|
+
static_assert(ShapeA::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
|
|
2098
|
+
static_assert(ShapeB::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
|
|
2099
|
+
|
|
2100
|
+
static_assert(ShapeA::size() == ShapeOut::size(), "Tile sizes must match for binary map");
|
|
2101
|
+
static_assert(ShapeB::size() == ShapeOut::size(), "Tile sizes must match for binary map");
|
|
2102
|
+
|
|
2103
|
+
auto out = tile_register_like<ReturnTile>();
|
|
1893
2104
|
|
|
1894
2105
|
auto a_reg = a.copy_to_register();
|
|
1895
2106
|
auto b_reg = b.copy_to_register();
|
|
@@ -1905,7 +2116,6 @@ inline CUDA_CALLABLE auto tile_map(Fwd op,
|
|
|
1905
2116
|
return out;
|
|
1906
2117
|
}
|
|
1907
2118
|
|
|
1908
|
-
|
|
1909
2119
|
template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
|
|
1910
2120
|
inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
1911
2121
|
TileA &a,
|
|
@@ -1936,28 +2146,32 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
|
|
|
1936
2146
|
adj_b.grad_add(adj_b_reg);
|
|
1937
2147
|
}
|
|
1938
2148
|
|
|
1939
|
-
// wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
|
|
2149
|
+
// We wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
|
|
1940
2150
|
// this is important because many of the builtin operators don't follow particular conventions on references for
|
|
1941
2151
|
// the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
|
|
1942
|
-
|
|
1943
|
-
|
|
2152
|
+
// The r argument is a dummy return tile argument, because we can't template on the return tile type in a macro definition.
|
|
2153
|
+
// So if we want users to be able to define functions that return a tile type that is different from the input type,
|
|
2154
|
+
// we must pass an extra dummy return tile argument that is used define the return type of tile_map.
|
|
2155
|
+
|
|
2156
|
+
#define tile_unary_map(op, a, r) tile_map([](auto x) { return op(x);}, a, r)
|
|
2157
|
+
#define adj_tile_unary_map(op, a, r, adj_op, adj_a, adj_r, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
|
|
1944
2158
|
|
|
1945
|
-
#define tile_binary_map(op, a, b) tile_map([](auto x, auto y) { return op(x, y);}, a, b)
|
|
1946
|
-
#define adj_tile_binary_map(op, a, b, adj_op, adj_a, adj_b, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
|
|
2159
|
+
#define tile_binary_map(op, a, b, r) tile_map([](auto x, auto y) { return op(x, y);}, a, b, r)
|
|
2160
|
+
#define adj_tile_binary_map(op, a, b, r, adj_op, adj_a, adj_b, adj_r, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
|
|
1947
2161
|
|
|
1948
2162
|
// -tile (unary neg)
|
|
1949
2163
|
template <typename Tile>
|
|
1950
|
-
inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a); }
|
|
2164
|
+
inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a, a); }
|
|
1951
2165
|
|
|
1952
2166
|
template <typename Tile, typename AdjTile>
|
|
1953
|
-
inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, wp::adj_neg, adj_a, adj_ret); }
|
|
2167
|
+
inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, a, wp::adj_neg, adj_a, adj_a, adj_ret); }
|
|
1954
2168
|
|
|
1955
2169
|
|
|
1956
2170
|
// tile + tile
|
|
1957
2171
|
template <typename TileA, typename TileB>
|
|
1958
2172
|
inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
|
|
1959
2173
|
{
|
|
1960
|
-
return tile_binary_map(add, a, b);
|
|
2174
|
+
return tile_binary_map(add, a, b, a);
|
|
1961
2175
|
}
|
|
1962
2176
|
|
|
1963
2177
|
// add overloads get called in user function adjoints generated by codegen (adj_tile += adj_ret)
|
|
@@ -1984,20 +2198,20 @@ inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_register
|
|
|
1984
2198
|
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1985
2199
|
inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1986
2200
|
{
|
|
1987
|
-
adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
|
|
2201
|
+
adj_tile_binary_map(add, a, b, a, adj_add, adj_a, adj_b, adj_a, adj_c);
|
|
1988
2202
|
}
|
|
1989
2203
|
|
|
1990
2204
|
// tile - tile
|
|
1991
2205
|
template <typename TileA, typename TileB>
|
|
1992
2206
|
inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
|
|
1993
2207
|
{
|
|
1994
|
-
return tile_binary_map(sub, a, b);
|
|
2208
|
+
return tile_binary_map(sub, a, b, a);
|
|
1995
2209
|
}
|
|
1996
2210
|
|
|
1997
2211
|
template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
|
|
1998
2212
|
inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
|
|
1999
2213
|
{
|
|
2000
|
-
adj_tile_binary_map(sub, a, b, adj_sub, adj_a, adj_b, adj_c);
|
|
2214
|
+
adj_tile_binary_map(sub, a, b, a, adj_sub, adj_a, adj_b, adj_a, adj_c);
|
|
2001
2215
|
}
|
|
2002
2216
|
|
|
2003
2217
|
|
|
@@ -2008,7 +2222,7 @@ inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
|
|
|
2008
2222
|
// promote scalar to a constant tile
|
|
2009
2223
|
auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
|
|
2010
2224
|
|
|
2011
|
-
return tile_binary_map(mul, a, s_tile);
|
|
2225
|
+
return tile_binary_map(mul, a, s_tile, a);
|
|
2012
2226
|
}
|
|
2013
2227
|
|
|
2014
2228
|
template <typename Tile, typename AdjTile>
|
|
@@ -2024,7 +2238,7 @@ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
|
|
|
2024
2238
|
// initialize to constant
|
|
2025
2239
|
s_tile = s;
|
|
2026
2240
|
|
|
2027
|
-
adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
|
|
2241
|
+
adj_tile_binary_map(mul, a, s_tile, a, adj_mul, adj_a, adj_s_tile, adj_a, adj_c);
|
|
2028
2242
|
|
|
2029
2243
|
for (int i=0; i < Layout::NumRegs; ++i)
|
|
2030
2244
|
{
|