warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/tile.h CHANGED
@@ -43,18 +43,10 @@
43
43
  };
44
44
  #endif
45
45
 
46
- // only used while building the warp core library
47
- #ifndef WP_TILE_BLOCK_DIM
48
- #define WP_TILE_BLOCK_DIM 256
49
- #endif
50
-
51
- #if !defined(__CUDA_ARCH__)
52
- #define WP_TILE_SHARED static
53
- #define WP_TILE_SYNC void
54
-
55
- #else
56
- #define WP_TILE_SHARED __shared__
46
+ #if defined(__CUDA_ARCH__)
57
47
  #define WP_TILE_SYNC __syncthreads
48
+ #else
49
+ #define WP_TILE_SYNC void
58
50
  #endif
59
51
 
60
52
  #if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
@@ -140,7 +132,6 @@
140
132
  [ ] LayerNorm
141
133
  [ ] SoftMax
142
134
  [ ] GEMM
143
- [ ] warp.sim (CRBA)
144
135
  [ ] Batched MLP
145
136
  [ ] Layer norm
146
137
  [ ] FNO + Burgers equation
@@ -149,7 +140,6 @@
149
140
  [ ] MeshCNN (Modulus, Oliver)
150
141
  [ ] BioNemo (Ali)
151
142
  [ ] Skinning (David/Or/Vismay)
152
- [ ] warp.sim (VBD)
153
143
  [ ] Error checking
154
144
  [ ] Ensure functions passed to tile_map() are compatible with tile type
155
145
  [ ] Ensure that args passed to tile ops are compatible
@@ -213,6 +203,12 @@ struct is_same<T, T> {
213
203
  static constexpr bool value = true;
214
204
  };
215
205
 
206
+ // Helper for dependent static_assert failures
207
+ template <typename T>
208
+ struct always_false {
209
+ static constexpr bool value = false;
210
+ };
211
+
216
212
 
217
213
  template <int N>
218
214
  struct tile_coord_t
@@ -338,6 +334,113 @@ template <int... V>
338
334
  using tile_stride_t = tile_tuple_t<V...>;
339
335
 
340
336
 
337
+ // helper to remove a dimension from a shape (used for axis reductions)
338
+ template<int Axis, typename Shape>
339
+ struct tile_shape_remove_dim {
340
+ static_assert(Axis >= 0 && Axis < Shape::N, "Axis out of bounds for tile_shape_remove_dim");
341
+ };
342
+
343
+ // 1D -> scalar
344
+ template<int D0>
345
+ struct tile_shape_remove_dim<0, tile_shape_t<D0>> {
346
+ using type = tile_shape_t<1>;
347
+ };
348
+
349
+ // 2D -> 1D
350
+ template<int D0, int D1>
351
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1>> {
352
+ using type = tile_shape_t<D1>;
353
+ };
354
+
355
+ template<int D0, int D1>
356
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1>> {
357
+ using type = tile_shape_t<D0>;
358
+ };
359
+
360
+ // 3D -> 2D
361
+ template<int D0, int D1, int D2>
362
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2>> {
363
+ using type = tile_shape_t<D1, D2>;
364
+ };
365
+
366
+ template<int D0, int D1, int D2>
367
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2>> {
368
+ using type = tile_shape_t<D0, D2>;
369
+ };
370
+
371
+ template<int D0, int D1, int D2>
372
+ struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2>> {
373
+ using type = tile_shape_t<D0, D1>;
374
+ };
375
+
376
+ // 4D -> 3D
377
+ template<int D0, int D1, int D2, int D3>
378
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2, D3>> {
379
+ using type = tile_shape_t<D1, D2, D3>;
380
+ };
381
+
382
+ template<int D0, int D1, int D2, int D3>
383
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2, D3>> {
384
+ using type = tile_shape_t<D0, D2, D3>;
385
+ };
386
+
387
+ template<int D0, int D1, int D2, int D3>
388
+ struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2, D3>> {
389
+ using type = tile_shape_t<D0, D1, D3>;
390
+ };
391
+
392
+ template<int D0, int D1, int D2, int D3>
393
+ struct tile_shape_remove_dim<3, tile_shape_t<D0, D1, D2, D3>> {
394
+ using type = tile_shape_t<D0, D1, D2>;
395
+ };
396
+
397
+
398
+ // helper to insert an axis value into a coordinate (inverse of removing dimension)
399
+ // used for mapping output coordinates back to input coordinates during axis reduction
400
+ template<int Axis, int N>
401
+ CUDA_CALLABLE constexpr auto tile_coord_insert_axis(const tile_coord_t<N>& coord, int axis_val)
402
+ {
403
+ static_assert(Axis >= 0 && Axis <= N, "Axis out of bounds for tile_coord_insert_axis");
404
+
405
+ if constexpr (N == 0)
406
+ {
407
+ // Scalar -> 1D
408
+ static_assert(Axis == 0, "Invalid axis for scalar coordinate");
409
+ return tile_coord(axis_val);
410
+ }
411
+ else if constexpr (N == 1)
412
+ {
413
+ // 1D -> 2D
414
+ if constexpr (Axis == 0)
415
+ return tile_coord(axis_val, coord[0]);
416
+ else
417
+ return tile_coord(coord[0], axis_val);
418
+ }
419
+ else if constexpr (N == 2)
420
+ {
421
+ // 2D -> 3D
422
+ if constexpr (Axis == 0)
423
+ return tile_coord(axis_val, coord[0], coord[1]);
424
+ else if constexpr (Axis == 1)
425
+ return tile_coord(coord[0], axis_val, coord[1]);
426
+ else
427
+ return tile_coord(coord[0], coord[1], axis_val);
428
+ }
429
+ else // N == 3
430
+ {
431
+ // 3D -> 4D
432
+ if constexpr (Axis == 0)
433
+ return tile_coord(axis_val, coord[0], coord[1], coord[2]);
434
+ else if constexpr (Axis == 1)
435
+ return tile_coord(coord[0], axis_val, coord[1], coord[2]);
436
+ else if constexpr (Axis == 2)
437
+ return tile_coord(coord[0], coord[1], axis_val, coord[2]);
438
+ else
439
+ return tile_coord(coord[0], coord[1], coord[2], axis_val);
440
+ }
441
+ }
442
+
443
+
341
444
  // represents a tile stored in global memory with dynamic strides
