warp-lang 1.8.1__py3-none-win_amd64.whl → 1.9.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/native/tile.h CHANGED
@@ -230,7 +230,9 @@ struct tile_coord_t
230
230
  out.indices[i] = indices[i] + c.indices[i];
231
231
  }
232
232
  return out;
233
- }
233
+ }
234
+
235
+ static constexpr int size() { return N; }
234
236
  };
235
237
 
236
238
  // This function deduces N = sizeof...(Ints)
@@ -338,7 +340,8 @@ using tile_stride_t = tile_tuple_t<V...>;
338
340
 
339
341
  // represents a tile stored in global memory with dynamic strides
340
342
  // used to represent the source and offset for tile loads to register/shared
341
- template <typename T, typename Shape_>
343
+ // BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
344
+ template <typename T, typename Shape_, bool BoundsCheck=true>
342
345
  struct tile_global_t
343
346
  {
344
347
  using Type = T;
@@ -370,25 +373,33 @@ struct tile_global_t
370
373
 
371
374
  inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
372
375
  {
373
- // element index
374
- int index = 0;
375
-
376
- WP_PRAGMA_UNROLL
377
- for (int i=0; i < Shape::N; ++i)
376
+ if constexpr (BoundsCheck)
378
377
  {
379
- // global = offset + coord
380
- int c = offset[i] + coord[i];
381
-
382
- // handle out of bounds case
383
- if (c >= data.shape[i])
384
- return false;
385
- else
386
- index += data.strides[i]*c;
387
- }
388
-
389
- // array strides are in bytes so we convert to elements
390
- out = index / sizeof(T);
391
- return true;
378
+ // element index
379
+ int index = 0;
380
+
381
+ WP_PRAGMA_UNROLL
382
+ for (int i=0; i < Shape::N; ++i)
383
+ {
384
+ // global = offset + coord
385
+ int c = offset[i] + coord[i];
386
+
387
+ // handle out of bounds case
388
+ if (c >= data.shape[i])
389
+ return false;
390
+ else
391
+ index += data.strides[i]*c;
392
+ }
393
+
394
+ // array strides are in bytes so we convert to elements
395
+ out = index / sizeof(T);
396
+ return true;
397
+ }
398
+ else
399
+ {
400
+ out = index_from_coord(coord);
401
+ return true;
402
+ }
392
403
  }
393
404
 
394
405
  inline CUDA_CALLABLE T load(const Coord& coord) const
@@ -435,6 +446,7 @@ struct tile_global_t
435
446
  }
436
447
  };
437
448
 
449
+
438
450
  template <typename Shape_>
439
451
  struct tile_layout_register_t
440
452
  {
@@ -521,7 +533,8 @@ struct tile_register_t
521
533
  data[i] = value;
522
534
  }
523
535
 
524
- inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
536
+ template <bool BoundsCheck>
537
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
525
538
  {
526
539
  copy_from_global(t);
527
540
  return *this;
@@ -529,7 +542,7 @@ struct tile_register_t
529
542
 
530
543
  // define the += operator which is used during backward pass codegen
531
544
  // when returning a register tile from a user defined function
532
- inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
545
+ inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, Layout>& rhs)
533
546
  {
534
547
  grad_add(rhs);
535
548
  return *this;
@@ -645,10 +658,9 @@ struct tile_register_t
645
658
  data[i] += tile.data[i];
646
659
  }
647
660
 
648
- CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
661
+ inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
649
662
  {
650
- apply([&](int reg, auto c) {data[reg] = global.load_grad(c);});
651
-
663
+ apply([&](int reg, auto c) {data[reg] += global.load_grad(c);});
652
664
  }
653
665
 
654
666
  inline CUDA_CALLABLE auto& grad_to_register()
@@ -746,6 +758,7 @@ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, boo
746
758
 
747
759
  // one entry per-thread so no need for synchronization
748
760
  smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
761
+ assert(smem_base[WP_TILE_THREAD_IDX] >= 0);
749
762
 
750
763
  #ifdef __CUDA_ARCH__
751
764
  extern __shared__ char dynamic_smem_base[];
