warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,30 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #pragma once
10
19
 
11
20
  #include <nanovdb/NanoVDB.h>
12
21
 
22
+ #define WP_VOLUME_BUILDER_INSTANTIATE_TYPES \
23
+ EXPAND_BUILDER_TYPE(int32_t) \
24
+ EXPAND_BUILDER_TYPE(float) \
25
+ EXPAND_BUILDER_TYPE(nanovdb::Vec3f) \
26
+ EXPAND_BUILDER_TYPE(nanovdb::Vec4f) \
27
+
13
28
  template <typename BuildT> struct BuildGridParams
14
29
  {
15
30
  nanovdb::Map map;
warp/native/volume_impl.h CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #pragma once
warp/native/warp.cpp CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #include "warp.h"
@@ -142,11 +151,6 @@ int is_cuda_compatibility_enabled()
142
151
  return int(WP_ENABLE_CUDA_COMPATIBILITY);
143
152
  }
144
153
 
145
- int is_cutlass_enabled()
146
- {
147
- return int(WP_ENABLE_CUTLASS);
148
- }
149
-
150
154
  int is_mathdx_enabled()
151
155
  {
152
156
  return int(WP_ENABLE_MATHDX);
@@ -995,6 +999,8 @@ WP_API int cuda_device_is_mempool_supported(int ordinal) { return 0; }
995
999
  WP_API int cuda_device_is_ipc_supported(int ordinal) { return 0; }
996
1000
  WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold) { return 0; }
997
1001
  WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal) { return 0; }
1002
+ WP_API uint64_t cuda_device_get_mempool_used_mem_current(int ordinal) { return 0; }
1003
+ WP_API uint64_t cuda_device_get_mempool_used_mem_high(int ordinal) { return 0; }
998
1004
  WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem) {}
999
1005
 
1000
1006
  WP_API void* cuda_context_get_current() { return NULL; }
@@ -1024,6 +1030,7 @@ WP_API void* cuda_ipc_open_event_handle(void* context, char* handle) { return NU
1024
1030
 
1025
1031
  WP_API void* cuda_stream_create(void* context, int priority) { return NULL; }
1026
1032
  WP_API void cuda_stream_destroy(void* context, void* stream) {}
1033
+ WP_API int cuda_stream_query(void* stream) { return 0; }
1027
1034
  WP_API void cuda_stream_register(void* context, void* stream) {}
1028
1035
  WP_API void cuda_stream_unregister(void* context, void* stream) {}
1029
1036
  WP_API void* cuda_stream_get_current() { return NULL; }
@@ -1036,7 +1043,8 @@ WP_API int cuda_stream_get_priority(void* stream) { return 0; }
1036
1043
 
1037
1044
  WP_API void* cuda_event_create(void* context, unsigned flags) { return NULL; }
1038
1045
  WP_API void cuda_event_destroy(void* event) {}
1039
- WP_API void cuda_event_record(void* event, void* stream) {}
1046
+ WP_API int cuda_event_query(void* event) { return 0; }
1047
+ WP_API void cuda_event_record(void* event, void* stream, bool timing) {}
1040
1048
  WP_API void cuda_event_synchronize(void* event) {}
1041
1049
  WP_API float cuda_event_elapsed_time(void* start_event, void* end_event) { return 0.0f; }
1042
1050
 
warp/native/warp.cu CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #include "warp.h"
@@ -1879,6 +1888,62 @@ uint64_t cuda_device_get_mempool_release_threshold(int ordinal)
1879
1888
  return threshold;
1880
1889
  }
1881
1890
 