342
445
  // used to represent the source and offset for tile loads to register/shared
343
446
  // BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
@@ -581,7 +684,11 @@ struct tile_register_t
581
684
  const int thread = Layout::thread_from_linear(linear);
582
685
  const int reg = Layout::register_from_linear(linear);
583
686
 
584
- WP_TILE_SHARED Type scratch;
687
+ #if defined(__CUDA_ARCH__)
688
+ __shared__ Type scratch;
689
+ #else
690
+ Type scratch;
691
+ #endif
585
692
 
586
693
  // ensure any previously scheduled threads have finished reading from scratch
587
694
  WP_TILE_SYNC();
@@ -735,43 +842,124 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
735
842
  return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
736
843
  }
737
844
 
738
- inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, bool check=false)
845
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
846
+ // On the CPU we use a fixed size block of stack memory for shared tile allocations.
847
+ // We store a pointer to the current allocation storage either in a reserved register
848
+ // (AArch64) or a static variable (x86-64).
849
+ #if !defined(__CUDA_ARCH__)
850
+ class tile_shared_storage_t;
851
+ #if defined(__aarch64__)
852
+ // x28 is is the last callee-saved register on AArch64. This allows us to call externally
853
+ // compiled functions without worrying about clobbering the pointer.
854
+ // We pass -target-feature +reserve-x28 to Clang to exclude it from register allocation.
855
+ register tile_shared_storage_t* shared_tile_storage asm("x28");
856
+ #else
857
+ // Ideally this would be thread_local, but LLVM's JIT doesn't support TLS yet
858
+ // There is also no support for something like -ffixed-r15 either
859
+ static tile_shared_storage_t* shared_tile_storage;
860
+ #endif
861
+ #endif
862
+ #endif
863
+
864
+ // This class manages a block of "shared" memory for use by tiles.
865
+ // On the GPU this maps to dynamic shared memory, while on the CPU we allocate
866
+ // a fixed size block of memory on the stack and manage allocations from it.
867
+ // An instance of this class gets created at the start of a kernel.
868
+ class tile_shared_storage_t
739
869
  {
870
+ private:
871
+ #if !defined(__CUDA_ARCH__)
872
+ #define WP_MAX_CPU_SHARED 256*1024
873
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
874
+ tile_shared_storage_t* old_value;
875
+ unsigned int smem_base[WP_TILE_BLOCK_DIM];
876
+ char dynamic_smem_base[WP_MAX_CPU_SHARED]; // on CPU allocate a fixed 256k block to use for shared allocs
877
+ #endif
878
+ #endif
879
+
740
880
  // we maintain a per-thread offset into dynamic
741
881
  // shared memory that allows us to keep track of
742
882
  // current use across dynamic function calls
743
- WP_TILE_SHARED int smem_base[WP_TILE_BLOCK_DIM];
883
+ static inline CUDA_CALLABLE unsigned int* get_smem_base()
884
+ {
885
+ #if defined(__CUDA_ARCH__)
886
+ __shared__ unsigned int smem_base[WP_TILE_BLOCK_DIM];
887
+ return smem_base;
888
+ #elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
889
+ return shared_tile_storage->smem_base;
890
+ #else
891
+ static unsigned int smem_base[WP_TILE_BLOCK_DIM];
892
+ return smem_base;
893
+ #endif
894
+ }
895
+
896
+ static inline CUDA_CALLABLE char* get_dynamic_smem_base()
897
+ {
898
+ #if defined(__CUDA_ARCH__)
899
+ extern __shared__ char dynamic_smem_base[];
900
+ return dynamic_smem_base;
901
+ #elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
902
+ return shared_tile_storage->dynamic_smem_base;
903
+ #else
904
+ static char dynamic_smem_base[WP_MAX_CPU_SHARED];
905
+ return dynamic_smem_base;
906
+ #endif
907
+ }
744
908
 
745
- if (init)
909
+ public:
910
+ // cppcheck-suppress uninitMemberVar
911
+ inline CUDA_CALLABLE tile_shared_storage_t()
746
912
  {
913
+ #if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
914
+ // On the CPU save a pointer to this instance in a reserved register
915
+ // or static variable so it can be accessed from anywhere within a kernel.
916
+ old_value = shared_tile_storage;
917
+ shared_tile_storage = this;
918
+ #endif
919
+
920
+ init();
921
+ }
922
+
923
+ inline CUDA_CALLABLE ~tile_shared_storage_t()
924
+ {
925
+ check();
926
+
927
+ #if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
928
+ shared_tile_storage = old_value;
929
+ #endif
930
+ }
931
+
932
+ static inline CUDA_CALLABLE void init()
933
+ {
934
+ unsigned int* smem_base = get_smem_base();
935
+
747
936
  smem_base[WP_TILE_THREAD_IDX] = 0;
748
- return nullptr;
749
937
  }
750
- else if (check)
938
+
939
+ static inline CUDA_CALLABLE void check()
751
940
  {
941
+ unsigned int* smem_base = get_smem_base();
942
+
752
943
  assert(smem_base[WP_TILE_THREAD_IDX] == 0);
753
- return nullptr;
754
944
  }
755
- else
945
+
946
+ static inline CUDA_CALLABLE void* alloc(int num_bytes)
756
947
  {
757
- const int offset = smem_base[WP_TILE_THREAD_IDX];
758
-
948
+ unsigned int* smem_base = get_smem_base();
949
+ char* dynamic_smem_base = get_dynamic_smem_base();
950
+
951
+ const unsigned int offset = smem_base[WP_TILE_THREAD_IDX];
952
+
759
953
  // one entry per-thread so no need for synchronization
760
954
  smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
761
- assert(smem_base[WP_TILE_THREAD_IDX] >= 0);
762
955
 
763
- #ifdef __CUDA_ARCH__
764
- extern __shared__ char dynamic_smem_base[];
765
- #else
766
- // on CPU allocate a fixed 256k block to use for shared allocs
767
- static const int max_cpu_shared = 256*1024;
768
- static char dynamic_smem_base[max_cpu_shared];
769
-
770
- assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
956
+ #if !defined(__CUDA_ARCH__)
957
+ assert(smem_base[WP_TILE_THREAD_IDX] <= WP_MAX_CPU_SHARED);
771
958
  #endif
959
+
772
960
  return &(dynamic_smem_base[offset]);
773
961
  }
774
- }
962
+ };
775
963
 