@@ -893,6 +906,28 @@ struct tile_shared_t
893
906
  {
894
907
  }
895
908
 
909
+ // we delete the copy constructor because in the case the shared tile is owning,
910
+ // this leads to a double deallocation.
911
+ // this also forces one to handle copies explicitly
912
+ inline CUDA_CALLABLE tile_shared_t(const tile_shared_t& other) : data(other.data), grad(other.grad), initialized(other.initialized)
913
+ {
914
+ static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
915
+ }
916
+
917
+ // move constructor
918
+ inline CUDA_CALLABLE tile_shared_t(tile_shared_t&& other) : data(other.data), grad(other.grad), initialized(other.initialized)
919
+ {
920
+ other.data.ptr = nullptr;
921
+ other.grad.ptr = nullptr;
922
+ }
923
+
924
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
925
+ inline CUDA_CALLABLE tile_shared_t(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& other) : data(other.data.ptr), grad(other.grad.ptr), initialized(other.initialized)
926
+ {
927
+ static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
928
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
929
+ }
930
+
896
931
  // initialize from an existing tile's memory
897
932
  inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
898
933
  {
@@ -920,22 +955,52 @@ struct tile_shared_t
920
955
 
921
956
  // construct from another shared tile, this constructor
922
957
  // is invoked for reshape operations like `wp.tile_transpose()`
958
+ // or `wp::copy()`
923
959
  template <typename OtherT, typename OtherLayout, bool OtherOwner>
924
960
  inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& rhs)
925
961
  {
926
962
  // check dimensions are compatible
927
963
  static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
928
964
 
929
- // alias tile directly
930
- data.ptr = rhs.data.ptr;
931
- grad.ptr = rhs.grad.ptr;
932
- initialized = rhs.initialized;
965
+
966
+ if (Owner)
967
+ {
968
+ // if the tile owns the data we need to copy
969
+ assign(rhs);
970
+ }
971
+ else
972
+ {
973
+ // alias tile directly
974
+ data.ptr = rhs.data.ptr;
975
+ grad.ptr = rhs.grad.ptr;
976
+ initialized = rhs.initialized;
977
+ }
933
978
 
934
979
  return *this;
935
- }
980
+ }
981
+
982
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t& rhs)
983
+ {
984
+ if (Owner)
985
+ {
986
+ // if the tile owns the data we need to copy
987
+ assign(rhs);
988
+ }
989
+ else
990
+ {
991
+ // alias tile directly
992
+ data.ptr = rhs.data.ptr;
993
+ grad.ptr = rhs.grad.ptr;
994
+ initialized = rhs.initialized;
995
+ }
996
+
997
+ return *this;
998
+ }
936
999
 
937
1000
  // assign from a global tile (load)
938
- inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
1001
+
1002
+ template <bool BoundsCheck>
1003
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
939
1004
  {
940
1005
  copy_from_global(t);
941
1006
  return *this;
@@ -958,6 +1023,21 @@ struct tile_shared_t
958
1023
  return *this;
959
1024
  }
960
1025
 
1026
+ // define the += operator which is used during backward pass codegen
1027
+ // when returning a register tile from a user defined function
1028
+ template<typename OtherLayout>
1029
+ inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, OtherLayout>& rhs)
1030
+ {
1031
+ grad_add(rhs);
1032
+ return *this;
1033
+ }
1034
+
1035
+ inline CUDA_CALLABLE auto& operator += (const tile_shared_t<T, Layout>& rhs)
1036
+ {
1037
+ grad_add(rhs);
1038
+ return *this;
1039
+ }
1040
+
961
1041
  // in-place zero
962
1042
  inline CUDA_CALLABLE void zero()
963
1043
  {
@@ -1039,6 +1119,27 @@ struct tile_shared_t
1039
1119
  WP_TILE_SYNC();
1040
1120
  }
1041
1121
 