1891
+ uint64_t cuda_device_get_mempool_used_mem_current(int ordinal)
1892
+ {
1893
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1894
+ {
1895
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1896
+ return 0;
1897
+ }
1898
+
1899
+ if (!g_devices[ordinal].is_mempool_supported)
1900
+ return 0;
1901
+
1902
+ cudaMemPool_t pool;
1903
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1904
+ {
1905
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1906
+ return 0;
1907
+ }
1908
+
1909
+ uint64_t mem_used = 0;
1910
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemCurrent, &mem_used)))
1911
+ {
1912
+ fprintf(stderr, "Warp error: Failed to get amount of currently used memory from the memory pool on device %d\n", ordinal);
1913
+ return 0;
1914
+ }
1915
+
1916
+ return mem_used;
1917
+ }
1918
+
1919
+ uint64_t cuda_device_get_mempool_used_mem_high(int ordinal)
1920
+ {
1921
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1922
+ {
1923
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1924
+ return 0;
1925
+ }
1926
+
1927
+ if (!g_devices[ordinal].is_mempool_supported)
1928
+ return 0;
1929
+
1930
+ cudaMemPool_t pool;
1931
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1932
+ {
1933
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1934
+ return 0;
1935
+ }
1936
+
1937
+ uint64_t mem_high_water_mark = 0;
1938
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &mem_high_water_mark)))
1939
+ {
1940
+ fprintf(stderr, "Warp error: Failed to get memory usage high water mark from the memory pool on device %d\n", ordinal);
1941
+ return 0;
1942
+ }
1943
+
1944
+ return mem_high_water_mark;
1945
+ }
1946
+
1882
1947
  void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
1883
1948
  {
1884
1949
  // use temporary storage if user didn't specify pointers
@@ -2362,6 +2427,19 @@ void cuda_stream_destroy(void* context, void* stream)
2362
2427
  check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
2363
2428
  }
2364
2429
 
2430
+ int cuda_stream_query(void* stream)
2431
+ {
2432
+ CUresult res = cuStreamQuery_f(static_cast<CUstream>(stream));
2433
+
2434
+ if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
2435
+ {
2436
+ // Abnormal, print out error
2437
+ check_cu(res);
2438
+ }
2439
+
2440
+ return res;
2441
+ }
2442
+
2365
2443
  void cuda_stream_register(void* context, void* stream)
2366
2444
  {
2367
2445
  if (!stream)
@@ -2456,9 +2534,30 @@ void cuda_event_destroy(void* event)
2456
2534
  check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
2457
2535
  }
2458
2536
 
2459
- void cuda_event_record(void* event, void* stream)
2537
+ int cuda_event_query(void* event)
2538
+ {
2539
+ CUresult res = cuEventQuery_f(static_cast<CUevent>(event));
2540
+
2541
+ if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
2542
+ {
2543
+ // Abnormal, print out error
2544
+ check_cu(res);
2545
+ }
2546
+
2547
+ return res;
2548
+ }
2549
+
2550
+ void cuda_event_record(void* event, void* stream, bool timing)
2460
2551
  {
2461
- check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
2552
+ if (timing && !g_captures.empty() && cuda_stream_is_capturing(stream))
2553
+ {
2554
+ // record timing event during graph capture
2555
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
2556
+ }
2557
+ else
2558
+ {
2559
+ check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
2560
+ }
2462
2561
  }
2463
2562
 
2464
2563
  void cuda_event_synchronize(void* event)
@@ -2805,6 +2904,12 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
2805
2904
  opts.push_back("--define-macro=WP_VERIFY_FP");
2806
2905
  else
2807
2906
  opts.push_back("--undefine-macro=WP_VERIFY_FP");
2907
+
2908
+ #if WP_ENABLE_MATHDX
2909
+ opts.push_back("--define-macro=WP_ENABLE_MATHDX=1");
2910
+ #else
2911
+ opts.push_back("--define-macro=WP_ENABLE_MATHDX=0");
2912
+ #endif
2808
2913
 
2809
2914
  if (fast_math)
2810
2915
  opts.push_back("--use_fast_math");
@@ -2814,10 +2919,6 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
2814
2919
  else
2815
2920
  opts.push_back("--fmad=false");
2816
2921
 
2817
- char include_cutlass[max_path];
2818
- sprintf(include_cutlass, "--include-path=%s/cutlass/include", include_dir);
2819
- opts.push_back(include_cutlass);
2820
-
2821
2922
  std::vector<std::string> cuda_include_opt;