776
964
 
777
965
  template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
@@ -939,10 +1127,10 @@ struct tile_shared_t
939
1127
  {
940
1128
  // update our per-thread shared memory allocator
941
1129
  if (data.ptr)
942
- tile_alloc_shared(-Layout::Size*int(sizeof(T)));
1130
+ tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
943
1131
 
944
1132
  if (grad.ptr)
945
- tile_alloc_shared(-Layout::Size*int(sizeof(T)));
1133
+ tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
946
1134
  }
947
1135
  }
948
1136
 
@@ -1095,6 +1283,46 @@ struct tile_shared_t
1095
1283
  adj_x -= grad(c);
1096
1284
  }
1097
1285
 
1286
+ // perform AND between a scalar value and a single tile element
1287
+ inline CUDA_CALLABLE void bit_and_inplace(const typename Layout::Coord& c, const Type& x)
1288
+ {
1289
+ // since multiple threads may access the same element
1290
+ // we need to access using atomic operations
1291
+ wp::atomic_and(&data(c), x);
1292
+
1293
+ WP_TILE_SYNC();
1294
+ }
1295
+
1296
+ // backward of inplace scalar AND
1297
+ inline CUDA_CALLABLE void adj_bit_and_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1298
+
1299
+
1300
+ // perform OR between a scalar value and a single tile element
1301
+ inline CUDA_CALLABLE void bit_or_inplace(const typename Layout::Coord& c, const Type& x)
1302
+ {
1303
+ // since multiple threads may access the same element
1304
+ // we need to access using atomic operations
1305
+ wp::atomic_or(&data(c), x);
1306
+
1307
+ WP_TILE_SYNC();
1308
+ }
1309
+
1310
+ // backward of inplace scalar OR
1311
+ inline CUDA_CALLABLE void adj_bit_or_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1312
+
1313
+ // perform XOR between a scalar value and a single tile element
1314
+ inline CUDA_CALLABLE void bit_xor_inplace(const typename Layout::Coord& c, const Type& x)
1315
+ {
1316
+ // since multiple threads may access the same element
1317
+ // we need to access using atomic operations
1318
+ wp::atomic_xor(&data(c), x);
1319
+
1320
+ WP_TILE_SYNC();
1321
+ }
1322
+
1323
+ // backward of inplace scalar XOR
1324
+ inline CUDA_CALLABLE void adj_bit_xor_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1325
+
1098
1326
  // copy register tile to shared
1099
1327
  template <typename Tile>
1100
1328
  inline CUDA_CALLABLE void assign(const Tile& tile)