1122
+ // shared tile deep copy
1123
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
1124
+ inline CUDA_CALLABLE void assign(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& tile)
1125
+ {
1126
+ // check dimensions are compatible
1127
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
1128
+
1129
+ if (initialized)
1130
+ WP_TILE_SYNC();
1131
+
1132
+ WP_PRAGMA_UNROLL
1133
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1134
+ {
1135
+ auto c = Layout::coord_from_linear(i);
1136
+ data(c) = tile.data(c);
1137
+ }
1138
+
1139
+ initialized = true;
1140
+ WP_TILE_SYNC();
1141
+ }
1142
+
1042
1143
  // in-place gradient zero
1043
1144
  inline CUDA_CALLABLE void grad_zero()
1044
1145
  {
@@ -1078,8 +1179,21 @@ struct tile_shared_t
1078
1179
  WP_TILE_SYNC();
1079
1180
  }
1080
1181
 
1182
+ // accumulate gradients onto this tile from another shared tile
1183
+ inline CUDA_CALLABLE void grad_add(const tile_shared_t<T, Layout>& tile)
1184
+ {
1185
+ WP_PRAGMA_UNROLL
1186
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1187
+ {
1188
+ auto c = Layout::coord_from_linear(i);
1189
+ grad(c) += tile.grad(c);
1190
+ }
1191
+
1192
+ WP_TILE_SYNC();
1193
+ }
1194
+
1081
1195
  // accumulate gradient onto this tile from a global array
1082
- CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1196
+ inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1083
1197
  {
1084
1198
  WP_PRAGMA_UNROLL
1085
1199
  for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
@@ -1103,7 +1217,7 @@ struct tile_shared_t
1103
1217
  }
1104
1218
 
1105
1219
  WP_TILE_SYNC();
1106
- }
1220
+ }
1107
1221
 
1108
1222
  // copy shared tile to register
1109
1223
  inline CUDA_CALLABLE auto grad_to_register()
