warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/tile.h ADDED
@@ -0,0 +1,2584 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "builtin.h"
21
+
22
+ #ifdef __clang__
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif // __clang__
27
+
28
+ // Check if the CUDA toolkit is available
29
+ #if WP_ENABLE_CUDA || defined(__CUDACC_RTC__)
30
+
31
+ // If NVRTC is being used, do not include extra headers (NVRTC has built-in float4)
32
+ #ifdef __CUDACC_RTC__
33
+ // NVRTC: Use built-in float4 (no need for extra definitions)
34
+ #else
35
+ // NVCC: Include vector_types.h to get float4
36
+ #include <cuda_runtime.h>
37
+ #endif
38
+
39
+ #else
40
+ // If CUDA is not available (e.g., macOS build), manually define float4
41
+ struct alignas(16) float4 {
42
+ float x, y, z, w;
43
+ };
44
+ #endif
45
+
46
+ // only used while building the warp core library
47
+ #ifndef WP_TILE_BLOCK_DIM
48
+ #define WP_TILE_BLOCK_DIM 256
49
+ #endif
50
+
51
+ #if !defined(__CUDA_ARCH__)
52
+ #define WP_TILE_SHARED static
53
+ #define WP_TILE_SYNC void
54
+
55
+ #else
56
+ #define WP_TILE_SHARED __shared__
57
+ #define WP_TILE_SYNC __syncthreads
58
+ #endif
59
+
60
+ #if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
61
+ #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
62
+ #define WP_PRAGMA_UNROLL _Pragma("unroll")
63
+ #define WP_PRAGMA_NO_UNROLL _Pragma("unroll 1")
64
+ #else
65
+ #define WP_PRAGMA_UNROLL #pragma unroll
66
+ #define WP_PRAGMA_NO_UNROLL #pragma unroll 1
67
+ #endif
68
+
69
+ #else
70
+
71
+ #define WP_PRAGMA_UNROLL
72
+ #define WP_PRAGMA_NO_UNROLL
73
+
74
+ #endif
75
+
76
+ #define WP_USE_ASYNC_PIPELINE 0
77
+ #define WP_USE_REGISTER_GEMM 0
78
+
79
+ #if defined(__CUDACC_RTC__)
80
+ #define WP_TILE_THREAD_IDX threadIdx.x
81
+ #else
82
+ #define WP_TILE_THREAD_IDX 0
83
+ #endif //
84
+
85
+
86
+
87
+ /* Tile Expressions
88
+
89
+ [ ] Tiles
90
+ [x] Register, Shared, Global
91
+ [ ] Layouts
92
+ [x] Simple
93
+ [ ] Cute
94
+ [x] Remove Alloc type from tile_shared_t
95
+ [x] wp.launch_tiled() helper
96
+ [ ] Creation
97
+ [x] zeros
98
+ [x] ones
99
+ [x] arange
100
+ [x] tile()
101
+ [x] untile()
102
+ [ ] fromfunction()
103
+ [ ] explicit storage
104
+ [ ] Load/Store
105
+ [ ] 1D load/store variants
106
+ [ ] max_coord option for non-aligned loads
107
+ [ ] Indexed load
108
+ [x] wp.tile_atomic_add()
109
+ [ ] Maps
110
+ [x] Support user functions
111
+ [x] Support built-in functions
112
+ [ ] Support for lambda functions
113
+ [ ] Infer tile_map() output from operator type (e.g.: dot for each element)
114
+ [ ] Reductions
115
+ [x] Sum
116
+ [x] Forward
117
+ [x] Reverse
118
+ [x] Min
119
+ [x] Max
120
+ [x] Custom
121
+ [x] MatMul
122
+ [x] Forward
123
+ [x] Reverse
124
+ [ ] Operators
125
+ [ ] +, -, *, /, @?
126
+ [ ] += for matmul, e.g.: c += a@b, or c = a@b
127
+ [ ] Reshape
128
+ [ ] Broadcasting
129
+ [ ] Transpose
130
+ [x] Shared
131
+ [ ] Register
132
+ [ ] Slice
133
+ [ ] Runtime
134
+ [x] Compile-time block dimensions
135
+ [x] Switch between SIMT / Tile based execution if `block_dim` not provided to wp.launch()
136
+ [ ] Examples
137
+ [ ] Point registration
138
+ [ ] GEMM
139
+ [ ] MLP
140
+ [ ] LayerNorm
141
+ [ ] SoftMax
142
+ [ ] GEMM
143
+ [ ] warp.sim (CRBA)
144
+ [ ] Batched MLP
145
+ [ ] Layer norm
146
+ [ ] FNO + Burgers equation
147
+ [ ] Stochastic financial modeling
148
+ [ ] Convolution: https://github.com/NVIDIA/MinkowskiEngine/blob/master/src/convolution_kernel.cu#L123
149
+ [ ] MeshCNN (Modulus, Oliver)
150
+ [ ] BioNemo (Ali)
151
+ [ ] Skinning (David/Or/Vismay)
152
+ [ ] warp.sim (VBD)
153
+ [ ] Error checking
154
+ [ ] Ensure functions passed to tile_map() are compatible with tile type
155
+ [ ] Ensure that args passed to tile ops are compatible
156
+ [ ] Ensure tile load/store operations don't go out of bounds of arrays in debug mode
157
+
158
+ */
159
+
160
+ /*
161
+ Notes on shared memory synchronization
162
+ ======================================
163
+
164
+ Currently operations that write to shared memory tiles (e.g.: tile_load())
165
+ must synchronize before they return through WP_TILE_SYNC(), this
166
+ ensures subsequent read operations from the tile do not cause a race condition.
167
+
168
+ For tile_shared_t adjoints, the gradient accumulation is done through shared
169
+ memory atomics, i.e.: atomic_add(), since for broadcast tiles multiple threads
170
+ may map to the same location. Synchronization is still required after these
171
+ updates, since subsequent operations e.g.: adj_tile_load() will store the
172
+ gradients to memory, and all updates must be visible at that point, e.g.:
173
+
174
+ a = wp.tile_load(...)
175
+ b = wp.tile_load(...)
176
+ c = wp.tile_matmul(a, b)
177
+ wp.tile_store(c)
178
+
179
+ // loads incoming adjoints from global -> shared
180
+ wp.adj_tile_store(c, adj_c)
181
+ // consumes adj_c, requires synchronization
182
+ wp.adj_tile_matmul(a, b, adj_a, adj_b, adj_c)
183
+ // consumes adj_b, requires synchronization
184
+ wp.adj_tile_load(..., adj_b)
185
+ // consumes adj_b, requires synchronization
186
+ wp.adj_tile_load(..., adj_a)
187
+
188
+ Generally synchronization to adjoint tiles will happen through the
189
+ tile_shared_t::add() and tile_shared_t::assign() function automatically,
190
+ but in some cases e.g.: tile_matmul() it is done manually.
191
+
192
+ The current synchronization strategy is conservative, and can lead to more
193
+ synchronization than necessary. A more sophisticated strategy would be
194
+ to track the 'dirty' state of shared tiles, and synchronize only when
195
+ necessary. In addition, custom synchronization for e.g.: tile_load()
196
+ operations could be added through a SyncProvider template parameter on
197
+ the tile_shared_t type, for example to support barrier synchronization
198
+ for asynchronous global to shared loads.
199
+ */
200
+
201
+ namespace wp
202
+ {
203
+
204
+ // Primary template
205
+ template <typename T, typename U>
206
+ struct is_same {
207
+ static constexpr bool value = false;
208
+ };
209
+
210
+ // Specialization for the case when T and U are the same type
211
+ template <typename T>
212
+ struct is_same<T, T> {
213
+ static constexpr bool value = true;
214
+ };
215
+
216
+
217
+ template <int N>
218
+ struct tile_coord_t
219
+ {
220
+ int indices[N];
221
+
222
+ CUDA_CALLABLE inline int operator[](int i) const { assert(0 <= 1 && i < N); return indices[i]; }
223
+ CUDA_CALLABLE inline int& operator[](int i) { assert(0 <= 1 && i < N); return indices[i]; }
224
+
225
+ CUDA_CALLABLE inline tile_coord_t<N> operator + (const tile_coord_t<N>& c) const
226
+ {
227
+ tile_coord_t<N> out;
228
+ for (int i=0; i < N; ++i)
229
+ {
230
+ out.indices[i] = indices[i] + c.indices[i];
231
+ }
232
+ return out;
233
+ }
234
+ };
235
+
236
+ // This function deduces N = sizeof...(Ints)
237
+ template <typename... Ints>
238
+ constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
239
+ {
240
+ constexpr int N = sizeof...(Ints);
241
+
242
+ // Create the result
243
+ tile_coord_t<N> result{};
244
+
245
+ // Capture all arguments in a local array
246
+ int arr[] = { static_cast<int>(idxs)... };
247
+
248
+ // C++14 or later: 'for' is allowed in a constexpr context
249
+ for (int i = 0; i < N; ++i)
250
+ {
251
+ result.indices[i] = arr[i];
252
+ }
253
+
254
+ return result;
255
+ }
256
+
257
+ // helpers to construct a coord from a set of indices
258
+ inline auto tile_coord(int i)
259
+ {
260
+ auto c = tile_coord_t<1>();
261
+ c.indices[0] = i;
262
+ return c;
263
+ }
264
+
265
+ inline auto tile_coord(int i, int j)
266
+ {
267
+ auto c = tile_coord_t<2>();
268
+ c.indices[0] = i;
269
+ c.indices[1] = j;
270
+ return c;
271
+ }
272
+
273
+ inline auto tile_coord(int i, int j, int k)
274
+ {
275
+ auto c = tile_coord_t<3>();
276
+ c.indices[0] = i;
277
+ c.indices[1] = j;
278
+ c.indices[2] = k;
279
+ return c;
280
+ }
281
+
282
+ inline auto tile_coord(int i, int j, int k, int l)
283
+ {
284
+ auto c = tile_coord_t<4>();
285
+ c.indices[0] = i;
286
+ c.indices[1] = j;
287
+ c.indices[2] = k;
288
+ c.indices[3] = l;
289
+ return c;
290
+ }
291
+
292
+ // represents a compile time int tuple for strides/shapes/coords
293
+ template <int... V>
294
+ struct tile_tuple_t
295
+ {
296
+ static constexpr int N = sizeof...(V);
297
+ static_assert(N > 0, "Expected N > 0");
298
+
299
+ static constexpr int data[N] = { V... };
300
+
301
+ static constexpr int dim(int i) { assert(i < N); return data[i]; }
302
+ static constexpr int size()
303
+ {
304
+ int res = data[0];
305
+ for (int i=1; i < N; ++i)
306
+ res *= data[i];
307
+
308
+ return res;
309
+ }
310
+ };
311
+
312
+ // simple helper to compute strides from a shape up to 4d
313
+ template <typename Shape>
314
+ struct compute_strides;
315
+
316
+ // 1D
317
+ template <int D0>
318
+ struct compute_strides< tile_tuple_t<D0> > { using Stride = tile_tuple_t<1>; };
319
+ // 2D
320
+ template <int D0, int D1>
321
+ struct compute_strides< tile_tuple_t<D0, D1> > { using Stride = tile_tuple_t<D1, 1>; };
322
+ // 3D
323
+ template <int D0, int D1, int D2>
324
+ struct compute_strides< tile_tuple_t<D0, D1, D2> > { using Stride = tile_tuple_t<(D1 * D2), D2, 1>; };
325
+ // 4D
326
+ template <int D0, int D1, int D2, int D3>
327
+ struct compute_strides< tile_tuple_t<D0, D1, D2, D3> > { using Stride = tile_tuple_t<(D1 * D2 * D3), (D2 * D3), D3, 1>; };
328
+
329
+
330
+ // alias of tuple to represent shapes
331
+ template <int... V>
332
+ using tile_shape_t = tile_tuple_t<V...>;
333
+
334
+ // alias of tuple to represent stride
335
+ template <int... V>
336
+ using tile_stride_t = tile_tuple_t<V...>;
337
+
338
+
339
+ // represents a tile stored in global memory with dynamic strides
340
+ // used to represent the source and offset for tile loads to register/shared
341
+ template <typename T, typename Shape_>
342
+ struct tile_global_t
343
+ {
344
+ using Type = T;
345
+ using Shape = Shape_;
346
+ using Coord = tile_coord_t<Shape::N>;
347
+
348
+ array_t<T> data;
349
+ Coord offset;
350
+
351
+ tile_global_t(array_t<T>& a, const Coord& c) : data(a), offset(c)
352
+ {
353
+ }
354
+
355
+ inline CUDA_CALLABLE int index_from_coord(const Coord& coord) const
356
+ {
357
+ // element index
358
+ int index = 0;
359
+
360
+ WP_PRAGMA_UNROLL
361
+ for (int i=0; i < Shape::N; ++i)
362
+ {
363
+ // global = offset + coord
364
+ int c = offset[i] + coord[i];
365
+ index += data.strides[i]*c;
366
+ }
367
+
368
+ return index/sizeof(T);
369
+ }
370
+
371
+ inline CUDA_CALLABLE bool index(const Coord& coord, int& out) const
372
+ {
373
+ // element index
374
+ int index = 0;
375
+
376
+ WP_PRAGMA_UNROLL
377
+ for (int i=0; i < Shape::N; ++i)
378
+ {
379
+ // global = offset + coord
380
+ int c = offset[i] + coord[i];
381
+
382
+ // handle out of bounds case
383
+ if (c >= data.shape[i])
384
+ return false;
385
+ else
386
+ index += data.strides[i]*c;
387
+ }
388
+
389
+ // array strides are in bytes so we convert to elements
390
+ out = index / sizeof(T);
391
+ return true;
392
+ }
393
+
394
+ inline CUDA_CALLABLE T load(const Coord& coord) const
395
+ {
396
+ int i;
397
+ if (index(coord, i))
398
+ return data.data[i];
399
+ else
400
+ return T(0);
401
+ }
402
+
403
+ inline CUDA_CALLABLE T load_grad(const Coord& coord) const
404
+ {
405
+ int i;
406
+ if (index(coord, i))
407
+ return data.grad[i];
408
+ else
409
+ return T(0);
410
+ }
411
+
412
+ inline CUDA_CALLABLE void store(const Coord& coord, const T& x) const
413
+ {
414
+ int i;
415
+ if (index(coord, i))
416
+ data.data[i] = x;
417
+ }
418
+
419
+ inline CUDA_CALLABLE T atomic_add(const Coord& coord, const T& value) const
420
+ {
421
+ int i;
422
+ if (index(coord, i))
423
+ return wp::atomic_add(&data.data[i], value);
424
+ else
425
+ return T(0);
426
+ }
427
+
428
+ inline CUDA_CALLABLE T atomic_add_grad(const Coord& coord, const T& grad) const
429
+ {
430
+ int i;
431
+ if (index(coord, i))
432
+ return wp::atomic_add(&data.grad[i], grad);
433
+ else
434
+ return T(0);
435
+ }
436
+ };
437
+
438
+ template <typename Shape_>
439
+ struct tile_layout_register_t
440
+ {
441
+ using Shape = Shape_;
442
+ using Coord = tile_coord_t<Shape::N>;
443
+
444
+ static constexpr int Size = Shape::size();
445
+ static constexpr int NumRegs = (Size + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
446
+ static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
447
+
448
+ static inline CUDA_CALLABLE int linear_from_register(int reg)
449
+ {
450
+ return WP_TILE_THREAD_IDX + reg*WP_TILE_BLOCK_DIM;
451
+ }
452
+
453
+ static inline CUDA_CALLABLE int linear_from_coord(Coord c)
454
+ {
455
+ int linear = 0;
456
+ int stride = 1;
457
+
458
+ WP_PRAGMA_UNROLL
459
+ for (int i=Shape::N-1; i >= 0; --i)
460
+ {
461
+ linear += c[i] * stride;
462
+ stride *= Shape::dim(i);
463
+ }
464
+ return linear;
465
+ }
466
+
467
+ static inline CUDA_CALLABLE auto coord_from_linear(int linear)
468
+ {
469
+ Coord c;
470
+
471
+ WP_PRAGMA_UNROLL
472
+ for (int i=Shape::N-1; i >= 0; --i)
473
+ {
474
+ c[i] = linear%Shape::dim(i);
475
+ linear /= Shape::dim(i);
476
+ }
477
+
478
+ return c;
479
+ }
480
+
481
+ static inline CUDA_CALLABLE int thread_from_linear(int linear)
482
+ {
483
+ const int thread = linear%WP_TILE_BLOCK_DIM;
484
+ return thread;
485
+ }
486
+
487
+ static inline CUDA_CALLABLE int register_from_linear(int linear)
488
+ {
489
+ const int reg = linear/WP_TILE_BLOCK_DIM;
490
+ return reg;
491
+ }
492
+
493
+ static inline CUDA_CALLABLE bool valid(int linear)
494
+ {
495
+ if (Aligned || linear < Size)
496
+ return true;
497
+ else
498
+ return false;
499
+ }
500
+
501
+ };
502
+
503
+ // represents a tile stored in registers across a block
504
+ template <typename T, typename L>
505
+ struct tile_register_t
506
+ {
507
+ using Type = T;
508
+ using Layout = L;
509
+
510
+ T data[Layout::NumRegs];
511
+
512
+ inline CUDA_CALLABLE tile_register_t(T value=T(0.0))
513
+ {
514
+ // zero-initialize by default necessary for tile adjoints
515
+ // need to check if this results in worse codegen
516
+ // than doing adj_var = tile_zeros() explicitly
517
+ // in backwards pass and letting default constructor
518
+ // avoid initialization
519
+
520
+ for (int i=0; i < Layout::NumRegs; ++i)
521
+ data[i] = value;
522
+ }
523
+
524
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
525
+ {
526
+ copy_from_global(t);
527
+ return *this;
528
+ }
529
+
530
+ // define the += operator which is used during backward pass codegen
531
+ // when returning a register tile from a user defined function
532
+ inline CUDA_CALLABLE auto& operator += (tile_register_t<T, Layout>& rhs)
533
+ {
534
+ grad_add(rhs);
535
+ return *this;
536
+ }
537
+
538
+ inline CUDA_CALLABLE T& operator()(int reg)
539
+ {
540
+ assert(reg < Layout::NumRegs);
541
+ return data[reg];
542
+ }
543
+
544
+ inline CUDA_CALLABLE const T& operator()(int reg) const
545
+ {
546
+ assert(reg < Layout::NumRegs);
547
+ return data[reg];
548
+ }
549
+
550
+ inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
551
+ {
552
+ for (int i=0; i < Layout::NumRegs; ++i)
553
+ data[i] = tile.data[i];
554
+ }
555
+
556
+ inline CUDA_CALLABLE void zero()
557
+ {
558
+ for (int i=0; i < Layout::NumRegs; ++i)
559
+ data[i] = T(0);
560
+ }
561
+
562
+ // extract a single tile element to a native type
563
+ template <typename Coord>
564
+ inline CUDA_CALLABLE Type extract(const Coord& c)
565
+ {
566
+ // map from logical coords (i, j) -> (thread, reg)
567
+ const int linear = Layout::linear_from_coord(c);
568
+ const int thread = Layout::thread_from_linear(linear);
569
+ const int reg = Layout::register_from_linear(linear);
570
+
571
+ WP_TILE_SHARED Type scratch;
572
+
573
+ // ensure any previously scheduled threads have finished reading from scratch
574
+ WP_TILE_SYNC();
575
+
576
+ if (WP_TILE_THREAD_IDX == thread)
577
+ {
578
+ scratch = data[reg];
579
+ }
580
+
581
+ // ensure extraction thread has updated smem
582
+ WP_TILE_SYNC();
583
+
584
+ return scratch;
585
+ }
586
+
587
+
588
+ // backward version of scalar extract
589
+ template <typename Coord>
590
+ inline CUDA_CALLABLE void adj_extract(const Coord& c, Type adj_ret)
591
+ {
592
+ // map from logical coords (i, j) -> (thread, reg)
593
+ const int linear = Layout::linear_from_coord(c);
594
+ const int thread = Layout::thread_from_linear(linear);
595
+ const int reg = Layout::register_from_linear(linear);
596
+
597
+ if (WP_TILE_THREAD_IDX == thread)
598
+ {
599
+ data[reg] += adj_ret;
600
+ }
601
+ }
602
+
603
+ inline CUDA_CALLABLE void print() const;
604
+
605
+
606
+ // return the in-register version of this tile (nop)
607
+ inline CUDA_CALLABLE auto& copy_to_register()
608
+ {
609
+ return *this;
610
+ }
611
+
612
+ inline CUDA_CALLABLE const auto& copy_to_register() const
613
+ {
614
+ return *this;
615
+ }
616
+
617
+ // apply a lambda to all valid entries in the tile
618
+ // Op should be a functor that takes a register index and tile_coord_t as input
619
+ template <typename Op>
620
+ void apply(Op op)
621
+ {
622
+ WP_PRAGMA_UNROLL
623
+ for (int i=0; i < Layout::NumRegs; ++i)
624
+ {
625
+ int linear = Layout::linear_from_register(i);
626
+ if (!Layout::valid(linear))
627
+ break;
628
+
629
+ auto c = Layout::coord_from_linear(linear);
630
+ op(i, c);
631
+ }
632
+ }
633
+
634
+
635
+ // in-place gradient zero
636
+ inline CUDA_CALLABLE void grad_zero()
637
+ {
638
+ zero();
639
+ }
640
+
641
+ // accumulate gradients onto this tile
642
+ inline CUDA_CALLABLE void grad_add(const tile_register_t<T, Layout>& tile)
643
+ {
644
+ for (int i=0; i < Layout::NumRegs; ++i)
645
+ data[i] += tile.data[i];
646
+ }
647
+
648
+ CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
649
+ {
650
+ apply([&](int reg, auto c) {data[reg] = global.load_grad(c);});
651
+
652
+ }
653
+
654
+ inline CUDA_CALLABLE auto& grad_to_register()
655
+ {
656
+ // nop for register tiles
657
+ return *this;
658
+ }
659
+
660
+ template <typename Global>
661
+ inline CUDA_CALLABLE void copy_to_global(const Global& dest)
662
+ {
663
+ apply([&](int reg, auto c) { dest.store(c, data[reg]); });
664
+ }
665
+
666
+ template <typename Global>
667
+ inline CUDA_CALLABLE void copy_from_global(const Global& src)
668
+ {
669
+ apply([&](int reg, auto c) { data[reg] = src.load(c); });
670
+ }
671
+
672
+ // add a register tile to a global array
673
+ template <typename Global>
674
+ inline CUDA_CALLABLE auto atomic_add(const Global& dest)
675
+ {
676
+ // allocate a tile to hold previous dest value
677
+ auto previous = *this;
678
+
679
+ apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add(c, data[reg]); });
680
+ return previous;
681
+ }
682
+
683
+ // add a register tile to the gradient of a global array
684
+ template <typename Global>
685
+ inline CUDA_CALLABLE auto atomic_add_grad(const Global& dest)
686
+ {
687
+ // allocate a tile to hold previous dest value
688
+ auto previous = *this;
689
+
690
+ apply([&](int reg, auto c) { previous.data[reg] = dest.atomic_add_grad(c, data[reg]); });
691
+ return previous;
692
+ }
693
+ };
694
+
695
+
696
+ // helper to allocate a register tile like another tile
697
+ // users can either specify a template explicitly or
698
+ // pass in another concrete instance
699
+ template<typename Tile>
700
+ auto tile_register_like(Tile* t=nullptr)
701
+ {
702
+ using T = typename Tile::Type;
703
+ using L = typename Tile::Layout;
704
+
705
+ return tile_register_t<T, tile_layout_register_t<typename L::Shape>>(T(0.0));
706
+ }
707
+
708
+ // helper to construct a register tile from a type and a list of dims
709
+ template <typename T, int... Dims>
710
+ auto tile_register()
711
+ {
712
+ return tile_register_t<T, tile_layout_register_t<tile_shape_t<Dims...>>>();
713
+ }
714
+
715
+ inline CUDA_CALLABLE int tile_align(int num_bytes)
716
+ {
717
+ // note this much match value in Python types.py
718
+ const int alignment = 16;
719
+
720
+ const int num_bytes_abs = num_bytes < 0 ? - num_bytes : num_bytes;
721
+ const int sign = num_bytes < 0 ? - 1 : 1;
722
+
723
+ return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
724
+ }
725
+
726
+ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, bool check=false)
727
+ {
728
+ // we maintain a per-thread offset into dynamic
729
+ // shared memory that allows us to keep track of
730
+ // current use across dynamic function calls
731
+ WP_TILE_SHARED int smem_base[WP_TILE_BLOCK_DIM];
732
+
733
+ if (init)
734
+ {
735
+ smem_base[WP_TILE_THREAD_IDX] = 0;
736
+ return nullptr;
737
+ }
738
+ else if (check)
739
+ {
740
+ assert(smem_base[WP_TILE_THREAD_IDX] == 0);
741
+ return nullptr;
742
+ }
743
+ else
744
+ {
745
+ const int offset = smem_base[WP_TILE_THREAD_IDX];
746
+
747
+ // one entry per-thread so no need for synchronization
748
+ smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
749
+
750
+ #ifdef __CUDA_ARCH__
751
+ extern __shared__ char dynamic_smem_base[];
752
+ #else
753
+ // on CPU allocate a fixed 256k block to use for shared allocs
754
+ static const int max_cpu_shared = 256*1024;
755
+ static char dynamic_smem_base[max_cpu_shared];
756
+
757
+ assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
758
+ #endif
759
+ return &(dynamic_smem_base[offset]);
760
+ }
761
+ }
762
+
763
+
764
+ template <typename Shape_, typename Stride_= typename compute_strides<Shape_>::Stride>
765
+ struct tile_layout_strided_t
766
+ {
767
+ using Shape = Shape_;
768
+ using Stride = Stride_;
769
+ using Coord = tile_coord_t<Shape::N>;
770
+
771
+ static constexpr int Size = Shape::size();
772
+ static constexpr bool Aligned = Size%WP_TILE_BLOCK_DIM == 0;
773
+
774
+ static inline CUDA_CALLABLE auto coord_from_linear(int linear)
775
+ {
776
+ assert(linear < Size);
777
+
778
+ Coord c;
779
+
780
+ WP_PRAGMA_UNROLL
781
+ for (int d=Shape::N-1; d >= 0; --d)
782
+ {
783
+ c[d] = linear%Shape::dim(d);
784
+ linear /= Shape::dim(d);
785
+ }
786
+
787
+ return c;
788
+ }
789
+
790
+ static inline CUDA_CALLABLE int index_from_coord(Coord c)
791
+ {
792
+ int index = 0;
793
+
794
+ WP_PRAGMA_UNROLL
795
+ for (int d=0; d < Shape::N; ++d)
796
+ {
797
+ assert(c[d] < Shape::dim(d));
798
+
799
+ index += c[d]*Stride::dim(d);
800
+ }
801
+
802
+ return index;
803
+ }
804
+
805
+ // checks whether a strided layout is unique, i.e.: if memory locations are only
806
+ // every referred to by one element in the tile, this is a basic test that only
807
+ // checks for broadcast dimensions, it would be possible to do the full check
808
+ // using sorted shape/strides in Python and add it as a template parameter to the type
809
+ static constexpr bool is_unique()
810
+ {
811
+ constexpr int N = Shape::N;
812
+
813
+ // check for any broadcast dimensions
814
+ for (int i=0; i < N; ++i)
815
+ if (Stride::dim(i) == 0)
816
+ return false;
817
+
818
+ return true;
819
+ }
820
+
821
+ static constexpr bool Unique = is_unique();
822
+
823
+ static inline CUDA_CALLABLE bool valid(int linear)
824
+ {
825
+ return linear < Size;
826
+ }
827
+
828
+ };
829
+
830
+
831
+ template <typename T, typename L, bool Owner_=true>
832
+ struct tile_shared_t
833
+ {
834
+ using Type = T;
835
+ using Layout = L;
836
+ static constexpr bool Owner = Owner_;
837
+
838
+ struct Storage
839
+ {
840
+ T* ptr;
841
+
842
+ Storage(T* p) : ptr(p) {}
843
+
844
+ inline CUDA_CALLABLE T& operator()(typename Layout::Coord c)
845
+ {
846
+ assert(ptr);
847
+
848
+ int index = Layout::index_from_coord(c);
849
+ return ptr[index];
850
+ }
851
+
852
+ inline CUDA_CALLABLE const T& operator()(typename Layout::Coord c) const
853
+ {
854
+ assert(ptr);
855
+
856
+ int index = Layout::index_from_coord(c);
857
+ return ptr[index];
858
+ }
859
+
860
+ inline CUDA_CALLABLE T& operator()(int linear)
861
+ {
862
+ assert(ptr);
863
+ assert(Layout::valid(linear));
864
+
865
+ auto c = Layout::coord_from_linear(linear);
866
+ return (*this)(c);
867
+ }
868
+
869
+ inline CUDA_CALLABLE const T& operator()(int linear) const
870
+ {
871
+ assert(ptr);
872
+ assert(Layout::valid(linear));
873
+
874
+ auto c = Layout::coord_from_linear(linear);
875
+ return (*this)(c);
876
+ }
877
+ };
878
+
879
+ Storage data;
880
+ Storage grad;
881
+
882
+ // we need to track whether or not this tile's data has been initialized.
883
+ // once true, any re-initialization of data that follows needs a WP_TILE_SYNC()
884
+ // call to precede it, to allow threads that are still reading from this tile
885
+ // to complete their work. e.g, in a dynamic loop:
886
+ // for i in range(x):
887
+ // tile = wp.tile_load(arr, i, TILE_SIZE, storage="shared")
888
+ // # read from tile...
889
+ bool initialized;
890
+
891
+ // default initialization (non-initialized)
892
+ inline CUDA_CALLABLE tile_shared_t() : data(nullptr), grad(nullptr), initialized(false)
893
+ {
894
+ }
895
+
896
+ // initialize from an existing tile's memory
897
+ inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
898
+ {
899
+ }
900
+
901
+ inline CUDA_CALLABLE ~tile_shared_t()
902
+ {
903
+ if (Owner)
904
+ {
905
+ // update our per-thread shared memory allocator
906
+ if (data.ptr)
907
+ tile_alloc_shared(-Layout::Size*int(sizeof(T)));
908
+
909
+ if (grad.ptr)
910
+ tile_alloc_shared(-Layout::Size*int(sizeof(T)));
911
+ }
912
+ }
913
+
914
+ // assign from a register tile
915
+ template <typename Tile>
916
+ inline CUDA_CALLABLE auto& operator=(const Tile& t)
917
+ {
918
+ assign(t);
919
+ return *this;
920
+ }
921
+
922
+
923
+ /*
924
+ // construct from another shared tile, this constructor
925
+ // is invoked for reshape operations like `wp.tile_transpose()`
926
+ template <typename OtherT, typename OtherLayout>
927
+ inline CUDA_CALLABLE auto& operator=(const tile_shared_t<OtherT, OtherLayout>& rhs)
928
+ {
929
+ using OtherTile = tile_shared_t<OtherT, OtherLayout>;
930
+
931
+ // check dimensions are compatible
932
+ static_assert(Size == OtherTile::Size, "Expected Size == OtherTile::Size");
933
+
934
+ // alias tile directly
935
+ data = rhs.data;
936
+ grad = rhs.grad;
937
+ initialized = rhs.initialized;
938
+
939
+ return *this;
940
+ }
941
+ */
942
+
943
+ // assign from a global tile (load)
944
+ inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
945
+ {
946
+ copy_from_global(t);
947
+ return *this;
948
+ }
949
+
950
+ // assign from a constant value
951
+ inline CUDA_CALLABLE auto& operator=(const T& x)
952
+ {
953
+ // sync if we are re-initializing data so that any threads that are still
954
+ // reading from this tile can complete their work, e.g.: if re-assigning
955
+ // to a tile during a dynamic loop
956
+ if (initialized)
957
+ WP_TILE_SYNC();
958
+
959
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
960
+ data(i) = x;
961
+
962
+ initialized = true;
963
+ WP_TILE_SYNC();
964
+ return *this;
965
+ }
966
+
967
+ // in-place zero
968
+ inline CUDA_CALLABLE void zero()
969
+ {
970
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
971
+ data(i) = T(0);
972
+
973
+ WP_TILE_SYNC();
974
+ }
975
+
976
+ // extract a single tile element to a native type
977
+ inline CUDA_CALLABLE Type extract(const typename Layout::Coord& c)
978
+ {
979
+ return data(c);
980
+ }
981
+
982
+ // backward of scalar extraction
983
+ inline CUDA_CALLABLE void adj_extract(const typename Layout::Coord& c, Type adj_ret)
984
+ {
985
+ // since multiple threads may extract the same element
986
+ // we need to accumulate using atomic operations
987
+ wp::atomic_add(&grad(c), adj_ret);
988
+
989
+ WP_TILE_SYNC();
990
+ }
991
+
992
+
993
+ // copy register tile to shared
994
+ template <typename Tile>
995
+ inline CUDA_CALLABLE void assign(const Tile& tile)
996
+ {
997
+ if (initialized)
998
+ WP_TILE_SYNC();
999
+
1000
+ WP_PRAGMA_UNROLL
1001
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
1002
+ {
1003
+ const int linear = Tile::Layout::linear_from_register(i);
1004
+
1005
+ // handle case where tile size is not
1006
+ // aligned to block dimensions
1007
+ if (!Tile::Layout::valid(linear))
1008
+ break;
1009
+
1010
+ data(linear) = tile.data[i];
1011
+ }
1012
+
1013
+ initialized = true;
1014
+ WP_TILE_SYNC();
1015
+ }
1016
+
1017
+ // in-place gradient zero
1018
+ inline CUDA_CALLABLE void grad_zero()
1019
+ {
1020
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
1021
+ grad(i) = T(0);
1022
+
1023
+ WP_TILE_SYNC();
1024
+ }
1025
+
1026
+
1027
+ // accumulate gradients onto this tile
1028
+ template <typename Tile>
1029
+ inline CUDA_CALLABLE void grad_add(const Tile& tile)
1030
+ {
1031
+ WP_PRAGMA_UNROLL
1032
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
1033
+ {
1034
+ const int linear = Tile::Layout::linear_from_register(i);
1035
+
1036
+ // handle case where tile size is not
1037
+ // aligned to block dimensions
1038
+ if (!Tile::Layout::valid(linear))
1039
+ break;
1040
+
1041
+ // if the destination layout is unique (no broadcast dimensions)
1042
+ // then we can use regular non-atomic accmulation
1043
+ if (Layout::Unique)
1044
+ grad(linear) += tile.data[i];
1045
+ else
1046
+ // use shared memory atomics to accumulate gradients
1047
+ // since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
1048
+ // may map to a single location in shared memory
1049
+ wp::atomic_add(&grad(linear), tile.data[i]);
1050
+
1051
+ }
1052
+
1053
+ WP_TILE_SYNC();
1054
+ }
1055
+
1056
+ // accumulate gradient onto this tile from a global array
1057
+ CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1058
+ {
1059
+ WP_PRAGMA_UNROLL
1060
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1061
+ {
1062
+ auto c = Layout::coord_from_linear(i);
1063
+ T g = global.load_grad(c);
1064
+
1065
+ if (Layout::Unique)
1066
+ {
1067
+ // if the destination layout is unique (no broadcast dimensions)
1068
+ // then we can use regular non-atomic accumulation
1069
+ grad(c) += g;
1070
+ }
1071
+ else
1072
+ {
1073
+ // use shared memory atomics to accumulate gradients
1074
+ // since for broadcast tiles (e.g.: a bias vector) multiple incoming threads
1075
+ // may map to a single location in shared memory
1076
+ wp::atomic_add(&grad(c), g);
1077
+ }
1078
+ }
1079
+
1080
+ WP_TILE_SYNC();
1081
+ }
1082
+
1083
+ // copy shared tile to register
1084
+ inline CUDA_CALLABLE auto grad_to_register()
1085
+ {
1086
+ using Tile = tile_register_t<T, tile_layout_register_t<typename Layout::Shape>>;
1087
+ Tile out;
1088
+
1089
+ WP_PRAGMA_UNROLL
1090
+ for (int i=0; i < Tile::Layout::NumRegs; ++i)
1091
+ {
1092
+ const int linear = Tile::Layout::linear_from_register(i);
1093
+
1094
+ if (!Tile::Layout::valid(linear))
1095
+ break;
1096
+
1097
+ out(i) = grad(linear);
1098
+ }
1099
+
1100
+ return out;
1101
+ }
1102
+
1103
+ // copy shared tile to register
1104
+ inline CUDA_CALLABLE auto copy_to_register() const
1105
+ {
1106
+
1107
+ auto out = tile_register_like(this);
1108
+
1109
+ using Layout = typename decltype(out)::Layout;
1110
+
1111
+ WP_PRAGMA_UNROLL
1112
+ for (int i=0; i < Layout::NumRegs; ++i)
1113
+ {
1114
+ const int linear = Layout::linear_from_register(i);
1115
+
1116
+ if (!Layout::valid(linear))
1117
+ break;
1118
+
1119
+ out(i) = data(linear);
1120
+ }
1121
+
1122
+ return out;
1123
+ }
1124
+
1125
+ template <typename Global>
1126
+ inline CUDA_CALLABLE void copy_to_global(const Global& dest)
1127
+ {
1128
+
1129
+ #if defined(__CUDA_ARCH__)
1130
+ // vectorized loads for specific input/output shapes
1131
+ if constexpr (Layout::Shape::N == 2)
1132
+ {
1133
+ constexpr int lastdim = Layout::Shape::N-1;
1134
+ constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1;
1135
+ const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T);
1136
+ const int elements = (dest.data.shape[lastdim] - dest.offset[lastdim]);
1137
+ const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1138
+
1139
+ float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
1140
+ const bool aligned_dst = (uint64_t)(dest128)%sizeof(float4) == 0;
1141
+
1142
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_dst)
1143
+ {
1144
+ constexpr int M = Layout::Shape::dim(0);
1145
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1146
+
1147
+ // alias of shared tile with 128bit type
1148
+ using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1149
+ tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
1150
+
1151
+ assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
1152
+ assert(((uint64_t)(dest128))%sizeof(float4) == 0);
1153
+
1154
+ const int stride_i = dest.data.strides[0]/sizeof(float4);
1155
+ const int stride_j = 1;
1156
+
1157
+ WP_PRAGMA_UNROLL
1158
+ for (int i=WP_TILE_THREAD_IDX; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
1159
+ {
1160
+ auto c = SrcLayout::coord_from_linear(i);
1161
+
1162
+ dest128[stride_i*c[0] + stride_j*c[1]] = src128.data(i);
1163
+ }
1164
+
1165
+ return;
1166
+ }
1167
+ }
1168
+
1169
+ #endif //defined(__CUDA_ARCH__)
1170
+
1171
+ // scalar bounds checked path
1172
+ WP_PRAGMA_UNROLL
1173
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1174
+ {
1175
+ auto c = Layout::coord_from_linear(i);
1176
+ dest.store(c, data(i));
1177
+ }
1178
+ }
1179
+
1180
+ inline CUDA_CALLABLE void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
1181
+ {
1182
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1183
+
1184
+ unsigned long long saddr = 0ULL;
1185
+ unsigned long long gaddr = 0ULL;
1186
+
1187
+ asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(saddr) : "l"(shared_dest));
1188
+ asm volatile("cvta.to.global.u64 %0, %1;" : "=l"(gaddr) : "l"(global_src));
1189
+
1190
+ // Use cp.async on newer architectures
1191
+ asm volatile(
1192
+ "cp.async.ca.shared.global [%0], [%1], 16;\n"
1193
+ :
1194
+ : "l"(saddr), "l"(gaddr)
1195
+ );
1196
+ #else
1197
+ // use regular load/store through register on older arches
1198
+ *shared_dest = *global_src;
1199
+ #endif
1200
+ }
1201
+
1202
+ inline CUDA_CALLABLE void cp_async_commit_and_wait_all_128()
1203
+ {
1204
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1205
+ asm volatile(
1206
+ "cp.async.commit_group;\n"
1207
+ "cp.async.wait_group 0;\n" ::);
1208
+ #endif
1209
+ }
1210
+
1211
+ template <typename Global>
1212
+ inline CUDA_CALLABLE void copy_from_global(const Global& src)
1213
+ {
1214
+ if (initialized)
1215
+ WP_TILE_SYNC();
1216
+
1217
+ #if defined(__CUDA_ARCH__)
1218
+
1219
+ // vectorized loads for specific input/output shapes
1220
+ if constexpr (Layout::Shape::N == 2)
1221
+ {
1222
+ constexpr int lastdim = Layout::Shape::N-1;
1223
+ constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1;
1224
+ const bool contiguous_src = src.data.strides[lastdim] == sizeof(T);
1225
+ const int elements = (src.data.shape[lastdim] - src.offset[lastdim]);
1226
+ const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1227
+
1228
+ float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1229
+ const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
1230
+
1231
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_src)
1232
+ {
1233
+ constexpr int M = Layout::Shape::dim(0);
1234
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1235
+
1236
+ // alias of shared tile with 128bit type
1237
+ using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1238
+ tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
1239
+
1240
+ assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
1241
+ assert(((uint64_t)(src128))%sizeof(float4) == 0);
1242
+
1243
+ const int stride_i = src.data.strides[0]/sizeof(float4);
1244
+ const int stride_j = 1;
1245
+
1246
+ WP_PRAGMA_UNROLL
1247
+ for (int i=WP_TILE_THREAD_IDX; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
1248
+ {
1249
+ auto c = DestLayout::coord_from_linear(i);
1250
+
1251
+ #if WP_USE_ASYNC_PIPELINE
1252
+ cp_async_global_to_shared_128(&dest128.data(i), &src128[stride_i*c[0] + stride_j*c[1]]);
1253
+ #else
1254
+ dest128.data(i) = src128[stride_i*c[0] + stride_j*c[1]];
1255
+ #endif // WP_USE_ASYNC_PIPELINE
1256
+ }
1257
+
1258
+ #if WP_USE_ASYNC_PIPELINE
1259
+ cp_async_commit_and_wait_all_128();
1260
+ #endif // WP_USE_ASYNC_PIPELINE
1261
+
1262
+ initialized = true;
1263
+ WP_TILE_SYNC();
1264
+ return;
1265
+ }
1266
+ }
1267
+
1268
+ #endif //defined(__CUDA_ARCH__)
1269
+
1270
+ // scalar bounds checked path
1271
+ WP_PRAGMA_UNROLL
1272
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1273
+ {
1274
+ auto c = Layout::coord_from_linear(i);
1275
+ data(i) = src.load(c);
1276
+ }
1277
+
1278
+ initialized = true;
1279
+ WP_TILE_SYNC();
1280
+ }
1281
+
1282
+ template <typename Global>
1283
+ inline CUDA_CALLABLE auto atomic_add(Global& dest)
1284
+ {
1285
+ copy_to_register().atomic_add(dest);
1286
+ }
1287
+
1288
+ template <typename Global>
1289
+ inline CUDA_CALLABLE auto atomic_add_grad(Global& dest)
1290
+ {
1291
+ grad_to_register().atomic_add_grad(dest);
1292
+ }
1293
+
1294
+ // overload for integral types
1295
+ inline CUDA_CALLABLE void print_value(int x) const
1296
+ {
1297
+ printf("%d", x);
1298
+ }
1299
+
1300
+ // overload for floating point types
1301
+ template <typename ValueType>
1302
+ inline CUDA_CALLABLE void print_value(ValueType x) const
1303
+ {
1304
+ printf("%g", x);
1305
+ }
1306
+
1307
+ template <int Level = 0>
1308
+ inline CUDA_CALLABLE void print_values(const Storage& storage, int index=0) const
1309
+ {
1310
+ using Shape = typename Layout::Shape;
1311
+
1312
+ if constexpr (Level < Shape::N)
1313
+ {
1314
+ if constexpr (Level == Shape::N - 1)
1315
+ {
1316
+ // Special handling for 1D case
1317
+ printf("[");
1318
+ for (int i = 0; i < Shape::dim(Level); ++i)
1319
+ {
1320
+ print_value(storage(index + i));
1321
+
1322
+ if (i < Shape::dim(Level) - 1)
1323
+ {
1324
+ printf(" ");
1325
+ }
1326
+ }
1327
+ printf("]");
1328
+ }
1329
+ else if constexpr (Level == Shape::N - 2)
1330
+ {
1331
+ // Special handling for 2D case
1332
+ printf("[");
1333
+ for (int i = 0; i < Shape::dim(Level); ++i)
1334
+ {
1335
+ printf("[");
1336
+ for (int j=0; j < Shape::dim(Level+1); ++j)
1337
+ {
1338
+ print_value(storage(index));
1339
+
1340
+ if (j < Shape::dim(Level+1) - 1)
1341
+ {
1342
+ printf(" ");
1343
+ }
1344
+
1345
+ ++index;
1346
+ }
1347
+
1348
+ printf("]");
1349
+
1350
+ // next row
1351
+ if (i < Shape::dim(Level)-1)
1352
+ {
1353
+ printf("\n");
1354
+
1355
+ // indent next row
1356
+ for (int i=0; i <= Shape::N-2; ++i)
1357
+ printf(" ");
1358
+
1359
+ }
1360
+ }
1361
+ printf("]");
1362
+ }
1363
+ else
1364
+ {
1365
+ printf("[");
1366
+ for (int i = 0; i < Shape::dim(Level); ++i)
1367
+ {
1368
+ print_values<Level + 1>(storage, index + i * Shape::dim(Level));
1369
+ if (i < Shape::dim(Level) - 1)
1370
+ {
1371
+ printf("\n\n");
1372
+
1373
+ // indent next row
1374
+ for (int i=0; i <= Level; ++i)
1375
+ printf(" ");
1376
+ }
1377
+ }
1378
+ printf("]");
1379
+ }
1380
+ }
1381
+ }
1382
+
1383
+ inline CUDA_CALLABLE void print(bool reverse=false) const
1384
+ {
1385
+ if (WP_TILE_THREAD_IDX != 0)
1386
+ return;
1387
+
1388
+ if (reverse)
1389
+ print_values(grad);
1390
+ else
1391
+ print_values(data);
1392
+
1393
+ printf(" = tile(shape=(");
1394
+ for (int i=0; i < Layout::Shape::N; ++i)
1395
+ {
1396
+ printf("%d", Layout::Shape::dim(i));
1397
+ if (i != Layout::Shape::N-1)
1398
+ printf(",");
1399
+ }
1400
+
1401
+ printf("), storage=shared)\n");
1402
+ }
1403
+ };
1404
+
1405
+
1406
+ template <typename T, typename L>
1407
+ void tile_register_t<T, L>::print() const
1408
+ {
1409
+ // create a temporary shared tile so that
1410
+ // we can print it deterministically
1411
+ WP_TILE_SHARED T smem[L::Size];
1412
+ tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
1413
+
1414
+ scratch.assign(*this);
1415
+
1416
+ WP_TILE_SYNC();
1417
+
1418
+ if (WP_TILE_THREAD_IDX == 0)
1419
+ {
1420
+ scratch.print_values(scratch.data, 0);
1421
+
1422
+ printf(" = tile(shape=(");
1423
+ for (int i=0; i < L::Shape::N; ++i)
1424
+ {
1425
+ printf("%d", L::Shape::dim(i));
1426
+ if (i != L::Shape::N-1)
1427
+ printf(",");
1428
+ }
1429
+
1430
+ printf("), storage=register)\n");
1431
+ }
1432
+
1433
+ WP_TILE_SYNC();
1434
+ }
1435
+
1436
+ // print entry points
1437
+ template <typename T, typename L>
1438
+ inline CUDA_CALLABLE void print(const tile_register_t<T, L>& t) { t.print(); }
1439
+ template <typename T, typename L, bool Owner>
1440
+ inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print(); }
1441
+
1442
+ template <typename T, typename L, bool O>
1443
+ inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
1444
+ {
1445
+ return L::Shape::dim(0);
1446
+ }
1447
+
1448
+ template <typename T, typename L, bool O, typename AdjTile>
1449
+ inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile& a, int& adj_ret)
1450
+ {
1451
+ }
1452
+
1453
+ template <typename T, typename L>
1454
+ inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
1455
+ {
1456
+ return L::Shape::dim(0);
1457
+ }
1458
+
1459
+ template <typename T, typename L, typename AdjTile>
1460
+ inline CUDA_CALLABLE void adj_len(const tile_register_t<T,L>& t, const AdjTile& a, int& adj_ret)
1461
+ {
1462
+ }
1463
+
1464
+
1465
+ template <typename T, typename L>
1466
+ inline CUDA_CALLABLE void adj_print(const tile_register_t<T, L>& t, const tile_register_t<T, L>& a) { a.print(); }
1467
+ template <typename T, typename L, bool Owner>
1468
+ inline CUDA_CALLABLE void adj_print(const tile_shared_t<T, L, Owner>& t, const tile_shared_t<T, L, Owner>& a) { a.print(true); }
1469
+
1470
+
1471
+
1472
+ // helpers to allocate shared tiles
1473
+ template <typename T, typename Shape, bool RequiresGrad>
1474
+ inline CUDA_CALLABLE auto tile_alloc_empty()
1475
+
1476
+ { constexpr int size = Shape::size();
1477
+ T* data = (T*)tile_alloc_shared(size*sizeof(T));
1478
+ T* grad = nullptr;
1479
+
1480
+ #if FP_CHECK
1481
+
1482
+ // initialize tile to quiet nan
1483
+ uint32_t qnanbits = 0x7FC00000;
1484
+ float qnan = *(float*)(&qnanbits);
1485
+
1486
+ for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1487
+ data[i] = T(qnan);
1488
+
1489
+ WP_TILE_SYNC();
1490
+
1491
+ #endif // FP_CHECK
1492
+
1493
+
1494
+ if (RequiresGrad)
1495
+ {
1496
+ grad = (T*)tile_alloc_shared(size*sizeof(T));
1497
+
1498
+ for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1499
+ grad[i] = T(0);
1500
+
1501
+ WP_TILE_SYNC();
1502
+ }
1503
+
1504
+ return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
1505
+ }
1506
+
1507
+
1508
+ //-----------------------------------------------------------------------------------------------------
1509
+ // High level entry points for each op (correspond to one Warp builtin)
1510
+
1511
+ // construct a tile from a local SIMT value (one per-thread)
1512
+ template <typename T>
1513
+ inline CUDA_CALLABLE auto tile(const T& x)
1514
+ {
1515
+ tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
1516
+
1517
+ using Layout = typename decltype(result)::Layout;
1518
+ static_assert(Layout::NumRegs == 1, "Expected Layout::NumRegs == 1");
1519
+
1520
+ result.data[0] = x;
1521
+ return result;
1522
+ }
1523
+
1524
+ // overload for constructing a tile from a per-thread vector
1525
+ template <typename T, unsigned Length>
1526
+ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1527
+ {
1528
+ tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
1529
+
1530
+ using Layout = typename decltype(result)::Layout;
1531
+ static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
1532
+
1533
+ for (int i=0; i < Length; ++i)
1534
+ result.data[i] = x[i];
1535
+
1536
+ return result;
1537
+ }
1538
+
1539
+ // construct a tile from a local SIMT value (one per-thread)
1540
+ template <typename T, typename AdjTile>
1541
+ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1542
+ {
1543
+ static_assert(AdjTile::Layout::Shape::N == 1, "Expected AdjTile::Layout::Shape::N == 1");
1544
+ static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM");
1545
+
1546
+ auto adj_reg = adj_ret.copy_to_register();
1547
+
1548
+ adj_x += adj_reg.data[0];
1549
+ }
1550
+
1551
+ template <typename T, unsigned Length, typename AdjTile>
1552
+ inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
1553
+ {
1554
+ static_assert(AdjTile::Layout::Shape::N == 2, "Expected AdjTile::Layout::Shape::N == 2");
1555
+ static_assert(AdjTile::Layout::Shape::dim(0) == Length, "Expected AdjTile::Layout::Shape::dim(0) == Length");
1556
+ static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM");
1557
+
1558
+ auto adj_reg = adj_ret.copy_to_register();
1559
+
1560
+ for (int i=0; i < Length; ++i)
1561
+ adj_x[i] += adj_reg.data[i];
1562
+ }
1563
+
1564
+ template <typename Tile>
1565
+ inline CUDA_CALLABLE auto untile(Tile& tile)
1566
+ {
1567
+ // code-gen should have set the tile to
1568
+ // have exactly the block dimension so
1569
+ // there is exactly one value per-thread
1570
+ auto reg = tile.copy_to_register();
1571
+
1572
+ constexpr int N = Tile::Layout::Shape::N;
1573
+
1574
+ // scalar case
1575
+ if constexpr(N == 1)
1576
+ {
1577
+ return reg.data[0];
1578
+ }
1579
+
1580
+ // vector case
1581
+ if constexpr(N == 2)
1582
+ {
1583
+ constexpr int Length = Tile::Layout::Shape::dim(0);
1584
+ wp::vec_t<Length, typename Tile::Type> v;
1585
+ for (int i=0; i < Length; ++i)
1586
+ v[i] = reg.data[i];
1587
+
1588
+ return v;
1589
+ }
1590
+ }
1591
+
1592
+ template <typename Tile, typename Value>
1593
+ inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret)
1594
+ {
1595
+ auto adj = adj_tile.copy_to_register();
1596
+
1597
+ constexpr int N = Tile::Layout::Shape::N;
1598
+
1599
+ // scalar case
1600
+ if constexpr(N == 1)
1601
+ {
1602
+ adj.data[0] += adj_ret;
1603
+ }
1604
+
1605
+ // vector case
1606
+ if constexpr(N == 2)
1607
+ {
1608
+ constexpr int Length = Tile::Layout::Shape::dim(0);
1609
+ for (int i=0; i < Length; ++i)
1610
+ adj.data[i] += adj_ret[i];
1611
+ }
1612
+
1613
+ adj_tile.assign(adj);
1614
+ }
1615
+
1616
+ // zero initialized tile
1617
+ template <typename T, unsigned... Shape>
1618
+ inline CUDA_CALLABLE auto tile_zeros()
1619
+ {
1620
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
1621
+ return T(0);
1622
+ }
1623
+
1624
+ // one-initialized tile
1625
+ template <typename T, unsigned... Shape>
1626
+ inline CUDA_CALLABLE auto tile_ones()
1627
+ {
1628
+ // tile variable assignment operator will handle initialization (since lhs could be shared/register tile)
1629
+ return T(1);
1630
+ }
1631
+
1632
+ // tile with evenly spaced values
1633
+ template <typename T, int Len>
1634
+ inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step)
1635
+ {
1636
+ auto out = tile_register<T, Len>();
1637
+
1638
+ using Layout = typename decltype(out)::Layout;
1639
+
1640
+ WP_PRAGMA_UNROLL
1641
+ for (int i=0; i < Layout::NumRegs; ++i)
1642
+ {
1643
+ const int linear = Layout::linear_from_register(i);
1644
+
1645
+ // handle case where tile size is not
1646
+ // aligned to block dimensions
1647
+ if (!Layout::valid(linear))
1648
+ break;
1649
+
1650
+ out.data[i] = start + linear*step;
1651
+ }
1652
+
1653
+ return out;
1654
+ }
1655
+
1656
+ template <typename T, typename AdjTile>
1657
+ inline CUDA_CALLABLE void adj_tile_arange(T start, T stop, T step,
1658
+ T& adj_start, T& adj_stop, T& adj_step, AdjTile& adj_ret) {}
1659
+
1660
+ // entry point for load operations, these just return a reference to a global memory array + coordinate
1661
+ template <unsigned... Shape, typename... Indices, typename T>
1662
+ inline CUDA_CALLABLE auto tile_load(array_t<T>& src, Indices... offset)
1663
+ {
1664
+ return tile_global_t<T, tile_shape_t<Shape...>>(src, tile_coord(offset...));
1665
+ }
1666
+
1667
+ // // entry point for tile store operations
1668
+ // template <typename... Indices, typename T, typename Tile>
1669
+ // inline CUDA_CALLABLE void tile_store(array_t<T>& dest, Tile& src, Indices... x)
1670
+ // {
1671
+ // src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x)));
1672
+ // }
1673
+
1674
+ // entry point for tile store operations
1675
+ template <typename T, typename Tile>
1676
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
1677
+ template <typename T, typename Tile>
1678
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y))); }
1679
+ template <typename T, typename Tile>
1680
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z))); }
1681
+ template <typename T, typename Tile>
1682
+ inline CUDA_CALLABLE void tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { src.copy_to_global(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w))); }
1683
+
1684
+
1685
+
1686
+ template <typename T, typename Tile>
1687
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x))); }
1688
+ template <typename T, typename Tile>
1689
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y)));}
1690
+ template <typename T, typename Tile>
1691
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z)));}
1692
+ template <typename T, typename Tile>
1693
+ inline CUDA_CALLABLE auto tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& src) { return src.atomic_add(tile_global_t<T, typename Tile::Layout::Shape>(dest, tile_coord(x, y, z, w)));}
1694
+
1695
+
1696
+ //-------------------------------------
1697
+ // Adjoints
1698
+
1699
+ template <typename T, typename AdjTile, typename Coord>
1700
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, Coord c,
1701
+ array_t<T>& adj_src, Coord adj_c,
1702
+ AdjTile& adj_ret)
1703
+ {
1704
+ tile_global_t<T, typename AdjTile::Layout::Shape> dest(src, c);
1705
+
1706
+ // we allow users to override grad of src
1707
+ if (adj_src.data)
1708
+ dest.data.grad = adj_src.data;
1709
+
1710
+ adj_ret.atomic_add_grad(dest);
1711
+ }
1712
+
1713
+
1714
+ template <typename T, typename AdjTile>
1715
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, array_t<T>& adj_src, int adj_x, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x), adj_src, tile_coord(0), adj_ret); }
1716
+ template <typename T, typename AdjTile>
1717
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, array_t<T>& adj_src, int adj_x, int adj_y, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y), adj_src, tile_coord(0,0), adj_ret); }
1718
+ template <typename T, typename AdjTile>
1719
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z), adj_src, tile_coord(0,0,0), adj_ret); }
1720
+ template <typename T, typename AdjTile>
1721
+ inline CUDA_CALLABLE void adj_tile_load(array_t<T>& src, int x, int y, int z, int w, array_t<T>& adj_src, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_ret) { adj_tile_load( src, tile_coord(x, y, z, w), adj_src, tile_coord(0,0,0,0), adj_ret); }
1722
+
1723
+
1724
+
1725
+ template <typename T, typename Tile, typename AdjTile, typename Coord>
1726
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, array_t<T>& adj_dest, Coord adj_c, AdjTile& adj_t)
1727
+ {
1728
+ tile_global_t<T, typename AdjTile::Layout::Shape> src(dest, c);
1729
+
1730
+ // we allow users to override grad of src
1731
+ if (adj_dest.data)
1732
+ src.data.grad = adj_dest.data;
1733
+
1734
+ if (src.data.grad == nullptr)
1735
+ return;
1736
+
1737
+ adj_t.grad_add(src);
1738
+ }
1739
+
1740
+ template <typename T, typename Tile, typename AdjTile>
1741
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(0), adj_t); }
1742
+ template <typename T, typename Tile, typename AdjTile>
1743
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(0,0), adj_t); }
1744
+ template <typename T, typename Tile, typename AdjTile>
1745
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(0,0,0), adj_t); }
1746
+ template <typename T, typename Tile, typename AdjTile>
1747
+ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(0,0,0,0), adj_t); }
1748
+
1749
+
1750
+
1751
+ // adj_tile_atomic_add is an alias for adj_tile_store
1752
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1753
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, Tile& t, array_t<T>& adj_dest, int adj_x, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x), t, adj_dest, tile_coord(adj_x), adj_t); }
1754
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1755
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y), t, adj_dest, tile_coord(adj_x, adj_y), adj_t); }
1756
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1757
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z), t, adj_dest, tile_coord(adj_x, adj_y, adj_z), adj_t); }
1758
+ template <typename T, typename Tile, typename AdjTile, typename AdjRet>
1759
+ inline CUDA_CALLABLE void adj_tile_atomic_add(array_t<T>& dest, int x, int y, int z, int w, Tile& t, array_t<T>& adj_dest, int adj_x, int adj_y, int adj_z, int adj_w, AdjTile& adj_t, AdjRet& adj_ret) { adj_tile_store(dest, tile_coord(x, y, z, w), t, adj_dest, tile_coord(adj_x, adj_y, adj_z, adj_w), adj_t); }
1760
+
1761
+
1762
+ // unary map
1763
+ template <typename Tile, typename Fwd>
1764
+ inline CUDA_CALLABLE auto tile_map(Fwd op,
1765
+ Tile &a)
1766
+ {
1767
+ auto out = tile_register_like<Tile>();
1768
+ auto a_reg = a.copy_to_register();
1769
+
1770
+ using Layout = typename decltype(out)::Layout;
1771
+
1772
+ WP_PRAGMA_UNROLL
1773
+ for (int i=0; i < Layout::NumRegs; ++i)
1774
+ {
1775
+ out.data[i] = op(a_reg.data[i]);
1776
+ }
1777
+
1778
+ return out;
1779
+ }
1780
+
1781
+
1782
+ template <typename Tile, typename AdjTile, typename Fwd, typename Adj>
1783
+ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1784
+ Tile& a,
1785
+ Adj adj_op,
1786
+ Tile& adj_a,
1787
+ AdjTile& adj_ret)
1788
+ {
1789
+ auto a_reg = a.copy_to_register();
1790
+ auto adj_a_reg = tile_register_like<Tile>();
1791
+ auto adj_ret_reg = adj_ret.grad_to_register();
1792
+
1793
+ using Layout = typename decltype(a_reg)::Layout;
1794
+
1795
+ WP_PRAGMA_UNROLL
1796
+ for (int i=0; i < Layout::NumRegs; ++i)
1797
+ {
1798
+ adj_op(a_reg.data[i], adj_a_reg.data[i], adj_ret_reg.data[i]);
1799
+ }
1800
+
1801
+ // write adjoints back
1802
+ adj_a.grad_add(adj_a_reg);
1803
+ }
1804
+
1805
+ // binary map
1806
+ template <typename TileA, typename TileB, typename Fwd>
1807
+ inline CUDA_CALLABLE auto tile_map(Fwd op,
1808
+ TileA& a,
1809
+ TileB& b)
1810
+ {
1811
+ auto out = tile_register_like<TileA>();
1812
+
1813
+ auto a_reg = a.copy_to_register();
1814
+ auto b_reg = b.copy_to_register();
1815
+
1816
+ using Layout = typename decltype(out)::Layout;
1817
+
1818
+ WP_PRAGMA_UNROLL
1819
+ for (int i=0; i < Layout::NumRegs; ++i)
1820
+ {
1821
+ out.data[i] = op(a_reg.data[i], b_reg.data[i]);
1822
+ }
1823
+
1824
+ return out;
1825
+ }
1826
+
1827
+
1828
+ template <typename TileA, typename TileB, typename Fwd, typename Adj, typename AdjTile>
1829
+ inline CUDA_CALLABLE void adj_tile_map(Fwd op,
1830
+ TileA &a,
1831
+ TileB &b,
1832
+ Adj adj_op,
1833
+ TileA &adj_a,
1834
+ TileB &adj_b,
1835
+ AdjTile &adj_ret)
1836
+ {
1837
+ auto a_reg = a.copy_to_register();
1838
+ auto b_reg = b.copy_to_register();
1839
+
1840
+ // allocate storage for adjoints
1841
+ auto adj_a_reg = tile_register_like<TileA>();
1842
+ auto adj_b_reg = tile_register_like<TileB>();
1843
+
1844
+ auto adj_ret_reg = adj_ret.grad_to_register();
1845
+
1846
+ using Layout = typename decltype(a_reg)::Layout;
1847
+
1848
+ WP_PRAGMA_UNROLL
1849
+ for (int i=0; i < Layout::NumRegs; ++i)
1850
+ {
1851
+ adj_op(a_reg.data[i], b_reg.data[i], adj_a_reg.data[i], adj_b_reg.data[i], adj_ret_reg.data[i]);
1852
+ }
1853
+
1854
+ adj_a.grad_add(adj_a_reg);
1855
+ adj_b.grad_add(adj_b_reg);
1856
+ }
1857
+
1858
+ // wrap the operator in a lambda so that we don't have to do overload resolution for things like e.g.: wp.sin()
1859
+ // this is important because many of the builtin operators don't follow particular conventions on references for
1860
+ // the `adj_ret` parameter, which means it's not possible to figure out the overload we need using simple casting
1861
+ #define tile_unary_map(op, a) tile_map([](auto x) { return op(x);}, a)
1862
+ #define adj_tile_unary_map(op, a, adj_op, adj_a, adj_ret) adj_tile_map([](auto x) { return op(x);}, a, [](auto x, auto& adj_x, auto adj_ret) { adj_op(x, adj_x, adj_ret);}, adj_a, adj_ret)
1863
+
1864
+ #define tile_binary_map(op, a, b) tile_map([](auto x, auto y) { return op(x, y);}, a, b)
1865
+ #define adj_tile_binary_map(op, a, b, adj_op, adj_a, adj_b, adj_ret) adj_tile_map([](auto x, auto y) { return op(x, y);}, a, b, [](auto x, auto y, auto& adj_x, auto& adj_y, auto adj_ret) { adj_op(x, y, adj_x, adj_y, adj_ret);}, adj_a, adj_b, adj_ret)
1866
+
1867
+ // -tile (unary neg)
1868
+ template <typename Tile>
1869
+ inline CUDA_CALLABLE auto tile_neg(Tile& a) { return tile_unary_map(wp::neg, a); }
1870
+
1871
+ template <typename Tile, typename AdjTile>
1872
+ inline CUDA_CALLABLE void adj_tile_neg(Tile& a, Tile& adj_a, AdjTile& adj_ret) { adj_tile_unary_map(wp::neg, a, wp::adj_neg, adj_a, adj_ret); }
1873
+
1874
+
1875
+ // tile + tile
1876
+ template <typename TileA, typename TileB>
1877
+ inline CUDA_CALLABLE auto tile_add(TileA& a, TileB& b)
1878
+ {
1879
+ return tile_binary_map(add, a, b);
1880
+ }
1881
+
1882
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1883
+ inline CUDA_CALLABLE void adj_tile_add(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1884
+ {
1885
+ adj_tile_binary_map(add, a, b, adj_add, adj_a, adj_b, adj_c);
1886
+ }
1887
+
1888
+ // tile - tile
1889
+ template <typename TileA, typename TileB>
1890
+ inline CUDA_CALLABLE auto tile_sub(TileA& a, TileB& b)
1891
+ {
1892
+ return tile_binary_map(sub, a, b);
1893
+ }
1894
+
1895
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename AdjTile>
1896
+ inline CUDA_CALLABLE void adj_tile_sub(TileA& a, TileB& b, AdjTileA& adj_a, AdjTileB& adj_b, AdjTile& adj_c)
1897
+ {
1898
+ adj_tile_binary_map(sub, a, b, adj_sub, adj_a, adj_b, adj_c);
1899
+ }
1900
+
1901
+
1902
+ // tile*scalar
1903
+ template <typename Tile>
1904
+ inline CUDA_CALLABLE auto tile_mul(Tile& a, const typename Tile::Type& s)
1905
+ {
1906
+ // promote scalar to a constant tile
1907
+ auto s_tile = tile_register_t<typename Tile::Type, tile_layout_register_t<typename Tile::Layout::Shape>>(s);
1908
+
1909
+ return tile_binary_map(mul, a, s_tile);
1910
+ }
1911
+
1912
+ template <typename Tile, typename AdjTile>
1913
+ inline CUDA_CALLABLE void adj_tile_mul(Tile& a, const typename Tile::Type& s,
1914
+ Tile& adj_a, typename Tile::Type& adj_s,
1915
+ AdjTile& adj_c)
1916
+ {
1917
+ auto s_tile = tile_register_like<Tile>();
1918
+ auto adj_s_tile = tile_register_like<Tile>();
1919
+
1920
+ using Layout = typename decltype(adj_s_tile)::Layout;
1921
+
1922
+ // initialize to constant
1923
+ s_tile = s;
1924
+
1925
+ adj_tile_binary_map(mul, a, s_tile, adj_mul, adj_a, adj_s_tile, adj_c);
1926
+
1927
+ for (int i=0; i < Layout::NumRegs; ++i)
1928
+ {
1929
+ adj_s += adj_s_tile.data[i];
1930
+ }
1931
+ }
1932
+
1933
+
1934
+ // scalar*tile
1935
+ template <typename Tile>
1936
+ inline CUDA_CALLABLE auto tile_mul(const typename Tile::Type& s, Tile& a)
1937
+ {
1938
+ return tile_mul(a, s);
1939
+ }
1940
+
1941
+ template <typename Tile, typename AdjTile>
1942
+ inline CUDA_CALLABLE void adj_tile_mul(const typename Tile::Type& s, Tile& a,
1943
+ typename Tile::Type& adj_s, Tile& adj_a,
1944
+ AdjTile& adj_c)
1945
+ {
1946
+ adj_tile_mul(a, s, adj_a, adj_s, adj_c);
1947
+ }
1948
+
1949
+
1950
+ template<typename Tile>
1951
+ typename Tile::Type tile_extract(Tile& t, int i) { return t.extract(tile_coord(i)); }
1952
+ template<typename Tile>
1953
+ typename Tile::Type tile_extract(Tile& t, int i, int j) { return t.extract(tile_coord(i,j)); }
1954
+ template<typename Tile>
1955
+ typename Tile::Type tile_extract(Tile& t, int i, int j, int k) { return t.extract(tile_coord(i,j,k)); }
1956
+ template<typename Tile>
1957
+ typename Tile::Type tile_extract(Tile& t, int i, int j, int k, int l) { return t.extract(tile_coord(i,j,k,l)); }
1958
+
1959
+
1960
+ template<typename Tile, typename AdjTile>
1961
+ void adj_tile_extract(Tile& t, int i, AdjTile& adj_t, int adj_i, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i), adj_ret); }
1962
+ template<typename Tile, typename AdjTile>
1963
+ void adj_tile_extract(Tile& t, int i, int j, AdjTile& adj_t, int adj_i, int adj_j, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j), adj_ret); }
1964
+ template<typename Tile, typename AdjTile>
1965
+ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k), adj_ret); }
1966
+ template<typename Tile, typename AdjTile>
1967
+ void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
1968
+
1969
+
1970
+ namespace partitioned_gemm
1971
+ {
1972
+
1973
+ template <typename T>
1974
+ inline CUDA_CALLABLE const T& index(const T* __restrict__ p, int i, int j, int stride)
1975
+ {
1976
+ return p[i*stride + j];
1977
+ }
1978
+
1979
+ template <typename T>
1980
+ inline CUDA_CALLABLE T& index(T* __restrict__ p, int i, int j, int stride)
1981
+ {
1982
+ return p[i*stride + j];
1983
+ }
1984
+
1985
+ template <int PartitionM, int PartitionN, typename Tile>
1986
+ struct partition_t
1987
+ {
1988
+ static constexpr int M = PartitionM;
1989
+ static constexpr int N = PartitionN;
1990
+ static constexpr int Stride = Tile::Layout::Shape::dim(1);
1991
+
1992
+ using T = typename Tile::Type;
1993
+
1994
+ inline partition_t(Tile& A)
1995
+ {
1996
+ data = A.data.ptr;
1997
+
1998
+ // todo: do ceil div for non-multiples of M,N
1999
+ shape[0] = Tile::Layout::Shape::dim(0)/PartitionM;
2000
+ shape[1] = Tile::Layout::Shape::dim(1)/PartitionN;
2001
+ }
2002
+
2003
+ // underlying data
2004
+ T* data;
2005
+
2006
+ // partition dimensions
2007
+ int shape[2];
2008
+ };
2009
+
2010
+ template <typename Partition>
2011
+ inline int partition_size(const Partition& part)
2012
+ {
2013
+ return part.shape[0]*part.shape[1];
2014
+ }
2015
+
2016
+ // returns the x, y coordinates of a tile given a linear index
2017
+ template <typename Partition>
2018
+ inline void partition_coord(const Partition& part, const int t, int& i, int& j)
2019
+ {
2020
+ i = t/part.shape[1];
2021
+ j = t%part.shape[1];
2022
+ }
2023
+
2024
+ template <typename Partition>
2025
+ inline auto partition_load(const Partition& tile, int i, int j)
2026
+ {
2027
+ mat_t<Partition::M, Partition::N, typename Partition::T> out;
2028
+
2029
+ const int tile_i = i*Partition::M;
2030
+ const int tile_j = j*Partition::N;
2031
+
2032
+ WP_PRAGMA_UNROLL
2033
+ for (int i=0; i < Partition::M; ++i)
2034
+ {
2035
+ WP_PRAGMA_UNROLL
2036
+ for (int j=0; j < Partition::N; ++j)
2037
+ {
2038
+ out.data[i][j] = partitioned_gemm::index(tile.data, tile_i + i, tile_j + j, Partition::Stride);
2039
+ }
2040
+ }
2041
+
2042
+ return out;
2043
+ }
2044
+
2045
+ template <typename Partition, typename Value>
2046
+ inline void partition_store(const Partition& tile, int i, int j, const Value& value)
2047
+ {
2048
+ const int tile_i = Partition::M*i;
2049
+ const int tile_j = Partition::N*j;
2050
+
2051
+ WP_PRAGMA_UNROLL
2052
+ for (int i=0; i < Partition::M; ++i)
2053
+ {
2054
+ WP_PRAGMA_UNROLL
2055
+ for (int j=0; j < Partition::N; ++j)
2056
+ {
2057
+ index(tile.data, tile_i + i, tile_j + j, Partition::Stride) = value.data[i][j];
2058
+ }
2059
+ }
2060
+ }
2061
+
2062
+
2063
+ template <typename TileA, typename TileB, typename TileC>
2064
+ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
2065
+ {
2066
+ const int TILE_M = 4;
2067
+ const int TILE_N = 4;
2068
+ const int TILE_K = 4;
2069
+
2070
+ auto A_tile = partition_t<TILE_M, TILE_K, TileA>(A);
2071
+ auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
2072
+ auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
2073
+
2074
+ //static_assert(is_same<typename TileA::Type, typename TileB::Type>::value);
2075
+
2076
+ const int length = partition_size(C_tile);
2077
+
2078
+ for (int t=WP_TILE_THREAD_IDX; t < length; t += WP_TILE_BLOCK_DIM)
2079
+ {
2080
+ int i, j;
2081
+ partition_coord(C_tile, t, i, j);
2082
+
2083
+ // accumulator
2084
+ auto sum = partition_load(C_tile, i, j);
2085
+
2086
+ WP_PRAGMA_UNROLL
2087
+ for (int k=0; k < A_tile.shape[1]; k++)
2088
+ {
2089
+ const auto a = partition_load(A_tile, i, k);
2090
+ const auto b = partition_load(B_tile, k, j);
2091
+
2092
+ sum += mul(a, b);
2093
+ }
2094
+
2095
+ partition_store(C_tile, i, j, sum);
2096
+ }
2097
+ }
2098
+
2099
+ template <typename LayoutA, typename LayoutB, typename LayoutC, typename StorageA, typename StorageB, typename StorageC, typename T>
2100
+ inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T scale)
2101
+ {
2102
+ for (int t=WP_TILE_THREAD_IDX; t < LayoutC::Size; t += WP_TILE_BLOCK_DIM)
2103
+ {
2104
+ auto coord = LayoutC::coord_from_linear(t);
2105
+
2106
+ int i = coord[0];
2107
+ int j = coord[1];
2108
+
2109
+ // accumulator
2110
+ auto sum = C(coord)*scale;
2111
+
2112
+ WP_PRAGMA_UNROLL
2113
+ for (int k=0; k < LayoutA::Shape::dim(1); k++)
2114
+ {
2115
+ const auto a = A(tile_coord(i, k));
2116
+ const auto b = B(tile_coord(k, j));
2117
+
2118
+ sum = muladd<decltype(sum)>(a, b, sum);
2119
+ }
2120
+
2121
+ C(coord) = sum;
2122
+ }
2123
+ }
2124
+
2125
+ template <typename TileA, typename TileL>
2126
+ inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
2127
+ {
2128
+ using T = typename TileA::Type;
2129
+ constexpr int n = TileA::Layout::Shape::dim(1);
2130
+
2131
+ for (int j=0; j < n; ++j)
2132
+ {
2133
+ T s = A.data(tile_coord(j, j));
2134
+
2135
+ for (int k=0; k < j; ++k)
2136
+ {
2137
+ T r = L.data(tile_coord(j, k));
2138
+ s -= r * r;
2139
+ }
2140
+
2141
+ s = wp::sqrt(s);
2142
+ T invS = 1.0 / s;
2143
+
2144
+ L.data(tile_coord(j, j)) = s;
2145
+
2146
+ for (int i=j+1; i < n; ++i)
2147
+ {
2148
+ s = A.data(tile_coord(i, j));
2149
+
2150
+ for (int k=0; k < j; ++k)
2151
+ {
2152
+ s -= L.data(tile_coord(i, k)) * L.data(tile_coord(j, k));
2153
+ }
2154
+
2155
+ L.data(tile_coord(i, j)) = s * invS;
2156
+ }
2157
+
2158
+ // zero out upper triangular portion
2159
+ for (int k=j+1; k < n; ++k)
2160
+ {
2161
+ L.data(tile_coord(j,k)) = T(0.0);
2162
+ }
2163
+ }
2164
+ }
2165
+
2166
+ template <typename TileL, typename TileX, typename TileY>
2167
+ inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
2168
+ {
2169
+ using T = typename TileL::Type;
2170
+ constexpr int n = TileL::Layout::Shape::dim(1);
2171
+
2172
+ for (int i=0; i < n; ++i)
2173
+ {
2174
+ T s = Y.data(tile_coord(i));
2175
+
2176
+ for (int j=0; j < i; ++j)
2177
+ s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
2178
+
2179
+ X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2180
+ }
2181
+
2182
+ for (int i=n-1; i >= 0; --i)
2183
+ {
2184
+ T s = X.data(tile_coord(i));
2185
+
2186
+ for (int j=i+1; j < n; ++j)
2187
+ s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
2188
+
2189
+ X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2190
+ }
2191
+ }
2192
+
2193
+ } // namespace partition_gemm
2194
+
2195
+
2196
+ template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
2197
+ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C)
2198
+ {
2199
+ using ShapeA = typename TileA::Layout::Shape;
2200
+ using ShapeB = typename TileB::Layout::Shape;
2201
+ using ShapeC = typename TileC::Layout::Shape;
2202
+
2203
+ static_assert(ShapeA::N == 2, "Expected ShapeA::N == 2");
2204
+ static_assert(ShapeB::N == 2, "Expected ShapeB::N == 2");
2205
+ static_assert(ShapeC::N == 2, "Expected ShapeC::N == 2");
2206
+
2207
+ static_assert(ShapeA::dim(1) == ShapeB::dim(0), "Expected ShapeA::dim(1) == ShapeB::dim(0)");
2208
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
2209
+ static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
2210
+
2211
+
2212
+ using T = typename TileA::Type;
2213
+
2214
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2215
+ partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
2216
+ #else
2217
+ fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
2218
+ #endif
2219
+
2220
+ WP_TILE_SYNC();
2221
+
2222
+ return C;
2223
+ }
2224
+
2225
+
2226
+ // backward for the wp.tile_matmul(a, b, out) syntax
2227
+ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
2228
+ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
2229
+ Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C)
2230
+ {
2231
+ using T = typename TileA::Type;
2232
+
2233
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2234
+ auto At = tile_transpose(A);
2235
+ auto Bt = tile_transpose(B);
2236
+
2237
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
2238
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
2239
+ #else
2240
+ fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
2241
+ fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
2242
+ #endif
2243
+
2244
+ WP_TILE_SYNC();
2245
+ }
2246
+
2247
+ // backward for the out = wp.tile_matmul(a, b) syntax
2248
+ template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
2249
+ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
2250
+ Fwd adj_fun_forward, AdjA adj_fun_backward_A, AdjB adj_fun_backward_B, TileA& adj_A, TileB& adj_B, TileC& adj_C, TileC& adj_ret)
2251
+ {
2252
+ using T = typename TileA::Type;
2253
+
2254
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2255
+ auto At = tile_transpose(A);
2256
+ auto Bt = tile_transpose(B);
2257
+
2258
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
2259
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
2260
+ #else
2261
+ fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
2262
+ fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
2263
+ #endif
2264
+
2265
+ WP_TILE_SYNC();
2266
+ }
2267
+
2268
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2269
+
2270
+ #define tile_fft()
2271
+ #define tile_ifft()
2272
+
2273
+ #define adj_tile_fft()
2274
+ #define adj_tile_ifft()
2275
+
2276
+ #else
2277
+
2278
+ // TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
2279
+ // and remove the need for __align__(16) dtypes data[...]
2280
+ #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
2281
+ do { \
2282
+ void function_name(dtype*, dtype*); \
2283
+ char* buffer = (char*)wp::tile_alloc_shared(shared_memory_size); \
2284
+ __align__(16) dtype data[ept]; \
2285
+ for(int b = 0; b < (int)batch_size; b++) { \
2286
+ dtype* inout = Xinout.data + (int)b * (int)ept; \
2287
+ memcpy(data, inout, sizeof(dtype) * ept); \
2288
+ function_name(data, (dtype*)buffer); \
2289
+ memcpy(inout, data, sizeof(dtype) * ept); \
2290
+ WP_TILE_SYNC(); \
2291
+ } \
2292
+ wp::tile_alloc_shared(-shared_memory_size); \
2293
+ } while (0)
2294
+
2295
+ #define tile_ifft tile_fft
2296
+
2297
+ // adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept are all ignored
2298
+
2299
+ #define adj_tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
2300
+ adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
2301
+ adj_Xinout) \
2302
+ do { \
2303
+ tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
2304
+ } while (0)
2305
+
2306
+ #define adj_tile_ifft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout, \
2307
+ adj_function_name, adj_dtype, adj_shared_memory_size, adj_batch_size, adj_ept, \
2308
+ adj_Xinout) \
2309
+ do { \
2310
+ tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
2311
+ } while (0)
2312
+
2313
+ #endif // !defined(__CUDA_ARCH__)
2314
+
2315
+ template <typename Fwd, typename TileA, typename TileL>
2316
+ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2317
+ {
2318
+ // Copy to L
2319
+ L = A;
2320
+
2321
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2322
+
2323
+ partitioned_gemm::scalar_cholesky(A, L);
2324
+
2325
+ #else
2326
+
2327
+
2328
+ // Call cholesky on L
2329
+ WP_TILE_SYNC();
2330
+
2331
+ fun_forward(L.data.ptr, TileL::Layout::Shape::dim(0));
2332
+
2333
+ WP_TILE_SYNC();
2334
+
2335
+ // Zero-out the upper triangular part of L
2336
+
2337
+ WP_PRAGMA_UNROLL
2338
+ for (int i=WP_TILE_THREAD_IDX; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
2339
+ {
2340
+ auto c = TileL::Layout::coord_from_linear(i);
2341
+
2342
+ if(c[0] < c[1])
2343
+ L.data(c) = 0.0;
2344
+ }
2345
+
2346
+ WP_TILE_SYNC();
2347
+
2348
+ #endif
2349
+
2350
+ return L;
2351
+ }
2352
+
2353
+ #define adj_tile_cholesky(function_name, A, L, \
2354
+ adj_function_name, adj_A, adj_L, adj_ret) \
2355
+ do { \
2356
+ assert(false); \
2357
+ } while (0)
2358
+
2359
+ template <typename Fwd, typename TileL, typename TileX, typename TileY>
2360
+ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2361
+ {
2362
+ // Copy x to y
2363
+
2364
+ Y = X;
2365
+
2366
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2367
+
2368
+ partitioned_gemm::scalar_cholesky_solve(L, X, Y);
2369
+
2370
+ #else
2371
+
2372
+ // Call cholesky solve on L & y
2373
+
2374
+ WP_TILE_SYNC();
2375
+
2376
+ fun_forward(L.data.ptr, Y.data.ptr); \
2377
+
2378
+ WP_TILE_SYNC();
2379
+
2380
+ #endif
2381
+
2382
+ return Y;
2383
+ }
2384
+
2385
+ #define adj_tile_cholesky_solve(function_name, L, X, Y, \
2386
+ adj_function_name, adj_L, adj_X, adj_Y, adj_ret) \
2387
+ do { \
2388
+ assert(false); \
2389
+ } while (0)
2390
+
2391
+ template <typename Tile>
2392
+ inline CUDA_CALLABLE auto tile_transpose(Tile& t)
2393
+ {
2394
+ static_assert(Tile::Layout::Shape::N == 2, "Expected Tile::Layout::Shape::N == 2");
2395
+
2396
+ // alias incoming tile
2397
+ constexpr int M = Tile::Layout::Shape::dim(0);
2398
+ constexpr int N = Tile::Layout::Shape::dim(1);
2399
+
2400
+ constexpr int StrideM = Tile::Layout::Stride::dim(0);
2401
+ constexpr int StrideN = Tile::Layout::Stride::dim(1);
2402
+
2403
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N,M>, tile_stride_t<StrideN, StrideM>>, false>(t.data.ptr, t.grad.ptr);
2404
+ }
2405
+
2406
+ template <typename Tile, typename AdjTile>
2407
+ inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_ret)
2408
+ {
2409
+ auto a = tile_transpose(adj_ret);
2410
+ auto b = adj_t;
2411
+
2412
+ adj_t.assign(tile_add(a,b));
2413
+ }
2414
+
2415
+ template <int N, int StrideN, typename Tile>
2416
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2417
+ {
2418
+ // alias incoming tile with new strides
2419
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N>, tile_stride_t<StrideN>>, false>(t.data.ptr, t.grad.ptr);
2420
+ }
2421
+
2422
+ template <int M, int N, int StrideM, int StrideN, typename Tile>
2423
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2424
+ {
2425
+ // alias incoming tile with new strides
2426
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
2427
+ }
2428
+
2429
+ template <int M, int N, int O, int StrideM, int StrideN, int StrideO, typename Tile>
2430
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2431
+ {
2432
+ // alias incoming tile with new strides
2433
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O>, tile_stride_t<StrideM, StrideN, StrideO>>, false>(t.data.ptr, t.grad.ptr);
2434
+ }
2435
+
2436
+ template <int M, int N, int O, int P, int StrideM, int StrideN, int StrideO, int StrideP, typename Tile>
2437
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2438
+ {
2439
+ // alias incoming tile with new strides
2440
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O, P>, tile_stride_t<StrideM, StrideN, StrideO, StrideP>>, false>(t.data.ptr, t.grad.ptr);
2441
+ }
2442
+
2443
+ template <typename Tile, typename AdjTile>
2444
+ inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
2445
+ {
2446
+ // nop, since memory is aliased grads already accumulated
2447
+ }
2448
+
2449
+ template <typename ReturnType, typename Tile, typename... Indices>
2450
+ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
2451
+ {
2452
+ auto c = tile_coord(indices...);
2453
+
2454
+ // return new tile with same strides
2455
+ typename Tile::Type* data_ptr = &t.data(c);
2456
+ typename Tile::Type* grad_ptr = nullptr;
2457
+
2458
+ if (t.grad.ptr)
2459
+ grad_ptr = &t.grad(c);
2460
+
2461
+ return ReturnType(data_ptr, grad_ptr);
2462
+ }
2463
+
2464
+
2465
+ template <typename TileA, typename Scalar>
2466
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
2467
+ {
2468
+ dest.data(tile_coord(i)) = src;
2469
+ WP_TILE_SYNC();
2470
+ }
2471
+
2472
+ template <typename TileA, typename Scalar>
2473
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, const Scalar& src)
2474
+ {
2475
+ dest.data(tile_coord(i, j)) = src;
2476
+ WP_TILE_SYNC();
2477
+ }
2478
+
2479
+ template <typename TileA, typename Scalar>
2480
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, const Scalar& src)
2481
+ {
2482
+ dest.data(tile_coord(i, j, k)) = src;
2483
+ WP_TILE_SYNC();
2484
+ }
2485
+
2486
+ template <typename TileA, typename Scalar>
2487
+ inline CUDA_CALLABLE void assign(TileA& dest, int i, int j, int k, int l, const Scalar& src)
2488
+ {
2489
+ dest.data(tile_coord(i, j, k, l)) = src;
2490
+ WP_TILE_SYNC();
2491
+ }
2492
+
2493
+
2494
+
2495
+
2496
+ template <typename TileA, typename TileB, typename Coord>
2497
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offset)
2498
+ {
2499
+ using Layout = typename TileB::Layout;
2500
+
2501
+ for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2502
+ {
2503
+ auto c = Layout::coord_from_linear(t);
2504
+ dest.data(c + offset) = src.data(c);
2505
+ }
2506
+
2507
+ WP_TILE_SYNC();
2508
+ }
2509
+
2510
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB, typename Coord, typename AdjCoord>
2511
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
2512
+ AdjTileA& adj_dest, AdjTileB& adj_src, AdjCoord adj_offset)
2513
+ {
2514
+ using Layout = typename TileB::Layout;
2515
+
2516
+ for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2517
+ {
2518
+ auto c = Layout::coord_from_linear(t);
2519
+ src.grad(c) += dest.grad(c + offset);
2520
+ }
2521
+
2522
+ WP_TILE_SYNC();
2523
+ }
2524
+
2525
+
2526
+ // codegen entry points, which emit calls like `tile_assign(dest, src, i, j, k)`
2527
+ // a better approach here would be for codegen to just directly generate `tile_assign(dest, src, tile_coord(i, j, k))`
2528
+ // i.e.: call the above implementation methods directly, then we could remove these overloads
2529
+ template <typename TileA, typename TileB>
2530
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i) { tile_assign(dest, src, tile_coord(i)); }
2531
+ template <typename TileA, typename TileB>
2532
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j) { tile_assign(dest, src, tile_coord(i, j)); }
2533
+ template <typename TileA, typename TileB>
2534
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k) { tile_assign(dest, src, tile_coord(i, j, k)); }
2535
+ template <typename TileA, typename TileB>
2536
+ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l) { tile_assign(dest, src, tile_coord(i, j, k, l)); }
2537
+
2538
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2539
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, AdjTileA& adj_dest, AdjTileB& adj_src, int) { adj_tile_assign(dest, src, tile_coord(i), adj_dest, adj_src, tile_coord(0)); }
2540
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2541
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, AdjTileA& adj_dest, AdjTileB& adj_src, int, int) { adj_tile_assign(dest, src, tile_coord(i,j), adj_dest, adj_src, tile_coord(0)); }
2542
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2543
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k), adj_dest, adj_src, tile_coord(0)); }
2544
+ template <typename TileA, typename TileB, typename AdjTileA, typename AdjTileB>
2545
+ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, int i, int j, int k, int l, AdjTileA& adj_dest, AdjTileB& adj_src, int, int, int, int) { adj_tile_assign(dest, src, tile_coord(i,j,k,l), adj_dest, adj_src, tile_coord(0)); }
2546
+
2547
+
2548
+ template <typename TileA, typename TileB, typename TileC>
2549
+ inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
2550
+ {
2551
+ using ShapeA = typename TileA::Layout::Shape;
2552
+ using ShapeB = typename TileB::Layout::Shape;
2553
+ using ShapeC = typename TileC::Layout::Shape;
2554
+
2555
+ static_assert(ShapeA::dim(0) == ShapeA::dim(1), "Expected ShapeA::dim(0) == ShapeA::dim(1)");
2556
+ static_assert(ShapeB::dim(0) == ShapeA::dim(0), "Expected ShapeB::dim(0) == ShapeA::dim(0)");
2557
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
2558
+ static_assert(ShapeC::dim(0) == ShapeC::dim(1), "Expected ShapeC::dim(0) == ShapeC::dim(1)");
2559
+
2560
+ c = a;
2561
+
2562
+ for (int t=WP_TILE_THREAD_IDX; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
2563
+ {
2564
+ c.data(tile_coord(t, t)) += b.data(tile_coord(t));
2565
+ }
2566
+
2567
+ WP_TILE_SYNC();
2568
+
2569
+ return c;
2570
+ }
2571
+
2572
+ template <typename TileA, typename TileB, typename TileC, typename AdjTileA, typename AdjTileB, typename AdjTileC>
2573
+ inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTileA& adj_a, AdjTileB& adj_b, AdjTileC& adj_c, AdjTileC& adj_ret)
2574
+ {
2575
+ assert(false);
2576
+ }
2577
+
2578
+
2579
+ } // namespace wp
2580
+
2581
+
2582
+ #ifdef __clang__
2583
+ #pragma clang diagnostic pop
2584
+ #endif