warp-lang 1.6.1__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/tests/test_mat.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
  from typing import Any
@@ -119,30 +127,6 @@ def test_tpl_constructor_error_incompatible_sizes(test, device):
119
127
  wp.launch(kernel, dim=1, inputs=[], device=device)
120
128
 
121
129
 
122
- def test_tpl_constructor_error_invalid_vector_count(test, device):
123
- @wp.kernel
124
- def kernel():
125
- wp.mat33(wp.vec3(1.0, 2.0, 3.0), wp.vec3(1.0, 2.0, 3.0))
126
-
127
- with test.assertRaisesRegex(
128
- RuntimeError,
129
- r"incompatible number of column vectors given \(2\) when constructing a matrix of shape \(3, 3\)$",
130
- ):
131
- wp.launch(kernel, dim=1, inputs=[], device=device)
132
-
133
-
134
- def test_tpl_constructor_error_invalid_vector_shape(test, device):
135
- @wp.kernel
136
- def kernel():
137
- wp.mat22(wp.vec3(1.0, 2.0, 3.0), wp.vec3(4.0, 5.0, 6.0))
138
-
139
- with test.assertRaisesRegex(
140
- RuntimeError,
141
- r"incompatible column vector lengths given when constructing a matrix of shape \(2, 2\)$",
142
- ):
143
- wp.launch(kernel, dim=1, inputs=[], device=device)
144
-
145
-
146
130
  def test_tpl_constructor_error_invalid_arg_count(test, device):
147
131
  @wp.kernel
148
132
  def kernel():
@@ -226,7 +210,7 @@ def test_quat_constructor(test, device, dtype, register_kernels=False):
226
210
  c0 = s[0][0] * R[0]
227
211
  c1 = s[0][1] * R[1]
228
212
  c2 = s[0][2] * R[2]