@@ -1549,7 +1777,11 @@ void tile_register_t<T, L>::print() const
1549
1777
  {
1550
1778
  // create a temporary shared tile so that
1551
1779
  // we can print it deterministically
1552
- WP_TILE_SHARED T smem[L::Size];
1780
+ #if defined(__CUDA_ARCH__)
1781
+ __shared__ T smem[L::Size];
1782
+ #else
1783
+ T smem[L::Size];
1784
+ #endif
1553
1785
  tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
1554
1786
 
1555
1787
  scratch.assign(*this);
@@ -1609,37 +1841,6 @@ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile&
1609
1841
  {
1610
1842
  }
1611
1843
 
1612
- // select specialization for shared tiles
1613
- template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1614
- inline CUDA_CALLABLE auto select(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
1615
- {
1616
- // The double NOT operator !! casts to bool without compiler warnings.
1617
- return (!!cond) ? b.copy_to_register() : a;
1618
- }
1619
-
1620
- template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1621
- inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
1622
- {
1623
- // The double NOT operator !! casts to bool without compiler warnings.
1624
- return (!!cond) ? b : a.copy_to_register();
1625
- }
1626
-
1627
- template <typename C, typename T, typename L, bool Owner>
1628
- inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
1629
- {
1630
- // The double NOT operator !! casts to bool without compiler warnings.
1631
- return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
1632
- }
1633
-
1634
- template <typename C, typename T, typename L, bool LOwner, bool ROwner>
1635
- inline CUDA_CALLABLE auto select(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
1636
- {
1637
- // The double NOT operator !! casts to bool without compiler warnings.
1638
- return (!!cond) ? tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr) : tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr);
1639
- }
1640
-
1641
- // adj_select same as in builtin.h
1642
-
1643
1844
  // where specialization for register/shared tiles
1644
1845
  template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1645
1846
  inline CUDA_CALLABLE auto where(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
@@ -1690,7 +1891,7 @@ template <typename T, typename Shape, typename Strides, bool RequiresGrad>
1690
1891
  inline CUDA_CALLABLE auto tile_alloc_empty()
1691
1892
  {
1692
1893
  constexpr int size = Shape::size();
1693
- T* data = (T*)tile_alloc_shared(size*sizeof(T));
1894
+ T* data = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
1694
1895
  T* grad = nullptr;
1695
1896
 
1696
1897
  #if FP_CHECK
@@ -1709,7 +1910,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1709
1910
 
1710
1911
  if (RequiresGrad)
1711
1912
  {
1712
- grad = (T*)tile_alloc_shared(size*sizeof(T));
1913
+ grad = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
1713
1914
 
1714
1915
  for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1715
1916
  grad[i] = T(0);
@@ -1887,6 +2088,14 @@ inline CUDA_CALLABLE auto tile_ones()
1887
2088
  return T(1);
1888
2089
  }
1889
2090
 
2091
+ // value-initialized tile
2092
+ template <typename T, unsigned... Shape>
2093
+ inline CUDA_CALLABLE auto tile_full(T x)
2094
+ {
2095
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
2096
+ return x;
2097
+ }
2098
+
1890
2099
  // tile with evenly spaced values
1891
2100
  template <typename T, int Len>
1892
2101
  inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
@@ -2438,6 +2647,43 @@ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
2438
2647
  }
2439
2648
 
2440
2649
 
2650
+ // tile & tile
2651
+ template <typename TileA, typename TileB>
2652
+ inline CUDA_CALLABLE auto tile_bit_and(TileA& a, TileB& b)
2653
+ {
2654
+ return tile_binary_map(bit_and, a, b, a);
2655
+ }
2656
+
2657
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2658
+ inline CUDA_CALLABLE void adj_tile_bit_and(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2659
+ {
2660
+ }
2661
+
2662
+ // tile | tile
2663
+ template <typename TileA, typename TileB>
2664
+ inline CUDA_CALLABLE auto tile_bit_or(TileA& a, TileB& b)
2665
+ {
2666
+ return tile_binary_map(bit_or, a, b, a);
2667
+ }
2668
+
2669
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2670
+ inline CUDA_CALLABLE void adj_tile_bit_or(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2671
+ {
2672
+ }
2673
+
2674
+ // tile ^ tile
2675
+ template <typename TileA, typename TileB>
2676
+ inline CUDA_CALLABLE auto tile_bit_xor(TileA& a, TileB& b)
2677
+ {
2678
+ return tile_binary_map(bit_xor, a, b, a);
2679
+ }
2680
+
2681
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2682
+ inline CUDA_CALLABLE void adj_tile_bit_xor(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2683
+ {
2684
+ }
2685
+
2686
+
2441
2687
  template <typename TileA, typename TileB>
2442
2688
  inline CUDA_CALLABLE void tile_add_inplace(TileA& a, TileB& b)
2443
2689
  {
@@ -2557,24 +2803,227 @@ inline CUDA_CALLABLE void adj_tile_sub_inplace(TileA& a, TileB& b, AdjTileA& adj
2557
2803
  adj_b.grad_add(adj_b_reg);
2558
2804
  }
2559
2805
 
2806
+ template <typename TileA, typename TileB>
2807
+ inline CUDA_CALLABLE void tile_bit_and_inplace(TileA& a, TileB& b)
2808
+ {
2809
+ using ShapeA = typename TileA::Layout::Shape;
2810
+ using ShapeB = typename TileB::Layout::Shape;
2811
+
2812
+ // verify shapes and sizes are compatible
2813
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise AND");
2814
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise AND");
2815
+
2816
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2817
+ auto a_reg = a.copy_to_register();
2818
+ auto b_reg = b.copy_to_register();
2819
+
2820
+ using Layout = typename decltype(a_reg)::Layout;
2821
+
2822
+ WP_PRAGMA_UNROLL
2823
+ for (int i=0; i < Layout::NumRegs; ++i)
2824
+ {
2825
+ const int linear = Layout::linear_from_register(i);
2826
+
2827
+ if(!Layout::valid(linear))
2828
+ break;
2829
+
2830
+ a_reg.data[i] &= b_reg.data[i];
2831
+ }
2832
+
2833
+ a.assign(a_reg);
2834
+ }
2835
+
2836
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2837
+ inline CUDA_CALLABLE void adj_tile_bit_and_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2838
+
2839
+ template <typename TileA, typename TileB>
2840
+ inline CUDA_CALLABLE void tile_bit_or_inplace(TileA& a, TileB& b)
2841
+ {
2842
+ using ShapeA = typename TileA::Layout::Shape;
2843
+ using ShapeB = typename TileB::Layout::Shape;
2844
+
2845
+ // verify shapes and sizes are compatible
2846
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise OR");
2847
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise OR");
2848
+
2849
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2850
+ auto a_reg = a.copy_to_register();
2851
+ auto b_reg = b.copy_to_register();
2852
+
2853
+ using Layout = typename decltype(a_reg)::Layout;
2854
+
2855
+ WP_PRAGMA_UNROLL
2856
+ for (int i=0; i < Layout::NumRegs; ++i)
2857
+ {
2858
+ const int linear = Layout::linear_from_register(i);
2859
+
2860
+ if(!Layout::valid(linear))
2861
+ break;
2862
+
2863
+ a_reg.data[i] |= b_reg.data[i];
2864
+ }
2865
+
2866
+ a.assign(a_reg);
2867
+ }
2868
+
2869
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2870
+ inline CUDA_CALLABLE void adj_tile_bit_or_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2871
+
2872
+ template <typename TileA, typename TileB>
2873
+ inline CUDA_CALLABLE void tile_bit_xor_inplace(TileA& a, TileB& b)
2874
+ {
2875
+ using ShapeA = typename TileA::Layout::Shape;
2876
+ using ShapeB = typename TileB::Layout::Shape;
2877
+
2878
+ // verify shapes and sizes are compatible
2879
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise XOR");
2880
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise XOR");
2881
+
2882
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2883
+ auto a_reg = a.copy_to_register();
2884
+ auto b_reg = b.copy_to_register();
2885
+
2886
+ using Layout = typename decltype(a_reg)::Layout;
2887
+
2888
+ WP_PRAGMA_UNROLL
2889
+ for (int i=0; i < Layout::NumRegs; ++i)
2890
+ {
2891
+ const int linear = Layout::linear_from_register(i);
2892
+
2893
+ if(!Layout::valid(linear))
2894
+ break;
2895
+
2896
+ a_reg.data[i] ^= b_reg.data[i];
2897
+ }
2898
+
2899
+ a.assign(a_reg);
2900
+ }
2901
+
2902
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2903
+ inline CUDA_CALLABLE void adj_tile_bit_xor_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2904
+
2560
2905
 
2561
2906
  template<typename Tile>
2562
- typename Tile::Type tile_extract(Tile& t, int i) { return t.extract(tile_coord(i)); }
2907
+ typename Tile::Type tile_extract(Tile& t, int i) {
2908
+ return t.extract(tile_coord(i));
2909
+ }
2563
2910
  template<typename Tile>
2564
- typename Tile::Type tile_extract(Tile& t, int i, int j) { return t.extract(tile_coord(i,j)); }
2911
+ auto tile_extract(Tile& t, int i, int j) {
2912
+ if constexpr(is_vector<typename Tile::Type>::value) {
2913
+ return t.extract(tile_coord(i))[j];
2914
+ } else {
2915
+ return t.extract(tile_coord(i,j));
2916
+ }
2917
+ }
2565
2918
  template<typename Tile>
2566
- typename Tile::Type tile_extract(Tile& t, int i, int j, int k) { return t.extract(tile_coord(i,j,k)); }
2919
+ auto tile_extract(Tile& t, int i, int j, int k) {
2920
+ if constexpr(is_vector<typename Tile::Type>::value) {
2921
+ return t.extract(tile_coord(i,j))[k];
2922
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2923
+ return t.extract(tile_coord(i)).data[j][k];
2924
+ } else {
2925
+ return t.extract(tile_coord(i,j,k));
2926
+ }
2927
+ }
2567
2928
  template<typename Tile>
2568
- typename Tile::Type tile_extract(Tile& t, int i, int j, int k, int l) { return t.extract(tile_coord(i,j,k,l)); }
2929
+ auto tile_extract(Tile& t, int i, int j, int k, int l) {
2930
+ if constexpr(is_vector<typename Tile::Type>::value) {
2931
+ return t.extract(tile_coord(i,j,k))[l];
2932
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2933
+ return t.extract(tile_coord(i,j)).data[k][l];
2934
+ } else {
2935
+ return t.extract(tile_coord(i,j,k,l));
2936
+ }
2937
+ }
2938
+ template<typename Tile>
2939
+ auto tile_extract(Tile& t, int i, int j, int k, int l, int m) {
2940
+ if constexpr(is_vector<typename Tile::Type>::value) {
2941
+ return t.extract(tile_coord(i,j,k,l))[m];
2942
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2943
+ return t.extract(tile_coord(i,j,k)).data[l][m];
2944
+ } else {
2945
+ static_assert(always_false<Tile>::value,
2946
+ "tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
2947
+ }
2948
+ }
2949
+ template<typename Tile>
2950
+ auto tile_extract(Tile& t, int i, int j, int k, int l, int m, int n) {
2951
+ if constexpr(is_matrix<typename Tile::Type>::value) {
2952
+ return t.extract(tile_coord(i,j,k,l)).data[m][n];
2953
+ } else {
2954
+ static_assert(always_false<Tile>::value,
2955
+ "tile_extract with 6 indices requires a tile of matrices (4D tile)");
2956
+ }
2957
+ }
2569
2958
 
2570
2959
  template<typename Tile, typename AdjTile>
2571
- void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i), adj_ret); }
2572
- template<typename Tile, typename AdjTile>
2573
- void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j), adj_ret); }
2574
- template<typename Tile, typename AdjTile>
2575
- void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k), adj_ret); }
2576
- template<typename Tile, typename AdjTile>
2577
- void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
2960
+ void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) {
2961
+ adj_t.adj_extract(tile_coord(i), adj_ret);
2962
+ }
2963
+ template<typename Tile, typename AdjTile, typename AdjType>
2964
+ void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, AdjType adj_ret) {
2965
+ if constexpr(is_vector<typename Tile::Type>::value) {
2966
+ typename Tile::Type vector_adj{};
2967
+ vector_adj[j] = adj_ret;
2968
+ adj_t.adj_extract(tile_coord(i), vector_adj);
2969
+ } else {
2970
+ adj_t.adj_extract(tile_coord(i, j), adj_ret);
2971
+ }
2972
+ }
2973
+ template<typename Tile, typename AdjTile, typename AdjType>
2974
+ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, AdjType adj_ret) {
2975
+ if constexpr(is_vector<typename Tile::Type>::value) {
2976
+ typename Tile::Type vector_adj{};
2977
+ vector_adj[k] = adj_ret;
2978
+ adj_t.adj_extract(tile_coord(i, j), vector_adj);
2979
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2980
+ typename Tile::Type matrix_adj{};
2981
+ matrix_adj.data[j][k] = adj_ret;
2982
+ adj_t.adj_extract(tile_coord(i), matrix_adj);
2983
+ } else {
2984
+ adj_t.adj_extract(tile_coord(i, j, k), adj_ret);
2985
+ }
2986
+ }
2987
+ template<typename Tile, typename AdjTile, typename AdjType>
2988
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, AdjType adj_ret) {
2989
+ if constexpr(is_vector<typename Tile::Type>::value) {
2990
+ typename Tile::Type vector_adj{};
2991
+ vector_adj[l] = adj_ret;
2992
+ adj_t.adj_extract(tile_coord(i, j, k), vector_adj);
2993
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2994
+ typename Tile::Type matrix_adj{};
2995
+ matrix_adj.data[k][l] = adj_ret;
2996
+ adj_t.adj_extract(tile_coord(i, j), matrix_adj);
2997
+ } else {
2998
+ adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret);
2999
+ }
3000
+ }
3001
+ template<typename Tile, typename AdjTile, typename AdjType>
3002
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, AdjType adj_ret) {
3003
+ if constexpr(is_vector<typename Tile::Type>::value) {
3004
+ typename Tile::Type vector_adj{};
3005
+ vector_adj[m] = adj_ret;
3006
+ adj_t.adj_extract(tile_coord(i, j, k, l), vector_adj);
3007
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
3008
+ typename Tile::Type matrix_adj{};
3009
+ matrix_adj.data[l][m] = adj_ret;
3010
+ adj_t.adj_extract(tile_coord(i, j, k), matrix_adj);
3011
+ } else {
3012
+ static_assert(always_false<Tile>::value,
3013
+ "adj_tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
3014
+ }
3015
+ }
3016
+ template<typename Tile, typename AdjTile, typename AdjType>
3017
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, int n, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, AdjType adj_ret) {
3018
+ if constexpr(is_matrix<typename Tile::Type>::value) {
3019
+ typename Tile::Type matrix_adj{};
3020
+ matrix_adj.data[m][n] = adj_ret;
3021
+ adj_t.adj_extract(tile_coord(i, j, k, l), matrix_adj);
3022
+ } else {
3023
+ static_assert(always_false<Tile>::value,
3024
+ "adj_tile_extract with 6 indices requires a tile of matrices (4D tile)");
3025
+ }
3026
+ }
2578
3027
 
