warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/tile.h CHANGED
@@ -43,18 +43,10 @@
43
43
  };
44
44
  #endif
45
45
 
46
- // only used while building the warp core library
47
- #ifndef WP_TILE_BLOCK_DIM
48
- #define WP_TILE_BLOCK_DIM 256
49
- #endif
50
-
51
- #if !defined(__CUDA_ARCH__)
52
- #define WP_TILE_SHARED static
53
- #define WP_TILE_SYNC void
54
-
55
- #else
56
- #define WP_TILE_SHARED __shared__
46
+ #if defined(__CUDA_ARCH__)
57
47
  #define WP_TILE_SYNC __syncthreads
48
+ #else
49
+ #define WP_TILE_SYNC void
58
50
  #endif
59
51
 
60
52
  #if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
@@ -140,7 +132,6 @@
140
132
  [ ] LayerNorm
141
133
  [ ] SoftMax
142
134
  [ ] GEMM
143
- [ ] warp.sim (CRBA)
144
135
  [ ] Batched MLP
145
136
  [ ] Layer norm
146
137
  [ ] FNO + Burgers equation
@@ -149,7 +140,6 @@
149
140
  [ ] MeshCNN (Modulus, Oliver)
150
141
  [ ] BioNemo (Ali)
151
142
  [ ] Skinning (David/Or/Vismay)
152
- [ ] warp.sim (VBD)
153
143
  [ ] Error checking
154
144
  [ ] Ensure functions passed to tile_map() are compatible with tile type
155
145
  [ ] Ensure that args passed to tile ops are compatible
@@ -213,6 +203,12 @@ struct is_same<T, T> {
213
203
  static constexpr bool value = true;
214
204
  };
215
205
 
206
+ // Helper for dependent static_assert failures
207
+ template <typename T>
208
+ struct always_false {
209
+ static constexpr bool value = false;
210
+ };
211
+
216
212
 
217
213
  template <int N>
218
214
  struct tile_coord_t
@@ -338,6 +334,113 @@ template <int... V>
338
334
  using tile_stride_t = tile_tuple_t<V...>;
339
335
 
340
336
 
337
+ // helper to remove a dimension from a shape (used for axis reductions)
338
+ template<int Axis, typename Shape>
339
+ struct tile_shape_remove_dim {
340
+ static_assert(Axis >= 0 && Axis < Shape::N, "Axis out of bounds for tile_shape_remove_dim");
341
+ };
342
+
343
+ // 1D -> scalar
344
+ template<int D0>
345
+ struct tile_shape_remove_dim<0, tile_shape_t<D0>> {
346
+ using type = tile_shape_t<1>;
347
+ };
348
+
349
+ // 2D -> 1D
350
+ template<int D0, int D1>
351
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1>> {
352
+ using type = tile_shape_t<D1>;
353
+ };
354
+
355
+ template<int D0, int D1>
356
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1>> {
357
+ using type = tile_shape_t<D0>;
358
+ };
359
+
360
+ // 3D -> 2D
361
+ template<int D0, int D1, int D2>
362
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2>> {
363
+ using type = tile_shape_t<D1, D2>;
364
+ };
365
+
366
+ template<int D0, int D1, int D2>
367
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2>> {
368
+ using type = tile_shape_t<D0, D2>;
369
+ };
370
+
371
+ template<int D0, int D1, int D2>
372
+ struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2>> {
373
+ using type = tile_shape_t<D0, D1>;
374
+ };
375
+
376
+ // 4D -> 3D
377
+ template<int D0, int D1, int D2, int D3>
378
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2, D3>> {
379
+ using type = tile_shape_t<D1, D2, D3>;
380
+ };
381
+
382
+ template<int D0, int D1, int D2, int D3>
383
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2, D3>> {
384
+ using type = tile_shape_t<D0, D2, D3>;
385
+ };
386
+
387
+ template<int D0, int D1, int D2, int D3>
388
+ struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2, D3>> {
389
+ using type = tile_shape_t<D0, D1, D3>;
390
+ };
391
+
392
+ template<int D0, int D1, int D2, int D3>
393
+ struct tile_shape_remove_dim<3, tile_shape_t<D0, D1, D2, D3>> {
394
+ using type = tile_shape_t<D0, D1, D2>;
395
+ };
396
+
397
+
398
+ // helper to insert an axis value into a coordinate (inverse of removing dimension)
399
+ // used for mapping output coordinates back to input coordinates during axis reduction
400
+ template<int Axis, int N>
401
+ CUDA_CALLABLE constexpr auto tile_coord_insert_axis(const tile_coord_t<N>& coord, int axis_val)
402
+ {
403
+ static_assert(Axis >= 0 && Axis <= N, "Axis out of bounds for tile_coord_insert_axis");
404
+
405
+ if constexpr (N == 0)
406
+ {
407
+ // Scalar -> 1D
408
+ static_assert(Axis == 0, "Invalid axis for scalar coordinate");
409
+ return tile_coord(axis_val);
410
+ }
411
+ else if constexpr (N == 1)
412
+ {
413
+ // 1D -> 2D
414
+ if constexpr (Axis == 0)
415
+ return tile_coord(axis_val, coord[0]);
416
+ else
417
+ return tile_coord(coord[0], axis_val);
418
+ }
419
+ else if constexpr (N == 2)
420
+ {
421
+ // 2D -> 3D
422
+ if constexpr (Axis == 0)
423
+ return tile_coord(axis_val, coord[0], coord[1]);
424
+ else if constexpr (Axis == 1)
425
+ return tile_coord(coord[0], axis_val, coord[1]);
426
+ else
427
+ return tile_coord(coord[0], coord[1], axis_val);
428
+ }
429
+ else // N == 3
430
+ {
431
+ // 3D -> 4D
432
+ if constexpr (Axis == 0)
433
+ return tile_coord(axis_val, coord[0], coord[1], coord[2]);
434
+ else if constexpr (Axis == 1)
435
+ return tile_coord(coord[0], axis_val, coord[1], coord[2]);
436
+ else if constexpr (Axis == 2)
437
+ return tile_coord(coord[0], coord[1], axis_val, coord[2]);
438
+ else
439
+ return tile_coord(coord[0], coord[1], coord[2], axis_val);
440
+ }
441
+ }
442
+
443
+
341
444
  // represents a tile stored in global memory with dynamic strides