@@ -1172,7 +1286,7 @@ struct tile_shared_t
1172
1286
  {
1173
1287
  // alias of shared tile with 128bit type
1174
1288
  using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1175
- tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
1289
+ tile_shared_t<float4, SrcLayout, false> src128((float4*)data.ptr);
1176
1290
 
1177
1291
  assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
1178
1292
  assert(((uint64_t)(dest128))%sizeof(float4) == 0);
@@ -1251,7 +1365,7 @@ struct tile_shared_t
1251
1365
  const int elements = min(Layout::Shape::dim(1), (src.data.shape[lastdim] - src.offset[lastdim]));
1252
1366
  const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1253
1367
  const bool aligned_stride = (src.data.strides[0]/sizeof(T))%Layout::Stride::dim(0) == 0;
1254
-
1368
+
1255
1369
  float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1256
1370
  const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
1257
1371
 
@@ -1262,7 +1376,7 @@ struct tile_shared_t
1262
1376
  {
1263
1377
  // alias of shared tile with 128bit type
1264
1378
  using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1265
- tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
1379
+ tile_shared_t<float4, DestLayout, false> dest128((float4*)data.ptr);
1266
1380
 
1267
1381
  assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
1268
1382
  assert(((uint64_t)(src128))%sizeof(float4) == 0);
@@ -1463,9 +1577,16 @@ void tile_register_t<T, L>::print() const
1463
1577
  // print entry points
1464
1578
  template <typename T, typename L>
1465
1579
  inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
1580
+
1581
+ template <typename T, typename L>
1582
+ inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1583
+
1466
1584
  template <typename T, typename L, bool Owner>
1467
1585
  inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
1468
1586
 
1587
+ template <typename T, typename L, bool Owner>
1588
+ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1589
+
1469
1590
  template <typename T, typename L, bool O>
1470
1591
  inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
1471
1592
  {
@@ -1488,13 +1609,81 @@ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile&
1488
1609
  {
1489
1610
  }
1490
1611
 
1612
+ // select specialization for shared tiles
1613
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1614
+ inline CUDA_CALLABLE auto select(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
1615
+ {
1616
+ // The double NOT operator !! casts to bool without compiler warnings.
1617
+ return (!!cond) ? b.copy_to_register() : a;
1618
+ }
1491
1619
 
1492
- template <typename T, typename L>
1493
- inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1494
- template <typename T, typename L, bool Owner>
1495
- inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1620
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1621
+ inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
1622
+ {
1623
+ // The double NOT operator !! casts to bool without compiler warnings.
1624
+ return (!!cond) ? b : a.copy_to_register();
1625
+ }
1626
+
1627
+ template <typename C, typename T, typename L, bool Owner>
1628
+ inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
1629
+ {
1630
+ // The double NOT operator !! casts to bool without compiler warnings.
1631
+ return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
1632
+ }
1496
1633
 
1634
+ template <typename C, typename T, typename L, bool LOwner, bool ROwner>
1635
+ inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
1636
+ {
1637
+ // The double NOT operator !! casts to bool without compiler warnings.
1638
+ return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
1639
+ }
1497
1640
 
1641
+ // adj_select same as in builtin.h
1642
+
1643
+ // where specialization for register/shared tiles
1644
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1645
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
1646
+ {
1647
+ // The double NOT operator !! casts to bool without compiler warnings.
1648
+ return (!!cond) ? a : b.copy_to_register();
1649
+ }
1650
+
1651
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1652
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
1653
+ {
1654
+ // The double NOT operator !! casts to bool without compiler warnings.
1655
+ return (!!cond) ? a.copy_to_register() : b;
1656
+ }
1657
+
1658
+ template <typename C, typename T, typename L, bool Owner>
1659
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
1660
+ {
1661
+ // The double NOT operator !! casts to bool without compiler warnings.
1662
+ return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
1663
+ }
1664
+
1665
+ template <typename C, typename T, typename L, bool LOwner, bool ROwner>
1666
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
1667
+ {
1668
+ // The double NOT operator !! casts to bool without compiler warnings.
1669
+ return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
1670
+ }
1671
+
1672
+ // adj_where same as in builtin.h
1673
+
1674
+ // copy specialization for shared tiles, the lvalue this gets assigned to is owning, thus, this invokes the copy assign path
1675
+ template <typename T, typename L, bool Owner>
1676
+ inline CUDA_CALLABLE auto copy(const tile_shared_t<T, L, Owner>& t)
1677
+ {
1678
+ return tile_shared_t<T, L, false>(t.data.ptr, t.grad.ptr);
1679
+ }
1680
+
1681
+ template <typename T, typename L, bool Owner>
1682
+ inline CUDA_CALLABLE void adj_copy(const tile_shared_t<T, L, Owner>& src, tile_shared_t<T, L, Owner>& adj_src, tile_shared_t<T, L, Owner>& adj_dest)
1683
+ {
1684
+ adj_src += adj_dest;
1685
+ adj_dest.grad_zero();
1686
+ }
1498
1687
 
1499
1688
  // helpers to allocate shared tiles
1500
1689
  template <typename T, typename Shape, typename Strides, bool RequiresGrad>
@@ -1727,10 +1916,66 @@ inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
1727
1916
  T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
1728
1917
 
1729
1918
  // entry point for load operations, these just return a reference to a global memory array + coordinate
1730
- template <unsigned... Shape, typename... Indices, typename T>
1731
- inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
1919
+ template <typename T, bool BoundsCheck, unsigned... Shape, typename... Offset>
1920
+ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Offset... offset)
1921
+ {
1922
+ return tile_global_t<T, tile_shape_t<Shape...>, BoundsCheck>(src, tile_coord(offset...));
1923
+ }
1924
+
1925
+ // used for indexed loads and stores
1926
+ template <typename T, int M, typename Coord>
1927
+ inline CUDA_CALLABLE bool compute_index(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Coord c, int& out)
1732
1928
  {
1733
- return tile_global_t<T, tile_shape_t<Shape...>>(src, tile_coord(offset...));
1929
+ int index = 0;
1930
+
1931
+ WP_PRAGMA_UNROLL
1932
+ for (int i = 0; i < Coord::size(); ++i)
1933
+ {
1934
+ if (i == axis)
1935
+ {
1936
+ // global = offset_coord + index_mapped_coord
1937
+ int index_along_axis = offset[i] + indices.data(c[i]);
1938
+
1939
+ // handle out of bounds case
1940
+ if (index_along_axis >= src.shape[i])
1941
+ return false;
1942
+ else
1943
+ index += src.strides[i] * index_along_axis;
1944
+ }
1945
+ else
1946
+ {
1947
+ // global = offset_coord + coord
1948
+ int g = offset[i] + c[i];
1949
+
1950
+ // handle out of bounds case
1951
+ if (g >= src.shape[i])
1952
+ return false;
1953
+ else
1954
+ index += src.strides[i] * g;
1955
+ }
1956
+ }
1957
+
1958
+ // array strides are in bytes so we convert to elements
1959
+ out = index / sizeof(T);
1960
+ return true;
1961
+ }
1962
+
1963
+
1964
+ template <unsigned... Shape, int M, typename T, typename... Offset>
1965
+ inline CUDA_CALLABLE auto tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Offset... offset)
1966
+ {
1967
+ auto out = tile_register_t<T, tile_layout_register_t<tile_shape_t<Shape...>>>();
1968
+ auto offset_coord = tile_coord(offset...);
1969
+
1970
+ out.apply([&](int reg, auto c) {
1971
+ int i;
1972
+ if (compute_index(src, indices, axis, offset_coord, c, i))
1973
+ out.data[reg] = src.data[i];
1974
+ else
1975
+ out.data[reg] = T(0);
1976
+ });
1977
+
1978
+ return out;
1734
1979
  }
1735
1980
 
1736
1981
  // // entry point for tile store operations
@@ -1741,38 +1986,90 @@ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
1741
1986
  // }
1742
1987
 
1743
1988
  // entry point for tile store operations
1744
- template <typename T, typename Tile>
1745
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
1746
- template <typename T, typename Tile>
1747
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y))); }
1748
- template <typename T, typename Tile>
1749
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z))); }
1750
- template <typename T, typename Tile>
1751
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w))); }
1989
+ template <typename T, bool BoundsCheck, typename Tile>
1990
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x))); }
1991
+ template <typename T, bool BoundsCheck, typename Tile>
1992
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y))); }
1993
+ template <typename T, bool BoundsCheck, typename Tile>
1994
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z))); }
1995
+ template <typename T, bool BoundsCheck, typename Tile>
1996
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z, w))); }
1997
+
1998
+ template <typename T, int M, typename Tile, typename Coord>
1999
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
2000
+ {
2001
+ auto src_reg = src.copy_to_register();
2002
+
2003
+ src_reg.apply([&](int reg, auto c) {
2004
+ int i;
2005
+ if (compute_index(dest, indices, axis, offset, c, i))
2006
+ dest.data[i] = src_reg.data[reg];
2007
+ });
2008
+ }
2009
+
2010
+ // entry point for tile index store operations
2011
+ template <typename T, int M, typename Tile>
2012
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x), src); }
2013
+ template <typename T, int M, typename Tile>
2014
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y), src); }
2015
+ template <typename T, int M, typename Tile>
2016
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
2017
+ template <typename T, int M, typename Tile>
2018
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
1752
2019
 
