warp-lang 1.10.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (468) hide show
  1. warp/__init__.py +334 -0
  2. warp/__init__.pyi +5856 -0
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1077 -0
  5. warp/_src/build.py +620 -0
  6. warp/_src/build_dll.py +642 -0
  7. warp/_src/builtins.py +10555 -0
  8. warp/_src/codegen.py +4361 -0
  9. warp/_src/config.py +178 -0
  10. warp/_src/constants.py +59 -0
  11. warp/_src/context.py +8352 -0
  12. warp/_src/dlpack.py +464 -0
  13. warp/_src/fabric.py +362 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +510 -0
  16. warp/_src/fem/cache.py +689 -0
  17. warp/_src/fem/dirichlet.py +190 -0
  18. warp/_src/fem/domain.py +553 -0
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +703 -0
  21. warp/_src/fem/field/nodal_field.py +403 -0
  22. warp/_src/fem/field/restriction.py +39 -0
  23. warp/_src/fem/field/virtual.py +1021 -0
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/_src/fem/geometry/adaptive_nanogrid.py +782 -0
  26. warp/_src/fem/geometry/closest_point.py +99 -0
  27. warp/_src/fem/geometry/deformed_geometry.py +277 -0
  28. warp/_src/fem/geometry/element.py +854 -0
  29. warp/_src/fem/geometry/geometry.py +693 -0
  30. warp/_src/fem/geometry/grid_2d.py +478 -0
  31. warp/_src/fem/geometry/grid_3d.py +539 -0
  32. warp/_src/fem/geometry/hexmesh.py +956 -0
  33. warp/_src/fem/geometry/nanogrid.py +660 -0
  34. warp/_src/fem/geometry/partition.py +483 -0
  35. warp/_src/fem/geometry/quadmesh.py +597 -0
  36. warp/_src/fem/geometry/tetmesh.py +762 -0
  37. warp/_src/fem/geometry/trimesh.py +588 -0
  38. warp/_src/fem/integrate.py +2507 -0
  39. warp/_src/fem/linalg.py +385 -0
  40. warp/_src/fem/operator.py +398 -0
  41. warp/_src/fem/polynomial.py +231 -0
  42. warp/_src/fem/quadrature/__init__.py +17 -0
  43. warp/_src/fem/quadrature/pic_quadrature.py +318 -0
  44. warp/_src/fem/quadrature/quadrature.py +665 -0
  45. warp/_src/fem/space/__init__.py +248 -0
  46. warp/_src/fem/space/basis_function_space.py +499 -0
  47. warp/_src/fem/space/basis_space.py +681 -0
  48. warp/_src/fem/space/dof_mapper.py +253 -0
  49. warp/_src/fem/space/function_space.py +312 -0
  50. warp/_src/fem/space/grid_2d_function_space.py +179 -0
  51. warp/_src/fem/space/grid_3d_function_space.py +229 -0
  52. warp/_src/fem/space/hexmesh_function_space.py +255 -0
  53. warp/_src/fem/space/nanogrid_function_space.py +199 -0
  54. warp/_src/fem/space/partition.py +435 -0
  55. warp/_src/fem/space/quadmesh_function_space.py +222 -0
  56. warp/_src/fem/space/restriction.py +221 -0
  57. warp/_src/fem/space/shape/__init__.py +152 -0
  58. warp/_src/fem/space/shape/cube_shape_function.py +1107 -0
  59. warp/_src/fem/space/shape/shape_function.py +134 -0
  60. warp/_src/fem/space/shape/square_shape_function.py +928 -0
  61. warp/_src/fem/space/shape/tet_shape_function.py +829 -0
  62. warp/_src/fem/space/shape/triangle_shape_function.py +674 -0
  63. warp/_src/fem/space/tetmesh_function_space.py +270 -0
  64. warp/_src/fem/space/topology.py +461 -0
  65. warp/_src/fem/space/trimesh_function_space.py +193 -0
  66. warp/_src/fem/types.py +114 -0
  67. warp/_src/fem/utils.py +488 -0
  68. warp/_src/jax.py +188 -0
  69. warp/_src/jax_experimental/__init__.py +14 -0
  70. warp/_src/jax_experimental/custom_call.py +389 -0
  71. warp/_src/jax_experimental/ffi.py +1286 -0
  72. warp/_src/jax_experimental/xla_ffi.py +658 -0
  73. warp/_src/marching_cubes.py +710 -0
  74. warp/_src/math.py +416 -0
  75. warp/_src/optim/__init__.py +14 -0
  76. warp/_src/optim/adam.py +165 -0
  77. warp/_src/optim/linear.py +1608 -0
  78. warp/_src/optim/sgd.py +114 -0
  79. warp/_src/paddle.py +408 -0
  80. warp/_src/render/__init__.py +14 -0
  81. warp/_src/render/imgui_manager.py +291 -0
  82. warp/_src/render/render_opengl.py +3638 -0
  83. warp/_src/render/render_usd.py +939 -0
  84. warp/_src/render/utils.py +162 -0
  85. warp/_src/sparse.py +2718 -0
  86. warp/_src/tape.py +1208 -0
  87. warp/_src/thirdparty/__init__.py +0 -0
  88. warp/_src/thirdparty/appdirs.py +598 -0
  89. warp/_src/thirdparty/dlpack.py +145 -0
  90. warp/_src/thirdparty/unittest_parallel.py +676 -0
  91. warp/_src/torch.py +393 -0
  92. warp/_src/types.py +5888 -0
  93. warp/_src/utils.py +1695 -0
  94. warp/autograd.py +33 -0
  95. warp/bin/libwarp-clang.dylib +0 -0
  96. warp/bin/libwarp.dylib +0 -0
  97. warp/build.py +29 -0
  98. warp/build_dll.py +24 -0
  99. warp/codegen.py +24 -0
  100. warp/constants.py +24 -0
  101. warp/context.py +33 -0
  102. warp/dlpack.py +24 -0
  103. warp/examples/__init__.py +24 -0
  104. warp/examples/assets/bear.usd +0 -0
  105. warp/examples/assets/bunny.usd +0 -0
  106. warp/examples/assets/cube.usd +0 -0
  107. warp/examples/assets/nonuniform.usd +0 -0
  108. warp/examples/assets/nvidia_logo.png +0 -0
  109. warp/examples/assets/pixel.jpg +0 -0
  110. warp/examples/assets/rocks.nvdb +0 -0
  111. warp/examples/assets/rocks.usd +0 -0
  112. warp/examples/assets/sphere.usd +0 -0
  113. warp/examples/assets/square_cloth.usd +0 -0
  114. warp/examples/benchmarks/benchmark_api.py +389 -0
  115. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  116. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  117. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  118. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  119. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  120. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  121. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  122. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  123. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  124. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  125. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  126. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  127. warp/examples/benchmarks/benchmark_launches.py +301 -0
  128. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  129. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  130. warp/examples/browse.py +37 -0
  131. warp/examples/core/example_cupy.py +86 -0
  132. warp/examples/core/example_dem.py +241 -0
  133. warp/examples/core/example_fluid.py +299 -0
  134. warp/examples/core/example_graph_capture.py +150 -0
  135. warp/examples/core/example_marching_cubes.py +195 -0
  136. warp/examples/core/example_mesh.py +180 -0
  137. warp/examples/core/example_mesh_intersect.py +211 -0
  138. warp/examples/core/example_nvdb.py +182 -0
  139. warp/examples/core/example_raycast.py +111 -0
  140. warp/examples/core/example_raymarch.py +205 -0
  141. warp/examples/core/example_render_opengl.py +290 -0
  142. warp/examples/core/example_sample_mesh.py +300 -0
  143. warp/examples/core/example_sph.py +411 -0
  144. warp/examples/core/example_spin_lock.py +93 -0
  145. warp/examples/core/example_torch.py +211 -0
  146. warp/examples/core/example_wave.py +269 -0
  147. warp/examples/core/example_work_queue.py +118 -0
  148. warp/examples/distributed/example_jacobi_mpi.py +506 -0
  149. warp/examples/fem/example_adaptive_grid.py +286 -0
  150. warp/examples/fem/example_apic_fluid.py +469 -0
  151. warp/examples/fem/example_burgers.py +261 -0
  152. warp/examples/fem/example_convection_diffusion.py +181 -0
  153. warp/examples/fem/example_convection_diffusion_dg.py +225 -0
  154. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  155. warp/examples/fem/example_deformed_geometry.py +172 -0
  156. warp/examples/fem/example_diffusion.py +196 -0
  157. warp/examples/fem/example_diffusion_3d.py +225 -0
  158. warp/examples/fem/example_diffusion_mgpu.py +225 -0
  159. warp/examples/fem/example_distortion_energy.py +228 -0
  160. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  161. warp/examples/fem/example_magnetostatics.py +242 -0
  162. warp/examples/fem/example_mixed_elasticity.py +293 -0
  163. warp/examples/fem/example_navier_stokes.py +263 -0
  164. warp/examples/fem/example_nonconforming_contact.py +300 -0
  165. warp/examples/fem/example_stokes.py +213 -0
  166. warp/examples/fem/example_stokes_transfer.py +262 -0
  167. warp/examples/fem/example_streamlines.py +357 -0
  168. warp/examples/fem/utils.py +1047 -0
  169. warp/examples/interop/example_jax_callable.py +146 -0
  170. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  171. warp/examples/interop/example_jax_kernel.py +232 -0
  172. warp/examples/optim/example_diffray.py +561 -0
  173. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  174. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  175. warp/examples/tile/example_tile_cholesky.py +88 -0
  176. warp/examples/tile/example_tile_convolution.py +66 -0
  177. warp/examples/tile/example_tile_fft.py +55 -0
  178. warp/examples/tile/example_tile_filtering.py +113 -0
  179. warp/examples/tile/example_tile_matmul.py +85 -0
  180. warp/examples/tile/example_tile_mcgp.py +191 -0
  181. warp/examples/tile/example_tile_mlp.py +385 -0
  182. warp/examples/tile/example_tile_nbody.py +199 -0
  183. warp/fabric.py +24 -0
  184. warp/fem/__init__.py +173 -0
  185. warp/fem/adaptivity.py +26 -0
  186. warp/fem/cache.py +30 -0
  187. warp/fem/dirichlet.py +24 -0
  188. warp/fem/field/__init__.py +24 -0
  189. warp/fem/field/field.py +26 -0
  190. warp/fem/geometry/__init__.py +21 -0
  191. warp/fem/geometry/closest_point.py +31 -0
  192. warp/fem/linalg.py +38 -0
  193. warp/fem/operator.py +32 -0
  194. warp/fem/polynomial.py +29 -0
  195. warp/fem/space/__init__.py +22 -0
  196. warp/fem/space/basis_space.py +24 -0
  197. warp/fem/space/shape/__init__.py +68 -0
  198. warp/fem/space/topology.py +24 -0
  199. warp/fem/types.py +24 -0
  200. warp/fem/utils.py +32 -0
  201. warp/jax.py +29 -0
  202. warp/jax_experimental/__init__.py +29 -0
  203. warp/jax_experimental/custom_call.py +29 -0
  204. warp/jax_experimental/ffi.py +39 -0
  205. warp/jax_experimental/xla_ffi.py +24 -0
  206. warp/marching_cubes.py +24 -0
  207. warp/math.py +37 -0
  208. warp/native/array.h +1687 -0
  209. warp/native/builtin.h +2327 -0
  210. warp/native/bvh.cpp +562 -0
  211. warp/native/bvh.cu +826 -0
  212. warp/native/bvh.h +555 -0
  213. warp/native/clang/clang.cpp +541 -0
  214. warp/native/coloring.cpp +622 -0
  215. warp/native/crt.cpp +51 -0
  216. warp/native/crt.h +568 -0
  217. warp/native/cuda_crt.h +1058 -0
  218. warp/native/cuda_util.cpp +677 -0
  219. warp/native/cuda_util.h +313 -0
  220. warp/native/error.cpp +77 -0
  221. warp/native/error.h +36 -0
  222. warp/native/exports.h +2023 -0
  223. warp/native/fabric.h +246 -0
  224. warp/native/hashgrid.cpp +311 -0
  225. warp/native/hashgrid.cu +89 -0
  226. warp/native/hashgrid.h +240 -0
  227. warp/native/initializer_array.h +41 -0
  228. warp/native/intersect.h +1253 -0
  229. warp/native/intersect_adj.h +375 -0
  230. warp/native/intersect_tri.h +348 -0
  231. warp/native/mat.h +5189 -0
  232. warp/native/mathdx.cpp +93 -0
  233. warp/native/matnn.h +221 -0
  234. warp/native/mesh.cpp +266 -0
  235. warp/native/mesh.cu +406 -0
  236. warp/native/mesh.h +2097 -0
  237. warp/native/nanovdb/GridHandle.h +533 -0
  238. warp/native/nanovdb/HostBuffer.h +591 -0
  239. warp/native/nanovdb/NanoVDB.h +6246 -0
  240. warp/native/nanovdb/NodeManager.h +323 -0
  241. warp/native/nanovdb/PNanoVDB.h +3390 -0
  242. warp/native/noise.h +859 -0
  243. warp/native/quat.h +1664 -0
  244. warp/native/rand.h +342 -0
  245. warp/native/range.h +145 -0
  246. warp/native/reduce.cpp +174 -0
  247. warp/native/reduce.cu +363 -0
  248. warp/native/runlength_encode.cpp +79 -0
  249. warp/native/runlength_encode.cu +61 -0
  250. warp/native/scan.cpp +47 -0
  251. warp/native/scan.cu +55 -0
  252. warp/native/scan.h +23 -0
  253. warp/native/solid_angle.h +466 -0
  254. warp/native/sort.cpp +251 -0
  255. warp/native/sort.cu +286 -0
  256. warp/native/sort.h +35 -0
  257. warp/native/sparse.cpp +241 -0
  258. warp/native/sparse.cu +435 -0
  259. warp/native/spatial.h +1306 -0
  260. warp/native/svd.h +727 -0
  261. warp/native/temp_buffer.h +46 -0
  262. warp/native/tile.h +4124 -0
  263. warp/native/tile_radix_sort.h +1112 -0
  264. warp/native/tile_reduce.h +838 -0
  265. warp/native/tile_scan.h +240 -0
  266. warp/native/tuple.h +189 -0
  267. warp/native/vec.h +2199 -0
  268. warp/native/version.h +23 -0
  269. warp/native/volume.cpp +501 -0
  270. warp/native/volume.cu +68 -0
  271. warp/native/volume.h +970 -0
  272. warp/native/volume_builder.cu +483 -0
  273. warp/native/volume_builder.h +52 -0
  274. warp/native/volume_impl.h +70 -0
  275. warp/native/warp.cpp +1143 -0
  276. warp/native/warp.cu +4604 -0
  277. warp/native/warp.h +358 -0
  278. warp/optim/__init__.py +20 -0
  279. warp/optim/adam.py +24 -0
  280. warp/optim/linear.py +35 -0
  281. warp/optim/sgd.py +24 -0
  282. warp/paddle.py +24 -0
  283. warp/py.typed +0 -0
  284. warp/render/__init__.py +22 -0
  285. warp/render/imgui_manager.py +29 -0
  286. warp/render/render_opengl.py +24 -0
  287. warp/render/render_usd.py +24 -0
  288. warp/render/utils.py +24 -0
  289. warp/sparse.py +51 -0
  290. warp/tape.py +24 -0
  291. warp/tests/__init__.py +1 -0
  292. warp/tests/__main__.py +4 -0
  293. warp/tests/assets/curlnoise_golden.npy +0 -0
  294. warp/tests/assets/mlp_golden.npy +0 -0
  295. warp/tests/assets/pixel.npy +0 -0
  296. warp/tests/assets/pnoise_golden.npy +0 -0
  297. warp/tests/assets/spiky.usd +0 -0
  298. warp/tests/assets/test_grid.nvdb +0 -0
  299. warp/tests/assets/test_index_grid.nvdb +0 -0
  300. warp/tests/assets/test_int32_grid.nvdb +0 -0
  301. warp/tests/assets/test_vec_grid.nvdb +0 -0
  302. warp/tests/assets/torus.nvdb +0 -0
  303. warp/tests/assets/torus.usda +105 -0
  304. warp/tests/aux_test_class_kernel.py +34 -0
  305. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  306. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  307. warp/tests/aux_test_dependent.py +29 -0
  308. warp/tests/aux_test_grad_customs.py +29 -0
  309. warp/tests/aux_test_instancing_gc.py +26 -0
  310. warp/tests/aux_test_module_aot.py +7 -0
  311. warp/tests/aux_test_module_unload.py +23 -0
  312. warp/tests/aux_test_name_clash1.py +40 -0
  313. warp/tests/aux_test_name_clash2.py +40 -0
  314. warp/tests/aux_test_reference.py +9 -0
  315. warp/tests/aux_test_reference_reference.py +8 -0
  316. warp/tests/aux_test_square.py +16 -0
  317. warp/tests/aux_test_unresolved_func.py +22 -0
  318. warp/tests/aux_test_unresolved_symbol.py +22 -0
  319. warp/tests/cuda/__init__.py +0 -0
  320. warp/tests/cuda/test_async.py +676 -0
  321. warp/tests/cuda/test_conditional_captures.py +1147 -0
  322. warp/tests/cuda/test_ipc.py +124 -0
  323. warp/tests/cuda/test_mempool.py +233 -0
  324. warp/tests/cuda/test_multigpu.py +169 -0
  325. warp/tests/cuda/test_peer.py +139 -0
  326. warp/tests/cuda/test_pinned.py +84 -0
  327. warp/tests/cuda/test_streams.py +691 -0
  328. warp/tests/geometry/__init__.py +0 -0
  329. warp/tests/geometry/test_bvh.py +335 -0
  330. warp/tests/geometry/test_hash_grid.py +259 -0
  331. warp/tests/geometry/test_marching_cubes.py +294 -0
  332. warp/tests/geometry/test_mesh.py +318 -0
  333. warp/tests/geometry/test_mesh_query_aabb.py +392 -0
  334. warp/tests/geometry/test_mesh_query_point.py +935 -0
  335. warp/tests/geometry/test_mesh_query_ray.py +323 -0
  336. warp/tests/geometry/test_volume.py +1103 -0
  337. warp/tests/geometry/test_volume_write.py +346 -0
  338. warp/tests/interop/__init__.py +0 -0
  339. warp/tests/interop/test_dlpack.py +730 -0
  340. warp/tests/interop/test_jax.py +1673 -0
  341. warp/tests/interop/test_paddle.py +800 -0
  342. warp/tests/interop/test_torch.py +1001 -0
  343. warp/tests/run_coverage_serial.py +39 -0
  344. warp/tests/test_adam.py +162 -0
  345. warp/tests/test_arithmetic.py +1096 -0
  346. warp/tests/test_array.py +3756 -0
  347. warp/tests/test_array_reduce.py +156 -0
  348. warp/tests/test_assert.py +303 -0
  349. warp/tests/test_atomic.py +336 -0
  350. warp/tests/test_atomic_bitwise.py +209 -0
  351. warp/tests/test_atomic_cas.py +312 -0
  352. warp/tests/test_bool.py +220 -0
  353. warp/tests/test_builtins_resolution.py +732 -0
  354. warp/tests/test_closest_point_edge_edge.py +327 -0
  355. warp/tests/test_codegen.py +974 -0
  356. warp/tests/test_codegen_instancing.py +1495 -0
  357. warp/tests/test_compile_consts.py +215 -0
  358. warp/tests/test_conditional.py +298 -0
  359. warp/tests/test_context.py +35 -0
  360. warp/tests/test_copy.py +319 -0
  361. warp/tests/test_ctypes.py +618 -0
  362. warp/tests/test_dense.py +73 -0
  363. warp/tests/test_devices.py +127 -0
  364. warp/tests/test_enum.py +136 -0
  365. warp/tests/test_examples.py +424 -0
  366. warp/tests/test_fabricarray.py +998 -0
  367. warp/tests/test_fast_math.py +72 -0
  368. warp/tests/test_fem.py +2204 -0
  369. warp/tests/test_fixedarray.py +229 -0
  370. warp/tests/test_fp16.py +136 -0
  371. warp/tests/test_func.py +501 -0
  372. warp/tests/test_future_annotations.py +100 -0
  373. warp/tests/test_generics.py +656 -0
  374. warp/tests/test_grad.py +893 -0
  375. warp/tests/test_grad_customs.py +339 -0
  376. warp/tests/test_grad_debug.py +341 -0
  377. warp/tests/test_implicit_init.py +411 -0
  378. warp/tests/test_import.py +45 -0
  379. warp/tests/test_indexedarray.py +1140 -0
  380. warp/tests/test_intersect.py +103 -0
  381. warp/tests/test_iter.py +76 -0
  382. warp/tests/test_large.py +177 -0
  383. warp/tests/test_launch.py +411 -0
  384. warp/tests/test_lerp.py +151 -0
  385. warp/tests/test_linear_solvers.py +223 -0
  386. warp/tests/test_lvalue.py +427 -0
  387. warp/tests/test_map.py +526 -0
  388. warp/tests/test_mat.py +3515 -0
  389. warp/tests/test_mat_assign_copy.py +178 -0
  390. warp/tests/test_mat_constructors.py +573 -0
  391. warp/tests/test_mat_lite.py +122 -0
  392. warp/tests/test_mat_scalar_ops.py +2913 -0
  393. warp/tests/test_math.py +212 -0
  394. warp/tests/test_module_aot.py +287 -0
  395. warp/tests/test_module_hashing.py +258 -0
  396. warp/tests/test_modules_lite.py +70 -0
  397. warp/tests/test_noise.py +252 -0
  398. warp/tests/test_operators.py +299 -0
  399. warp/tests/test_options.py +129 -0
  400. warp/tests/test_overwrite.py +551 -0
  401. warp/tests/test_print.py +408 -0
  402. warp/tests/test_quat.py +2653 -0
  403. warp/tests/test_quat_assign_copy.py +145 -0
  404. warp/tests/test_rand.py +339 -0
  405. warp/tests/test_reload.py +303 -0
  406. warp/tests/test_rounding.py +157 -0
  407. warp/tests/test_runlength_encode.py +196 -0
  408. warp/tests/test_scalar_ops.py +133 -0
  409. warp/tests/test_smoothstep.py +108 -0
  410. warp/tests/test_snippet.py +318 -0
  411. warp/tests/test_sparse.py +845 -0
  412. warp/tests/test_spatial.py +2859 -0
  413. warp/tests/test_spatial_assign_copy.py +160 -0
  414. warp/tests/test_special_values.py +361 -0
  415. warp/tests/test_static.py +640 -0
  416. warp/tests/test_struct.py +901 -0
  417. warp/tests/test_tape.py +242 -0
  418. warp/tests/test_transient_module.py +93 -0
  419. warp/tests/test_triangle_closest_point.py +192 -0
  420. warp/tests/test_tuple.py +361 -0
  421. warp/tests/test_types.py +615 -0
  422. warp/tests/test_utils.py +594 -0
  423. warp/tests/test_vec.py +1408 -0
  424. warp/tests/test_vec_assign_copy.py +143 -0
  425. warp/tests/test_vec_constructors.py +325 -0
  426. warp/tests/test_vec_lite.py +80 -0
  427. warp/tests/test_vec_scalar_ops.py +2327 -0
  428. warp/tests/test_verify_fp.py +100 -0
  429. warp/tests/test_version.py +75 -0
  430. warp/tests/tile/__init__.py +0 -0
  431. warp/tests/tile/test_tile.py +1519 -0
  432. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  433. warp/tests/tile/test_tile_cholesky.py +608 -0
  434. warp/tests/tile/test_tile_load.py +724 -0
  435. warp/tests/tile/test_tile_mathdx.py +156 -0
  436. warp/tests/tile/test_tile_matmul.py +179 -0
  437. warp/tests/tile/test_tile_mlp.py +400 -0
  438. warp/tests/tile/test_tile_reduce.py +950 -0
  439. warp/tests/tile/test_tile_shared_memory.py +376 -0
  440. warp/tests/tile/test_tile_sort.py +121 -0
  441. warp/tests/tile/test_tile_view.py +173 -0
  442. warp/tests/unittest_serial.py +47 -0
  443. warp/tests/unittest_suites.py +430 -0
  444. warp/tests/unittest_utils.py +469 -0
  445. warp/tests/walkthrough_debug.py +95 -0
  446. warp/torch.py +24 -0
  447. warp/types.py +51 -0
  448. warp/utils.py +31 -0
  449. warp_lang-1.10.0.dist-info/METADATA +459 -0
  450. warp_lang-1.10.0.dist-info/RECORD +468 -0
  451. warp_lang-1.10.0.dist-info/WHEEL +5 -0
  452. warp_lang-1.10.0.dist-info/licenses/LICENSE.md +176 -0
  453. warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  454. warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  455. warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  456. warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  457. warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  458. warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  459. warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  460. warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  461. warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  462. warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  463. warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  464. warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  465. warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  466. warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  467. warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  468. warp_lang-1.10.0.dist-info/top_level.txt +1 -0