342
445
  // used to represent the source and offset for tile loads to register/shared
343
446
  // BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
@@ -542,7 +645,7 @@ struct tile_register_t
542
645
 
543
646
  // define the += operator which is used during backward pass codegen
544
647
  // when returning a register tile from a user defined function
545
- inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
648
+ inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, Layout>& rhs)
546
649
  {
547
650
  grad_add(rhs);
548
651
  return *this;
@@ -581,7 +684,11 @@ struct tile_register_t
581
684
  const int thread = Layout::thread_from_linear(linear);
582
685
  const int reg = Layout::register_from_linear(linear);
583
686
 
584
- WP_TILE_SHARED Type scratch;
687
+ #if defined(__CUDA_ARCH__)
688
+ __shared__ Type scratch;
689
+ #else
690
+ Type scratch;
691
+ #endif
585
692
 
586
693
  // ensure any previously scheduled threads have finished reading from scratch
587
694
  WP_TILE_SYNC();
@@ -658,7 +765,7 @@ struct tile_register_t
658
765
  data[i] += tile.data[i];
659
766
  }
660
767
 
661
- CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
768
+ inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
662
769
  {
663
770
  apply([&](int reg, auto c) {data[reg] += global.load_grad(c);});
664
771
  }
@@ -735,42 +842,124 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
735
842
  return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
736
843
  }
737
844
 
738
- inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, bool check=false)
845
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
846
+ // On the CPU we use a fixed size block of stack memory for shared tile allocations.
847
+ // We store a pointer to the current allocation storage either in a reserved register
848
+ // (AArch64) or a static variable (x86-64).
849
+ #if !defined(__CUDA_ARCH__)
850
+ class tile_shared_storage_t;
851
+ #if defined(__aarch64__)
852
+ // x28 is is the last callee-saved register on AArch64. This allows us to call externally
853
+ // compiled functions without worrying about clobbering the pointer.
854
+ // We pass -target-feature +reserve-x28 to Clang to exclude it from register allocation.
855
+ register tile_shared_storage_t* shared_tile_storage asm("x28");
856
+ #else
857
+ // Ideally this would be thread_local, but LLVM's JIT doesn't support TLS yet
858
+ // There is also no support for something like -ffixed-r15 either
859
+ static tile_shared_storage_t* shared_tile_storage;
860
+ #endif
861
+ #endif
862
+ #endif
863
+
864
+ // This class manages a block of "shared" memory for use by tiles.
865
+ // On the GPU this maps to dynamic shared memory, while on the CPU we allocate
866
+ // a fixed size block of memory on the stack and manage allocations from it.
867
+ // An instance of this class gets created at the start of a kernel.
868
+ class tile_shared_storage_t
739
869
  {
870
+ private:
871
+ #if !defined(__CUDA_ARCH__)
872
+ #define WP_MAX_CPU_SHARED 256*1024
873
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
874
+ tile_shared_storage_t* old_value;
875
+ unsigned int smem_base[WP_TILE_BLOCK_DIM];
876
+ char dynamic_smem_base[WP_MAX_CPU_SHARED]; // on CPU allocate a fixed 256k block to use for shared allocs
877
+ #endif
878
+ #endif
879
+
740
880
  // we maintain a per-thread offset into dynamic
741
881
  // shared memory that allows us to keep track of
742
882
  // current use across dynamic function calls
743
- WP_TILE_SHARED int smem_base[WP_TILE_BLOCK_DIM];
883
+ static inline CUDA_CALLABLE unsigned int* get_smem_base()
884
+ {
885
+ #if defined(__CUDA_ARCH__)
886
+ __shared__ unsigned int smem_base[WP_TILE_BLOCK_DIM];
887
+ return smem_base;
888
+ #elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
889
+ return shared_tile_storage->smem_base;
890
+ #else
891
+ static unsigned int smem_base[WP_TILE_BLOCK_DIM];
892
+ return smem_base;
893
+ #endif
894
+ }
744
895
 
745
- if (init)
896
+ static inline CUDA_CALLABLE char* get_dynamic_smem_base()
746
897
  {
898
+ #if defined(__CUDA_ARCH__)
899
+ extern __shared__ char dynamic_smem_base[];
900
+ return dynamic_smem_base;
901
+ #elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
902
+ return shared_tile_storage->dynamic_smem_base;
903
+ #else
904
+ static char dynamic_smem_base[WP_MAX_CPU_SHARED];
905
+ return dynamic_smem_base;
906
+ #endif
907
+ }
908
+
909
+ public:
910
+ // cppcheck-suppress uninitMemberVar
911
+ inline CUDA_CALLABLE tile_shared_storage_t()
912
+ {
913
+ #if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
914
+ // On the CPU save a pointer to this instance in a reserved register
915
+ // or static variable so it can be accessed from anywhere within a kernel.
916
+ old_value = shared_tile_storage;
917
+ shared_tile_storage = this;
918
+ #endif
919
+
920
+ init();
921
+ }
922
+
923
+ inline CUDA_CALLABLE ~tile_shared_storage_t()
924
+ {
925
+ check();
926
+
927
+ #if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
928
+ shared_tile_storage = old_value;
929
+ #endif
930
+ }
931
+
932
+ static inline CUDA_CALLABLE void init()
933
+ {
934
+ unsigned int* smem_base = get_smem_base();
935
+
747
936
  smem_base[WP_TILE_THREAD_IDX] = 0;
748
- return nullptr;
749
937
  }
750
- else if (check)
938
+
939
+ static inline CUDA_CALLABLE void check()
751
940
  {
941
+ unsigned int* smem_base = get_smem_base();
942
+
752
943
  assert(smem_base[WP_TILE_THREAD_IDX] == 0);
753
- return nullptr;
754
944
  }
755
- else
945
+
946
+ static inline CUDA_CALLABLE void* alloc(int num_bytes)
756
947
  {
757
- const int offset = smem_base[WP_TILE_THREAD_IDX];
758
-
948
+ unsigned int* smem_base = get_smem_base();
949
+ char* dynamic_smem_base = get_dynamic_smem_base();
950
+
951
+ const unsigned int offset = smem_base[WP_TILE_THREAD_IDX];
952
+
759
953
  // one entry per-thread so no need for synchronization
760
954
  smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
761
955
 
762
- #ifdef __CUDA_ARCH__
763
- extern __shared__ char dynamic_smem_base[];
764
- #else
765
- // on CPU allocate a fixed 256k block to use for shared allocs
766
- static const int max_cpu_shared = 256*1024;
767
- static char dynamic_smem_base[max_cpu_shared];
768
-
769
- assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
956
+ #if !defined(__CUDA_ARCH__)
957
+ assert(smem_base[WP_TILE_THREAD_IDX] <= WP_MAX_CPU_SHARED);
770
958
  #endif
959
+
771
960
  return &(dynamic_smem_base[offset]);
772
961
  }
773
- }
962
+ };
774
963
 
