warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/tests/test_utils.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import contextlib
9
17
  import inspect
@@ -79,7 +87,7 @@ def test_array_scan_error_unsupported_dtype(test, device):
79
87
 
80
88
 
81
89
  def test_radix_sort_pairs(test, device):
82
- keyTypes = [int, wp.float32]
90
+ keyTypes = [int, wp.float32, wp.int64]
83
91
 
84
92
  for keyType in keyTypes:
85
93
  keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
@@ -89,18 +97,46 @@ def test_radix_sort_pairs(test, device):
89
97
  assert_np_equal(values.numpy()[:8], np.array((5, 2, 8, 4, 7, 6, 1, 3)))
90
98
 
91
99
 
92
- def test_radix_sort_pairs_empty(test, device):
100
+ def test_segmented_sort_pairs(test, device):
93
101
  keyTypes = [int, wp.float32]
94
102
 
103
+ for keyType in keyTypes:
104
+ keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
105
+ values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device)
106
+ wp.utils.segmented_sort_pairs(
107
+ keys,
108
+ values,
109
+ 8,
110
+ wp.array((0, 4), dtype=int, device=device),
111
+ wp.array((4, 8), dtype=int, device=device),
112
+ )
113
+ assert_np_equal(keys.numpy()[:8], np.array((2, 4, 7, 8, 1, 3, 5, 6)))
114
+ assert_np_equal(values.numpy()[:8], np.array((2, 4, 1, 3, 5, 8, 7, 6)))
115
+
116
+
117
+ def test_radix_sort_pairs_empty(test, device):
118
+ keyTypes = [int, wp.float32, wp.int64]
119
+
95
120
  for keyType in keyTypes:
96
121
  keys = wp.array((), dtype=keyType, device=device)
97
122
  values = wp.array((), dtype=int, device=device)
98
123
  wp.utils.radix_sort_pairs(keys, values, 0)
99
124
 
100
125
 
101
- def test_radix_sort_pairs_error_insufficient_storage(test, device):
126
+ def test_segmented_sort_pairs_empty(test, device):
102
127
  keyTypes = [int, wp.float32]
103
128
 
129
+ for keyType in keyTypes:
130
+ keys = wp.array((), dtype=keyType, device=device)
131
+ values = wp.array((), dtype=int, device=device)
132
+ wp.utils.segmented_sort_pairs(
133
+ keys, values, 0, wp.array((), dtype=int, device=device), wp.array((), dtype=int, device=device)
134
+ )
135
+
136
+
137
+ def test_radix_sort_pairs_error_insufficient_storage(test, device):
138
+ keyTypes = [int, wp.float32, wp.int64]
139
+
104
140
  for keyType in keyTypes:
105
141
  keys = wp.array((1, 2, 3), dtype=keyType, device=device)
106
142
  values = wp.array((1, 2, 3), dtype=int, device=device)
@@ -111,9 +147,28 @@ def test_radix_sort_pairs_error_insufficient_storage(test, device):
111
147
  wp.utils.radix_sort_pairs(keys, values, 3)
112
148
 
113
149
 
114
- def test_radix_sort_pairs_error_unsupported_dtype(test, device):
150
+ def test_segmented_sort_pairs_error_insufficient_storage(test, device):
115
151
  keyTypes = [int, wp.float32]
116
152
 
153
+ for keyType in keyTypes:
154
+ keys = wp.array((1, 2, 3), dtype=keyType, device=device)
155
+ values = wp.array((1, 2, 3), dtype=int, device=device)
156
+ with test.assertRaisesRegex(
157
+ RuntimeError,
158
+ r"Array storage must be large enough to contain 2\*count elements$",
159
+ ):
160
+ wp.utils.segmented_sort_pairs(
161
+ keys,
162
+ values,
163
+ 3,
164
+ wp.array((0,), dtype=int, device=device),
165
+ wp.array((3,), dtype=int, device=device),
166
+ )
167
+
168
+
169
+ def test_radix_sort_pairs_error_unsupported_dtype(test, device):
170
+ keyTypes = [int, wp.float32, wp.int64]
171
+
117
172
  for keyType in keyTypes:
118
173
  keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
119
174
  values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
@@ -124,6 +179,25 @@ def test_radix_sort_pairs_error_unsupported_dtype(test, device):
124
179
  wp.utils.radix_sort_pairs(keys, values, 1)
125
180
 
126
181
 
