warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/tile.h CHANGED
@@ -1,18 +1,57 @@
1
- /** Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #pragma once
10
19
 
11
20
  #include "builtin.h"
12
21
 
22
+ #ifdef __clang__
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif // __clang__
27
+
28
+ // Check if the CUDA toolkit is available
29
+ #if WP_ENABLE_CUDA || defined(__CUDACC_RTC__)
30
+
31
+ // If NVRTC is being used, do not include extra headers (NVRTC has built-in float4)
32
+ #ifdef __CUDACC_RTC__
33
+ // NVRTC: Use built-in float4 (no need for extra definitions)
34
+ #else
35
+ // NVCC: Include vector_types.h to get float4
36
+ #include <cuda_runtime.h>
37
+ #endif
38
+
39
+ #else
40
+ // If CUDA is not available (e.g., macOS build), manually define float4
41
+ struct alignas(16) float4 {
42
+ float x, y, z, w;
43
+ };
44
+ #endif
45
+
46
+ // only used while building the warp core library
47
+ #ifndef WP_TILE_BLOCK_DIM
48
+ #define WP_TILE_BLOCK_DIM 256
49
+ #endif
50
+
13
51
  #if !defined(__CUDA_ARCH__)
14
52
  #define WP_TILE_SHARED static
15
53
  #define WP_TILE_SYNC void
54
+
16
55
  #else
17
56
  #define WP_TILE_SHARED __shared__
18
57
  #define WP_TILE_SYNC __syncthreads
@@ -37,6 +76,14 @@
37
76
  #define WP_USE_ASYNC_PIPELINE 0
38
77
  #define WP_USE_REGISTER_GEMM 0
39
78
 
79
+ #if defined(__CUDACC_RTC__)
80
+ #define WP_TILE_THREAD_IDX threadIdx.x
81
+ #else
82
+ #define WP_TILE_THREAD_IDX 0
83
+ #endif //
84
+
85
+
86
+
40
87
  /* Tile Expressions
41
88
 
42
89
  [ ] Tiles
@@ -208,14 +255,14 @@ constexpr tile_coord_t<sizeof...(Ints)> tile_coord(Ints... idxs)
208
255
  }
209
256
 
210
257
  // helpers to construct a coord from a set of indices
211
- auto tile_coord(int i)
258
+ inline auto tile_coord(int i)
212
259
  {
213
260
  auto c = tile_coord_t<1>();
214
261
  c.indices[0] = i;
215
262
  return c;
216
263
  }
217
264
 
218
- auto tile_coord(int i, int j)
265
+ inline auto tile_coord(int i, int j)
219
266
  {
220
267
  auto c = tile_coord_t<2>();
221
268
  c.indices[0] = i;
@@ -223,7 +270,7 @@ auto tile_coord(int i, int j)
223
270
  return c;
224
271
  }
225
272
 
226
- auto tile_coord(int i, int j, int k)
273
+ inline auto tile_coord(int i, int j, int k)
227
274
  {
228
275
  auto c = tile_coord_t<3>();
229
276
  c.indices[0] = i;
@@ -232,7 +279,7 @@ auto tile_coord(int i, int j, int k)
232
279
  return c;
233
280
  }
234
281
 
235
- auto tile_coord(int i, int j, int k, int l)
282
+ inline auto tile_coord(int i, int j, int k, int l)
236
283
  {
237
284
  auto c = tile_coord_t<4>();
238
285
  c.indices[0] = i;
@@ -247,7 +294,7 @@ template <int... V>
247
294
  struct tile_tuple_t
248
295
  {
249
296
  static constexpr int N = sizeof...(V);
250
- static_assert(N > 0);
297
+ static_assert(N > 0, "Expected N > 0");
251
298
 
252
299
  static constexpr int data[N] = { V... };
253
300
 
@@ -400,7 +447,7 @@ struct tile_layout_register_t
400
447
 
401
448
  static inline CUDA_CALLABLE int linear_from_register(int reg)
402
449
  {
403
- return threadIdx.x + reg*WP_TILE_BLOCK_DIM;
450
+ return WP_TILE_THREAD_IDX + reg*WP_TILE_BLOCK_DIM;
404
451
  }
405
452
 
406
453
  static inline CUDA_CALLABLE int linear_from_coord(Coord c)
@@ -500,15 +547,6 @@ struct tile_register_t
500
547
  return data[reg];
501
548
  }
502
549
 
503
- // Returns the number of valid registers for this tile
504
- // i.e.: how many registers map to a valid coordinate.
505
- // When a tile's size is not aligned to the block dimension
506
- // some of the trailing registers may lie outside the valid range
507
- inline CUDA_CALLABLE int valid() const
508
- {
509
- return (int)floor(float(Size - threadIdx.x - 1)/WP_TILE_BLOCK_DIM) + 1;
510
- }
511
-
512
550
  inline CUDA_CALLABLE void assign(const tile_register_t<T, Layout>& tile)
513
551
  {
514
552
  for (int i=0; i < Layout::NumRegs; ++i)
@@ -535,7 +573,7 @@ struct tile_register_t
535
573
  // ensure any previously scheduled threads have finished reading from scratch
536
574
  WP_TILE_SYNC();
537
575
 
538
- if (threadIdx.x == thread)
576
+ if (WP_TILE_THREAD_IDX == thread)
539
577
  {
540
578
  scratch = data[reg];
541
579
  }
@@ -556,7 +594,7 @@ struct tile_register_t
556
594
  const int thread = Layout::thread_from_linear(linear);
557
595
  const int reg = Layout::register_from_linear(linear);
558
596
 
559
- if (threadIdx.x == thread)
597
+ if (WP_TILE_THREAD_IDX == thread)
560
598
  {
561
599
  data[reg] += adj_ret;
562
600
  }
@@ -659,7 +697,7 @@ struct tile_register_t
659
697
  // users can either specify a template explicitly or
660
698
  // pass in another concrete instance
661
699
  template<typename Tile>
662
- auto tile_register_like(Tile* t=NULL)
700
+ auto tile_register_like(Tile* t=nullptr)
663
701
  {
664
702
  using T = typename Tile::Type;
665
703
  using L = typename Tile::Layout;
@@ -685,26 +723,39 @@ inline CUDA_CALLABLE int tile_align(int num_bytes)
685
723
  return sign * ((num_bytes_abs + alignment - 1) / alignment) * alignment;
686
724
  }
687
725
 
688
- inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false)
726
+ inline CUDA_CALLABLE void* tile_alloc_shared(int num_bytes, bool init=false, bool check=false)
689
727
  {
690
728
  // we maintain a per-thread offset into dynamic
691
729
  // shared memory that allows us to keep track of
692
730
  // current use across dynamic function calls
693
- __shared__ int smem_base[WP_TILE_BLOCK_DIM];
731
+ WP_TILE_SHARED int smem_base[WP_TILE_BLOCK_DIM];
694
732
 
695
733
  if (init)
696
734
  {
697
- smem_base[threadIdx.x] = 0;
698
- return NULL;
735
+ smem_base[WP_TILE_THREAD_IDX] = 0;
736
+ return nullptr;
737
+ }
738
+ else if (check)
739
+ {
740
+ assert(smem_base[WP_TILE_THREAD_IDX] == 0);
741
+ return nullptr;
699
742
  }
700
743
  else
701
744
  {
702
- const int offset = smem_base[threadIdx.x];
745
+ const int offset = smem_base[WP_TILE_THREAD_IDX];
703
746
 
704
747
  // one entry per-thread so no need for synchronization
705
- smem_base[threadIdx.x] += tile_align(num_bytes);
748
+ smem_base[WP_TILE_THREAD_IDX] += tile_align(num_bytes);
706
749
 
750
+ #ifdef __CUDA_ARCH__
707
751
  extern __shared__ char dynamic_smem_base[];
752
+ #else
753
+ // on CPU allocate a fixed 256k block to use for shared allocs
754
+ static const int max_cpu_shared = 256*1024;
755
+ static char dynamic_smem_base[max_cpu_shared];
756
+
757
+ assert(smem_base[WP_TILE_THREAD_IDX] <= max_cpu_shared);
758
+ #endif
708
759
  return &(dynamic_smem_base[offset]);
709
760
  }
710
761
  }
@@ -838,12 +889,12 @@ struct tile_shared_t
838
889
  bool initialized;
839
890
 
840
891
  // default initialization (non-initialized)
841
- inline CUDA_CALLABLE tile_shared_t() : data(NULL), grad(NULL), initialized(false)
892
+ inline CUDA_CALLABLE tile_shared_t() : data(nullptr), grad(nullptr), initialized(false)
842
893
  {
843
894
  }
844
895
 
845
896
  // initialize from an existing tile's memory
846
- inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=NULL, bool initialized=true) : data(data), grad(grad), initialized(initialized)
897
+ inline CUDA_CALLABLE tile_shared_t(T* data, T* grad=nullptr, bool initialized=true) : data(data), grad(grad), initialized(initialized)
847
898
  {
848
899
  }
849
900
 
@@ -869,6 +920,7 @@ struct tile_shared_t
869
920
  }
870
921
 
871
922
 
923
+ /*
872
924
  // construct from another shared tile, this constructor
873
925
  // is invoked for reshape operations like `wp.tile_transpose()`
874
926
  template <typename OtherT, typename OtherLayout>
@@ -877,7 +929,7 @@ struct tile_shared_t
877
929
  using OtherTile = tile_shared_t<OtherT, OtherLayout>;
878
930
 
879
931
  // check dimensions are compatible
880
- static_assert(Size == OtherTile::Size);
932
+ static_assert(Size == OtherTile::Size, "Expected Size == OtherTile::Size");
881
933
 
882
934
  // alias tile directly
883
935
  data = rhs.data;
@@ -886,6 +938,7 @@ struct tile_shared_t
886
938
 
887
939
  return *this;
888
940
  }
941
+ */
889
942
 
