warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (350) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +2220 -313
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1497 -226
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -471
  95. warp/codegen.py +6 -4246
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -7851
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +3 -2
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -342
  136. warp/jax_experimental/ffi.py +17 -853
  137. warp/jax_experimental/xla_ffi.py +5 -596
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +316 -39
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sort.cu +22 -13
  159. warp/native/sort.h +2 -0
  160. warp/native/sparse.cu +7 -3
  161. warp/native/spatial.h +12 -0
  162. warp/native/tile.h +837 -70
  163. warp/native/tile_radix_sort.h +1 -1
  164. warp/native/tile_reduce.h +394 -46
  165. warp/native/tile_scan.h +4 -4
  166. warp/native/vec.h +469 -53
  167. warp/native/version.h +23 -0
  168. warp/native/volume.cpp +1 -1
  169. warp/native/volume.cu +1 -0
  170. warp/native/volume.h +1 -1
  171. warp/native/volume_builder.cu +2 -0
  172. warp/native/warp.cpp +60 -32
  173. warp/native/warp.cu +313 -201
  174. warp/native/warp.h +14 -11
  175. warp/optim/__init__.py +6 -3
  176. warp/optim/adam.py +6 -145
  177. warp/optim/linear.py +14 -1585
  178. warp/optim/sgd.py +6 -94
  179. warp/paddle.py +6 -388
  180. warp/render/__init__.py +8 -4
  181. warp/render/imgui_manager.py +7 -267
  182. warp/render/render_opengl.py +6 -3616
  183. warp/render/render_usd.py +6 -918
  184. warp/render/utils.py +6 -142
  185. warp/sparse.py +37 -2563
  186. warp/tape.py +6 -1188
  187. warp/tests/__main__.py +1 -1
  188. warp/tests/cuda/test_async.py +4 -4
  189. warp/tests/cuda/test_conditional_captures.py +1 -1
  190. warp/tests/cuda/test_multigpu.py +1 -1
  191. warp/tests/cuda/test_streams.py +58 -1
  192. warp/tests/geometry/test_bvh.py +157 -22
  193. warp/tests/geometry/test_hash_grid.py +38 -0
  194. warp/tests/geometry/test_marching_cubes.py +0 -1
  195. warp/tests/geometry/test_mesh.py +5 -3
  196. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  197. warp/tests/geometry/test_mesh_query_point.py +5 -2
  198. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  199. warp/tests/geometry/test_volume_write.py +5 -5
  200. warp/tests/interop/test_dlpack.py +14 -14
  201. warp/tests/interop/test_jax.py +1382 -79
  202. warp/tests/interop/test_paddle.py +1 -1
  203. warp/tests/test_adam.py +0 -1
  204. warp/tests/test_arithmetic.py +9 -9
  205. warp/tests/test_array.py +529 -100
  206. warp/tests/test_array_reduce.py +3 -3
  207. warp/tests/test_atomic.py +12 -8
  208. warp/tests/test_atomic_bitwise.py +209 -0
  209. warp/tests/test_atomic_cas.py +4 -4
  210. warp/tests/test_bool.py +2 -2
  211. warp/tests/test_builtins_resolution.py +5 -571
  212. warp/tests/test_codegen.py +34 -15
  213. warp/tests/test_conditional.py +1 -1
  214. warp/tests/test_context.py +6 -6
  215. warp/tests/test_copy.py +242 -161
  216. warp/tests/test_ctypes.py +3 -3
  217. warp/tests/test_devices.py +24 -2
  218. warp/tests/test_examples.py +16 -84
  219. warp/tests/test_fabricarray.py +35 -35
  220. warp/tests/test_fast_math.py +0 -2
  221. warp/tests/test_fem.py +60 -14
  222. warp/tests/test_fixedarray.py +3 -3
  223. warp/tests/test_func.py +8 -5
  224. warp/tests/test_generics.py +1 -1
  225. warp/tests/test_indexedarray.py +24 -24
  226. warp/tests/test_intersect.py +39 -9
  227. warp/tests/test_large.py +1 -1
  228. warp/tests/test_lerp.py +3 -1
  229. warp/tests/test_linear_solvers.py +1 -1
  230. warp/tests/test_map.py +49 -4
  231. warp/tests/test_mat.py +52 -62
  232. warp/tests/test_mat_constructors.py +4 -5
  233. warp/tests/test_mat_lite.py +1 -1
  234. warp/tests/test_mat_scalar_ops.py +121 -121
  235. warp/tests/test_math.py +34 -0
  236. warp/tests/test_module_aot.py +4 -4
  237. warp/tests/test_modules_lite.py +28 -2
  238. warp/tests/test_print.py +11 -11
  239. warp/tests/test_quat.py +93 -58
  240. warp/tests/test_runlength_encode.py +1 -1
  241. warp/tests/test_scalar_ops.py +38 -10
  242. warp/tests/test_smoothstep.py +1 -1
  243. warp/tests/test_sparse.py +126 -15
  244. warp/tests/test_spatial.py +105 -87
  245. warp/tests/test_special_values.py +6 -6
  246. warp/tests/test_static.py +7 -7
  247. warp/tests/test_struct.py +13 -2
  248. warp/tests/test_triangle_closest_point.py +48 -1
  249. warp/tests/test_tuple.py +96 -0
  250. warp/tests/test_types.py +82 -9
  251. warp/tests/test_utils.py +52 -52
  252. warp/tests/test_vec.py +29 -29
  253. warp/tests/test_vec_constructors.py +5 -5
  254. warp/tests/test_vec_scalar_ops.py +97 -97
  255. warp/tests/test_version.py +75 -0
  256. warp/tests/tile/test_tile.py +239 -0
  257. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  258. warp/tests/tile/test_tile_cholesky.py +7 -4
  259. warp/tests/tile/test_tile_load.py +26 -2
  260. warp/tests/tile/test_tile_mathdx.py +3 -3
  261. warp/tests/tile/test_tile_matmul.py +1 -1
  262. warp/tests/tile/test_tile_mlp.py +2 -4
  263. warp/tests/tile/test_tile_reduce.py +214 -13
  264. warp/tests/unittest_suites.py +6 -14
  265. warp/tests/unittest_utils.py +10 -9
  266. warp/tests/walkthrough_debug.py +3 -1
  267. warp/torch.py +6 -373
  268. warp/types.py +29 -5750
  269. warp/utils.py +10 -1659
  270. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
  271. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  283. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  284. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  285. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  286. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  287. warp/examples/assets/cartpole.urdf +0 -110
  288. warp/examples/assets/crazyflie.usd +0 -0
  289. warp/examples/assets/nv_ant.xml +0 -92
  290. warp/examples/assets/nv_humanoid.xml +0 -183
  291. warp/examples/assets/quadruped.urdf +0 -268
  292. warp/examples/optim/example_bounce.py +0 -266
  293. warp/examples/optim/example_cloth_throw.py +0 -228
  294. warp/examples/optim/example_drone.py +0 -870
  295. warp/examples/optim/example_inverse_kinematics.py +0 -182
  296. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  297. warp/examples/optim/example_softbody_properties.py +0 -400
  298. warp/examples/optim/example_spring_cage.py +0 -245
  299. warp/examples/optim/example_trajectory.py +0 -227
  300. warp/examples/sim/example_cartpole.py +0 -143
  301. warp/examples/sim/example_cloth.py +0 -225
  302. warp/examples/sim/example_cloth_self_contact.py +0 -316
  303. warp/examples/sim/example_granular.py +0 -130
  304. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  305. warp/examples/sim/example_jacobian_ik.py +0 -244
  306. warp/examples/sim/example_particle_chain.py +0 -124
  307. warp/examples/sim/example_quadruped.py +0 -203
  308. warp/examples/sim/example_rigid_chain.py +0 -203
  309. warp/examples/sim/example_rigid_contact.py +0 -195
  310. warp/examples/sim/example_rigid_force.py +0 -133
  311. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  312. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  313. warp/examples/sim/example_soft_body.py +0 -196
  314. warp/examples/tile/example_tile_walker.py +0 -327
  315. warp/sim/__init__.py +0 -74
  316. warp/sim/articulation.py +0 -793
  317. warp/sim/collide.py +0 -2570
  318. warp/sim/graph_coloring.py +0 -307
  319. warp/sim/import_mjcf.py +0 -791
  320. warp/sim/import_snu.py +0 -227
  321. warp/sim/import_urdf.py +0 -579
  322. warp/sim/import_usd.py +0 -898
  323. warp/sim/inertia.py +0 -357
  324. warp/sim/integrator.py +0 -245
  325. warp/sim/integrator_euler.py +0 -2000
  326. warp/sim/integrator_featherstone.py +0 -2101
  327. warp/sim/integrator_vbd.py +0 -2487
  328. warp/sim/integrator_xpbd.py +0 -3295
  329. warp/sim/model.py +0 -4821
  330. warp/sim/particles.py +0 -121
  331. warp/sim/render.py +0 -431
  332. warp/sim/utils.py +0 -431
  333. warp/tests/sim/disabled_kinematics.py +0 -244
  334. warp/tests/sim/test_cloth.py +0 -863
  335. warp/tests/sim/test_collision.py +0 -743
  336. warp/tests/sim/test_coloring.py +0 -347
  337. warp/tests/sim/test_inertia.py +0 -161
  338. warp/tests/sim/test_model.py +0 -226
  339. warp/tests/sim/test_sim_grad.py +0 -287
  340. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  341. warp/tests/sim/test_sim_kinematics.py +0 -98
  342. warp/thirdparty/__init__.py +0 -0
  343. warp_lang-1.9.0.dist-info/RECORD +0 -456
  344. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  345. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  346. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  347. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  348. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  349. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  350. {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,383 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ import warp as wp
19
+ from warp._src.types import type_scalar_type
20
+
21
+
22
+ @wp.func
23
+ def generalized_outer(x: wp.vec(Any, wp.Scalar), y: wp.vec(Any, wp.Scalar)):
24
+ """Generalized outer product allowing for vector or scalar arguments"""
25
+ return wp.outer(x, y)
26
+
27
+
28
+ @wp.func
29
+ def generalized_outer(x: wp.Scalar, y: wp.vec(Any, wp.Scalar)):
30
+ return x * y
31
+
32
+
33
+ @wp.func
34
+ def generalized_outer(x: wp.vec(Any, wp.Scalar), y: wp.Scalar):
35
+ return x * y
36
+
37
+
38
+ @wp.func
39
+ def generalized_outer(x: wp.quatf, y: wp.vec(Any, wp.Scalar)):
40
+ return generalized_outer(wp.vec4(x[0], x[1], x[2], x[3]), y)
41
+
42
+
43
+ @wp.func
44
+ def generalized_inner(x: wp.vec(Any, wp.Scalar), y: wp.vec(Any, wp.Scalar)):
45
+ """Generalized inner product allowing for vector, tensor and scalar arguments"""
46
+ return wp.dot(x, y)
47
+
48
+
49
+ @wp.func
50
+ def generalized_inner(x: wp.Scalar, y: wp.Scalar):
51
+ return x * y
52
+
53
+
54
+ @wp.func
55
+ def generalized_inner(x: wp.mat((Any, Any), wp.Scalar), y: wp.vec(Any, wp.Scalar)):
56
+ return y @ x
57
+
58
+
59
+ @wp.func
60
+ def generalized_inner(x: wp.vec(Any, wp.Scalar), y: wp.mat((Any, Any), wp.Scalar)):
61
+ return y @ x
62
+
63
+
64
+ @wp.func
65
+ def basis_coefficient(val: wp.Scalar, i: int):
66
+ return val
67
+
68
+
69
+ @wp.func
70
+ def basis_coefficient(val: wp.mat((Any, Any), wp.Scalar), i: int):
71
+ cols = int(type(val[0]).length)
72
+ row = i // cols
73
+ col = i - row * cols
74
+ return val[row, col]
75
+
76
+
77
+ @wp.func
78
+ def basis_coefficient(val: Any, i: int):
79
+ return val[i]
80
+
81
+
82
+ @wp.func
83
+ def basis_coefficient(val: wp.vec(Any, wp.Scalar), i: int, j: int):
84
+ # treat as row vector
85
+ return val[j]
86
+
87
+
88
+ @wp.func
89
+ def basis_coefficient(val: wp.mat((Any, Any), wp.Scalar), i: int, j: int):
90
+ return val[i, j]
91
+
92
+
93
+ @wp.func
94
+ def symmetric_part(x: Any):
95
+ """Symmetric part of a square tensor"""
96
+ return 0.5 * (x + wp.transpose(x))
97
+
98
+
99
+ @wp.func
100
+ def spherical_part(x: wp.mat22):
101
+ """Spherical part of a square tensor"""
102
+ return 0.5 * wp.trace(x) * wp.identity(n=2, dtype=float)
103
+
104
+
105
+ @wp.func
106
+ def spherical_part(x: wp.mat33):
107
+ """Spherical part of a square tensor"""
108
+ return (wp.trace(x) / 3.0) * wp.identity(n=3, dtype=float)
109
+
110
+
111
+ @wp.func
112
+ def skew_part(x: wp.mat22):
113
+ """Skew part of a 2x2 tensor as corresponding rotation angle"""
114
+ return 0.5 * (x[1, 0] - x[0, 1])
115
+
116
+
117
+ @wp.func
118
+ def skew_part(x: wp.mat33):
119
+ """Skew part of a 3x3 tensor as the corresponding rotation vector"""
120
+ a = 0.5 * (x[2, 1] - x[1, 2])
121
+ b = 0.5 * (x[0, 2] - x[2, 0])
122
+ c = 0.5 * (x[1, 0] - x[0, 1])
123
+ return wp.vec3(a, b, c)
124
+
125
+
126
+ @wp.func
127
+ def householder_qr_decomposition(A: Any):
128
+ """
129
+ QR decomposition of a square matrix using Householder reflections
130
+
131
+ Returns Q and R such that Q R = A, Q orthonormal (such that QQ^T = Id), R upper triangular
132
+ """
133
+
134
+ x = type(A[0])()
135
+ Q = wp.identity(n=type(x).length, dtype=A.dtype)
136
+
137
+ zero = x.dtype(0.0)
138
+ two = x.dtype(2.0)
139
+
140
+ for i in range(type(x).length):
141
+ for k in range(type(x).length):
142
+ x[k] = wp.where(k < i, zero, A[k, i])
143
+
144
+ alpha = wp.length(x) * wp.sign(x[i])
145
+ x[i] += alpha
146
+ two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
147
+
148
+ A -= wp.outer(two_over_x_sq * x, x * A)
149
+ Q -= wp.outer(Q * x, two_over_x_sq * x)
150
+
151
+ return Q, A
152
+
153
+
154
+ @wp.func
155
+ def householder_make_hessenberg(A: Any):
156
+ """Transforms a square matrix to Hessenberg form (single lower diagonal) using Householder reflections
157
+
158
+ Returns:
159
+ Q and H such that Q H Q^T = A, Q orthonormal, H under Hessenberg form
160
+ If A is symmetric, H will be tridiagonal
161
+ """
162
+
163
+ x = type(A[0])()
164
+ Q = wp.identity(n=type(x).length, dtype=A.dtype)
165
+
166
+ zero = x.dtype(0.0)
167
+ two = x.dtype(2.0)
168
+
169
+ for i in range(1, type(x).length):
170
+ for k in range(type(x).length):
171
+ x[k] = wp.where(k < i, zero, A[k, i - 1])
172
+
173
+ alpha = wp.length(x) * wp.sign(x[i])
174
+ x[i] += alpha
175
+ two_over_x_sq = wp.where(alpha == zero, zero, two / wp.length_sq(x))
176
+
177
+ # apply on both sides
178
+ A -= wp.outer(two_over_x_sq * x, x * A)
179
+ A -= wp.outer(A * x, two_over_x_sq * x)
180
+ Q -= wp.outer(Q * x, two_over_x_sq * x)
181
+
182
+ return Q, A
183
+
184
+
185
+ @wp.func
186
+ def solve_triangular(R: Any, b: Any):
187
+ """Solves for R x = b where R is an upper triangular matrix
188
+
189
+ Returns x
190
+ """
191
+ zero = b.dtype(0)
192
+ x = type(b)(b.dtype(0))
193
+ for i in range(b.length, 0, -1):
194
+ j = i - 1
195
+ r = b[j] - wp.dot(R[j], x)
196
+ x[j] = wp.where(R[j, j] == zero, zero, r / R[j, j])
197
+
198
+ return x
199
+
200
+
201
+ @wp.func
202
+ def inverse_qr(A: Any):
203
+ # Computes a square matrix inverse using QR factorization
204
+
205
+ Q, R = householder_qr_decomposition(A)
206
+
207
+ A_inv = type(A)()
208
+ for i in range(type(A[0]).length):
209
+ A_inv[i] = solve_triangular(R, Q[i]) # ith column of Q^T
210
+
211
+ return wp.transpose(A_inv)
212
+
213
+
214
+ @wp.func
215
+ def _wilkinson_shift(a: Any, b: Any, c: Any, tol: Any):
216
+ # Wilkinson shift: estimate eigenvalue of 2x2 symmetric matrix [a, c, c, b]
217
+ d = (a - b) * type(tol)(0.5)
218
+ return b + d - wp.sign(d) * wp.sqrt(d * d + c * c)
219
+
220
+
221
+ @wp.func
222
+ def _givens_rotation(a: Any, b: Any):
223
+ # Givens rotation [[c -s], [s c]] such that sa+cb =0
224
+ zero = type(a)(0.0)
225
+ one = type(a)(1.0)
226
+
227
+ b2 = b * b
228
+ if b2 == zero:
229
+ # id rotation
230
+ return one, zero
231
+
232
+ scale = one / wp.sqrt(a * a + b2)
233
+ return a * scale, -b * scale
234
+
235
+
236
+ @wp.func
237
+ def tridiagonal_symmetric_eigenvalues_qr(D: Any, L: Any, Q: Any, tol: Any):
238
+ """
239
+ Computes the eigenvalues and eigen vectors of a symmetric tridiagonal matrix using the
240
+ Symmetric tridiagonal QR algorithm with implicit Wilkinson shift
241
+
242
+ Args:
243
+ D: Main diagonal of the matrix
244
+ L: Lower diagonal of the matrix, indexed such that L[i] = A[i+1, i]
245
+ Q: Initialization for the eigenvectors, useful if a pre-transformation has been applied, otherwise may be identity
246
+ tol: Tolerance for the diagonalization residual (Linf norm of off-diagonal over diagonal terms)
247
+
248
+ Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
249
+
250
+
251
+ Ref: Arbenz P, Numerical Methods for Solving Large Scale Eigenvalue Problems, Chapter 4 (QR algorithm, Mar 13, 2018)
252
+ """
253
+
254
+ two = D.dtype(2.0)
255
+ m = wp.static(len(D) + 1)
256
+
257
+ start = int(0)
258
+ y = D.dtype(0.0) # moving buldge
259
+ x = D.dtype(0.0) # coeff atop buldge
260
+
261
+ for _ in range(32 * m): # failsafe, usually converges faster than that
262
+ # Iterate over all independent (deflated) blocks
263
+ end = int(-1)
264
+
265
+ for k in range(m - 1):
266
+ if k >= end:
267
+ # Check if new block is starting
268
+ if k == end or wp.abs(L[k]) <= tol * (wp.abs(D[k]) + wp.abs(D[k + 1])):
269
+ continue
270
+
271
+ # Find end of block
272
+ start = k
273
+ end = start + 1
274
+ while end + 1 < m:
275
+ if wp.abs(L[end]) <= tol * (wp.abs(D[end + 1]) + wp.abs(D[end])):
276
+ break
277
+ end += 1
278
+
279
+ # Wilkinson shift (an eigenvalue of the last 2x2 block)
280
+ shift = _wilkinson_shift(D[end - 1], D[end], L[end - 1], tol)
281
+
282
+ # start with eliminating lower diag of first column of shifted matrix
283
+ # (i.e. first step of explicit QR factorization)
284
+ # Then all further steps eliminate the buldge (second diag) of the non-shifted matrix
285
+ x = D[start] - shift
286
+ y = L[start]
287
+
288
+ c, s = _givens_rotation(x, y)
289
+
290
+ # Apply Givens rotation on both sides of tridiagonal matrix
291
+
292
+ # middle block
293
+ d = D[k] - D[k + 1]
294
+ z = (two * c * L[k] + d * s) * s
295
+ D[k] -= z
296
+ D[k + 1] += z
297
+ L[k] = d * c * s + (c * c - s * s) * L[k]
298
+
299
+ if k > start:
300
+ L[k - 1] = c * x - s * y
301
+
302
+ x = L[k]
303
+ y = -s * L[k + 1] # new buldge
304
+ L[k + 1] *= c
305
+
306
+ # apply givens rotation on left of Q
307
+ # note: Q is transposed compared to usual impls, as Warp makes it easier to index rows
308
+ Qk0 = Q[k]
309
+ Qk1 = Q[k + 1]
310
+ Q[k] = c * Qk0 - s * Qk1
311
+ Q[k + 1] = c * Qk1 + s * Qk0
312
+
313
+ if end <= 0:
314
+ # We did nothing, so diagonalization must have been achieved
315
+ break
316
+
317
+ return D, Q
318
+
319
+
320
+ @wp.func
321
+ def symmetric_eigenvalues_qr(A: Any, tol: Any):
322
+ """
323
+ Computes the eigenvalues and eigen vectors of a square symmetric matrix A using the QR algorithm
324
+
325
+ Args:
326
+ A: square symmetric matrix
327
+ tol: Tolerance for the diagonalization residual (Linf norm of off-diagonal over diagonal terms)
328
+
329
+ Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
330
+ """
331
+
332
+ # Put A under Hessenberg form (tridiagonal)
333
+ Q, H = householder_make_hessenberg(A)
334
+
335
+ # tridiagonal storage for H
336
+ D = wp.get_diag(H)
337
+ L = type(D)(A.dtype(0.0))
338
+ for i in range(1, type(D).length):
339
+ L[i - 1] = H[i, i - 1]
340
+
341
+ Qt = wp.transpose(Q)
342
+ ev, P = tridiagonal_symmetric_eigenvalues_qr(D, L, Qt, tol)
343
+ return ev, P
344
+
345
+
346
+ def array_axpy(x: wp.array, y: wp.array, alpha: float = 1.0, beta: float = 1.0):
347
+ """Performs y = alpha*x + beta*y"""
348
+
349
+ from warp._src.context import runtime
350
+
351
+ dtype = type_scalar_type(y.dtype)
352
+
353
+ alpha = dtype(alpha)
354
+ beta = dtype(beta)
355
+
356
+ if x.shape != y.shape or x.device != y.device:
357
+ raise ValueError("x and y arrays must have the same shape and device")
358
+
359
+ # array_axpy requires a custom adjoint; unfortunately we cannot use `wp.func_grad`
360
+ # as generic functions are not supported yet. Instead we use a non-differentiable kernel
361
+ # and record a custom adjoint function on the tape.
362
+
363
+ # temporarily disable tape to avoid printing warning that kernel is not differentiable
364
+ (tape, runtime.tape) = (runtime.tape, None)
365
+ wp.launch(kernel=_array_axpy_kernel, dim=x.shape, device=x.device, inputs=[x, y, alpha, beta])
366
+ runtime.tape = tape
367
+
368
+ if tape is not None and (x.requires_grad or y.requires_grad):
369
+
370
+ def backward_axpy():
371
+ # adj_x += adj_y * alpha
372
+ # adj_y = adj_y * beta
373
+ array_axpy(x=y.grad, y=x.grad, alpha=alpha, beta=1.0)
374
+ if beta != 1.0:
375
+ array_axpy(x=y.grad, y=y.grad, alpha=0.0, beta=beta)
376
+
377
+ tape.record_func(backward_axpy, arrays=[x, y])
378
+
379
+
380
+ @wp.kernel(enable_backward=False)
381
+ def _array_axpy_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any), alpha: Any, beta: Any):
382
+ i = wp.tid()
383
+ y[i] = beta * y[i] + alpha * y.dtype(x[i])