775
964
 
776
965
  template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
@@ -905,6 +1094,28 @@ struct tile_shared_t
905
1094
  {
906
1095
  }
907
1096
 
1097
+ // we delete the copy constructor because in the case the shared tile is owning,
1098
+ // this leads to a double deallocation.
1099
+ // this also forces one to handle copies explicitly
1100
+ inline CUDA_CALLABLE tile_shared_t(const tile_shared_t& other) : data(other.data), grad(other.grad), initialized(other.initialized)
1101
+ {
1102
+ static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
1103
+ }
1104
+
1105
+ // move constructor
1106
+ inline CUDA_CALLABLE tile_shared_t(tile_shared_t&& other) : data(other.data), grad(other.grad), initialized(other.initialized)
1107
+ {
1108
+ other.data.ptr = nullptr;
1109
+ other.grad.ptr = nullptr;
1110
+ }
1111
+
1112
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
1113
+ inline CUDA_CALLABLE tile_shared_t(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& other) : data(other.data.ptr), grad(other.grad.ptr), initialized(other.initialized)
1114
+ {
1115
+ static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
1116
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
1117
+ }
1118
+
908
1119
  // initialize from an existing tile's memory
909
1120
  inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
910
1121
  {
@@ -916,10 +1127,10 @@ struct tile_shared_t
916
1127
  {
917
1128
  // update our per-thread shared memory allocator
918
1129
  if (data.ptr)
919
- tile_alloc_shared(-Layout::Size*int(sizeof(T)));
1130
+ tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
920
1131
 
921
1132
  if (grad.ptr)
922
- tile_alloc_shared(-Layout::Size*int(sizeof(T)));
1133
+ tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
923
1134
  }
924
1135
  }
925
1136
 
@@ -932,19 +1143,47 @@ struct tile_shared_t
932
1143
 
933
1144
  // construct from another shared tile, this constructor
934
1145
  // is invoked for reshape operations like `wp.tile_transpose()`
1146
+ // or `wp::copy()`
935
1147
  template <typename OtherT, typename OtherLayout, bool OtherOwner>
936
1148
  inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& rhs)
937
1149
  {
938
1150
  // check dimensions are compatible
939
1151
  static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
940
1152
 
941
- // alias tile directly
942
- data.ptr = rhs.data.ptr;
943
- grad.ptr = rhs.grad.ptr;
944
- initialized = rhs.initialized;
1153
+
1154
+ if (Owner)
1155
+ {
1156
+ // if the tile owns the data we need to copy
1157
+ assign(rhs);
1158
+ }
1159
+ else
1160
+ {
1161
+ // alias tile directly
1162
+ data.ptr = rhs.data.ptr;
1163
+ grad.ptr = rhs.grad.ptr;
1164
+ initialized = rhs.initialized;
1165
+ }
945
1166
 
946
1167
  return *this;
947
- }
1168
+ }
1169
+
1170
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t& rhs)
1171
+ {
1172
+ if (Owner)
1173
+ {
1174
+ // if the tile owns the data we need to copy
1175
+ assign(rhs);
1176
+ }
1177
+ else
1178
+ {
1179
+ // alias tile directly
1180
+ data.ptr = rhs.data.ptr;
1181
+ grad.ptr = rhs.grad.ptr;
1182
+ initialized = rhs.initialized;
1183
+ }
1184
+
1185
+ return *this;
1186
+ }
948
1187
 
949
1188
  // assign from a global tile (load)
950
1189
 
@@ -972,6 +1211,21 @@ struct tile_shared_t
972
1211
  return *this;
973
1212
  }
974
1213
 
1214
+ // define the += operator which is used during backward pass codegen
1215
+ // when returning a register tile from a user defined function
1216
+ template<typename OtherLayout>
1217
+ inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, OtherLayout>& rhs)
1218
+ {
1219
+ grad_add(rhs);
1220
+ return *this;
1221
+ }
1222
+
1223
+ inline CUDA_CALLABLE auto& operator += (const tile_shared_t<T, Layout>& rhs)
1224
+ {
1225
+ grad_add(rhs);
1226
+ return *this;
1227
+ }
1228
+
975
1229
  // in-place zero
976
1230
  inline CUDA_CALLABLE void zero()
977
1231
  {
@@ -1029,6 +1283,46 @@ struct tile_shared_t
1029
1283
  adj_x -= grad(c);
1030
1284
  }
1031
1285
 