2822
2923
  for(int i = 0; i < num_cuda_include_dirs; i++)
2823
2924
  {
@@ -3173,7 +3274,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3173
3274
  std::vector<char> lto(lto_size);
3174
3275
  CHECK_CUSOLVER(cusolverGetLTOIR(h, lto.size(), lto.data()));
3175
3276
 
3176
- // This fatbin is universal, ie it is the same for any instantations of a cusolver device function
3277
+ // This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
3177
3278
  size_t fatbin_size = 0;
3178
3279
  CHECK_CUSOLVER(cusolverGetUniversalFATBINSize(h, &fatbin_size));
3179
3280
 
@@ -3530,9 +3631,6 @@ void cuda_timing_end(timing_result_t* results, int size)
3530
3631
  #include "sparse.cu"
3531
3632
  #include "volume.cu"
3532
3633
  #include "volume_builder.cu"
3533
- #if WP_ENABLE_CUTLASS
3534
- #include "cutlass_gemm.cu"
3535
- #endif
3536
3634
 
3537
3635
  //#include "spline.inl"
3538
3636
  //#include "volume.inl"
warp/native/warp.h CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #pragma once
@@ -32,8 +41,6 @@ extern "C"
32
41
  WP_API int is_cuda_enabled();
33
42
  // whether Warp was compiled with enhanced CUDA compatibility
34
43
  WP_API int is_cuda_compatibility_enabled();
35
- // whether Warp was compiled with CUTLASS support
36
- WP_API int is_cutlass_enabled();
37
44
  // whether Warp was compiled with MathDx support
38
45
  WP_API int is_mathdx_enabled();
39
46
  // whether Warp was compiled with debug support
@@ -103,10 +110,6 @@ extern "C"
103
110
  WP_API void hash_grid_destroy_device(uint64_t id);
104
111
  WP_API void hash_grid_update_device(uint64_t id, float cell_width, const wp::array_t<wp::vec3>* points);
105
112
 
106
- WP_API bool cutlass_gemm(void* context, int compute_capability, int m, int n, int k, const char* datatype,
107
- const void* a, const void* b, const void* c, void* d, float alpha, float beta,
108
- bool row_major_a, bool row_major_b, bool allow_tf32x3_arith, int batch_count);
109
-
110
113
  WP_API uint64_t volume_create_host(void* buf, uint64_t size, bool copy, bool owner);
111
114
  WP_API void volume_get_tiles_host(uint64_t id, void* buf);
112
115
  WP_API void volume_get_voxels_host(uint64_t id, void* buf);
@@ -117,9 +120,7 @@ extern "C"
117
120
  WP_API void volume_get_voxels_device(uint64_t id, void* buf);
118
121
  WP_API void volume_destroy_device(uint64_t id);
119
122
 
120
- WP_API uint64_t volume_f_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space, float bg_value);
121
- WP_API uint64_t volume_v_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space, float bg_value[3]);
122
- WP_API uint64_t volume_i_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space, int bg_value);
123
+ WP_API uint64_t volume_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space, const void* bg_value, uint32_t bg_value_size, const char* bg_value_type);
123
124
  WP_API uint64_t volume_index_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space);
124
125
  WP_API uint64_t volume_from_active_voxels_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space);
125
126
 
@@ -164,6 +165,15 @@ extern "C"
164
165
  WP_API void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n);
165
166
  WP_API void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n);
166
167
 