1753
2020
 
1754
2021
  // compiler struggles with these if they are one line
1755
- template <typename T, typename Tile>
2022
+ template <typename T, bool BoundsCheck, typename Tile>
1756
2023
  inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) {
1757
- tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x));
2024
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x));
1758
2025
  return src.atomic_add(global);
1759
2026
  }
1760
- template <typename T, typename Tile>
2027
+ template <typename T, bool BoundsCheck, typename Tile>
1761
2028
  inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) {
1762
- tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y));
2029
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y));
1763
2030
  return src.atomic_add(global);
1764
2031
  }
1765
- template <typename T, typename Tile>
2032
+ template <typename T, bool BoundsCheck, typename Tile>
1766
2033
  inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) {
1767
- tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y, z));
2034
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z));
1768
2035
  return src.atomic_add(global);
1769
2036
  }
1770
- template <typename T, typename Tile>
2037
+ template <typename T, bool BoundsCheck, typename Tile>
1771
2038
  inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) {
1772
- tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y, z, w));
2039
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z, w));
1773
2040
  return src.atomic_add(global);
1774
2041
  }
1775
2042
 
2043
+ template <typename T, int M, typename Tile, typename Coord>
2044
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
2045
+ {
2046
+ auto src_reg = src.copy_to_register();
2047
+ auto ret_reg = tile_register_like<Tile>();
2048
+
2049
+ src_reg.apply([&](int reg, auto c) {
2050
+ int i;
2051
+ if (compute_index(dest, indices, axis, offset, c, i))
2052
+ ret_reg.data[reg] = wp::atomic_add(&dest.data[i], src_reg.data[reg]);
2053
+ else
2054
+ ret_reg.data[reg] = T(0);
2055
+ });
2056
+
2057
+ return ret_reg;
2058
+ }
2059
+
2060
+ // entry point for tile index atomic add operations
2061
+ template <typename T, int M, typename Tile>
2062
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x), src); }
2063
+
2064
+ template <typename T, int M, typename Tile>
2065
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y), src); }
2066
+
2067
+ template <typename T, int M, typename Tile>
2068
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
2069
+
2070
+ template <typename T, int M, typename Tile>
2071
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
2072
+
1776
2073
 