182
+ def test_segmented_sort_pairs_error_unsupported_dtype(test, device):
183
+ keyTypes = [int, wp.float32]
184
+
185
+ for keyType in keyTypes:
186
+ keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
187
+ values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
188
+ with test.assertRaisesRegex(
189
+ RuntimeError,
190
+ r"Unsupported data type$",
191
+ ):
192
+ wp.utils.segmented_sort_pairs(
193
+ keys,
194
+ values,
195
+ 1,
196
+ wp.array((0,), dtype=int, device=device),
197
+ wp.array((3,), dtype=int, device=device),
198
+ )
199
+
200
+
127
201
  def test_array_sum(test, device):
128
202
  for dtype in (wp.float32, wp.float64):
129
203
  with test.subTest(dtype=dtype):
@@ -460,6 +534,20 @@ add_function_test(
460
534
  test_radix_sort_pairs_error_unsupported_dtype,
461
535
  devices=devices,
462
536
  )
537
+ add_function_test(TestUtils, "test_segmented_sort_pairs", test_segmented_sort_pairs, devices=devices)
538
+ add_function_test(TestUtils, "test_segmented_sort_pairs_empty", test_segmented_sort_pairs, devices=devices)
539
+ add_function_test(
540
+ TestUtils,
541
+ "test_segmented_sort_pairs_error_insufficient_storage",
542
+ test_segmented_sort_pairs_error_insufficient_storage,
543
+ devices=devices,
544
+ )
545
+ add_function_test(
546
+ TestUtils,
547
+ "test_segmented_sort_pairs_error_unsupported_dtype",
548
+ test_segmented_sort_pairs_error_unsupported_dtype,
549
+ devices=devices,
550
+ )
463
551
  add_function_test(TestUtils, "test_array_sum", test_array_sum, devices=devices)