168
+ WP_API void radix_sort_pairs_int64_host(uint64_t keys, uint64_t values, int n);
169
+ WP_API void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n);
170
+
171
+ WP_API void segmented_sort_pairs_float_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments);
172
+ WP_API void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments);
173
+
174
+ WP_API void segmented_sort_pairs_int_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments);
175
+ WP_API void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments);
176
+
167
177
  WP_API void runlength_encode_int_host(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
168
178
  WP_API void runlength_encode_int_device(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
169
179
 
@@ -176,6 +186,7 @@ extern "C"
176
186
  int* tpl_columns,
177
187
  void* tpl_values,
178
188
  bool prune_numerical_zeros,
189
+ bool masked,
179
190
  int* bsr_offsets,
180
191
  int* bsr_columns,
181
192
  void* bsr_values,
@@ -190,6 +201,7 @@ extern "C"
190
201
  int* tpl_columns,
191
202
  void* tpl_values,
192
203
  bool prune_numerical_zeros,
204
+ bool masked,
193
205
  int* bsr_offsets,
194
206
  int* bsr_columns,
195
207
  void* bsr_values,
@@ -204,6 +216,7 @@ extern "C"
204
216
  int* tpl_columns,
205
217
  void* tpl_values,
206
218
  bool prune_numerical_zeros,
219
+ bool masked,
207
220
  int* bsr_offsets,
208
221
  int* bsr_columns,
209
222
  void* bsr_values,
@@ -218,6 +231,7 @@ extern "C"
218
231
  int* tpl_columns,
219
232
  void* tpl_values,
220
233
  bool prune_numerical_zeros,
234
+ bool masked,
221
235
  int* bsr_offsets,
222
236
  int* bsr_columns,
223
237
  void* bsr_values,
@@ -274,6 +288,8 @@ extern "C"
274
288
  WP_API int cuda_device_is_ipc_supported(int ordinal);
275
289
  WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold);
276
290
  WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal);
291
+ WP_API uint64_t cuda_device_get_mempool_used_mem_current(int ordinal);
292
+ WP_API uint64_t cuda_device_get_mempool_used_mem_high(int ordinal);
277
293
  WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem);
278
294
 
279
295
  WP_API void* cuda_context_get_current();
@@ -309,6 +325,7 @@ extern "C"
309
325
 
310
326
  WP_API void* cuda_stream_create(void* context, int priority);
311
327
  WP_API void cuda_stream_destroy(void* context, void* stream);
328
+ WP_API int cuda_stream_query(void* stream);
312
329
  WP_API void cuda_stream_register(void* context, void* stream);
313
330
  WP_API void cuda_stream_unregister(void* context, void* stream);
314
331
  WP_API void* cuda_stream_get_current();
@@ -321,7 +338,8 @@ extern "C"
321
338
 
322
339
  WP_API void* cuda_event_create(void* context, unsigned flags);
323
340
  WP_API void cuda_event_destroy(void* event);
324
- WP_API void cuda_event_record(void* event, void* stream);
341
+ WP_API int cuda_event_query(void* event);
342
+ WP_API void cuda_event_record(void* event, void* stream, bool timing=false);
325
343
  WP_API void cuda_event_synchronize(void* event);
326
344
  WP_API float cuda_event_elapsed_time(void* start_event, void* end_event);
327
345
 
warp/optim/__init__.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  from .adam import Adam
9
17
  from .sgd import SGD
warp/optim/adam.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import warp as wp
9
17
 
warp/optim/linear.py CHANGED
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from math import sqrt
2
17
  from typing import Any, Callable, Optional, Tuple, Union
3
18
 