1286
+ // perform AND between a scalar value and a single tile element
1287
+ inline CUDA_CALLABLE void bit_and_inplace(const typename Layout::Coord& c, const Type& x)
1288
+ {
1289
+ // since multiple threads may access the same element
1290
+ // we need to access using atomic operations
1291
+ wp::atomic_and(&data(c), x);
1292
+
1293
+ WP_TILE_SYNC();
1294
+ }
1295
+
1296
+ // backward of inplace scalar AND
1297
+ inline CUDA_CALLABLE void adj_bit_and_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1298
+
1299
+
1300
+ // perform OR between a scalar value and a single tile element
1301
+ inline CUDA_CALLABLE void bit_or_inplace(const typename Layout::Coord& c, const Type& x)
1302
+ {
1303
+ // since multiple threads may access the same element
1304
+ // we need to access using atomic operations
1305
+ wp::atomic_or(&data(c), x);
1306
+
1307
+ WP_TILE_SYNC();
1308
+ }
1309
+
1310
+ // backward of inplace scalar OR
1311
+ inline CUDA_CALLABLE void adj_bit_or_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1312
+
1313
+ // perform XOR between a scalar value and a single tile element
1314
+ inline CUDA_CALLABLE void bit_xor_inplace(const typename Layout::Coord& c, const Type& x)
1315
+ {
1316
+ // since multiple threads may access the same element
1317
+ // we need to access using atomic operations
1318
+ wp::atomic_xor(&data(c), x);
1319
+
1320
+ WP_TILE_SYNC();
1321
+ }
1322
+
1323
+ // backward of inplace scalar XOR
1324
+ inline CUDA_CALLABLE void adj_bit_xor_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1325
+
1032
1326
  // copy register tile to shared
1033
1327
  template <typename Tile>
1034
1328
  inline CUDA_CALLABLE void assign(const Tile& tile)
@@ -1053,6 +1347,27 @@ struct tile_shared_t
1053
1347
  WP_TILE_SYNC();
1054
1348
  }
1055
1349
 
1350
+ // shared tile deep copy
1351
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
1352
+ inline CUDA_CALLABLE void assign(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& tile)
1353
+ {
1354
+ // check dimensions are compatible
1355
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
1356
+
1357
+ if (initialized)
1358
+ WP_TILE_SYNC();
1359
+
1360
+ WP_PRAGMA_UNROLL
1361
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1362
+ {
1363
+ auto c = Layout::coord_from_linear(i);
1364
+ data(c) = tile.data(c);
1365
+ }
1366
+
1367
+ initialized = true;
1368
+ WP_TILE_SYNC();
1369
+ }
1370
+
1056
1371
  // in-place gradient zero
1057
1372
  inline CUDA_CALLABLE void grad_zero()
1058
1373
  {
@@ -1092,8 +1407,21 @@ struct tile_shared_t
1092
1407
  WP_TILE_SYNC();
1093
1408
  }
1094
1409
 
1410
+ // accumulate gradients onto this tile from another shared tile
1411
+ inline CUDA_CALLABLE void grad_add(const tile_shared_t<T, Layout>& tile)
1412
+ {
1413
+ WP_PRAGMA_UNROLL
1414
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1415
+ {
1416
+ auto c = Layout::coord_from_linear(i);
1417
+ grad(c) += tile.grad(c);
1418
+ }
1419
+
1420
+ WP_TILE_SYNC();
1421
+ }
1422
+
1095
1423
  // accumulate gradient onto this tile from a global array
