warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (134) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +47 -67
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +312 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1249 -784
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/fabric.py +1 -1
  18. warp/fem/cache.py +27 -19
  19. warp/fem/domain.py +2 -2
  20. warp/fem/field/nodal_field.py +2 -2
  21. warp/fem/field/virtual.py +264 -166
  22. warp/fem/geometry/geometry.py +5 -5
  23. warp/fem/integrate.py +129 -51
  24. warp/fem/space/restriction.py +4 -0
  25. warp/fem/space/shape/tet_shape_function.py +3 -10
  26. warp/jax_experimental/custom_call.py +1 -1
  27. warp/jax_experimental/ffi.py +2 -1
  28. warp/marching_cubes.py +708 -0
  29. warp/native/array.h +99 -4
  30. warp/native/builtin.h +82 -5
  31. warp/native/bvh.cpp +64 -28
  32. warp/native/bvh.cu +58 -58
  33. warp/native/bvh.h +2 -2
  34. warp/native/clang/clang.cpp +7 -7
  35. warp/native/coloring.cpp +8 -2
  36. warp/native/crt.cpp +2 -2
  37. warp/native/crt.h +3 -5
  38. warp/native/cuda_util.cpp +41 -10
  39. warp/native/cuda_util.h +10 -4
  40. warp/native/exports.h +1842 -1908
  41. warp/native/fabric.h +2 -1
  42. warp/native/hashgrid.cpp +37 -37
  43. warp/native/hashgrid.cu +2 -2
  44. warp/native/initializer_array.h +1 -1
  45. warp/native/intersect.h +2 -2
  46. warp/native/mat.h +1910 -116
  47. warp/native/mathdx.cpp +43 -43
  48. warp/native/mesh.cpp +24 -24
  49. warp/native/mesh.cu +26 -26
  50. warp/native/mesh.h +4 -2
  51. warp/native/nanovdb/GridHandle.h +179 -12
  52. warp/native/nanovdb/HostBuffer.h +8 -7
  53. warp/native/nanovdb/NanoVDB.h +517 -895
  54. warp/native/nanovdb/NodeManager.h +323 -0
  55. warp/native/nanovdb/PNanoVDB.h +2 -2
  56. warp/native/quat.h +331 -14
  57. warp/native/range.h +7 -1
  58. warp/native/reduce.cpp +10 -10
  59. warp/native/reduce.cu +13 -14
  60. warp/native/runlength_encode.cpp +2 -2
  61. warp/native/runlength_encode.cu +5 -5
  62. warp/native/scan.cpp +3 -3
  63. warp/native/scan.cu +4 -4
  64. warp/native/sort.cpp +10 -10
  65. warp/native/sort.cu +22 -22
  66. warp/native/sparse.cpp +8 -8
  67. warp/native/sparse.cu +13 -13
  68. warp/native/spatial.h +366 -17
  69. warp/native/temp_buffer.h +2 -2
  70. warp/native/tile.h +283 -69
  71. warp/native/vec.h +381 -14
  72. warp/native/volume.cpp +54 -54
  73. warp/native/volume.cu +1 -1
  74. warp/native/volume.h +2 -1
  75. warp/native/volume_builder.cu +30 -37
  76. warp/native/warp.cpp +150 -149
  77. warp/native/warp.cu +323 -192
  78. warp/native/warp.h +227 -226
  79. warp/optim/linear.py +736 -271
  80. warp/render/imgui_manager.py +289 -0
  81. warp/render/render_opengl.py +85 -6
  82. warp/sim/graph_coloring.py +2 -2
  83. warp/sparse.py +558 -175
  84. warp/tests/aux_test_module_aot.py +7 -0
  85. warp/tests/cuda/test_async.py +3 -3
  86. warp/tests/cuda/test_conditional_captures.py +101 -0
  87. warp/tests/geometry/test_marching_cubes.py +233 -12
  88. warp/tests/sim/test_coloring.py +6 -6
  89. warp/tests/test_array.py +56 -5
  90. warp/tests/test_codegen.py +3 -2
  91. warp/tests/test_context.py +8 -15
  92. warp/tests/test_enum.py +136 -0
  93. warp/tests/test_examples.py +2 -2
  94. warp/tests/test_fem.py +45 -2
  95. warp/tests/test_fixedarray.py +229 -0
  96. warp/tests/test_func.py +18 -15
  97. warp/tests/test_future_annotations.py +7 -5
  98. warp/tests/test_linear_solvers.py +30 -0
  99. warp/tests/test_map.py +1 -1
  100. warp/tests/test_mat.py +1518 -378
  101. warp/tests/test_mat_assign_copy.py +178 -0
  102. warp/tests/test_mat_constructors.py +574 -0
  103. warp/tests/test_module_aot.py +287 -0
  104. warp/tests/test_print.py +69 -0
  105. warp/tests/test_quat.py +140 -34
  106. warp/tests/test_quat_assign_copy.py +145 -0
  107. warp/tests/test_reload.py +2 -1
  108. warp/tests/test_sparse.py +71 -0
  109. warp/tests/test_spatial.py +140 -34
  110. warp/tests/test_spatial_assign_copy.py +160 -0
  111. warp/tests/test_struct.py +43 -3
  112. warp/tests/test_types.py +0 -20
  113. warp/tests/test_vec.py +179 -34
  114. warp/tests/test_vec_assign_copy.py +143 -0
  115. warp/tests/tile/test_tile.py +184 -18
  116. warp/tests/tile/test_tile_cholesky.py +605 -0
  117. warp/tests/tile/test_tile_load.py +169 -0
  118. warp/tests/tile/test_tile_mathdx.py +2 -558
  119. warp/tests/tile/test_tile_matmul.py +1 -1
  120. warp/tests/tile/test_tile_mlp.py +1 -1
  121. warp/tests/tile/test_tile_shared_memory.py +5 -5
  122. warp/tests/unittest_suites.py +6 -0
  123. warp/tests/walkthrough_debug.py +1 -1
  124. warp/thirdparty/unittest_parallel.py +108 -9
  125. warp/types.py +554 -264
  126. warp/utils.py +68 -86
  127. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  128. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
  129. warp/native/marching.cpp +0 -19
  130. warp/native/marching.cu +0 -514
  131. warp/native/marching.h +0 -19
  132. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  133. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/native/tile.h CHANGED