1777
2074
  //-------------------------------------
1778
2075
  // Adjoints
@@ -1791,7 +2088,6 @@ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
1791
2088
  adj_ret.atomic_add_grad(dest);
1792
2089
  }
1793
2090
 
1794
-
1795
2091
  template <typename T, typename AdjTile>
1796
2092
  inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
1797
2093
  template <typename T, typename AdjTile>
@@ -1801,7 +2097,44 @@ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, ar
1801
2097
  template <typename T, typename AdjTile>
1802
2098
  inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
1803
2099
 
2100
+ template <typename T, int M, typename AdjTile, typename Coord>
2101
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset,
2102
+ array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset,
2103
+ AdjTile& adj_ret)
2104
+ {
2105
+ // we allow users to override grad of src
2106
+ if (adj_src.data)
2107
+ src.grad = adj_src.data;
2108
+
2109
+ auto adj_ret_reg = adj_ret.grad_to_register();
1804
2110
 
2111
+ adj_ret_reg.apply([&](int reg, auto c) {
2112
+ int i;
2113
+ if (compute_index(src, indices, axis, offset, c, i))
2114
+ wp::atomic_add(&src.grad[i], adj_ret_reg.data[reg]);
2115
+ });
2116
+ }
2117
+
2118
+ template <typename T, int M, typename AdjTile>
2119
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_ret)
2120
+ {
2121
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x), adj_src, adj_indices, adj_axis, tile_coord(0), adj_ret);
2122
+ }
2123
+ template <typename T, int M, typename AdjTile>
2124
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_ret)
2125
+ {
2126
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y), adj_src, adj_indices, adj_axis, tile_coord(0, 0), adj_ret);
2127
+ }
2128
+ template <typename T, int M, typename AdjTile>
2129
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret)
2130
+ {
2131
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0), adj_ret);
2132
+ }
2133
+ template <typename T, int M, typename AdjTile>
2134
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret)
2135
+ {
2136
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z, w), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0, 0), adj_ret);
2137
+ }
1805
2138
 
1806
2139
  template <typename T, typename Tile, typename AdjTile, typename Coord>
1807
2140
  inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
@@ -1827,7 +2160,33 @@ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z,
1827
2160
  template <typename T, typename Tile, typename AdjTile>
1828
2161
  inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
1829
2162
 
2163
+ template <typename T, int M, typename Tile, typename AdjTile, typename Coord>
2164
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset, AdjTile& adj_t)
2165
+ {
2166
+ // we allow users to override grad of src
2167
+ if (adj_dest.data)
2168
+ dest.grad = adj_dest.data;
2169
+
2170
+ auto adj_t_reg = tile_register_like<Tile>();
1830
2171
 
2172
+ adj_t_reg.apply([&](int reg, auto c) {
2173
+ int i;
2174
+ if (compute_index(dest, indices, axis, offset, c, i))
2175
+ adj_t_reg.data[reg] += dest.grad[i];
2176
+ });
2177
+
2178
+ // write adjoints back
2179
+ adj_t.grad_add(adj_t_reg);
2180
+ }
2181
+
2182
+ template <typename T, int M, typename Tile, typename AdjTile>
2183
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
2184
+ template <typename T, int M, typename Tile, typename AdjTile>
2185
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
2186
+ template <typename T, int M, typename Tile, typename AdjTile>
2187
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
2188
+ template <typename T, int M, typename Tile, typename AdjTile>
2189
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
1831
2190
 