1096
- CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1424
+ inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1097
1425
  {
1098
1426
  WP_PRAGMA_UNROLL
1099
1427
  for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
@@ -1449,7 +1777,11 @@ void tile_register_t<T, L>::print() const
1449
1777
  {
1450
1778
  // create a temporary shared tile so that
1451
1779
  // we can print it deterministically
1452
- WP_TILE_SHARED T smem[L::Size];
1780
+ #if defined(__CUDA_ARCH__)
1781
+ __shared__ T smem[L::Size];
1782
+ #else
1783
+ T smem[L::Size];
1784
+ #endif
1453
1785
  tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
1454
1786
 
1455
1787
  scratch.assign(*this);
@@ -1477,9 +1809,16 @@ void tile_register_t<T, L>::print() const
1477
1809
  // print entry points
1478
1810
  template <typename T, typename L>
1479
1811
  inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
1812
+
1813
+ template <typename T, typename L>
1814
+ inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1815
+
1480
1816
  template <typename T, typename L, bool Owner>
1481
1817
  inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
1482
1818
 
1819
+ template <typename T, typename L, bool Owner>
1820
+ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1821
+
1483
1822
  template <typename T, typename L, bool O>
1484
1823
  inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
1485
1824
  {
@@ -1502,20 +1841,57 @@ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile&
1502
1841
  {
1503
1842
  }
1504
1843
 
1844
+ // where specialization for register/shared tiles
1845
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1846
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
1847
+ {
1848
+ // The double NOT operator !! casts to bool without compiler warnings.
1849
+ return (!!cond) ? a : b.copy_to_register();
1850
+ }
1851
+
1852
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1853
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
1854
+ {
1855
+ // The double NOT operator !! casts to bool without compiler warnings.
1856
+ return (!!cond) ? a.copy_to_register() : b;
1857
+ }
1505
1858
 
1506
- template <typename T, typename L>
1507
- inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1508
- template <typename T, typename L, bool Owner>
1509
- inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1859
+ template <typename C, typename T, typename L, bool Owner>
1860
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
1861
+ {
1862
+ // The double NOT operator !! casts to bool without compiler warnings.
1863
+ return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
1864
+ }
1510
1865
 
1866
+ template <typename C, typename T, typename L, bool LOwner, bool ROwner>
1867
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
1868
+ {
1869
+ // The double NOT operator !! casts to bool without compiler warnings.
1870
+ return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
1871
+ }
1511
1872
 
1873
+ // adj_where same as in builtin.h
1874
+
1875
+ // copy specialization for shared tiles, the lvalue this gets assigned to is owning, thus, this invokes the copy assign path
1876
+ template <typename T, typename L, bool Owner>
1877
+ inline CUDA_CALLABLE auto copy(const tile_shared_t<T, L, Owner>& t)
1878
+ {
1879
+ return tile_shared_t<T, L, false>(t.data.ptr, t.grad.ptr);
1880
+ }
1881
+
1882
+ template <typename T, typename L, bool Owner>
1883
+ inline CUDA_CALLABLE void adj_copy(const tile_shared_t<T, L, Owner>& src, tile_shared_t<T, L, Owner>& adj_src, tile_shared_t<T, L, Owner>& adj_dest)
1884
+ {
1885
+ adj_src += adj_dest;
1886
+ adj_dest.grad_zero();
1887
+ }
1512
1888
 
1513
1889
  // helpers to allocate shared tiles
1514
1890
  template <typename T, typename Shape, typename Strides, bool RequiresGrad>
1515
1891
  inline CUDA_CALLABLE auto tile_alloc_empty()
1516
1892
  {
1517
1893
  constexpr int size = Shape::size();
1518
- T* data = (T*)tile_alloc_shared(size*sizeof(T));
1894
+ T* data = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
1519
1895
  T* grad = nullptr;
1520
1896
 
1521
1897
  #if FP_CHECK
@@ -1534,7 +1910,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1534
1910
 
1535
1911
  if (RequiresGrad)
1536
1912
  {
1537
- grad = (T*)tile_alloc_shared(size*sizeof(T));
1913
+ grad = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
1538
1914
 
1539
1915
  for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1540
1916
  grad[i] = T(0);
@@ -1712,6 +2088,14 @@ inline CUDA_CALLABLE auto tile_ones()
1712
2088
  return T(1);
1713
2089
  }
1714
2090
 
2091
+ // value-initialized tile
2092
+ template <typename T, unsigned... Shape>
2093
+ inline CUDA_CALLABLE auto tile_full(T x)
2094
+ {
2095
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
2096
+ return x;
2097
+ }
2098
+
1715
2099
  // tile with evenly spaced values
1716
2100
  template <typename T, int Len>
1717
2101
  inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
@@ -2263,6 +2647,43 @@ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
2263
2647
  }
2264
2648
 
2265
2649
 
2650
+ // tile & tile
2651
+ template <typename TileA, typename TileB>
2652
+ inline CUDA_CALLABLE auto tile_bit_and(TileA& a, TileB& b)
2653
+ {
2654
+ return tile_binary_map(bit_and, a, b, a);
2655
+ }
2656
+
2657
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2658
+ inline CUDA_CALLABLE void adj_tile_bit_and(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2659
+ {
2660
+ }
2661
+
2662
+ // tile | tile
2663
+ template <typename TileA, typename TileB>
2664
+ inline CUDA_CALLABLE auto tile_bit_or(TileA& a, TileB& b)
2665
+ {
2666
+ return tile_binary_map(bit_or, a, b, a);
2667
+ }
2668
+
2669
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2670
+ inline CUDA_CALLABLE void adj_tile_bit_or(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2671
+ {
2672
+ }
2673
+
2674
+ // tile ^ tile
2675
+ template <typename TileA, typename TileB>
2676
+ inline CUDA_CALLABLE auto tile_bit_xor(TileA& a, TileB& b)
2677
+ {
2678
+ return tile_binary_map(bit_xor, a, b, a);
2679
+ }
2680
+
2681
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2682
+ inline CUDA_CALLABLE void adj_tile_bit_xor(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2683
+ {
2684
+ }
2685
+
2686
+
2266
2687
  template <typename TileA, typename TileB>
2267
2688
  inline CUDA_CALLABLE void tile_add_inplace(TileA& a, TileB& b)
2268
2689
  {
@@ -2382,24 +2803,227 @@ inline CUDA_CALLABLE void adj_tile_sub_inplace(TileA& a, TileB& b, AdjTileA& adj
2382
2803
  adj_b.grad_add(adj_b_reg);
2383
2804
  }
2384
2805
 
2806
+ template <typename TileA, typename TileB>
2807
+ inline CUDA_CALLABLE void tile_bit_and_inplace(TileA& a, TileB& b)
2808
+ {
2809
+ using ShapeA = typename TileA::Layout::Shape;
2810
+ using ShapeB = typename TileB::Layout::Shape;
2811
+
2812
+ // verify shapes and sizes are compatible
2813
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise AND");
2814
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise AND");
2815
+
2816
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2817
+ auto a_reg = a.copy_to_register();
2818
+ auto b_reg = b.copy_to_register();
2819
+
2820
+ using Layout = typename decltype(a_reg)::Layout;
2821
+
2822
+ WP_PRAGMA_UNROLL
2823
+ for (int i=0; i < Layout::NumRegs; ++i)
2824
+ {
2825
+ const int linear = Layout::linear_from_register(i);
2826
+
2827
+ if(!Layout::valid(linear))
2828
+ break;
2829
+
2830
+ a_reg.data[i] &= b_reg.data[i];
2831
+ }
2832
+
2833
+ a.assign(a_reg);
2834
+ }
2835
+
2836
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2837
+ inline CUDA_CALLABLE void adj_tile_bit_and_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2838
+
2839
+ template <typename TileA, typename TileB>
2840
+ inline CUDA_CALLABLE void tile_bit_or_inplace(TileA& a, TileB& b)
2841
+ {
2842
+ using ShapeA = typename TileA::Layout::Shape;
2843
+ using ShapeB = typename TileB::Layout::Shape;
2844
+
2845
+ // verify shapes and sizes are compatible
2846
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise OR");
2847
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise OR");
2848
+
2849
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2850
+ auto a_reg = a.copy_to_register();
2851
+ auto b_reg = b.copy_to_register();
2852
+
2853
+ using Layout = typename decltype(a_reg)::Layout;
2854
+
2855
+ WP_PRAGMA_UNROLL
2856
+ for (int i=0; i < Layout::NumRegs; ++i)
2857
+ {
2858
+ const int linear = Layout::linear_from_register(i);
2859
+
2860
+ if(!Layout::valid(linear))
2861
+ break;
2862
+
2863
+ a_reg.data[i] |= b_reg.data[i];
2864
+ }
2865
+
2866
+ a.assign(a_reg);
2867
+ }
2868
+
2869
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2870
+ inline CUDA_CALLABLE void adj_tile_bit_or_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2871
+
2872
+ template <typename TileA, typename TileB>
2873
+ inline CUDA_CALLABLE void tile_bit_xor_inplace(TileA& a, TileB& b)
2874
+ {
2875
+ using ShapeA = typename TileA::Layout::Shape;
2876
+ using ShapeB = typename TileB::Layout::Shape;
2877
+
2878
+ // verify shapes and sizes are compatible
2879
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise XOR");
2880
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise XOR");
2881
+
2882
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2883
+ auto a_reg = a.copy_to_register();
2884
+ auto b_reg = b.copy_to_register();
2885
+
2886
+ using Layout = typename decltype(a_reg)::Layout;
2887
+
2888
+ WP_PRAGMA_UNROLL
2889
+ for (int i=0; i < Layout::NumRegs; ++i)
2890
+ {
2891
+ const int linear = Layout::linear_from_register(i);
2892
+
2893
+ if(!Layout::valid(linear))
2894
+ break;
2895
+
2896
+ a_reg.data[i] ^= b_reg.data[i];
2897
+ }
2898
+
2899
+ a.assign(a_reg);
2900
+ }
2901
+
2902
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2903
+ inline CUDA_CALLABLE void adj_tile_bit_xor_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2904
+
2385
2905
 
2386
2906
  template<typename Tile>
2387
- typename Tile::Type tile_extract(Tile& t, int i) { return t.extract(tile_coord(i)); }
2907
+ typename Tile::Type tile_extract(Tile& t, int i) {
2908
+ return t.extract(tile_coord(i));
2909
+ }
2388
2910
  template<typename Tile>
2389
- typename Tile::Type tile_extract(Tile& t, int i, int j) { return t.extract(tile_coord(i,j)); }
2911
+ auto tile_extract(Tile& t, int i, int j) {
2912
+ if constexpr(is_vector<typename Tile::Type>::value) {
2913
+ return t.extract(tile_coord(i))[j];
2914
+ } else {
2915
+ return t.extract(tile_coord(i,j));
2916
+ }
2917
+ }
2390
2918
  template<typename Tile>
2391
- typename Tile::Type tile_extract(Tile& t, int i, int j, int k) { return t.extract(tile_coord(i,j,k)); }
2919
+ auto tile_extract(Tile& t, int i, int j, int k) {
2920
+ if constexpr(is_vector<typename Tile::Type>::value) {
2921
+ return t.extract(tile_coord(i,j))[k];
2922
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2923
+ return t.extract(tile_coord(i)).data[j][k];
2924
+ } else {
2925
+ return t.extract(tile_coord(i,j,k));
2926
+ }
2927
+ }
2392
2928
  template<typename Tile>
2393
- typename Tile::Type tile_extract(Tile& t, int i, int j, int k, int l) { return t.extract(tile_coord(i,j,k,l)); }
2929
+ auto tile_extract(Tile& t, int i, int j, int k, int l) {
2930
+ if constexpr(is_vector<typename Tile::Type>::value) {
2931
+ return t.extract(tile_coord(i,j,k))[l];
2932
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2933
+ return t.extract(tile_coord(i,j)).data[k][l];
2934
+ } else {
2935
+ return t.extract(tile_coord(i,j,k,l));
2936
+ }
2937
+ }
2938
+ template<typename Tile>
2939
+ auto tile_extract(Tile& t, int i, int j, int k, int l, int m) {
2940
+ if constexpr(is_vector<typename Tile::Type>::value) {
2941
+ return t.extract(tile_coord(i,j,k,l))[m];
2942
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2943
+ return t.extract(tile_coord(i,j,k)).data[l][m];
2944
+ } else {
2945
+ static_assert(always_false<Tile>::value,
2946
+ "tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
2947
+ }
2948
+ }
2949
+ template<typename Tile>
2950
+ auto tile_extract(Tile& t, int i, int j, int k, int l, int m, int n) {
2951
+ if constexpr(is_matrix<typename Tile::Type>::value) {
2952
+ return t.extract(tile_coord(i,j,k,l)).data[m][n];
2953
+ } else {
2954
+ static_assert(always_false<Tile>::value,
2955
+ "tile_extract with 6 indices requires a tile of matrices (4D tile)");
2956
+ }
2957
+ }
2394
2958
 
2395
2959
  template<typename Tile, typename AdjTile>
2396
- void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i), adj_ret); }
2397
- template<typename Tile, typename AdjTile>
2398
- void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j), adj_ret); }
2399
- template<typename Tile, typename AdjTile>
2400
- void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k), adj_ret); }
2401
- template<typename Tile, typename AdjTile>
2402
- void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
2960
+ void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) {
2961
+ adj_t.adj_extract(tile_coord(i), adj_ret);
2962
+ }
2963
+ template<typename Tile, typename AdjTile, typename AdjType>
2964
+ void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, AdjType adj_ret) {
2965
+ if constexpr(is_vector<typename Tile::Type>::value) {
2966
+ typename Tile::Type vector_adj{};
2967
+ vector_adj[j] = adj_ret;
2968
+ adj_t.adj_extract(tile_coord(i), vector_adj);
2969
+ } else {
2970
+ adj_t.adj_extract(tile_coord(i, j), adj_ret);
2971
+ }
2972
+ }
2973
+ template<typename Tile, typename AdjTile, typename AdjType>
2974
+ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, AdjType adj_ret) {
2975
+ if constexpr(is_vector<typename Tile::Type>::value) {
2976
+ typename Tile::Type vector_adj{};
2977
+ vector_adj[k] = adj_ret;
2978
+ adj_t.adj_extract(tile_coord(i, j), vector_adj);
2979
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2980
+ typename Tile::Type matrix_adj{};
2981
+ matrix_adj.data[j][k] = adj_ret;
2982
+ adj_t.adj_extract(tile_coord(i), matrix_adj);
2983
+ } else {
2984
+ adj_t.adj_extract(tile_coord(i, j, k), adj_ret);
2985
+ }
2986
+ }
2987
+ template<typename Tile, typename AdjTile, typename AdjType>
2988
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, AdjType adj_ret) {
2989
+ if constexpr(is_vector<typename Tile::Type>::value) {
2990
+ typename Tile::Type vector_adj{};
2991
+ vector_adj[l] = adj_ret;
2992
+ adj_t.adj_extract(tile_coord(i, j, k), vector_adj);
2993
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2994
+ typename Tile::Type matrix_adj{};
2995
+ matrix_adj.data[k][l] = adj_ret;
2996
+ adj_t.adj_extract(tile_coord(i, j), matrix_adj);
2997
+ } else {
2998
+ adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret);
2999
+ }
3000
+ }
3001
+ template<typename Tile, typename AdjTile, typename AdjType>
3002
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, AdjType adj_ret) {
3003
+ if constexpr(is_vector<typename Tile::Type>::value) {
3004
+ typename Tile::Type vector_adj{};
3005
+ vector_adj[m] = adj_ret;
3006
+ adj_t.adj_extract(tile_coord(i, j, k, l), vector_adj);
3007
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
3008
+ typename Tile::Type matrix_adj{};
3009
+ matrix_adj.data[l][m] = adj_ret;
3010
+ adj_t.adj_extract(tile_coord(i, j, k), matrix_adj);
3011
+ } else {
3012
+ static_assert(always_false<Tile>::value,
3013
+ "adj_tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
3014
+ }
3015
+ }
3016
+ template<typename Tile, typename AdjTile, typename AdjType>
3017
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, int n, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, AdjType adj_ret) {
3018
+ if constexpr(is_matrix<typename Tile::Type>::value) {
3019
+ typename Tile::Type matrix_adj{};
3020
+ matrix_adj.data[m][n] = adj_ret;
3021
+ adj_t.adj_extract(tile_coord(i, j, k, l), matrix_adj);
3022
+ } else {
3023
+ static_assert(always_false<Tile>::value,
3024
+ "adj_tile_extract with 6 indices requires a tile of matrices (4D tile)");
3025
+ }
3026
+ }
2403
3027
 
