warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/__init__.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  # isort: skip_file
9
17
 
@@ -76,7 +84,12 @@ from warp.context import Stream, get_stream, set_stream, wait_stream, synchroniz
76
84
  from warp.context import Event, record_event, wait_event, synchronize_event, get_event_elapsed_time
77
85
  from warp.context import RegisteredGLBuffer
78
86
  from warp.context import is_mempool_supported, is_mempool_enabled, set_mempool_enabled
79
- from warp.context import set_mempool_release_threshold, get_mempool_release_threshold
87
+ from warp.context import (
88
+ set_mempool_release_threshold,
89
+ get_mempool_release_threshold,
90
+ get_mempool_used_mem_current,
91
+ get_mempool_used_mem_high,
92
+ )
80
93
  from warp.context import is_mempool_access_supported, is_mempool_access_enabled, set_mempool_access_enabled
81
94
  from warp.context import is_peer_access_supported, is_peer_access_enabled, set_peer_access_enabled
82
95
 
@@ -112,6 +125,7 @@ from warp.paddle import device_from_paddle, device_to_paddle
112
125
  from warp.paddle import stream_from_paddle
113
126
 
114
127
  from warp.build import clear_kernel_cache
128
+ from warp.build import clear_lto_cache
115
129
 
116
130
  from warp.constants import *
117
131
 
warp/autograd.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import inspect
9
17
  import itertools
warp/bin/warp-clang.dll CHANGED
Binary file
warp/bin/warp.dll CHANGED
Binary file
warp/build.py CHANGED
@@ -1,15 +1,29 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import ctypes
17
+ import errno
18
+ import hashlib
19
+ import json
9
20
  import os
21
+ import time
22
+ from pathlib import Path
10
23
 
11
24
  import warp.config
12
25
  from warp.thirdparty import appdirs
26
+ from warp.types import *
13
27
 
14
28
  # From nvJitLink.h
15
29
  nvJitLink_input_type = {"cubin": 1, "ptx": 2, "ltoir": 3, "fatbin": 4, "object": 5, "library": 6}
@@ -123,6 +137,7 @@ def clear_kernel_cache() -> None:
123
137
 
124
138
  Only directories beginning with ``wp_`` will be deleted.
125
139
  This function only clears the cache for the current Warp version.