@@ -230,7 +230,9 @@ struct tile_coord_t
230
230
  out.indices[i] = indices[i] + c.indices[i];
231
231
  }
232
232
  return out;
233
- }
233
+ }
234
+
235
+ static constexpr int size() { return N; }
234
236
  };
235
237
 
236
238
  // This function deduces N = sizeof...(Ints)
@@ -338,7 +340,8 @@ using tile_stride_t = tile_tuple_t<V...>;
338
340
 
339
341
  // represents a tile stored in global memory with dynamic strides
340
342
  // used to represent the source and offset for tile loads to register/shared
341
- template <typename T, typename Shape_>
343
+ // BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
344
+ template <typename T, typename Shape_, bool BoundsCheck=true>
342
345
  struct tile_global_t
343
346
  {
344
347
  using Type = T;
@@ -370,25 +373,33 @@ struct tile_global_t
370
373
 
371
374
  inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
372
375
  {
373
- // element index
374
- int index = 0;
375
-
376
- WP_PRAGMA_UNROLL
377
- for (int i=0; i < Shape::N; ++i)
376
+ if constexpr (BoundsCheck)
378
377
  {
379
- // global = offset + coord
380
- int c = offset[i] + coord[i];
381
-
382
- // handle out of bounds case
383
- if (c >= data.shape[i])
384
- return false;
385
- else
386
- index += data.strides[i]*c;
387
- }
388
-
389
- // array strides are in bytes so we convert to elements
390
- out = index / sizeof(T);
391
- return true;
378
+ // element index
379
+ int index = 0;
380
+
381
+ WP_PRAGMA_UNROLL
382
+ for (int i=0; i < Shape::N; ++i)
383
+ {
384
+ // global = offset + coord
385
+ int c = offset[i] + coord[i];
386
+
387
+ // handle out of bounds case
388
+ if (c >= data.shape[i])
389
+ return false;
390
+ else
391
+ index += data.strides[i]*c;
392
+ }
393
+
394
+ // array strides are in bytes so we convert to elements
395
+ out = index / sizeof(T);
396
+ return true;
397
+ }
398
+ else
399
+ {
400
+ out = index_from_coord(coord);
401
+ return true;
402
+ }
392
403
  }