1832
2191
  // adj_tile_atomic_add is an alias for adj_tile_store
1833
2192
  template <typename T, typename Tile, typename AdjTile, typename AdjRet>
@@ -1839,13 +2198,28 @@ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, in
1839
2198
  template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1840
2199
  inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
1841
2200
 
2201
+ // adj_tile_atomic_add_indexed is an alias for adj_tile_store_indexed
2202
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2203
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
2204
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2205
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
2206
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2207
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
2208
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2209
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
1842
2210
 
1843
2211
  // unary map
1844
- template <typename Tile, typename Fwd>
1845
- inline CUDA_CALLABLE auto tile_map(Fwd op,
1846
- Tile &a)
2212
+ template <typename Tile, typename Fwd, typename ReturnTile>
2213
+ inline CUDA_CALLABLE auto tile_map(Fwd op, Tile &a, ReturnTile &r)
1847
2214
  {
1848
- auto out = tile_register_like<Tile>();
2215
+ // verify shapes and sizes are compatible
2216
+ using ShapeIn = typename Tile::Layout::Shape;
2217
+ using ShapeOut = typename ReturnTile::Layout::Shape;
2218
+
2219
+ static_assert(ShapeIn::N == ShapeOut::N, "Number of tile dimensions must match for unary map");
2220
+ static_assert(ShapeIn::size() == ShapeOut::size(), "Tile sizes must match for unary map");
2221
+
2222
+ auto out = tile_register_like<ReturnTile>();
1849
2223
  auto a_reg = a.copy_to_register();
1850
2224
 
1851
2225
  using Layout = typename decltype(out)::Layout;
@@ -1884,12 +2258,24 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1884
2258
  }
1885
2259
 
1886
2260
  // binary map
1887
- template <typename TileA, typename TileB, typename Fwd>
2261
+ template <typename TileA, typename TileB, typename Fwd, typename ReturnTile>
1888
2262
  inline CUDA_CALLABLE auto tile_map(Fwd op,
1889
2263
  TileA& a,
1890
- TileB& b)
2264
+ TileB& b,
2265
+ ReturnTile& r)
1891
2266
  {
1892
- auto out = tile_register_like<TileA>();
2267
+ // verify shapes and sizes are compatible
2268
+ using ShapeA = typename TileA::Layout::Shape;
2269
+ using ShapeB = typename TileB::Layout::Shape;
2270
+ using ShapeOut = typename ReturnTile::Layout::Shape;
2271
+
2272
+ static_assert(ShapeA::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
2273
+ static_assert(ShapeB::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
2274
+
2275
+ static_assert(ShapeA::size() == ShapeOut::size(), "Tile sizes must match for binary map");
2276
+ static_assert(ShapeB::size() == ShapeOut::size(), "Tile sizes must match for binary map");
2277
+
2278
+ auto out = tile_register_like<ReturnTile>();
1893
2279
 
1894
2280
  auto a_reg = a.copy_to_register();
1895
2281
  auto b_reg = b.copy_to_register();
@@ -1905,7 +2291,6 @@ inline CUDA_CALLABLE auto tile_map(Fwd op,
1905
2291
  return out;
1906
2292
  }
1907
2293
 
1908
-
1909
2294
  template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
1910
2295
  inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1911
2296
  TileA &a,
@@ -1936,28 +2321,32 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1936
2321
  adj_b.grad_add(adj_b_reg);
1937
2322
  }
1938
2323
 
1939
- // wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
2324
+ // We wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
1940
2325
  // this is important because many of the builtin operators don't follow particular conventions on references for
1941
2326
  // the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
1942
- #define tile_unary_map(op, a) tile_map([](auto x) { return op(x);}, a)
1943
- #define adj_tile_unary_map(op, a, adj_op, adj_a, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
2327
+ // The r argument is a dummy return tile argument, because we can't template on the return tile type in a macro definition.
2328
+ // So if we want users to be able to define functions that return a tile type that is different from the input type,
2329
+ // we must pass an extra dummy return tile argument that is used define the return type of tile_map.
2330
+
2331
+ #define tile_unary_map(op, a, r) tile_map([](auto x) { return op(x);}, a, r)
2332
+ #define adj_tile_unary_map(op, a, r, adj_op, adj_a, adj_r, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
1944
2333
 
1945
- #define tile_binary_map(op, a, b) tile_map([](auto x, auto y) { return op(x, y);}, a, b)
1946
- #define adj_tile_binary_map(op, a, b, adj_op, adj_a, adj_b, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
2334
+ #define tile_binary_map(op, a, b, r) tile_map([](auto x, auto y) { return op(x, y);}, a, b, r)
2335
+ #define adj_tile_binary_map(op, a, b, r, adj_op, adj_a, adj_b, adj_r, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
1947
2336
 
1948
2337
  // -tile (unary neg)
1949
2338
  template <typename Tile>
1950
- inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a); }
2339
+ inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a, a); }
1951
2340
 
