warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,373 +0,0 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
- */
8
-
9
- #include "builtin.h"
10
- #include "temp_buffer.h"
11
- #include "cuda_util.h"
12
-
13
- #include "cutlass/cutlass.h"
14
- #include "cutlass/gemm/device/gemm_universal.h"
15
- #include "cutlass/util/device_memory.h"
16
-
17
- #define F16_STR "<f2"
18
- #define F32_STR "<f4"
19
- #define F64_STR "<f8"
20
-
21
- namespace wp {
22
-
23
- template <typename Gemm>
24
- bool run_gemm(int m, int n, int k, int batch_count, const void* a, const void* b, const void* c, void* d, float alpha, float beta) {
25
- //
26
- // Initialize arguments
27
- //
28
- typename Gemm::EpilogueOutputOp::Params epilogue_params(
29
- (typename Gemm::EpilogueOutputOp::ElementCompute)alpha,
30
- (typename Gemm::EpilogueOutputOp::ElementCompute)beta);
31
-
32
- typename Gemm::Arguments arguments{
33
- batch_count == 1 ? cutlass::gemm::GemmUniversalMode::kGemm : cutlass::gemm::GemmUniversalMode::kBatched ,
34
- cutlass::gemm::GemmCoord{m, n, k}, // Problem size
35
- batch_count,
36
- epilogue_params,
37
- a, b, c, d,
38
- int64_t(m * k), int64_t(k * n), int64_t(m * n), int64_t(m * n), // Batch strides
39
- Gemm::LayoutA::packed({m, k}).stride(0), Gemm::LayoutB::packed({k, n}).stride(0), n, n
40
- };
41
-
42
- Gemm gemm;
43
- size_t workspace_size = Gemm::get_workspace_size(arguments);
44
- ScopedTemporary<> workspace(WP_CURRENT_CONTEXT, workspace_size);
45
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
46
- cutlass::Status status = gemm.initialize(arguments, workspace.buffer(), stream);
47
-
48
- if (status != cutlass::Status::kSuccess) {
49
- cudaError_t error = cudaGetLastError();
50
- std::cerr << "Error initializing GEMM: " << cudaGetErrorString(error) << "\n";
51
- return false;
52
- }
53
-
54
- //
55
- // Run the GEMM
56
- //
57
-
58
- status = gemm(stream);
59
- if (status != cutlass::Status::kSuccess) {
60
- cudaError_t error = cudaGetLastError();
61
- std::cerr << "Runtime error: " << cudaGetErrorString(error) << "\n";
62
- return false;
63
- }
64
-
65
- return true;
66
- }
67
-
68
- template <
69
- int ComputeCapability,
70
- typename Element_,
71
- typename LayoutA,
72
- typename LayoutB
73
- >
74
- struct DefaultGemmConfig;
75
-
76
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
77
-
78
- // Partial specialization for SM80 F64 Tensor Cores
79
- template <typename LayoutA, typename LayoutB>
80
- struct DefaultGemmConfig<80, double, LayoutA, LayoutB> {
81
- using Gemm = cutlass::gemm::device::GemmUniversal<
82
- double, LayoutA, // ElementA and LayoutA
83
- double, LayoutB, // ElementB and LayoutB
84
- double, cutlass::layout::RowMajor, // ElementC and LayoutC
85
- double, // ElementAccumulator
86
- cutlass::arch::OpClassTensorOp, // Operation type
87
- cutlass::arch::Sm80, // Architecture
88
- cutlass::gemm::GemmShape<128, 128, 16>, // ThreadblockShape
89
- cutlass::gemm::GemmShape<32, 64, 16>, // WarpShape
90
- cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
91
- cutlass::epilogue::thread::LinearCombination< // Epilogue
92
- double,
93
- 1,
94
- double,
95
- double>,
96
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
97
- 3 // Stages
98
- >;
99
- };
100
-
101
- // Partial specialization for SM80 F32 Tensor Cores
102
- template <typename LayoutA, typename LayoutB>
103
- struct DefaultGemmConfig<80, float, LayoutA, LayoutB> {
104
- using Gemm = cutlass::gemm::device::GemmUniversal<
105
- float, LayoutA, // ElementA and LayoutA
106
- float, LayoutB, // ElementB and LayoutB
107
- float, cutlass::layout::RowMajor, // ElementC and LayoutC
108
- float, // ElementAccumulator
109
- cutlass::arch::OpClassTensorOp, // Operation type
110
- cutlass::arch::Sm80, // Architecture
111
- cutlass::gemm::GemmShape<256, 128, 16>, // ThreadblockShape
112
- cutlass::gemm::GemmShape<64, 64, 16>, // WarpShape
113
- cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
114
- cutlass::epilogue::thread::LinearCombination< // Epilogue
115
- float,
116
- 128 / cutlass::sizeof_bits<float>::value,
117
- float,
118
- float>,
119
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
120
- 3, // Stages
121
- 4, 4, // AlignmentA and AlignmentB
122
- cutlass::arch::OpMultiplyAddFastF32 // Math mode -- use 3xTF32
123
- >;
124
- };
125
-
126
- // Partial specialization for SM80 F16 Tensor Cores
127
- template <typename LayoutA, typename LayoutB>
128
- struct DefaultGemmConfig<80, cutlass::half_t, LayoutA, LayoutB> {
129
- using Gemm = cutlass::gemm::device::GemmUniversal<
130
- cutlass::half_t, LayoutA, // ElementA and LayoutA
131
- cutlass::half_t, LayoutB, // ElementB and LayoutB
132
- cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
133
- cutlass::half_t, // ElementAccumulator
134
- cutlass::arch::OpClassTensorOp, // Operation type
135
- cutlass::arch::Sm80, // Architecture
136
- cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
137
- cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
138
- cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
139
- cutlass::epilogue::thread::LinearCombination< // Epilogue
140
- cutlass::half_t,
141
- 128 / cutlass::sizeof_bits<cutlass::half_t>::value,
142
- cutlass::half_t,
143
- cutlass::half_t>,
144
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
145
- 3 // Stages
146
- >;
147
- };
148
-
149
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
150
-
151
- // Partial specialization for SM75 F16 Tensor Cores
152
- template <typename LayoutA, typename LayoutB>
153
- struct DefaultGemmConfig<75, cutlass::half_t, LayoutA, LayoutB> {
154
- using Gemm = cutlass::gemm::device::GemmUniversal<
155
- cutlass::half_t, LayoutA, // ElementA and LayoutA
156
- cutlass::half_t, LayoutB, // ElementB and LayoutB
157
- cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
158
- cutlass::half_t, // ElementAccumulator
159
- cutlass::arch::OpClassTensorOp, // Operation type
160
- cutlass::arch::Sm75, // Architecture
161
- cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
162
- cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
163
- cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
164
- cutlass::epilogue::thread::LinearCombination< // Epilogue
165
- cutlass::half_t,
166
- 128 / cutlass::sizeof_bits<cutlass::half_t>::value,
167
- cutlass::half_t,
168
- cutlass::half_t>,
169
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
170
- 2 // Stages
171
- >;
172
- };
173
-
174
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
175
-
176
- // Partial specialization for SM70 F16 Tensor Cores
177
- template <typename LayoutA, typename LayoutB>
178
- struct DefaultGemmConfig<70, cutlass::half_t, LayoutA, LayoutB> {
179
- using Gemm = cutlass::gemm::device::GemmUniversal<
180
- cutlass::half_t, LayoutA, // ElementA and LayoutA
181
- cutlass::half_t, LayoutB, // ElementB and LayoutB
182
- cutlass::half_t, cutlass::layout::RowMajor, // ElementC and LayoutC
183
- cutlass::half_t, // ElementAccumulator
184
- cutlass::arch::OpClassTensorOp, // Operation type
185
- cutlass::arch::Sm70, // Architecture
186
- cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
187
- cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
188
- cutlass::gemm::GemmShape<8, 8, 4>, // Instruction Shape
189
- cutlass::epilogue::thread::LinearCombination< // Epilogue
190
- cutlass::half_t,
191
- 128 / cutlass::sizeof_bits<cutlass::half_t>::value,
192
- cutlass::half_t,
193
- cutlass::half_t>,
194
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
195
- 2 // Stages
196
- >;
197
- };
198
-
199
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
200
-
201
- // Partial specialization for SM50 SIMT
202
- template <typename Element, typename LayoutA, typename LayoutB>
203
- struct DefaultGemmConfig<50, Element, LayoutA, LayoutB> {
204
- using Gemm = cutlass::gemm::device::GemmUniversal<
205
- Element, LayoutA, // ElementA and LayoutA
206
- Element, LayoutB, // ElementB and LayoutB
207
- Element, cutlass::layout::RowMajor, // ElementC and LayoutC
208
- Element, // ElementAccumulator
209
- cutlass::arch::OpClassSimt, // Operation type
210
- cutlass::arch::Sm50, // Architecture
211
- cutlass::gemm::GemmShape<128, 128, 8>, // ThreadblockShape
212
- cutlass::gemm::GemmShape<32, 64, 8>, // WarpShape
213
- cutlass::gemm::GemmShape<1, 1, 1>, // Instruction Shape
214
- cutlass::epilogue::thread::LinearCombination< // Epilogue
215
- Element,
216
- 1,
217
- Element,
218
- Element>,
219
- cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, // Swizzling
220
- 2 // Stages
221
- >;
222
- };
223
-
224
- //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
225
-
226
- extern "C" {
227
-
228
- WP_API
229
- bool cutlass_gemm(
230
- void* context, int compute_capability,
231
- int m, int n, int k,
232
- const char* datatype_str,
233
- const void* a, const void* b, const void* c, void* d,
234
- float alpha, float beta,
235
- bool row_major_a, bool row_major_b,
236
- bool allow_tf32x3_arith,
237
- int batch_count) {
238
-
239
- std::string datatype(datatype_str);
240
-
241
- ContextGuard guard(context);
242
-
243
- // Specializations for using Tensor Cores and A/B RowMajor/ColumnMajor designations
244
- if (compute_capability == 80) {
245
- if (datatype == F64_STR) {
246
- if (row_major_a && row_major_b) {
247
- using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
248
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
249
- } else if (!row_major_a && row_major_b) {
250
- using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
251
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
252
- } else if (row_major_a && !row_major_b) {
253
- using Gemm = DefaultGemmConfig<80, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
254
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
255
- } else if (!row_major_a && !row_major_b) {
256
- using Gemm = DefaultGemmConfig<80, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
257
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
258
- }
259
- } else if (datatype == F32_STR && allow_tf32x3_arith) {
260
- if (row_major_a && row_major_b) {
261
- using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
262
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
263
- } else if (!row_major_a && row_major_b) {
264
- using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
265
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
266
- } else if (row_major_a && !row_major_b) {
267
- using Gemm = DefaultGemmConfig<80, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
268
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
269
- } else if (!row_major_a && !row_major_b) {
270
- using Gemm = DefaultGemmConfig<80, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
271
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
272
- }
273
- } else if (datatype == F16_STR) {
274
- if (row_major_a && row_major_b) {
275
- using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
276
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
277
- } else if (!row_major_a && row_major_b) {
278
- using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
279
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
280
- } else if (row_major_a && !row_major_b) {
281
- using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
282
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
283
- } else if (!row_major_a && !row_major_b) {
284
- using Gemm = DefaultGemmConfig<80, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
285
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
286
- }
287
- }
288
- } else if (compute_capability == 75) {
289
- if (datatype == F16_STR) {
290
- if (row_major_a && row_major_b) {
291
- using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
292
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
293
- } else if (!row_major_a && row_major_b) {
294
- using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
295
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
296
- } else if (row_major_a && !row_major_b) {
297
- using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
298
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
299
- } else if (!row_major_a && !row_major_b) {
300
- using Gemm = DefaultGemmConfig<75, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
301
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
302
- }
303
- }
304
- } else if (compute_capability == 70) {
305
- if (datatype == F16_STR) {
306
- if (row_major_a && row_major_b) {
307
- using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
308
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
309
- } else if (!row_major_a && row_major_b) {
310
- using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
311
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
312
- } else if (row_major_a && !row_major_b) {
313
- using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
314
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
315
- } else if (!row_major_a && !row_major_b) {
316
- using Gemm = DefaultGemmConfig<70, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
317
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
318
- }
319
- }
320
- }
321
-
322
- // No Tensor Core capability available. Run a SIMT kernel
323
- if (datatype == F64_STR) {
324
- if (row_major_a && row_major_b) {
325
- using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
326
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
327
- } else if (!row_major_a && row_major_b) {
328
- using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
329
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
330
- } else if (row_major_a && !row_major_b) {
331
- using Gemm = DefaultGemmConfig<50, double, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
332
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
333
- } else if (!row_major_a && !row_major_b) {
334
- using Gemm = DefaultGemmConfig<50, double, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
335
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
336
- }
337
- } else if (datatype == F32_STR) {
338
- if (row_major_a && row_major_b) {
339
- using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
340
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
341
- } else if (!row_major_a && row_major_b) {
342
- using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
343
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
344
- } else if (row_major_a && !row_major_b) {
345
- using Gemm = DefaultGemmConfig<50, float, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
346
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
347
- } else if (!row_major_a && !row_major_b) {
348
- using Gemm = DefaultGemmConfig<50, float, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
349
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
350
- }
351
- } else if (datatype == F16_STR) {
352
- if (row_major_a && row_major_b) {
353
- using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::RowMajor>::Gemm;
354
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
355
- } else if (!row_major_a && row_major_b) {
356
- using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>::Gemm;
357
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
358
- } else if (row_major_a && !row_major_b) {
359
- using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::RowMajor, cutlass::layout::ColumnMajor>::Gemm;
360
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
361
- } else if (!row_major_a && !row_major_b) {
362
- using Gemm = DefaultGemmConfig<50, cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor>::Gemm;
363
- return run_gemm<Gemm>(m, n, k, batch_count, a, b, c, d, alpha, beta);
364
- }
365
- }
366
-
367
- std::cerr << "Data type " << datatype << " is not currently supported." << std::endl;
368
- return false;
369
- }
370
-
371
- }
372
-
373
- } // namespace wp