2579
3028
 
2580
3029
  template<typename Tile>
@@ -2595,6 +3044,33 @@ void tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) {
2595
3044
  template<typename Tile>
2596
3045
  void tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k,l), value); }
2597
3046
 
3047
+ template<typename Tile>
3048
+ void tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i), value); }
3049
+ template<typename Tile>
3050
+ void tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j), value); }
3051
+ template<typename Tile>
3052
+ void tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k), value); }
3053
+ template<typename Tile>
3054
+ void tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k,l), value); }
3055
+
3056
+ template<typename Tile>
3057
+ void tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i), value); }
3058
+ template<typename Tile>
3059
+ void tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j), value); }
3060
+ template<typename Tile>
3061
+ void tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k), value); }
3062
+ template<typename Tile>
3063
+ void tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k,l), value); }
3064
+
3065
+ template<typename Tile>
3066
+ void tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i), value); }
3067
+ template<typename Tile>
3068
+ void tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j), value); }
3069
+ template<typename Tile>
3070
+ void tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k), value); }
3071
+ template<typename Tile>
3072
+ void tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k,l), value); }
3073
+
2598
3074
  template<typename Tile, typename AdjTile>
2599
3075
  void adj_tile_add_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i), adj_value); }
2600
3076
  template<typename Tile, typename AdjTile>