1952
2341
  template <typename Tile, typename AdjTile>
1953
- inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, wp::adj_neg, adj_a, adj_ret); }
2342
+ inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, a, wp::adj_neg, adj_a, adj_a, adj_ret); }
1954
2343
 
1955
2344
 
1956
2345
  // tile + tile
1957
2346
  template <typename TileA, typename TileB>
1958
2347
  inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
1959
2348
  {
1960
- return tile_binary_map(add, a, b);
2349
+ return tile_binary_map(add, a, b, a);
1961
2350
  }
1962
2351
 
1963
2352
  // add overloads get called in user function adjoints generated by codegen (adj_tile += adj_ret)
@@ -1984,20 +2373,20 @@ inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_register
1984
2373
  template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1985
2374
  inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1986
2375
  {
1987
- adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
2376
+ adj_tile_binary_map(add, a, b, a, adj_add, adj_a, adj_b, adj_a, adj_c);
1988
2377
  }
1989
2378
 
1990
2379
  // tile - tile
1991
2380
  template <typename TileA, typename TileB>
1992
2381
  inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
1993
2382
  {
1994
- return tile_binary_map(sub, a, b);
2383
+ return tile_binary_map(sub, a, b, a);
1995
2384
  }
1996
2385
 
1997
2386
  template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1998
2387
  inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1999
2388
  {
2000
- adj_tile_binary_map(sub, a, b, adj_sub, adj_a, adj_b, adj_c);
2389
+ adj_tile_binary_map(sub, a, b, a, adj_sub, adj_a, adj_b, adj_a, adj_c);
2001
2390
  }
2002
2391
 
2003
2392
 
@@ -2008,7 +2397,7 @@ inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
2008
2397
  // promote scalar to a constant tile
2009
2398
  auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
2010
2399
 
2011
- return tile_binary_map(mul, a, s_tile);
2400
+ return tile_binary_map(mul, a, s_tile, a);
2012
2401
  }
2013
2402
 
2014
2403
  template <typename Tile, typename AdjTile>
@@ -2024,7 +2413,7 @@ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
2024
2413
  // initialize to constant
2025
2414
  s_tile = s;
2026
2415
 
2027
- adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
2416
+ adj_tile_binary_map(mul, a, s_tile, a, adj_mul, adj_a, adj_s_tile, adj_a, adj_c);
2028
2417
 
2029
2418
  for (int i=0; i < Layout::NumRegs; ++i)
2030
2419
  {
@@ -2834,7 +3223,7 @@ template <typename Tile, typename AdjTile>
2834
3223
  inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
2835
3224
  {
2836
3225
  auto a = tile_transpose(adj_ret);
2837
- auto b = adj_t;
3226
+ auto& b = adj_t;
2838
3227
 
2839
3228
  adj_t.assign(tile_add(a,b));
2840
3229
  }