warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
@@ -12,8 +20,6 @@ import numpy as np
12
20
  import warp as wp
13
21
  from warp.tests.unittest_utils import *
14
22
 
15
- wp.init() # For wp.context.runtime.core.is_mathdx_enabled()
16
-
17
23
  TILE_M = wp.constant(8)
18
24
  TILE_N = wp.constant(4)
19
25
  TILE_K = wp.constant(8)
@@ -208,7 +214,6 @@ def test_tile_binary_map(test, device):
208
214
  assert_np_equal(B_wp.grad.numpy(), B_grad)
209
215
 
210
216
 
211
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
212
217
  def test_tile_grouped_gemm(test, device):
213
218
  @wp.kernel
214
219
  def tile_grouped_gemm(A: wp.array3d(dtype=float), B: wp.array3d(dtype=float), C: wp.array3d(dtype=float)):
@@ -248,60 +253,62 @@ def test_tile_grouped_gemm(test, device):
248
253
  assert_np_equal(C_wp.numpy(), C, 1e-6)
249
254
 
250
255
 
251
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
252
- def test_tile_gemm(test, device):
253
- @wp.kernel
254
- def tile_gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)):
255
- # output tile index
256
- i, j = wp.tid()
256
+ def test_tile_gemm(dtype):
257
+ def test(test, device):
258
+ @wp.kernel
259
+ def tile_gemm(A: wp.array2d(dtype=dtype), B: wp.array2d(dtype=dtype), C: wp.array2d(dtype=dtype)):
260
+ # output tile index
261
+ i, j = wp.tid()
257
262
 
258
- sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32)
263
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=dtype)
259
264
 
260
- M = A.shape[0]
261
- N = B.shape[1]
262
- K = A.shape[1]
265
+ M = A.shape[0]
266
+ N = B.shape[1]
267
+ K = A.shape[1]
263
268
 
264
- count = int(K / TILE_K)
269
+ count = int(K / TILE_K)
265
270
 
266
- for k in range(0, count):
267
- a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
268
- b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
271
+ for k in range(0, count):
272
+ a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
273
+ b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
269
274
 
270
- # sum += a*b
271
- wp.tile_matmul(a, b, sum)
275
+ # sum += a*b
276
+ wp.tile_matmul(a, b, sum)
272
277
 
273
- wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
278
+ wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
274
279
 
275
- M = TILE_M * 7
276
- K = TILE_K * 6
277
- N = TILE_N * 5
280
+ M = TILE_M * 7
281
+ K = TILE_K * 6
282
+ N = TILE_N * 5
278
283
 
279
- rng = np.random.default_rng(42)
280
- A = rng.random((M, K), dtype=np.float32)
281
- B = rng.random((K, N), dtype=np.float32)
282
- C = np.zeros((M, N), dtype=np.float32)
284
+ rng = np.random.default_rng(42)
285
+ A = rng.random((M, K), dtype=float).astype(wp.dtype_to_numpy(dtype))
286
+ B = rng.random((K, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
287
+ C = np.zeros((M, N), dtype=float).astype(wp.dtype_to_numpy(dtype))
283
288
 
284
- A_wp = wp.array(A, requires_grad=True, device=device)
285
- B_wp = wp.array(B, requires_grad=True, device=device)
286
- C_wp = wp.array(C, requires_grad=True, device=device)
289
+ A_wp = wp.array(A, requires_grad=True, device=device)
290
+ B_wp = wp.array(B, requires_grad=True, device=device)
291
+ C_wp = wp.array(C, requires_grad=True, device=device)
287
292
 
288
- with wp.Tape() as tape:
289
- wp.launch_tiled(
290
- tile_gemm,
291
- dim=(int(M / TILE_M), int(N / TILE_N)),
292
- inputs=[A_wp, B_wp, C_wp],
293
- block_dim=TILE_DIM,
294
- device=device,
295
- )
293
+ with wp.Tape() as tape:
294
+ wp.launch_tiled(
295
+ tile_gemm,
296
+ dim=(int(M / TILE_M), int(N / TILE_N)),
297
+ inputs=[A_wp, B_wp, C_wp],
298
+ block_dim=TILE_DIM,
299
+ device=device,
300
+ )
296
301
 
297
- assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-5)
302
+ assert_np_equal(C_wp.numpy(), A @ B, tol=1.0e-1)
298
303
 
299
- adj_C = np.ones_like(C)
304
+ adj_C = np.ones_like(C)
300
305
 
301
- tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
306
+ tape.backward(grads={C_wp: wp.array(adj_C, device=device)})
302
307
 
303
- assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-5)
304
- assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-5)
308
+ assert_np_equal(A_wp.grad.numpy(), adj_C @ B.T, tol=1.0e-1)
309
+ assert_np_equal(B_wp.grad.numpy(), A.T @ adj_C, 1.0e-1)
310
+
311
+ return test
305
312
 