warp/native/tile.h ADDED
@@ -0,0 +1,4124 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "builtin.h"
21
+
22
+ #ifdef __clang__
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif // __clang__
27
+
28
+ // Check if the CUDA toolkit is available
29
+ #if WP_ENABLE_CUDA || defined(__CUDACC_RTC__)
30
+
31
+ // If NVRTC is being used, do not include extra headers (NVRTC has built-in float4)
32
+ #ifdef __CUDACC_RTC__
33
+ // NVRTC: Use built-in float4 (no need for extra definitions)
34
+ #else
35
+ // NVCC: Include vector_types.h to get float4
36
+ #include <cuda_runtime.h>
37
+ #endif
38
+
39
+ #else
40
+ // If CUDA is not available (e.g., macOS build), manually define float4
41
+ struct alignas(16) float4 {
42
+ float x, y, z, w;
43
+ };
44
+ #endif
45
+
46
+ #if defined(__CUDA_ARCH__)
47
+ #define WP_TILE_SYNC __syncthreads
48
+ #else
49
+ #define WP_TILE_SYNC void
50
+ #endif
51
+
52
+ #if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
53
+ #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
54
+ #define WP_PRAGMA_UNROLL _Pragma("unroll")
55
+ #define WP_PRAGMA_NO_UNROLL _Pragma("unroll 1")
56
+ #else
57
+ #define WP_PRAGMA_UNROLL #pragma unroll
58
+ #define WP_PRAGMA_NO_UNROLL #pragma unroll 1
59
+ #endif
60
+
61
+ #else
62
+
63
+ #define WP_PRAGMA_UNROLL
64
+ #define WP_PRAGMA_NO_UNROLL
65
+
66
+ #endif
67
+
68
+ #define WP_USE_ASYNC_PIPELINE 0
69
+ #define WP_USE_REGISTER_GEMM 0
70
+
71
+ #if defined(__CUDACC_RTC__)
72
+ #define WP_TILE_THREAD_IDX threadIdx.x
73
+ #else
74
+ #define WP_TILE_THREAD_IDX 0
75
+ #endif //
76
+
77
+
78
+
79
+ /* Tile Expressions
80
+
81
+ [ ] Tiles
82
+ [x] Register, Shared, Global
83
+ [ ] Layouts
84
+ [x] Simple
85
+ [ ] Cute
86
+ [x] Remove Alloc type from tile_shared_t
87
+ [x] wp.launch_tiled() helper
88
+ [ ] Creation
89
+ [x] zeros
90
+ [x] ones
91
+ [x] arange
92
+ [x] tile()
93
+ [x] untile()
94
+ [ ] fromfunction()
95
+ [ ] explicit storage
96
+ [ ] Load/Store
97
+ [ ] 1D load/store variants
98
+ [ ] max_coord option for non-aligned loads
99
+ [ ] Indexed load
100
+ [x] wp.tile_atomic_add()
101
+ [ ] Maps
102
+ [x] Support user functions
103
+ [x] Support built-in functions
104
+ [ ] Support for lambda functions
105
+ [ ] Infer tile_map() output from operator type (e.g.: dot for each element)
106
+ [ ] Reductions
107
+ [x] Sum
108
+ [x] Forward
109
+ [x] Reverse
110
+ [x] Min
111
+ [x] Max
112
+ [x] Custom
113
+ [x] MatMul
114
+ [x] Forward
115
+ [x] Reverse
116
+ [ ] Operators
117
+ [ ] +, -, *, /, @?
118
+ [ ] += for matmul, e.g.: c += a@b, or c = a@b
119
+ [ ] Reshape
120
+ [ ] Broadcasting
121
+ [ ] Transpose
122
+ [x] Shared
123
+ [ ] Register
124
+ [ ] Slice
125
+ [ ] Runtime
126
+ [x] Compile-time block dimensions
127
+ [x] Switch between SIMT / Tile based execution if `block_dim` not provided to wp.launch()
128
+ [ ] Examples
129
+ [ ] Point registration
130
+ [ ] GEMM
131
+ [ ] MLP
132
+ [ ] LayerNorm
133
+ [ ] SoftMax
134
+ [ ] GEMM
135
+ [ ] Batched MLP
136
+ [ ] Layer norm
137
+ [ ] FNO + Burgers equation
138
+ [ ] Stochastic financial modeling
139
+ [ ] Convolution: https://github.com/NVIDIA/MinkowskiEngine/blob/master/src/convolution_kernel.cu#L123
140
+ [ ] MeshCNN (Modulus, Oliver)
141
+ [ ] BioNemo (Ali)
142
+ [ ] Skinning (David/Or/Vismay)
143
+ [ ] Error checking
144
+ [ ] Ensure functions passed to tile_map() are compatible with tile type
145
+ [ ] Ensure that args passed to tile ops are compatible
146
+ [ ] Ensure tile load/store operations don't go out of bounds of arrays in debug mode
147
+
148
+ */
149
+
150
+ /*
151
+ Notes on shared memory synchronization
152
+ ======================================
153
+
154
+ Currently operations that write to shared memory tiles (e.g.: tile_load())
155
+ must synchronize before they return through WP_TILE_SYNC(), this
156
+ ensures subsequent read operations from the tile do not cause a race condition.
157
+
158
+ For tile_shared_t adjoints, the gradient accumulation is done through shared
159
+ memory atomics, i.e.: atomic_add(), since for broadcast tiles multiple threads
160
+ may map to the same location. Synchronization is still required after these
161
+ updates, since subsequent operations e.g.: adj_tile_load() will store the
162
+ gradients to memory, and all updates must be visible at that point, e.g.:
163
+
164
+ a = wp.tile_load(...)
165
+ b = wp.tile_load(...)
166
+ c = wp.tile_matmul(a, b)
167
+ wp.tile_store(c)
168
+
169
+ // loads incoming adjoints from global -> shared
170
+ wp.adj_tile_store(c, adj_c)
171
+ // consumes adj_c, requires synchronization
172
+ wp.adj_tile_matmul(a, b, adj_a, adj_b, adj_c)
173
+ // consumes adj_b, requires synchronization
174
+ wp.adj_tile_load(..., adj_b)
175
+ // consumes adj_b, requires synchronization
176
+ wp.adj_tile_load(..., adj_a)
177
+
178
+ Generally synchronization to adjoint tiles will happen through the
179
+ tile_shared_t::add() and tile_shared_t::assign() function automatically,
180
+ but in some cases e.g.: tile_matmul() it is done manually.
181
+
182
+ The current synchronization strategy is conservative, and can lead to more
183
+ synchronization than necessary. A more sophisticated strategy would be
184
+ to track the 'dirty' state of shared tiles, and synchronize only when
185
+ necessary. In addition, custom synchronization for e.g.: tile_load()
186
+ operations could be added through a SyncProvider template parameter on
187
+ the tile_shared_t type, for example to support barrier synchronization
188
+ for asynchronous global to shared loads.
189
+ */
190
+
191
+ namespace wp
192
+ {
193
+
194
+ // Primary template
195
+ template <typename T, typename U>
196
+ struct is_same {
197
+ static constexpr bool value = false;
198
+ };
199
+
200
+ // Specialization for the case when T and U are the same type
201
+ template <typename T>
202
+ struct is_same<T, T> {
203
+ static constexpr bool value = true;
204
+ };
205
+
206
+ // Helper for dependent static_assert failures
207
+ template <typename T>
208
+ struct always_false {
209
+ static constexpr bool value = false;
210
+ };
211
+
212
+
213
+ template <int N>
214
+ struct tile_coord_t
215
+ {
216
+ int indices[N];
217
+
218
+ CUDA_CALLABLE inline int operator[](int i) const { assert(0 <= i && i < N); return indices[i]; }
219
+ CUDA_CALLABLE inline int& operator[](int i) { assert(0 <= i && i < N); return indices[i]; }
220
+
221
+ CUDA_CALLABLE inline tile_coord_t<N> operator + (const tile_coord_t<N>& c) const
222
+ {
223
+ tile_coord_t<N> out;
224
+ for (int i=0; i < N; ++i)
225
+ {
226
+ out.indices[i] = indices[i] + c.indices[i];
227
+ }
228
+ return out;
229
+ }
230
+
231
+ static constexpr int size() { return N; }
232
+ };
233
+
234
+ // This function deduces N = sizeof...(Ints)
235
+ template <typename... Ints>
236
+ constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
237
+ {
238
+ constexpr int N = sizeof...(Ints);
239
+
240
+ // Create the result
241
+ tile_coord_t<N> result{};
242
+
243
+ // Capture all arguments in a local array
244
+ int arr[] = { static_cast<int>(idxs)... };
245
+
246
+ // C++14 or later: 'for' is allowed in a constexpr context
247
+ for (int i = 0; i < N; ++i)
248
+ {
249
+ result.indices[i] = arr[i];
250
+ }
251
+
252
+ return result;
253
+ }
254
+
255
+ // helpers to construct a coord from a set of indices
256
+ inline auto tile_coord(int i)
257
+ {
258
+ auto c = tile_coord_t<1>();
259
+ c.indices[0] = i;
260
+ return c;
261
+ }
262
+
263
+ inline auto tile_coord(int i, int j)
264
+ {
265
+ auto c = tile_coord_t<2>();
266
+ c.indices[0] = i;
267
+ c.indices[1] = j;
268
+ return c;
269
+ }
270
+
271
+ inline auto tile_coord(int i, int j, int k)
272
+ {
273
+ auto c = tile_coord_t<3>();
274
+ c.indices[0] = i;
275
+ c.indices[1] = j;
276
+ c.indices[2] = k;
277
+ return c;
278
+ }
279
+
280
+ inline auto tile_coord(int i, int j, int k, int l)
281
+ {
282
+ auto c = tile_coord_t<4>();
283
+ c.indices[0] = i;
284
+ c.indices[1] = j;
285
+ c.indices[2] = k;
286
+ c.indices[3] = l;
287
+ return c;
288
+ }
289
+
290
+ // represents a compile time int tuple for strides/shapes/coords
291
+ template <int... V>
292
+ struct tile_tuple_t
293
+ {
294
+ static constexpr int N = sizeof...(V);
295
+ static_assert(N > 0, "Expected N > 0");
296
+
297
+ static constexpr int data[N] = { V... };
298
+
299
+ static constexpr int dim(int i) { assert(i < N); return data[i]; }
300
+ static constexpr int size()
301
+ {
302
+ int res = data[0];
303
+ for (int i=1; i < N; ++i)
304
+ res *= data[i];
305
+
306
+ return res;
307
+ }
308
+ };
309
+
310
+ // simple helper to compute strides from a shape up to 4d
311
+ template <typename Shape>
312
+ struct compute_strides;
313
+
314
+ // 1D
315
+ template <int D0>
316
+ struct compute_strides< tile_tuple_t<D0> > { using Stride = tile_tuple_t<1>; };
317
+ // 2D
318
+ template <int D0, int D1>
319
+ struct compute_strides< tile_tuple_t<D0, D1> > { using Stride = tile_tuple_t<D1, 1>; };
320
+ // 3D
321
+ template <int D0, int D1, int D2>
322
+ struct compute_strides< tile_tuple_t<D0, D1, D2> > { using Stride = tile_tuple_t<(D1 * D2), D2, 1>; };
323
+ // 4D
324
+ template <int D0, int D1, int D2, int D3>
325
+ struct compute_strides< tile_tuple_t<D0, D1, D2, D3> > { using Stride = tile_tuple_t<(D1 * D2 * D3), (D2 * D3), D3, 1>; };
326
+
327
+
328
+ // alias of tuple to represent shapes
329
+ template <int... V>
330
+ using tile_shape_t = tile_tuple_t<V...>;
331
+
332
+ // alias of tuple to represent stride
333
+ template <int... V>
334
+ using tile_stride_t = tile_tuple_t<V...>;
335
+
336
+
337
+ // helper to remove a dimension from a shape (used for axis reductions)
338
+ template<int Axis, typename Shape>
339
+ struct tile_shape_remove_dim {
340
+ static_assert(Axis >= 0 && Axis < Shape::N, "Axis out of bounds for tile_shape_remove_dim");
341
+ };
342
+
343
+ // 1D -> scalar
344
+ template<int D0>
345
+ struct tile_shape_remove_dim<0, tile_shape_t<D0>> {
346
+ using type = tile_shape_t<1>;
347
+ };
348
+
349
+ // 2D -> 1D
350
+ template<int D0, int D1>
351
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1>> {
352
+ using type = tile_shape_t<D1>;
353
+ };
354
+
355
+ template<int D0, int D1>
356
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1>> {
357
+ using type = tile_shape_t<D0>;
358
+ };
359
+
360
+ // 3D -> 2D
361
+ template<int D0, int D1, int D2>
362
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2>> {
363
+ using type = tile_shape_t<D1, D2>;
364
+ };
365
+
366
+ template<int D0, int D1, int D2>
367
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2>> {
368
+ using type = tile_shape_t<D0, D2>;
369
+ };
370
+
371
+ template<int D0, int D1, int D2>
372
+ struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2>> {
373
+ using type = tile_shape_t<D0, D1>;
374
+ };
375
+
376
+ // 4D -> 3D
377
+ template<int D0, int D1, int D2, int D3>
378
+ struct tile_shape_remove_dim<0, tile_shape_t<D0, D1, D2, D3>> {
379
+ using type = tile_shape_t<D1, D2, D3>;
380
+ };
381
+
382
+ template<int D0, int D1, int D2, int D3>
383
+ struct tile_shape_remove_dim<1, tile_shape_t<D0, D1, D2, D3>> {
384
+ using type = tile_shape_t<D0, D2, D3>;
385
+ };
386
+
387
+ template<int D0, int D1, int D2, int D3>
388
+ struct tile_shape_remove_dim<2, tile_shape_t<D0, D1, D2, D3>> {
389
+ using type = tile_shape_t<D0, D1, D3>;
390
+ };
391
+
392
+ template<int D0, int D1, int D2, int D3>
393
+ struct tile_shape_remove_dim<3, tile_shape_t<D0, D1, D2, D3>> {
394
+ using type = tile_shape_t<D0, D1, D2>;
395
+ };
396
+
397
+
398
+ // helper to insert an axis value into a coordinate (inverse of removing dimension)
399
+ // used for mapping output coordinates back to input coordinates during axis reduction
400
+ template<int Axis, int N>
401
+ CUDA_CALLABLE constexpr auto tile_coord_insert_axis(const tile_coord_t<N>& coord, int axis_val)
402
+ {
403
+ static_assert(Axis >= 0 && Axis <= N, "Axis out of bounds for tile_coord_insert_axis");
404
+
405
+ if constexpr (N == 0)
406
+ {
407
+ // Scalar -> 1D
408
+ static_assert(Axis == 0, "Invalid axis for scalar coordinate");
409
+ return tile_coord(axis_val);
410
+ }
411
+ else if constexpr (N == 1)
412
+ {
413
+ // 1D -> 2D
414
+ if constexpr (Axis == 0)
415
+ return tile_coord(axis_val, coord[0]);
416
+ else
417
+ return tile_coord(coord[0], axis_val);
418
+ }
419
+ else if constexpr (N == 2)
420
+ {
421
+ // 2D -> 3D
422
+ if constexpr (Axis == 0)
423
+ return tile_coord(axis_val, coord[0], coord[1]);
424
+ else if constexpr (Axis == 1)
425
+ return tile_coord(coord[0], axis_val, coord[1]);
426
+ else
427
+ return tile_coord(coord[0], coord[1], axis_val);
428
+ }
429
+ else // N == 3
430
+ {
431
+ // 3D -> 4D
432
+ if constexpr (Axis == 0)
433
+ return tile_coord(axis_val, coord[0], coord[1], coord[2]);
434
+ else if constexpr (Axis == 1)
435
+ return tile_coord(coord[0], axis_val, coord[1], coord[2]);
436
+ else if constexpr (Axis == 2)
437
+ return tile_coord(coord[0], coord[1], axis_val, coord[2]);
438
+ else
439
+ return tile_coord(coord[0], coord[1], coord[2], axis_val);
440
+ }
441
+ }
442
+
443
+
444
+ // represents a tile stored in global memory with dynamic strides
445
+ // used to represent the source and offset for tile loads to register/shared
446
+ // BoundsCheck: when true (default), validates array access bounds; when false, skips validation for performance
447
+ template <typename T, typename Shape_, bool BoundsCheck=true>
448
+ struct tile_global_t
449
+ {
450
+ using Type = T;
451
+ using Shape = Shape_;
452
+ using Coord = tile_coord_t<Shape::N>;
453
+
454
+ array_t<T> data;
455
+ Coord offset;
456
+
457
+ tile_global_t(array_t<T>& a, const Coord& c) : data(a), offset(c)
458
+ {
459
+ }
460
+
461
+ inline CUDA_CALLABLE int index_from_coord(const Coord& coord) const
462
+ {
463
+ // element index
464
+ int index = 0;
465
+
466
+ WP_PRAGMA_UNROLL
467
+ for (int i=0; i < Shape::N; ++i)
468
+ {
469
+ // global = offset + coord
470
+ int c = offset[i] + coord[i];
471
+ index += data.strides[i]*c;
472
+ }
473
+
474
+ return index/sizeof(T);
475
+ }
476
+
477
+ inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
478
+ {
479
+ if constexpr (BoundsCheck)
480
+ {
481
+ // element index
482
+ int index = 0;
483
+
484
+ WP_PRAGMA_UNROLL
485
+ for (int i=0; i < Shape::N; ++i)
486
+ {
487
+ // global = offset + coord
488
+ int c = offset[i] + coord[i];
489
+
490
+ // handle out of bounds case
491
+ if (c >= data.shape[i])
492
+ return false;
493
+ else
494
+ index += data.strides[i]*c;
495
+ }
496
+
497
+ // array strides are in bytes so we convert to elements
498
+ out = index / sizeof(T);
499
+ return true;
500
+ }
501
+ else
502
+ {
503
+ out = index_from_coord(coord);
504
+ return true;
505
+ }
506
+ }
507
+
508
+ inline CUDA_CALLABLE T load(const Coord& coord) const
509
+ {
510
+ int i;
511
+ if (index(coord, i))
512
+ return data.data[i];
513
+ else
514
+ return T(0);
515
+ }
516
+
517
+ inline CUDA_CALLABLE T load_grad(const Coord& coord) const
518
+ {
519
+ int i;
520
+ if (index(coord, i))
521
+ return data.grad[i];
522
+ else
523
+ return T(0);
524
+ }
525
+
526
+ inline CUDA_CALLABLE void store(const Coord& coord, const T& x) const
527
+ {
528
+ int i;
529
+ if (index(coord, i))
530
+ data.data[i] = x;
531
+ }
532
+
533
+ inline CUDA_CALLABLE T atomic_add(const Coord& coord, const T& value) const
534
+ {
535
+ int i;
536
+ if (index(coord, i))
537
+ return wp::atomic_add(&data.data[i], value);
538
+ else
539
+ return T(0);
540
+ }
541
+
542
+ inline CUDA_CALLABLE T atomic_add_grad(const Coord& coord, const T& grad) const
543
+ {
544
+ int i;
545
+ if (index(coord, i))
546
+ return wp::atomic_add(&data.grad[i], grad);
547
+ else
548
+ return T(0);
549
+ }
550
+ };
551
+
552
+
553
+ template <typename Shape_>
554
+ struct tile_layout_register_t
555
+ {
556
+ using Shape = Shape_;
557
+ using Coord = tile_coord_t<Shape::N>;
558
+
559
+ static constexpr int Size = Shape::size();
560
+ static constexpr int NumRegs = (Size + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
561
+ static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
562
+
563
+ static inline CUDA_CALLABLE int linear_from_register(int reg)
564
+ {
565
+ return WP_TILE_THREAD_IDX + reg*WP_TILE_BLOCK_DIM;
566
+ }
567
+
568
+ static inline CUDA_CALLABLE int linear_from_coord(Coord c)
569
+ {
570
+ int linear = 0;
571
+ int stride = 1;
572
+
573
+ WP_PRAGMA_UNROLL
574
+ for (int i=Shape::N-1; i >= 0; --i)
575
+ {
576
+ linear += c[i] * stride;
577
+ stride *= Shape::dim(i);
578
+ }
579
+ return linear;
580
+ }
581
+
582
+ static inline CUDA_CALLABLE auto coord_from_linear(int linear)
583
+ {
584
+ Coord c;
585
+
586
+ WP_PRAGMA_UNROLL
587
+ for (int i=Shape::N-1; i >= 0; --i)
588
+ {
589
+ c[i] = linear%Shape::dim(i);
590
+ linear /= Shape::dim(i);
591
+ }
592
+
593
+ return c;
594
+ }
595
+
596
+ static inline CUDA_CALLABLE int thread_from_linear(int linear)
597
+ {
598
+ const int thread = linear%WP_TILE_BLOCK_DIM;
599
+ return thread;
600
+ }
601
+
602
+ static inline CUDA_CALLABLE int register_from_linear(int linear)
603
+ {
604
+ const int reg = linear/WP_TILE_BLOCK_DIM;
605
+ return reg;
606
+ }
607
+
608
+ static inline CUDA_CALLABLE bool valid(int linear)
609
+ {
610
+ if (Aligned || linear < Size)
611
+ return true;
612
+ else
613
+ return false;
614
+ }
615
+
616
+ };
617
+
618
+ // represents a tile stored in registers across a block
619
+ template <typename T, typename L>
620
+ struct tile_register_t
621
+ {
622
+ using Type = T;
623
+ using Layout = L;
624
+
625
+ T data[Layout::NumRegs];
626
+
627
+ inline CUDA_CALLABLE tile_register_t(T value=T(0.0))
628
+ {
629
+ // zero-initialize by default necessary for tile adjoints
630
+ // need to check if this results in worse codegen
631
+ // than doing adj_var = tile_zeros() explicitly
632
+ // in backwards pass and letting default constructor
633
+ // avoid initialization
634
+
635
+ for (int i=0; i < Layout::NumRegs; ++i)
636
+ data[i] = value;
637
+ }
638
+
639
+ template <bool BoundsCheck>
640
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
641
+ {
642
+ copy_from_global(t);
643
+ return *this;
644
+ }
645
+
646
+ // define the += operator which is used during backward pass codegen
647
+ // when returning a register tile from a user defined function
648
+ inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, Layout>& rhs)
649
+ {
650
+ grad_add(rhs);
651
+ return *this;
652
+ }
653
+
654
+ inline CUDA_CALLABLE T& operator()(int reg)
655
+ {
656
+ assert(reg < Layout::NumRegs);
657
+ return data[reg];
658
+ }
659
+
660
+ inline CUDA_CALLABLE const T& operator()(int reg) const
661
+ {
662
+ assert(reg < Layout::NumRegs);
663
+ return data[reg];
664
+ }
665
+
666
+ inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
667
+ {
668
+ for (int i=0; i < Layout::NumRegs; ++i)
669
+ data[i] = tile.data[i];
670
+ }
671
+
672
+ inline CUDA_CALLABLE void zero()
673
+ {
674
+ for (int i=0; i < Layout::NumRegs; ++i)
675
+ data[i] = T(0);
676
+ }
677
+
678
+ // extract a single tile element to a native type
679
+ template <typename Coord>
680
+ inline CUDA_CALLABLE Type extract(const Coord& c)
681
+ {
682
+ // map from logical coords (i, j) -> (thread, reg)
683
+ const int linear = Layout::linear_from_coord(c);
684
+ const int thread = Layout::thread_from_linear(linear);
685
+ const int reg = Layout::register_from_linear(linear);
686
+
687
+ #if defined(__CUDA_ARCH__)
688
+ __shared__ Type scratch;
689
+ #else
690
+ Type scratch;
691
+ #endif
692
+
693
+ // ensure any previously scheduled threads have finished reading from scratch
694
+ WP_TILE_SYNC();
695
+
696
+ if (WP_TILE_THREAD_IDX == thread)
697
+ {
698
+ scratch = data[reg];
699
+ }
700
+
701
+ // ensure extraction thread has updated smem
702
+ WP_TILE_SYNC();
703
+
704
+ return scratch;
705
+ }
706
+
707
+
708
+ // backward version of scalar extract
709
+ template <typename Coord>
710
+ inline CUDA_CALLABLE void adj_extract(const Coord& c, Type adj_ret)
711
+ {
712
+ // map from logical coords (i, j) -> (thread, reg)
713
+ const int linear = Layout::linear_from_coord(c);
714
+ const int thread = Layout::thread_from_linear(linear);
715
+ const int reg = Layout::register_from_linear(linear);
716
+
717
+ if (WP_TILE_THREAD_IDX == thread)
718
+ {
719
+ data[reg] += adj_ret;
720
+ }
721
+ }
722
+
723
+ inline CUDA_CALLABLE void print() const;
724
+
725
+
726
+ // return the in-register version of this tile (nop)
727
+ inline CUDA_CALLABLE auto& copy_to_register()
728
+ {
729
+ return *this;
730
+ }
731
+
732
+ inline CUDA_CALLABLE const auto& copy_to_register() const
733
+ {
734
+ return *this;
735
+ }
736
+
737
+ // apply a lambda to all valid entries in the tile
738
+ // Op should be a functor that takes a register index and tile_coord_t as input
739
+ template <typename Op>
740
+ void apply(Op op)
741
+ {
742
+ WP_PRAGMA_UNROLL
743
+ for (int i=0; i < Layout::NumRegs; ++i)
744
+ {
745
+ int linear = Layout::linear_from_register(i);
746
+ if (!Layout::valid(linear))
747
+ break;
748
+
749
+ auto c = Layout::coord_from_linear(linear);
750
+ op(i, c);
751
+ }
752
+ }
753
+
754
+
755
+ // in-place gradient zero
756
+ inline CUDA_CALLABLE void grad_zero()
757
+ {
758
+ zero();
759
+ }
760
+
761
+ // accumulate gradients onto this tile
762
+ inline CUDA_CALLABLE void grad_add(const tile_register_t<T, Layout>& tile)
763
+ {
764
+ for (int i=0; i < Layout::NumRegs; ++i)
765
+ data[i] += tile.data[i];
766
+ }
767
+
768
+ inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
769
+ {
770
+ apply([&](int reg, auto c) {data[reg] += global.load_grad(c);});
771
+ }
772
+
773
+ inline CUDA_CALLABLE auto& grad_to_register()
774
+ {
775
+ // nop for register tiles
776
+ return *this;
777
+ }
778
+
779
+ template <typename Global>
780
+ inline CUDA_CALLABLE void copy_to_global(const Global& dest)
781
+ {
782
+ apply([&](int reg, auto c) { dest.store(c, data[reg]); });
783
+ }
784
+
785
+ template <typename Global>
786
+ inline CUDA_CALLABLE void copy_from_global(const Global& src)
787
+ {
788
+ apply([&](int reg, auto c) { data[reg] = src.load(c); });
789
+ }
790
+
791
+ // add a register tile to a global array
792
+ template <typename Global>
793
+ inline CUDA_CALLABLE auto atomic_add(const Global& dest)
794
+ {
795
+ // allocate a tile to hold previous dest value
796
+ auto previous = *this;
797
+
798
+ apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add(c, data[reg]); });
799
+ return previous;
800
+ }
801
+
802
+ // add a register tile to the gradient of a global array
803
+ template <typename Global>
804
+ inline CUDA_CALLABLE auto atomic_add_grad(const Global& dest)
805
+ {
806
+ // allocate a tile to hold previous dest value
807
+ auto previous = *this;
808
+
809
+ apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add_grad(c, data[reg]); });
810
+ return previous;
811
+ }
812
+ };
813
+
814
+
815
+ // helper to allocate a register tile like another tile
816
+ // users can either specify a template explicitly or
817
+ // pass in another concrete instance
818
+ template<typename Tile>
819
+ auto tile_register_like(Tile* t=nullptr)
820
+ {
821
+ using T = typename Tile::Type;
822
+ using L = typename Tile::Layout;
823
+
824
+ return tile_register_t<T, tile_layout_register_t<typename L::Shape>>(T(0.0));
825
+ }
826
+
827
+ // helper to construct a register tile from a type and a list of dims
828
+ template <typename T, int... Dims>
829
+ auto tile_register()
830
+ {
831
+ return tile_register_t<T, tile_layout_register_t<tile_shape_t<Dims...>>>();
832
+ }
833
+
834
+ inline CUDA_CALLABLE int tile_align(int num_bytes)
835
+ {
836
+ // note this much match value in Python types.py
837
+ const int alignment = 16;
838
+
839
+ const int num_bytes_abs = num_bytes < 0 ? - num_bytes : num_bytes;
840
+ const int sign = num_bytes < 0 ? - 1 : 1;
841
+
842
+ return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
843
+ }
844
+
845
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
846
+ // On the CPU we use a fixed size block of stack memory for shared tile allocations.
847
+ // We store a pointer to the current allocation storage either in a reserved register
848
+ // (AArch64) or a static variable (x86-64).
849
+ #if !defined(__CUDA_ARCH__)
850
+ class tile_shared_storage_t;
851
+ #if defined(__aarch64__)
852
+ // x28 is is the last callee-saved register on AArch64. This allows us to call externally
853
+ // compiled functions without worrying about clobbering the pointer.
854
+ // We pass -target-feature +reserve-x28 to Clang to exclude it from register allocation.
855
+ register tile_shared_storage_t* shared_tile_storage asm("x28");
856
+ #else
857
+ // Ideally this would be thread_local, but LLVM's JIT doesn't support TLS yet
858
+ // There is also no support for something like -ffixed-r15 either
859
+ static tile_shared_storage_t* shared_tile_storage;
860
+ #endif
861
+ #endif
862
+ #endif
863
+
864
+ // This class manages a block of "shared" memory for use by tiles.
865
+ // On the GPU this maps to dynamic shared memory, while on the CPU we allocate
866
+ // a fixed size block of memory on the stack and manage allocations from it.
867
+ // An instance of this class gets created at the start of a kernel.
868
+ class tile_shared_storage_t
869
+ {
870
+ private:
871
+ #if !defined(__CUDA_ARCH__)
872
+ #define WP_MAX_CPU_SHARED 256*1024
873
+ #if defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
874
+ tile_shared_storage_t* old_value;
875
+ unsigned int smem_base[WP_TILE_BLOCK_DIM];
876
+ char dynamic_smem_base[WP_MAX_CPU_SHARED]; // on CPU allocate a fixed 256k block to use for shared allocs
877
+ #endif
878
+ #endif
879
+
880
+ // we maintain a per-thread offset into dynamic
881
+ // shared memory that allows us to keep track of
882
+ // current use across dynamic function calls
883
+ static inline CUDA_CALLABLE unsigned int* get_smem_base()
884
+ {
885
+ #if defined(__CUDA_ARCH__)
886
+ __shared__ unsigned int smem_base[WP_TILE_BLOCK_DIM];
887
+ return smem_base;
888
+ #elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
889
+ return shared_tile_storage->smem_base;
890
+ #else
891
+ static unsigned int smem_base[WP_TILE_BLOCK_DIM];
892
+ return smem_base;
893
+ #endif
894
+ }
895
+
896
+ static inline CUDA_CALLABLE char* get_dynamic_smem_base()
897
+ {
898
+ #if defined(__CUDA_ARCH__)
899
+ extern __shared__ char dynamic_smem_base[];
900
+ return dynamic_smem_base;
901
+ #elif defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
902
+ return shared_tile_storage->dynamic_smem_base;
903
+ #else
904
+ static char dynamic_smem_base[WP_MAX_CPU_SHARED];
905
+ return dynamic_smem_base;
906
+ #endif
907
+ }
908
+
909
+ public:
910
+ // cppcheck-suppress uninitMemberVar
911
+ inline CUDA_CALLABLE tile_shared_storage_t()
912
+ {
913
+ #if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
914
+ // On the CPU save a pointer to this instance in a reserved register
915
+ // or static variable so it can be accessed from anywhere within a kernel.
916
+ old_value = shared_tile_storage;
917
+ shared_tile_storage = this;
918
+ #endif
919
+
920
+ init();
921
+ }
922
+
923
+ inline CUDA_CALLABLE ~tile_shared_storage_t()
924
+ {
925
+ check();
926
+
927
+ #if !defined(__CUDA_ARCH__) && defined(WP_ENABLE_TILES_IN_STACK_MEMORY)
928
+ shared_tile_storage = old_value;
929
+ #endif
930
+ }
931
+
932
+ static inline CUDA_CALLABLE void init()
933
+ {
934
+ unsigned int* smem_base = get_smem_base();
935
+
936
+ smem_base[WP_TILE_THREAD_IDX] = 0;
937
+ }
938
+
939
+ static inline CUDA_CALLABLE void check()
940
+ {
941
+ unsigned int* smem_base = get_smem_base();
942
+
943
+ assert(smem_base[WP_TILE_THREAD_IDX] == 0);
944
+ }
945
+
946
+ static inline CUDA_CALLABLE void* alloc(int num_bytes)
947
+ {
948
+ unsigned int* smem_base = get_smem_base();
949
+ char* dynamic_smem_base = get_dynamic_smem_base();
950
+
951
+ const unsigned int offset = smem_base[WP_TILE_THREAD_IDX];
952
+
953
+ // one entry per-thread so no need for synchronization
954
+ smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
955
+
956
+ #if !defined(__CUDA_ARCH__)
957
+ assert(smem_base[WP_TILE_THREAD_IDX] <= WP_MAX_CPU_SHARED);
958
+ #endif
959
+
960
+ return &(dynamic_smem_base[offset]);
961
+ }
962
+ };
963
+
964
+
965
+ template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
966
+ struct tile_layout_strided_t
967
+ {
968
+ using Shape = Shape_;
969
+ using Stride = Stride_;
970
+ using Coord = tile_coord_t<Shape::N>;
971
+
972
+ static constexpr int Size = Shape::size();
973
+ static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
974
+
975
+ static inline CUDA_CALLABLE auto coord_from_linear(int linear)
976
+ {
977
+ assert(linear < Size);
978
+
979
+ Coord c;
980
+
981
+ WP_PRAGMA_UNROLL
982
+ for (int d=Shape::N-1; d >= 0; --d)
983
+ {
984
+ c[d] = linear%Shape::dim(d);
985
+ linear /= Shape::dim(d);
986
+ }
987
+
988
+ return c;
989
+ }
990
+
991
+ static inline CUDA_CALLABLE int index_from_coord(Coord c)
992
+ {
993
+ int index = 0;
994
+
995
+ WP_PRAGMA_UNROLL
996
+ for (int d=0; d < Shape::N; ++d)
997
+ {
998
+ assert(c[d] < Shape::dim(d));
999
+
1000
+ index += c[d]*Stride::dim(d);
1001
+ }
1002
+
1003
+ return index;
1004
+ }
1005
+
1006
+ // checks whether a strided layout is unique, i.e.: if memory locations are only
1007
+ // ever referred to by one element in the tile, this is a basic test that only
1008
+ // checks for broadcast dimensions, it would be possible to do the full check
1009
+ // using sorted shape/strides in Python and add it as a template parameter to the type
1010
+ static constexpr bool is_unique()
1011
+ {
1012
+ constexpr int N = Shape::N;
1013
+
1014
+ // check for any broadcast dimensions
1015
+ for (int i=0; i < N; ++i)
1016
+ if (Stride::dim(i) == 0)
1017
+ return false;
1018
+
1019
+ return true;
1020
+ }
1021
+
1022
+ static constexpr bool Unique = is_unique();
1023
+
1024
+ static inline CUDA_CALLABLE bool valid(int linear)
1025
+ {
1026
+ return linear < Size;
1027
+ }
1028
+
1029
+ };
1030
+
1031
+
1032
+ template <typename T, typename L, bool Owner_=true>
1033
+ struct tile_shared_t
1034
+ {
1035
+ using Type = T;
1036
+ using Layout = L;
1037
+ static constexpr bool Owner = Owner_;
1038
+
1039
+ struct Storage
1040
+ {
1041
+ T* ptr;
1042
+
1043
+ Storage(T* p) : ptr(p) {}
1044
+
1045
+ inline CUDA_CALLABLE T& operator()(typename Layout::Coord c)
1046
+ {
1047
+ assert(ptr);
1048
+
1049
+ int index = Layout::index_from_coord(c);
1050
+ return ptr[index];
1051
+ }
1052
+
1053
+ inline CUDA_CALLABLE const T& operator()(typename Layout::Coord c) const
1054
+ {
1055
+ assert(ptr);
1056
+
1057
+ int index = Layout::index_from_coord(c);
1058
+ return ptr[index];
1059
+ }
1060
+
1061
+ inline CUDA_CALLABLE T& operator()(int linear)
1062
+ {
1063
+ assert(ptr);
1064
+ assert(Layout::valid(linear));
1065
+
1066
+ auto c = Layout::coord_from_linear(linear);
1067
+ return (*this)(c);
1068
+ }
1069
+
1070
+ inline CUDA_CALLABLE const T& operator()(int linear) const
1071
+ {
1072
+ assert(ptr);
1073
+ assert(Layout::valid(linear));
1074
+
1075
+ auto c = Layout::coord_from_linear(linear);
1076
+ return (*this)(c);
1077
+ }
1078
+ };
1079
+
1080
+ Storage data;
1081
+ Storage grad;
1082
+
1083
+ // we need to track whether or not this tile's data has been initialized.
1084
+ // once true, any re-initialization of data that follows needs a WP_TILE_SYNC()
1085
+ // call to precede it, to allow threads that are still reading from this tile
1086
+ // to complete their work. e.g, in a dynamic loop:
1087
+ // for i in range(x):
1088
+ // tile = wp.tile_load(arr, i, TILE_SIZE, storage="shared")
1089
+ // # read from tile...
1090
+ bool initialized;
1091
+
1092
+ // default initialization (non-initialized)
1093
+ inline CUDA_CALLABLE tile_shared_t() : data(nullptr), grad(nullptr), initialized(false)
1094
+ {
1095
+ }
1096
+
1097
+ // we delete the copy constructor because in the case the shared tile is owning,
1098
+ // this leads to a double deallocation.
1099
+ // this also forces one to handle copies explicitly
1100
+ inline CUDA_CALLABLE tile_shared_t(const tile_shared_t& other) : data(other.data), grad(other.grad), initialized(other.initialized)
1101
+ {
1102
+ static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
1103
+ }
1104
+
1105
+ // move constructor
1106
+ inline CUDA_CALLABLE tile_shared_t(tile_shared_t&& other) : data(other.data), grad(other.grad), initialized(other.initialized)
1107
+ {
1108
+ other.data.ptr = nullptr;
1109
+ other.grad.ptr = nullptr;
1110
+ }
1111
+
1112
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
1113
+ inline CUDA_CALLABLE tile_shared_t(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& other) : data(other.data.ptr), grad(other.grad.ptr), initialized(other.initialized)
1114
+ {
1115
+ static_assert(!Owner, "Copy constructor is only supported for non-owning tiles.");
1116
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
1117
+ }
1118
+
1119
+ // initialize from an existing tile's memory
1120
+ inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
1121
+ {
1122
+ }
1123
+
1124
+ inline CUDA_CALLABLE ~tile_shared_t()
1125
+ {
1126
+ if (Owner)
1127
+ {
1128
+ // update our per-thread shared memory allocator
1129
+ if (data.ptr)
1130
+ tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
1131
+
1132
+ if (grad.ptr)
1133
+ tile_shared_storage_t::alloc(-Layout::Size*int(sizeof(T)));
1134
+ }
1135
+ }
1136
+
1137
+ // assign from a register tile
1138
+ inline CUDA_CALLABLE auto& operator=(const tile_register_t<Type, tile_layout_register_t<typename Layout::Shape>>& t)
1139
+ {
1140
+ assign(t);
1141
+ return *this;
1142
+ }
1143
+
1144
+ // construct from another shared tile, this constructor
1145
+ // is invoked for reshape operations like `wp.tile_transpose()`
1146
+ // or `wp::copy()`
1147
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
1148
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& rhs)
1149
+ {
1150
+ // check dimensions are compatible
1151
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
1152
+
1153
+
1154
+ if (Owner)
1155
+ {
1156
+ // if the tile owns the data we need to copy
1157
+ assign(rhs);
1158
+ }
1159
+ else
1160
+ {
1161
+ // alias tile directly
1162
+ data.ptr = rhs.data.ptr;
1163
+ grad.ptr = rhs.grad.ptr;
1164
+ initialized = rhs.initialized;
1165
+ }
1166
+
1167
+ return *this;
1168
+ }
1169
+
1170
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t& rhs)
1171
+ {
1172
+ if (Owner)
1173
+ {
1174
+ // if the tile owns the data we need to copy
1175
+ assign(rhs);
1176
+ }
1177
+ else
1178
+ {
1179
+ // alias tile directly
1180
+ data.ptr = rhs.data.ptr;
1181
+ grad.ptr = rhs.grad.ptr;
1182
+ initialized = rhs.initialized;
1183
+ }
1184
+
1185
+ return *this;
1186
+ }
1187
+
1188
+ // assign from a global tile (load)
1189
+
1190
+ template <bool BoundsCheck>
1191
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape, BoundsCheck>& t)
1192
+ {
1193
+ copy_from_global(t);
1194
+ return *this;
1195
+ }
1196
+
1197
+ // assign from a constant value
1198
+ inline CUDA_CALLABLE auto& operator=(const T& x)
1199
+ {
1200
+ // sync if we are re-initializing data so that any threads that are still
1201
+ // reading from this tile can complete their work, e.g.: if re-assigning
1202
+ // to a tile during a dynamic loop
1203
+ if (initialized)
1204
+ WP_TILE_SYNC();
1205
+
1206
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
1207
+ data(i) = x;
1208
+
1209
+ initialized = true;
1210
+ WP_TILE_SYNC();
1211
+ return *this;
1212
+ }
1213
+
1214
+ // define the += operator which is used during backward pass codegen
1215
+ // when returning a register tile from a user defined function
1216
+ template<typename OtherLayout>
1217
+ inline CUDA_CALLABLE auto& operator += (const tile_register_t<T, OtherLayout>& rhs)
1218
+ {
1219
+ grad_add(rhs);
1220
+ return *this;
1221
+ }
1222
+
1223
+ inline CUDA_CALLABLE auto& operator += (const tile_shared_t<T, Layout>& rhs)
1224
+ {
1225
+ grad_add(rhs);
1226
+ return *this;
1227
+ }
1228
+
1229
+ // in-place zero
1230
+ inline CUDA_CALLABLE void zero()
1231
+ {
1232
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
1233
+ data(i) = T(0);
1234
+
1235
+ WP_TILE_SYNC();
1236
+ }
1237
+
1238
+ // extract a single tile element to a native type
1239
+ inline CUDA_CALLABLE Type extract(const typename Layout::Coord& c)
1240
+ {
1241
+ return data(c);
1242
+ }
1243
+
1244
+ // backward of scalar extraction
1245
+ inline CUDA_CALLABLE void adj_extract(const typename Layout::Coord& c, Type adj_ret)
1246
+ {
1247
+ // since multiple threads may extract the same element
1248
+ // we need to accumulate using atomic operations
1249
+ wp::atomic_add(&grad(c), adj_ret);
1250
+
1251
+ WP_TILE_SYNC();
1252
+ }
1253
+
1254
+ // add scalar value onto a single tile element
1255
+ inline CUDA_CALLABLE void add_inplace(const typename Layout::Coord& c, const Type& x)
1256
+ {
1257
+ // since multiple threads may add to the same element
1258
+ // we need to accumulate using atomic operations
1259
+ wp::atomic_add(&data(c), x);
1260
+
1261
+ WP_TILE_SYNC();
1262
+ }
1263
+
1264
+ // backward of inplace scalar addition
1265
+ inline CUDA_CALLABLE void adj_add_inplace(const typename Layout::Coord& c, Type& adj_x)
1266
+ {
1267
+ adj_x += grad(c);
1268
+ }
1269
+
1270
+ // subtract scalar value from a single tile element
1271
+ inline CUDA_CALLABLE void sub_inplace(const typename Layout::Coord& c, const Type& x)
1272
+ {
1273
+ // since multiple threads may add to the same element
1274
+ // we need to accumulate using atomic operations
1275
+ wp::atomic_add(&data(c), -x);
1276
+
1277
+ WP_TILE_SYNC();
1278
+ }
1279
+
1280
+ // backward of inplace scalar subtraction
1281
+ inline CUDA_CALLABLE void adj_sub_inplace(const typename Layout::Coord& c, Type& adj_x)
1282
+ {
1283
+ adj_x -= grad(c);
1284
+ }
1285
+
1286
+ // perform AND between a scalar value and a single tile element
1287
+ inline CUDA_CALLABLE void bit_and_inplace(const typename Layout::Coord& c, const Type& x)
1288
+ {
1289
+ // since multiple threads may access the same element
1290
+ // we need to access using atomic operations
1291
+ wp::atomic_and(&data(c), x);
1292
+
1293
+ WP_TILE_SYNC();
1294
+ }
1295
+
1296
+ // backward of inplace scalar AND
1297
+ inline CUDA_CALLABLE void adj_bit_and_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1298
+
1299
+
1300
+ // perform OR between a scalar value and a single tile element
1301
+ inline CUDA_CALLABLE void bit_or_inplace(const typename Layout::Coord& c, const Type& x)
1302
+ {
1303
+ // since multiple threads may access the same element
1304
+ // we need to access using atomic operations
1305
+ wp::atomic_or(&data(c), x);
1306
+
1307
+ WP_TILE_SYNC();
1308
+ }
1309
+
1310
+ // backward of inplace scalar OR
1311
+ inline CUDA_CALLABLE void adj_bit_or_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1312
+
1313
+ // perform XOR between a scalar value and a single tile element
1314
+ inline CUDA_CALLABLE void bit_xor_inplace(const typename Layout::Coord& c, const Type& x)
1315
+ {
1316
+ // since multiple threads may access the same element
1317
+ // we need to access using atomic operations
1318
+ wp::atomic_xor(&data(c), x);
1319
+
1320
+ WP_TILE_SYNC();
1321
+ }
1322
+
1323
+ // backward of inplace scalar XOR
1324
+ inline CUDA_CALLABLE void adj_bit_xor_inplace(const typename Layout::Coord& c, Type& adj_x) {}
1325
+
1326
+ // copy register tile to shared
1327
+ template <typename Tile>
1328
+ inline CUDA_CALLABLE void assign(const Tile& tile)
1329
+ {
1330
+ if (initialized)
1331
+ WP_TILE_SYNC();
1332
+
1333
+ WP_PRAGMA_UNROLL
1334
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
1335
+ {
1336
+ const int linear = Tile::Layout::linear_from_register(i);
1337
+
1338
+ // handle case where tile size is not
1339
+ // aligned to block dimensions
1340
+ if (!Tile::Layout::valid(linear))
1341
+ break;
1342
+
1343
+ data(linear) = tile.data[i];
1344
+ }
1345
+
1346
+ initialized = true;
1347
+ WP_TILE_SYNC();
1348
+ }
1349
+
1350
+ // shared tile deep copy
1351
+ template <typename OtherT, typename OtherLayout, bool OtherOwner>
1352
+ inline CUDA_CALLABLE void assign(const tile_shared_t<OtherT, OtherLayout, OtherOwner>& tile)
1353
+ {
1354
+ // check dimensions are compatible
1355
+ static_assert(Layout::Size == OtherLayout::Size, "Expected Size == OtherLayout::Size");
1356
+
1357
+ if (initialized)
1358
+ WP_TILE_SYNC();
1359
+
1360
+ WP_PRAGMA_UNROLL
1361
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1362
+ {
1363
+ auto c = Layout::coord_from_linear(i);
1364
+ data(c) = tile.data(c);
1365
+ }
1366
+
1367
+ initialized = true;
1368
+ WP_TILE_SYNC();
1369
+ }
1370
+
1371
+ // in-place gradient zero
1372
+ inline CUDA_CALLABLE void grad_zero()
1373
+ {
1374
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
1375
+ grad(i) = T(0);
1376
+
1377
+ WP_TILE_SYNC();
1378
+ }
1379
+
1380
+
1381
+ // accumulate gradients onto this tile
1382
+ template <typename Tile>
1383
+ inline CUDA_CALLABLE void grad_add(const Tile& tile)
1384
+ {
1385
+ WP_PRAGMA_UNROLL
1386
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
1387
+ {
1388
+ const int linear = Tile::Layout::linear_from_register(i);
1389
+
1390
+ // handle case where tile size is not
1391
+ // aligned to block dimensions
1392
+ if (!Tile::Layout::valid(linear))
1393
+ break;
1394
+
1395
+ // if the destination layout is unique (no broadcast dimensions)
1396
+ // then we can use regular non-atomic accmulation
1397
+ if (Layout::Unique)
1398
+ grad(linear) += tile.data[i];
1399
+ else
1400
+ // use shared memory atomics to accumulate gradients
1401
+ // since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
1402
+ // may map to a single location in shared memory
1403
+ wp::atomic_add(&grad(linear), tile.data[i]);
1404
+
1405
+ }
1406
+
1407
+ WP_TILE_SYNC();
1408
+ }
1409
+
1410
+ // accumulate gradients onto this tile from another shared tile
1411
+ inline CUDA_CALLABLE void grad_add(const tile_shared_t<T, Layout>& tile)
1412
+ {
1413
+ WP_PRAGMA_UNROLL
1414
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1415
+ {
1416
+ auto c = Layout::coord_from_linear(i);
1417
+ grad(c) += tile.grad(c);
1418
+ }
1419
+
1420
+ WP_TILE_SYNC();
1421
+ }
1422
+
1423
+ // accumulate gradient onto this tile from a global array
1424
+ inline CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1425
+ {
1426
+ WP_PRAGMA_UNROLL
1427
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1428
+ {
1429
+ auto c = Layout::coord_from_linear(i);
1430
+ T g = global.load_grad(c);
1431
+
1432
+ if (Layout::Unique)
1433
+ {
1434
+ // if the destination layout is unique (no broadcast dimensions)
1435
+ // then we can use regular non-atomic accumulation
1436
+ grad(c) += g;
1437
+ }
1438
+ else
1439
+ {
1440
+ // use shared memory atomics to accumulate gradients
1441
+ // since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
1442
+ // may map to a single location in shared memory
1443
+ wp::atomic_add(&grad(c), g);
1444
+ }
1445
+ }
1446
+
1447
+ WP_TILE_SYNC();
1448
+ }
1449
+
1450
+ // copy shared tile to register
1451
+ inline CUDA_CALLABLE auto grad_to_register()
1452
+ {
1453
+ using Tile = tile_register_t<T, tile_layout_register_t<typename Layout::Shape>>;
1454
+ Tile out;
1455
+
1456
+ WP_PRAGMA_UNROLL
1457
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
1458
+ {
1459
+ const int linear = Tile::Layout::linear_from_register(i);
1460
+
1461
+ if (!Tile::Layout::valid(linear))
1462
+ break;
1463
+
1464
+ out(i) = grad(linear);
1465
+ }
1466
+
1467
+ return out;
1468
+ }
1469
+
1470
+ // copy shared tile to register
1471
+ inline CUDA_CALLABLE auto copy_to_register() const
1472
+ {
1473
+
1474
+ auto out = tile_register_like(this);
1475
+
1476
+ using Layout = typename decltype(out)::Layout;
1477
+
1478
+ WP_PRAGMA_UNROLL
1479
+ for (int i=0; i < Layout::NumRegs; ++i)
1480
+ {
1481
+ const int linear = Layout::linear_from_register(i);
1482
+
1483
+ if (!Layout::valid(linear))
1484
+ break;
1485
+
1486
+ out(i) = data(linear);
1487
+ }
1488
+
1489
+ return out;
1490
+ }
1491
+
1492
+ template <typename Global>
1493
+ inline CUDA_CALLABLE void copy_to_global(const Global& dest)
1494
+ {
1495
+
1496
+ #if defined(__CUDA_ARCH__)
1497
+ // vectorized loads for specific input/output shapes
1498
+ if constexpr (Layout::Shape::N == 2)
1499
+ {
1500
+ constexpr int lastdim = Layout::Shape::N-1;
1501
+ constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1;
1502
+ const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T);
1503
+ const int elements = min(Layout::Shape::dim(1), (dest.data.shape[lastdim] - dest.offset[lastdim]));
1504
+ const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1505
+ const bool aligned_stride = (dest.data.strides[0]/sizeof(T))%Layout::Stride::dim(0) == 0;
1506
+
1507
+ float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
1508
+ const bool aligned_dst = (uint64_t)(dest128)%sizeof(float4) == 0;
1509
+
1510
+ constexpr int M = Layout::Shape::dim(0);
1511
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1512
+
1513
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_dst && aligned_stride && N)
1514
+ {
1515
+ // alias of shared tile with 128bit type
1516
+ using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1517
+ tile_shared_t<float4, SrcLayout, false> src128((float4*)data.ptr);
1518
+
1519
+ assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
1520
+ assert(((uint64_t)(dest128))%sizeof(float4) == 0);
1521
+
1522
+ const int stride_i = dest.data.strides[0]/sizeof(float4);
1523
+ const int stride_j = 1;
1524
+
1525
+ WP_PRAGMA_UNROLL
1526
+ for (int i=WP_TILE_THREAD_IDX; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
1527
+ {
1528
+ auto c = SrcLayout::coord_from_linear(i);
1529
+
1530
+ dest128[stride_i*c[0] + stride_j*c[1]] = src128.data(i);
1531
+ }
1532
+
1533
+ return;
1534
+ }
1535
+ }
1536
+
1537
+ #endif //defined(__CUDA_ARCH__)
1538
+
1539
+ // scalar bounds checked path
1540
+ WP_PRAGMA_UNROLL
1541
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1542
+ {
1543
+ auto c = Layout::coord_from_linear(i);
1544
+ dest.store(c, data(i));
1545
+ }
1546
+ }
1547
+
1548
+ inline CUDA_CALLABLE void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
1549
+ {
1550
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1551
+
1552
+ unsigned long long saddr = 0ULL;
1553
+ unsigned long long gaddr = 0ULL;
1554
+
1555
+ asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(saddr) : "l"(shared_dest));
1556
+ asm volatile("cvta.to.global.u64 %0, %1;" : "=l"(gaddr) : "l"(global_src));
1557
+
1558
+ // Use cp.async on newer architectures
1559
+ asm volatile(
1560
+ "cp.async.ca.shared.global [%0], [%1], 16;\n"
1561
+ :
1562
+ : "l"(saddr), "l"(gaddr)
1563
+ );
1564
+ #else
1565
+ // use regular load/store through register on older arches
1566
+ *shared_dest = *global_src;
1567
+ #endif
1568
+ }
1569
+
1570
+ inline CUDA_CALLABLE void cp_async_commit_and_wait_all_128()
1571
+ {
1572
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1573
+ asm volatile(
1574
+ "cp.async.commit_group;\n"
1575
+ "cp.async.wait_group 0;\n" ::);
1576
+ #endif
1577
+ }
1578
+
1579
+ template <typename Global>
1580
+ inline CUDA_CALLABLE void copy_from_global(const Global& src)
1581
+ {
1582
+ if (initialized)
1583
+ WP_TILE_SYNC();
1584
+
1585
+ #if defined(__CUDA_ARCH__)
1586
+
1587
+ // vectorized loads for specific input/output shapes
1588
+ if constexpr (Layout::Shape::N == 2)
1589
+ {
1590
+ constexpr int lastdim = Layout::Shape::N-1;
1591
+ constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1;
1592
+ const bool contiguous_src = src.data.strides[lastdim] == sizeof(T);
1593
+ const int elements = min(Layout::Shape::dim(1), (src.data.shape[lastdim] - src.offset[lastdim]));
1594
+ const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1595
+ const bool aligned_stride = (src.data.strides[0]/sizeof(T))%Layout::Stride::dim(0) == 0;
1596
+
1597
+ float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1598
+ const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
1599
+
1600
+ constexpr int M = Layout::Shape::dim(0);
1601
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1602
+
1603
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_src && aligned_stride && N)
1604
+ {
1605
+ // alias of shared tile with 128bit type
1606
+ using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1607
+ tile_shared_t<float4, DestLayout, false> dest128((float4*)data.ptr);
1608
+
1609
+ assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
1610
+ assert(((uint64_t)(src128))%sizeof(float4) == 0);
1611
+
1612
+ const int stride_i = src.data.strides[0]/sizeof(float4);
1613
+ const int stride_j = 1;
1614
+
1615
+ WP_PRAGMA_UNROLL
1616
+ for (int i=WP_TILE_THREAD_IDX; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
1617
+ {
1618
+ auto c = DestLayout::coord_from_linear(i);
1619
+
1620
+ #if WP_USE_ASYNC_PIPELINE
1621
+ cp_async_global_to_shared_128(&dest128.data(i), &src128[stride_i*c[0] + stride_j*c[1]]);
1622
+ #else
1623
+ dest128.data(i) = src128[stride_i*c[0] + stride_j*c[1]];
1624
+ #endif // WP_USE_ASYNC_PIPELINE
1625
+ }
1626
+
1627
+ #if WP_USE_ASYNC_PIPELINE
1628
+ cp_async_commit_and_wait_all_128();
1629
+ #endif // WP_USE_ASYNC_PIPELINE
1630
+
1631
+ initialized = true;
1632
+ WP_TILE_SYNC();
1633
+ return;
1634
+ }
1635
+ }
1636
+
1637
+ #endif //defined(__CUDA_ARCH__)
1638
+
1639
+ // scalar bounds checked path
1640
+ WP_PRAGMA_UNROLL
1641
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1642
+ {
1643
+ auto c = Layout::coord_from_linear(i);
1644
+ data(i) = src.load(c);
1645
+ }
1646
+
1647
+ initialized = true;
1648
+ WP_TILE_SYNC();
1649
+ }
1650
+
1651
+ template <typename Global>
1652
+ inline CUDA_CALLABLE auto atomic_add(Global& dest)
1653
+ {
1654
+ return copy_to_register().atomic_add(dest);
1655
+ }
1656
+
1657
+ template <typename Global>
1658
+ inline CUDA_CALLABLE auto atomic_add_grad(Global& dest)
1659
+ {
1660
+ return grad_to_register().atomic_add_grad(dest);
1661
+ }
1662
+
1663
+ // overload for integral types
1664
+ inline CUDA_CALLABLE void print_value(int x) const
1665
+ {
1666
+ printf("%d", x);
1667
+ }
1668
+
1669
+ // overload for floating point types
1670
+ template <typename ValueType>
1671
+ inline CUDA_CALLABLE void print_value(ValueType x) const
1672
+ {
1673
+ printf("%g", x);
1674
+ }
1675
+
1676
+ template <int Level = 0>
1677
+ inline CUDA_CALLABLE void print_values(const Storage& storage, int index=0) const
1678
+ {
1679
+ using Shape = typename Layout::Shape;
1680
+
1681
+ if constexpr (Level < Shape::N)
1682
+ {
1683
+ if constexpr (Level == Shape::N - 1)
1684
+ {
1685
+ // Special handling for 1D case
1686
+ printf("[");
1687
+ for (int i = 0; i < Shape::dim(Level); ++i)
1688
+ {
1689
+ print_value(storage(index + i));
1690
+
1691
+ if (i < Shape::dim(Level) - 1)
1692
+ {
1693
+ printf(" ");
1694
+ }
1695
+ }
1696
+ printf("]");
1697
+ }
1698
+ else if constexpr (Level == Shape::N - 2)
1699
+ {
1700
+ // Special handling for 2D case
1701
+ printf("[");
1702
+ for (int i = 0; i < Shape::dim(Level); ++i)
1703
+ {
1704
+ printf("[");
1705
+ for (int j=0; j < Shape::dim(Level+1); ++j)
1706
+ {
1707
+ print_value(storage(index));
1708
+
1709
+ if (j < Shape::dim(Level+1) - 1)
1710
+ {
1711
+ printf(" ");
1712
+ }
1713
+
1714
+ ++index;
1715
+ }
1716
+
1717
+ printf("]");
1718
+
1719
+ // next row
1720
+ if (i < Shape::dim(Level)-1)
1721
+ {
1722
+ printf("\n");
1723
+
1724
+ // indent next row
1725
+ for (int i=0; i <= Shape::N-2; ++i)
1726
+ printf(" ");
1727
+
1728
+ }
1729
+ }
1730
+ printf("]");
1731
+ }
1732
+ else
1733
+ {
1734
+ printf("[");
1735
+ for (int i = 0; i < Shape::dim(Level); ++i)
1736
+ {
1737
+ print_values<Level + 1>(storage, index + i * Shape::dim(Level));
1738
+ if (i < Shape::dim(Level) - 1)
1739
+ {
1740
+ printf("\n\n");
1741
+
1742
+ // indent next row
1743
+ for (int i=0; i <= Level; ++i)
1744
+ printf(" ");
1745
+ }
1746
+ }
1747
+ printf("]");
1748
+ }
1749
+ }
1750
+ }
1751
+
1752
+ inline CUDA_CALLABLE void print(bool reverse=false) const
1753
+ {
1754
+ if (WP_TILE_THREAD_IDX != 0)
1755
+ return;
1756
+
1757
+ if (reverse)
1758
+ print_values(grad);
1759
+ else
1760
+ print_values(data);
1761
+
1762
+ printf(" = tile(shape=(");
1763
+ for (int i=0; i < Layout::Shape::N; ++i)
1764
+ {
1765
+ printf("%d", Layout::Shape::dim(i));
1766
+ if (i != Layout::Shape::N-1)
1767
+ printf(",");
1768
+ }
1769
+
1770
+ printf("), storage=shared)\n");
1771
+ }
1772
+ };
1773
+
1774
+
1775
+ template <typename T, typename L>
1776
+ void tile_register_t<T, L>::print() const
1777
+ {
1778
+ // create a temporary shared tile so that
1779
+ // we can print it deterministically
1780
+ #if defined(__CUDA_ARCH__)
1781
+ __shared__ T smem[L::Size];
1782
+ #else
1783
+ T smem[L::Size];
1784
+ #endif
1785
+ tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
1786
+
1787
+ scratch.assign(*this);
1788
+
1789
+ WP_TILE_SYNC();
1790
+
1791
+ if (WP_TILE_THREAD_IDX == 0)
1792
+ {
1793
+ scratch.print_values(scratch.data, 0);
1794
+
1795
+ printf(" = tile(shape=(");
1796
+ for (int i=0; i < L::Shape::N; ++i)
1797
+ {
1798
+ printf("%d", L::Shape::dim(i));
1799
+ if (i != L::Shape::N-1)
1800
+ printf(",");
1801
+ }
1802
+
1803
+ printf("), storage=register)\n");
1804
+ }
1805
+
1806
+ WP_TILE_SYNC();
1807
+ }
1808
+
1809
+ // print entry points
1810
+ template <typename T, typename L>
1811
+ inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
1812
+
1813
+ template <typename T, typename L>
1814
+ inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1815
+
1816
+ template <typename T, typename L, bool Owner>
1817
+ inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
1818
+
1819
+ template <typename T, typename L, bool Owner>
1820
+ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1821
+
1822
+ template <typename T, typename L, bool O>
1823
+ inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
1824
+ {
1825
+ return L::Shape::dim(0);
1826
+ }
1827
+
1828
+ template <typename T, typename L, bool O, typename AdjTile>
1829
+ inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile& a, int& adj_ret)
1830
+ {
1831
+ }
1832
+
1833
+ template <typename T, typename L>
1834
+ inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
1835
+ {
1836
+ return L::Shape::dim(0);
1837
+ }
1838
+
1839
+ template <typename T, typename L, typename AdjTile>
1840
+ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile& a, int& adj_ret)
1841
+ {
1842
+ }
1843
+
1844
+ // where specialization for register/shared tiles
1845
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1846
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_register_t<T, LRegister>& a, const tile_shared_t<T, LShared, Owner>& b)
1847
+ {
1848
+ // The double NOT operator !! casts to bool without compiler warnings.
1849
+ return (!!cond) ? a : b.copy_to_register();
1850
+ }
1851
+
1852
+ template <typename C, typename T, typename LRegister, typename LShared, bool Owner>
1853
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, LShared, Owner>& a, const tile_register_t<T, LRegister>& b)
1854
+ {
1855
+ // The double NOT operator !! casts to bool without compiler warnings.
1856
+ return (!!cond) ? a.copy_to_register() : b;
1857
+ }
1858
+
1859
+ template <typename C, typename T, typename L, bool Owner>
1860
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b)
1861
+ {
1862
+ // The double NOT operator !! casts to bool without compiler warnings.
1863
+ return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
1864
+ }
1865
+
1866
+ template <typename C, typename T, typename L, bool LOwner, bool ROwner>
1867
+ inline CUDA_CALLABLE auto where(const C& cond, const tile_shared_t<T, L, LOwner>& a, const tile_shared_t<T, L, ROwner>& b)
1868
+ {
1869
+ // The double NOT operator !! casts to bool without compiler warnings.
1870
+ return (!!cond) ? tile_shared_t<T, L, false>(a.data.ptr, a.grad.ptr) : tile_shared_t<T, L, false>(b.data.ptr, b.grad.ptr);
1871
+ }
1872
+
1873
+ // adj_where same as in builtin.h
1874
+
1875
+ // copy specialization for shared tiles, the lvalue this gets assigned to is owning, thus, this invokes the copy assign path
1876
+ template <typename T, typename L, bool Owner>
1877
+ inline CUDA_CALLABLE auto copy(const tile_shared_t<T, L, Owner>& t)
1878
+ {
1879
+ return tile_shared_t<T, L, false>(t.data.ptr, t.grad.ptr);
1880
+ }
1881
+
1882
+ template <typename T, typename L, bool Owner>
1883
+ inline CUDA_CALLABLE void adj_copy(const tile_shared_t<T, L, Owner>& src, tile_shared_t<T, L, Owner>& adj_src, tile_shared_t<T, L, Owner>& adj_dest)
1884
+ {
1885
+ adj_src += adj_dest;
1886
+ adj_dest.grad_zero();
1887
+ }
1888
+
1889
+ // helpers to allocate shared tiles
1890
+ template <typename T, typename Shape, typename Strides, bool RequiresGrad>
1891
+ inline CUDA_CALLABLE auto tile_alloc_empty()
1892
+ {
1893
+ constexpr int size = Shape::size();
1894
+ T* data = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
1895
+ T* grad = nullptr;
1896
+
1897
+ #if FP_CHECK
1898
+
1899
+ // initialize tile to quiet nan
1900
+ uint32_t qnanbits = 0x7FC00000;
1901
+ float qnan = *(float*)(&qnanbits);
1902
+
1903
+ for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1904
+ data[i] = T(qnan);
1905
+
1906
+ WP_TILE_SYNC();
1907
+
1908
+ #endif // FP_CHECK
1909
+
1910
+
1911
+ if (RequiresGrad)
1912
+ {
1913
+ grad = (T*)tile_shared_storage_t::alloc(size*sizeof(T));
1914
+
1915
+ for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1916
+ grad[i] = T(0);
1917
+
1918
+ WP_TILE_SYNC();
1919
+ }
1920
+
1921
+ return tile_shared_t<T, tile_layout_strided_t<Shape, Strides>>(data, grad);
1922
+ }
1923
+
1924
+
1925
+ //-----------------------------------------------------------------------------------------------------
1926
+ // High level entry points for each op (correspond to one Warp builtin)
1927
+
1928
+ // construct a tile from a local SIMT value (one per-thread)
1929
+ template <typename T>
1930
+ inline CUDA_CALLABLE auto tile(const T& x)
1931
+ {
1932
+ tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
1933
+
1934
+ using Layout = typename decltype(result)::Layout;
1935
+ static_assert(Layout::NumRegs == 1, "Expected Layout::NumRegs == 1");
1936
+
1937
+ result.data[0] = x;
1938
+ return result;
1939
+ }
1940
+
1941
+ // overload for constructing a tile from a per-thread vector
1942
+ template <typename T, unsigned Length>
1943
+ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1944
+ {
1945
+ tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
1946
+
1947
+ using Layout = typename decltype(result)::Layout;
1948
+ static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
1949
+
1950
+ for (unsigned i=0; i < Length; ++i)
1951
+ result.data[i] = x[i];
1952
+
1953
+ return result;
1954
+ }
1955
+
1956
+ // overload for constructing a tile from a per-thread matrix
1957
+ template <unsigned Rows, unsigned Cols, typename T>
1958
+ inline CUDA_CALLABLE auto tile(const wp::mat_t<Rows, Cols, T>& x)
1959
+ {
1960
+ tile_register_t<T, tile_layout_register_t<tile_shape_t<Rows, Cols, WP_TILE_BLOCK_DIM>>> result;
1961
+
1962
+ using Layout = typename decltype(result)::Layout;
1963
+ static_assert(Layout::NumRegs == Rows*Cols, "Expected Layout::NumRegs == Rows*Cols");
1964
+
1965
+ for (unsigned i=0; i < Rows; ++i)
1966
+ for (unsigned j=0; j < Cols; ++j)
1967
+ result.data[i*Cols + j] = x.data[i][j];
1968
+
1969
+ return result;
1970
+ }
1971
+
1972
+ // it is sufficient to use a single adjoint for all tile overload funcs
1973
+ // it is also necessary, because we don't provide a dispatch_func for adjoint calls
1974
+ // so the compiler will default to choosing based on argument types
1975
+ template <typename T, typename AdjTile>
1976
+ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1977
+ {
1978
+ static_assert(AdjTile::Layout::Shape::dim(AdjTile::Layout::Shape::N - 1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(AdjTile::Layout::Shape::N - 1) == WP_TILE_BLOCK_DIM");
1979
+
1980
+ auto adj_reg = adj_ret.copy_to_register();
1981
+
1982
+ if constexpr (AdjTile::Layout::Shape::N == 1)
1983
+ {
1984
+ adj_x += adj_reg.data[0];
1985
+ }
1986
+ else if constexpr (AdjTile::Layout::Shape::N == 2)
1987
+ {
1988
+ for (unsigned i=0; i < AdjTile::Layout::Shape::dim(0); ++i)
1989
+ adj_x[i] += adj_reg.data[i];
1990
+ }
1991
+ else if constexpr (AdjTile::Layout::Shape::N == 3)
1992
+ {
1993
+ for (unsigned i=0; i < AdjTile::Layout::Shape::dim(0); ++i)
1994
+ for (unsigned j=0; j < AdjTile::Layout::Shape::dim(1); ++j)
1995
+ adj_x.data[i][j] += adj_reg.data[i*AdjTile::Layout::Shape::dim(1) + j];
1996
+ }
1997
+ }
1998
+
1999
+
2000
+ template <typename Tile>
2001
+ inline CUDA_CALLABLE auto untile(Tile& tile)
2002
+ {
2003
+ // code-gen should have set the tile to
2004
+ // have exactly the block dimension so
2005
+ // there is exactly one value per-thread
2006
+ auto reg = tile.copy_to_register();
2007
+
2008
+ constexpr int N = Tile::Layout::Shape::N;
2009
+
2010
+ // scalar case
2011
+ if constexpr(N == 1)
2012
+ {
2013
+ return reg.data[0];
2014
+ }
2015
+
2016
+ // vector case
2017
+ if constexpr(N == 2)
2018
+ {
2019
+ constexpr int Length = Tile::Layout::Shape::dim(0);
2020
+ wp::vec_t<Length, typename Tile::Type> v;
2021
+ for (int i=0; i < Length; ++i)
2022
+ v[i] = reg.data[i];
2023
+
2024
+ return v;
2025
+ }
2026
+
2027
+ // matrix case
2028
+ if constexpr(N == 3)
2029
+ {
2030
+ constexpr int Rows = Tile::Layout::Shape::dim(0);
2031
+ constexpr int Cols = Tile::Layout::Shape::dim(1);
2032
+ wp::mat_t<Rows, Cols, typename Tile::Type> m;
2033
+ for (int i=0; i < Rows; ++i)
2034
+ for (int j=0; j < Cols; ++j)
2035
+ m.data[i][j] = reg.data[i*Cols + j];
2036
+
2037
+ return m;
2038
+ }
2039
+ }
2040
+
2041
+ template <typename Tile, typename Value>
2042
+ inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
2043
+ {
2044
+ auto adj = adj_tile.copy_to_register();
2045
+
2046
+ constexpr int N = Tile::Layout::Shape::N;
2047
+
2048
+ // scalar case
2049
+ if constexpr(N == 1)
2050
+ {
2051
+ adj.data[0] += adj_ret;
2052
+ }
2053
+
2054
+ // vector case
2055
+ if constexpr(N == 2)
2056
+ {
2057
+ constexpr int Length = Tile::Layout::Shape::dim(0);
2058
+ for (int i=0; i < Length; ++i)
2059
+ adj.data[i] += adj_ret[i];
2060
+ }
2061
+
2062
+ // matrix case
2063
+ if constexpr(N == 3)
2064
+ {
2065
+ constexpr int Rows = Tile::Layout::Shape::dim(0);
2066
+ constexpr int Cols = Tile::Layout::Shape::dim(1);
2067
+ for (int i=0; i < Rows; ++i)
2068
+ for (int j=0; j < Cols; ++j)
2069
+ adj.data[i*Cols + j] += adj_ret.data[i][j];
2070
+ }
2071
+
2072
+ adj_tile.assign(adj);
2073
+ }
2074
+
2075
+ // zero initialized tile
2076
+ template <typename T, unsigned... Shape>
2077
+ inline CUDA_CALLABLE auto tile_zeros()
2078
+ {
2079
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
2080
+ return T(0);
2081
+ }
2082
+
2083
+ // one-initialized tile
2084
+ template <typename T, unsigned... Shape>
2085
+ inline CUDA_CALLABLE auto tile_ones()
2086
+ {
2087
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
2088
+ return T(1);
2089
+ }
2090
+
2091
+ // value-initialized tile
2092
+ template <typename T, unsigned... Shape>
2093
+ inline CUDA_CALLABLE auto tile_full(T x)
2094
+ {
2095
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
2096
+ return x;
2097
+ }
2098
+
2099
+ // tile with evenly spaced values
2100
+ template <typename T, int Len>
2101
+ inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
2102
+ {
2103
+ auto out = tile_register<T, Len>();
2104
+
2105
+ using Layout = typename decltype(out)::Layout;
2106
+
2107
+ WP_PRAGMA_UNROLL
2108
+ for (int i=0; i < Layout::NumRegs; ++i)
2109
+ {
2110
+ const int linear = Layout::linear_from_register(i);
2111
+
2112
+ // handle case where tile size is not
2113
+ // aligned to block dimensions
2114
+ if (!Layout::valid(linear))
2115
+ break;
2116
+
2117
+ out.data[i] = start + linear*step;
2118
+ }
2119
+
2120
+ return out;
2121
+ }
2122
+
2123
+ template <typename T, typename AdjTile>
2124
+ inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
2125
+ T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
2126
+
2127
+ // entry point for load operations, these just return a reference to a global memory array + coordinate
2128
+ template <typename T, bool BoundsCheck, unsigned... Shape, typename... Offset>
2129
+ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Offset... offset)
2130
+ {
2131
+ return tile_global_t<T, tile_shape_t<Shape...>, BoundsCheck>(src, tile_coord(offset...));
2132
+ }
2133
+
2134
+ // used for indexed loads and stores
2135
+ template <typename T, int M, typename Coord>
2136
+ inline CUDA_CALLABLE bool compute_index(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Coord c, int& out)
2137
+ {
2138
+ int index = 0;
2139
+
2140
+ WP_PRAGMA_UNROLL
2141
+ for (int i = 0; i < Coord::size(); ++i)
2142
+ {
2143
+ if (i == axis)
2144
+ {
2145
+ // global = offset_coord + index_mapped_coord
2146
+ int index_along_axis = offset[i] + indices.data(c[i]);
2147
+
2148
+ // handle out of bounds case
2149
+ if (index_along_axis >= src.shape[i])
2150
+ return false;
2151
+ else
2152
+ index += src.strides[i] * index_along_axis;
2153
+ }
2154
+ else
2155
+ {
2156
+ // global = offset_coord + coord
2157
+ int g = offset[i] + c[i];
2158
+
2159
+ // handle out of bounds case
2160
+ if (g >= src.shape[i])
2161
+ return false;
2162
+ else
2163
+ index += src.strides[i] * g;
2164
+ }
2165
+ }
2166
+
2167
+ // array strides are in bytes so we convert to elements
2168
+ out = index / sizeof(T);
2169
+ return true;
2170
+ }
2171
+
2172
+
2173
+ template <unsigned... Shape, int M, typename T, typename... Offset>
2174
+ inline CUDA_CALLABLE auto tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Offset... offset)
2175
+ {
2176
+ auto out = tile_register_t<T, tile_layout_register_t<tile_shape_t<Shape...>>>();
2177
+ auto offset_coord = tile_coord(offset...);
2178
+
2179
+ out.apply([&](int reg, auto c) {
2180
+ int i;
2181
+ if (compute_index(src, indices, axis, offset_coord, c, i))
2182
+ out.data[reg] = src.data[i];
2183
+ else
2184
+ out.data[reg] = T(0);
2185
+ });
2186
+
2187
+ return out;
2188
+ }
2189
+
2190
+ // // entry point for tile store operations
2191
+ // template <typename... Indices, typename T, typename Tile>
2192
+ // inline CUDA_CALLABLE void tile_store(array_t<T>& dest, Tile& src, Indices... x)
2193
+ // {
2194
+ // src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x)));
2195
+ // }
2196
+
2197
+ // entry point for tile store operations
2198
+ template <typename T, bool BoundsCheck, typename Tile>
2199
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x))); }
2200
+ template <typename T, bool BoundsCheck, typename Tile>
2201
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y))); }
2202
+ template <typename T, bool BoundsCheck, typename Tile>
2203
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z))); }
2204
+ template <typename T, bool BoundsCheck, typename Tile>
2205
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck>(dest, tile_coord(x, y, z, w))); }
2206
+
2207
+ template <typename T, int M, typename Tile, typename Coord>
2208
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
2209
+ {
2210
+ auto src_reg = src.copy_to_register();
2211
+
2212
+ src_reg.apply([&](int reg, auto c) {
2213
+ int i;
2214
+ if (compute_index(dest, indices, axis, offset, c, i))
2215
+ dest.data[i] = src_reg.data[reg];
2216
+ });
2217
+ }
2218
+
2219
+ // entry point for tile index store operations
2220
+ template <typename T, int M, typename Tile>
2221
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x), src); }
2222
+ template <typename T, int M, typename Tile>
2223
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y), src); }
2224
+ template <typename T, int M, typename Tile>
2225
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
2226
+ template <typename T, int M, typename Tile>
2227
+ inline CUDA_CALLABLE void tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
2228
+
2229
+
2230
+ // compiler struggles with these if they are one line
2231
+ template <typename T, bool BoundsCheck, typename Tile>
2232
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) {
2233
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x));
2234
+ return src.atomic_add(global);
2235
+ }
2236
+ template <typename T, bool BoundsCheck, typename Tile>
2237
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) {
2238
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y));
2239
+ return src.atomic_add(global);
2240
+ }
2241
+ template <typename T, bool BoundsCheck, typename Tile>
2242
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) {
2243
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z));
2244
+ return src.atomic_add(global);
2245
+ }
2246
+ template <typename T, bool BoundsCheck, typename Tile>
2247
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) {
2248
+ tile_global_t<T, typename Tile::Layout::Shape, BoundsCheck> global(dest, tile_coord(x, y, z, w));
2249
+ return src.atomic_add(global);
2250
+ }
2251
+
2252
+ template <typename T, int M, typename Tile, typename Coord>
2253
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& src)
2254
+ {
2255
+ auto src_reg = src.copy_to_register();
2256
+ auto ret_reg = tile_register_like<Tile>();
2257
+
2258
+ src_reg.apply([&](int reg, auto c) {
2259
+ int i;
2260
+ if (compute_index(dest, indices, axis, offset, c, i))
2261
+ ret_reg.data[reg] = wp::atomic_add(&dest.data[i], src_reg.data[reg]);
2262
+ else
2263
+ ret_reg.data[reg] = T(0);
2264
+ });
2265
+
2266
+ return ret_reg;
2267
+ }
2268
+
2269
+ // entry point for tile index atomic add operations
2270
+ template <typename T, int M, typename Tile>
2271
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x), src); }
2272
+
2273
+ template <typename T, int M, typename Tile>
2274
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y), src); }
2275
+
2276
+ template <typename T, int M, typename Tile>
2277
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z), src); }
2278
+
2279
+ template <typename T, int M, typename Tile>
2280
+ inline CUDA_CALLABLE auto tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& src) { return tile_atomic_add_indexed(dest, indices, axis, tile_coord(x, y, z, w), src); }
2281
+
2282
+
2283
+ //-------------------------------------
2284
+ // Adjoints
2285
+
2286
+ template <typename T, typename AdjTile, typename Coord>
2287
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
2288
+ array_t<T>& adj_src, Coord adj_c,
2289
+ AdjTile& adj_ret)
2290
+ {
2291
+ tile_global_t<T, typename AdjTile::Layout::Shape> dest(src, c);
2292
+
2293
+ // we allow users to override grad of src
2294
+ if (adj_src.data)
2295
+ dest.data.grad = adj_src.data;
2296
+
2297
+ adj_ret.atomic_add_grad(dest);
2298
+ }
2299
+
2300
+ template <typename T, typename AdjTile>
2301
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
2302
+ template <typename T, typename AdjTile>
2303
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, array_t<T>& adj_src, int adj_x, int adj_y, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y), adj_src, tile_coord(0,0), adj_ret); }
2304
+ template <typename T, typename AdjTile>
2305
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z), adj_src, tile_coord(0,0,0), adj_ret); }
2306
+ template <typename T, typename AdjTile>
2307
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
2308
+
2309
+ template <typename T, int M, typename AdjTile, typename Coord>
2310
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset,
2311
+ array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset,
2312
+ AdjTile& adj_ret)
2313
+ {
2314
+ // we allow users to override grad of src
2315
+ if (adj_src.data)
2316
+ src.grad = adj_src.data;
2317
+
2318
+ auto adj_ret_reg = adj_ret.grad_to_register();
2319
+
2320
+ adj_ret_reg.apply([&](int reg, auto c) {
2321
+ int i;
2322
+ if (compute_index(src, indices, axis, offset, c, i))
2323
+ wp::atomic_add(&src.grad[i], adj_ret_reg.data[reg]);
2324
+ });
2325
+ }
2326
+
2327
+ template <typename T, int M, typename AdjTile>
2328
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_ret)
2329
+ {
2330
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x), adj_src, adj_indices, adj_axis, tile_coord(0), adj_ret);
2331
+ }
2332
+ template <typename T, int M, typename AdjTile>
2333
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_ret)
2334
+ {
2335
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y), adj_src, adj_indices, adj_axis, tile_coord(0, 0), adj_ret);
2336
+ }
2337
+ template <typename T, int M, typename AdjTile>
2338
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret)
2339
+ {
2340
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0), adj_ret);
2341
+ }
2342
+ template <typename T, int M, typename AdjTile>
2343
+ inline CUDA_CALLABLE void adj_tile_load_indexed(array_t<T>& src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, array_t<T>& adj_src, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret)
2344
+ {
2345
+ adj_tile_load_indexed(src, indices, axis, tile_coord(x, y, z, w), adj_src, adj_indices, adj_axis, tile_coord(0, 0, 0, 0), adj_ret);
2346
+ }
2347
+
2348
+ template <typename T, typename Tile, typename AdjTile, typename Coord>
2349
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
2350
+ {
2351
+ tile_global_t<T, typename AdjTile::Layout::Shape> src(dest, c);
2352
+
2353
+ // we allow users to override grad of src
2354
+ if (adj_dest.data)
2355
+ src.data.grad = adj_dest.data;
2356
+
2357
+ if (src.data.grad == nullptr)
2358
+ return;
2359
+
2360
+ adj_t.grad_add(src);
2361
+ }
2362
+
2363
+ template <typename T, typename Tile, typename AdjTile>
2364
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(0), adj_t); }
2365
+ template <typename T, typename Tile, typename AdjTile>
2366
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(0,0), adj_t); }
2367
+ template <typename T, typename Tile, typename AdjTile>
2368
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(0,0,0), adj_t); }
2369
+ template <typename T, typename Tile, typename AdjTile>
2370
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
2371
+
2372
+ template <typename T, int M, typename Tile, typename AdjTile, typename Coord>
2373
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, Coord offset, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, Coord adj_offset, AdjTile& adj_t)
2374
+ {
2375
+ // we allow users to override grad of src
2376
+ if (adj_dest.data)
2377
+ dest.grad = adj_dest.data;
2378
+
2379
+ auto adj_t_reg = tile_register_like<Tile>();
2380
+
2381
+ adj_t_reg.apply([&](int reg, auto c) {
2382
+ int i;
2383
+ if (compute_index(dest, indices, axis, offset, c, i))
2384
+ adj_t_reg.data[reg] += dest.grad[i];
2385
+ });
2386
+
2387
+ // write adjoints back
2388
+ adj_t.grad_add(adj_t_reg);
2389
+ }
2390
+
2391
+ template <typename T, int M, typename Tile, typename AdjTile>
2392
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
2393
+ template <typename T, int M, typename Tile, typename AdjTile>
2394
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
2395
+ template <typename T, int M, typename Tile, typename AdjTile>
2396
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
2397
+ template <typename T, int M, typename Tile, typename AdjTile>
2398
+ inline CUDA_CALLABLE void adj_tile_store_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
2399
+
2400
+ // adj_tile_atomic_add is an alias for adj_tile_store
2401
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
2402
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(adj_x), adj_t); }
2403
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
2404
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(adj_x, adj_y), adj_t); }
2405
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
2406
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(adj_x, adj_y, adj_z), adj_t); }
2407
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
2408
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
2409
+
2410
+ // adj_tile_atomic_add_indexed is an alias for adj_tile_store_indexed
2411
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2412
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x), t, adj_dest, adj_indices, adj_axis, tile_coord(0), adj_t); }
2413
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2414
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0), adj_t); }
2415
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2416
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0), adj_t); }
2417
+ template <typename T, int M, typename Tile, typename AdjTile, typename AdjRet>
2418
+ inline CUDA_CALLABLE void adj_tile_atomic_add_indexed(array_t<T>& dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& indices, int axis, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, tile_shared_t<int, tile_layout_strided_t<tile_shape_t<M>>>& adj_indices, int adj_axis, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store_indexed(dest, indices, axis, tile_coord(x, y, z, w), t, adj_dest, adj_indices, adj_axis, tile_coord(0,0,0,0), adj_t); }
2419
+
2420
+ // unary map
2421
+ template <typename Tile, typename Fwd, typename ReturnTile>
2422
+ inline CUDA_CALLABLE auto tile_map(Fwd op, Tile &a, ReturnTile &r)
2423
+ {
2424
+ // verify shapes and sizes are compatible
2425
+ using ShapeIn = typename Tile::Layout::Shape;
2426
+ using ShapeOut = typename ReturnTile::Layout::Shape;
2427
+
2428
+ static_assert(ShapeIn::N == ShapeOut::N, "Number of tile dimensions must match for unary map");
2429
+ static_assert(ShapeIn::size() == ShapeOut::size(), "Tile sizes must match for unary map");
2430
+
2431
+ auto out = tile_register_like<ReturnTile>();
2432
+ auto a_reg = a.copy_to_register();
2433
+
2434
+ using Layout = typename decltype(out)::Layout;
2435
+
2436
+ WP_PRAGMA_UNROLL
2437
+ for (int i=0; i < Layout::NumRegs; ++i)
2438
+ {
2439
+ out.data[i] = op(a_reg.data[i]);
2440
+ }
2441
+
2442
+ return out;
2443
+ }
2444
+
2445
+
2446
+ template <typename Tile, typename AdjTile, typename Fwd, typename Adj>
2447
+ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
2448
+ Tile& a,
2449
+ Adj adj_op,
2450
+ Tile& adj_a,
2451
+ AdjTile& adj_ret)
2452
+ {
2453
+ auto a_reg = a.copy_to_register();
2454
+ auto adj_a_reg = tile_register_like<Tile>();
2455
+ auto adj_ret_reg = adj_ret.grad_to_register();
2456
+
2457
+ using Layout = typename decltype(a_reg)::Layout;
2458
+
2459
+ WP_PRAGMA_UNROLL
2460
+ for (int i=0; i < Layout::NumRegs; ++i)
2461
+ {
2462
+ adj_op(a_reg.data[i], adj_a_reg.data[i], adj_ret_reg.data[i]);
2463
+ }
2464
+
2465
+ // write adjoints back
2466
+ adj_a.grad_add(adj_a_reg);
2467
+ }
2468
+
2469
+ // binary map
2470
+ template <typename TileA, typename TileB, typename Fwd, typename ReturnTile>
2471
+ inline CUDA_CALLABLE auto tile_map(Fwd op,
2472
+ TileA& a,
2473
+ TileB& b,
2474
+ ReturnTile& r)
2475
+ {
2476
+ // verify shapes and sizes are compatible
2477
+ using ShapeA = typename TileA::Layout::Shape;
2478
+ using ShapeB = typename TileB::Layout::Shape;
2479
+ using ShapeOut = typename ReturnTile::Layout::Shape;
2480
+
2481
+ static_assert(ShapeA::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
2482
+ static_assert(ShapeB::N == ShapeOut::N, "Number of tile dimensions must match for binary map");
2483
+
2484
+ static_assert(ShapeA::size() == ShapeOut::size(), "Tile sizes must match for binary map");
2485
+ static_assert(ShapeB::size() == ShapeOut::size(), "Tile sizes must match for binary map");
2486
+
2487
+ auto out = tile_register_like<ReturnTile>();
2488
+
2489
+ auto a_reg = a.copy_to_register();
2490
+ auto b_reg = b.copy_to_register();
2491
+
2492
+ using Layout = typename decltype(out)::Layout;
2493
+
2494
+ WP_PRAGMA_UNROLL
2495
+ for (int i=0; i < Layout::NumRegs; ++i)
2496
+ {
2497
+ out.data[i] = op(a_reg.data[i], b_reg.data[i]);
2498
+ }
2499
+
2500
+ return out;
2501
+ }
2502
+
2503
+ template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
2504
+ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
2505
+ TileA &a,
2506
+ TileB &b,
2507
+ Adj adj_op,
2508
+ TileA &adj_a,
2509
+ TileB &adj_b,
2510
+ AdjTile &adj_ret)
2511
+ {
2512
+ auto a_reg = a.copy_to_register();
2513
+ auto b_reg = b.copy_to_register();
2514
+
2515
+ // allocate storage for adjoints
2516
+ auto adj_a_reg = tile_register_like<TileA>();
2517
+ auto adj_b_reg = tile_register_like<TileB>();
2518
+
2519
+ auto adj_ret_reg = adj_ret.grad_to_register();
2520
+
2521
+ using Layout = typename decltype(a_reg)::Layout;
2522
+
2523
+ WP_PRAGMA_UNROLL
2524
+ for (int i=0; i < Layout::NumRegs; ++i)
2525
+ {
2526
+ adj_op(a_reg.data[i], b_reg.data[i], adj_a_reg.data[i], adj_b_reg.data[i], adj_ret_reg.data[i]);
2527
+ }
2528
+
2529
+ adj_a.grad_add(adj_a_reg);
2530
+ adj_b.grad_add(adj_b_reg);
2531
+ }
2532
+
2533
+ // We wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
2534
+ // this is important because many of the builtin operators don't follow particular conventions on references for
2535
+ // the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
2536
+ // The r argument is a dummy return tile argument, because we can't template on the return tile type in a macro definition.
2537
+ // So if we want users to be able to define functions that return a tile type that is different from the input type,
2538
+ // we must pass an extra dummy return tile argument that is used define the return type of tile_map.
2539
+
2540
+ #define tile_unary_map(op, a, r) tile_map([](auto x) { return op(x);}, a, r)
2541
+ #define adj_tile_unary_map(op, a, r, adj_op, adj_a, adj_r, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
2542
+
2543
+ #define tile_binary_map(op, a, b, r) tile_map([](auto x, auto y) { return op(x, y);}, a, b, r)
2544
+ #define adj_tile_binary_map(op, a, b, r, adj_op, adj_a, adj_b, adj_r, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
2545
+
2546
+ // -tile (unary neg)
2547
+ template <typename Tile>
2548
+ inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a, a); }
2549
+
2550
+ template <typename Tile, typename AdjTile>
2551
+ inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, a, wp::adj_neg, adj_a, adj_a, adj_ret); }
2552
+
2553
+
2554
+ // tile + tile
2555
+ template <typename TileA, typename TileB>
2556
+ inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
2557
+ {
2558
+ return tile_binary_map(add, a, b, a);
2559
+ }
2560
+
2561
+ // add overloads get called in user function adjoints generated by codegen (adj_tile += adj_ret)
2562
+ template <typename T, typename L>
2563
+ inline CUDA_CALLABLE auto add(tile_register_t<T, L>& a, const tile_register_t<T, L>& b) {
2564
+ return tile_add(a, b);
2565
+ }
2566
+
2567
+ template <typename T, typename L, bool Owner>
2568
+ inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_shared_t<T, L, Owner>& b) {
2569
+ return tile_add(a, b);
2570
+ }
2571
+
2572
+ template <typename T, typename L, bool Owner>
2573
+ inline CUDA_CALLABLE auto add(tile_register_t<T, L>& a, const tile_shared_t<T, L, Owner>& b) {
2574
+ return tile_add(a, b);
2575
+ }
2576
+
2577
+ template <typename T, typename L, bool Owner>
2578
+ inline CUDA_CALLABLE auto add(tile_shared_t<T, L, Owner>& a, const tile_register_t<T, L>& b) {
2579
+ return tile_add(a, b);
2580
+ }
2581
+
2582
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2583
+ inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2584
+ {
2585
+ adj_tile_binary_map(add, a, b, a, adj_add, adj_a, adj_b, adj_a, adj_c);
2586
+ }
2587
+
2588
+ // tile - tile
2589
+ template <typename TileA, typename TileB>
2590
+ inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
2591
+ {
2592
+ return tile_binary_map(sub, a, b, a);
2593
+ }
2594
+
2595
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2596
+ inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2597
+ {
2598
+ adj_tile_binary_map(sub, a, b, a, adj_sub, adj_a, adj_b, adj_a, adj_c);
2599
+ }
2600
+
2601
+
2602
+ // tile*scalar
2603
+ template <typename Tile>
2604
+ inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
2605
+ {
2606
+ // promote scalar to a constant tile
2607
+ auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
2608
+
2609
+ return tile_binary_map(mul, a, s_tile, a);
2610
+ }
2611
+
2612
+ template <typename Tile, typename AdjTile>
2613
+ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
2614
+ Tile& adj_a, typename Tile::Type& adj_s,
2615
+ AdjTile& adj_c)
2616
+ {
2617
+ auto s_tile = tile_register_like<Tile>();
2618
+ auto adj_s_tile = tile_register_like<Tile>();
2619
+
2620
+ using Layout = typename decltype(adj_s_tile)::Layout;
2621
+
2622
+ // initialize to constant
2623
+ s_tile = s;
2624
+
2625
+ adj_tile_binary_map(mul, a, s_tile, a, adj_mul, adj_a, adj_s_tile, adj_a, adj_c);
2626
+
2627
+ for (int i=0; i < Layout::NumRegs; ++i)
2628
+ {
2629
+ adj_s += adj_s_tile.data[i];
2630
+ }
2631
+ }
2632
+
2633
+
2634
+ // scalar*tile
2635
+ template <typename Tile>
2636
+ inline CUDA_CALLABLE auto tile_mul(const typename Tile::Type& s, Tile& a)
2637
+ {
2638
+ return tile_mul(a, s);
2639
+ }
2640
+
2641
+ template <typename Tile, typename AdjTile>
2642
+ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
2643
+ typename Tile::Type& adj_s, Tile& adj_a,
2644
+ AdjTile& adj_c)
2645
+ {
2646
+ adj_tile_mul(a, s, adj_a, adj_s, adj_c);
2647
+ }
2648
+
2649
+
2650
+ // tile & tile
2651
+ template <typename TileA, typename TileB>
2652
+ inline CUDA_CALLABLE auto tile_bit_and(TileA& a, TileB& b)
2653
+ {
2654
+ return tile_binary_map(bit_and, a, b, a);
2655
+ }
2656
+
2657
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2658
+ inline CUDA_CALLABLE void adj_tile_bit_and(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2659
+ {
2660
+ }
2661
+
2662
+ // tile | tile
2663
+ template <typename TileA, typename TileB>
2664
+ inline CUDA_CALLABLE auto tile_bit_or(TileA& a, TileB& b)
2665
+ {
2666
+ return tile_binary_map(bit_or, a, b, a);
2667
+ }
2668
+
2669
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2670
+ inline CUDA_CALLABLE void adj_tile_bit_or(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2671
+ {
2672
+ }
2673
+
2674
+ // tile ^ tile
2675
+ template <typename TileA, typename TileB>
2676
+ inline CUDA_CALLABLE auto tile_bit_xor(TileA& a, TileB& b)
2677
+ {
2678
+ return tile_binary_map(bit_xor, a, b, a);
2679
+ }
2680
+
2681
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
2682
+ inline CUDA_CALLABLE void adj_tile_bit_xor(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
2683
+ {
2684
+ }
2685
+
2686
+
2687
+ template <typename TileA, typename TileB>
2688
+ inline CUDA_CALLABLE void tile_add_inplace(TileA& a, TileB& b)
2689
+ {
2690
+ using ShapeA = typename TileA::Layout::Shape;
2691
+ using ShapeB = typename TileB::Layout::Shape;
2692
+
2693
+ // verify shapes and sizes are compatible
2694
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace addition");
2695
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace addition");
2696
+
2697
+ auto a_reg = a.copy_to_register();
2698
+ auto b_reg = b.copy_to_register();
2699
+
2700
+ using Layout = typename decltype(b_reg)::Layout;
2701
+
2702
+ WP_PRAGMA_UNROLL
2703
+ for (int i=0; i < Layout::NumRegs; ++i)
2704
+ {
2705
+ const int linear = Layout::linear_from_register(i);
2706
+
2707
+ if(!Layout::valid(linear))
2708
+ break;
2709
+
2710
+ a_reg.data[i] += b_reg.data[i];
2711
+ }
2712
+
2713
+ a.assign(a_reg);
2714
+ }
2715
+
2716
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2717
+ inline CUDA_CALLABLE void adj_tile_add_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b)
2718
+ {
2719
+ using ShapeA = typename TileA::Layout::Shape;
2720
+ using ShapeB = typename TileB::Layout::Shape;
2721
+
2722
+ // verify shapes and sizes are compatible
2723
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace addition");
2724
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace addition");
2725
+
2726
+ // allocate storage for adjoints
2727
+ auto adj_a_reg = adj_a.grad_to_register();
2728
+ auto adj_b_reg = tile_register_like<TileB>();
2729
+
2730
+ using Layout = typename decltype(adj_a_reg)::Layout;
2731
+
2732
+ WP_PRAGMA_UNROLL
2733
+ for (int i=0; i < Layout::NumRegs; ++i)
2734
+ {
2735
+ const int linear = Layout::linear_from_register(i);
2736
+
2737
+ if(!Layout::valid(linear))
2738
+ break;
2739
+
2740
+ adj_b_reg.data[i] += adj_a_reg.data[i];
2741
+ }
2742
+
2743
+ adj_b.grad_add(adj_b_reg);
2744
+ }
2745
+
2746
+ template <typename TileA, typename TileB>
2747
+ inline CUDA_CALLABLE void tile_sub_inplace(TileA& a, TileB& b)
2748
+ {
2749
+ using ShapeA = typename TileA::Layout::Shape;
2750
+ using ShapeB = typename TileB::Layout::Shape;
2751
+
2752
+ // verify shapes and sizes are compatible
2753
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace subtraction");
2754
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace subtraction");
2755
+
2756
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2757
+ auto a_reg = a.copy_to_register();
2758
+ auto b_reg = b.copy_to_register();
2759
+
2760
+ using Layout = typename decltype(a_reg)::Layout;
2761
+
2762
+ WP_PRAGMA_UNROLL
2763
+ for (int i=0; i < Layout::NumRegs; ++i)
2764
+ {
2765
+ const int linear = Layout::linear_from_register(i);
2766
+
2767
+ if(!Layout::valid(linear))
2768
+ break;
2769
+
2770
+ a_reg.data[i] -= b_reg.data[i];
2771
+ }
2772
+
2773
+ a.assign(a_reg);
2774
+ }
2775
+
2776
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2777
+ inline CUDA_CALLABLE void adj_tile_sub_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b)
2778
+ {
2779
+ using ShapeA = typename TileA::Layout::Shape;
2780
+ using ShapeB = typename TileB::Layout::Shape;
2781
+
2782
+ // verify shapes and sizes are compatible
2783
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace subtraction");
2784
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace subtraction");
2785
+
2786
+ // allocate storage for adjoints
2787
+ auto adj_a_reg = adj_a.grad_to_register();
2788
+ auto adj_b_reg = tile_register_like<TileB>();
2789
+
2790
+ using Layout = typename decltype(adj_a_reg)::Layout;
2791
+
2792
+ WP_PRAGMA_UNROLL
2793
+ for (int i=0; i < Layout::NumRegs; ++i)
2794
+ {
2795
+ const int linear = Layout::linear_from_register(i);
2796
+
2797
+ if(!Layout::valid(linear))
2798
+ break;
2799
+
2800
+ adj_b_reg.data[i] -= adj_a_reg.data[i];
2801
+ }
2802
+
2803
+ adj_b.grad_add(adj_b_reg);
2804
+ }
2805
+
2806
+ template <typename TileA, typename TileB>
2807
+ inline CUDA_CALLABLE void tile_bit_and_inplace(TileA& a, TileB& b)
2808
+ {
2809
+ using ShapeA = typename TileA::Layout::Shape;
2810
+ using ShapeB = typename TileB::Layout::Shape;
2811
+
2812
+ // verify shapes and sizes are compatible
2813
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise AND");
2814
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise AND");
2815
+
2816
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2817
+ auto a_reg = a.copy_to_register();
2818
+ auto b_reg = b.copy_to_register();
2819
+
2820
+ using Layout = typename decltype(a_reg)::Layout;
2821
+
2822
+ WP_PRAGMA_UNROLL
2823
+ for (int i=0; i < Layout::NumRegs; ++i)
2824
+ {
2825
+ const int linear = Layout::linear_from_register(i);
2826
+
2827
+ if(!Layout::valid(linear))
2828
+ break;
2829
+
2830
+ a_reg.data[i] &= b_reg.data[i];
2831
+ }
2832
+
2833
+ a.assign(a_reg);
2834
+ }
2835
+
2836
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2837
+ inline CUDA_CALLABLE void adj_tile_bit_and_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2838
+
2839
+ template <typename TileA, typename TileB>
2840
+ inline CUDA_CALLABLE void tile_bit_or_inplace(TileA& a, TileB& b)
2841
+ {
2842
+ using ShapeA = typename TileA::Layout::Shape;
2843
+ using ShapeB = typename TileB::Layout::Shape;
2844
+
2845
+ // verify shapes and sizes are compatible
2846
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise OR");
2847
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise OR");
2848
+
2849
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2850
+ auto a_reg = a.copy_to_register();
2851
+ auto b_reg = b.copy_to_register();
2852
+
2853
+ using Layout = typename decltype(a_reg)::Layout;
2854
+
2855
+ WP_PRAGMA_UNROLL
2856
+ for (int i=0; i < Layout::NumRegs; ++i)
2857
+ {
2858
+ const int linear = Layout::linear_from_register(i);
2859
+
2860
+ if(!Layout::valid(linear))
2861
+ break;
2862
+
2863
+ a_reg.data[i] |= b_reg.data[i];
2864
+ }
2865
+
2866
+ a.assign(a_reg);
2867
+ }
2868
+
2869
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2870
+ inline CUDA_CALLABLE void adj_tile_bit_or_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2871
+
2872
+ template <typename TileA, typename TileB>
2873
+ inline CUDA_CALLABLE void tile_bit_xor_inplace(TileA& a, TileB& b)
2874
+ {
2875
+ using ShapeA = typename TileA::Layout::Shape;
2876
+ using ShapeB = typename TileB::Layout::Shape;
2877
+
2878
+ // verify shapes and sizes are compatible
2879
+ static_assert(ShapeA::N == ShapeB::N, "Tile shapes must match for inplace bitwise XOR");
2880
+ static_assert(ShapeA::size() == ShapeB::size(), "Tile sizes must match for inplace bitwise XOR");
2881
+
2882
+ // work with register tiles for inplace operations, regardless of the storage type of the input tiles
2883
+ auto a_reg = a.copy_to_register();
2884
+ auto b_reg = b.copy_to_register();
2885
+
2886
+ using Layout = typename decltype(a_reg)::Layout;
2887
+
2888
+ WP_PRAGMA_UNROLL
2889
+ for (int i=0; i < Layout::NumRegs; ++i)
2890
+ {
2891
+ const int linear = Layout::linear_from_register(i);
2892
+
2893
+ if(!Layout::valid(linear))
2894
+ break;
2895
+
2896
+ a_reg.data[i] ^= b_reg.data[i];
2897
+ }
2898
+
2899
+ a.assign(a_reg);
2900
+ }
2901
+
2902
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2903
+ inline CUDA_CALLABLE void adj_tile_bit_xor_inplace(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b) {}
2904
+
2905
+
2906
+ template<typename Tile>
2907
+ typename Tile::Type tile_extract(Tile& t, int i) {
2908
+ return t.extract(tile_coord(i));
2909
+ }
2910
+ template<typename Tile>
2911
+ auto tile_extract(Tile& t, int i, int j) {
2912
+ if constexpr(is_vector<typename Tile::Type>::value) {
2913
+ return t.extract(tile_coord(i))[j];
2914
+ } else {
2915
+ return t.extract(tile_coord(i,j));
2916
+ }
2917
+ }
2918
+ template<typename Tile>
2919
+ auto tile_extract(Tile& t, int i, int j, int k) {
2920
+ if constexpr(is_vector<typename Tile::Type>::value) {
2921
+ return t.extract(tile_coord(i,j))[k];
2922
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2923
+ return t.extract(tile_coord(i)).data[j][k];
2924
+ } else {
2925
+ return t.extract(tile_coord(i,j,k));
2926
+ }
2927
+ }
2928
+ template<typename Tile>
2929
+ auto tile_extract(Tile& t, int i, int j, int k, int l) {
2930
+ if constexpr(is_vector<typename Tile::Type>::value) {
2931
+ return t.extract(tile_coord(i,j,k))[l];
2932
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2933
+ return t.extract(tile_coord(i,j)).data[k][l];
2934
+ } else {
2935
+ return t.extract(tile_coord(i,j,k,l));
2936
+ }
2937
+ }
2938
+ template<typename Tile>
2939
+ auto tile_extract(Tile& t, int i, int j, int k, int l, int m) {
2940
+ if constexpr(is_vector<typename Tile::Type>::value) {
2941
+ return t.extract(tile_coord(i,j,k,l))[m];
2942
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2943
+ return t.extract(tile_coord(i,j,k)).data[l][m];
2944
+ } else {
2945
+ static_assert(always_false<Tile>::value,
2946
+ "tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
2947
+ }
2948
+ }
2949
+ template<typename Tile>
2950
+ auto tile_extract(Tile& t, int i, int j, int k, int l, int m, int n) {
2951
+ if constexpr(is_matrix<typename Tile::Type>::value) {
2952
+ return t.extract(tile_coord(i,j,k,l)).data[m][n];
2953
+ } else {
2954
+ static_assert(always_false<Tile>::value,
2955
+ "tile_extract with 6 indices requires a tile of matrices (4D tile)");
2956
+ }
2957
+ }
2958
+
2959
+ template<typename Tile, typename AdjTile>
2960
+ void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) {
2961
+ adj_t.adj_extract(tile_coord(i), adj_ret);
2962
+ }
2963
+ template<typename Tile, typename AdjTile, typename AdjType>
2964
+ void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, AdjType adj_ret) {
2965
+ if constexpr(is_vector<typename Tile::Type>::value) {
2966
+ typename Tile::Type vector_adj{};
2967
+ vector_adj[j] = adj_ret;
2968
+ adj_t.adj_extract(tile_coord(i), vector_adj);
2969
+ } else {
2970
+ adj_t.adj_extract(tile_coord(i, j), adj_ret);
2971
+ }
2972
+ }
2973
+ template<typename Tile, typename AdjTile, typename AdjType>
2974
+ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, AdjType adj_ret) {
2975
+ if constexpr(is_vector<typename Tile::Type>::value) {
2976
+ typename Tile::Type vector_adj{};
2977
+ vector_adj[k] = adj_ret;
2978
+ adj_t.adj_extract(tile_coord(i, j), vector_adj);
2979
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2980
+ typename Tile::Type matrix_adj{};
2981
+ matrix_adj.data[j][k] = adj_ret;
2982
+ adj_t.adj_extract(tile_coord(i), matrix_adj);
2983
+ } else {
2984
+ adj_t.adj_extract(tile_coord(i, j, k), adj_ret);
2985
+ }
2986
+ }
2987
+ template<typename Tile, typename AdjTile, typename AdjType>
2988
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, AdjType adj_ret) {
2989
+ if constexpr(is_vector<typename Tile::Type>::value) {
2990
+ typename Tile::Type vector_adj{};
2991
+ vector_adj[l] = adj_ret;
2992
+ adj_t.adj_extract(tile_coord(i, j, k), vector_adj);
2993
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
2994
+ typename Tile::Type matrix_adj{};
2995
+ matrix_adj.data[k][l] = adj_ret;
2996
+ adj_t.adj_extract(tile_coord(i, j), matrix_adj);
2997
+ } else {
2998
+ adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret);
2999
+ }
3000
+ }
3001
+ template<typename Tile, typename AdjTile, typename AdjType>
3002
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, AdjType adj_ret) {
3003
+ if constexpr(is_vector<typename Tile::Type>::value) {
3004
+ typename Tile::Type vector_adj{};
3005
+ vector_adj[m] = adj_ret;
3006
+ adj_t.adj_extract(tile_coord(i, j, k, l), vector_adj);
3007
+ } else if constexpr(is_matrix<typename Tile::Type>::value) {
3008
+ typename Tile::Type matrix_adj{};
3009
+ matrix_adj.data[l][m] = adj_ret;
3010
+ adj_t.adj_extract(tile_coord(i, j, k), matrix_adj);
3011
+ } else {
3012
+ static_assert(always_false<Tile>::value,
3013
+ "adj_tile_extract with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
3014
+ }
3015
+ }
3016
+ template<typename Tile, typename AdjTile, typename AdjType>
3017
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, int m, int n, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, AdjType adj_ret) {
3018
+ if constexpr(is_matrix<typename Tile::Type>::value) {
3019
+ typename Tile::Type matrix_adj{};
3020
+ matrix_adj.data[m][n] = adj_ret;
3021
+ adj_t.adj_extract(tile_coord(i, j, k, l), matrix_adj);
3022
+ } else {
3023
+ static_assert(always_false<Tile>::value,
3024
+ "adj_tile_extract with 6 indices requires a tile of matrices (4D tile)");
3025
+ }
3026
+ }
3027
+
3028
+
3029
+ template<typename Tile>
3030
+ void tile_add_inplace(Tile& t, int i, typename Tile::Type value) { t.add_inplace(tile_coord(i), value); }
3031
+ template<typename Tile>
3032
+ void tile_add_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.add_inplace(tile_coord(i,j), value); }
3033
+ template<typename Tile>
3034
+ void tile_add_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.add_inplace(tile_coord(i,j,k), value); }
3035
+ template<typename Tile>
3036
+ void tile_add_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.add_inplace(tile_coord(i,j,k,l), value); }
3037
+
3038
+ template<typename Tile>
3039
+ void tile_sub_inplace(Tile& t, int i, typename Tile::Type value) { t.sub_inplace(tile_coord(i), value); }
3040
+ template<typename Tile>
3041
+ void tile_sub_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j), value); }
3042
+ template<typename Tile>
3043
+ void tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k), value); }
3044
+ template<typename Tile>
3045
+ void tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.sub_inplace(tile_coord(i,j,k,l), value); }
3046
+
3047
+ template<typename Tile>
3048
+ void tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i), value); }
3049
+ template<typename Tile>
3050
+ void tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j), value); }
3051
+ template<typename Tile>
3052
+ void tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k), value); }
3053
+ template<typename Tile>
3054
+ void tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_and_inplace(tile_coord(i,j,k,l), value); }
3055
+
3056
+ template<typename Tile>
3057
+ void tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i), value); }
3058
+ template<typename Tile>
3059
+ void tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j), value); }
3060
+ template<typename Tile>
3061
+ void tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k), value); }
3062
+ template<typename Tile>
3063
+ void tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_or_inplace(tile_coord(i,j,k,l), value); }
3064
+
3065
+ template<typename Tile>
3066
+ void tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i), value); }
3067
+ template<typename Tile>
3068
+ void tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j), value); }
3069
+ template<typename Tile>
3070
+ void tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k), value); }
3071
+ template<typename Tile>
3072
+ void tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value) { t.bit_xor_inplace(tile_coord(i,j,k,l), value); }
3073
+
3074
+ template<typename Tile, typename AdjTile>
3075
+ void adj_tile_add_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i), adj_value); }
3076
+ template<typename Tile, typename AdjTile>
3077
+ void adj_tile_add_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j), adj_value); }
3078
+ template<typename Tile, typename AdjTile>
3079
+ void adj_tile_add_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j, k), adj_value); }
3080
+ template<typename Tile, typename AdjTile>
3081
+ void adj_tile_add_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_add_inplace(tile_coord(i, j, k, l), adj_value); }
3082
+
3083
+ template<typename Tile, typename AdjTile>
3084
+ void adj_tile_sub_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i), adj_value); }
3085
+ template<typename Tile, typename AdjTile>
3086
+ void adj_tile_sub_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j), adj_value); }
3087
+ template<typename Tile, typename AdjTile>
3088
+ void adj_tile_sub_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k), adj_value); }
3089
+ template<typename Tile, typename AdjTile>
3090
+ void adj_tile_sub_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) { adj_t.adj_sub_inplace(tile_coord(i, j, k, l), adj_value); }
3091
+
3092
+ template<typename Tile, typename AdjTile>
3093
+ void adj_tile_bit_and_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3094
+ template<typename Tile, typename AdjTile>
3095
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3096
+ template<typename Tile, typename AdjTile>
3097
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3098
+ template<typename Tile, typename AdjTile>
3099
+ void adj_tile_bit_and_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3100
+
3101
+ template<typename Tile, typename AdjTile>
3102
+ void adj_tile_bit_or_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3103
+ template<typename Tile, typename AdjTile>
3104
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3105
+ template<typename Tile, typename AdjTile>
3106
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3107
+ template<typename Tile, typename AdjTile>
3108
+ void adj_tile_bit_or_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3109
+
3110
+ template<typename Tile, typename AdjTile>
3111
+ void adj_tile_bit_xor_inplace(Tile& t, int i, typename Tile::Type value, AdjTile& adj_t, int adj_i, typename Tile::Type& adj_value) {}
3112
+ template<typename Tile, typename AdjTile>
3113
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type& adj_value) {}
3114
+ template<typename Tile, typename AdjTile>
3115
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type& adj_value) {}
3116
+ template<typename Tile, typename AdjTile>
3117
+ void adj_tile_bit_xor_inplace(Tile& t, int i, int j, int k, int l, typename Tile::Type value, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type& adj_value) {}
3118
+
3119
+ namespace partitioned_gemm
3120
+ {
3121
+
3122
+ template <typename T>
3123
+ inline CUDA_CALLABLE const T& index(const T* __restrict__ p, int i, int j, int stride)
3124
+ {
3125
+ return p[i*stride + j];
3126
+ }
3127
+
3128
+ template <typename T>
3129
+ inline CUDA_CALLABLE T& index(T* __restrict__ p, int i, int j, int stride)
3130
+ {
3131
+ return p[i*stride + j];
3132
+ }
3133
+
3134
+ template <int PartitionM, int PartitionN, typename Tile>
3135
+ struct partition_t
3136
+ {
3137
+ static constexpr int M = PartitionM;
3138
+ static constexpr int N = PartitionN;
3139
+ static constexpr int Stride = Tile::Layout::Shape::dim(1);
3140
+
3141
+ using T = typename Tile::Type;
3142
+
3143
+ inline partition_t(Tile& A)
3144
+ {
3145
+ data = A.data.ptr;
3146
+
3147
+ // todo: do ceil div for non-multiples of M,N
3148
+ shape[0] = Tile::Layout::Shape::dim(0)/PartitionM;
3149
+ shape[1] = Tile::Layout::Shape::dim(1)/PartitionN;
3150
+ }
3151
+
3152
+ // underlying data
3153
+ T* data;
3154
+
3155
+ // partition dimensions
3156
+ int shape[2];
3157
+ };
3158
+
3159
+ template <typename Partition>
3160
+ inline int partition_size(const Partition& part)
3161
+ {
3162
+ return part.shape[0]*part.shape[1];
3163
+ }
3164
+
3165
+ // returns the x, y coordinates of a tile given a linear index
3166
+ template <typename Partition>
3167
+ inline void partition_coord(const Partition& part, const int t, int& i, int& j)
3168
+ {
3169
+ i = t/part.shape[1];
3170
+ j = t%part.shape[1];
3171
+ }
3172
+
3173
+ template <typename Partition>
3174
+ inline auto partition_load(const Partition& tile, int i, int j)
3175
+ {
3176
+ mat_t<Partition::M, Partition::N, typename Partition::T> out;
3177
+
3178
+ const int tile_i = i*Partition::M;
3179
+ const int tile_j = j*Partition::N;
3180
+
3181
+ WP_PRAGMA_UNROLL
3182
+ for (int i=0; i < Partition::M; ++i)
3183
+ {
3184
+ WP_PRAGMA_UNROLL
3185
+ for (int j=0; j < Partition::N; ++j)
3186
+ {
3187
+ out.data[i][j] = partitioned_gemm::index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
3188
+ }
3189
+ }
3190
+
3191
+ return out;
3192
+ }
3193
+
3194
+ template <typename Partition, typename Value>
3195
+ inline void partition_store(const Partition& tile, int i, int j, const Value& value)
3196
+ {
3197
+ const int tile_i = Partition::M*i;
3198
+ const int tile_j = Partition::N*j;
3199
+
3200
+ WP_PRAGMA_UNROLL
3201
+ for (int i=0; i < Partition::M; ++i)
3202
+ {
3203
+ WP_PRAGMA_UNROLL
3204
+ for (int j=0; j < Partition::N; ++j)
3205
+ {
3206
+ index(tile.data, tile_i + i, tile_j + j, Partition::Stride) = value.data[i][j];
3207
+ }
3208
+ }
3209
+ }
3210
+
3211
+
3212
+ template <typename TileA, typename TileB, typename TileC>
3213
+ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
3214
+ {
3215
+ const int TILE_M = 4;
3216
+ const int TILE_N = 4;
3217
+ const int TILE_K = 4;
3218
+
3219
+ auto A_tile = partition_t<TILE_M, TILE_K, TileA>(A);
3220
+ auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
3221
+ auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
3222
+
3223
+ //static_assert(is_same<typename TileA::Type, typename TileB::Type>::value);
3224
+
3225
+ const int length = partition_size(C_tile);
3226
+
3227
+ for (int t=WP_TILE_THREAD_IDX; t < length; t += WP_TILE_BLOCK_DIM)
3228
+ {
3229
+ int i, j;
3230
+ partition_coord(C_tile, t, i, j);
3231
+
3232
+ // accumulator
3233
+ auto sum = partition_load(C_tile, i, j);
3234
+
3235
+ WP_PRAGMA_UNROLL
3236
+ for (int k=0; k < A_tile.shape[1]; k++)
3237
+ {
3238
+ const auto a = partition_load(A_tile, i, k);
3239
+ const auto b = partition_load(B_tile, k, j);
3240
+
3241
+ sum += mul(a, b);
3242
+ }
3243
+
3244
+ partition_store(C_tile, i, j, sum);
3245
+ }
3246
+ }
3247
+
3248
+ template <typename LayoutA, typename LayoutB, typename LayoutC, typename StorageA, typename StorageB, typename StorageC, typename T>
3249
+ inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T scale)
3250
+ {
3251
+ for (int t=WP_TILE_THREAD_IDX; t < LayoutC::Size; t += WP_TILE_BLOCK_DIM)
3252
+ {
3253
+ auto coord = LayoutC::coord_from_linear(t);
3254
+
3255
+ int i = coord[0];
3256
+ int j = coord[1];
3257
+
3258
+ // accumulator
3259
+ auto sum = C(coord)*scale;
3260
+
3261
+ WP_PRAGMA_UNROLL
3262
+ for (int k=0; k < LayoutA::Shape::dim(1); k++)
3263
+ {
3264
+ const auto a = A(tile_coord(i, k));
3265
+ const auto b = B(tile_coord(k, j));
3266
+
3267
+ sum = muladd<decltype(sum)>(a, b, sum);
3268
+ }
3269
+
3270
+ C(coord) = sum;
3271
+ }
3272
+ }
3273
+
3274
+ template <typename TileA, typename TileL>
3275
+ inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
3276
+ {
3277
+ using T = typename TileA::Type;
3278
+ constexpr int n = TileA::Layout::Shape::dim(1);
3279
+
3280
+ for (int j=0; j < n; ++j)
3281
+ {
3282
+ T s = A.data(tile_coord(j, j));
3283
+
3284
+ for (int k=0; k < j; ++k)
3285
+ {
3286
+ T r = L.data(tile_coord(j, k));
3287
+ s -= r * r;
3288
+ }
3289
+
3290
+ s = wp::sqrt(s);
3291
+ T invS = 1.0 / s;
3292
+
3293
+ L.data(tile_coord(j, j)) = s;
3294
+
3295
+ for (int i=j+1; i < n; ++i)
3296
+ {
3297
+ s = A.data(tile_coord(i, j));
3298
+
3299
+ for (int k=0; k < j; ++k)
3300
+ {
3301
+ s -= L.data(tile_coord(i, k)) * L.data(tile_coord(j, k));
3302
+ }
3303
+
3304
+ L.data(tile_coord(i, j)) = s * invS;
3305
+ }
3306
+
3307
+ // zero out upper triangular portion
3308
+ for (int k=j+1; k < n; ++k)
3309
+ {
3310
+ L.data(tile_coord(j,k)) = T(0.0);
3311
+ }
3312
+ }
3313
+ }
3314
+
3315
+ // Writes into X
3316
+ template <typename TileL, typename TileX, typename TileY>
3317
+ inline CUDA_CALLABLE void scalar_cholesky_forward_substitution(TileL& L, TileX& X, TileY& Y)
3318
+ {
3319
+ using T = typename TileL::Type;
3320
+
3321
+ if constexpr (TileY::Layout::Shape::N == 1)
3322
+ {
3323
+ constexpr int n = TileL::Layout::Shape::dim(1);
3324
+
3325
+ for (int i=0; i < n; ++i)
3326
+ {
3327
+ T s = Y.data(tile_coord(i));
3328
+
3329
+ for (int j=0; j < i; ++j)
3330
+ s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
3331
+
3332
+ T diag = L.data(tile_coord(i, i));
3333
+ X.data(tile_coord(i)) = (diag != T(0.0f)) ? s / diag : s;
3334
+ }
3335
+ }
3336
+ else if constexpr (TileY::Layout::Shape::N == 2)
3337
+ {
3338
+ constexpr int n = TileL::Layout::Shape::dim(1);
3339
+ constexpr int m = TileY::Layout::Shape::dim(1);
3340
+
3341
+ for (int k=0; k < m; ++k)
3342
+ {
3343
+ for (int i=0; i < n; ++i)
3344
+ {
3345
+ T s = Y.data(tile_coord(i,k));
3346
+
3347
+ for (int j=0; j < i; ++j)
3348
+ s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j,k));
3349
+
3350
+ T diag = L.data(tile_coord(i, i));
3351
+ X.data(tile_coord(i,k)) = (diag != T(0.0f)) ? s / diag : s;
3352
+ }
3353
+ }
3354
+ }
3355
+ }
3356
+
3357
+ // Reads and writes X
3358
+ template <typename TileL, typename TileX>
3359
+ inline CUDA_CALLABLE void scalar_cholesky_back_substitution(TileL& L, TileX& X)
3360
+ {
3361
+ using T = typename TileL::Type;
3362
+
3363
+ if constexpr (TileX::Layout::Shape::N == 1)
3364
+ {
3365
+ constexpr int n = TileL::Layout::Shape::dim(1);
3366
+
3367
+ for (int i=n-1; i >= 0; --i)
3368
+ {
3369
+ T s = X.data(tile_coord(i));
3370
+
3371
+ for (int j=i+1; j < n; ++j)
3372
+ s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
3373
+
3374
+ T diag = L.data(tile_coord(i, i));
3375
+ X.data(tile_coord(i)) = (diag != T(0.0f)) ? s / diag : s;
3376
+ }
3377
+ }
3378
+ else if constexpr (TileX::Layout::Shape::N == 2)
3379
+ {
3380
+ constexpr int n = TileL::Layout::Shape::dim(1);
3381
+ constexpr int m = TileX::Layout::Shape::dim(1);
3382
+
3383
+ for (int k=0; k < m; ++k)
3384
+ {
3385
+ for (int i=n-1; i >= 0; --i)
3386
+ {
3387
+ T s = X.data(tile_coord(i,k));
3388
+
3389
+ for (int j=i+1; j < n; ++j)
3390
+ s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j,k));
3391
+
3392
+ T diag = L.data(tile_coord(i, i));
3393
+ X.data(tile_coord(i,k)) = (diag != T(0.0f)) ? s / diag : s;
3394
+ }
3395
+ }
3396
+ }
3397
+ }
3398
+
3399
+ template <typename TileL, typename TileX, typename TileY>
3400
+ inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
3401
+ {
3402
+ scalar_cholesky_forward_substitution(L, X, Y);
3403
+ scalar_cholesky_back_substitution(L, X);
3404
+ }
3405
+
3406
+
3407
+ } // namespace partition_gemm
3408
+
3409
+
3410
+ template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
3411
+ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C)
3412
+ {
3413
+ using ShapeA = typename TileA::Layout::Shape;
3414
+ using ShapeB = typename TileB::Layout::Shape;
3415
+ using ShapeC = typename TileC::Layout::Shape;
3416
+
3417
+ static_assert(ShapeA::N == 2, "Expected ShapeA::N == 2");
3418
+ static_assert(ShapeB::N == 2, "Expected ShapeB::N == 2");
3419
+ static_assert(ShapeC::N == 2, "Expected ShapeC::N == 2");
3420
+
3421
+ static_assert(ShapeA::dim(1) == ShapeB::dim(0), "Expected ShapeA::dim(1) == ShapeB::dim(0)");
3422
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
3423
+ static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
3424
+
3425
+
3426
+ using T = typename TileC::Type;
3427
+
3428
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
3429
+ partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
3430
+ #else
3431
+ T alpha = T(1.0);
3432
+ T beta = T(Add);
3433
+ fun_forward(&alpha, A.data.ptr, B.data.ptr, &beta, C.data.ptr);
3434
+ #endif
3435
+
3436
+ WP_TILE_SYNC();
3437
+
3438
+ return C;
3439
+ }
3440
+
3441
+
3442
+ // backward for the wp.tile_matmul(a, b, out) syntax
3443
+ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
3444
+ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
3445
+ Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C)
3446
+ {
3447
+ using T_A = typename TileA::Type;
3448
+ using T_B = typename TileB::Type;
3449
+
3450
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
3451
+ auto At = tile_transpose(A);
3452
+ auto Bt = tile_transpose(B);
3453
+
3454
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T_A(1.0));
3455
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T_B(1.0));
3456
+ #else
3457
+ T_A alpha_A = T_A(1.0);
3458
+ T_A beta_A = T_A(1.0);
3459
+ fun_backward_A(&alpha_A, adj_C.grad.ptr, B.data.ptr, &beta_A, adj_A.grad.ptr);
3460
+ T_B alpha_B = T_B(1.0);
3461
+ T_B beta_B = T_B(1.0);
3462
+ fun_backward_B(&alpha_B, A.data.ptr, adj_C.grad.ptr, &beta_B, adj_B.grad.ptr);
3463
+ #endif
3464
+
3465
+ WP_TILE_SYNC();
3466
+ }
3467
+
3468
+ // backward for the out = wp.tile_matmul(a, b) syntax
3469
+ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
3470
+ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
3471
+ Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C, TileC& adj_ret)
3472
+ {
3473
+ using T = typename TileC::Type;
3474
+
3475
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
3476
+ auto At = tile_transpose(A);
3477
+ auto Bt = tile_transpose(B);
3478
+
3479
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
3480
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
3481
+ #else
3482
+ T alpha = T(1.0);
3483
+ T beta = T(1.0);
3484
+ fun_backward_A(&alpha, adj_C.grad.ptr, B.data.ptr, &beta, adj_A.grad.ptr);
3485
+ fun_backward_B(&alpha, A.data.ptr, adj_C.grad.ptr, &beta, adj_B.grad.ptr);
3486
+ #endif
3487
+
3488
+ WP_TILE_SYNC();
3489
+ }
3490
+
3491
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
3492
+
3493
+ #define tile_fft()
3494
+ #define tile_ifft()
3495
+
3496
+ #define adj_tile_fft()
3497
+ #define adj_tile_ifft()
3498
+
3499
+ #else
3500
+
3501
+ // TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
3502
+ // and remove the need for __align__(16) dtypes data[...]
3503
+ #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
3504
+ do { \
3505
+ void function_name(dtype*, char*); \
3506
+ char* buffer = (char*)wp::tile_shared_storage_t::alloc(shared_memory_size); \
3507
+ __align__(16) dtype data[ept]; \
3508
+ for(int b = 0; b < (int)batch_size; b++) { \
3509
+ dtype* inout = Xinout.data + (int)b * (int)ept; \
3510
+ memcpy(data, inout, sizeof(dtype) * ept); \
3511
+ function_name(data, buffer); \
3512
+ memcpy(inout, data, sizeof(dtype) * ept); \
3513
+ WP_TILE_SYNC(); \
3514
+ } \
3515
+ wp::tile_shared_storage_t::alloc(-shared_memory_size); \
3516
+ } while (0)
3517
+
3518
+ #define tile_ifft tile_fft
3519
+
3520
+ // adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept are all ignored
3521
+
3522
+ #define adj_tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
3523
+ adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
3524
+ adj_Xinout) \
3525
+ do { \
3526
+ tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
3527
+ } while (0)
3528
+
3529
+ #define adj_tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
3530
+ adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
3531
+ adj_Xinout) \
3532
+ do { \
3533
+ tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
3534
+ } while (0)
3535
+
3536
+ #endif // !defined(__CUDA_ARCH__)
3537
+
3538
+ template <typename Fwd, typename TileA, typename TileL>
3539
+ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
3540
+ {
3541
+ static_assert(TileA::Layout::Shape::N == 2, "Expected TileA::Layout::Shape::N == 2");
3542
+ static_assert(TileL::Layout::Shape::N == 2, "Expected TileL::Layout::Shape::N == 2");
3543
+
3544
+ static_assert(TileA::Layout::Shape::dim(0) == TileA::Layout::Shape::dim(1), "Expected TileA to be square");
3545
+ static_assert(TileL::Layout::Shape::dim(0) == TileL::Layout::Shape::dim(1), "Expected TileL to be square");
3546
+ static_assert(TileA::Layout::Shape::dim(0) == TileL::Layout::Shape::dim(0), "Expected A and L to have the same number of rows");
3547
+ static_assert(TileA::Layout::Shape::dim(1) == TileL::Layout::Shape::dim(1), "Expected A and L to have the same number of columns");
3548
+
3549
+ // Copy to L
3550
+ L = A;
3551
+
3552
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
3553
+
3554
+ partitioned_gemm::scalar_cholesky(A, L);
3555
+
3556
+ #else
3557
+
3558
+ // TODO: for batched Cholesky, need one info per batch
3559
+ __shared__ int info[1];
3560
+
3561
+ if (WP_TILE_THREAD_IDX == 0) {
3562
+ info[0] = 0;
3563
+ }
3564
+
3565
+ // Call cholesky on L
3566
+ WP_TILE_SYNC();
3567
+
3568
+ fun_forward(L.data.ptr, info);
3569
+
3570
+ WP_TILE_SYNC();
3571
+
3572
+ // TODO: for batched Cholesky, check all batches
3573
+ #if defined(_DEBUG)
3574
+ if (WP_TILE_THREAD_IDX == 0 && info[0] != 0) {
3575
+ printf("Non-zero status in Cholesky factorization, got %d\n", info[0]);
3576
+ }
3577
+ #endif
3578
+
3579
+ // Zero-out the upper triangular part of L
3580
+
3581
+ WP_PRAGMA_UNROLL
3582
+ for (int i=WP_TILE_THREAD_IDX; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
3583
+ {
3584
+ auto c = TileL::Layout::coord_from_linear(i);
3585
+
3586
+ if(c[0] < c[1])
3587
+ L.data(c) = 0.0;
3588
+ }
3589
+
3590
+ WP_TILE_SYNC();
3591
+
3592
+ #endif
3593
+
3594
+ return L;
3595
+ }
3596
+
3597
+ #define adj_tile_cholesky(function_name, A, L, \
3598
+ adj_function_name, adj_A, adj_L, adj_ret) \
3599
+ do { \
3600
+ assert(false); \
3601
+ } while (0)
3602
+
3603
+ template <typename Fwd, typename TileL, typename TileX, typename TileY>
3604
+ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& Y, TileY& X)
3605
+ {
3606
+ // Copy y to x
3607
+
3608
+ X = Y;
3609
+
3610
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
3611
+
3612
+ partitioned_gemm::scalar_cholesky_solve(L, X, Y);
3613
+
3614
+ #else
3615
+
3616
+ // Call cholesky solve on L & x
3617
+
3618
+ WP_TILE_SYNC();
3619
+
3620
+ fun_forward(L.data.ptr, X.data.ptr); \
3621
+
3622
+ WP_TILE_SYNC();
3623
+
3624
+ #endif
3625
+
3626
+ return X;
3627
+ }
3628
+
3629
+ #define adj_tile_cholesky_solve(function_name, L, Y, X, \
3630
+ adj_function_name, adj_L, adj_Y, adj_X, adj_ret) \
3631
+ do { \
3632
+ assert(false); \
3633
+ } while (0)
3634
+
3635
+
3636
+
3637
+
3638
+
3639
+
3640
+ template <typename Fwd, typename TileL, typename TileY, typename TileZ>
3641
+ TileZ& tile_lower_solve(Fwd fun_forward, TileL& L, TileY& y, TileZ& z)
3642
+ {
3643
+ // Copy y to z
3644
+ z = y;
3645
+
3646
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
3647
+
3648
+ partitioned_gemm::scalar_cholesky_forward_substitution(L, z, y);
3649
+
3650
+ #else
3651
+
3652
+ // Call cholesky solve on L & z
3653
+
3654
+ WP_TILE_SYNC();
3655
+
3656
+ fun_forward(L.data.ptr, z.data.ptr);
3657
+
3658
+ WP_TILE_SYNC();
3659
+
3660
+ #endif
3661
+
3662
+ return z;
3663
+ }
3664
+
3665
+ #define adj_tile_lower_solve(function_name, L, y, z, \
3666
+ adj_function_name, adj_L, adj_y, adj_z, adj_ret) \
3667
+ do { \
3668
+ assert(false); \
3669
+ } while (0)
3670
+
3671
+
3672
+
3673
+ template <typename Fwd, typename TileU, typename TileZ, typename TileX>
3674
+ TileX& tile_upper_solve(Fwd fun_forward, TileU& U, TileZ& z, TileX& x)
3675
+ {
3676
+ // Copy z to x
3677
+ x = z;
3678
+
3679
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
3680
+
3681
+ auto L = tile_transpose(U);
3682
+ partitioned_gemm::scalar_cholesky_back_substitution(L, x);
3683
+
3684
+ #else
3685
+
3686
+ // Call cholesky solve on U & x
3687
+
3688
+ WP_TILE_SYNC();
3689
+
3690
+ fun_forward(U.data.ptr, x.data.ptr);
3691
+
3692
+ WP_TILE_SYNC();
3693
+
3694
+ #endif
3695
+
3696
+ return x;
3697
+ }
3698
+
3699
+ #define adj_tile_upper_solve(function_name, U, z, x, \
3700
+ adj_function_name, adj_U, adj_z, adj_x, adj_ret) \
3701
+ do { \
3702
+ assert(false); \
3703
+ } while (0)
3704
+
3705
+
3706
+
3707
+
3708
+
3709
+
3710
+ template <typename Tile>
3711
+ inline CUDA_CALLABLE auto tile_transpose(Tile& t)
3712
+ {
3713
+ static_assert(Tile::Layout::Shape::N == 2, "Expected Tile::Layout::Shape::N == 2");
3714
+
3715
+ // alias incoming tile
3716
+ constexpr int M = Tile::Layout::Shape::dim(0);
3717
+ constexpr int N = Tile::Layout::Shape::dim(1);
3718
+
3719
+ constexpr int StrideM = Tile::Layout::Stride::dim(0);
3720
+ constexpr int StrideN = Tile::Layout::Stride::dim(1);
3721
+
3722
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N,M>, tile_stride_t<StrideN, StrideM>>, false>(t.data.ptr, t.grad.ptr);
3723
+ }
3724
+
3725
+ template <typename Tile, typename AdjTile>
3726
+ inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
3727
+ {
3728
+ auto a = tile_transpose(adj_ret);
3729
+ auto& b = adj_t;
3730
+
3731
+ adj_t.assign(tile_add(a,b));
3732
+ }
3733
+
3734
+ template <int N, int StrideN, typename Tile>
3735
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
3736
+ {
3737
+ // alias incoming tile with new strides
3738
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N>, tile_stride_t<StrideN>>, false>(t.data.ptr, t.grad.ptr);
3739
+ }
3740
+
3741
+ template <int M, int N, int StrideM, int StrideN, typename Tile>
3742
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
3743
+ {
3744
+ // alias incoming tile with new strides
3745
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
3746
+ }
3747
+
3748
+ template <int M, int N, int O, int StrideM, int StrideN, int StrideO, typename Tile>
3749
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
3750
+ {
3751
+ // alias incoming tile with new strides
3752
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O>, tile_stride_t<StrideM, StrideN, StrideO>>, false>(t.data.ptr, t.grad.ptr);
3753
+ }
3754
+
3755
+ template <int M, int N, int O, int P, int StrideM, int StrideN, int StrideO, int StrideP, typename Tile>
3756
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
3757
+ {
3758
+ // alias incoming tile with new strides
3759
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O, P>, tile_stride_t<StrideM, StrideN, StrideO, StrideP>>, false>(t.data.ptr, t.grad.ptr);
3760
+ }
3761
+
3762
+ template <typename Tile, typename AdjTile>
3763
+ inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
3764
+ {
3765
+ // nop, since memory is aliased, grads already accumulated
3766
+ }
3767
+
3768
+
3769
+ template <typename ReturnTile, typename Tile, typename... Indices>
3770
+ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
3771
+ {
3772
+ auto c = tile_coord(indices...);
3773
+
3774
+ // return new tile with same strides
3775
+ typename Tile::Type* data_ptr = &t.data(c);
3776
+ typename Tile::Type* grad_ptr = nullptr;
3777
+
3778
+ if (t.grad.ptr)
3779
+ grad_ptr = &t.grad(c);
3780
+
3781
+ return ReturnTile(data_ptr, grad_ptr);
3782
+ }
3783
+
3784
+
3785
+ template <typename ReturnTile, typename Tile>
3786
+ inline CUDA_CALLABLE auto tile_squeeze(Tile& t)
3787
+ {
3788
+ // ReturnTile layout is set in builtins.py
3789
+ typename Tile::Type* data_ptr = t.data.ptr;
3790
+ typename Tile::Type* grad_ptr = nullptr;
3791
+
3792
+ if (t.grad.ptr)
3793
+ grad_ptr = t.grad.ptr;
3794
+
3795
+ return ReturnTile(data_ptr, grad_ptr);
3796
+ }
3797
+
3798
+ template <typename Tile, typename AdjTile, typename AdjReturnTile>
3799
+ inline CUDA_CALLABLE void adj_tile_squeeze(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
3800
+ {
3801
+ // nop, since memory is aliased, grads already accumulated
3802
+ }
3803
+
3804
+
3805
+ template <typename ReturnTile, typename Tile>
3806
+ inline CUDA_CALLABLE auto tile_reshape(Tile& t)
3807
+ {
3808
+ // ReturnTile layout is set in builtins.py
3809
+ typename Tile::Type* data_ptr = t.data.ptr;
3810
+ typename Tile::Type* grad_ptr = nullptr;
3811
+
3812
+ if (t.grad.ptr)
3813
+ grad_ptr = t.grad.ptr;
3814
+
3815
+ return ReturnTile(data_ptr, grad_ptr);
3816
+ }
3817
+
3818
+ template <typename Tile, typename AdjTile, typename AdjReturnTile>
3819
+ inline CUDA_CALLABLE void adj_tile_reshape(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
3820
+ {
3821
+ // nop, since memory is aliased, grads already accumulated
3822
+ }
3823
+
3824
+
3825
+ template <typename ReturnTile, typename Tile>
3826
+ inline CUDA_CALLABLE auto tile_astype(Tile& t)
3827
+ {
3828
+ // verify shapes and sizes are compatible
3829
+ using ShapeIn = typename Tile::Layout::Shape;
3830
+ using ShapeOut = typename ReturnTile::Layout::Shape;
3831
+
3832
+ static_assert(ShapeIn::N == ShapeOut::N, "Tile shapes must match for data type casting");
3833
+ static_assert(ShapeIn::size() == ShapeOut::size(), "Tile sizes must match for data type casting");
3834
+
3835
+ // work with register tiles for type casting
3836
+ auto t_reg = t.copy_to_register();
3837
+ auto result = tile_register_like<ReturnTile>();
3838
+
3839
+ using Layout = typename decltype(result)::Layout;
3840
+
3841
+ WP_PRAGMA_UNROLL
3842
+ for (int i = 0; i < Layout::NumRegs; ++i)
3843
+ {
3844
+ const int linear = Layout::linear_from_register(i);
3845
+
3846
+ if(!Layout::valid(linear))
3847
+ break;
3848
+
3849
+ result.data[i] = static_cast<typename ReturnTile::Type>(t_reg.data[i]);
3850
+ }
3851
+
3852
+ return result;
3853
+ }
3854
+
3855
+ template <typename Tile, typename AdjTile, typename AdjReturnTile>
3856
+ inline CUDA_CALLABLE void adj_tile_astype(Tile& t, AdjTile& adj_t, AdjReturnTile& adj_ret)
3857
+ {
3858
+ // gradients only flow between float conversions
3859
+ if constexpr((is_same<typename AdjTile::Type, wp::float16>::value ||
3860
+ is_same<typename AdjTile::Type, wp::float32>::value ||
3861
+ is_same<typename AdjTile::Type, wp::float64>::value) &&
3862
+ (is_same<typename AdjReturnTile::Type, wp::float16>::value ||
3863
+ is_same<typename AdjReturnTile::Type, wp::float32>::value ||
3864
+ is_same<typename AdjReturnTile::Type, wp::float64>::value))
3865
+ {
3866
+ auto adj_ret_reg = adj_ret.grad_to_register();
3867
+ auto adj_t_reg = tile_register_like<AdjTile>();
3868
+
3869
+ using Layout = typename decltype(adj_t_reg)::Layout;
3870
+
3871
+ WP_PRAGMA_UNROLL
3872
+ for (int i = 0; i < Layout::NumRegs; ++i)
3873
+ {
3874
+ adj_t_reg.data[i] += static_cast<typename AdjTile::Type>(adj_ret_reg.data[i]);
3875
+ }
3876
+
3877
+ adj_t.grad_add(adj_t_reg);
3878
+ }
3879
+ }
3880
+
3881
+
3882
+ template <typename TileA, typename Scalar>
3883
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
3884
+ {
3885
+ dest.data(tile_coord(i)) = src;
3886
+ WP_TILE_SYNC();
3887
+ }
3888
+ template <typename TileA, typename Scalar>
3889
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
3890
+ {
3891
+ if constexpr(is_vector<typename TileA::Type>::value) {
3892
+ dest.data(tile_coord(i))[j] = src;
3893
+ } else {
3894
+ dest.data(tile_coord(i, j)) = src;
3895
+ }
3896
+ WP_TILE_SYNC();
3897
+ }
3898
+ template <typename TileA, typename Scalar>
3899
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
3900
+ {
3901
+ if constexpr(is_vector<typename TileA::Type>::value) {
3902
+ dest.data(tile_coord(i, j))[k] = src;
3903
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3904
+ dest.data(tile_coord(i)).data[j][k] = src;
3905
+ } else {
3906
+ dest.data(tile_coord(i, j, k)) = src;
3907
+ }
3908
+ WP_TILE_SYNC();
3909
+ }
3910
+ template <typename TileA, typename Scalar>
3911
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
3912
+ {
3913
+ if constexpr(is_vector<typename TileA::Type>::value) {
3914
+ dest.data(tile_coord(i, j, k))[l] = src;
3915
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3916
+ dest.data(tile_coord(i, j)).data[k][l] = src;
3917
+ } else {
3918
+ dest.data(tile_coord(i, j, k, l)) = src;
3919
+ }
3920
+ WP_TILE_SYNC();
3921
+ }
3922
+ template <typename TileA, typename Scalar>
3923
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src)
3924
+ {
3925
+ if constexpr(is_vector<typename TileA::Type>::value) {
3926
+ dest.data(tile_coord(i, j, k, l))[m] = src;
3927
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3928
+ dest.data(tile_coord(i, j, k)).data[l][m] = src;
3929
+ } else {
3930
+ static_assert(always_false<TileA>::value,
3931
+ "assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
3932
+ }
3933
+ WP_TILE_SYNC();
3934
+ }
3935
+ template <typename TileA, typename Scalar>
3936
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src)
3937
+ {
3938
+ if constexpr(is_matrix<typename TileA::Type>::value) {
3939
+ dest.data(tile_coord(i, j, k, l)).data[m][n] = src;
3940
+ } else {
3941
+ static_assert(always_false<TileA>::value,
3942
+ "assign with 6 indices requires a tile of matrices (4D tile)");
3943
+ }
3944
+ WP_TILE_SYNC();
3945
+ }
3946
+
3947
+
3948
+ template <typename TileA, typename AdjTileA, typename Scalar>
3949
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, const Scalar& src, AdjTileA& adj_dest, int adj_i, Scalar& adj_src)
3950
+ {
3951
+ if (dest.grad.ptr == nullptr)
3952
+ {
3953
+ return;
3954
+ }
3955
+
3956
+ adj_src += dest.grad(tile_coord(i));
3957
+ }
3958
+ template <typename TileA, typename AdjTileA, typename Scalar>
3959
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, Scalar& adj_src)
3960
+ {
3961
+ if (dest.grad.ptr == nullptr)
3962
+ {
3963
+ return;
3964
+ }
3965
+
3966
+ if constexpr(is_vector<typename TileA::Type>::value) {
3967
+ adj_src += dest.grad(tile_coord(i))[j];
3968
+ } else {
3969
+ adj_src += dest.grad(tile_coord(i, j));
3970
+ }
3971
+ }
3972
+ template <typename TileA, typename AdjTileA, typename Scalar>
3973
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, Scalar& adj_src)
3974
+ {
3975
+ if (dest.grad.ptr == nullptr)
3976
+ {
3977
+ return;
3978
+ }
3979
+
3980
+ if constexpr(is_vector<typename TileA::Type>::value) {
3981
+ adj_src += dest.grad(tile_coord(i, j))[k];
3982
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3983
+ adj_src += dest.grad(tile_coord(i)).data[j][k];
3984
+ } else {
3985
+ adj_src += dest.grad(tile_coord(i, j, k));
3986
+ }
3987
+ }
3988
+ template <typename TileA, typename AdjTileA, typename Scalar>
3989
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, Scalar& adj_src)
3990
+ {
3991
+ if (dest.grad.ptr == nullptr)
3992
+ {
3993
+ return;
3994
+ }
3995
+
3996
+ if constexpr(is_vector<typename TileA::Type>::value) {
3997
+ adj_src += dest.grad(tile_coord(i, j, k))[l];
3998
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
3999
+ adj_src += dest.grad(tile_coord(i, j)).data[k][l];
4000
+ } else {
4001
+ adj_src += dest.grad(tile_coord(i, j, k, l));
4002
+ }
4003
+ }
4004
+ template <typename TileA, typename AdjTileA, typename Scalar>
4005
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, Scalar& adj_src)
4006
+ {
4007
+ if (dest.grad.ptr == nullptr)
4008
+ {
4009
+ return;
4010
+ }
4011
+
4012
+ if constexpr(is_vector<typename TileA::Type>::value) {
4013
+ adj_src += dest.grad(tile_coord(i, j, k, l))[m];
4014
+ } else if constexpr(is_matrix<typename TileA::Type>::value) {
4015
+ adj_src += dest.grad(tile_coord(i, j, k)).data[l][m];
4016
+ } else {
4017
+ static_assert(always_false<TileA>::value,
4018
+ "adj_assign with 5 indices requires a tile of vectors (4D tile) or matrices (3D tile)");
4019
+ }
4020
+ }
4021
+ template <typename TileA, typename AdjTileA, typename Scalar>
4022
+ inline CUDA_CALLABLE void adj_assign(TileA& dest, int i, int j, int k, int l, int m, int n, const Scalar& src, AdjTileA& adj_dest, int adj_i, int adj_j, int adj_k, int adj_l, int adj_m, int adj_n, Scalar& adj_src)
4023
+ {
4024
+ if (dest.grad.ptr == nullptr)
4025
+ {
4026
+ return;
4027
+ }
4028
+
4029
+ if constexpr(is_matrix<typename TileA::Type>::value) {
4030
+ adj_src += dest.grad(tile_coord(i, j, k, l)).data[m][n];
4031
+ } else {
4032
+ static_assert(always_false<TileA>::value,
4033
+ "adj_assign with 6 indices requires a tile of matrices (4D tile)");
4034
+ }
4035
+ }
4036
+
4037
+ template <typename TileA, typename TileB, typename Coord>
4038
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offset)
4039
+ {
4040
+ using Layout = typename TileB::Layout;
4041
+
4042
+ for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
4043
+ {
4044
+ auto c = Layout::coord_from_linear(t);
4045
+ dest.data(c + offset) = src.data(c);
4046
+ }
4047
+
4048
+ WP_TILE_SYNC();
4049
+ }
4050
+
4051
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename Coord, typename AdjCoord>
4052
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
4053
+ AdjTileA& adj_dest, AdjTileB& adj_src, AdjCoord adj_offset)
4054
+ {
4055
+ using Layout = typename TileB::Layout;
4056
+
4057
+ for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
4058
+ {
4059
+ auto c = Layout::coord_from_linear(t);
4060
+ src.grad(c) += dest.grad(c + offset);
4061
+ }
4062
+
4063
+ WP_TILE_SYNC();
4064
+ }
4065
+
4066
+
4067
+ // codegen entry points, which emit calls like `tile_assign(dest, src, i, j, k)`
4068
+ // a better approach here would be for codegen to just directly generate `tile_assign(dest, src, tile_coord(i, j, k))`
4069
+ // i.e.: call the above implementation methods directly, then we could remove these overloads
4070
+ template <typename TileA, typename TileB>
4071
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i) { tile_assign(dest, src, tile_coord(i)); }
4072
+ template <typename TileA, typename TileB>
4073
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j) { tile_assign(dest, src, tile_coord(i, j)); }
4074
+ template <typename TileA, typename TileB>
4075
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k) { tile_assign(dest, src, tile_coord(i, j, k)); }
4076
+ template <typename TileA, typename TileB>
4077
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l) { tile_assign(dest, src, tile_coord(i, j, k, l)); }
4078
+
4079
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
4080
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, AdjTileA& adj_dest, AdjTileB& adj_src, int) { adj_tile_assign(dest, src, tile_coord(i), adj_dest, adj_src, tile_coord(0)); }
4081
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
4082
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, AdjTileA& adj_dest, AdjTileB& adj_src, int, int) { adj_tile_assign(dest, src, tile_coord(i,j), adj_dest, adj_src, tile_coord(0)); }
4083
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
4084
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k), adj_dest, adj_src, tile_coord(0)); }
4085
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
4086
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k,l), adj_dest, adj_src, tile_coord(0)); }
4087
+
4088
+
4089
+ template <typename TileA, typename TileB, typename TileC>
4090
+ inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
4091
+ {
4092
+ using ShapeA = typename TileA::Layout::Shape;
4093
+ using ShapeB = typename TileB::Layout::Shape;
4094
+ using ShapeC = typename TileC::Layout::Shape;
4095
+
4096
+ static_assert(ShapeA::dim(0) == ShapeA::dim(1), "Expected ShapeA::dim(0) == ShapeA::dim(1)");
4097
+ static_assert(ShapeB::dim(0) == ShapeA::dim(0), "Expected ShapeB::dim(0) == ShapeA::dim(0)");
4098
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
4099
+ static_assert(ShapeC::dim(0) == ShapeC::dim(1), "Expected ShapeC::dim(0) == ShapeC::dim(1)");
4100
+
4101
+ c = a;
4102
+
4103
+ for (int t=WP_TILE_THREAD_IDX; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
4104
+ {
4105
+ c.data(tile_coord(t, t)) += b.data(tile_coord(t));
4106
+ }
4107
+
4108
+ WP_TILE_SYNC();
4109
+
4110
+ return c;
4111
+ }
4112
+
4113
+ template <typename TileA, typename TileB, typename TileC, typename AdjTileA, typename AdjTileB, typename AdjTileC>
4114
+ inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c, AdjTileC& adj_ret)
4115
+ {
4116
+ }
4117
+
4118
+
4119
+ } // namespace wp
4120
+
4121
+
4122
+ #ifdef __clang__
4123
+ #pragma clang diagnostic pop
4124
+ #endif