393
404
 
394
405
  inline CUDA_CALLABLE T load(const Coord& coord) const
@@ -435,6 +446,7 @@ struct tile_global_t
435
446
  }
436
447
  };
437
448
 
449
+
438
450
  template <typename Shape_>
439
451
  struct tile_layout_register_t
440
452
  {
@@ -521,7 +533,8 @@ struct tile_register_t
521
533
  data[i] = value;
522
534
  }
523
535
 
524
- inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
536
+ template <bool BoundsCheck>
537
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
525
538
  {
526
539
  copy_from_global(t);
527
540
  return *this;
@@ -647,8 +660,7 @@ struct tile_register_t
647
660
 
648
661
  CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
649
662
  {
650
- apply([&](int reg, auto c) {data[reg] = global.load_grad(c);});
651
-
663
+ apply([&](int reg, auto c) {data[reg] += global.load_grad(c);});
652
664
  }
653
665
 
654
666
  inline CUDA_CALLABLE auto& grad_to_register()
@@ -935,7 +947,9 @@ struct tile_shared_t
935
947
  }
936
948
 
937
949
  // assign from a global tile (load)
938
- inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
950
+
951
+ template <bool BoundsCheck>
952
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
939
953
  {
940
954
  copy_from_global(t);
941
955
  return *this;
@@ -1103,7 +1117,7 @@ struct tile_shared_t
1103
1117
  }
1104
1118
 
1105
1119
  WP_TILE_SYNC();
1106
- }
1120
+ }
1107
1121
 
1108
1122
  // copy shared tile to register
1109
1123
  inline CUDA_CALLABLE auto grad_to_register()
@@ -1172,7 +1186,7 @@ struct tile_shared_t
1172
1186
  {
1173
1187
  // alias of shared tile with 128bit type
1174
1188
  using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1175
- tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
1189
+ tile_shared_t<float4, SrcLayout, false> src128((float4*)data.ptr);
1176
1190
 
1177
1191
  assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
1178
1192
  assert(((uint64_t)(dest128))%sizeof(float4) == 0);
@@ -1251,7 +1265,7 @@ struct tile_shared_t
1251
1265
  const int elements = min(Layout::Shape::dim(1), (src.data.shape[lastdim] - src.offset[lastdim]));
1252
1266
  const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1253
1267
  const bool aligned_stride = (src.data.strides[0]/sizeof(T))%Layout::Stride::dim(0) == 0;
1254
-
1268
+
1255
1269
  float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1256
1270
  const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
1257
1271
 
@@ -1262,7 +1276,7 @@ struct tile_shared_t
1262
1276
  {
1263
1277
  // alias of shared tile with 128bit type
1264
1278
  using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1265
- tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
1279
+ tile_shared_t<float4, DestLayout, false> dest128((float4*)data.ptr);
1266
1280
 
1267
1281
  assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
1268
1282
  assert(((uint64_t)(src128))%sizeof(float4) == 0);
@@ -1727,10 +1741,66 @@ inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
1727
1741
  T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
1728
1742
 
1729
1743
  // entry point for load operations, these just return a reference to a global memory array + coordinate
1730
- template <unsigned... Shape, typename... Indices, typename T>
1731
- inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
1744
+ template <typename T, bool BoundsCheck, unsigned... Shape, typename... Offset>
1745
+ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Offset... offset)
1746
+ {
1747
+ return tile_global_t<T, tile_shape_t<Shape...>, BoundsCheck>(src, tile_coord(offset...));
1748
+ }
1749
+
1750
+ // used for indexed loads and stores
1751
+ template <typename T, int M, typename Coord>
1752
+ inline CUDA_CALLABLE bool compute_index(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Coord c, int& out)
1753
+ {
1754
+ int index = 0;
1755
+
1756
+ WP_PRAGMA_UNROLL
1757
+ for (int i = 0; i < Coord::size(); ++i)
1758
+ {
1759
+ if (i == axis)
1760
+ {
1761
+ // global = offset_coord + index_mapped_coord
1762
+ int index_along_axis = offset[i] + indices.data(c[i]);
1763
+
1764
+ // handle out of bounds case
1765
+ if (index_along_axis >= src.shape[i])
1766
+ return false;
1767
+ else
1768
+ index += src.strides[i] * index_along_axis;
1769
+ }
1770
+ else
1771
+ {
1772
+ // global = offset_coord + coord
1773
+ int g = offset[i] + c[i];
1774
+
1775
+ // handle out of bounds case
1776
+ if (g >= src.shape[i])
1777
+ return false;
1778
+ else
1779
+ index += src.strides[i] * g;
1780
+ }
1781
+ }
1782
+
1783
+ // array strides are in bytes so we convert to elements
1784
+ out = index / sizeof(T);
1785
+ return true;
1786
+ }
1787
+
1788
+
1789
+ template <unsigned... Shape, int M, typename T, typename... Offset>
1790
+ inline CUDA_CALLABLE auto tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Offset... offset)
1732
1791
  {
1733
- return tile_global_t<T, tile_shape_t<Shape...>>(src, tile_coord(offset...));
1792
+ auto out = tile_register_t<T, tile_layout_register_t<tile_shape_t<Shape...>>>();
1793
+ auto offset_coord = tile_coord(offset...);
1794
+
1795
+ out.apply([&](int reg, auto c) {
1796
+ int i;
1797
+ if (compute_index(src, indices, axis, offset_coord, c, i))
1798
+ out.data[reg] = src.data[i];
1799
+ else
1800
+ out.data[reg] = T(0);
1801
+ });
1802
+
1803
+ return out;
1734
1804
  }