@@ -851,7 +866,7 @@ def _diag_mv_vec_kernel(
851
866
  def _inverse_diag_coefficient(coeff: Any, use_abs: wp.bool):
852
867
  zero = type(coeff)(0.0)
853
868
  one = type(coeff)(1.0)
854
- return wp.select(coeff == zero, one / wp.select(use_abs, coeff, wp.abs(coeff)), one)
869
+ return wp.where(coeff == zero, one, one / wp.where(use_abs, wp.abs(coeff), coeff))
855
870
 
856
871
 
857
872
  @wp.kernel
@@ -902,7 +917,7 @@ def _cg_kernel_1(
902
917
  ):
903
918
  i = wp.tid()
904
919
 
905
- alpha = wp.select(resid[0] > tol, rz_old.dtype(0.0), rz_old[0] / p_Ap[0])
920
+ alpha = wp.where(resid[0] > tol, rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
906
921
 
907
922
  x[i] = x[i] + alpha * p[i]
908
923
  r[i] = r[i] - alpha * Ap[i]
@@ -920,7 +935,7 @@ def _cg_kernel_2(
920
935
  # p = r + (rz_new / rz_old) * p;
921
936
  i = wp.tid()
922
937
 
923
- beta = wp.select(resid[0] > tol, rz_old.dtype(0.0), rz_new[0] / rz_old[0])
938
+ beta = wp.where(resid[0] > tol, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
924
939
 
925
940
  p[i] = z[i] + beta * p[i]
926
941
 
@@ -940,7 +955,7 @@ def _cr_kernel_1(
940
955
  ):
941
956
  i = wp.tid()
942
957
 
943
- alpha = wp.select(resid[0] > tol and y_Ap[0] > 0.0, zAz_old.dtype(0.0), zAz_old[0] / y_Ap[0])
958
+ alpha = wp.where(resid[0] > tol and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
944
959
 
945
960
  x[i] = x[i] + alpha * p[i]
946
961
  r[i] = r[i] - alpha * Ap[i]
@@ -961,7 +976,7 @@ def _cr_kernel_2(
961
976
  # p = r + (rz_new / rz_old) * p;
962
977
  i = wp.tid()
963
978
 
964
- beta = wp.select(resid[0] > tol and zAz_old[0] > 0.0, zAz_old.dtype(0.0), zAz_new[0] / zAz_old[0])
979
+ beta = wp.where(resid[0] > tol and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
965
980
 
966
981
  p[i] = z[i] + beta * p[i]
967
982
  Ap[i] = Az[i] + beta * Ap[i]
@@ -980,7 +995,7 @@ def _bicgstab_kernel_1(
980
995
  ):
981
996
  i = wp.tid()
982
997
 
983
- alpha = wp.select(resid[0] > tol, rho_old.dtype(0.0), rho_old[0] / r0v[0])
998
+ alpha = wp.where(resid[0] > tol, rho_old[0] / r0v[0], rho_old.dtype(0.0))
984
999
 
985
1000
  x[i] += alpha * y[i]
986
1001
  r[i] -= alpha * v[i]
@@ -999,7 +1014,7 @@ def _bicgstab_kernel_2(
999
1014
  ):
1000
1015
  i = wp.tid()
1001
1016
 
1002
- omega = wp.select(resid[0] > tol, st.dtype(0.0), st[0] / tt[0])
1017
+ omega = wp.where(resid[0] > tol, st[0] / tt[0], st.dtype(0.0))
1003
1018
 
1004
1019
  x[i] += omega * z[i]
1005
1020
  r[i] -= omega * t[i]
@@ -1019,8 +1034,8 @@ def _bicgstab_kernel_3(
1019
1034
  ):
1020
1035
  i = wp.tid()
1021
1036
 
1022
- beta = wp.select(resid[0] > tol, st.dtype(0.0), rho_new[0] * tt[0] / (r0v[0] * st[0]))
1023
- beta_omega = wp.select(resid[0] > tol, st.dtype(0.0), rho_new[0] / r0v[0])
1037
+ beta = wp.where(resid[0] > tol, rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
1038
+ beta_omega = wp.where(resid[0] > tol, rho_new[0] / r0v[0], st.dtype(0.0))
1024
1039
 
1025
1040
  p[i] = r[i] + beta * p[i] - beta_omega * v[i]
1026
1041
 
@@ -1108,7 +1123,7 @@ def _gmres_arnoldi_normalize_kernel(
1108
1123
  alpha: wp.array(dtype=Any),
1109
1124
  ):
1110
1125
  tid = wp.tid()
1111
- y[tid] = wp.select(alpha[0] == alpha.dtype(0.0), x[tid] / wp.sqrt(alpha[0]), x[tid])
1126
+ y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / wp.sqrt(alpha[0]))
1112
1127
 
1113
1128
 
1114
1129
  @wp.kernel
warp/optim/sgd.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  from typing import Any
9
17
 
warp/paddle.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  from __future__ import annotations
9
17