229
- m_alt = mat44(
213
+ m_alt = wp.matrix_from_cols(
230
214
  vec4(c0[0], c0[1], c0[2], wptype(0.0)),
231
215
  vec4(c1[0], c1[1], c1[2], wptype(0.0)),
232
216
  vec4(c2[0], c2[1], c2[2], wptype(0.0)),
@@ -1058,6 +1042,124 @@ def test_svd(test, device, dtype, register_kernels=False):
1058
1042
  assert_np_equal((plusval - minusval) / (2 * dx), m3grads[ii, jj], tol=fdtol)
1059
1043
 
1060
1044
 
1045
+ def test_svd_2D(test, device, dtype, register_kernels=False):
1046
+ rng = np.random.default_rng(123)
1047
+
1048
+ tol = {
1049
+ np.float16: 1.0e-3,
1050
+ np.float32: 1.0e-6,
1051
+ np.float64: 1.0e-12,
1052
+ }.get(dtype, 0)
1053
+
1054
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1055
+ vec2 = wp.types.vector(length=2, dtype=wptype)
1056
+ mat22 = wp.types.matrix(shape=(2, 2), dtype=wptype)
1057
+
1058
+ def check_mat_svd2(
1059
+ m2: wp.array(dtype=mat22),
1060
+ Uout: wp.array(dtype=mat22),
1061
+ sigmaout: wp.array(dtype=vec2),
1062
+ Vout: wp.array(dtype=mat22),
1063
+ outcomponents: wp.array(dtype=wptype),
1064
+ ):
1065
+ U = mat22()
1066
+ sigma = vec2()
1067
+ V = mat22()
1068
+
1069
+ wp.svd2(m2[0], U, sigma, V) # Assuming there's a 2D SVD kernel
1070
+
1071
+ Uout[0] = U
1072
+ sigmaout[0] = sigma
1073
+ Vout[0] = V
1074
+
1075
+ # multiply outputs by 2 so we've got something to backpropagate:
1076
+ idx = 0
1077
+ for i in range(2):
1078
+ for j in range(2):
1079
+ outcomponents[idx] = wptype(2) * U[i, j]
1080
+ idx = idx + 1
1081
+
1082
+ for i in range(2):
1083
+ outcomponents[idx] = wptype(2) * sigma[i]
1084
+ idx = idx + 1
1085
+
1086
+ for i in range(2):
1087
+ for j in range(2):
1088
+ outcomponents[idx] = wptype(2) * V[i, j]
1089
+ idx = idx + 1
1090
+
1091
+ kernel = getkernel(check_mat_svd2, suffix=dtype.__name__)
1092
+
1093
+ output_select_kernel = get_select_kernel(wptype)
1094
+
1095
+ if register_kernels:
1096
+ return
1097
+
1098
+ m2 = wp.array(randvals(rng, [1, 2, 2], dtype) + np.eye(2), dtype=mat22, requires_grad=True, device=device)
1099
+
1100
+ outcomponents = wp.zeros(2 * 2 * 2 + 2, dtype=wptype, requires_grad=True, device=device)
1101
+ Uout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1102
+ sigmaout = wp.zeros(1, dtype=vec2, requires_grad=True, device=device)
1103
+ Vout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1104
+
1105
+ wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1106
+
1107
+ Uout_np = Uout.numpy()[0].astype(np.float64)
1108
+ sigmaout_np = np.diag(sigmaout.numpy()[0].astype(np.float64))
1109
+ Vout_np = Vout.numpy()[0].astype(np.float64)
1110
+
1111
+ assert_np_equal(
1112
+ np.matmul(Uout_np, np.matmul(sigmaout_np, Vout_np.T)), m2.numpy()[0].astype(np.float64), tol=30 * tol
1113
+ )
1114
+
1115
+ if dtype == np.float16:
1116
+ # Skip gradient check for float16 due to rounding errors
1117
+ return
1118
+
1119
+ # Check gradients:
1120
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
1121
+ idx = 0
1122
+ for idx in range(2 * 2 + 2 + 2 * 2):
1123
+ tape = wp.Tape()
1124
+ with tape:
1125
+ wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1126
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1127
+ tape.backward(out)
1128
+ m2grads = 1.0 * tape.gradients[m2].numpy()[0]
1129
+
1130
+ tape.zero()
1131
+
1132
+ dx = 0.0001
1133
+ fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
1134
+ for ii in range(2):
1135
+ for jj in range(2):
1136
+ m2test = 1.0 * m2.numpy()
1137
+ m2test[0, ii, jj] += dx
1138
+ wp.launch(
1139
+ kernel,
1140
+ dim=1,
1141
+ inputs=[wp.array(m2test, dtype=mat22, device=device)],
1142
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1143
+ device=device,
1144
+ )
1145
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1146
+ plusval = out.numpy()[0]
1147
+
1148
+ m2test = 1.0 * m2.numpy()
1149
+ m2test[0, ii, jj] -= dx
1150
+ wp.launch(
1151
+ kernel,
1152
+ dim=1,
1153
+ inputs=[wp.array(m2test, dtype=mat22, device=device)],
1154
+ outputs=[Uout, sigmaout, Vout, outcomponents],
1155
+ device=device,
1156
+ )
1157
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
1158
+ minusval = out.numpy()[0]
1159
+
1160
+ assert_np_equal((plusval - minusval) / (2 * dx), m2grads[ii, jj], tol=fdtol)
1161
+
1162
+
1061
1163
  def test_qr(test, device, dtype, register_kernels=False):
1062
1164
  rng = np.random.default_rng(123)
1063
1165
 
@@ -1505,13 +1607,12 @@ def test_transform_vector(test, device, dtype, register_kernels=False):
1505
1607
  tape.zero()
1506
1608
 
1507
1609
 
1508
- def test_mat_array_type_indexing(test, device, dtype, register_kernels=False):
1610
+ def test_matrix_assign_inplace(test, device, dtype, register_kernels=False):
1509
1611
  np_type = np.dtype(dtype)
1510
1612
  wp_type = wp.types.np_dtype_to_warp_type[np_type]
1511
1613
 
1512
1614
  vec2 = wp.types.vector(length=2, dtype=wp_type)
1513
1615
  mat22 = wp.types.matrix(shape=(2, 2), dtype=wp_type)
1514
- mat33 = wp.types.matrix(shape=(3, 3), dtype=wp_type)
1515
1616
 
1516
1617
  def mattest_read_write_store(x: wp.array(dtype=wp_type), a: wp.array(dtype=mat22)):
1517
1618
  tid = wp.tid()
@@ -1528,17 +1629,8 @@ def test_mat_array_type_indexing(test, device, dtype, register_kernels=False):
1528
1629
  a[1, 1] = wp_type(3.0)
1529
1630
  x[i, j] = a
1530
1631
 
1531
- def mattest_in_register_overwrite(x: wp.array2d(dtype=mat22), y: wp.array(dtype=vec2)):
1532
- i, j = wp.tid()
1533
-
1534
- a = mat22(wp_type(0.0))
1535
- a[0] = y[i]
1536
- a[0, 1] = wp_type(3.0)
1537
- x[i, j] = a
1538
-
1539
1632
  kernel_read_write_store = getkernel(mattest_read_write_store, suffix=dtype.__name__)
1540
1633
  kernel_in_register = getkernel(mattest_in_register, suffix=dtype.__name__)
1541
- kernel_in_register_overwrite = getkernel(mattest_in_register_overwrite, suffix=dtype.__name__)
1542
1634
 
1543
1635
  if register_kernels:
1544
1636
  return
@@ -1568,19 +1660,6 @@ def test_mat_array_type_indexing(test, device, dtype, register_kernels=False):
1568
1660
  assert_np_equal(x.numpy(), np.array([[[[1.0, 1.0], [0.0, 3.0]]]], dtype=np_type))
1569
1661
  assert_np_equal(y.grad.numpy(), np.array([[1.0, 1.0]], dtype=np_type))
1570
1662
 
1571
- tape.reset()
1572
-
1573
- x = wp.zeros((1, 1), dtype=mat22, device=device, requires_grad=True)
1574
- y = wp.ones(1, dtype=vec2, device=device, requires_grad=True)
1575
-
1576
- with tape:
1577
- wp.launch(kernel_in_register_overwrite, dim=(1, 1), inputs=[x, y], device=device)
1578
-
1579
- tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1580
-
1581
- assert_np_equal(x.numpy(), np.array([[[[1.0, 3.0], [0.0, 0.0]]]], dtype=np_type))
1582
- assert_np_equal(y.grad.numpy(), np.array([[1.0, 0.0]], dtype=np_type))
1583
-
1584
1663
 
1585
1664
  # Test matrix constructors using explicit type (float16)
1586
1665
  # note that these tests are specifically not using generics / closure
@@ -1615,10 +1694,61 @@ def test_matrix_constructor_value_func():
1615
1694
  c = mat32d()
1616
1695
  d = mat32d(c, shape=(3, 2))
1617
1696
  e = mat32d(wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0))
1618
- f = mat32d(
1619
- wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
1620
- wp.vec3d(wp.float64(1.0), wp.float64(2.0), wp.float64(3.0)),
1697
+
1698
+
1699
+ @wp.kernel
1700
+ def test_matrix_from_vecs():
1701
+ m1 = wp.matrix_from_cols(
1702
+ wp.vec3(1.0, 2.0, 3.0),
1703
+ wp.vec3(4.0, 5.0, 6.0),
1704
+ wp.vec3(7.0, 8.0, 9.0),
1705
+ )
1706
+ wp.expect_eq(m1[0, 0], 1.0)
1707
+ wp.expect_eq(m1[0, 1], 4.0)
1708
+ wp.expect_eq(m1[0, 2], 7.0)
1709
+ wp.expect_eq(m1[1, 0], 2.0)
1710
+ wp.expect_eq(m1[1, 1], 5.0)
1711
+ wp.expect_eq(m1[1, 2], 8.0)
1712
+ wp.expect_eq(m1[2, 0], 3.0)
1713
+ wp.expect_eq(m1[2, 1], 6.0)
1714
+ wp.expect_eq(m1[2, 2], 9.0)
1715
+
1716
+ m2 = wp.matrix_from_rows(
1717
+ wp.vec3(1.0, 2.0, 3.0),
1718
+ wp.vec3(4.0, 5.0, 6.0),
1719
+ wp.vec3(7.0, 8.0, 9.0),
1720
+ )
1721
+ wp.expect_eq(m2[0, 0], 1.0)
1722
+ wp.expect_eq(m2[0, 1], 2.0)
1723
+ wp.expect_eq(m2[0, 2], 3.0)
1724
+ wp.expect_eq(m2[1, 0], 4.0)
1725
+ wp.expect_eq(m2[1, 1], 5.0)
1726
+ wp.expect_eq(m2[1, 2], 6.0)
1727
+ wp.expect_eq(m2[2, 0], 7.0)
1728
+ wp.expect_eq(m2[2, 1], 8.0)
1729
+ wp.expect_eq(m2[2, 2], 9.0)
1730
+
1731
+ m3 = wp.matrix_from_cols(
1732
+ wp.vec3(1.0, 2.0, 3.0),
1733
+ wp.vec3(4.0, 5.0, 6.0),
1621
1734
  )
1735
+ wp.expect_eq(m3[0, 0], 1.0)
1736
+ wp.expect_eq(m3[0, 1], 4.0)
1737
+ wp.expect_eq(m3[1, 0], 2.0)
1738
+ wp.expect_eq(m3[1, 1], 5.0)
1739
+ wp.expect_eq(m3[2, 0], 3.0)
1740
+ wp.expect_eq(m3[2, 1], 6.0)
1741
+
1742
+ m4 = wp.matrix_from_rows(
1743
+ wp.vec3(1.0, 2.0, 3.0),
1744
+ wp.vec3(4.0, 5.0, 6.0),
1745
+ )
1746
+ wp.expect_eq(m4[0, 0], 1.0)
1747
+ wp.expect_eq(m4[0, 1], 2.0)
1748
+ wp.expect_eq(m4[0, 2], 3.0)
1749
+ wp.expect_eq(m4[1, 0], 4.0)
1750
+ wp.expect_eq(m4[1, 1], 5.0)
1751
+ wp.expect_eq(m4[1, 2], 6.0)
1622
1752
 
1623
1753
 
1624
1754
  # Same as above but with a default (float/int) type
@@ -1735,15 +1865,20 @@ def test_matrix_len(test, device):
1735
1865
 
1736
1866
  @wp.kernel
1737
1867
  def matrix_augassign_kernel(
1738
- a: wp.array(dtype=wp.mat22), b: wp.array(dtype=wp.mat22), c: wp.array(dtype=wp.mat22), d: wp.array(dtype=wp.mat22)
1868
+ a: wp.array(dtype=wp.mat22),
1869
+ b: wp.array(dtype=wp.mat22),
1870
+ x: wp.array(dtype=wp.vec2),
1871
+ c: wp.array(dtype=wp.mat22),
1872
+ d: wp.array(dtype=wp.mat22),
1873
+ y: wp.array(dtype=wp.vec2),
1739
1874
  ):
1740
1875
  i = wp.tid()
1741
1876
 
1742
1877
  m1 = wp.mat22()
1743
1878
  m2 = b[i]
1879
+ v2 = x[i]
1744
1880
 
1745
- m1[0, 0] += m2[0, 0]
1746
- m1[0, 1] += m2[0, 1]
1881
+ m1[0] += v2
1747
1882
  m1[1, 0] += m2[1, 0]
1748
1883
  m1[1, 1] += m2[1, 1]
1749
1884
 
@@ -1751,9 +1886,9 @@ def matrix_augassign_kernel(
1751
1886
 
1752
1887
  m3 = wp.mat22()
1753
1888
  m4 = d[i]
1889
+ v4 = y[i]
1754
1890
 
1755
- m3[0, 0] -= m4[0, 0]
1756
- m3[0, 1] -= m4[0, 1]
1891
+ m3[0] -= v4
1757
1892
  m3[1, 0] -= m4[1, 0]
1758
1893
  m3[1, 1] -= m4[1, 1]
1759
1894
 
@@ -1761,27 +1896,61 @@ def matrix_augassign_kernel(
1761
1896
 
1762
1897
 
1763
1898
  def test_matrix_augassign(test, device):
1764
- N = 3
1899
+ N = 1
1765
1900
 
1766
- a = wp.zeros(N, dtype=wp.mat22, requires_grad=True)
1767
- b = wp.ones(N, dtype=wp.mat22, requires_grad=True)
1901
+ a = wp.zeros(N, dtype=wp.mat22, requires_grad=True, device=device)
1902
+ b = wp.ones(N, dtype=wp.mat22, requires_grad=True, device=device)
1903
+ x = wp.ones(N, dtype=wp.vec2, requires_grad=True, device=device)
1768
1904
 
1769
- c = wp.zeros(N, dtype=wp.mat22, requires_grad=True)
1770
- d = wp.ones(N, dtype=wp.mat22, requires_grad=True)
1905
+ c = wp.zeros(N, dtype=wp.mat22, requires_grad=True, device=device)
1906
+ d = wp.ones(N, dtype=wp.mat22, requires_grad=True, device=device)
1907
+ y = wp.ones(N, dtype=wp.vec2, requires_grad=True, device=device)
1771
1908
 
1772
1909
  tape = wp.Tape()
1773
1910
  with tape:
1774
- wp.launch(matrix_augassign_kernel, N, inputs=[a, b, c, d])
1911
+ wp.launch(matrix_augassign_kernel, N, inputs=[a, b, x, c, d, y], device=device)
1775
1912
 
1776
1913
  tape.backward(grads={a: wp.ones_like(a), c: wp.ones_like(c)})
1777
1914
 
1778
1915
  assert_np_equal(a.numpy(), wp.ones_like(a).numpy())
1779
1916
  assert_np_equal(a.grad.numpy(), wp.ones_like(a).numpy())
1780
- assert_np_equal(b.grad.numpy(), wp.ones_like(a).numpy())
1917
+ assert_np_equal(b.grad.numpy(), np.array([[[0, 0], [1, 1]]], dtype=float))
1918
+ assert_np_equal(x.grad.numpy(), np.array([[1, 1]], dtype=float))
1781
1919
 
1782
1920
  assert_np_equal(c.numpy(), -wp.ones_like(c).numpy())
1783
1921
  assert_np_equal(c.grad.numpy(), wp.ones_like(c).numpy())
1784
- assert_np_equal(d.grad.numpy(), -wp.ones_like(d).numpy())
1922
+ assert_np_equal(d.grad.numpy(), np.array([[[0, 0], [-1, -1]]], dtype=float))
1923
+ assert_np_equal(y.grad.numpy(), np.array([[-1, -1]], dtype=float))
1924
+
1925
+
1926
+ def test_matrix_assign_copy(test, device):
1927
+ saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
1928
+ try:
1929
+ wp.config.enable_vector_component_overwrites = True
1930
+
1931
+ @wp.kernel
1932
+ def mat_in_register_overwrite(x: wp.array2d(dtype=wp.mat22), y: wp.array(dtype=wp.vec2)):
1933
+ i, j = wp.tid()
1934
+
1935
+ a = wp.mat22()
1936
+ a[0] = y[i]
1937
+ a[0, 1] = 3.0
1938
+ x[i, j] = a
1939
+
1940
+ x = wp.zeros((1, 1), dtype=wp.mat22, device=device, requires_grad=True)
1941
+ y = wp.ones(1, dtype=wp.vec2, device=device, requires_grad=True)
1942
+
1943
+ tape = wp.Tape()
1944
+ with tape:
1945
+ wp.launch(mat_in_register_overwrite, dim=(1, 1), inputs=[x, y], device=device)
1946
+
1947
+ tape.backward(grads={x: wp.ones_like(x, requires_grad=False)})
1948
+
1949
+ assert_np_equal(x.numpy(), np.array([[[[1.0, 3.0], [0.0, 0.0]]]], dtype=float))
1950
+ assert_np_equal(y.grad.numpy(), np.array([[1.0, 0.0]], dtype=float))
1951
+
1952
+ finally:
1953
+ wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
1785
1954
 
1786
1955
 
1787
1956
  devices = get_test_devices()
@@ -1806,6 +1975,7 @@ add_kernel_test(TestMat, test_constructors_explicit_precision, dim=1, devices=de
1806
1975
  add_kernel_test(TestMat, test_constructors_default_precision, dim=1, devices=devices)
1807
1976
  add_kernel_test(TestMat, test_constructors_constant_shape, dim=1, devices=devices)
1808
1977
  add_kernel_test(TestMat, test_matrix_constructor_value_func, dim=1, devices=devices)
1978
+ add_kernel_test(TestMat, test_matrix_from_vecs, dim=1, devices=devices)
1809
1979
 
1810
1980
  mat103 = wp.types.matrix(shape=(10, 3), dtype=float)
1811
1981
  add_kernel_test(
@@ -1870,18 +2040,6 @@ add_function_test(
1870
2040
  test_tpl_constructor_error_incompatible_sizes,
1871
2041
  devices=devices,
1872
2042
  )
1873
- add_function_test(
1874
- TestMat,
1875
- "test_tpl_constructor_error_invalid_vector_count",
1876
- test_tpl_constructor_error_invalid_vector_count,
1877
- devices=devices,
1878
- )
1879
- add_function_test(
1880
- TestMat,
1881
- "test_tpl_constructor_error_invalid_vector_shape",
1882
- test_tpl_constructor_error_invalid_vector_shape,
1883
- devices=devices,
1884
- )
1885
2043
  add_function_test(
1886
2044
  TestMat,
1887
2045
  "test_tpl_constructor_error_invalid_arg_count",
@@ -1900,6 +2058,9 @@ for dtype in np_float_types:
1900
2058
  TestMat, f"test_inverse_{dtype.__name__}", test_inverse, devices=devices, dtype=dtype
1901
2059
  )
1902
2060
  add_function_test_register_kernel(TestMat, f"test_svd_{dtype.__name__}", test_svd, devices=devices, dtype=dtype)
2061
+ add_function_test_register_kernel(
2062
+ TestMat, f"test_svd_2D{dtype.__name__}", test_svd_2D, devices=devices, dtype=dtype
2063
+ )
1903
2064
  add_function_test_register_kernel(TestMat, f"test_qr_{dtype.__name__}", test_qr, devices=devices, dtype=dtype)
1904
2065
  add_function_test_register_kernel(TestMat, f"test_eig_{dtype.__name__}", test_eig, devices=devices, dtype=dtype)
1905
2066
  add_function_test_register_kernel(
@@ -1914,13 +2075,14 @@ for dtype in np_float_types:
1914
2075
  add_function_test_register_kernel(TestMat, f"test_skew_{dtype.__name__}", test_skew, devices=devices, dtype=dtype)
1915
2076
  add_function_test_register_kernel(
1916
2077
  TestMat,
1917
- f"test_mat_array_type_indexing_{dtype.__name__}",
1918
- test_mat_array_type_indexing,
2078
+ f"test_matrix_assign_inplace_{dtype.__name__}",
2079
+ test_matrix_assign_inplace,
1919
2080
  devices=devices,
1920
2081
  dtype=dtype,
1921
2082
  )
1922
2083
  add_function_test(TestMat, "test_matrix_len", test_matrix_len, devices=devices)
1923
2084
  add_function_test(TestMat, "test_matrix_augassign", test_matrix_augassign, devices=devices)
2085
+ add_function_test(TestMat, "test_matrix_assign_copy", test_matrix_assign_copy, devices=devices)
1924
2086
 
1925
2087
  if __name__ == "__main__":
1926
2088
  wp.clear_kernel_cache()
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
@@ -326,19 +334,19 @@ def test_constructors(test, device, dtype, register_kernels=False):
326
334
  outcomponents: wp.array(dtype=wptype),
327
335
  ):
328
336
  # multiply outputs by 2 so we've got something to backpropagate:
329
- m2result = wptype(2) * mat22(vec2(input[0], input[2]), vec2(input[1], input[3]))
330
- m3result = wptype(2) * mat33(
337
+ m2result = wptype(2) * wp.matrix_from_cols(vec2(input[0], input[2]), vec2(input[1], input[3]))
338
+ m3result = wptype(2) * wp.matrix_from_cols(
331
339
  vec3(input[4], input[7], input[10]),
332
340
  vec3(input[5], input[8], input[11]),
333
341
  vec3(input[6], input[9], input[12]),
334
342
  )
335
- m4result = wptype(2) * mat44(
343
+ m4result = wptype(2) * wp.matrix_from_cols(
336
344
  vec4(input[13], input[17], input[21], input[25]),
337
345
  vec4(input[14], input[18], input[22], input[26]),
338
346
  vec4(input[15], input[19], input[23], input[27]),
339
347
  vec4(input[16], input[20], input[24], input[28]),
340
348
  )
341
- m5result = wptype(2) * mat55(
349
+ m5result = wptype(2) * wp.matrix_from_cols(
342
350
  vec5(input[29], input[34], input[39], input[44], input[49]),
343
351
  vec5(input[30], input[35], input[40], input[45], input[50]),
344
352
  vec5(input[31], input[36], input[41], input[46], input[51]),
warp/tests/test_math.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
  from typing import Any, NamedTuple
warp/tests/test_mlp.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  # TODO: add more tests for kernels and generics
9
17
 
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17
 
warp/tests/test_noise.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import unittest
9
17