1735
1805
 
1736
1806
  // // entry point for tile store operations
@@ -1741,38 +1811,90 @@ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
1741
1811
  // }
1742
1812
 
1743
1813
  // entry point for tile store operations
1744
- template <typename T, typename Tile>
1745
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
1746
- template <typename T, typename Tile>
1747
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y))); }
1748
- template <typename T, typename Tile>
1749
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z))); }
1750
- template <typename T, typename Tile>
1751
- inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w))); }
1814
+ template <typename T, bool BoundsCheck, typename Tile>
1815
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x))); }
1816
+ template <typename T, bool BoundsCheck, typename Tile>
1817
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y))); }
1818
+ template <typename T, bool BoundsCheck, typename Tile>
1819
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z))); }
1820
+ template <typename T, bool BoundsCheck, typename Tile>
1821
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z, w))); }
1822
+
1823
+ template <typename T, int M, typename Tile, typename Coord>
1824
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
1825
+ {
1826
+ auto src_reg = src.copy_to_register();
1827
+
1828
+ src_reg.apply([&](int reg, auto c) {
1829
+ int i;
1830
+ if (compute_index(dest, indices, axis, offset, c, i))
1831
+ dest.data[i] = src_reg.data[reg];
1832
+ });
1833
+ }
1834
+
1835
+ // entry point for tile index store operations
1836
+ template <typename T, int M, typename Tile>
1837
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x), src); }
1838
+ template <typename T, int M, typename Tile>
1839
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y), src); }
1840
+ template <typename T, int M, typename Tile>
1841
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
1842
+ template <typename T, int M, typename Tile>
1843
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
1752
1844
 
1753
1845
 
1754
1846
  // compiler struggles with these if they are one line
1755
- template <typename T, typename Tile>
1847
+ template <typename T, bool BoundsCheck, typename Tile>
1756
1848
  inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) {
1757
- tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x));
1849
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x));
1758
1850
  return src.atomic_add(global);
1759
1851
  }
1760
- template <typename T, typename Tile>
1852
+ template <typename T, bool BoundsCheck, typename Tile>
1761
1853
  inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) {
1762
- tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y));
1854
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y));
1763
1855
  return src.atomic_add(global);
1764
1856
  }
1765
- template <typename T, typename Tile>
1857
+ template <typename T, bool BoundsCheck, typename Tile>
1766
1858
  inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) {
1767
- tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y, z));
1859
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z));
1768
1860
  return src.atomic_add(global);
