warp-lang 1.6.1__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (401) hide show
  1. warp/__init__.py +21 -7
  2. warp/autograd.py +14 -6
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +424 -6
  6. warp/build_dll.py +20 -20
  7. warp/builtins.py +467 -368
  8. warp/codegen.py +193 -125
  9. warp/config.py +56 -12
  10. warp/constants.py +14 -6
  11. warp/context.py +524 -277
  12. warp/dlpack.py +22 -12
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/assets/nonuniform.usd +0 -0
  15. warp/examples/assets/nvidia_logo.png +0 -0
  16. warp/examples/benchmarks/benchmark_api.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  19. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  21. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  24. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  25. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  26. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  27. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  28. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  29. warp/examples/benchmarks/benchmark_launches.py +14 -6
  30. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  31. warp/examples/browse.py +14 -6
  32. warp/examples/core/example_cupy.py +14 -6
  33. warp/examples/core/example_dem.py +14 -6
  34. warp/examples/core/example_fluid.py +14 -6
  35. warp/examples/core/example_graph_capture.py +14 -6
  36. warp/examples/core/example_marching_cubes.py +14 -6
  37. warp/examples/core/example_mesh.py +14 -6
  38. warp/examples/core/example_mesh_intersect.py +14 -6
  39. warp/examples/core/example_nvdb.py +14 -6
  40. warp/examples/core/example_raycast.py +14 -6
  41. warp/examples/core/example_raymarch.py +14 -6
  42. warp/examples/core/example_render_opengl.py +14 -6
  43. warp/examples/core/example_sample_mesh.py +300 -0
  44. warp/examples/core/example_sph.py +14 -6
  45. warp/examples/core/example_torch.py +14 -6
  46. warp/examples/core/example_wave.py +14 -6
  47. warp/examples/fem/example_adaptive_grid.py +14 -6
  48. warp/examples/fem/example_apic_fluid.py +15 -7
  49. warp/examples/fem/example_burgers.py +16 -8
  50. warp/examples/fem/example_convection_diffusion.py +14 -6
  51. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  52. warp/examples/fem/example_deformed_geometry.py +15 -7
  53. warp/examples/fem/example_diffusion.py +14 -6
  54. warp/examples/fem/example_diffusion_3d.py +14 -6
  55. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  56. warp/examples/fem/example_distortion_energy.py +15 -7
  57. warp/examples/fem/example_magnetostatics.py +20 -12
  58. warp/examples/fem/example_mixed_elasticity.py +14 -6
  59. warp/examples/fem/example_navier_stokes.py +14 -6
  60. warp/examples/fem/example_nonconforming_contact.py +14 -6
  61. warp/examples/fem/example_stokes.py +14 -6
  62. warp/examples/fem/example_stokes_transfer.py +14 -6
  63. warp/examples/fem/example_streamlines.py +14 -6
  64. warp/examples/fem/utils.py +24 -3
  65. warp/examples/interop/example_jax_callable.py +116 -0
  66. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  67. warp/examples/interop/example_jax_kernel.py +205 -0
  68. warp/examples/optim/example_bounce.py +14 -6
  69. warp/examples/optim/example_cloth_throw.py +14 -6
  70. warp/examples/optim/example_diffray.py +14 -6
  71. warp/examples/optim/example_drone.py +14 -6
  72. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  73. warp/examples/optim/example_inverse_kinematics.py +14 -6
  74. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  75. warp/examples/optim/example_softbody_properties.py +14 -6
  76. warp/examples/optim/example_spring_cage.py +14 -6
  77. warp/examples/optim/example_trajectory.py +14 -6
  78. warp/examples/sim/example_cartpole.py +14 -6
  79. warp/examples/sim/example_cloth.py +14 -6
  80. warp/examples/sim/example_cloth_self_contact.py +14 -6
  81. warp/examples/sim/example_granular.py +14 -6
  82. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  83. warp/examples/sim/example_jacobian_ik.py +14 -6
  84. warp/examples/sim/example_particle_chain.py +14 -6
  85. warp/examples/sim/example_quadruped.py +14 -6
  86. warp/examples/sim/example_rigid_chain.py +14 -6
  87. warp/examples/sim/example_rigid_contact.py +14 -6
  88. warp/examples/sim/example_rigid_force.py +14 -6
  89. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  90. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  91. warp/examples/sim/example_soft_body.py +14 -6
  92. warp/examples/tile/example_tile_cholesky.py +14 -6
  93. warp/examples/tile/example_tile_convolution.py +14 -6
  94. warp/examples/tile/example_tile_fft.py +14 -6
  95. warp/examples/tile/example_tile_filtering.py +14 -6
  96. warp/examples/tile/example_tile_matmul.py +16 -10
  97. warp/examples/tile/example_tile_mlp.py +14 -6
  98. warp/examples/tile/example_tile_nbody.py +14 -6
  99. warp/examples/tile/example_tile_walker.py +14 -6
  100. warp/fabric.py +15 -0
  101. warp/fem/__init__.py +26 -1
  102. warp/fem/adaptivity.py +19 -4
  103. warp/fem/cache.py +15 -0
  104. warp/fem/dirichlet.py +15 -0
  105. warp/fem/domain.py +15 -0
  106. warp/fem/field/__init__.py +15 -0
  107. warp/fem/field/field.py +15 -0
  108. warp/fem/field/nodal_field.py +37 -68
  109. warp/fem/field/restriction.py +15 -0
  110. warp/fem/field/virtual.py +77 -23
  111. warp/fem/geometry/__init__.py +15 -0
  112. warp/fem/geometry/adaptive_nanogrid.py +24 -10
  113. warp/fem/geometry/closest_point.py +16 -1
  114. warp/fem/geometry/deformed_geometry.py +20 -2
  115. warp/fem/geometry/element.py +15 -0
  116. warp/fem/geometry/geometry.py +20 -0
  117. warp/fem/geometry/grid_2d.py +27 -12
  118. warp/fem/geometry/grid_3d.py +27 -15
  119. warp/fem/geometry/hexmesh.py +20 -7
  120. warp/fem/geometry/nanogrid.py +24 -11
  121. warp/fem/geometry/partition.py +15 -0
  122. warp/fem/geometry/quadmesh.py +28 -13
  123. warp/fem/geometry/tetmesh.py +18 -4
  124. warp/fem/geometry/trimesh.py +18 -8
  125. warp/fem/integrate.py +277 -93
  126. warp/fem/linalg.py +20 -5
  127. warp/fem/operator.py +15 -0
  128. warp/fem/polynomial.py +15 -0
  129. warp/fem/quadrature/__init__.py +15 -0
  130. warp/fem/quadrature/pic_quadrature.py +52 -22
  131. warp/fem/quadrature/quadrature.py +209 -25
  132. warp/fem/space/__init__.py +16 -1
  133. warp/fem/space/basis_function_space.py +19 -2
  134. warp/fem/space/basis_space.py +40 -18
  135. warp/fem/space/dof_mapper.py +15 -0
  136. warp/fem/space/function_space.py +15 -0
  137. warp/fem/space/grid_2d_function_space.py +15 -0
  138. warp/fem/space/grid_3d_function_space.py +15 -0
  139. warp/fem/space/hexmesh_function_space.py +17 -2
  140. warp/fem/space/nanogrid_function_space.py +15 -0
  141. warp/fem/space/partition.py +21 -2
  142. warp/fem/space/quadmesh_function_space.py +23 -8
  143. warp/fem/space/restriction.py +15 -0
  144. warp/fem/space/shape/__init__.py +15 -0
  145. warp/fem/space/shape/cube_shape_function.py +38 -23
  146. warp/fem/space/shape/shape_function.py +15 -0
  147. warp/fem/space/shape/square_shape_function.py +27 -12
  148. warp/fem/space/shape/tet_shape_function.py +15 -0
  149. warp/fem/space/shape/triangle_shape_function.py +16 -1
  150. warp/fem/space/tetmesh_function_space.py +18 -3
  151. warp/fem/space/topology.py +15 -0
  152. warp/fem/space/trimesh_function_space.py +17 -2
  153. warp/fem/types.py +15 -0
  154. warp/fem/utils.py +27 -6
  155. warp/jax.py +28 -7
  156. warp/jax_experimental/__init__.py +16 -0
  157. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -33
  158. warp/jax_experimental/ffi.py +698 -0
  159. warp/jax_experimental/xla_ffi.py +602 -0
  160. warp/math.py +103 -6
  161. warp/native/array.h +28 -6
  162. warp/native/builtin.h +44 -9
  163. warp/native/bvh.cpp +18 -7
  164. warp/native/bvh.cu +57 -20
  165. warp/native/bvh.h +17 -7
  166. warp/native/clang/clang.cpp +45 -9
  167. warp/native/coloring.cpp +15 -6
  168. warp/native/crt.cpp +15 -6
  169. warp/native/crt.h +15 -6
  170. warp/native/cuda_crt.h +15 -6
  171. warp/native/cuda_util.cpp +29 -6
  172. warp/native/cuda_util.h +17 -6
  173. warp/native/error.cpp +15 -6
  174. warp/native/error.h +15 -6
  175. warp/native/exports.h +85 -63
  176. warp/native/fabric.h +15 -6
  177. warp/native/hashgrid.cpp +15 -6
  178. warp/native/hashgrid.cu +15 -6
  179. warp/native/hashgrid.h +15 -6
  180. warp/native/initializer_array.h +15 -6
  181. warp/native/intersect.h +41 -32
  182. warp/native/intersect_adj.h +48 -39
  183. warp/native/intersect_tri.h +17 -0
  184. warp/native/marching.cpp +16 -0
  185. warp/native/marching.cu +16 -7
  186. warp/native/marching.h +17 -0
  187. warp/native/mat.h +528 -15
  188. warp/native/mathdx.cpp +15 -6
  189. warp/native/matnn.h +15 -6
  190. warp/native/mesh.cpp +15 -6
  191. warp/native/mesh.cu +15 -6
  192. warp/native/mesh.h +25 -16
  193. warp/native/noise.h +15 -6
  194. warp/native/quat.h +114 -17
  195. warp/native/rand.h +21 -6
  196. warp/native/range.h +15 -6
  197. warp/native/reduce.cpp +15 -6
  198. warp/native/reduce.cu +15 -6
  199. warp/native/runlength_encode.cpp +15 -6
  200. warp/native/runlength_encode.cu +15 -6
  201. warp/native/scan.cpp +15 -6
  202. warp/native/scan.cu +15 -6
  203. warp/native/scan.h +15 -6
  204. warp/native/solid_angle.h +17 -0
  205. warp/native/sort.cpp +137 -65
  206. warp/native/sort.cu +167 -21
  207. warp/native/sort.h +23 -7
  208. warp/native/sparse.cpp +58 -28
  209. warp/native/sparse.cu +67 -23
  210. warp/native/spatial.h +15 -6
  211. warp/native/svd.h +131 -6
  212. warp/native/temp_buffer.h +15 -6
  213. warp/native/tile.h +316 -111
  214. warp/native/tile_reduce.h +61 -9
  215. warp/native/vec.h +83 -13
  216. warp/native/volume.cpp +100 -119
  217. warp/native/volume.cu +15 -6
  218. warp/native/volume.h +15 -6
  219. warp/native/volume_builder.cu +40 -16
  220. warp/native/volume_builder.h +21 -6
  221. warp/native/volume_impl.h +15 -6
  222. warp/native/warp.cpp +20 -12
  223. warp/native/warp.cu +114 -16
  224. warp/native/warp.h +34 -16
  225. warp/optim/__init__.py +14 -6
  226. warp/optim/adam.py +14 -6
  227. warp/optim/linear.py +25 -10
  228. warp/optim/sgd.py +14 -6
  229. warp/paddle.py +14 -6
  230. warp/render/__init__.py +14 -6
  231. warp/render/render_opengl.py +14 -6
  232. warp/render/render_usd.py +14 -6
  233. warp/render/utils.py +14 -6
  234. warp/sim/__init__.py +14 -7
  235. warp/sim/articulation.py +18 -10
  236. warp/sim/collide.py +35 -16
  237. warp/sim/graph_coloring.py +14 -6
  238. warp/sim/import_mjcf.py +463 -162
  239. warp/sim/import_snu.py +14 -7
  240. warp/sim/import_urdf.py +46 -18
  241. warp/sim/import_usd.py +14 -7
  242. warp/sim/inertia.py +14 -6
  243. warp/sim/integrator.py +14 -6
  244. warp/sim/integrator_euler.py +19 -11
  245. warp/sim/integrator_featherstone.py +17 -16
  246. warp/sim/integrator_vbd.py +222 -8
  247. warp/sim/integrator_xpbd.py +19 -11
  248. warp/sim/model.py +56 -19
  249. warp/sim/particles.py +14 -6
  250. warp/sim/render.py +14 -6
  251. warp/sim/utils.py +17 -2
  252. warp/sparse.py +657 -555
  253. warp/stubs.py +231 -19
  254. warp/tape.py +14 -6
  255. warp/tests/aux_test_class_kernel.py +14 -6
  256. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  257. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  258. warp/tests/aux_test_dependent.py +14 -6
  259. warp/tests/aux_test_grad_customs.py +14 -6
  260. warp/tests/aux_test_instancing_gc.py +14 -6
  261. warp/tests/aux_test_module_unload.py +14 -6
  262. warp/tests/aux_test_name_clash1.py +14 -6
  263. warp/tests/aux_test_name_clash2.py +14 -6
  264. warp/tests/aux_test_unresolved_func.py +14 -6
  265. warp/tests/aux_test_unresolved_symbol.py +14 -6
  266. warp/tests/cuda/__init__.py +0 -0
  267. warp/tests/{test_async.py → cuda/test_async.py} +14 -6
  268. warp/tests/{test_ipc.py → cuda/test_ipc.py} +14 -6
  269. warp/tests/{test_mempool.py → cuda/test_mempool.py} +53 -6
  270. warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +14 -6
  271. warp/tests/{test_peer.py → cuda/test_peer.py} +14 -6
  272. warp/tests/{test_pinned.py → cuda/test_pinned.py} +14 -6
  273. warp/tests/{test_streams.py → cuda/test_streams.py} +85 -6
  274. warp/tests/geometry/__init__.py +0 -0
  275. warp/tests/{test_bvh.py → geometry/test_bvh.py} +14 -6
  276. warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +14 -6
  277. warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +14 -6
  278. warp/tests/{test_mesh.py → geometry/test_mesh.py} +14 -6
  279. warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +14 -6
  280. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +80 -69
  281. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +15 -7
  282. warp/tests/{test_volume.py → geometry/test_volume.py} +55 -12
  283. warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +14 -6
  284. warp/tests/interop/__init__.py +0 -0
  285. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +42 -11
  286. warp/tests/{test_jax.py → interop/test_jax.py} +14 -6
  287. warp/tests/{test_paddle.py → interop/test_paddle.py} +14 -6
  288. warp/tests/{test_torch.py → interop/test_torch.py} +14 -6
  289. warp/tests/run_coverage_serial.py +14 -6
  290. warp/tests/sim/__init__.py +0 -0
  291. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +23 -16
  292. warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +14 -6
  293. warp/tests/{test_collision.py → sim/test_collision.py} +16 -8
  294. warp/tests/{test_coloring.py → sim/test_coloring.py} +14 -7
  295. warp/tests/{test_model.py → sim/test_model.py} +55 -7
  296. warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +14 -6
  297. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +16 -7
  298. warp/tests/sim/test_vbd.py +597 -0
  299. warp/tests/test_adam.py +14 -6
  300. warp/tests/test_arithmetic.py +14 -6
  301. warp/tests/test_array.py +14 -6
  302. warp/tests/test_array_reduce.py +14 -6
  303. warp/tests/test_assert.py +14 -6
  304. warp/tests/test_atomic.py +14 -6
  305. warp/tests/test_bool.py +15 -7
  306. warp/tests/test_builtins_resolution.py +14 -6
  307. warp/tests/test_closest_point_edge_edge.py +14 -6
  308. warp/tests/test_codegen.py +14 -6
  309. warp/tests/test_codegen_instancing.py +14 -6
  310. warp/tests/test_compile_consts.py +14 -6
  311. warp/tests/test_conditional.py +14 -6
  312. warp/tests/test_context.py +14 -6
  313. warp/tests/test_copy.py +14 -6
  314. warp/tests/test_ctypes.py +14 -6
  315. warp/tests/test_dense.py +14 -6
  316. warp/tests/test_devices.py +14 -6
  317. warp/tests/test_examples.py +42 -42
  318. warp/tests/test_fabricarray.py +14 -6
  319. warp/tests/test_fast_math.py +14 -6
  320. warp/tests/test_fem.py +37 -10
  321. warp/tests/test_fp16.py +14 -6
  322. warp/tests/test_func.py +14 -6
  323. warp/tests/test_future_annotations.py +14 -6
  324. warp/tests/test_generics.py +14 -6
  325. warp/tests/test_grad.py +14 -6
  326. warp/tests/test_grad_customs.py +14 -6
  327. warp/tests/test_grad_debug.py +14 -6
  328. warp/tests/test_implicit_init.py +14 -6
  329. warp/tests/test_import.py +14 -6
  330. warp/tests/test_indexedarray.py +14 -6
  331. warp/tests/test_intersect.py +14 -6
  332. warp/tests/test_iter.py +14 -6
  333. warp/tests/test_large.py +14 -6
  334. warp/tests/test_launch.py +14 -6
  335. warp/tests/test_lerp.py +14 -6
  336. warp/tests/test_linear_solvers.py +15 -11
  337. warp/tests/test_lvalue.py +14 -6
  338. warp/tests/test_mat.py +247 -85
  339. warp/tests/test_mat_lite.py +14 -6
  340. warp/tests/test_mat_scalar_ops.py +18 -10
  341. warp/tests/test_math.py +14 -6
  342. warp/tests/test_mlp.py +14 -6
  343. warp/tests/test_module_hashing.py +14 -6
  344. warp/tests/test_modules_lite.py +14 -6
  345. warp/tests/test_noise.py +14 -6
  346. warp/tests/test_operators.py +14 -6
  347. warp/tests/test_options.py +14 -6
  348. warp/tests/test_overwrite.py +15 -60
  349. warp/tests/test_print.py +14 -6
  350. warp/tests/test_quat.py +81 -52
  351. warp/tests/test_rand.py +58 -43
  352. warp/tests/test_reload.py +14 -6
  353. warp/tests/test_rounding.py +14 -6
  354. warp/tests/test_runlength_encode.py +14 -6
  355. warp/tests/test_scalar_ops.py +14 -6
  356. warp/tests/test_smoothstep.py +14 -6
  357. warp/tests/test_snippet.py +15 -0
  358. warp/tests/test_sparse.py +61 -12
  359. warp/tests/test_spatial.py +89 -6
  360. warp/tests/test_special_values.py +14 -6
  361. warp/tests/test_static.py +15 -7
  362. warp/tests/test_struct.py +14 -6
  363. warp/tests/test_tape.py +14 -6
  364. warp/tests/test_transient_module.py +14 -6
  365. warp/tests/test_triangle_closest_point.py +14 -6
  366. warp/tests/test_types.py +14 -6
  367. warp/tests/test_utils.py +98 -10
  368. warp/tests/test_vec.py +60 -40
  369. warp/tests/test_vec_lite.py +14 -6
  370. warp/tests/test_vec_scalar_ops.py +14 -6
  371. warp/tests/test_verify_fp.py +14 -6
  372. warp/tests/tile/__init__.py +0 -0
  373. warp/tests/{test_tile.py → tile/test_tile.py} +150 -57
  374. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +15 -7
  375. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +23 -12
  376. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +39 -20
  377. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +74 -7
  378. warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +14 -6
  379. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +15 -7
  380. warp/tests/unittest_serial.py +15 -6
  381. warp/tests/unittest_suites.py +59 -65
  382. warp/tests/unittest_utils.py +16 -7
  383. warp/tests/walkthrough_debug.py +14 -6
  384. warp/thirdparty/unittest_parallel.py +15 -8
  385. warp/torch.py +14 -6
  386. warp/types.py +124 -664
  387. warp/utils.py +151 -78
  388. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/METADATA +39 -12
  389. warp_lang-1.7.0.dist-info/RECORD +429 -0
  390. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  391. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  392. warp/examples/optim/example_walker.py +0 -309
  393. warp/native/cutlass_gemm.cpp +0 -34
  394. warp/native/cutlass_gemm.cu +0 -373
  395. warp/tests/test_matmul.py +0 -503
  396. warp/tests/test_matmul_lite.py +0 -403
  397. warp/tests/test_vbd.py +0 -378
  398. warp/tests/unused_test_misc.py +0 -69
  399. warp_lang-1.6.1.dist-info/LICENSE.md +0 -126
  400. warp_lang-1.6.1.dist-info/RECORD +0 -419
  401. {warp_lang-1.6.1.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/sparse.cpp CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #include "warp.h"
@@ -72,7 +81,8 @@ template <typename T> void bsr_dyn_block_transpose(const T* src, T* dest, int ro
72
81
  template <typename T>
73
82
  int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_block, const int row_count,
74
83
  const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
75
- const bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns, T* bsr_values)
84
+ const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
85
+ int* bsr_columns, T* bsr_values)
76
86
  {
77
87
 
78
88
  // get specialized accumulator for common block sizes (1,1), (1,2), (1,3),
@@ -115,14 +125,33 @@ int bsr_matrix_from_triplets_host(const int rows_per_block, const int cols_per_b
115
125
  std::iota(block_indices.begin(), block_indices.end(), 0);
116
126
 
117
127
  // remove zero blocks and invalid row indices
118
- block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(),
119
- [&](int i)
120
- {
121
- return tpl_rows[i] < 0 || tpl_rows[i] >= row_count ||
122
- (prune_numerical_zeros && tpl_values &&
123
- block_is_zero_func(tpl_values + i * block_size, block_size));
124
- }),
125
- block_indices.end());
128
+
129
+ auto discard_block = [&](int i)
130
+ {
131
+ const int row = tpl_rows[i];
132
+ if (row < 0 || row >= row_count)
133
+ {
134
+ return true;
135
+ }
136
+
137
+ if (prune_numerical_zeros && tpl_values && block_is_zero_func(tpl_values + i * block_size, block_size))
138
+ {
139
+ return true;
140
+ }
141
+
142
+ if (!masked)
143
+ {
144
+ return false;
145
+ }
146
+
147
+ const int* beg = bsr_columns + bsr_offsets[row];
148
+ const int* end = bsr_columns + bsr_offsets[row + 1];
149
+ const int col = tpl_columns[i];
150
+ const int* block = std::lower_bound(beg, end, col);
151
+ return block == end || *block != col;
152
+ };
153
+
154
+ block_indices.erase(std::remove_if(block_indices.begin(), block_indices.end(), discard_block), block_indices.end());
126
155
 