@@ -2613,6 +3089,33 @@ void adj_tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type valu
2613
3089
  template<typename Tile, typename AdjTile>
2614
3090
  void adj_tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k, l), adj_value); }
2615
3091
 
3092
+ template<typename Tile, typename AdjTile>
3093
+ void adj_tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3094
+ template<typename Tile, typename AdjTile>
3095
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3096
+ template<typename Tile, typename AdjTile>
3097
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3098
+ template<typename Tile, typename AdjTile>
3099
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3100
+
3101
+ template<typename Tile, typename AdjTile>
3102
+ void adj_tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3103
+ template<typename Tile, typename AdjTile>
3104
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3105
+ template<typename Tile, typename AdjTile>
3106
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3107
+ template<typename Tile, typename AdjTile>
3108
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3109
+
3110
+ template<typename Tile, typename AdjTile>
3111
+ void adj_tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3112
+ template<typename Tile, typename AdjTile>
3113
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3114
+ template<typename Tile, typename AdjTile>
3115
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3116
+ template<typename Tile, typename AdjTile>
3117
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3118
+
2616
3119
  namespace partitioned_gemm
2617
3120
  {
2618
3121
 
@@ -3000,7 +3503,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
3000
3503
  #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
3001
3504
  do { \
3002
3505
  void function_name(dtype*, char*); \
3003
- char* buffer = (char*)wp::tile_alloc_shared(shared_memory_size); \
3506
+ char* buffer = (char*)wp::tile_shared_storage_t::alloc(shared_memory_size); \
3004
3507
  __align__(16) dtype data[ept]; \
3005
3508
  for(int b = 0; b < (int)batch_size; b++) { \
3006
3509
  dtype* inout = Xinout.data + (int)b * (int)ept; \
@@ -3009,7 +3512,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
3009
3512
  memcpy(inout, data, sizeof(dtype) * ept); \
3010
3513
  WP_TILE_SYNC(); \
3011
3514
  } \
3012
- wp::tile_alloc_shared(-shared_memory_size); \
3515
+ wp::tile_shared_storage_t::alloc(-shared_memory_size); \
3013
3516
  } while (0)
3014
3517
 
3015
3518
  #define tile_ifft tile_fft
@@ -3053,7 +3556,7 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
3053
3556
  #else
3054
3557
 
3055
3558
  // TODO: for batched Cholesky, need one info per batch
3056
- WP_TILE_SHARED int info[1];
3559
+ __shared__ int info[1];
3057
3560
 
3058
3561
  if (WP_TILE_THREAD_IDX == 0) {
3059
3562
  info[0] = 0;
@@ -3385,21 +3888,62 @@ inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
3385
3888
  template <typename TileA, typename Scalar>
3386
3889
  inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
3387
3890
  {
3388
- dest.data(tile_coord(i, j)) = src;
3891
+ if constexpr(is_vector<typename TileA::Type>::value) {
3892
+ dest.data(tile_coord(i))[j] = src;
3893
+ } else {
3894
+ dest.data(tile_coord(i, j)) = src;
3895
+ }
3389
3896
  WP_TILE_SYNC();
3390
3897
  }
3391
3898
  template <typename TileA, typename Scalar>
3392
3899
  inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
3393
3900
  {
3394
- dest.data(tile_coord(i, j, k)) = src;
3901
+ if constexpr(is_vector<typename TileA::Type>::value) {
3902
+ dest.data(tile_coord(i, j))[k] = src;
3903
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3904
+ dest.data(tile_coord(i)).data[j][k] = src;
3905
+ } else {
3906
+ dest.data(tile_coord(i, j, k)) = src;
3907
+ }
3395
3908
  WP_TILE_SYNC();
3396
3909
  }
3397
3910
  template <typename TileA, typename Scalar>
3398
3911
  inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
3399
3912
  {
3400
- dest.data(tile_coord(i, j, k, l)) = src;
3913
+ if constexpr(is_vector<typename TileA::Type>::value) {
3914
+ dest.data(tile_coord(i, j, k))[l] = src;
3915
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3916
+ dest.data(tile_coord(i, j)).data[k][l] = src;
3917
+ } else {
3918
+ dest.data(tile_coord(i, j, k, l)) = src;
3919
+ }
3920
+ WP_TILE_SYNC();
3921
+ }
3922
+ template <typename TileA, typename Scalar>
3923
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src)
3924
+ {
3925
+ if constexpr(is_vector<typename TileA::Type>::value) {
3926
+ dest.data(tile_coord(i, j, k, l))[m] = src;
3927
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3928
+ dest.data(tile_coord(i, j, k)).data[l][m] = src;
3929
+ } else {
3930
+ static_assert(always_false<TileA>::value,
3931
+ "assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
3932
+ }
3401
3933
  WP_TILE_SYNC();
3402
3934
  }
3935
+ template <typename TileA, typename Scalar>
3936
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src)
3937
+ {
3938
+ if constexpr(is_matrix<typename TileA::Type>::value) {
3939
+ dest.data(tile_coord(i, j, k, l)).data[m][n] = src;
3940
+ } else {
3941
+ static_assert(always_false<TileA>::value,
3942
+ "assign with 6 indices requires a tile of matrices (4D tile)");
3943
+ }
3944
+ WP_TILE_SYNC();
3945
+ }
3946
+
3403
3947
 
3404
3948
  template <typename TileA, typename AdjTileA, typename Scalar>
3405
3949
  inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, const Scalar& src, AdjTileA& adj_dest, int adj_i, Scalar& adj_src)