1769
1861
  }
1770
- template <typename T, typename Tile>
1862
+ template <typename T, bool BoundsCheck, typename Tile>
1771
1863
  inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) {
1772
- tile_global_t<T, typename Tile::Layout::Shape> global(dest, tile_coord(x, y, z, w));
1864
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z, w));
1773
1865
  return src.atomic_add(global);
1774
1866
  }
1775
1867
 
1868
+ template <typename T, int M, typename Tile, typename Coord>
1869
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
1870
+ {
1871
+ auto src_reg = src.copy_to_register();
1872
+ auto ret_reg = tile_register_like<Tile>();
1873
+
1874
+ src_reg.apply([&](int reg, auto c) {
1875
+ int i;
1876
+ if (compute_index(dest, indices, axis, offset, c, i))
1877
+ ret_reg.data[reg] = wp::atomic_add(&dest.data[i], src_reg.data[reg]);
1878
+ else
1879
+ ret_reg.data[reg] = T(0);
1880
+ });
1881
+
1882
+ return ret_reg;
1883
+ }
1884
+
1885
+ // entry point for tile index atomic add operations
1886
+ template <typename T, int M, typename Tile>
1887
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x), src); }
1888
+
1889
+ template <typename T, int M, typename Tile>
1890
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y), src); }
1891
+
1892
+ template <typename T, int M, typename Tile>
1893
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
1894
+
1895
+ template <typename T, int M, typename Tile>
1896
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
1897
+
1776
1898
 
1777
1899
  //-------------------------------------
1778
1900
  // Adjoints
@@ -1791,7 +1913,6 @@ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
1791
1913
  adj_ret.atomic_add_grad(dest);
1792
1914
  }
1793
1915
 
1794
-
1795
1916
  template <typename T, typename AdjTile>
1796
1917
  inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
1797
1918
  template <typename T, typename AdjTile>
@@ -1801,7 +1922,44 @@ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, ar
1801
1922
  template <typename T, typename AdjTile>
1802
1923
  inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
1803
1924
 
1925
+ template <typename T, int M, typename AdjTile, typename Coord>
1926
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset,
1927
+ array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset,
1928
+ AdjTile& adj_ret)
1929
+ {
1930
+ // we allow users to override grad of src
1931
+ if (adj_src.data)
1932
+ src.grad = adj_src.data;
1804
1933
 
1934
+ auto adj_ret_reg = adj_ret.grad_to_register();
1935
+
1936
+ adj_ret_reg.apply([&](int reg, auto c) {
1937
+ int i;
1938
+ if (compute_index(src, indices, axis, offset, c, i))
1939
+ wp::atomic_add(&src.grad[i], adj_ret_reg.data[reg]);
1940
+ });
1941
+ }
1942
+
1943
+ template <typename T, int M, typename AdjTile>
1944
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_ret)
1945
+ {
1946
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x), adj_src, adj_indices, adj_axis, tile_coord(0), adj_ret);
1947
+ }
1948
+ template <typename T, int M, typename AdjTile>
1949
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_ret)
1950
+ {
1951
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y), adj_src, adj_indices, adj_axis, tile_coord(0, 0), adj_ret);
1952
+ }
1953
+ template <typename T, int M, typename AdjTile>
1954
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret)
1955
+ {
1956
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0), adj_ret);
1957
+ }
1958
+ template <typename T, int M, typename AdjTile>
1959
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret)
1960
+ {
1961
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z, w), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0, 0), adj_ret);
1962
+ }
1805
1963
 
1806
1964
  template <typename T, typename Tile, typename AdjTile, typename Coord>
1807
1965
  inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
@@ -1827,7 +1985,33 @@ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z,
1827
1985
  template <typename T, typename Tile, typename AdjTile>
1828
1986
  inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
1829
1987
 