306
313
 
307
314
  @wp.kernel
@@ -542,7 +549,6 @@ def test_tile_transpose(test, device):
542
549
  assert_np_equal(output.numpy(), input.numpy().T)
543
550
 
544
551
 
545
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
546
552
  def test_tile_transpose_matmul(test, device):
547
553
  @wp.kernel
548
554
  def test_tile_transpose_matmul_kernel(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
@@ -564,9 +570,36 @@ def test_tile_transpose_matmul(test, device):
564
570
 
565
571
 
566
572
  @wp.kernel
567
- def test_tile_broadcast_add_kernel(
573
+ def test_tile_broadcast_add_1d_kernel(
574
+ input_a: wp.array(dtype=float), input_b: wp.array(dtype=float), output: wp.array(dtype=float)
575
+ ):
576
+ a = wp.tile_load(input_a, shape=(10,))
577
+ b = wp.tile_load(input_b, shape=(1,))
578
+
579
+ c = wp.tile_broadcast(b, shape=(10,))
580
+ d = a + c
581
+
582
+ wp.tile_store(output, d)
583
+
584
+
585
+ def test_tile_broadcast_add_1d(test, device):
586
+ N = 10
587
+
588
+ # implicit 1-dim ([1], 1)
589
+ a = wp.array(np.arange(0, N, dtype=np.float32), device=device)
590
+ b = wp.array(np.ones(1, dtype=np.float32), device=device)
591
+ out = wp.zeros((N,), dtype=float, device=device)
592
+
593
+ wp.launch_tiled(test_tile_broadcast_add_1d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
594
+
595
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
596
+
597
+
598
+ @wp.kernel
599
+ def test_tile_broadcast_add_2d_kernel(
568
600
  input_a: wp.array2d(dtype=float), input_b: wp.array(dtype=float), output: wp.array2d(dtype=float)
569
601
  ):
602
+ # implicit 1-dim ([1], 10)
570
603
  a = wp.tile_load(input_a, shape=(10, 10))
571
604
  b = wp.tile_load(input_b, shape=10)
572
605
 
@@ -576,7 +609,7 @@ def test_tile_broadcast_add_kernel(
576
609
  wp.tile_store(output, d)
577
610
 
578
611
 
579
- def test_tile_broadcast_add(test, device):
612
+ def test_tile_broadcast_add_2d(test, device):
580
613
  M = 10
581
614
  N = 10
582
615
 
@@ -584,7 +617,62 @@ def test_tile_broadcast_add(test, device):
584
617
  b = wp.array(np.arange(0, N, dtype=np.float32), device=device)
585
618
  out = wp.zeros((M, N), dtype=float, device=device)
586
619
 
587
- wp.launch_tiled(test_tile_broadcast_add_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
620
+ wp.launch_tiled(test_tile_broadcast_add_2d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
621
+
622
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
623
+
624
+
625
+ @wp.kernel
626
+ def test_tile_broadcast_add_3d_kernel(
627
+ input_a: wp.array3d(dtype=float), input_b: wp.array3d(dtype=float), output: wp.array3d(dtype=float)
628
+ ):
629
+ a = wp.tile_load(input_a, shape=(4, 10, 12))
630
+ b = wp.tile_load(input_b, shape=(4, 10, 1))
631
+
632
+ c = wp.tile_broadcast(b, shape=(4, 10, 12))
633
+ d = a + c
634
+
635
+ wp.tile_store(output, d)
636
+
637
+
638
+ def test_tile_broadcast_add_3d(test, device):
639
+ M = 4
640
+ N = 10
641
+ O = 12
642
+
643
+ # explicit 1-dim (M, N, 1) to (M, N, O)
644
+ a = wp.array(np.ones((M, N, O), dtype=np.float32), device=device)
645
+ b = wp.array(np.arange(0, M * N, dtype=np.float32).reshape((M, N, 1)), device=device)
646
+ out = wp.zeros((M, N, O), dtype=float, device=device)
647
+
648
+ wp.launch_tiled(test_tile_broadcast_add_3d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
649
+ assert_np_equal(out.numpy(), a.numpy() + b.numpy())
650
+
651
+
652
+ @wp.kernel
653
+ def test_tile_broadcast_add_4d_kernel(
654
+ input_a: wp.array4d(dtype=float), input_b: wp.array4d(dtype=float), output: wp.array4d(dtype=float)
655
+ ):
656
+ a = wp.tile_load(input_a, shape=(4, 10, 5, 6))
657
+ b = wp.tile_load(input_b, shape=(4, 1, 5, 1))
658
+ c = wp.tile_broadcast(b, shape=(4, 10, 5, 6))
659
+ d = a + c
660
+
661
+ wp.tile_store(output, d)
662
+
663
+
664
+ def test_tile_broadcast_add_4d(test, device):
665
+ M = 4
666
+ N = 10
667
+ O = 5
668
+ P = 6
669
+
670
+ # explicit 1-dims (M, 1, O, 1) to (M, N, O, P)
671
+ a = wp.array(np.ones((M, N, O, P), dtype=np.float32), device=device)
672
+ b = wp.array(np.arange(0, M * O, dtype=np.float32).reshape((M, 1, O, 1)), device=device)
673
+ out = wp.zeros((M, N, O, P), dtype=float, device=device)
674
+
675
+ wp.launch_tiled(test_tile_broadcast_add_4d_kernel, dim=[1], inputs=[a, b, out], block_dim=32, device=device)
588
676
 
589
677
  assert_np_equal(out.numpy(), a.numpy() + b.numpy())
590
678
 
@@ -657,7 +745,7 @@ def test_tile_print(test, device):
657
745
  wp.synchronize()
658
746
 
659
747
 
660
- devices = get_cuda_test_devices()
748
+ devices = get_test_devices()
661
749
 
662
750
 
663
751
  class TestTile(unittest.TestCase):
@@ -669,15 +757,20 @@ add_function_test(TestTile, "test_tile_copy_2d", test_tile_copy_2d, devices=devi
669
757
  add_function_test(TestTile, "test_tile_unary_map", test_tile_unary_map, devices=devices)
670
758
  add_function_test(TestTile, "test_tile_binary_map", test_tile_binary_map, devices=devices)
671
759
  add_function_test(TestTile, "test_tile_grouped_gemm", test_tile_grouped_gemm, devices=devices)
672
- add_function_test(TestTile, "test_tile_gemm", test_tile_gemm, devices=devices)
760
+ add_function_test(TestTile, "test_tile_gemm_fp16", test_tile_gemm(wp.float16), devices=devices)
761
+ add_function_test(TestTile, "test_tile_gemm_fp32", test_tile_gemm(wp.float32), devices=devices)
762
+ add_function_test(TestTile, "test_tile_gemm_fp64", test_tile_gemm(wp.float64), devices=devices)
673
763
  add_function_test(TestTile, "test_tile_transpose", test_tile_transpose, devices=devices)
674
764
  add_function_test(TestTile, "test_tile_transpose_matmul", test_tile_transpose_matmul, devices=devices)
675
765
  add_function_test(TestTile, "test_tile_operators", test_tile_operators, devices=devices)
676
- add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices)
766
+ add_function_test(TestTile, "test_tile_sum", test_tile_sum, devices=devices, check_output=False)
677
767
  add_function_test(TestTile, "test_tile_sum_launch", test_tile_sum_launch, devices=devices)
678
768
  add_function_test(TestTile, "test_tile_extract", test_tile_extract, devices=devices)
679
769
  add_function_test(TestTile, "test_tile_extract_repeated", test_tile_extract_repeated, devices=devices)
680
- add_function_test(TestTile, "test_tile_broadcast_add", test_tile_broadcast_add, devices=devices)
770
+ add_function_test(TestTile, "test_tile_broadcast_add_1d", test_tile_broadcast_add_1d, devices=devices)
771
+ add_function_test(TestTile, "test_tile_broadcast_add_2d", test_tile_broadcast_add_2d, devices=devices)
772
+ add_function_test(TestTile, "test_tile_broadcast_add_3d", test_tile_broadcast_add_3d, devices=devices)
773
+ add_function_test(TestTile, "test_tile_broadcast_add_4d", test_tile_broadcast_add_4d, devices=devices)
681
774
  add_function_test(TestTile, "test_tile_broadcast_grad", test_tile_broadcast_grad, devices=devices)
682
775
  add_function_test(TestTile, "test_tile_len", test_tile_len, devices=devices)
683
776
  add_function_test(TestTile, "test_tile_print", test_tile_print, devices=devices, check_output=False)
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
@@ -368,7 +376,7 @@ def test_tile_load_fortran(test, device):
368
376
  assert_array_equal(B_wp.grad, A_wp.grad)
369
377
 
370
378
 
371
- devices = get_cuda_test_devices()
379
+ devices = get_test_devices()
372
380
 
373
381
 
374
382
  class TestTileLoad(unittest.TestCase):
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import functools
9
17
  import unittest
@@ -84,6 +92,7 @@ def tile_math_fft_kernel_vec2d(gx: wp.array2d(dtype=wp.vec2d), gy: wp.array2d(dt
84
92
  wp.tile_store(gy, xy)
85
93
 
86
94
 
95
+ @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
87
96
  def test_tile_math_fft(test, device, wp_dtype):
88
97
  np_real_dtype = {wp.vec2f: np.float32, wp.vec2d: np.float64}[wp_dtype]
89
98
  np_cplx_dtype = {wp.vec2f: np.complex64, wp.vec2d: np.complex128}[wp_dtype]
@@ -164,31 +173,33 @@ def test_tile_math_cholesky(test, device):
164
173
  # TODO: implement and test backward pass
165
174
 
166
175
 
167
- devices = get_cuda_test_devices()
176
+ all_devices = get_test_devices()
177
+ cuda_devices = get_cuda_test_devices()
168
178
 
169
179
 
170
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
171
180
  class TestTileMathDx(unittest.TestCase):
172
181
  pass
173
182
 
174
183
 
175
184
  # check_output=False so we can enable libmathdx's logging without failing the tests
176
- add_function_test(TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=devices, check_output=False)
177
185
  add_function_test(
178
- TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=devices, check_output=False
186
+ TestTileMathDx, "test_tile_math_matmul", test_tile_math_matmul, devices=all_devices, check_output=False
187
+ )
188
+ add_function_test(
189
+ TestTileMathDx, "test_tile_math_cholesky", test_tile_math_cholesky, devices=all_devices, check_output=False
179
190
  )
180
191
  add_function_test(
181
192
  TestTileMathDx,
182
193
  "test_tile_math_fft_vec2f",
183
194
  functools.partial(test_tile_math_fft, wp_dtype=wp.vec2f),
184
- devices=devices,
195
+ devices=cuda_devices,
185
196
  check_output=False,
186
197
  )
187
198
  add_function_test(
188
199
  TestTileMathDx,
189
200
  "test_tile_math_fft_vec2d",
190
201
  functools.partial(test_tile_math_fft, wp_dtype=wp.vec2d),
191
- devices=devices,
202
+ devices=cuda_devices,
192
203
  check_output=False,
193
204
  )
194
205
 
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import os
9
17
 
@@ -14,11 +22,6 @@ import warp.examples
14
22
  import warp.optim
15
23
  from warp.tests.unittest_utils import *
16
24
 
17
- wp.init()
18
-
19
- # needs to be constant for the whole module
20
- NUM_THREADS = 32
21
-
22
25
 
23
26
  def create_layer(rng, dim_in, dim_hid, dtype=float):
24
27
  w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
@@ -37,10 +40,12 @@ def create_array(rng, dim_in, dim_hid, dtype=float):
37
40
  return a
38
41
 
39
42
 
40
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
41
43
  def test_multi_layer_nn(test, device):
42
44
  import torch as tc
43
45
 
46
+ if device.is_cuda and not wp.context.runtime.core.is_mathdx_enabled():
47
+ test.skipTest("Skipping test on CUDA device without MathDx (tolerance)")
48
+
44
49
  NUM_FREQ = wp.constant(8)
45
50
 
46
51
  DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequency
@@ -52,7 +57,13 @@ def test_multi_layer_nn(test, device):
52
57
 
53
58
  BATCH_SIZE = min(512, int((IMG_WIDTH * IMG_HEIGHT) / 8))
54
59
 
60
+ if device.is_cpu:
61
+ NUM_THREADS = 1
62
+ else:
63
+ NUM_THREADS = 32
64
+
55
65
  dtype = wp.float16
66
+ npdtype = wp.types.warp_type_to_np_dtype[dtype]
56
67
 
57
68
  @wp.func
58
69
  def relu(x: dtype):
@@ -66,7 +77,7 @@ def test_multi_layer_nn(test, device):
66
77
  def zero(loss: wp.array(dtype=float)):
67
78
  loss[0] = 0.0
68
79
 
69
- @wp.kernel
80
+ @wp.kernel(module="unique")
70
81
  def compute(
71
82
  batches: wp.array(dtype=int),
72
83
  input: wp.array2d(dtype=dtype),
@@ -162,7 +173,9 @@ def test_multi_layer_nn(test, device):
162
173
  input = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_IN, dtype=dtype)
163
174
  output = create_array(rng, IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
164
175
 
165
- reference_np = np.load(os.path.join(os.path.dirname(__file__), "assets/pixel.npy"), allow_pickle=True) / 255.0
176
+ reference_np = (
177
+ np.load(os.path.join(os.path.dirname(__file__), "..", "assets", "pixel.npy"), allow_pickle=True) / 255.0
178
+ )
166
179
  reference = wp.array(reference_np, dtype=float)
167
180
 
168
181
  assert reference.shape[1] == IMG_WIDTH * IMG_HEIGHT
@@ -224,7 +237,7 @@ def test_multi_layer_nn(test, device):
224
237
  z_np = np.maximum(weights_3.numpy() @ z_np + bias_3.numpy(), 0.0)
225
238
 
226
239
  # test numpy forward
227
- assert_np_equal(output.numpy()[:, indices], z_np, tol=1.0e-2)
240
+ assert_np_equal(output.numpy()[:, indices].astype(npdtype), z_np, tol=1.0e-2)
228
241
 
229
242
  # torch
230
243
  input_tc = tc.tensor(input.numpy()[:, indices], requires_grad=True, device=torch_device)
@@ -252,7 +265,9 @@ def test_multi_layer_nn(test, device):
252
265
  l_tc.backward()
253
266
 
254
267
  # test torch
255
- assert_np_equal(z_tc.cpu().detach().numpy(), output.numpy()[:, indices], tol=1.0e-2)
268
+ assert_np_equal(
269
+ z_tc.cpu().detach().numpy(), output.numpy()[:, indices].astype(npdtype), tol=1.0e-2
270
+ )
256
271
  assert_np_equal(weights_0.grad.numpy(), weights_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
257
272
  assert_np_equal(bias_0.grad.numpy(), bias_0_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
258
273
  assert_np_equal(weights_1.grad.numpy(), weights_1_tc.grad.cpu().detach().numpy(), tol=1.0e-2)
@@ -269,7 +284,6 @@ def test_multi_layer_nn(test, device):
269
284
  test.assertLess(loss.numpy()[0], 0.002)
270
285
 
271
286
 
272
- @unittest.skipUnless(wp.context.runtime.core.is_mathdx_enabled(), "Warp was not built with MathDx support")
273
287
  def test_single_layer_nn(test, device):
274
288
  import torch as tc
275
289
 
@@ -279,11 +293,16 @@ def test_single_layer_nn(test, device):
279
293
 
280
294
  NUM_BLOCKS = 56
281
295
 
296
+ if device.is_cpu:
297
+ NUM_THREADS = 1
298
+ else:
299
+ NUM_THREADS = 32
300
+
282
301
  @wp.func
283
302
  def relu(x: float):
284
303
  return wp.max(x, 0.0)
285
304
 
286
- @wp.kernel
305
+ @wp.kernel(module="unique")
287
306
  def compute(
288
307
  input: wp.array2d(dtype=float),
289
308
  weights: wp.array2d(dtype=float),
@@ -345,7 +364,6 @@ try:
345
364
  import torch
346
365
 
347
366
  # check which Warp devices work with Torch
348
- # CUDA devices may fail if Torch was not compiled with CUDA support
349
367
  torch_compatible_devices = []
350
368
  torch_compatible_cuda_devices = []
351
369
 
@@ -364,7 +382,7 @@ try:
364
382
  "test_single_layer_nn",
365
383
  test_single_layer_nn,
366
384
  check_output=False,
367
- devices=torch_compatible_cuda_devices,
385
+ devices=torch_compatible_devices,
368
386
  )
369
387
  add_function_test(
370
388
  TestTileMLP,
@@ -380,4 +398,5 @@ except Exception as e:
380
398
 
381
399
  if __name__ == "__main__":
382
400
  wp.clear_kernel_cache()
401
+ wp.clear_lto_cache()
383
402
  unittest.main(verbosity=2, failfast=True)
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
@@ -168,6 +176,64 @@ def test_tile_reduce_custom(test, device):
168
176
  test.assertAlmostEqual(prod_wp[i], prod_np, places=4)
169
177
 
170
178
 
179
+ @wp.struct
180
+ class KeyValue:
181
+ key: wp.int32
182
+ value: wp.float32
183
+
184
+
185
+ @wp.func
186
+ def kv_max(a: KeyValue, b: KeyValue) -> KeyValue:
187
+ return wp.where(a.value < b.value, b, a)
188
+
189
+
190
+ @wp.kernel
191
+ def initialize_key_value(values: wp.array2d(dtype=wp.float32), keyvalues: wp.array2d(dtype=KeyValue)):
192
+ batch, idx = wp.tid()
193
+ keyvalues[batch, idx] = KeyValue(idx, values[batch, idx])
194
+
195
+
196
+ @wp.kernel(enable_backward=False)
197
+ def tile_reduce_custom_struct_kernel(values: wp.array2d(dtype=KeyValue), res: wp.array(dtype=KeyValue)):
198
+ # output tile index
199
+ i = wp.tid()
200
+
201
+ t = wp.tile_load(values, shape=(1, TILE_DIM), offset=(i, 0))
202
+
203
+ max_el = wp.tile_reduce(kv_max, t)
204
+ wp.tile_store(res, max_el, offset=i)
205
+
206
+
207
+ def test_tile_reduce_custom_struct(test, device):
208
+ batch_count = 56
209
+
210
+ N = TILE_DIM
211
+
212
+ rng = np.random.default_rng(42)
213
+ input = rng.random((batch_count, N), dtype=np.float32)
214
+
215
+ input_wp = wp.array(input, dtype=wp.float32, device=device)
216
+ keyvalues_wp = wp.empty(input_wp.shape, dtype=KeyValue, device=device)
217
+
218
+ wp.launch(initialize_key_value, dim=[batch_count, N], inputs=[input_wp], outputs=[keyvalues_wp], device=device)
219
+
220
+ output_wp = wp.empty(batch_count, dtype=KeyValue, device=device)
221
+
222
+ wp.launch_tiled(
223
+ tile_reduce_custom_struct_kernel,
224
+ dim=[batch_count],
225
+ inputs=[keyvalues_wp],
226
+ outputs=[output_wp],
227
+ block_dim=TILE_DIM,
228
+ device=device,
229
+ )
230
+
231
+ prod_wp = np.array([k for k, v in output_wp.numpy()])
232
+ expected = np.argmax(input, axis=1)
233
+
234
+ assert_np_equal(prod_wp, expected)
235
+
236
+
171
237
  @wp.kernel
172
238
  def tile_grouped_sum_kernel(input: wp.array3d(dtype=float), output: wp.array(dtype=float)):
173
239
  # output tile index
@@ -357,7 +423,7 @@ def test_tile_arange(test, device):
357
423
  assert_np_equal(output.numpy()[4], np.arange(17, 0, -1))
358
424
 
359
425
 
360
- devices = get_cuda_test_devices()
426
+ devices = get_test_devices()
361
427
 
362
428
 
363
429
  class TestTileReduce(unittest.TestCase):
@@ -368,6 +434,7 @@ add_function_test(TestTileReduce, "test_tile_reduce_sum", test_tile_reduce_sum,
368
434
  add_function_test(TestTileReduce, "test_tile_reduce_min", test_tile_reduce_min, devices=devices)
369
435
  add_function_test(TestTileReduce, "test_tile_reduce_max", test_tile_reduce_max, devices=devices)
370
436
  add_function_test(TestTileReduce, "test_tile_reduce_custom", test_tile_reduce_custom, devices=devices)
437
+ add_function_test(TestTileReduce, "test_tile_reduce_custom_struct", test_tile_reduce_custom_struct, devices=devices)
371
438
  add_function_test(TestTileReduce, "test_tile_reduce_grouped_sum", test_tile_reduce_sum, devices=devices)
372
439
  add_function_test(TestTileReduce, "test_tile_reduce_simt", test_tile_reduce_simt, devices=devices)
373
440
  add_function_test(TestTileReduce, "test_tile_ones", test_tile_ones, devices=devices)