140
+ LTO artifacts are not affected.
126
141
  """
127
142
 
128
143
  warp.context.init()
@@ -137,3 +152,406 @@ def clear_kernel_cache() -> None:
137
152
  if os.path.isdir(item_path) and item.startswith("wp_"):
138
153
  # Remove the directory and its contents
139
154
  shutil.rmtree(item_path, ignore_errors=True)
155
+
156
+
157
+ def clear_lto_cache() -> None:
158
+ """Clear the LTO cache directory of previously generated LTO code.
159
+
160
+ The LTO cache is stored within a subdirectory of the kernel cache directory.
161
+ This function only clears the cache for the current Warp version.
162
+ """
163
+
164
+ warp.context.init()
165
+
166
+ import shutil
167
+
168
+ is_intialized = warp.context.runtime is not None
169
+ assert is_intialized, "The kernel cache directory is not configured; wp.init() has not been called yet or failed."
170
+
171
+ lto_path = os.path.join(warp.config.kernel_cache_dir, "lto")
172
+ if os.path.isdir(lto_path):
173
+ # Remove the lto directory and its contents
174
+ shutil.rmtree(lto_path, ignore_errors=True)
175
+
176
+
177
+ def safe_rename(src, dst, attempts=5, delay=0.1):
178
+ for i in range(attempts):
179
+ try:
180
+ os.rename(src, dst)
181
+ return
182
+ except FileExistsError:
183
+ return
184
+ except OSError as e:
185
+ if e.errno == errno.ENOTEMPTY:
186
+ # if directory exists we assume another process
187
+ # got there first, in which case we will copy
188
+ # our output to the directory manually in second step
189
+ return
190
+ else:
191
+ # otherwise assume directory creation failed e.g.: access denied
192
+ # on Windows we see occasional failures to rename directories due to
193
+ # some process holding a lock on a file to be moved to workaround
194
+ # this we make multiple attempts to rename with some delay
195
+ if i < attempts - 1:
196
+ time.sleep(delay)
197
+ else:
198
+ print(
199
+ f"Could not update Warp cache with compiled binaries, trying to rename {src} to {dst}, error {e}"
200
+ )
201
+ raise e
202
+
203
+
204
+ def hash_symbol(symbol):
205
+ ch = hashlib.sha256()
206
+ ch.update(symbol.encode("utf-8"))
207
+ return ch.hexdigest()
208
+
209
+
210
+ def get_lto_cache_dir():
211
+ lto_dir = os.path.join(warp.config.kernel_cache_dir, "lto")
212
+ return lto_dir
213
+
214
+
215
+ def get_cached_lto(path):
216
+ if os.path.exists(path):
217
+ with open(path, "rb") as f:
218
+ lto_code_data = f.read()
219
+ return lto_code_data
220
+ else:
221
+ return None
222
+
223
+
224
+ def get_cached_lto_meta(path, symbol):
225
+ if os.path.exists(path):
226
+ with open(path, "r") as f:
227
+ keys = json.load(f)
228
+ value = keys[symbol]
229
+ return value
230
+ else:
231
+ return None
232
+
233
+
234
+ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, arch, num_threads, builder):
235
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
236
+ arch = min(arch, 90)
237
+
238
+ # Maps Python/Warp types to C++ types and enums
239
+ def cublasdx_type_map(dtype):
240
+ if dtype == float16:
241
+ return ("wp::float16", 3, 0)
242
+ if dtype == float32:
243
+ return ("wp::float32", 5, 0)
244
+ if dtype == float64:
245
+ return ("wp::float64", 6, 0)
246
+ if dtype == vec2h:
247
+ return ("wp::vec2h", 3, 1)
248
+ if dtype == vec2f:
249
+ return ("wp::vec2f", 5, 1)
250
+ if dtype == vec2d:
251
+ return ("wp::vec2d", 6, 1)
252
+ raise TypeError("Unsupported input type in tile_matmul")
253
+
254
+ def cublasdx_arrangement_map(layout):
255
+ if layout == "colmajor":
256
+ return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
257
+ if layout == "rowmajor":
258
+ return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
259
+ raise ValueError("Unsupported layout in tile_matmul")
260
+
261
+ (a_dtype, a_prec, a_type) = cublasdx_type_map(adtype)
262
+ (b_dtype, b_prec, b_type) = cublasdx_type_map(bdtype)
263
+ (c_dtype, c_prec, c_type) = cublasdx_type_map(cdtype)
264
+ a_arrangement = cublasdx_arrangement_map(alayout)
265
+ b_arrangement = cublasdx_arrangement_map(blayout)
266
+ c_arrangement = cublasdx_arrangement_map(clayout)
267
+
268
+ if a_type != b_type or a_type != c_type:
269
+ raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
270
+
271
+ element_type = a_type
272
+
273
+ lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
274
+
275
+ # early out if LTO for this symbol is already cached in current module
276
+ if lto_symbol in builder.ltoirs:
277
+ return lto_symbol, builder.ltoirs[lto_symbol]
278
+
279
+ # hash symbol and determine output path
280
+ h = hash_symbol(lto_symbol)
281
+
282
+ lto_dir = get_lto_cache_dir()
283
+ lto_name = f"{h[:7]}.lto"
284
+ lto_path = os.path.join(lto_dir, lto_name)
285
+
286
+ # early out if LTO for this symbol is already built but not cached in current module
287
+ lto_code_data = get_cached_lto(lto_path)
288
+
289
+ if lto_code_data is not None:
290
+ builder.ltoirs[lto_symbol] = lto_code_data
291
+ builder.ltoirs_decl[lto_symbol] = (
292
+ f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
293
+ )
294
+
295
+ return lto_symbol, lto_code_data
296
+
297
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
298
+ build_dir = f"{lto_dir}_p{os.getpid()}"
299
+
300
+ # dir may exist from previous attempts / runs / archs
301
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
302
+
303
+ # temporary path to compile to in build_dir
304
+ temp_lto_path = os.path.join(build_dir, lto_name)
305
+
306
+ # compile LTO
307
+ result = warp.context.runtime.core.cuda_compile_dot(
308
+ temp_lto_path.encode("utf-8"),
309
+ lto_symbol.encode("utf-8"),
310
+ 0,
311
+ None,
312
+ None,
313
+ arch,
314
+ M,
315
+ N,
316
+ K,
317
+ a_prec,
318
+ b_prec,
319
+ c_prec,
320
+ element_type,
321
+ a_arrangement,
322
+ b_arrangement,
323
+ c_arrangement,
324
+ num_threads,
325
+ )
326
+
327
+ if not result:
328
+ if Path(temp_lto_path).exists():
329
+ Path(temp_lto_path).unlink()
330
+ raise RuntimeError("Failed to compile tile_matmul")
331
+ else:
332
+ with open(temp_lto_path, "rb") as f:
333
+ lto_code_data = f.read()
334
+
335
+ builder.ltoirs[lto_symbol] = lto_code_data
336
+ builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
337
+
338
+ # try to move process outputs to cache
339
+ safe_rename(build_dir, lto_dir)
340
+
341
+ if os.path.exists(lto_dir):
342
+ if not os.path.exists(lto_path):
343
+ # copy output file to the destination lto dir
344
+ try:
345
+ os.rename(temp_lto_path, lto_path)
346
+ except (OSError, FileExistsError):
347
+ # another process likely updated the lto dir first
348
+ pass
349
+
350
+ if build_dir:
351
+ import shutil
352
+
353
+ # clean up build_dir used for this process
354
+ shutil.rmtree(build_dir, ignore_errors=True)
355
+
356
+ return lto_symbol, lto_code_data
357
+
358
+
359
+ def build_lto_solver(M, N, solver, solver_enum, fill_mode, arch, precision_enum, num_threads, parameter_list, builder):
360
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
361
+ arch = min(arch, 90)
362
+
363
+ lto_symbol = f"{solver}_{M}_{N}_{arch}_{precision_enum}"
364
+ ltoir_decl = f"void {lto_symbol}{parameter_list};"
365
+
366
+ # early out if LTO for this symbol is already cached in current module
367
+ if lto_symbol in builder.ltoirs:
368
+ return lto_symbol, builder.ltoirs[lto_symbol]
369
+
370
+ # hash symbol and determine output path
371
+ h = hash_symbol(lto_symbol)
372
+
373
+ lto_dir = get_lto_cache_dir()
374
+ lto_name = f"{h[:7]}.lto"
375
+ lto_path = os.path.join(lto_dir, lto_name)
376
+
377
+ # we also cache a universal fatbin binary for this symbol
378
+ universal_fatbin_name = f"{h[:7]}_fatbin.lto"
379
+ universal_fatbin_path = os.path.join(lto_dir, universal_fatbin_name)
380
+
381
+ lto_code_data = get_cached_lto(lto_path)
382
+ universal_fatbin_code_data = get_cached_lto(universal_fatbin_path)
383
+
384
+ # early out if LTO for this symbol is already built but not cached in current module
385
+ if lto_code_data is not None and universal_fatbin_code_data is not None:
386
+ builder.ltoirs[lto_symbol] = lto_code_data
387
+ builder.ltoirs_decl[lto_symbol] = ltoir_decl
388
+ builder.fatbins[lto_symbol] = universal_fatbin_code_data
389
+
390
+ return lto_symbol, lto_code_data
391
+
392
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
393
+ build_dir = f"{lto_dir}_p{os.getpid()}"
394
+
395
+ # dir may exist from previous attempts / runs / archs
396
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
397
+
398
+ # temporary paths to compile to in build_dir
399
+ temp_lto_path = os.path.join(build_dir, lto_name)
400
+ temp_universal_fatbin_path = os.path.join(build_dir, universal_fatbin_name)
401
+
402
+ # compile LTO
403
+ result = warp.context.runtime.core.cuda_compile_solver(
404
+ temp_universal_fatbin_path.encode("utf-8"),
405
+ temp_lto_path.encode("utf-8"),
406
+ lto_symbol.encode("utf-8"),
407
+ 0,
408
+ None,
409
+ None,
410
+ arch,
411
+ M,
412
+ N,
413
+ solver_enum,
414
+ precision_enum,
415
+ fill_mode,
416
+ num_threads,
417
+ )
418
+
419
+ if not result:
420
+ for path in [temp_universal_fatbin_path, temp_lto_path]:
421
+ if Path(path).exists():
422
+ Path(path).unlink()
423
+ raise RuntimeError("Failed to compile tile_cholesky")
424
+
425
+ else:
426
+ with open(temp_lto_path, "rb") as f:
427
+ lto_code_data = f.read()
428
+ with open(temp_universal_fatbin_path, "rb") as f:
429
+ universal_fatbin_code_data = f.read()
430
+
431
+ builder.ltoirs[lto_symbol] = lto_code_data
432
+ builder.ltoirs_decl[lto_symbol] = ltoir_decl
433
+ builder.fatbins[lto_symbol] = universal_fatbin_code_data
434
+
435
+ # try to move process outputs to lto cache
436
+ safe_rename(build_dir, lto_dir)
437
+
438
+ if os.path.exists(lto_dir):
439
+ for p in [(lto_path, temp_lto_path), (universal_fatbin_path, temp_universal_fatbin_path)]:
440
+ path, temp_path = p
441
+ if not os.path.exists(path):
442
+ # copy output file to the destination lto dir
443
+ try:
444
+ os.rename(temp_path, path)
445
+ except (OSError, FileExistsError):
446
+ # another process likely updated the lto dir first
447
+ pass
448
+
449
+ if build_dir:
450
+ import shutil
451
+
452
+ # clean up build_dir used for this process
453
+ shutil.rmtree(build_dir, ignore_errors=True)
454
+
455
+ return lto_symbol, lto_code_data
456
+
457
+
458
+ def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
459
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
460
+ arch = min(arch, 90)
461
+
462
+ lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
463
+
464
+ # early out if LTO for this symbol is already cached in current module
465
+ if lto_symbol in builder.ltoirs:
466
+ return lto_symbol, builder.ltoirs[lto_symbol], builder.shared_memory_bytes[lto_symbol]
467
+
468
+ # hash symbol and determine output path
469
+ h = hash_symbol(lto_symbol)
470
+
471
+ lto_dir = get_lto_cache_dir()
472
+ lto_name = f"{h[:7]}.lto"
473
+ lto_path = os.path.join(lto_dir, lto_name)
474
+
475
+ # we also cache shared memory requirements for this kernel in a .meta file
476
+ meta_name = f"{h[:7]}.meta"
477
+ meta_path = os.path.join(lto_dir, meta_name)
478
+
479
+ # early out if LTO for this symbol is already built but not cached in current module
480
+ lto_code_data = get_cached_lto(lto_path)
481
+ shared_memory_bytes = get_cached_lto_meta(meta_path, lto_symbol)
482
+
483
+ if lto_code_data is not None and shared_memory_bytes is not None:
484
+ builder.ltoirs[lto_symbol] = lto_code_data
485
+ builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
486
+
487
+ return lto_symbol, lto_code_data, shared_memory_bytes
488
+
489
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
490
+ build_dir = f"{lto_dir}_p{os.getpid()}"
491
+
492
+ # dir may exist from previous attempts / runs / archs
493
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
494
+
495
+ # temporary paths to compile to in build_dir
496
+ temp_lto_path = os.path.join(build_dir, lto_name)
497
+ temp_meta_path = os.path.join(build_dir, meta_name)
498
+
499
+ # compile LTO
500
+ shared_memory_size = ctypes.c_int(0)
501
+
502
+ result = warp.context.runtime.core.cuda_compile_fft(
503
+ temp_lto_path.encode("utf-8"),
504
+ lto_symbol.encode("utf-8"),
505
+ 0,
506
+ None,
507
+ None,
508
+ arch,
509
+ size,
510
+ ept,
511
+ dir,
512
+ precision,
513
+ ctypes.byref(shared_memory_size),
514
+ )
515
+
516
+ shared_memory_bytes = Tile.round_up(shared_memory_size.value)
517
+
518
+ if not result:
519
+ if Path(temp_lto_path).exists():
520
+ Path(temp_lto_path).unlink()
521
+ raise RuntimeError("Failed to compile tile_fft")
522
+
523
+ else:
524
+ with open(temp_lto_path, "rb") as f:
525
+ lto_code_data = f.read()
526
+
527
+ # output meta file with shared memory requirements for this lto_symbol
528
+ meta = {}
529
+ meta[lto_symbol] = shared_memory_bytes
530
+
531
+ with open(temp_meta_path, "w") as meta_file:
532
+ json.dump(meta, meta_file)
533
+
534
+ builder.ltoirs[lto_symbol] = lto_code_data
535
+ builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
536
+
537
+ # try to move process outputs to cache
538
+ safe_rename(build_dir, lto_dir)
539
+
540
+ if os.path.exists(lto_dir):
541
+ for p in [(lto_path, temp_lto_path), (meta_path, temp_meta_path)]:
542
+ path, temp_path = p
543
+ if not os.path.exists(path):
544
+ # copy output file to the destination lto dir
545
+ try:
546
+ os.rename(temp_path, path)
547
+ except (OSError, FileExistsError):
548
+ # another process likely updated the lto dir first
549
+ pass
550
+
551
+ if build_dir:
552
+ import shutil
553
+
554
+ # clean up build_dir used for this process
555
+ shutil.rmtree(build_dir, ignore_errors=True)
556
+
557
+ return lto_symbol, lto_code_data, shared_memory_bytes
warp/build_dll.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import os
9
17
  import platform
@@ -139,14 +147,6 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
139
147
  cuda_home = args.cuda_path
140
148
  cuda_cmd = None
141
149
 
142
- if args.quick:
143
- cutlass_includes = ""
144
- cutlass_enabled = "WP_ENABLE_CUTLASS=0"
145
- else:
146
- cutlass_home = "warp/native/cutlass"
147
- cutlass_includes = f'-I"{cutlass_home}/include" -I"{cutlass_home}/tools/util/include"'
148
- cutlass_enabled = "WP_ENABLE_CUTLASS=1"
149
-
150
150
  if args.quick or cu_path is None:
151
151
  cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
152
152
  else:
@@ -262,7 +262,7 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
262
262
  iter_dbg = "_ITERATOR_DEBUG_LEVEL=2"
263
263
  debug = "_DEBUG"
264
264
 
265
- cpp_flags = f'/nologo /std:c++17 /GR- {runtime} /D "{debug}" /D "{cuda_enabled}" /D "{cutlass_enabled}" /D "{mathdx_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" {includes} '
265
+ cpp_flags = f'/nologo /std:c++17 /GR- {runtime} /D "{debug}" /D "{cuda_enabled}" /D "{mathdx_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" {includes} '
266
266
 
267
267
  if args.mode == "debug":
268
268
  cpp_flags += "/Zi /Od /D WP_ENABLE_DEBUG=1"
@@ -291,10 +291,10 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
291
291
  cu_out = cu_path + ".o"
292
292
 
293
293
  if mode == "debug":
294
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 --compiler-options=/MT,/Zi,/Od -g -G -O0 -DNDEBUG -D_ITERATOR_DEBUG_LEVEL=0 -I"{native_dir}" -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -D{cutlass_enabled} {cutlass_includes} -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
294
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 --compiler-options=/MT,/Zi,/Od -g -G -O0 -DNDEBUG -D_ITERATOR_DEBUG_LEVEL=0 -I"{native_dir}" -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
295
295
 
296
296
  elif mode == "release":
297
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 {" ".join(nvcc_opts)} -I"{native_dir}" -DNDEBUG -DWP_ENABLE_CUDA=1 -D{cutlass_enabled} {cutlass_includes} -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
297
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 {" ".join(nvcc_opts)} -I"{native_dir}" -DNDEBUG -DWP_ENABLE_CUDA=1 -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
298
298
 
299
299
  with ScopedTimer("build_cuda", active=args.verbose):
300
300
  run_cmd(cuda_cmd)
@@ -321,7 +321,7 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
321
321
  else:
322
322
  version = "-fabi-version=13" # GCC 8.2+
323
323
 
324
- cpp_flags = f'{version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{cutlass_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
324
+ cpp_flags = f'{version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
325
325
 
326
326
  if mode == "debug":
327
327
  cpp_flags += "-O0 -g -D_DEBUG -DWP_ENABLE_DEBUG=1 -fkeep-inline-functions"
@@ -349,10 +349,10 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
349
349
  cu_out = cu_path + ".o"
350
350
 
351
351
  if mode == "debug":
352
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{cutlass_enabled} {cutlass_includes} -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
352
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
353
353
 
354
354
  elif mode == "release":
355
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{cutlass_enabled} {cutlass_includes} -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
355
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
356
356
 
357
357
  with ScopedTimer("build_cuda", active=args.verbose):
358
358
  run_cmd(cuda_cmd)