1988
+ template <typename T, int M, typename Tile, typename AdjTile, typename Coord>
1989
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset, AdjTile& adj_t)
1990
+ {
1991
+ // we allow users to override grad of src
1992
+ if (adj_dest.data)
1993
+ dest.grad = adj_dest.data;
1994
+
1995
+ auto adj_t_reg = tile_register_like<Tile>();
1996
+
1997
+ adj_t_reg.apply([&](int reg, auto c) {
1998
+ int i;
1999
+ if (compute_index(dest, indices, axis, offset, c, i))
2000
+ adj_t_reg.data[reg] += dest.grad[i];
2001
+ });
2002
+
2003
+ // write adjoints back
2004
+ adj_t.grad_add(adj_t_reg);
2005
+ }
1830
2006
 
2007
+ template <typename T, int M, typename Tile, typename AdjTile>
2008
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
2009
+ template <typename T, int M, typename Tile, typename AdjTile>
2010
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
2011
+ template <typename T, int M, typename Tile, typename AdjTile>
2012
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
2013
+ template <typename T, int M, typename Tile, typename AdjTile>
2014
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
1831
2015
 
1832
2016
  // adj_tile_atomic_add is an alias for adj_tile_store
1833
2017
  template <typename T, typename Tile, typename AdjTile, typename AdjRet>
@@ -1839,13 +2023,28 @@ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, in
1839
2023
  template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1840
2024
  inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
1841
2025
 
2026
+ // adj_tile_atomic_add_indexed is an alias for adj_tile_store_indexed
2027
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2028
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
2029
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2030
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
2031
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2032
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
2033
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2034
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
1842
2035
 
1843
2036
  // unary map
1844
- template <typename Tile, typename Fwd>
1845
- inline CUDA_CALLABLE auto tile_map(Fwd op,
1846
- Tile &a)
2037
+ template <typename Tile, typename Fwd, typename ReturnTile>
2038
+ inline CUDA_CALLABLE auto tile_map(Fwd op, Tile &a, ReturnTile &r)
1847
2039
  {
1848
- auto out = tile_register_like<Tile>();
2040
+ // verify shapes and sizes are compatible
2041
+ using ShapeIn = typename Tile::Layout::Shape;
2042
+ using ShapeOut = typename ReturnTile::Layout::Shape;
2043
+
2044
+ static_assert(ShapeIn::N == ShapeOut::N, "Number of tile dimensions must match for unary map");
2045
+ static_assert(ShapeIn::size() == ShapeOut::size(), "Tile sizes must match for unary map");
2046
+
2047
+ auto out = tile_register_like<ReturnTile>();
1849
2048
  auto a_reg = a.copy_to_register();
1850
2049
 
1851
2050
  using Layout = typename decltype(out)::Layout;
@@ -1884,12 +2083,24 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1884
2083
  }
1885
2084
 
1886
2085
  // binary map