127
156
  // sort block indices according to lexico order
128
157
  std::sort(block_indices.begin(), block_indices.end(), [tpl_rows, tpl_columns](int i, int j) -> bool
@@ -272,12 +301,12 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
272
301
 
273
302
  WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
274
303
  int* tpl_rows, int* tpl_columns, void* tpl_values,
275
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
276
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
304
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
305
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
277
306
  {
278
307
  bsr_matrix_from_triplets_host<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
279
- static_cast<const float*>(tpl_values), prune_numerical_zeros, bsr_offsets,
280
- bsr_columns, static_cast<float*>(bsr_values));
308
+ static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
309
+ bsr_offsets, bsr_columns, static_cast<float*>(bsr_values));
281
310
  if (bsr_nnz)
282
311
  {
283
312
  *bsr_nnz = bsr_offsets[row_count];
@@ -286,12 +315,12 @@ WP_API void bsr_matrix_from_triplets_float_host(int rows_per_block, int cols_per
286
315
 
287
316
  WP_API void bsr_matrix_from_triplets_double_host(int rows_per_block, int cols_per_block, int row_count, int nnz,
288
317
  int* tpl_rows, int* tpl_columns, void* tpl_values,
289
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
290
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
318
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
319
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
291
320
  {
292
321
  bsr_matrix_from_triplets_host<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
293
- static_cast<const double*>(tpl_values), prune_numerical_zeros, bsr_offsets,
294
- bsr_columns, static_cast<double*>(bsr_values));
322
+ static_cast<const double*>(tpl_values), prune_numerical_zeros, masked,
323
+ bsr_offsets, bsr_columns, static_cast<double*>(bsr_values));
295
324
  if (bsr_nnz)
296
325
  {
297
326
  *bsr_nnz = bsr_offsets[row_count];
@@ -318,16 +347,17 @@ WP_API void bsr_transpose_double_host(int rows_per_block, int cols_per_block, in
318
347
 
319
348
  #if !WP_ENABLE_CUDA
320
349
  WP_API void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
321
- int* tpl_rows, int* tpl_columns, void* tpl_values,
322
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
323
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
350
+ int* tpl_rows, int* tpl_columns, void* tpl_values,
351
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
352
+ int* bsr_columns, void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
324
353
  {
325
354
  }
326
355
 
327
356
  WP_API void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
328
357
  int* tpl_rows, int* tpl_columns, void* tpl_values,
329
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
330
- void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
358
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets,
359
+ int* bsr_columns, void* bsr_values, int* bsr_nnz,
360
+ void* bsr_nnz_event)
331
361
  {
332
362
  }
333
363
 
warp/native/sparse.cu CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #include "cuda_util.h"
@@ -52,10 +61,41 @@ template <typename T> struct BsrBlockIsNotZero
52
61
  }
53
62
  };
54
63
 
64
+ struct BsrBlockInMask
65
+ {
66
+ const int* bsr_offsets;
67
+ const int* bsr_columns;
68
+
69
+ CUDA_CALLABLE_DEVICE bool operator()(int row, int col) const
70
+ {
71
+ if (bsr_offsets == nullptr)
72
+ return true;
73
+
74
+ int lower = bsr_offsets[row];
75
+ int upper = bsr_offsets[row + 1] - 1;
76
+
77
+ while (lower < upper)
78
+ {
79
+ const int mid = lower + (upper - lower) / 2;
80
+
81
+ if (bsr_columns[mid] < col)
82
+ {
83
+ lower = mid + 1;
84
+ }
85
+ else
86
+ {
87
+ upper = mid;
88
+ }
89
+ }
90
+
91
+ return lower == upper && (bsr_columns[lower] == col);
92
+ }
93
+ };
94
+
55
95
  template <typename T>
56
96
  __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const int* tpl_rows, const int* tpl_columns,
57
- const BsrBlockIsNotZero<T> nonZero, uint32_t* block_indices,
58
- BsrRowCol* tpl_row_col)
97
+ const BsrBlockIsNotZero<T> nonZero, const BsrBlockInMask mask,
98
+ uint32_t* block_indices, BsrRowCol* tpl_row_col)
59
99
  {
60
100
  int block = blockIdx.x * blockDim.x + threadIdx.x;
61
101
  if (block >= nnz)
@@ -65,7 +105,8 @@ __global__ void bsr_fill_triplet_key_values(const int nnz, const int nrow, const
65
105
  const int col = tpl_columns[block];
66
106
  const bool is_valid = row >= 0 && row < nrow;
67
107
 
68
- const BsrRowCol row_col = is_valid && nonZero(block) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
108
+ const BsrRowCol row_col =
109
+ is_valid && nonZero(block) && mask(row, col) ? bsr_combine_row_col(row, col) : PRUNED_ROWCOL;
69
110
  tpl_row_col[block] = row_col;
70
111
  block_indices[block] = block;
71
112
  }
@@ -113,7 +154,7 @@ __global__ void bsr_find_row_offsets(uint32_t row_count, const T* d_nnz, const B
113
154
  }
114
155
 
115
156
  template <typename T>
116
- __global__ void bsr_merge_blocks(const uint32_t* d_nnz, int block_size, const uint32_t* block_offsets,
157
+ __global__ void bsr_merge_blocks(const int* d_nnz, int block_size, const uint32_t* block_offsets,
117
158
  const uint32_t* sorted_block_indices, const BsrRowCol* unique_row_cols,
118
159
  const T* tpl_values, int* bsr_cols, T* bsr_values)
119
160
 
@@ -154,8 +195,8 @@ __global__ void bsr_merge_blocks(const uint32_t* d_nnz, int block_size, const ui
154
195
  template <typename T>
155
196
  void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_per_block, const int row_count,
156
197
  const int nnz, const int* tpl_rows, const int* tpl_columns, const T* tpl_values,
157
- const bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
158
- T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
198
+ const bool prune_numerical_zeros, const bool masked, int* bsr_offsets,
199
+ int* bsr_columns, T* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
159
200
  {
160
201
  const int block_size = rows_per_block * cols_per_block;
161
202
 
@@ -177,8 +218,9 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
177
218
 
178
219
  // Combine rows and columns so we can sort on them both
179
220
  BsrBlockIsNotZero<T> isNotZero{block_size, prune_numerical_zeros ? tpl_values : nullptr};
221
+ BsrBlockInMask mask{masked ? bsr_offsets : nullptr, bsr_columns};
180
222
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_fill_triplet_key_values, nnz,
181
- (nnz, row_count, tpl_rows, tpl_columns, isNotZero, d_keys.Current(), d_values.Current()));
223
+ (nnz, row_count, tpl_rows, tpl_columns, isNotZero, mask, d_keys.Current(), d_values.Current()));
182
224
 
183
225
  // Sort
184
226
  {
@@ -205,7 +247,7 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
205
247
 
206
248
  if (bsr_nnz)
207
249
  {
208
- // Copy nnz to host, and record an event for the competed transfer if desired
250
+ // Copy nnz to host, and record an event for the completed transfer if desired
209
251
 
210
252
  memcpy_d2h(WP_CURRENT_CONTEXT, bsr_nnz, bsr_offsets + row_count, sizeof(int), stream);
211
253
 
@@ -227,7 +269,7 @@ void bsr_matrix_from_triplets_device(const int rows_per_block, const int cols_pe
227
269
 
228
270
  // Accumulate repeated blocks and set column indices
229
271
  wp_launch_device(WP_CURRENT_CONTEXT, bsr_merge_blocks, nnz,
230
- (unique_triplet_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
272
+ (bsr_offsets + row_count, block_size, d_keys.Alternate(), d_keys.Current(), d_values.Alternate(),
231
273
  tpl_values, bsr_columns, bsr_values));
232
274
  }
233
275
 
@@ -443,22 +485,24 @@ void bsr_transpose_device(int rows_per_block, int cols_per_block, int row_count,
443
485
 
444
486
  void bsr_matrix_from_triplets_float_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
445
487
  int* tpl_rows, int* tpl_columns, void* tpl_values,
446
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
488
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
447
489
  void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
448
490
  {
449
- return bsr_matrix_from_triplets_device<float>(
450
- rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns, static_cast<const float*>(tpl_values),
451
- prune_numerical_zeros, bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz, bsr_nnz_event);
491
+ return bsr_matrix_from_triplets_device<float>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns,
492
+ static_cast<const float*>(tpl_values), prune_numerical_zeros, masked,
493
+ bsr_offsets, bsr_columns, static_cast<float*>(bsr_values), bsr_nnz,
494
+ bsr_nnz_event);
452
495
  }
453
496
 
454
497
  void bsr_matrix_from_triplets_double_device(int rows_per_block, int cols_per_block, int row_count, int nnz,
455
498
  int* tpl_rows, int* tpl_columns, void* tpl_values,
456
- bool prune_numerical_zeros, int* bsr_offsets, int* bsr_columns,
499
+ bool prune_numerical_zeros, bool masked, int* bsr_offsets, int* bsr_columns,
457
500
  void* bsr_values, int* bsr_nnz, void* bsr_nnz_event)
458
501
  {
459
- return bsr_matrix_from_triplets_device<double>(
460
- rows_per_block, cols_per_block, row_count, nnz, tpl_rows, tpl_columns, static_cast<const double*>(tpl_values),
461
- prune_numerical_zeros, bsr_offsets, bsr_columns, static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
502
+ return bsr_matrix_from_triplets_device<double>(rows_per_block, cols_per_block, row_count, nnz, tpl_rows,
503
+ tpl_columns, static_cast<const double*>(tpl_values),
504
+ prune_numerical_zeros, masked, bsr_offsets, bsr_columns,
505
+ static_cast<double*>(bsr_values), bsr_nnz, bsr_nnz_event);
462
506
  }
463
507
 
464
508
  void bsr_transpose_float_device(int rows_per_block, int cols_per_block, int row_count, int col_count, int nnz,
warp/native/spatial.h CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #pragma once
warp/native/svd.h CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  // The MIT License (MIT)
@@ -423,6 +432,62 @@ void _svd(// input A
423
432
  );
424
433
  }
425
434
 
435
+
436
+ template<typename Type>
437
+ inline CUDA_CALLABLE
438
+ void _svd_2(// input A
439
+ Type a11, Type a12,
440
+ Type a21, Type a22,
441
+ // output U
442
+ Type &u11, Type &u12,
443
+ Type &u21, Type &u22,
444
+ // output S
445
+ Type &s11, Type &s12,
446
+ Type &s21, Type &s22,
447
+ // output V
448
+ Type &v11, Type &v12,
449
+ Type &v21, Type &v22)
450
+ {
451
+ // Step 1: Compute ATA
452
+ Type ATA11 = a11 * a11 + a21 * a21;
453
+ Type ATA12 = a11 * a12 + a21 * a22;
454
+ Type ATA22 = a12 * a12 + a22 * a22;
455
+
456
+ // Step 2: Eigenanalysis
457
+ Type trace = ATA11 + ATA22;
458
+ Type det = ATA11 * ATA22 - ATA12 * ATA12;
459
+ Type sqrt_term = sqrt(trace * trace - Type(4.0) * det);
460
+ Type lambda1 = (trace + sqrt_term) * Type(0.5);
461
+ Type lambda2 = (trace - sqrt_term) * Type(0.5);
462
+
463
+ // Step 3: Singular values
464
+ Type sigma1 = sqrt(lambda1);
465
+ Type sigma2 = sqrt(lambda2);
466
+
467
+ // Step 4: Eigenvectors (find V)
468
+ Type v1x = ATA12, v1y = lambda1 - ATA11; // For first eigenvector
469
+ Type v2x = ATA12, v2y = lambda2 - ATA11; // For second eigenvector
470
+ Type norm1 = sqrt(v1x * v1x + v1y * v1y);
471
+ Type norm2 = sqrt(v2x * v2x + v2y * v2y);
472
+
473
+ v11 = v1x / norm1; v12 = v2x / norm2;
474
+ v21 = v1y / norm1; v22 = v2y / norm2;
475
+
476
+ // Step 5: Compute U
477
+ Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
478
+ Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);
479
+
480
+ u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
481
+ u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
482
+ u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
483
+ u22 = (a21 * v12 + a22 * v22) * inv_sigma2;
484
+
485
+ // Step 6: Set S
486
+ s11 = sigma1; s12 = Type(0.0);
487
+ s21 = Type(0.0); s22 = sigma2;
488
+ }
489
+
490
+
426
491
  template<typename Type>
427
492
  inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
428
493
  Type s12, s13, s21, s23, s31, s32;
@@ -483,6 +548,66 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
483
548
  adj_A = adj_A + (u_term + v_term + sigma_term);
484
549
  }
485
550
 
551
+ template<typename Type>
552
+ inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
553
+ Type s12, s21;
554
+ _svd_2(A.data[0][0], A.data[0][1],
555
+ A.data[1][0], A.data[1][1],
556
+
557
+ U.data[0][0], U.data[0][1],
558
+ U.data[1][0], U.data[1][1],
559
+
560
+ sigma[0], s12,
561
+ s21, sigma[1],
562
+
563
+ V.data[0][0], V.data[0][1],
564
+ V.data[1][0], V.data[1][1]);
565
+ }
566
+
567
+ template<typename Type>
568
+ inline CUDA_CALLABLE void adj_svd2(const mat_t<2,2,Type>& A,
569
+ const mat_t<2,2,Type>& U,
570
+ const vec_t<2,Type>& sigma,
571
+ const mat_t<2,2,Type>& V,
572
+ mat_t<2,2,Type>& adj_A,
573
+ const mat_t<2,2,Type>& adj_U,
574
+ const vec_t<2,Type>& adj_sigma,
575
+ const mat_t<2,2,Type>& adj_V) {
576
+ Type s1_squared = sigma[0] * sigma[0];
577
+ Type s2_squared = sigma[1] * sigma[1];
578
+
579
+ // Compute inverse of (s1^2 - s2^2) if possible, use small epsilon to prevent division by zero
580
+ Type F01 = Type(1) / min(s2_squared - s1_squared, Type(-1e-6f));
581
+
582
+ // Construct the matrix F for the adjoint
583
+ mat_t<2,2,Type> F = mat_t<2,2,Type>(0.0, F01,
584
+ -F01, 0.0);
585
+
586
+ // Create a matrix to handle the adjoint of the singular values (diagonal matrix)
587
+ mat_t<2,2,Type> adj_sigma_mat = mat_t<2,2,Type>(adj_sigma[0], 0.0,
588
+ 0.0, adj_sigma[1]);
589
+
590
+ // Matrix for handling singular values (diagonal matrix with sigma values)
591
+ mat_t<2,2,Type> s_mat = mat_t<2,2,Type>(sigma[0], 0.0,
592
+ 0.0, sigma[1]);
593
+
594
+ // Compute the transpose of U and V
595
+ mat_t<2,2,Type> UT = transpose(U);
596
+ mat_t<2,2,Type> VT = transpose(V);
597
+
598
+ // Compute the term for sigma (diagonal matrix of adjoint singular values)
599
+ mat_t<2,2,Type> sigma_term = mul(U, mul(adj_sigma_mat, VT));
600
+
601
+ // Compute the adjoint contributions for U (left singular vectors)
602
+ mat_t<2,2,Type> u_term = mul(mul(U, mul(cw_mul(F, (mul(UT, adj_U) - mul(transpose(adj_U), U))), s_mat)), VT);
603
+
604
+ // Compute the adjoint contributions for V (right singular vectors)
605
+ mat_t<2,2,Type> v_term = mul(U, mul(s_mat, mul(cw_mul(F, (mul(VT, adj_V) - mul(transpose(adj_V), V))), VT)));
606
+
607
+ // Combine the terms to compute the adjoint of A
608
+ adj_A = adj_A + (u_term + v_term + sigma_term);
609
+ }
610
+
486
611
 
487
612
  template<typename Type>
488
613
  inline CUDA_CALLABLE void qr3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& Q, mat_t<3,3,Type>& R) {
warp/native/temp_buffer.h CHANGED
@@ -1,9 +1,18 @@
1
- /** Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- * NVIDIA CORPORATION and its licensors retain all intellectual property
3
- * and proprietary rights in and to this software, related documentation
4
- * and any modifications thereto. Any use, reproduction, disclosure or
5
- * distribution of this software and related documentation without an express
6
- * license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
7
16
  */
8
17
 
9
18
  #pragma once