@@ -3419,7 +3963,11 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, const Scalar& sr
3419
3963
  return;
3420
3964
  }
3421
3965
 
3422
- adj_src += dest.grad(tile_coord(i, j));
3966
+ if constexpr(is_vector<typename TileA::Type>::value) {
3967
+ adj_src += dest.grad(tile_coord(i))[j];
3968
+ } else {
3969
+ adj_src += dest.grad(tile_coord(i, j));
3970
+ }
3423
3971
  }
3424
3972
  template <typename TileA, typename AdjTileA, typename Scalar>
3425
3973
  inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, Scalar& adj_src)
@@ -3429,7 +3977,13 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Sca
3429
3977
  return;
3430
3978
  }
3431
3979
 
3432
- adj_src += dest.grad(tile_coord(i, j, k));
3980
+ if constexpr(is_vector<typename TileA::Type>::value) {
3981
+ adj_src += dest.grad(tile_coord(i, j))[k];
3982
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3983
+ adj_src += dest.grad(tile_coord(i)).data[j][k];
3984
+ } else {
3985
+ adj_src += dest.grad(tile_coord(i, j, k));
3986
+ }
3433
3987
  }
3434
3988
  template <typename TileA, typename AdjTileA, typename Scalar>
3435
3989
  inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, Scalar& adj_src)
