warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
warp/native/svd.h ADDED
@@ -0,0 +1,702 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ // The MIT License (MIT)
19
+
20
+ // Copyright (c) 2014 Eric V. Jang
21
+
22
+ // Permission is hereby granted, free of charge, to any person obtaining a copy
23
+ // of this software and associated documentation files (the "Software"), to deal
24
+ // in the Software without restriction, including without limitation the rights
25
+ // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
26
+ // copies of the Software, and to permit persons to whom the Software is
27
+ // furnished to do so, subject to the following conditions:
28
+
29
+ // The above copyright notice and this permission notice shall be included in all
30
+ // copies or substantial portions of the Software.
31
+
32
+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ // SOFTWARE.
39
+
40
+ // Source: https://github.com/ericjang/svd3/blob/master/svd3_cuda/svd3_cuda.h
41
+
42
+
43
+ #pragma once
44
+
45
+ #include "builtin.h"
46
+
47
+ namespace wp
48
+ {
49
+
50
+
51
+ template<typename Type>
52
+ struct _svd_config {
53
+ static constexpr float QR_GIVENS_EPSILON = 1.e-6f;
54
+ static constexpr int JACOBI_ITERATIONS = 4;
55
+ };
56
+
57
+ template<>
58
+ struct _svd_config<double> {
59
+ static constexpr double QR_GIVENS_EPSILON = 1.e-12;
60
+ static constexpr int JACOBI_ITERATIONS = 8;
61
+ };
62
+
63
+
64
+
65
+ // TODO: replace sqrt with rsqrt
66
+
67
+ template<typename Type>
68
+ inline CUDA_CALLABLE
69
+ Type accurateSqrt(Type x)
70
+ {
71
+ return x / sqrt(x);
72
+ }
73
+
74
+ template<typename Type>
75
+ inline CUDA_CALLABLE
76
+ void condSwap(bool c, Type &X, Type &Y)
77
+ {
78
+ // used in step 2
79
+ Type Z = X;
80
+ X = c ? Y : X;
81
+ Y = c ? Z : Y;
82
+ }
83
+
84
+ template<typename Type>
85
+ inline CUDA_CALLABLE
86
+ void condNegSwap(bool c, Type &X, Type &Y)
87
+ {
88
+ // used in step 2 and 3
89
+ Type Z = -X;
90
+ X = c ? Y : X;
91
+ Y = c ? Z : Y;
92
+ }
93
+
94
+ // matrix multiplication M = A * B
95
+ template<typename Type>
96
+ inline CUDA_CALLABLE
97
+ void multAB(Type a11, Type a12, Type a13,
98
+ Type a21, Type a22, Type a23,
99
+ Type a31, Type a32, Type a33,
100
+ //
101
+ Type b11, Type b12, Type b13,
102
+ Type b21, Type b22, Type b23,
103
+ Type b31, Type b32, Type b33,
104
+ //
105
+ Type &m11, Type &m12, Type &m13,
106
+ Type &m21, Type &m22, Type &m23,
107
+ Type &m31, Type &m32, Type &m33)
108
+ {
109
+
110
+ m11=a11*b11 + a12*b21 + a13*b31; m12=a11*b12 + a12*b22 + a13*b32; m13=a11*b13 + a12*b23 + a13*b33;
111
+ m21=a21*b11 + a22*b21 + a23*b31; m22=a21*b12 + a22*b22 + a23*b32; m23=a21*b13 + a22*b23 + a23*b33;
112
+ m31=a31*b11 + a32*b21 + a33*b31; m32=a31*b12 + a32*b22 + a33*b32; m33=a31*b13 + a32*b23 + a33*b33;
113
+ }
114
+
115
+ // matrix multiplication M = Transpose[A] * B
116
+ template<typename Type>
117
+ inline CUDA_CALLABLE
118
+ void multAtB(Type a11, Type a12, Type a13,
119
+ Type a21, Type a22, Type a23,
120
+ Type a31, Type a32, Type a33,
121
+ //
122
+ Type b11, Type b12, Type b13,
123
+ Type b21, Type b22, Type b23,
124
+ Type b31, Type b32, Type b33,
125
+ //
126
+ Type &m11, Type &m12, Type &m13,
127
+ Type &m21, Type &m22, Type &m23,
128
+ Type &m31, Type &m32, Type &m33)
129
+ {
130
+ m11=a11*b11 + a21*b21 + a31*b31; m12=a11*b12 + a21*b22 + a31*b32; m13=a11*b13 + a21*b23 + a31*b33;
131
+ m21=a12*b11 + a22*b21 + a32*b31; m22=a12*b12 + a22*b22 + a32*b32; m23=a12*b13 + a22*b23 + a32*b33;
132
+ m31=a13*b11 + a23*b21 + a33*b31; m32=a13*b12 + a23*b22 + a33*b32; m33=a13*b13 + a23*b23 + a33*b33;
133
+ }
134
+
135
+ template<typename Type>
136
+ inline CUDA_CALLABLE
137
+ void quatToMat3(const Type * qV,
138
+ Type &m11, Type &m12, Type &m13,
139
+ Type &m21, Type &m22, Type &m23,
140
+ Type &m31, Type &m32, Type &m33
141
+ )
142
+ {
143
+ Type w = qV[3];
144
+ Type x = qV[0];
145
+ Type y = qV[1];
146
+ Type z = qV[2];
147
+
148
+ Type qxx = x*x;
149
+ Type qyy = y*y;
150
+ Type qzz = z*z;
151
+ Type qxz = x*z;
152
+ Type qxy = x*y;
153
+ Type qyz = y*z;
154
+ Type qwx = w*x;
155
+ Type qwy = w*y;
156
+ Type qwz = w*z;
157
+
158
+ m11=Type(1) - Type(2)*(qyy + qzz); m12=Type(2)*(qxy - qwz); m13=Type(2)*(qxz + qwy);
159
+ m21=Type(2)*(qxy + qwz); m22=Type(1) - Type(2)*(qxx + qzz); m23=Type(2)*(qyz - qwx);
160
+ m31=Type(2)*(qxz - qwy); m32=Type(2)*(qyz + qwx); m33=Type(1) - Type(2)*(qxx + qyy);
161
+ }
162
+
163
+ template<typename Type>
164
+ inline CUDA_CALLABLE
165
+ void approximateGivensQuaternion(Type a11, Type a12, Type a22, Type &ch, Type &sh)
166
+ {
167
+ /*
168
+ * Given givens angle computed by approximateGivensAngles,
169
+ * compute the corresponding rotation quaternion.
170
+ */
171
+ constexpr double _gamma = 5.82842712474619; // FOUR_GAMMA_SQUARED = sqrt(8)+3;
172
+ constexpr double _cstar = 0.9238795325112867; // cos(pi/8)
173
+ constexpr double _sstar = 0.3826834323650898; // sin(p/8)
174
+
175
+ ch = Type(2)*(a11-a22);
176
+ sh = a12;
177
+ bool b = Type(_gamma)*sh*sh < ch*ch;
178
+ Type w = Type(1) / sqrt(ch*ch+sh*sh);
179
+ ch=b?w*ch:Type(_cstar);
180
+ sh=b?w*sh:Type(_sstar);
181
+ }
182
+
183
+ template<typename Type>
184
+ inline CUDA_CALLABLE
185
+ void jacobiConjugation( const int x, const int y, const int z,
186
+ Type &s11,
187
+ Type &s21, Type &s22,
188
+ Type &s31, Type &s32, Type &s33,
189
+ Type * qV)
190
+ {
191
+ Type ch,sh;
192
+ approximateGivensQuaternion(s11,s21,s22,ch,sh);
193
+
194
+ Type scale = ch*ch+sh*sh;
195
+ Type a = (ch*ch-sh*sh)/scale;
196
+ Type b = (Type(2)*sh*ch)/scale;
197
+
198
+ // make temp copy of S
199
+ Type _s11 = s11;
200
+ Type _s21 = s21; Type _s22 = s22;
201
+ Type _s31 = s31; Type _s32 = s32; Type _s33 = s33;
202
+
203
+ // perform conjugation S = Q'*S*Q
204
+ // Q already implicitly solved from a, b
205
+ s11 =a*(a*_s11 + b*_s21) + b*(a*_s21 + b*_s22);
206
+ s21 =a*(-b*_s11 + a*_s21) + b*(-b*_s21 + a*_s22); s22=-b*(-b*_s11 + a*_s21) + a*(-b*_s21 + a*_s22);
207
+ s31 =a*_s31 + b*_s32; s32=-b*_s31 + a*_s32; s33=_s33;
208
+
209
+ // update cumulative rotation qV
210
+ Type tmp[3];
211
+ tmp[0]=qV[0]*sh;
212
+ tmp[1]=qV[1]*sh;
213
+ tmp[2]=qV[2]*sh;
214
+ sh *= qV[3];
215
+
216
+ qV[0] *= ch;
217
+ qV[1] *= ch;
218
+ qV[2] *= ch;
219
+ qV[3] *= ch;
220
+
221
+ // (x,y,z) corresponds to ((0,1,2),(1,2,0),(2,0,1))
222
+ // for (p,q) = ((0,1),(1,2),(0,2))
223
+ qV[z] += sh;
224
+ qV[3] -= tmp[z]; // w
225
+ qV[x] += tmp[y];
226
+ qV[y] -= tmp[x];
227
+
228
+ // re-arrange matrix for next iteration
229
+ _s11 = s22;
230
+ _s21 = s32; _s22 = s33;
231
+ _s31 = s21; _s32 = s31; _s33 = s11;
232
+ s11 = _s11;
233
+ s21 = _s21; s22 = _s22;
234
+ s31 = _s31; s32 = _s32; s33 = _s33;
235
+
236
+ }
237
+
238
+ template<typename Type>
239
+ inline CUDA_CALLABLE
240
+ Type dist2(Type x, Type y, Type z)
241
+ {
242
+ return x*x+y*y+z*z;
243
+ }
244
+
245
+ // finds transformation that diagonalizes a symmetric matrix
246
+ template<typename Type>
247
+ inline CUDA_CALLABLE
248
+ void jacobiEigenanlysis( // symmetric matrix
249
+ Type &s11,
250
+ Type &s21, Type &s22,
251
+ Type &s31, Type &s32, Type &s33,
252
+ // quaternion representation of V
253
+ Type * qV)
254
+ {
255
+ qV[3]=1; qV[0]=0;qV[1]=0;qV[2]=0; // follow same indexing convention as GLM
256
+ constexpr int ITERS = _svd_config<Type>::JACOBI_ITERATIONS;
257
+ for (int i=0;i<ITERS;i++)
258
+ {
259
+ // we wish to eliminate the maximum off-diagonal element
260
+ // on every iteration, but cycling over all 3 possible rotations
261
+ // in fixed order (p,q) = (1,2) , (2,3), (1,3) still retains
262
+ // asymptotic convergence
263
+ jacobiConjugation(0,1,2,s11,s21,s22,s31,s32,s33,qV); // p,q = 0,1
264
+ jacobiConjugation(1,2,0,s11,s21,s22,s31,s32,s33,qV); // p,q = 1,2
265
+ jacobiConjugation(2,0,1,s11,s21,s22,s31,s32,s33,qV); // p,q = 0,2
266
+ }
267
+ }
268
+
269
+ template<typename Type>
270
+ inline CUDA_CALLABLE
271
+ void sortSingularValues(// matrix that we want to decompose
272
+ Type &b11, Type &b12, Type &b13,
273
+ Type &b21, Type &b22, Type &b23,
274
+ Type &b31, Type &b32, Type &b33,
275
+ // sort V simultaneously
276
+ Type &v11, Type &v12, Type &v13,
277
+ Type &v21, Type &v22, Type &v23,
278
+ Type &v31, Type &v32, Type &v33)
279
+ {
280
+ Type rho1 = dist2(b11,b21,b31);
281
+ Type rho2 = dist2(b12,b22,b32);
282
+ Type rho3 = dist2(b13,b23,b33);
283
+ bool c;
284
+ c = rho1 < rho2;
285
+ condNegSwap(c,b11,b12); condNegSwap(c,v11,v12);
286
+ condNegSwap(c,b21,b22); condNegSwap(c,v21,v22);
287
+ condNegSwap(c,b31,b32); condNegSwap(c,v31,v32);
288
+ condSwap(c,rho1,rho2);
289
+ c = rho1 < rho3;
290
+ condNegSwap(c,b11,b13); condNegSwap(c,v11,v13);
291
+ condNegSwap(c,b21,b23); condNegSwap(c,v21,v23);
292
+ condNegSwap(c,b31,b33); condNegSwap(c,v31,v33);
293
+ condSwap(c,rho1,rho3);
294
+ c = rho2 < rho3;
295
+ condNegSwap(c,b12,b13); condNegSwap(c,v12,v13);
296
+ condNegSwap(c,b22,b23); condNegSwap(c,v22,v23);
297
+ condNegSwap(c,b32,b33); condNegSwap(c,v32,v33);
298
+ }
299
+
300
+ template<typename Type>
301
+ inline CUDA_CALLABLE
302
+ void QRGivensQuaternion(Type a1, Type a2, Type &ch, Type &sh)
303
+ {
304
+ // a1 = pivot point on diagonal
305
+ // a2 = lower triangular entry we want to annihilate
306
+ const Type epsilon = _svd_config<Type>::QR_GIVENS_EPSILON;
307
+ Type rho = accurateSqrt(a1*a1 + a2*a2);
308
+
309
+ sh = rho > epsilon ? a2 : Type(0);
310
+ ch = abs(a1) + max(rho,epsilon);
311
+ bool b = a1 < Type(0);
312
+ condSwap(b,sh,ch);
313
+ Type w = Type(1) / sqrt(ch*ch+sh*sh);
314
+ ch *= w;
315
+ sh *= w;
316
+ }
317
+
318
+ template<typename Type>
319
+ inline CUDA_CALLABLE
320
+ void QRDecomposition(// matrix that we want to decompose
321
+ Type b11, Type b12, Type b13,
322
+ Type b21, Type b22, Type b23,
323
+ Type b31, Type b32, Type b33,
324
+ // output Q
325
+ Type &q11, Type &q12, Type &q13,
326
+ Type &q21, Type &q22, Type &q23,
327
+ Type &q31, Type &q32, Type &q33,
328
+ // output R
329
+ Type &r11, Type &r12, Type &r13,
330
+ Type &r21, Type &r22, Type &r23,
331
+ Type &r31, Type &r32, Type &r33)
332
+ {
333
+ Type ch1,sh1,ch2,sh2,ch3,sh3;
334
+ Type a,b;
335
+
336
+ // first givens rotation (ch,0,0,sh)
337
+ QRGivensQuaternion(b11,b21,ch1,sh1);
338
+ a=Type(1)-Type(2)*sh1*sh1;
339
+ b=Type(2)*ch1*sh1;
340
+ // apply B = Q' * B
341
+ r11=a*b11+b*b21; r12=a*b12+b*b22; r13=a*b13+b*b23;
342
+ r21=-b*b11+a*b21; r22=-b*b12+a*b22; r23=-b*b13+a*b23;
343
+ r31=b31; r32=b32; r33=b33;
344
+
345
+ // second givens rotation (ch,0,-sh,0)
346
+ QRGivensQuaternion(r11,r31,ch2,sh2);
347
+ a=Type(1)-Type(2)*sh2*sh2;
348
+ b=Type(2)*ch2*sh2;
349
+ // apply B = Q' * B;
350
+ b11=a*r11+b*r31; b12=a*r12+b*r32; b13=a*r13+b*r33;
351
+ b21=r21; b22=r22; b23=r23;
352
+ b31=-b*r11+a*r31; b32=-b*r12+a*r32; b33=-b*r13+a*r33;
353
+
354
+ // third givens rotation (ch,sh,0,0)
355
+ QRGivensQuaternion(b22,b32,ch3,sh3);
356
+ a=Type(1)-Type(2)*sh3*sh3;
357
+ b=Type(2)*ch3*sh3;
358
+ // R is now set to desired value
359
+ r11=b11; r12=b12; r13=b13;
360
+ r21=a*b21+b*b31; r22=a*b22+b*b32; r23=a*b23+b*b33;
361
+ r31=-b*b21+a*b31; r32=-b*b22+a*b32; r33=-b*b23+a*b33;
362
+
363
+ // construct the cumulative rotation Q=Q1 * Q2 * Q3
364
+ // the number of floating point operations for three quaternion multiplications
365
+ // is more or less comparable to the explicit form of the joined matrix.
366
+ // certainly more memory-efficient!
367
+ Type sh12=sh1*sh1;
368
+ Type sh22=sh2*sh2;
369
+ Type sh32=sh3*sh3;
370
+
371
+ q11=(Type(-1)+Type(2)*sh12)*(Type(-1)+Type(2)*sh22);
372
+ q12=Type(4)*ch2*ch3*(Type(-1)+Type(2)*sh12)*sh2*sh3+Type(2)*ch1*sh1*(Type(-1)+Type(2)*sh32);
373
+ q13=Type(4)*ch1*ch3*sh1*sh3-Type(2)*ch2*(Type(-1)+Type(2)*sh12)*sh2*(Type(-1)+Type(2)*sh32);
374
+
375
+ q21=Type(2)*ch1*sh1*(Type(1)-Type(2)*sh22);
376
+ q22=Type(-8)*ch1*ch2*ch3*sh1*sh2*sh3+(Type(-1)+Type(2)*sh12)*(Type(-1)+Type(2)*sh32);
377
+ q23=Type(-2)*ch3*sh3+Type(4)*sh1*(ch3*sh1*sh3+ch1*ch2*sh2*(Type(-1)+Type(2)*sh32));
378
+
379
+ q31=Type(2)*ch2*sh2;
380
+ q32=Type(2)*ch3*(Type(1)-Type(2)*sh22)*sh3;
381
+ q33=(Type(-1)+Type(2)*sh22)*(Type(-1)+Type(2)*sh32);
382
+ }
383
+
384
+ template<typename Type>
385
+ inline CUDA_CALLABLE
386
+ void _svd(// input A
387
+ Type a11, Type a12, Type a13,
388
+ Type a21, Type a22, Type a23,
389
+ Type a31, Type a32, Type a33,
390
+ // output U
391
+ Type &u11, Type &u12, Type &u13,
392
+ Type &u21, Type &u22, Type &u23,
393
+ Type &u31, Type &u32, Type &u33,
394
+ // output S
395
+ Type &s11, Type &s12, Type &s13,
396
+ Type &s21, Type &s22, Type &s23,
397
+ Type &s31, Type &s32, Type &s33,
398
+ // output V
399
+ Type &v11, Type &v12, Type &v13,
400
+ Type &v21, Type &v22, Type &v23,
401
+ Type &v31, Type &v32, Type &v33)
402
+ {
403
+ // normal equations matrix
404
+ Type ATA11, ATA12, ATA13;
405
+ Type ATA21, ATA22, ATA23;
406
+ Type ATA31, ATA32, ATA33;
407
+
408
+ multAtB(a11,a12,a13,a21,a22,a23,a31,a32,a33,
409
+ a11,a12,a13,a21,a22,a23,a31,a32,a33,
410
+ ATA11,ATA12,ATA13,ATA21,ATA22,ATA23,ATA31,ATA32,ATA33);
411
+
412
+ // symmetric eigenalysis
413
+ Type qV[4];
414
+ jacobiEigenanlysis( ATA11,ATA21,ATA22, ATA31,ATA32,ATA33,qV);
415
+ quatToMat3(qV,v11,v12,v13,v21,v22,v23,v31,v32,v33);
416
+
417
+ Type b11, b12, b13;
418
+ Type b21, b22, b23;
419
+ Type b31, b32, b33;
420
+ multAB(a11,a12,a13,a21,a22,a23,a31,a32,a33,
421
+ v11,v12,v13,v21,v22,v23,v31,v32,v33,
422
+ b11, b12, b13, b21, b22, b23, b31, b32, b33);
423
+
424
+ // sort singular values and find V
425
+ sortSingularValues(b11, b12, b13, b21, b22, b23, b31, b32, b33,
426
+ v11,v12,v13,v21,v22,v23,v31,v32,v33);
427
+
428
+ // QR decomposition
429
+ QRDecomposition(b11, b12, b13, b21, b22, b23, b31, b32, b33,
430
+ u11, u12, u13, u21, u22, u23, u31, u32, u33,
431
+ s11, s12, s13, s21, s22, s23, s31, s32, s33
432
+ );
433
+ }
434
+
435
+
436
+ template<typename Type>
437
+ inline CUDA_CALLABLE
438
+ void _svd_2(// input A
439
+ Type a11, Type a12,
440
+ Type a21, Type a22,
441
+ // output U
442
+ Type &u11, Type &u12,
443
+ Type &u21, Type &u22,
444
+ // output S
445
+ Type &s11, Type &s12,
446
+ Type &s21, Type &s22,
447
+ // output V
448
+ Type &v11, Type &v12,
449
+ Type &v21, Type &v22)
450
+ {
451
+ // Step 1: Compute ATA
452
+ Type ATA11 = a11 * a11 + a21 * a21;
453
+ Type ATA12 = a11 * a12 + a21 * a22;
454
+ Type ATA22 = a12 * a12 + a22 * a22;
455
+
456
+ // Step 2: Eigenanalysis
457
+ Type trace = ATA11 + ATA22;
458
+ Type det = ATA11 * ATA22 - ATA12 * ATA12;
459
+ Type sqrt_term = sqrt(trace * trace - Type(4.0) * det);
460
+ Type lambda1 = (trace + sqrt_term) * Type(0.5);
461
+ Type lambda2 = (trace - sqrt_term) * Type(0.5);
462
+
463
+ // Step 3: Singular values
464
+ Type sigma1 = sqrt(lambda1);
465
+ Type sigma2 = sqrt(lambda2);
466
+
467
+ // Step 4: Eigenvectors (find V)
468
+ Type v1x = ATA12, v1y = lambda1 - ATA11; // For first eigenvector
469
+ Type v2x = ATA12, v2y = lambda2 - ATA11; // For second eigenvector
470
+ Type norm1 = sqrt(v1x * v1x + v1y * v1y);
471
+ Type norm2 = sqrt(v2x * v2x + v2y * v2y);
472
+
473
+ v11 = v1x / norm1; v12 = v2x / norm2;
474
+ v21 = v1y / norm1; v22 = v2y / norm2;
475
+
476
+ // Step 5: Compute U
477
+ Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
478
+ Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);
479
+
480
+ u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
481
+ u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
482
+ u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
483
+ u22 = (a21 * v12 + a22 * v22) * inv_sigma2;
484
+
485
+ // Step 6: Set S
486
+ s11 = sigma1; s12 = Type(0.0);
487
+ s21 = Type(0.0); s22 = sigma2;
488
+ }
489
+
490
+
491
+ template<typename Type>
492
+ inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
493
+ Type s12, s13, s21, s23, s31, s32;
494
+ _svd(A.data[0][0], A.data[0][1], A.data[0][2],
495
+ A.data[1][0], A.data[1][1], A.data[1][2],
496
+ A.data[2][0], A.data[2][1], A.data[2][2],
497
+
498
+ U.data[0][0], U.data[0][1], U.data[0][2],
499
+ U.data[1][0], U.data[1][1], U.data[1][2],
500
+ U.data[2][0], U.data[2][1], U.data[2][2],
501
+
502
+ sigma[0], s12, s13,
503
+ s21, sigma[1], s23,
504
+ s31, s32, sigma[2],
505
+
506
+ V.data[0][0], V.data[0][1], V.data[0][2],
507
+ V.data[1][0], V.data[1][1], V.data[1][2],
508
+ V.data[2][0], V.data[2][1], V.data[2][2]);
509
+ }
510
+
511
+ template<typename Type>
512
+ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
513
+ const mat_t<3,3,Type>& U,
514
+ const vec_t<3,Type>& sigma,
515
+ const mat_t<3,3,Type>& V,
516
+ mat_t<3,3,Type>& adj_A,
517
+ const mat_t<3,3,Type>& adj_U,
518
+ const vec_t<3,Type>& adj_sigma,
519
+ const mat_t<3,3,Type>& adj_V) {
520
+ Type sx2 = sigma[0] * sigma[0];
521
+ Type sy2 = sigma[1] * sigma[1];
522
+ Type sz2 = sigma[2] * sigma[2];
523
+
524
+ Type F01 = Type(1) / min(sy2 - sx2, Type(-1e-6f));
525
+ Type F02 = Type(1) / min(sz2 - sx2, Type(-1e-6f));
526
+ Type F12 = Type(1) / min(sz2 - sy2, Type(-1e-6f));
527
+
528
+ mat_t<3,3,Type> F = mat_t<3,3,Type>(0, F01, F02,
529
+ -F01, 0, F12,
530
+ -F02, -F12, 0);
531
+
532
+ mat_t<3,3,Type> adj_sigma_mat = mat_t<3,3,Type>(adj_sigma[0], 0, 0,
533
+ 0, adj_sigma[1], 0,
534
+ 0, 0, adj_sigma[2]);
535
+ mat_t<3,3,Type> s_mat = mat_t<3,3,Type>(sigma[0], 0, 0,
536
+ 0, sigma[1], 0,
537
+ 0, 0, sigma[2]);
538
+
539
+ // https://github.com/pytorch/pytorch/blob/d7ddae8e4fe66fa1330317673438d1eb5aa99ca4/torch/csrc/autograd/FunctionsManual.cpp
540
+ mat_t<3,3,Type> UT = transpose(U);
541
+ mat_t<3,3,Type> VT = transpose(V);
542
+
543
+ mat_t<3,3,Type> sigma_term = mul(U, mul(adj_sigma_mat, VT));
544
+
545
+ mat_t<3,3,Type> u_term = mul(mul(U, mul(cw_mul(F, (mul(UT, adj_U) - mul(transpose(adj_U), U))), s_mat)), VT);
546
+ mat_t<3,3,Type> v_term = mul(U, mul(s_mat, mul(cw_mul(F, (mul(VT, adj_V) - mul(transpose(adj_V), V))), VT)));
547
+
548
+ adj_A = adj_A + (u_term + v_term + sigma_term);
549
+ }
550
+
551
+ template<typename Type>
552
+ inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
553
+ Type s12, s21;
554
+ _svd_2(A.data[0][0], A.data[0][1],
555
+ A.data[1][0], A.data[1][1],
556
+
557
+ U.data[0][0], U.data[0][1],
558
+ U.data[1][0], U.data[1][1],
559
+
560
+ sigma[0], s12,
561
+ s21, sigma[1],
562
+
563
+ V.data[0][0], V.data[0][1],
564
+ V.data[1][0], V.data[1][1]);
565
+ }
566
+
567
+ template<typename Type>
568
+ inline CUDA_CALLABLE void adj_svd2(const mat_t<2,2,Type>& A,
569
+ const mat_t<2,2,Type>& U,
570
+ const vec_t<2,Type>& sigma,
571
+ const mat_t<2,2,Type>& V,
572
+ mat_t<2,2,Type>& adj_A,
573
+ const mat_t<2,2,Type>& adj_U,
574
+ const vec_t<2,Type>& adj_sigma,
575
+ const mat_t<2,2,Type>& adj_V) {
576
+ Type s1_squared = sigma[0] * sigma[0];
577
+ Type s2_squared = sigma[1] * sigma[1];
578
+
579
+ // Compute inverse of (s1^2 - s2^2) if possible, use small epsilon to prevent division by zero
580
+ Type F01 = Type(1) / min(s2_squared - s1_squared, Type(-1e-6f));
581
+
582
+ // Construct the matrix F for the adjoint
583
+ mat_t<2,2,Type> F = mat_t<2,2,Type>(0.0, F01,
584
+ -F01, 0.0);
585
+
586
+ // Create a matrix to handle the adjoint of the singular values (diagonal matrix)
587
+ mat_t<2,2,Type> adj_sigma_mat = mat_t<2,2,Type>(adj_sigma[0], 0.0,
588
+ 0.0, adj_sigma[1]);
589
+
590
+ // Matrix for handling singular values (diagonal matrix with sigma values)
591
+ mat_t<2,2,Type> s_mat = mat_t<2,2,Type>(sigma[0], 0.0,
592
+ 0.0, sigma[1]);
593
+
594
+ // Compute the transpose of U and V
595
+ mat_t<2,2,Type> UT = transpose(U);
596
+ mat_t<2,2,Type> VT = transpose(V);
597
+
598
+ // Compute the term for sigma (diagonal matrix of adjoint singular values)
599
+ mat_t<2,2,Type> sigma_term = mul(U, mul(adj_sigma_mat, VT));
600
+
601
+ // Compute the adjoint contributions for U (left singular vectors)
602
+ mat_t<2,2,Type> u_term = mul(mul(U, mul(cw_mul(F, (mul(UT, adj_U) - mul(transpose(adj_U), U))), s_mat)), VT);
603
+
604
+ // Compute the adjoint contributions for V (right singular vectors)
605
+ mat_t<2,2,Type> v_term = mul(U, mul(s_mat, mul(cw_mul(F, (mul(VT, adj_V) - mul(transpose(adj_V), V))), VT)));
606
+
607
+ // Combine the terms to compute the adjoint of A
608
+ adj_A = adj_A + (u_term + v_term + sigma_term);
609
+ }
610
+
611
+
612
+ template<typename Type>
613
+ inline CUDA_CALLABLE void qr3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& Q, mat_t<3,3,Type>& R) {
614
+ QRDecomposition(A.data[0][0], A.data[0][1], A.data[0][2],
615
+ A.data[1][0], A.data[1][1], A.data[1][2],
616
+ A.data[2][0], A.data[2][1], A.data[2][2],
617
+
618
+ Q.data[0][0], Q.data[0][1], Q.data[0][2],
619
+ Q.data[1][0], Q.data[1][1], Q.data[1][2],
620
+ Q.data[2][0], Q.data[2][1], Q.data[2][2],
621
+
622
+ R.data[0][0], R.data[0][1], R.data[0][2],
623
+ R.data[1][0], R.data[1][1], R.data[1][2],
624
+ R.data[2][0], R.data[2][1], R.data[2][2]);
625
+ }
626
+
627
+
628
+ template<typename Type>
629
+ inline CUDA_CALLABLE void adj_qr3(const mat_t<3,3,Type>& A,
630
+ const mat_t<3,3,Type>& Q,
631
+ const mat_t<3,3,Type>& R,
632
+ mat_t<3,3,Type>& adj_A,
633
+ const mat_t<3,3,Type>& adj_Q,
634
+ const mat_t<3,3,Type>& adj_R) {
635
+ // Eq 3 of https://arxiv.org/pdf/2009.10071.pdf
636
+ mat_t<3,3,Type> M = mul(R,transpose(adj_R)) - mul(transpose(adj_Q), Q);
637
+ mat_t<3,3,Type> copyltuM = mat_t<3,3,Type>(M.data[0][0], M.data[1][0], M.data[2][0],
638
+ M.data[1][0], M.data[1][1], M.data[2][1],
639
+ M.data[2][0], M.data[2][1], M.data[2][2]);
640
+ adj_A = adj_A + mul(adj_Q + mul(Q,copyltuM), inverse(transpose(R)));
641
+ }
642
+
643
+
644
+ template<typename Type>
645
+ inline CUDA_CALLABLE void eig3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& Q, vec_t<3,Type>& d) {
646
+ Type qV[4];
647
+ Type s11 = A.data[0][0];
648
+ Type s21 = A.data[1][0];
649
+ Type s22 = A.data[1][1];
650
+ Type s31 = A.data[2][0];
651
+ Type s32 = A.data[2][1];
652
+ Type s33 = A.data[2][2];
653
+
654
+ jacobiEigenanlysis(s11, s21, s22, s31, s32, s33, qV);
655
+ quatToMat3(qV, Q.data[0][0], Q.data[0][1], Q.data[0][2], Q.data[1][0], Q.data[1][1], Q.data[1][2], Q.data[2][0], Q.data[2][1], Q.data[2][2]);
656
+ mat_t<3,3,Type> t;
657
+ multAtB(Q.data[0][0], Q.data[0][1], Q.data[0][2], Q.data[1][0], Q.data[1][1], Q.data[1][2], Q.data[2][0], Q.data[2][1], Q.data[2][2],
658
+ A.data[0][0], A.data[0][1], A.data[0][2], A.data[1][0], A.data[1][1], A.data[1][2], A.data[2][0], A.data[2][1], A.data[2][2],
659
+ t.data[0][0], t.data[0][1], t.data[0][2], t.data[1][0], t.data[1][1], t.data[1][2], t.data[2][0], t.data[2][1], t.data[2][2]);
660
+
661
+ mat_t<3,3,Type> u;
662
+ multAB(t.data[0][0], t.data[0][1], t.data[0][2], t.data[1][0], t.data[1][1], t.data[1][2], t.data[2][0], t.data[2][1], t.data[2][2],
663
+ Q.data[0][0], Q.data[0][1], Q.data[0][2], Q.data[1][0], Q.data[1][1], Q.data[1][2], Q.data[2][0], Q.data[2][1], Q.data[2][2],
664
+ u.data[0][0], u.data[0][1], u.data[0][2], u.data[1][0], u.data[1][1], u.data[1][2], u.data[2][0], u.data[2][1], u.data[2][2]
665
+ );
666
+ d = vec_t<3,Type>(u.data[0][0], u.data[1][1], u.data[2][2]);
667
+ }
668
+
669
+ template<typename Type>
670
+ inline CUDA_CALLABLE void adj_eig3(const mat_t<3,3,Type>& A, const mat_t<3,3,Type>& Q, const vec_t<3,Type>& d,
671
+ mat_t<3,3,Type>& adj_A, const mat_t<3,3,Type>& adj_Q, const vec_t<3,Type>& adj_d) {
672
+ // Page 10 of https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
673
+ mat_t<3,3,Type> D = mat_t<3,3,Type>(d[0], 0, 0,
674
+ 0, d[1], 0,
675
+ 0, 0, d[2]);
676
+ mat_t<3,3,Type> D_bar = mat_t<3,3,Type>(adj_d[0], 0, 0,
677
+ 0, adj_d[1], 0,
678
+ 0, 0, adj_d[2]);
679
+
680
+ Type dyx = d[1] - d[0];
681
+ Type dzx = d[2] - d[0];
682
+ Type dzy = d[2] - d[1];
683
+
684
+ if ((dyx < Type(0)) && (dyx > Type(-1e-6))) dyx = -1e-6;
685
+ if ((dyx > Type(0)) && (dyx < Type(1e-6))) dyx = 1e-6;
686
+
687
+ if ((dzx < Type(0)) && (dzx > Type(-1e-6))) dzx = -1e-6;
688
+ if ((dzx > Type(0)) && (dzx < Type(1e-6))) dzx = 1e-6;
689
+
690
+ if ((dzy < Type(0)) && (dzy > Type(-1e-6))) dzy = -1e-6;
691
+ if ((dzy > Type(0)) && (dzy < Type(1e-6))) dzy = 1e-6;
692
+
693
+ Type F01 = Type(1) / dyx;
694
+ Type F02 = Type(1) / dzx;
695
+ Type F12 = Type(1) / dzy;
696
+ mat_t<3,3,Type> F = mat_t<3,3,Type>(0, F01, F02,
697
+ -F01, 0, F12,
698
+ -F02, -F12, 0);
699
+ mat_t<3,3,Type> QT = transpose(Q);
700
+ adj_A = adj_A + mul(Q, mul(D_bar + cw_mul(F, mul(QT, adj_Q)), QT));
701
+ }
702
+ }