2404
3028
 
2405
3029
  template<typename Tile>
@@ -2420,6 +3044,33 @@ void tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) {
2420
3044
  template<typename Tile>
2421
3045
  void tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k,l), value); }
2422
3046
 
3047
+ template<typename Tile>
3048
+ void tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i), value); }
3049
+ template<typename Tile>
3050
+ void tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j), value); }
3051
+ template<typename Tile>
3052
+ void tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k), value); }
3053
+ template<typename Tile>
3054
+ void tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k,l), value); }
3055
+
3056
+ template<typename Tile>
3057
+ void tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i), value); }
3058
+ template<typename Tile>
3059
+ void tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j), value); }
3060
+ template<typename Tile>
3061
+ void tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k), value); }
3062
+ template<typename Tile>
3063
+ void tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k,l), value); }
3064
+
3065
+ template<typename Tile>
3066
+ void tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i), value); }
3067
+ template<typename Tile>
3068
+ void tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j), value); }
3069
+ template<typename Tile>
3070
+ void tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k), value); }
3071
+ template<typename Tile>
3072
+ void tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k,l), value); }
3073
+
2423
3074
  template<typename Tile, typename AdjTile>
2424
3075
  void adj_tile_add_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i), adj_value); }
2425
3076
  template<typename Tile, typename AdjTile>