464
552
  add_function_test(
465
553
  TestUtils, "test_array_sum_error_out_dtype_mismatch", test_array_sum_error_out_dtype_mismatch, devices=devices
warp/tests/test_vec.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
  from typing import Any
@@ -1036,7 +1044,7 @@ def test_casting_constructors(test, device, dtype, register_kernels=False):
1036
1044
  assert_np_equal(out, a_grad.numpy())
1037
1045
 
1038
1046
 
1039
- def test_vec_assign(test, device, dtype, register_kernels=False):
1047
+ def test_vector_assign_inplace(test, device, dtype, register_kernels=False):
1040
1048
  np_type = np.dtype(dtype)
1041
1049
  wp_type = wp.types.np_dtype_to_warp_type[np_type]
1042
1050
 
@@ -1077,16 +1085,6 @@ def test_vec_assign(test, device, dtype, register_kernels=False):
1077
1085
  g = a_vec[0] + a_vec[1]
1078
1086
  x[tid] = g
1079
1087
 
1080
- def vectest_in_register_overwrite(x: wp.array(dtype=vec3), a: wp.array(dtype=vec3)):
1081
- tid = wp.tid()
1082
-
1083
- f = vec3(wp_type(0.0))
1084
- a_vec = a[tid]
1085
- f = a_vec
1086
- f[1] = wp_type(3.0)
1087
-
1088
- x[tid] = f
1089
-
1090
1088
  def vectest_component(x: wp.array(dtype=vec3), y: wp.array(dtype=wp_type)):
1091
1089
  i = wp.tid()
1092
1090
 
@@ -1098,7 +1096,6 @@ def test_vec_assign(test, device, dtype, register_kernels=False):
1098
1096
 
1099
1097
  kernel_read_write_store = getkernel(vectest_read_write_store, suffix=dtype.__name__)
1100
1098
  kernel_in_register = getkernel(vectest_in_register, suffix=dtype.__name__)
1101
- kernel_in_register_overwrite = getkernel(vectest_in_register_overwrite, suffix=dtype.__name__)
1102
1099
  kernel_component = getkernel(vectest_component, suffix=dtype.__name__)
1103
1100
 
1104
1101
  if register_kernels:
@@ -1148,7 +1145,6 @@ def test_vec_assign(test, device, dtype, register_kernels=False):
1148
1145
  x = wp.zeros(1, dtype=vec3, device=device, requires_grad=True)
1149
1146
  y = wp.ones(1, dtype=wp_type, device=device, requires_grad=True)
1150
1147
 
1151
- tape = wp.Tape()
1152
1148
  with tape:
1153
1149
  wp.launch(kernel_component, dim=1, inputs=[x, y], device=device)
1154
1150
 
@@ -1157,20 +1153,6 @@ def test_vec_assign(test, device, dtype, register_kernels=False):
1157
1153
  assert_np_equal(x.numpy(), np.array([[1.0, 2.0, 3.0]], dtype=np_type))
1158
1154
  assert_np_equal(y.grad.numpy(), np.array([6.0], dtype=np_type))
1159
1155
 
1160
- tape.reset()
1161
-
1162
- x = wp.zeros(1, dtype=vec3, device=device, requires_grad=True)
1163
- a = wp.ones(1, dtype=vec3, device=device, requires_grad=True)
1164
-
1165
- tape = wp.Tape()
1166
- with tape:
1167
- wp.launch(kernel_in_register_overwrite, dim=1, inputs=[x, a], device=device)
1168
-
1169
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1170
-
1171
- assert_np_equal(x.numpy(), np.array([[1.0, 3.0, 1.0]], dtype=np_type))
1172
- assert_np_equal(a.grad.numpy(), np.array([[1.0, 0.0, 1.0]], dtype=np_type))
1173
-
1174
1156
 
1175
1157
  @wp.kernel
1176
1158
  def test_vector_constructor_value_func():
@@ -1317,15 +1299,15 @@ def vector_augassign_kernel(
1317
1299
  def test_vector_augassign(test, device):
1318
1300
  N = 3
1319
1301
 
1320
- a = wp.zeros(N, dtype=wp.vec3, requires_grad=True)
1321
- b = wp.ones(N, dtype=wp.vec3, requires_grad=True)
1302
+ a = wp.zeros(N, dtype=wp.vec3, requires_grad=True, device=device)
1303
+ b = wp.ones(N, dtype=wp.vec3, requires_grad=True, device=device)
1322
1304
 
1323
- c = wp.zeros(N, dtype=wp.vec3, requires_grad=True)
1324
- d = wp.ones(N, dtype=wp.vec3, requires_grad=True)
1305
+ c = wp.zeros(N, dtype=wp.vec3, requires_grad=True, device=device)
1306
+ d = wp.ones(N, dtype=wp.vec3, requires_grad=True, device=device)
1325
1307
 
1326
1308
  tape = wp.Tape()
1327
1309
  with tape:
1328
- wp.launch(vector_augassign_kernel, N, inputs=[a, b, c, d])
1310
+ wp.launch(vector_augassign_kernel, N, inputs=[a, b, c, d], device=device)
1329
1311
 
1330
1312
  tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
1331
1313
 
@@ -1338,6 +1320,38 @@ def test_vector_augassign(test, device):
1338
1320
  assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
1339
1321
 
1340
1322
 
1323
+ def test_vector_assign_copy(test, device):
1324
+ saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
1325
+ try:
1326
+ wp.config.enable_vector_component_overwrites = True
1327
+
1328
+ @wp.kernel
1329
+ def vec_in_register_overwrite(x: wp.array(dtype=wp.vec3), a: wp.array(dtype=wp.vec3)):
1330
+ tid = wp.tid()
1331
+
1332
+ f = wp.vec3(0.0)
1333
+ a_vec = a[tid]
1334
+ f = a_vec
1335
+ f[1] = 3.0
1336
+
1337
+ x[tid] = f
1338
+
1339
+ x = wp.zeros(1, dtype=wp.vec3, device=device, requires_grad=True)
1340
+ a = wp.ones(1, dtype=wp.vec3, device=device, requires_grad=True)
1341
+
1342
+ tape = wp.Tape()
1343
+ with tape:
1344
+ wp.launch(vec_in_register_overwrite, dim=1, inputs=[x, a], device=device)
1345
+
1346
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1347
+
1348
+ assert_np_equal(x.numpy(), np.array([[1.0, 3.0, 1.0]], dtype=float))
1349
+ assert_np_equal(a.grad.numpy(), np.array([[1.0, 0.0, 1.0]], dtype=float))
1350
+
1351
+ finally:
1352
+ wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
1353
+
1354
+
1341
1355
  devices = get_test_devices()
1342
1356
 
1343
1357
 
@@ -1406,8 +1420,8 @@ for dtype in np_float_types:
1406
1420
  )
1407
1421
  add_function_test_register_kernel(
1408
1422
  TestVec,
1409
- f"test_vec_assign_{dtype.__name__}",
1410
- test_vec_assign,
1423
+ f"test_vector_assign_inplace_{dtype.__name__}",
1424
+ test_vector_assign_inplace,
1411
1425
  devices=devices,
1412
1426
  dtype=dtype,
1413
1427
  )
@@ -1460,6 +1474,12 @@ add_function_test(
1460
1474
  test_vector_augassign,
1461
1475
  devices=devices,
1462
1476
  )
1477
+ add_function_test(
1478
+ TestVec,
1479
+ "test_vector_assign_copy",
1480
+ test_vector_assign_copy,
1481
+ devices=devices,
1482
+ )
1463
1483
 
1464
1484
 
1465
1485
  if __name__ == "__main__":
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
File without changes