890
943
  // assign from a global tile (load)
891
944
  inline CUDA_CALLABLE auto& operator=(const tile_global_t<T, typename Layout::Shape>& t)
@@ -903,7 +956,7 @@ struct tile_shared_t
903
956
  if (initialized)
904
957
  WP_TILE_SYNC();
905
958
 
906
- for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
959
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
907
960
  data(i) = x;
908
961
 
909
962
  initialized = true;
@@ -914,7 +967,7 @@ struct tile_shared_t
914
967
  // in-place zero
915
968
  inline CUDA_CALLABLE void zero()
916
969
  {
917
- for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
970
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
918
971
  data(i) = T(0);
919
972
 
920
973
  WP_TILE_SYNC();
@@ -964,7 +1017,7 @@ struct tile_shared_t
964
1017
  // in-place gradient zero
965
1018
  inline CUDA_CALLABLE void grad_zero()
966
1019
  {
967
- for (int i=threadIdx.x; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
1020
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i+= WP_TILE_BLOCK_DIM)
968
1021
  grad(i) = T(0);
969
1022
 
970
1023
  WP_TILE_SYNC();
@@ -1004,7 +1057,7 @@ struct tile_shared_t
1004
1057
  CUDA_CALLABLE void grad_add(const tile_global_t<T, typename Layout::Shape>& global)
1005
1058
  {
1006
1059
  WP_PRAGMA_UNROLL
1007
- for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1060
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1008
1061
  {
1009
1062
  auto c = Layout::coord_from_linear(i);
1010
1063
  T g = global.load_grad(c);
@@ -1072,6 +1125,8 @@ struct tile_shared_t
1072
1125
  template <typename Global>
1073
1126
  inline CUDA_CALLABLE void copy_to_global(const Global& dest)
1074
1127
  {
1128
+
1129
+ #if defined(__CUDA_ARCH__)
1075
1130
  // vectorized loads for specific input/output shapes
1076
1131
  if constexpr (Layout::Shape::N == 2)
1077
1132
  {
@@ -1100,7 +1155,7 @@ struct tile_shared_t
1100
1155
  const int stride_j = 1;
1101
1156
 
1102
1157
  WP_PRAGMA_UNROLL
1103
- for (int i=threadIdx.x; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
1158
+ for (int i=WP_TILE_THREAD_IDX; i < SrcLayout::Size; i += WP_TILE_BLOCK_DIM)
1104
1159
  {
1105
1160
  auto c = SrcLayout::coord_from_linear(i);
1106
1161
 
@@ -1111,17 +1166,18 @@ struct tile_shared_t
1111
1166
  }
1112
1167
  }
1113
1168
 
1169
+ #endif //defined(__CUDA_ARCH__)
1170
+
1114
1171
  // scalar bounds checked path
1115
1172
  WP_PRAGMA_UNROLL
1116
- for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1173
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1117
1174
  {
1118
1175
  auto c = Layout::coord_from_linear(i);
1119
1176
  dest.store(c, data(i));
1120
1177
  }
1121
1178
  }
1122
1179
 
1123
- __device__ __forceinline__
1124
- void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
1180
+ inline CUDA_CALLABLE void cp_async_global_to_shared_128(float4* shared_dest, const float4* global_src)
1125
1181
  {
1126
1182
  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1127
1183
 
@@ -1143,8 +1199,7 @@ struct tile_shared_t
1143
1199
  #endif
1144
1200
  }
1145
1201
 
1146
- __device__ __forceinline__
1147
- void cp_async_commit_and_wait_all_128()
1202
+ inline CUDA_CALLABLE void cp_async_commit_and_wait_all_128()
1148
1203
  {
1149
1204
  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
1150
1205
  asm volatile(
@@ -1159,6 +1214,8 @@ struct tile_shared_t
1159
1214
  if (initialized)
1160
1215
  WP_TILE_SYNC();
1161
1216
 
1217
+ #if defined(__CUDA_ARCH__)
1218
+
1162
1219
  // vectorized loads for specific input/output shapes
1163
1220
  if constexpr (Layout::Shape::N == 2)
1164
1221
  {
@@ -1187,7 +1244,7 @@ struct tile_shared_t
1187
1244
  const int stride_j = 1;
1188
1245
 
1189
1246
  WP_PRAGMA_UNROLL
1190
- for (int i=threadIdx.x; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
1247
+ for (int i=WP_TILE_THREAD_IDX; i < DestLayout::Size; i += WP_TILE_BLOCK_DIM)
1191
1248
  {
1192
1249
  auto c = DestLayout::coord_from_linear(i);
1193
1250
 
@@ -1208,9 +1265,11 @@ struct tile_shared_t
1208
1265
  }
1209
1266
  }
1210
1267
 
1268
+ #endif //defined(__CUDA_ARCH__)
1269
+
1211
1270
  // scalar bounds checked path
1212
1271
  WP_PRAGMA_UNROLL
1213
- for (int i=threadIdx.x; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1272
+ for (int i=WP_TILE_THREAD_IDX; i < Layout::Size; i += WP_TILE_BLOCK_DIM)
1214
1273
  {
1215
1274
  auto c = Layout::coord_from_linear(i);
1216
1275
  data(i) = src.load(c);
@@ -1323,7 +1382,7 @@ struct tile_shared_t
1323
1382
 
1324
1383
  inline CUDA_CALLABLE void print(bool reverse=false) const
1325
1384
  {
1326
- if (threadIdx.x != 0)
1385
+ if (WP_TILE_THREAD_IDX != 0)
1327
1386
  return;
1328
1387
 
1329
1388
  if (reverse)
@@ -1350,13 +1409,13 @@ void tile_register_t<T, L>::print() const
1350
1409
  // create a temporary shared tile so that
1351
1410
  // we can print it deterministically
1352
1411
  WP_TILE_SHARED T smem[L::Size];
1353
- tile_shared_t<T, tile_layout_strided_t<typename L::Shape>> scratch(smem, NULL);
1412
+ tile_shared_t<T, tile_layout_strided_t<typename L::Shape>, false> scratch(smem, nullptr);
1354
1413
 
1355
1414
  scratch.assign(*this);
1356
1415
 
1357
1416
  WP_TILE_SYNC();
1358
1417
 
1359
- if (threadIdx.x == 0)
1418
+ if (WP_TILE_THREAD_IDX == 0)
1360
1419
  {
1361
1420
  scratch.print_values(scratch.data, 0);
1362
1421
 
@@ -1383,7 +1442,7 @@ inline CUDA_CALLABLE void print(const tile_shared_t<T, L, Owner>& t) { t.print()
1383
1442
  template <typename T, typename L, bool O>
1384
1443
  inline CUDA_CALLABLE int len(const tile_shared_t<T, L, O>& t)
1385
1444
  {
1386
- return Tile::Layout::Shape::dim(0);
1445
+ return L::Shape::dim(0);
1387
1446
  }
1388
1447
 
1389
1448
  template <typename T, typename L, bool O, typename AdjTile>
@@ -1394,7 +1453,7 @@ inline CUDA_CALLABLE void adj_len(const tile_shared_t<T,L,O>& t, const AdjTile&
1394
1453
  template <typename T, typename L>
1395
1454
  inline CUDA_CALLABLE int len(const tile_register_t<T, L>& t)
1396
1455
  {
1397
- return Tile::Layout::Shape::dim(0);
1456
+ return L::Shape::dim(0);
1398
1457
  }
1399
1458
 
1400
1459
  template <typename T, typename L, typename AdjTile>
@@ -1416,12 +1475,16 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1416
1475
 
1417
1476
  { constexpr int size = Shape::size();
1418
1477
  T* data = (T*)tile_alloc_shared(size*sizeof(T));
1419
- T* grad = NULL;
1478
+ T* grad = nullptr;
1420
1479
 
1421
1480
  #if FP_CHECK
1422
1481
 
1423
- for (int i=threadIdx.x; i < size; i+= WP_TILE_BLOCK_DIM)
1424
- data[i] = T(nanf(""));
1482
+ // initialize tile to quiet nan
1483
+ uint32_t qnanbits = 0x7FC00000;
1484
+ float qnan = *(float*)(&qnanbits);
1485
+
1486
+ for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1487
+ data[i] = T(qnan);
1425
1488
 
1426
1489
  WP_TILE_SYNC();
1427
1490
 
@@ -1432,7 +1495,7 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1432
1495
  {
1433
1496
  grad = (T*)tile_alloc_shared(size*sizeof(T));
1434
1497
 
1435
- for (int i=threadIdx.x; i < size; i+= WP_TILE_BLOCK_DIM)
1498
+ for (int i=WP_TILE_THREAD_IDX; i < size; i+= WP_TILE_BLOCK_DIM)
1436
1499
  grad[i] = T(0);
1437
1500
 
1438
1501
  WP_TILE_SYNC();
@@ -1441,30 +1504,6 @@ inline CUDA_CALLABLE auto tile_alloc_empty()
1441
1504
  return tile_shared_t<T, tile_layout_strided_t<Shape>>(data, grad);
1442
1505
  }
1443
1506
 
1444
- template <typename T, int M, int N, bool RequiresGrad>
1445
- inline CUDA_CALLABLE auto tile_alloc_zeros()
1446
- {
1447
- // compute the total storage required for the tile (may be different from M*N) for broadcast tiles
1448
- constexpr int Len = M*N;
1449
- T* data = (T*)tile_alloc_shared(Len*sizeof(T));
1450
- T* grad = NULL;
1451
-
1452
- for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1453
- data[i] = T(0);
1454
-
1455
- if (RequiresGrad)
1456
- {
1457
- grad = (T*)tile_alloc_shared(Len*sizeof(T));
1458
-
1459
- for (int i=threadIdx.x; i < Len; i+= WP_TILE_BLOCK_DIM)
1460
- grad[i] = T(0);
1461
- }
1462
-
1463
- WP_TILE_SYNC();
1464
-
1465
- return tile_shared_t<T, tile_layout_strided_t<tile_shape_t<M, N>>(data, grad);
1466
- }
1467
-
1468
1507
 
1469
1508
  //-----------------------------------------------------------------------------------------------------
1470
1509
  // High level entry points for each op (correspond to one Warp builtin)
@@ -1476,7 +1515,7 @@ inline CUDA_CALLABLE auto tile(const T& x)
1476
1515
  tile_register_t<T, tile_layout_register_t<tile_shape_t<WP_TILE_BLOCK_DIM>>> result;
1477
1516
 
1478
1517
  using Layout = typename decltype(result)::Layout;
1479
- static_assert(Layout::NumRegs == 1);
1518
+ static_assert(Layout::NumRegs == 1, "Expected Layout::NumRegs == 1");
1480
1519
 
1481
1520
  result.data[0] = x;
1482
1521
  return result;
@@ -1489,7 +1528,7 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1489
1528
  tile_register_t<T, tile_layout_register_t<tile_shape_t<Length, WP_TILE_BLOCK_DIM>>> result;
1490
1529
 
1491
1530
  using Layout = typename decltype(result)::Layout;
1492
- static_assert(Layout::NumRegs == Length);
1531
+ static_assert(Layout::NumRegs == Length, "Expected Layout::NumRegs == Length");
1493
1532
 
1494
1533
  for (int i=0; i < Length; ++i)
1495
1534
  result.data[i] = x[i];
@@ -1501,8 +1540,8 @@ inline CUDA_CALLABLE auto tile(const wp::vec_t<Length, T>& x)
1501
1540
  template <typename T, typename AdjTile>
1502
1541
  inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1503
1542
  {
1504
- static_assert(AdjTile::Layout::Shape::N == 1);
1505
- static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM);
1543
+ static_assert(AdjTile::Layout::Shape::N == 1, "Expected AdjTile::Layout::Shape::N == 1");
1544
+ static_assert(AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(0) == WP_TILE_BLOCK_DIM");
1506
1545
 
1507
1546
  auto adj_reg = adj_ret.copy_to_register();
1508
1547
 
@@ -1512,9 +1551,9 @@ inline CUDA_CALLABLE void adj_tile(const T& x, T& adj_x, AdjTile& adj_ret)
1512
1551
  template <typename T, unsigned Length, typename AdjTile>
1513
1552
  inline CUDA_CALLABLE void adj_tile(const wp::vec_t<Length, T>& x, wp::vec_t<Length, T>& adj_x, AdjTile& adj_ret)
1514
1553
  {
1515
- static_assert(AdjTile::Layout::Shape::N == 2);
1516
- static_assert(AdjTile::Layout::Shape::dim(0) == Length);
1517
- static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM);
1554
+ static_assert(AdjTile::Layout::Shape::N == 2, "Expected AdjTile::Layout::Shape::N == 2");
1555
+ static_assert(AdjTile::Layout::Shape::dim(0) == Length, "Expected AdjTile::Layout::Shape::dim(0) == Length");
1556
+ static_assert(AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM, "Expected AdjTile::Layout::Shape::dim(1) == WP_TILE_BLOCK_DIM");
1518
1557
 
1519
1558
  auto adj_reg = adj_ret.copy_to_register();
1520
1559
 
@@ -1692,7 +1731,7 @@ inline CUDA_CALLABLE void adj_tile_store(array_t<T>& dest, Coord c, Tile& t, arr
1692
1731
  if (adj_dest.data)
1693
1732
  src.data.grad = adj_dest.data;
1694
1733
 
1695
- if (src.data.grad == NULL)
1734
+ if (src.data.grad == nullptr)
1696
1735
  return;
1697
1736
 
1698
1737
  adj_t.grad_add(src);
@@ -1927,7 +1966,6 @@ void adj_tile_extract(Tile& t, int i, int j, int k, AdjTile& adj_t, int adj_i, i
1927
1966
  template<typename Tile, typename AdjTile>
1928
1967
  void adj_tile_extract(Tile& t, int i, int j, int k, int l, AdjTile& adj_t, int adj_i, int adj_j, int adj_k, int adj_l, typename Tile::Type adj_ret) { adj_t.adj_extract(tile_coord(i, j, k, l), adj_ret); }
1929
1968
 
1930
- #if WP_USE_REGISTER_GEMM
1931
1969
 
1932
1970
  namespace partitioned_gemm
1933
1971
  {
@@ -2033,9 +2071,11 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
2033
2071
  auto B_tile = partition_t<TILE_K, TILE_N, TileB>(B);
2034
2072
  auto C_tile = partition_t<TILE_M, TILE_N, TileC>(out);
2035
2073
 
2074
+ //static_assert(is_same<typename TileA::Type, typename TileB::Type>::value);
2075
+
2036
2076
  const int length = partition_size(C_tile);
2037
2077
 
2038
- for (int t=threadIdx.x; t < length; t += blockDim.x)
2078
+ for (int t=WP_TILE_THREAD_IDX; t < length; t += WP_TILE_BLOCK_DIM)
2039
2079
  {
2040
2080
  int i, j;
2041
2081
  partition_coord(C_tile, t, i, j);
@@ -2055,10 +2095,102 @@ inline CUDA_CALLABLE void matmul(TileA& A, TileB& B, TileC& out)
2055
2095
  partition_store(C_tile, i, j, sum);
2056
2096
  }
2057
2097
  }
2058
-
2059
- } // namespace partition_gemm
2060
2098
 
2061
- #endif // WP_USE_REGISTER_GEMM
2099
+ template <typename LayoutA, typename LayoutB, typename LayoutC, typename StorageA, typename StorageB, typename StorageC, typename T>
2100
+ inline CUDA_CALLABLE void scalar_matmul(const StorageA& A, const StorageB& B, StorageC& C, T scale)
2101
+ {
2102
+ for (int t=WP_TILE_THREAD_IDX; t < LayoutC::Size; t += WP_TILE_BLOCK_DIM)
2103
+ {
2104
+ auto coord = LayoutC::coord_from_linear(t);
2105
+
2106
+ int i = coord[0];
2107
+ int j = coord[1];
2108
+
2109
+ // accumulator
2110
+ auto sum = C(coord)*scale;
2111
+
2112
+ WP_PRAGMA_UNROLL
2113
+ for (int k=0; k < LayoutA::Shape::dim(1); k++)
2114
+ {
2115
+ const auto a = A(tile_coord(i, k));
2116
+ const auto b = B(tile_coord(k, j));
2117
+
2118
+ sum = muladd<decltype(sum)>(a, b, sum);
2119
+ }
2120
+
2121
+ C(coord) = sum;
2122
+ }
2123
+ }
2124
+
2125
+ template <typename TileA, typename TileL>
2126
+ inline CUDA_CALLABLE void scalar_cholesky(TileA& A, TileL& L)
2127
+ {
2128
+ using T = typename TileA::Type;
2129
+ constexpr int n = TileA::Layout::Shape::dim(1);
2130
+
2131
+ for (int j=0; j < n; ++j)
2132
+ {
2133
+ T s = A.data(tile_coord(j, j));
2134
+
2135
+ for (int k=0; k < j; ++k)
2136
+ {
2137
+ T r = L.data(tile_coord(j, k));
2138
+ s -= r * r;
2139
+ }
2140
+
2141
+ s = wp::sqrt(s);
2142
+ T invS = 1.0 / s;
2143
+
2144
+ L.data(tile_coord(j, j)) = s;
2145
+
2146
+ for (int i=j+1; i < n; ++i)
2147
+ {
2148
+ s = A.data(tile_coord(i, j));
2149
+
2150
+ for (int k=0; k < j; ++k)
2151
+ {
2152
+ s -= L.data(tile_coord(i, k)) * L.data(tile_coord(j, k));
2153
+ }
2154
+
2155
+ L.data(tile_coord(i, j)) = s * invS;
2156
+ }
2157
+
2158
+ // zero out upper triangular portion
2159
+ for (int k=j+1; k < n; ++k)
2160
+ {
2161
+ L.data(tile_coord(j,k)) = T(0.0);
2162
+ }
2163
+ }
2164
+ }
2165
+
2166
+ template <typename TileL, typename TileX, typename TileY>
2167
+ inline CUDA_CALLABLE void scalar_cholesky_solve(TileL& L, TileX& X, TileY& Y)
2168
+ {
2169
+ using T = typename TileL::Type;
2170
+ constexpr int n = TileL::Layout::Shape::dim(1);
2171
+
2172
+ for (int i=0; i < n; ++i)
2173
+ {
2174
+ T s = Y.data(tile_coord(i));
2175
+
2176
+ for (int j=0; j < i; ++j)
2177
+ s -= L.data(tile_coord(i,j)) * X.data(tile_coord(j));
2178
+
2179
+ X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2180
+ }
2181
+
2182
+ for (int i=n-1; i >= 0; --i)
2183
+ {
2184
+ T s = X.data(tile_coord(i));
2185
+
2186
+ for (int j=i+1; j < n; ++j)
2187
+ s -= L.data(tile_coord(j, i)) * X.data(tile_coord(j));
2188
+
2189
+ X.data(tile_coord(i)) = s / L.data(tile_coord(i, i));
2190
+ }
2191
+ }
2192
+
2193
+ } // namespace partition_gemm
2062
2194
 
2063
2195
 
2064
2196
  template <int Add, typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
@@ -2068,19 +2200,19 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
2068
2200
  using ShapeB = typename TileB::Layout::Shape;
2069
2201
  using ShapeC = typename TileC::Layout::Shape;
2070
2202
 
2071
- static_assert(ShapeA::N == 2);
2072
- static_assert(ShapeB::N == 2);
2073
- static_assert(ShapeC::N == 2);
2203
+ static_assert(ShapeA::N == 2, "Expected ShapeA::N == 2");
2204
+ static_assert(ShapeB::N == 2, "Expected ShapeB::N == 2");
2205
+ static_assert(ShapeC::N == 2, "Expected ShapeC::N == 2");
2074
2206
 
2075
- static_assert(ShapeA::dim(1) == ShapeB::dim(0));
2076
- static_assert(ShapeC::dim(0) == ShapeA::dim(0));
2077
- static_assert(ShapeC::dim(1) == ShapeB::dim(1));
2207
+ static_assert(ShapeA::dim(1) == ShapeB::dim(0), "Expected ShapeA::dim(1) == ShapeB::dim(0)");
2208
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
2209
+ static_assert(ShapeC::dim(1) == ShapeB::dim(1), "Expected ShapeC::dim(1) == ShapeB::dim(1)");
2078
2210
 
2079
2211
 
2080
2212
  using T = typename TileA::Type;
2081
2213
 
2082
- #if WP_USE_REGISTER_GEMM
2083
- partitioned_gemm::matmul(A, B, C);
2214
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2215
+ partitioned_gemm::scalar_matmul<typename TileA::Layout, typename TileB::Layout, typename TileC::Layout>(A.data, B.data, C.data, T(Add));
2084
2216
  #else
2085
2217
  fun_forward(T(1.0), A.data.ptr, B.data.ptr, T(Add), C.data.ptr);
2086
2218
  #endif
@@ -2090,6 +2222,7 @@ TileC& tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, Ti
2090
2222
  return C;
2091
2223
  }
2092
2224
 
2225
+
2093
2226
  // backward for the wp.tile_matmul(a, b, out) syntax
2094
2227
  template <typename Fwd, typename AdjA, typename AdjB, typename TileA, typename TileB, typename TileC>
2095
2228
  void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B, TileA& A, TileB& B, TileC& C,
@@ -2097,8 +2230,17 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2097
2230
  {
2098
2231
  using T = typename TileA::Type;
2099
2232
 
2233
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2234
+ auto At = tile_transpose(A);
2235
+ auto Bt = tile_transpose(B);
2236
+
2237
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
2238
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
2239
+ #else
2100
2240
  fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
2101
2241
  fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
2242
+ #endif
2243
+
2102
2244
  WP_TILE_SYNC();
2103
2245
  }
2104
2246
 
@@ -2109,11 +2251,30 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2109
2251
  {
2110
2252
  using T = typename TileA::Type;
2111
2253
 
2254
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2255
+ auto At = tile_transpose(A);
2256
+ auto Bt = tile_transpose(B);
2257
+
2258
+ partitioned_gemm::scalar_matmul<typename TileC::Layout, typename decltype(Bt)::Layout, typename TileA::Layout>(adj_C.grad, Bt.data, adj_A.grad, T(1.0));
2259
+ partitioned_gemm::scalar_matmul<typename decltype(At)::Layout, typename TileC::Layout, typename TileB::Layout>(At.data, adj_C.grad, adj_B.grad, T(1.0));
2260
+ #else
2112
2261
  fun_backward_A(T(1.0), adj_C.grad.ptr, B.data.ptr, T(1.0), adj_A.grad.ptr);
2113
2262
  fun_backward_B(T(1.0), A.data.ptr, adj_C.grad.ptr, T(1.0), adj_B.grad.ptr);
2263
+ #endif
2264
+
2114
2265
  WP_TILE_SYNC();
2115
2266
  }
2116
2267
 
2268
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2269
+
2270
+ #define tile_fft()
2271
+ #define tile_ifft()
2272
+
2273
+ #define adj_tile_fft()
2274
+ #define adj_tile_ifft()
2275
+
2276
+ #else
2277
+
2117
2278
  // TODO(lcambier): use a properly overaligned complex type that matches cuFFTDx's expectation
2118
2279
  // and remove the need for __align__(16) dtypes data[...]
2119
2280
  #define tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, Xinout) \
@@ -2149,12 +2310,21 @@ void adj_tile_matmul(Fwd fun_forward, AdjA fun_backward_A, AdjB fun_backward_B,
2149
2310
  tile_fft(function_name, dtype, shared_memory_size, batch_size, ept, adj_Xinout); \
2150
2311
  } while (0)
2151
2312
 
2313
+ #endif // !defined(__CUDA_ARCH__)
2314
+
2152
2315
  template <typename Fwd, typename TileA, typename TileL>
2153
2316
  TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2154
2317
  {
2155
2318
  // Copy to L
2156
2319
  L = A;
2157
2320
 
2321
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2322
+
2323
+ partitioned_gemm::scalar_cholesky(A, L);
2324
+
2325
+ #else
2326
+
2327
+
2158
2328
  // Call cholesky on L
2159
2329
  WP_TILE_SYNC();
2160
2330
 
@@ -2165,7 +2335,7 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2165
2335
  // Zero-out the upper triangular part of L
2166
2336
 
2167
2337
  WP_PRAGMA_UNROLL
2168
- for (int i=threadIdx.x; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
2338
+ for (int i=WP_TILE_THREAD_IDX; i < TileL::Layout::Size; i += WP_TILE_BLOCK_DIM)
2169
2339
  {
2170
2340
  auto c = TileL::Layout::coord_from_linear(i);
2171
2341
 
@@ -2174,7 +2344,9 @@ TileL& tile_cholesky(Fwd fun_forward, TileA& A, TileL& L)
2174
2344
  }
2175
2345
 
2176
2346
  WP_TILE_SYNC();
2177
-
2347
+
2348
+ #endif
2349
+
2178
2350
  return L;
2179
2351
  }
2180
2352
 
@@ -2191,6 +2363,12 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2191
2363
 
2192
2364
  Y = X;
2193
2365
 
2366
+ #if !defined(__CUDA_ARCH__) || WP_ENABLE_MATHDX == 0
2367
+
2368
+ partitioned_gemm::scalar_cholesky_solve(L, X, Y);
2369
+
2370
+ #else
2371
+
2194
2372
  // Call cholesky solve on L & y
2195
2373
 
2196
2374
  WP_TILE_SYNC();
@@ -2199,6 +2377,8 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2199
2377
 
2200
2378
  WP_TILE_SYNC();
2201
2379
 
2380
+ #endif
2381
+
2202
2382
  return Y;
2203
2383
  }
2204
2384
 
@@ -2211,7 +2391,7 @@ TileY& tile_cholesky_solve(Fwd fun_forward, TileL& L, TileX& X, TileY& Y)
2211
2391
  template <typename Tile>
2212
2392
  inline CUDA_CALLABLE auto tile_transpose(Tile& t)
2213
2393
  {
2214
- static_assert(Tile::Layout::Shape::N == 2);
2394
+ static_assert(Tile::Layout::Shape::N == 2, "Expected Tile::Layout::Shape::N == 2");
2215
2395
 
2216
2396
  // alias incoming tile
2217
2397
  constexpr int M = Tile::Layout::Shape::dim(0);
@@ -2232,13 +2412,34 @@ inline CUDA_CALLABLE void adj_tile_transpose(Tile& t, Tile& adj_t, AdjTile& adj_
2232
2412
  adj_t.assign(tile_add(a,b));
2233
2413
  }
2234
2414
 
2415
+ template <int N, int StrideN, typename Tile>
2416
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2417
+ {
2418
+ // alias incoming tile with new strides
2419
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<N>, tile_stride_t<StrideN>>, false>(t.data.ptr, t.grad.ptr);
2420
+ }
2421
+
2235
2422
  template <int M, int N, int StrideM, int StrideN, typename Tile>
2236
2423
  inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2237
- {
2424
+ {
2238
2425
  // alias incoming tile with new strides
2239
2426
  return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N>, tile_stride_t<StrideM, StrideN>>, false>(t.data.ptr, t.grad.ptr);
2240
2427
  }
2241
2428
 
2429
+ template <int M, int N, int O, int StrideM, int StrideN, int StrideO, typename Tile>
2430
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2431
+ {
2432
+ // alias incoming tile with new strides
2433
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O>, tile_stride_t<StrideM, StrideN, StrideO>>, false>(t.data.ptr, t.grad.ptr);
2434
+ }
2435
+
2436
+ template <int M, int N, int O, int P, int StrideM, int StrideN, int StrideO, int StrideP, typename Tile>
2437
+ inline CUDA_CALLABLE auto tile_broadcast(Tile& t)
2438
+ {
2439
+ // alias incoming tile with new strides
2440
+ return tile_shared_t<typename Tile::Type, tile_layout_strided_t<tile_shape_t<M, N, O, P>, tile_stride_t<StrideM, StrideN, StrideO, StrideP>>, false>(t.data.ptr, t.grad.ptr);
2441
+ }
2442
+
2242
2443
  template <typename Tile, typename AdjTile>
2243
2444
  inline CUDA_CALLABLE void adj_tile_broadcast(Tile& t, Tile& adj_t, AdjTile& adj_ret)
2244
2445
  {
@@ -2252,7 +2453,7 @@ inline CUDA_CALLABLE auto tile_view(Tile& t, Indices... indices)
2252
2453
 
2253
2454
  // return new tile with same strides
2254
2455
  typename Tile::Type* data_ptr = &t.data(c);
2255
- typename Tile::Type* grad_ptr = NULL;
2456
+ typename Tile::Type* grad_ptr = nullptr;
2256
2457
 
2257
2458
  if (t.grad.ptr)
2258
2459
  grad_ptr = &t.grad(c);
@@ -2297,7 +2498,7 @@ inline CUDA_CALLABLE void tile_assign(TileA& dest, TileB& src, const Coord& offs
2297
2498
  {
2298
2499
  using Layout = typename TileB::Layout;
2299
2500
 
2300
- for (int t=threadIdx.x; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2501
+ for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2301
2502
  {
2302
2503
  auto c = Layout::coord_from_linear(t);
2303
2504
  dest.data(c + offset) = src.data(c);
@@ -2312,7 +2513,7 @@ inline CUDA_CALLABLE void adj_tile_assign(TileA& dest, TileB& src, Coord offset,
2312
2513
  {
2313
2514
  using Layout = typename TileB::Layout;
2314
2515
 
2315
- for (int t=threadIdx.x; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2516
+ for (int t=WP_TILE_THREAD_IDX; t < Layout::Size; t += WP_TILE_BLOCK_DIM)
2316
2517
  {
2317
2518
  auto c = Layout::coord_from_linear(t);
2318
2519
  src.grad(c) += dest.grad(c + offset);
@@ -2351,14 +2552,14 @@ inline CUDA_CALLABLE TileC& tile_diag_add(TileA& a, TileB& b, TileC& c)
2351
2552
  using ShapeB = typename TileB::Layout::Shape;
2352
2553
  using ShapeC = typename TileC::Layout::Shape;
2353
2554
 
2354
- static_assert(ShapeA::dim(0) == ShapeA::dim(1));
2355
- static_assert(ShapeB::dim(0) == ShapeA::dim(0));
2356
- static_assert(ShapeC::dim(0) == ShapeA::dim(0));
2357
- static_assert(ShapeC::dim(0) == ShapeC::dim(1));
2555
+ static_assert(ShapeA::dim(0) == ShapeA::dim(1), "Expected ShapeA::dim(0) == ShapeA::dim(1)");
2556
+ static_assert(ShapeB::dim(0) == ShapeA::dim(0), "Expected ShapeB::dim(0) == ShapeA::dim(0)");
2557
+ static_assert(ShapeC::dim(0) == ShapeA::dim(0), "Expected ShapeC::dim(0) == ShapeA::dim(0)");
2558
+ static_assert(ShapeC::dim(0) == ShapeC::dim(1), "Expected ShapeC::dim(0) == ShapeC::dim(1)");
2358
2559
 
2359
2560
  c = a;
2360
2561
 
2361
- for (int t=threadIdx.x; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
2562
+ for (int t=WP_TILE_THREAD_IDX; t < ShapeA::dim(0); t += WP_TILE_BLOCK_DIM)
2362
2563
  {
2363
2564
  c.data(tile_coord(t, t)) += b.data(tile_coord(t));
2364
2565
  }
@@ -2377,3 +2578,7 @@ inline CUDA_CALLABLE void adj_tile_diag_add(TileA& a, TileB& b, TileC& c, AdjTil
2377
2578
 
2378
2579
  } // namespace wp
2379
2580
 
2581
+
2582
+ #ifdef __clang__
2583
+ #pragma clang diagnostic pop
2584
+ #endif