@@ -2438,6 +3089,33 @@ void adj_tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type valu
2438
3089
  template<typename Tile, typename AdjTile>
2439
3090
  void adj_tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k, l), adj_value); }
2440
3091
 
3092
+ template<typename Tile, typename AdjTile>
3093
+ void adj_tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3094
+ template<typename Tile, typename AdjTile>
3095
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3096
+ template<typename Tile, typename AdjTile>
3097
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3098
+ template<typename Tile, typename AdjTile>
3099
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3100
+
3101
+ template<typename Tile, typename AdjTile>
3102
+ void adj_tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3103
+ template<typename Tile, typename AdjTile>
3104
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3105
+ template<typename Tile, typename AdjTile>
3106
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3107
+ template<typename Tile, typename AdjTile>
3108
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3109
+
3110
+ template<typename Tile, typename AdjTile>
3111
+ void adj_tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3112
+ template<typename Tile, typename AdjTile>
3113
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3114
+ template<typename Tile, typename AdjTile>
3115
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3116
+ template<typename Tile, typename AdjTile>
3117
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3118
+
2441
3119
  namespace partitioned_gemm
2442
3120
  {
2443
3121
 
@@ -2825,7 +3503,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2825
3503
  #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
2826
3504
  do { \
2827
3505
  void function_name(dtype*, char*); \
2828
- char* buffer = (char*)wp::tile_alloc_shared(shared_memory_size); \
3506
+ char* buffer = (char*)wp::tile_shared_storage_t::alloc(shared_memory_size); \
2829
3507
  __align__(16) dtype data[ept]; \
2830
3508
  for(int b = 0; b < (int)batch_size; b++) { \
2831
3509
  dtype* inout = Xinout.data + (int)b * (int)ept; \
@@ -2834,7 +3512,7 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2834
3512
  memcpy(inout, data, sizeof(dtype) * ept); \
2835
3513
  WP_TILE_SYNC(); \
2836
3514
  } \
2837
- wp::tile_alloc_shared(-shared_memory_size); \
3515
+ wp::tile_shared_storage_t::alloc(-shared_memory_size); \
2838
3516
  } while (0)
2839
3517
 
2840
3518
  #define tile_ifft tile_fft
@@ -2878,7 +3556,7 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2878
3556
  #else
2879
3557
 
2880
3558
  // TODO: for batched Cholesky, need one info per batch
2881
- WP_TILE_SHARED int info[1];
3559
+ __shared__ int info[1];
2882
3560
 
2883
3561
  if (WP_TILE_THREAD_IDX == 0) {
2884
3562
  info[0] = 0;
@@ -3048,7 +3726,7 @@ template <typename Tile, typename AdjTile>
3048
3726
  inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
3049
3727
  {
3050
3728
  auto a = tile_transpose(adj_ret);
3051
- auto b = adj_t;
3729
+ auto& b = adj_t;
3052
3730
 
3053
3731
  adj_t.assign(tile_add(a,b));
3054
3732
  }
@@ -3210,22 +3888,63 @@ inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
3210
3888
  template <typename TileA, typename Scalar>
3211
3889
  inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
3212
3890
  {
3213
- dest.data(tile_coord(i, j)) = src;
3891
+ if constexpr(is_vector<typename TileA::Type>::value) {
3892
+ dest.data(tile_coord(i))[j] = src;
3893
+ } else {
3894
+ dest.data(tile_coord(i, j)) = src;
3895
+ }
3214
3896
  WP_TILE_SYNC();
3215
3897
  }
3216
3898
  template <typename TileA, typename Scalar>
3217
3899
  inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
3218
3900
  {
3219
- dest.data(tile_coord(i, j, k)) = src;
3901
+ if constexpr(is_vector<typename TileA::Type>::value) {
3902
+ dest.data(tile_coord(i, j))[k] = src;
3903
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3904
+ dest.data(tile_coord(i)).data[j][k] = src;
3905
+ } else {
3906
+ dest.data(tile_coord(i, j, k)) = src;
3907
+ }
3220
3908
  WP_TILE_SYNC();
3221
3909
  }
3222
3910
  template <typename TileA, typename Scalar>
3223
3911
  inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
3224
3912
  {
3225
- dest.data(tile_coord(i, j, k, l)) = src;
3913
+ if constexpr(is_vector<typename TileA::Type>::value) {
3914
+ dest.data(tile_coord(i, j, k))[l] = src;
3915
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3916
+ dest.data(tile_coord(i, j)).data[k][l] = src;
3917
+ } else {
3918
+ dest.data(tile_coord(i, j, k, l)) = src;
3919
+ }
3920
+ WP_TILE_SYNC();
3921
+ }
3922
+ template <typename TileA, typename Scalar>
3923
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src)
3924
+ {
3925
+ if constexpr(is_vector<typename TileA::Type>::value) {
3926
+ dest.data(tile_coord(i, j, k, l))[m] = src;
3927
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3928
+ dest.data(tile_coord(i, j, k)).data[l][m] = src;
3929
+ } else {
3930
+ static_assert(always_false<TileA>::value,
3931
+ "assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
3932
+ }
3933
+ WP_TILE_SYNC();
3934
+ }
3935
+ template <typename TileA, typename Scalar>
3936
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src)
3937
+ {
3938
+ if constexpr(is_matrix<typename TileA::Type>::value) {
3939
+ dest.data(tile_coord(i, j, k, l)).data[m][n] = src;
3940
+ } else {
3941
+ static_assert(always_false<TileA>::value,
3942
+ "assign with 6 indices requires a tile of matrices (4D tile)");
3943
+ }
3226
3944
  WP_TILE_SYNC();
3227
3945
  }
3228
3946
 
3947
+
3229
3948
  template <typename TileA, typename AdjTileA, typename Scalar>
3230
3949
  inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, const Scalar& src, AdjTileA& adj_dest, int adj_i, Scalar& adj_src)