1887
- template <typename TileA, typename TileB, typename Fwd>
2086
+ template <typename TileA, typename TileB, typename Fwd, typename ReturnTile>
1888
2087
  inline CUDA_CALLABLE auto tile_map(Fwd op,
1889
2088
  TileA& a,
1890
- TileB& b)
2089
+ TileB& b,
2090
+ ReturnTile& r)
1891
2091
  {
1892
- auto out = tile_register_like<TileA>();
2092
+ // verify shapes and sizes are compatible
2093
+ using ShapeA = typename TileA::Layout::Shape;
2094
+ using ShapeB = typename TileB::Layout::Shape;
2095
+ using ShapeOut = typename ReturnTile::Layout::Shape;
2096
+
2097
+ static_assert(ShapeA::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
2098
+ static_assert(ShapeB::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
2099
+
2100
+ static_assert(ShapeA::size() == ShapeOut::size(), "Tile sizes must match for binary map");
2101
+ static_assert(ShapeB::size() == ShapeOut::size(), "Tile sizes must match for binary map");
2102
+
2103
+ auto out = tile_register_like<ReturnTile>();
1893
2104
 
1894
2105
  auto a_reg = a.copy_to_register();
1895
2106
  auto b_reg = b.copy_to_register();
@@ -1905,7 +2116,6 @@ inline CUDA_CALLABLE auto tile_map(Fwd op,
1905
2116
  return out;
1906
2117
  }
1907
2118
 
1908
-
1909
2119
  template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
1910
2120
  inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1911
2121
  TileA &a,
@@ -1936,28 +2146,32 @@ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1936
2146
  adj_b.grad_add(adj_b_reg);
1937
2147
  }
1938
2148
 
1939
- // wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
2149
+ // We wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
1940
2150
  // this is important because many of the builtin operators don't follow particular conventions on references for
1941
2151
  // the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
1942
- #define tile_unary_map(op, a) tile_map([](auto x) { return op(x);}, a)
1943
- #define adj_tile_unary_map(op, a, adj_op, adj_a, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
2152
+ // The r argument is a dummy return tile argument, because we can't template on the return tile type in a macro definition.
2153
+ // So if we want users to be able to define functions that return a tile type that is different from the input type,
2154
+ // we must pass an extra dummy return tile argument that is used define the return type of tile_map.
2155
+
2156
+ #define tile_unary_map(op, a, r) tile_map([](auto x) { return op(x);}, a, r)
2157
+ #define adj_tile_unary_map(op, a, r, adj_op, adj_a, adj_r, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
1944
2158
 
1945
- #define tile_binary_map(op, a, b) tile_map([](auto x, auto y) { return op(x, y);}, a, b)
1946
- #define adj_tile_binary_map(op, a, b, adj_op, adj_a, adj_b, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
2159
+ #define tile_binary_map(op, a, b, r) tile_map([](auto x, auto y) { return op(x, y);}, a, b, r)
2160
+ #define adj_tile_binary_map(op, a, b, r, adj_op, adj_a, adj_b, adj_r, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
1947
2161
 
1948
2162
  // -tile (unary neg)
1949
2163
  template <typename Tile>
1950
- inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a); }
2164
+ inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a, a); }
1951
2165
 
1952
2166
  template <typename Tile, typename AdjTile>
1953
- inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, wp::adj_neg, adj_a, adj_ret); }
2167
+ inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, a, wp::adj_neg, adj_a, adj_a, adj_ret); }
1954
2168
 
1955
2169
 
1956
2170
  // tile + tile
1957
2171
  template <typename TileA, typename TileB>
1958
2172
  inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
1959
2173
  {
1960
- return tile_binary_map(add, a, b);
2174
+ return tile_binary_map(add, a, b, a);
1961
2175
  }
1962
2176
 
1963
2177
  // add overloads get called in user function adjoints generated by codegen (adj_tile += adj_ret)
@@ -1984,20 +2198,20 @@ inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_register
1984
2198
  template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1985
2199
  inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1986
2200
  {
1987
- adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
2201
+ adj_tile_binary_map(add, a, b, a, adj_add, adj_a, adj_b, adj_a, adj_c);
1988
2202
  }
1989
2203
 
1990
2204
  // tile - tile
1991
2205
  template <typename TileA, typename TileB>
1992
2206
  inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
1993
2207
  {
1994
- return tile_binary_map(sub, a, b);
2208
+ return tile_binary_map(sub, a, b, a);
1995
2209
  }
1996
2210
 
1997
2211
  template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1998
2212
  inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1999
2213
  {
2000
- adj_tile_binary_map(sub, a, b, adj_sub, adj_a, adj_b, adj_c);
2214
+ adj_tile_binary_map(sub, a, b, a, adj_sub, adj_a, adj_b, adj_a, adj_c);
2001
2215
  }
2002
2216
 
2003
2217
 
@@ -2008,7 +2222,7 @@ inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
2008
2222
  // promote scalar to a constant tile
2009
2223
  auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
2010
2224
 
2011
- return tile_binary_map(mul, a, s_tile);
2225
+ return tile_binary_map(mul, a, s_tile, a);
2012
2226
  }
2013
2227
 
2014
2228
  template <typename Tile, typename AdjTile>
@@ -2024,7 +2238,7 @@ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
2024
2238
  // initialize to constant
2025
2239
  s_tile = s;
2026
2240
 
2027
- adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
2241
+ adj_tile_binary_map(mul, a, s_tile, a, adj_mul, adj_a, adj_s_tile, adj_a, adj_c);
2028
2242
 
2029
2243
  for (int i=0; i < Layout::NumRegs; ++i)
2030
2244
  {