@@ -3439,7 +3993,45 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, co
3439
3993
  return;
3440
3994
  }
3441
3995
 
3442
- adj_src += dest.grad(tile_coord(i, j, k, l));
3996
+ if constexpr(is_vector<typename TileA::Type>::value) {
3997
+ adj_src += dest.grad(tile_coord(i, j, k))[l];
3998
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3999
+ adj_src += dest.grad(tile_coord(i, j)).data[k][l];
4000
+ } else {
4001
+ adj_src += dest.grad(tile_coord(i, j, k, l));
4002
+ }
4003
+ }
4004
+ template <typename TileA, typename AdjTileA, typename Scalar>
4005
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, Scalar& adj_src)
4006
+ {
4007
+ if (dest.grad.ptr == nullptr)
4008
+ {
4009
+ return;
4010
+ }
4011
+
4012
+ if constexpr(is_vector<typename TileA::Type>::value) {
4013
+ adj_src += dest.grad(tile_coord(i, j, k, l))[m];
4014
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
4015
+ adj_src += dest.grad(tile_coord(i, j, k)).data[l][m];
4016
+ } else {
4017
+ static_assert(always_false<TileA>::value,
4018
+ "adj_assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
4019
+ }
4020
+ }
4021
+ template <typename TileA, typename AdjTileA, typename Scalar>
4022
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, Scalar& adj_src)
4023
+ {
4024
+ if (dest.grad.ptr == nullptr)
4025
+ {
4026
+ return;
4027
+ }
4028
+
4029
+ if constexpr(is_matrix<typename TileA::Type>::value) {
4030
+ adj_src += dest.grad(tile_coord(i, j, k, l)).data[m][n];
4031
+ } else {
4032
+ static_assert(always_false<TileA>::value,
4033
+ "adj_assign with 6 indices requires a tile of matrices (4D tile)");
4034
+ }
3443
4035
  }
3444
4036
 
3445
4037
  template <typename TileA, typename TileB, typename Coord>