3231
3950
  {
@@ -3244,7 +3963,11 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, const Scalar& sr
3244
3963
  return;
3245
3964
  }
3246
3965
 
3247
- adj_src += dest.grad(tile_coord(i, j));
3966
+ if constexpr(is_vector<typename TileA::Type>::value) {
3967
+ adj_src += dest.grad(tile_coord(i))[j];
3968
+ } else {
3969
+ adj_src += dest.grad(tile_coord(i, j));
3970
+ }
3248
3971
  }
3249
3972
  template <typename TileA, typename AdjTileA, typename Scalar>
3250
3973
  inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, Scalar& adj_src)
@@ -3254,7 +3977,13 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Sca
3254
3977
  return;
3255
3978
  }
3256
3979
 
3257
- adj_src += dest.grad(tile_coord(i, j, k));
3980
+ if constexpr(is_vector<typename TileA::Type>::value) {
3981
+ adj_src += dest.grad(tile_coord(i, j))[k];
3982
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3983
+ adj_src += dest.grad(tile_coord(i)).data[j][k];
3984
+ } else {
3985
+ adj_src += dest.grad(tile_coord(i, j, k));
3986
+ }
3258
3987
  }
3259
3988
  template <typename TileA, typename AdjTileA, typename Scalar>
3260
3989
  inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, Scalar& adj_src)
@@ -3264,7 +3993,45 @@ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, co
3264
3993
  return;
3265
3994
  }
3266
3995
 
3267
- adj_src += dest.grad(tile_coord(i, j, k, l));
3996
+ if constexpr(is_vector<typename TileA::Type>::value) {
3997
+ adj_src += dest.grad(tile_coord(i, j, k))[l];
3998
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3999
+ adj_src += dest.grad(tile_coord(i, j)).data[k][l];
4000
+ } else {
4001
+ adj_src += dest.grad(tile_coord(i, j, k, l));
4002
+ }
4003
+ }
4004
+ template <typename TileA, typename AdjTileA, typename Scalar>
4005
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, Scalar& adj_src)
4006
+ {
4007
+ if (dest.grad.ptr == nullptr)
4008
+ {
4009
+ return;
4010
+ }
4011
+
4012
+ if constexpr(is_vector<typename TileA::Type>::value) {
4013
+ adj_src += dest.grad(tile_coord(i, j, k, l))[m];
4014
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
4015
+ adj_src += dest.grad(tile_coord(i, j, k)).data[l][m];
4016
+ } else {
4017
+ static_assert(always_false<TileA>::value,
4018
+ "adj_assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
4019
+ }
4020
+ }
4021
+ template <typename TileA, typename AdjTileA, typename Scalar>
4022
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, Scalar& adj_src)
4023
+ {
4024
+ if (dest.grad.ptr == nullptr)
4025
+ {
4026
+ return;
4027
+ }
4028
+
4029
+ if constexpr(is_matrix<typename TileA::Type>::value) {
4030
+ adj_src += dest.grad(tile_coord(i, j, k, l)).data[m][n];
4031
+ } else {
4032
+ static_assert(always_false<TileA>::value,
4033
+ "adj_assign with 6 indices requires a tile of matrices (4D tile)");
4034
+ }
3268
4035
  }
3269
4036
 
3270
4037
  template <typename TileA, typename TileB, typename Coord>