warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (346) hide show
  1. warp/__init__.py +301 -287
  2. warp/__init__.pyi +794 -305
  3. warp/_src/__init__.py +14 -0
  4. warp/_src/autograd.py +1075 -0
  5. warp/_src/build.py +618 -0
  6. warp/_src/build_dll.py +640 -0
  7. warp/{builtins.py → _src/builtins.py} +1382 -377
  8. warp/_src/codegen.py +4359 -0
  9. warp/{config.py → _src/config.py} +178 -169
  10. warp/_src/constants.py +57 -0
  11. warp/_src/context.py +8294 -0
  12. warp/_src/dlpack.py +462 -0
  13. warp/_src/fabric.py +355 -0
  14. warp/_src/fem/__init__.py +14 -0
  15. warp/_src/fem/adaptivity.py +508 -0
  16. warp/_src/fem/cache.py +687 -0
  17. warp/_src/fem/dirichlet.py +188 -0
  18. warp/{fem → _src/fem}/domain.py +40 -30
  19. warp/_src/fem/field/__init__.py +131 -0
  20. warp/_src/fem/field/field.py +701 -0
  21. warp/{fem → _src/fem}/field/nodal_field.py +30 -15
  22. warp/{fem → _src/fem}/field/restriction.py +1 -1
  23. warp/{fem → _src/fem}/field/virtual.py +53 -27
  24. warp/_src/fem/geometry/__init__.py +32 -0
  25. warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
  26. warp/_src/fem/geometry/closest_point.py +97 -0
  27. warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
  28. warp/{fem → _src/fem}/geometry/element.py +32 -10
  29. warp/{fem → _src/fem}/geometry/geometry.py +48 -20
  30. warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
  31. warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
  32. warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
  33. warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
  34. warp/{fem → _src/fem}/geometry/partition.py +121 -63
  35. warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
  36. warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
  37. warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
  38. warp/{fem → _src/fem}/integrate.py +164 -158
  39. warp/_src/fem/linalg.py +383 -0
  40. warp/_src/fem/operator.py +396 -0
  41. warp/_src/fem/polynomial.py +229 -0
  42. warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
  43. warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
  44. warp/_src/fem/space/__init__.py +248 -0
  45. warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
  46. warp/_src/fem/space/basis_space.py +679 -0
  47. warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
  48. warp/{fem → _src/fem}/space/function_space.py +14 -13
  49. warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
  50. warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
  51. warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
  52. warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
  53. warp/{fem → _src/fem}/space/partition.py +117 -60
  54. warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
  55. warp/{fem → _src/fem}/space/restriction.py +66 -33
  56. warp/_src/fem/space/shape/__init__.py +152 -0
  57. warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
  58. warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
  59. warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
  60. warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
  61. warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
  62. warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
  63. warp/_src/fem/space/topology.py +459 -0
  64. warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
  65. warp/_src/fem/types.py +112 -0
  66. warp/_src/fem/utils.py +486 -0
  67. warp/_src/jax.py +186 -0
  68. warp/_src/jax_experimental/__init__.py +14 -0
  69. warp/_src/jax_experimental/custom_call.py +387 -0
  70. warp/_src/jax_experimental/ffi.py +1284 -0
  71. warp/_src/jax_experimental/xla_ffi.py +656 -0
  72. warp/_src/marching_cubes.py +708 -0
  73. warp/_src/math.py +414 -0
  74. warp/_src/optim/__init__.py +14 -0
  75. warp/_src/optim/adam.py +163 -0
  76. warp/_src/optim/linear.py +1606 -0
  77. warp/_src/optim/sgd.py +112 -0
  78. warp/_src/paddle.py +406 -0
  79. warp/_src/render/__init__.py +14 -0
  80. warp/_src/render/imgui_manager.py +289 -0
  81. warp/_src/render/render_opengl.py +3636 -0
  82. warp/_src/render/render_usd.py +937 -0
  83. warp/_src/render/utils.py +160 -0
  84. warp/_src/sparse.py +2716 -0
  85. warp/_src/tape.py +1206 -0
  86. warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
  87. warp/_src/torch.py +391 -0
  88. warp/_src/types.py +5870 -0
  89. warp/_src/utils.py +1693 -0
  90. warp/autograd.py +12 -1054
  91. warp/bin/warp-clang.dll +0 -0
  92. warp/bin/warp.dll +0 -0
  93. warp/build.py +8 -588
  94. warp/build_dll.py +6 -721
  95. warp/codegen.py +6 -4251
  96. warp/constants.py +6 -39
  97. warp/context.py +12 -8062
  98. warp/dlpack.py +6 -444
  99. warp/examples/distributed/example_jacobi_mpi.py +4 -5
  100. warp/examples/fem/example_adaptive_grid.py +1 -1
  101. warp/examples/fem/example_apic_fluid.py +1 -1
  102. warp/examples/fem/example_burgers.py +8 -8
  103. warp/examples/fem/example_diffusion.py +1 -1
  104. warp/examples/fem/example_distortion_energy.py +1 -1
  105. warp/examples/fem/example_mixed_elasticity.py +2 -2
  106. warp/examples/fem/example_navier_stokes.py +1 -1
  107. warp/examples/fem/example_nonconforming_contact.py +7 -7
  108. warp/examples/fem/example_stokes.py +1 -1
  109. warp/examples/fem/example_stokes_transfer.py +1 -1
  110. warp/examples/fem/utils.py +2 -2
  111. warp/examples/interop/example_jax_callable.py +1 -1
  112. warp/examples/interop/example_jax_ffi_callback.py +1 -1
  113. warp/examples/interop/example_jax_kernel.py +1 -1
  114. warp/examples/tile/example_tile_mcgp.py +191 -0
  115. warp/fabric.py +6 -337
  116. warp/fem/__init__.py +159 -97
  117. warp/fem/adaptivity.py +7 -489
  118. warp/fem/cache.py +9 -648
  119. warp/fem/dirichlet.py +6 -184
  120. warp/fem/field/__init__.py +8 -109
  121. warp/fem/field/field.py +7 -652
  122. warp/fem/geometry/__init__.py +7 -18
  123. warp/fem/geometry/closest_point.py +11 -77
  124. warp/fem/linalg.py +18 -366
  125. warp/fem/operator.py +11 -369
  126. warp/fem/polynomial.py +9 -209
  127. warp/fem/space/__init__.py +5 -211
  128. warp/fem/space/basis_space.py +6 -662
  129. warp/fem/space/shape/__init__.py +41 -118
  130. warp/fem/space/topology.py +6 -437
  131. warp/fem/types.py +6 -81
  132. warp/fem/utils.py +11 -444
  133. warp/jax.py +8 -165
  134. warp/jax_experimental/__init__.py +14 -1
  135. warp/jax_experimental/custom_call.py +8 -365
  136. warp/jax_experimental/ffi.py +17 -873
  137. warp/jax_experimental/xla_ffi.py +5 -605
  138. warp/marching_cubes.py +5 -689
  139. warp/math.py +16 -393
  140. warp/native/array.h +385 -37
  141. warp/native/builtin.h +314 -37
  142. warp/native/bvh.cpp +43 -9
  143. warp/native/bvh.cu +62 -27
  144. warp/native/bvh.h +310 -309
  145. warp/native/clang/clang.cpp +102 -97
  146. warp/native/coloring.cpp +0 -1
  147. warp/native/crt.h +208 -0
  148. warp/native/exports.h +156 -0
  149. warp/native/hashgrid.cu +2 -0
  150. warp/native/intersect.h +24 -1
  151. warp/native/intersect_tri.h +44 -35
  152. warp/native/mat.h +1456 -276
  153. warp/native/mesh.cpp +4 -4
  154. warp/native/mesh.cu +4 -2
  155. warp/native/mesh.h +176 -61
  156. warp/native/quat.h +0 -52
  157. warp/native/scan.cu +2 -0
  158. warp/native/sparse.cu +7 -3
  159. warp/native/spatial.h +12 -0
  160. warp/native/tile.h +681 -89
  161. warp/native/tile_radix_sort.h +1 -1
  162. warp/native/tile_reduce.h +394 -46
  163. warp/native/tile_scan.h +4 -4
  164. warp/native/vec.h +469 -0
  165. warp/native/version.h +23 -0
  166. warp/native/volume.cpp +1 -1
  167. warp/native/volume.cu +1 -0
  168. warp/native/volume.h +1 -1
  169. warp/native/volume_builder.cu +2 -0
  170. warp/native/warp.cpp +57 -29
  171. warp/native/warp.cu +253 -171
  172. warp/native/warp.h +11 -8
  173. warp/optim/__init__.py +6 -3
  174. warp/optim/adam.py +6 -145
  175. warp/optim/linear.py +14 -1585
  176. warp/optim/sgd.py +6 -94
  177. warp/paddle.py +6 -388
  178. warp/render/__init__.py +8 -4
  179. warp/render/imgui_manager.py +7 -267
  180. warp/render/render_opengl.py +6 -3618
  181. warp/render/render_usd.py +6 -919
  182. warp/render/utils.py +6 -142
  183. warp/sparse.py +37 -2563
  184. warp/tape.py +6 -1188
  185. warp/tests/__main__.py +1 -1
  186. warp/tests/cuda/test_async.py +4 -4
  187. warp/tests/cuda/test_conditional_captures.py +1 -1
  188. warp/tests/cuda/test_multigpu.py +1 -1
  189. warp/tests/cuda/test_streams.py +58 -1
  190. warp/tests/geometry/test_bvh.py +157 -22
  191. warp/tests/geometry/test_marching_cubes.py +0 -1
  192. warp/tests/geometry/test_mesh.py +5 -3
  193. warp/tests/geometry/test_mesh_query_aabb.py +5 -12
  194. warp/tests/geometry/test_mesh_query_point.py +5 -2
  195. warp/tests/geometry/test_mesh_query_ray.py +15 -3
  196. warp/tests/geometry/test_volume_write.py +5 -5
  197. warp/tests/interop/test_dlpack.py +14 -14
  198. warp/tests/interop/test_jax.py +772 -49
  199. warp/tests/interop/test_paddle.py +1 -1
  200. warp/tests/test_adam.py +0 -1
  201. warp/tests/test_arithmetic.py +9 -9
  202. warp/tests/test_array.py +527 -100
  203. warp/tests/test_array_reduce.py +3 -3
  204. warp/tests/test_atomic.py +12 -8
  205. warp/tests/test_atomic_bitwise.py +209 -0
  206. warp/tests/test_atomic_cas.py +4 -4
  207. warp/tests/test_bool.py +2 -2
  208. warp/tests/test_builtins_resolution.py +5 -571
  209. warp/tests/test_codegen.py +33 -14
  210. warp/tests/test_conditional.py +1 -1
  211. warp/tests/test_context.py +6 -6
  212. warp/tests/test_copy.py +242 -161
  213. warp/tests/test_ctypes.py +3 -3
  214. warp/tests/test_devices.py +24 -2
  215. warp/tests/test_examples.py +16 -84
  216. warp/tests/test_fabricarray.py +35 -35
  217. warp/tests/test_fast_math.py +0 -2
  218. warp/tests/test_fem.py +56 -10
  219. warp/tests/test_fixedarray.py +3 -3
  220. warp/tests/test_func.py +8 -5
  221. warp/tests/test_generics.py +1 -1
  222. warp/tests/test_indexedarray.py +24 -24
  223. warp/tests/test_intersect.py +39 -9
  224. warp/tests/test_large.py +1 -1
  225. warp/tests/test_lerp.py +3 -1
  226. warp/tests/test_linear_solvers.py +1 -1
  227. warp/tests/test_map.py +35 -4
  228. warp/tests/test_mat.py +52 -62
  229. warp/tests/test_mat_constructors.py +4 -5
  230. warp/tests/test_mat_lite.py +1 -1
  231. warp/tests/test_mat_scalar_ops.py +121 -121
  232. warp/tests/test_math.py +34 -0
  233. warp/tests/test_module_aot.py +4 -4
  234. warp/tests/test_modules_lite.py +28 -2
  235. warp/tests/test_print.py +11 -11
  236. warp/tests/test_quat.py +93 -58
  237. warp/tests/test_runlength_encode.py +1 -1
  238. warp/tests/test_scalar_ops.py +38 -10
  239. warp/tests/test_smoothstep.py +1 -1
  240. warp/tests/test_sparse.py +126 -15
  241. warp/tests/test_spatial.py +105 -87
  242. warp/tests/test_special_values.py +6 -6
  243. warp/tests/test_static.py +7 -7
  244. warp/tests/test_struct.py +13 -2
  245. warp/tests/test_triangle_closest_point.py +48 -1
  246. warp/tests/test_types.py +27 -15
  247. warp/tests/test_utils.py +52 -52
  248. warp/tests/test_vec.py +29 -29
  249. warp/tests/test_vec_constructors.py +5 -5
  250. warp/tests/test_vec_scalar_ops.py +97 -97
  251. warp/tests/test_version.py +75 -0
  252. warp/tests/tile/test_tile.py +178 -0
  253. warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
  254. warp/tests/tile/test_tile_cholesky.py +7 -4
  255. warp/tests/tile/test_tile_load.py +26 -2
  256. warp/tests/tile/test_tile_mathdx.py +3 -3
  257. warp/tests/tile/test_tile_matmul.py +1 -1
  258. warp/tests/tile/test_tile_mlp.py +2 -4
  259. warp/tests/tile/test_tile_reduce.py +214 -13
  260. warp/tests/unittest_suites.py +6 -14
  261. warp/tests/unittest_utils.py +10 -9
  262. warp/tests/walkthrough_debug.py +3 -1
  263. warp/torch.py +6 -373
  264. warp/types.py +29 -5764
  265. warp/utils.py +10 -1659
  266. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
  267. warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
  268. warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
  269. warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
  270. warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
  271. warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
  272. warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
  273. warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
  274. warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
  275. warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
  276. warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
  277. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
  278. warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
  279. warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
  280. warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
  281. warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
  282. warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
  283. warp/examples/assets/cartpole.urdf +0 -110
  284. warp/examples/assets/crazyflie.usd +0 -0
  285. warp/examples/assets/nv_ant.xml +0 -92
  286. warp/examples/assets/nv_humanoid.xml +0 -183
  287. warp/examples/assets/quadruped.urdf +0 -268
  288. warp/examples/optim/example_bounce.py +0 -266
  289. warp/examples/optim/example_cloth_throw.py +0 -228
  290. warp/examples/optim/example_drone.py +0 -870
  291. warp/examples/optim/example_inverse_kinematics.py +0 -182
  292. warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
  293. warp/examples/optim/example_softbody_properties.py +0 -400
  294. warp/examples/optim/example_spring_cage.py +0 -245
  295. warp/examples/optim/example_trajectory.py +0 -227
  296. warp/examples/sim/example_cartpole.py +0 -143
  297. warp/examples/sim/example_cloth.py +0 -225
  298. warp/examples/sim/example_cloth_self_contact.py +0 -316
  299. warp/examples/sim/example_granular.py +0 -130
  300. warp/examples/sim/example_granular_collision_sdf.py +0 -202
  301. warp/examples/sim/example_jacobian_ik.py +0 -244
  302. warp/examples/sim/example_particle_chain.py +0 -124
  303. warp/examples/sim/example_quadruped.py +0 -203
  304. warp/examples/sim/example_rigid_chain.py +0 -203
  305. warp/examples/sim/example_rigid_contact.py +0 -195
  306. warp/examples/sim/example_rigid_force.py +0 -133
  307. warp/examples/sim/example_rigid_gyroscopic.py +0 -115
  308. warp/examples/sim/example_rigid_soft_contact.py +0 -140
  309. warp/examples/sim/example_soft_body.py +0 -196
  310. warp/examples/tile/example_tile_walker.py +0 -327
  311. warp/sim/__init__.py +0 -74
  312. warp/sim/articulation.py +0 -793
  313. warp/sim/collide.py +0 -2570
  314. warp/sim/graph_coloring.py +0 -307
  315. warp/sim/import_mjcf.py +0 -791
  316. warp/sim/import_snu.py +0 -227
  317. warp/sim/import_urdf.py +0 -579
  318. warp/sim/import_usd.py +0 -898
  319. warp/sim/inertia.py +0 -357
  320. warp/sim/integrator.py +0 -245
  321. warp/sim/integrator_euler.py +0 -2000
  322. warp/sim/integrator_featherstone.py +0 -2101
  323. warp/sim/integrator_vbd.py +0 -2487
  324. warp/sim/integrator_xpbd.py +0 -3295
  325. warp/sim/model.py +0 -4821
  326. warp/sim/particles.py +0 -121
  327. warp/sim/render.py +0 -431
  328. warp/sim/utils.py +0 -431
  329. warp/tests/sim/disabled_kinematics.py +0 -244
  330. warp/tests/sim/test_cloth.py +0 -863
  331. warp/tests/sim/test_collision.py +0 -743
  332. warp/tests/sim/test_coloring.py +0 -347
  333. warp/tests/sim/test_inertia.py +0 -161
  334. warp/tests/sim/test_model.py +0 -226
  335. warp/tests/sim/test_sim_grad.py +0 -287
  336. warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
  337. warp/tests/sim/test_sim_kinematics.py +0 -98
  338. warp/thirdparty/__init__.py +0 -0
  339. warp_lang-1.9.1.dist-info/RECORD +0 -456
  340. /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
  341. /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
  342. /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
  343. /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
  344. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
  345. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
  346. {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,1668 +13,19 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from __future__ import annotations
16
+ # isort: skip_file
17
17
 
18
- import cProfile
19
- import ctypes
20
- import os
21
- import sys
22
- import time
23
- import warnings
24
- from types import ModuleType
25
- from typing import Any, Callable
18
+ from warp._src.utils import array_cast as array_cast
19
+ from warp._src.utils import segmented_sort_pairs as segmented_sort_pairs
20
+ from warp._src.utils import warn as warn
26
21
 
27
- import numpy as np
28
22
 
29
- import warp as wp
30
- import warp.context
31
- import warp.types
32
- from warp.context import Devicelike
33
- from warp.types import Array, DType, type_repr, types_equal
23
+ # TODO: Remove after cleaning up the public API.
34
24
 
35
- warnings_seen = set()
25
+ from warp._src import utils as _utils
36
26
 
37
27
 
38
- def warp_showwarning(message, category, filename, lineno, file=None, line=None):
39
- """Version of warnings.showwarning that always prints to sys.stdout."""
28
+ def __getattr__(name):
29
+ from warp._src.utils import get_deprecated_api
40
30
 
41
- if warp.config.verbose_warnings:
42
- s = f"Warp {category.__name__}: {message} ({filename}:{lineno})\n"
43
-
44
- if line is None:
45
- try:
46
- import linecache
47
-
48
- line = linecache.getline(filename, lineno)
49
- except Exception:
50
- # When a warning is logged during Python shutdown, linecache
51
- # and the import machinery don't work anymore
52
- line = None
53
- linecache = None
54
-
55
- if line:
56
- line = line.strip()
57
- s += f" {line}\n"
58
- else:
59
- # simple warning
60
- s = f"Warp {category.__name__}: {message}\n"
61
-
62
- sys.stdout.write(s)
63
-
64
-
65
- def warn(message, category=None, stacklevel=1):
66
- if (category, message) in warnings_seen:
67
- return
68
-
69
- with warnings.catch_warnings():
70
- warnings.simplefilter("default") # Change the filter in this process
71
- warnings.showwarning = warp_showwarning
72
- warnings.warn(
73
- message,
74
- category,
75
- stacklevel=stacklevel + 1, # Increment stacklevel by 1 since we are in a wrapper
76
- )
77
-
78
- if category is DeprecationWarning:
79
- warnings_seen.add((category, message))
80
-
81
-
82
- # expand a 7-vec to a tuple of arrays
83
- def transform_expand(t):
84
- return wp.transform(np.array(t[0:3]), np.array(t[3:7]))
85
-
86
-
87
- @wp.func
88
- def quat_between_vectors(a: wp.vec3, b: wp.vec3) -> wp.quat:
89
- """
90
- Compute the quaternion that rotates vector a to vector b
91
- """
92
- a = wp.normalize(a)
93
- b = wp.normalize(b)
94
- c = wp.cross(a, b)
95
- d = wp.dot(a, b)
96
- q = wp.quat(c[0], c[1], c[2], 1.0 + d)
97
- return wp.normalize(q)
98
-
99
-
100
- def array_scan(in_array, out_array, inclusive=True):
101
- """Perform a scan (prefix sum) operation on an array.
102
-
103
- This function computes the inclusive or exclusive scan of the input array and stores the result in the output array.
104
- The scan operation computes a running sum of elements in the array.
105
-
106
- Args:
107
- in_array (wp.array): Input array to scan. Must be of type int32 or float32.
108
- out_array (wp.array): Output array to store scan results. Must match input array type and size.
109
- inclusive (bool, optional): If True, performs an inclusive scan (includes current element in sum).
110
- If False, performs an exclusive scan (excludes current element). Defaults to True.
111
-
112
- Raises:
113
- RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
114
- """
115
-
116
- if in_array.device != out_array.device:
117
- raise RuntimeError(f"In and out array storage devices do not match ({in_array.device} vs {out_array.device})")
118
-
119
- if in_array.size != out_array.size:
120
- raise RuntimeError(f"In and out array storage sizes do not match ({in_array.size} vs {out_array.size})")
121
-
122
- if not types_equal(in_array.dtype, out_array.dtype):
123
- raise RuntimeError(
124
- f"In and out array data types do not match ({type_repr(in_array.dtype)} vs {type_repr(out_array.dtype)})"
125
- )
126
-
127
- if in_array.size == 0:
128
- return
129
-
130
- from warp.context import runtime
131
-
132
- if in_array.device.is_cpu:
133
- if in_array.dtype == wp.int32:
134
- runtime.core.wp_array_scan_int_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
135
- elif in_array.dtype == wp.float32:
136
- runtime.core.wp_array_scan_float_host(in_array.ptr, out_array.ptr, in_array.size, inclusive)
137
- else:
138
- raise RuntimeError(f"Unsupported data type: {type_repr(in_array.dtype)}")
139
- elif in_array.device.is_cuda:
140
- if in_array.dtype == wp.int32:
141
- runtime.core.wp_array_scan_int_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
142
- elif in_array.dtype == wp.float32:
143
- runtime.core.wp_array_scan_float_device(in_array.ptr, out_array.ptr, in_array.size, inclusive)
144
- else:
145
- raise RuntimeError(f"Unsupported data type: {type_repr(in_array.dtype)}")
146
-
147
-
148
- def radix_sort_pairs(keys, values, count: int):
149
- """Sort key-value pairs using radix sort.
150
-
151
- This function sorts pairs of arrays based on the keys array, maintaining the key-value
152
- relationship. The sort is stable and operates in linear time.
153
- The `keys` and `values` arrays must be large enough to accommodate 2*`count` elements.
154
-
155
- Args:
156
- keys (wp.array): Array of keys to sort. Must be of type int32, float32, or int64.
157
- values (wp.array): Array of values to sort along with keys. Must be of type int32.
158
- count (int): Number of elements to sort.
159
-
160
- Raises:
161
- RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
162
- """
163
- if keys.device != values.device:
164
- raise RuntimeError(f"Keys and values array storage devices do not match ({keys.device} vs {values.device})")
165
-
166
- if count == 0:
167
- return
168
-
169
- if keys.size < 2 * count or values.size < 2 * count:
170
- raise RuntimeError("Keys and values array storage must be large enough to contain 2*count elements")
171
-
172
- from warp.context import runtime
173
-
174
- if keys.device.is_cpu:
175
- if keys.dtype == wp.int32 and values.dtype == wp.int32:
176
- runtime.core.wp_radix_sort_pairs_int_host(keys.ptr, values.ptr, count)
177
- elif keys.dtype == wp.float32 and values.dtype == wp.int32:
178
- runtime.core.wp_radix_sort_pairs_float_host(keys.ptr, values.ptr, count)
179
- elif keys.dtype == wp.int64 and values.dtype == wp.int32:
180
- runtime.core.wp_radix_sort_pairs_int64_host(keys.ptr, values.ptr, count)
181
- else:
182
- raise RuntimeError(
183
- f"Unsupported keys and values data types: {type_repr(keys.dtype)}, {type_repr(values.dtype)}"
184
- )
185
- elif keys.device.is_cuda:
186
- if keys.dtype == wp.int32 and values.dtype == wp.int32:
187
- runtime.core.wp_radix_sort_pairs_int_device(keys.ptr, values.ptr, count)
188
- elif keys.dtype == wp.float32 and values.dtype == wp.int32:
189
- runtime.core.wp_radix_sort_pairs_float_device(keys.ptr, values.ptr, count)
190
- elif keys.dtype == wp.int64 and values.dtype == wp.int32:
191
- runtime.core.wp_radix_sort_pairs_int64_device(keys.ptr, values.ptr, count)
192
- else:
193
- raise RuntimeError(
194
- f"Unsupported keys and values data types: {type_repr(keys.dtype)}, {type_repr(values.dtype)}"
195
- )
196
-
197
-
198
- def segmented_sort_pairs(
199
- keys,
200
- values,
201
- count: int,
202
- segment_start_indices: wp.array(dtype=wp.int32),
203
- segment_end_indices: wp.array(dtype=wp.int32) = None,
204
- ):
205
- """Sort key-value pairs within segments.
206
-
207
- This function performs a segmented sort of key-value pairs, where the sorting is done independently within each segment.
208
- The segments are defined by their start and optionally end indices.
209
- The `keys` and `values` arrays must be large enough to accommodate 2*`count` elements.
210
-
211
- Args:
212
- keys: Array of keys to sort. Must be of type int32 or float32.
213
- values: Array of values to sort along with keys. Must be of type int32.
214
- count: Number of elements to sort.
215
- segment_start_indices: Array containing start index of each segment. Must be of type int32.
216
- If segment_end_indices is None, this array must have length at least num_segments + 1,
217
- and segment_end_indices will be inferred as segment_start_indices[1:].
218
- If segment_end_indices is provided, this array must have length at least num_segments.
219
- segment_end_indices: Optional array containing end index of each segment. Must be of type int32 if provided.
220
- If None, segment_end_indices will be inferred from segment_start_indices[1:].
221
- If provided, must have length at least num_segments.
222
-
223
- Raises:
224
- RuntimeError: If array storage devices don't match, if storage size is insufficient,
225
- if segment_start_indices is not of type int32, or if data types are unsupported.
226
- """
227
- if keys.device != values.device:
228
- raise RuntimeError(f"Array storage devices do not match ({keys.device} vs {values.device})")
229
-
230
- if count == 0:
231
- return
232
-
233
- if keys.size < 2 * count or values.size < 2 * count:
234
- raise RuntimeError("Array storage must be large enough to contain 2*count elements")
235
-
236
- from warp.context import runtime
237
-
238
- if segment_start_indices.dtype != wp.int32:
239
- raise RuntimeError("segment_start_indices array must be of type int32")
240
-
241
- # Handle case where segment_end_indices is not provided
242
- if segment_end_indices is None:
243
- num_segments = max(0, segment_start_indices.size - 1)
244
-
245
- segment_end_indices = segment_start_indices[1:]
246
- segment_end_indices_ptr = segment_end_indices.ptr
247
- segment_start_indices_ptr = segment_start_indices.ptr
248
- else:
249
- if segment_end_indices.dtype != wp.int32:
250
- raise RuntimeError("segment_end_indices array must be of type int32")
251
-
252
- num_segments = segment_start_indices.size
253
-
254
- segment_end_indices_ptr = segment_end_indices.ptr
255
- segment_start_indices_ptr = segment_start_indices.ptr
256
-
257
- if keys.device.is_cpu:
258
- if keys.dtype == wp.int32 and values.dtype == wp.int32:
259
- runtime.core.wp_segmented_sort_pairs_int_host(
260
- keys.ptr,
261
- values.ptr,
262
- count,
263
- segment_start_indices_ptr,
264
- segment_end_indices_ptr,
265
- num_segments,
266
- )
267
- elif keys.dtype == wp.float32 and values.dtype == wp.int32:
268
- runtime.core.wp_segmented_sort_pairs_float_host(
269
- keys.ptr,
270
- values.ptr,
271
- count,
272
- segment_start_indices_ptr,
273
- segment_end_indices_ptr,
274
- num_segments,
275
- )
276
- else:
277
- raise RuntimeError(f"Unsupported data type: {type_repr(keys.dtype)}")
278
- elif keys.device.is_cuda:
279
- if keys.dtype == wp.int32 and values.dtype == wp.int32:
280
- runtime.core.wp_segmented_sort_pairs_int_device(
281
- keys.ptr,
282
- values.ptr,
283
- count,
284
- segment_start_indices_ptr,
285
- segment_end_indices_ptr,
286
- num_segments,
287
- )
288
- elif keys.dtype == wp.float32 and values.dtype == wp.int32:
289
- runtime.core.wp_segmented_sort_pairs_float_device(
290
- keys.ptr,
291
- values.ptr,
292
- count,
293
- segment_start_indices_ptr,
294
- segment_end_indices_ptr,
295
- num_segments,
296
- )
297
- else:
298
- raise RuntimeError(f"Unsupported data type: {type_repr(keys.dtype)}")
299
-
300
-
301
- def runlength_encode(values, run_values, run_lengths, run_count=None, value_count=None):
302
- """Perform run-length encoding on an array.
303
-
304
- This function compresses an array by replacing consecutive identical values with a single value
305
- and its count. For example, [1,1,1,2,2,3] becomes values=[1,2,3] and lengths=[3,2,1].
306
-
307
- Args:
308
- values (wp.array): Input array to encode. Must be of type int32.
309
- run_values (wp.array): Output array to store unique values. Must be at least value_count in size.
310
- run_lengths (wp.array): Output array to store run lengths. Must be at least value_count in size.
311
- run_count (wp.array, optional): Optional output array to store the number of runs.
312
- If None, returns the count as an integer.
313
- value_count (int, optional): Number of values to process. If None, processes entire array.
314
-
315
- Returns:
316
- int or wp.array: Number of runs if run_count is None, otherwise returns run_count array.
317
-
318
- Raises:
319
- RuntimeError: If array storage devices don't match, if storage size is insufficient, or if data types are unsupported.
320
- """
321
- if run_values.device != values.device or run_lengths.device != values.device:
322
- raise RuntimeError("run_values, run_lengths and values storage devices do not match")
323
-
324
- if value_count is None:
325
- value_count = values.size
326
-
327
- if run_values.size < value_count or run_lengths.size < value_count:
328
- raise RuntimeError(f"Output array storage sizes must be at least equal to value_count ({value_count})")
329
-
330
- if not types_equal(values.dtype, run_values.dtype):
331
- raise RuntimeError(
332
- f"values and run_values data types do not match ({type_repr(values.dtype)} vs {type_repr(run_values.dtype)})"
333
- )
334
-
335
- if run_lengths.dtype != wp.int32:
336
- raise RuntimeError("run_lengths array must be of type int32")
337
-
338
- # User can provide a device output array for storing the number of runs
339
- # For convenience, if no such array is provided, number of runs is returned on host
340
- if run_count is None:
341
- if value_count == 0:
342
- return 0
343
- run_count = wp.empty(shape=(1,), dtype=int, device=values.device)
344
- host_return = True
345
- else:
346
- if run_count.device != values.device:
347
- raise RuntimeError("run_count storage device does not match other arrays")
348
- if run_count.dtype != wp.int32:
349
- raise RuntimeError("run_count array must be of type int32")
350
- if value_count == 0:
351
- run_count.zero_()
352
- return run_count
353
- host_return = False
354
-
355
- from warp.context import runtime
356
-
357
- if values.device.is_cpu:
358
- if values.dtype == wp.int32:
359
- runtime.core.wp_runlength_encode_int_host(
360
- values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
361
- )
362
- else:
363
- raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
364
- elif values.device.is_cuda:
365
- if values.dtype == wp.int32:
366
- runtime.core.wp_runlength_encode_int_device(
367
- values.ptr, run_values.ptr, run_lengths.ptr, run_count.ptr, value_count
368
- )
369
- else:
370
- raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
371
-
372
- if host_return:
373
- return int(run_count.numpy()[0])
374
- return run_count
375
-
376
-
377
- def array_sum(values, out=None, value_count=None, axis=None):
378
- """Compute the sum of array elements.
379
-
380
- This function computes the sum of array elements, optionally along a specified axis.
381
- The operation can be performed on the entire array or along a specific dimension.
382
-
383
- Args:
384
- values (wp.array): Input array to sum. Must be of type float32 or float64.
385
- out (wp.array, optional): Output array to store results. If None, a new array is created.
386
- value_count (int, optional): Number of elements to process. If None, processes entire array.
387
- axis (int, optional): Axis along which to compute sum. If None, computes sum of all elements.
388
-
389
- Returns:
390
- wp.array or float: The sum result. Returns a float if axis is None and out is None,
391
- otherwise returns the output array.
392
-
393
- Raises:
394
- RuntimeError: If output array storage device or data type is incompatible with input array.
395
- """
396
- if value_count is None:
397
- if axis is None:
398
- value_count = values.size
399
- else:
400
- value_count = values.shape[axis]
401
-
402
- if axis is None:
403
- output_shape = (1,)
404
- else:
405
-
406
- def output_dim(ax, dim):
407
- return 1 if ax == axis else dim
408
-
409
- output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(values.shape))
410
-
411
- type_size = wp.types.type_size(values.dtype)
412
- scalar_type = wp.types.type_scalar_type(values.dtype)
413
-
414
- # User can provide a device output array for storing the number of runs
415
- # For convenience, if no such array is provided, number of runs is returned on host
416
- if out is None:
417
- host_return = True
418
- out = wp.empty(shape=output_shape, dtype=values.dtype, device=values.device)
419
- else:
420
- host_return = False
421
- if out.device != values.device:
422
- raise RuntimeError("out storage device should match values array")
423
- if out.dtype != values.dtype:
424
- raise RuntimeError(f"out array should have type {values.dtype.__name__}")
425
- if out.shape != output_shape:
426
- raise RuntimeError(f"out array should have shape {output_shape}")
427
-
428
- if value_count == 0:
429
- out.zero_()
430
- if axis is None and host_return:
431
- return out.numpy()[0]
432
- return out
433
-
434
- from warp.context import runtime
435
-
436
- if values.device.is_cpu:
437
- if scalar_type == wp.float32:
438
- native_func = runtime.core.wp_array_sum_float_host
439
- elif scalar_type == wp.float64:
440
- native_func = runtime.core.wp_array_sum_double_host
441
- else:
442
- raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
443
- elif values.device.is_cuda:
444
- if scalar_type == wp.float32:
445
- native_func = runtime.core.wp_array_sum_float_device
446
- elif scalar_type == wp.float64:
447
- native_func = runtime.core.wp_array_sum_double_device
448
- else:
449
- raise RuntimeError(f"Unsupported data type: {type_repr(values.dtype)}")
450
-
451
- if axis is None:
452
- stride = wp.types.type_size_in_bytes(values.dtype)
453
- native_func(values.ptr, out.ptr, value_count, stride, type_size)
454
-
455
- if host_return:
456
- return out.numpy()[0]
457
- return out
458
-
459
- stride = values.strides[axis]
460
- for idx in np.ndindex(output_shape):
461
- out_offset = sum(i * s for i, s in zip(idx, out.strides))
462
- val_offset = sum(i * s for i, s in zip(idx, values.strides))
463
-
464
- native_func(
465
- values.ptr + val_offset,
466
- out.ptr + out_offset,
467
- value_count,
468
- stride,
469
- type_size,
470
- )
471
-
472
- return out
473
-
474
-
475
- def array_inner(a, b, out=None, count=None, axis=None):
476
- """Compute the inner product of two arrays.
477
-
478
- This function computes the dot product between two arrays, optionally along a specified axis.
479
- The operation can be performed on the entire arrays or along a specific dimension.
480
-
481
- Args:
482
- a (wp.array): First input array.
483
- b (wp.array): Second input array. Must match shape and type of a.
484
- out (wp.array, optional): Output array to store results. If None, a new array is created.
485
- count (int, optional): Number of elements to process. If None, processes entire arrays.
486
- axis (int, optional): Axis along which to compute inner product. If None, computes on flattened arrays.
487
-
488
- Returns:
489
- wp.array or float: The inner product result. Returns a float if axis is None and out is None,
490
- otherwise returns the output array.
491
-
492
- Raises:
493
- RuntimeError: If array storage devices, sizes, or data types are incompatible.
494
- """
495
- if a.size != b.size:
496
- raise RuntimeError(f"A and b array storage sizes do not match ({a.size} vs {b.size})")
497
-
498
- if a.device != b.device:
499
- raise RuntimeError(f"A and b array storage devices do not match ({a.device} vs {b.device})")
500
-
501
- if not types_equal(a.dtype, b.dtype):
502
- raise RuntimeError(f"A and b array data types do not match ({type_repr(a.dtype)} vs {type_repr(b.dtype)})")
503
-
504
- if count is None:
505
- if axis is None:
506
- count = a.size
507
- else:
508
- count = a.shape[axis]
509
-
510
- if axis is None:
511
- output_shape = (1,)
512
- else:
513
-
514
- def output_dim(ax, dim):
515
- return 1 if ax == axis else dim
516
-
517
- output_shape = tuple(output_dim(ax, dim) for ax, dim in enumerate(a.shape))
518
-
519
- type_size = wp.types.type_size(a.dtype)
520
- scalar_type = wp.types.type_scalar_type(a.dtype)
521
-
522
- # User can provide a device output array for storing the number of runs
523
- # For convenience, if no such array is provided, number of runs is returned on host
524
- if out is None:
525
- host_return = True
526
- out = wp.empty(shape=output_shape, dtype=scalar_type, device=a.device)
527
- else:
528
- host_return = False
529
- if out.device != a.device:
530
- raise RuntimeError("out storage device should match values array")
531
- if out.dtype != scalar_type:
532
- raise RuntimeError(f"out array should have type {scalar_type.__name__}")
533
- if out.shape != output_shape:
534
- raise RuntimeError(f"out array should have shape {output_shape}")
535
-
536
- if count == 0:
537
- if axis is None and host_return:
538
- return 0.0
539
- out.zero_()
540
- return out
541
-
542
- from warp.context import runtime
543
-
544
- if a.device.is_cpu:
545
- if scalar_type == wp.float32:
546
- native_func = runtime.core.wp_array_inner_float_host
547
- elif scalar_type == wp.float64:
548
- native_func = runtime.core.wp_array_inner_double_host
549
- else:
550
- raise RuntimeError(f"Unsupported data type: {type_repr(a.dtype)}")
551
- elif a.device.is_cuda:
552
- if scalar_type == wp.float32:
553
- native_func = runtime.core.wp_array_inner_float_device
554
- elif scalar_type == wp.float64:
555
- native_func = runtime.core.wp_array_inner_double_device
556
- else:
557
- raise RuntimeError(f"Unsupported data type: {type_repr(a.dtype)}")
558
-
559
- if axis is None:
560
- stride_a = wp.types.type_size_in_bytes(a.dtype)
561
- stride_b = wp.types.type_size_in_bytes(b.dtype)
562
- native_func(a.ptr, b.ptr, out.ptr, count, stride_a, stride_b, type_size)
563
-
564
- if host_return:
565
- return out.numpy()[0]
566
- return out
567
-
568
- stride_a = a.strides[axis]
569
- stride_b = b.strides[axis]
570
-
571
- for idx in np.ndindex(output_shape):
572
- out_offset = sum(i * s for i, s in zip(idx, out.strides))
573
- a_offset = sum(i * s for i, s in zip(idx, a.strides))
574
- b_offset = sum(i * s for i, s in zip(idx, b.strides))
575
-
576
- native_func(
577
- a.ptr + a_offset,
578
- b.ptr + b_offset,
579
- out.ptr + out_offset,
580
- count,
581
- stride_a,
582
- stride_b,
583
- type_size,
584
- )
585
-
586
- return out
587
-
588
-
589
- @wp.kernel
590
- def _array_cast_kernel(
591
- dest: Any,
592
- src: Any,
593
- ):
594
- i = wp.tid()
595
- dest[i] = dest.dtype(src[i])
596
-
597
-
598
- def array_cast(in_array, out_array, count=None):
599
- """Cast elements from one array to another array with a different data type.
600
-
601
- This function performs element-wise casting from the input array to the output array.
602
- The arrays must have the same number of dimensions and data type shapes. If they don't match,
603
- the arrays will be flattened and casting will be performed at the scalar level.
604
-
605
- Args:
606
- in_array (wp.array): Input array to cast from.
607
- out_array (wp.array): Output array to cast to. Must have the same device as in_array.
608
- count (int, optional): Number of elements to process. If None, processes entire array.
609
- For multi-dimensional arrays, partial casting is not supported.
610
-
611
- Raises:
612
- RuntimeError: If arrays have different devices or if attempting partial casting
613
- on multi-dimensional arrays.
614
-
615
- Note:
616
- If the input and output arrays have the same data type, this function will
617
- simply copy the data without any conversion.
618
- """
619
- if in_array.device != out_array.device:
620
- raise RuntimeError(f"Array storage devices do not match ({in_array.device} vs {out_array.device})")
621
-
622
- in_array_data_shape = getattr(in_array.dtype, "_shape_", ())
623
- out_array_data_shape = getattr(out_array.dtype, "_shape_", ())
624
-
625
- if in_array.ndim != out_array.ndim or in_array_data_shape != out_array_data_shape:
626
- # Number of dimensions or data type shape do not match.
627
- # Flatten arrays and do cast at the scalar level
628
- in_array = in_array.flatten()
629
- out_array = out_array.flatten()
630
-
631
- in_array_data_length = warp.types.type_size(in_array.dtype)
632
- out_array_data_length = warp.types.type_size(out_array.dtype)
633
- in_array_scalar_type = wp.types.type_scalar_type(in_array.dtype)
634
- out_array_scalar_type = wp.types.type_scalar_type(out_array.dtype)
635
-
636
- in_array = wp.array(
637
- data=None,
638
- ptr=in_array.ptr,
639
- capacity=in_array.capacity,
640
- device=in_array.device,
641
- dtype=in_array_scalar_type,
642
- shape=in_array.shape[0] * in_array_data_length,
643
- )
644
-
645
- out_array = wp.array(
646
- data=None,
647
- ptr=out_array.ptr,
648
- capacity=out_array.capacity,
649
- device=out_array.device,
650
- dtype=out_array_scalar_type,
651
- shape=out_array.shape[0] * out_array_data_length,
652
- )
653
-
654
- if count is not None:
655
- count *= in_array_data_length
656
-
657
- if count is None:
658
- count = in_array.size
659
-
660
- if in_array.ndim == 1:
661
- dim = count
662
- elif count < in_array.size:
663
- raise RuntimeError("Partial cast is not supported for arrays with more than one dimension")
664
- else:
665
- dim = in_array.shape
666
-
667
- if in_array.dtype == out_array.dtype:
668
- # Same data type, can simply copy
669
- wp.copy(dest=out_array, src=in_array, count=count)
670
- else:
671
- wp.launch(kernel=_array_cast_kernel, dim=dim, inputs=[out_array, in_array], device=out_array.device)
672
-
673
-
674
- def create_warp_function(func: Callable) -> tuple[wp.Function, warp.context.Module]:
675
- """Create a Warp function from a Python function.
676
-
677
- Args:
678
- func (Callable): A Python function to be converted to a Warp function.
679
-
680
- Returns:
681
- wp.Function: A Warp function created from the input function.
682
- """
683
-
684
- from .codegen import Adjoint, get_full_arg_spec
685
-
686
- def unique_name(code: str):
687
- return "func_" + hex(hash(code))[-8:]
688
-
689
- # Create a Warp function from the input function
690
- source = None
691
- argspec = get_full_arg_spec(func)
692
- key = getattr(func, "__name__", None)
693
- if key is None:
694
- source, _ = Adjoint.extract_function_source(func)
695
- key = unique_name(source)
696
- elif key == "<lambda>":
697
- body = Adjoint.extract_lambda_source(func, only_body=True)
698
- if body is None:
699
- raise ValueError("Could not extract lambda source code")
700
- key = unique_name(body)
701
- source = f"def {key}({', '.join(argspec.args)}):\n return {body}"
702
- else:
703
- # use the qualname of the function as the key
704
- key = getattr(func, "__qualname__", key)
705
- key = key.replace(".", "_").replace(" ", "_").replace("<", "").replace(">", "_")
706
-
707
- module = warp.context.get_module(f"map_{key}")
708
- func = wp.Function(
709
- func,
710
- namespace="",
711
- module=module,
712
- key=key,
713
- source=source,
714
- overloaded_annotations=dict.fromkeys(argspec.args, Any),
715
- )
716
- return func, module
717
-
718
-
719
- def broadcast_shapes(shapes: list[tuple[int]]) -> tuple[int]:
720
- """Broadcast a list of shapes to a common shape.
721
-
722
- Following the broadcasting rules of NumPy, two shapes are compatible when:
723
- starting from the trailing dimension,
724
- 1. the two dimensions are equal, or
725
- 2. one of the dimensions is 1.
726
-
727
- Example:
728
- >>> broadcast_shapes([(3, 1, 4), (5, 4)])
729
- (3, 5, 4)
730
-
731
- Returns:
732
- tuple[int]: The broadcasted shape.
733
-
734
- Raises:
735
- ValueError: If the shapes are not broadcastable.
736
- """
737
- ref = shapes[0]
738
- for shape in shapes[1:]:
739
- broad = []
740
- for j in range(1, max(len(ref), len(shape)) + 1):
741
- if j <= len(ref) and j <= len(shape):
742
- s = shape[-j]
743
- r = ref[-j]
744
- if s == r:
745
- broad.append(s)
746
- elif s == 1 or r == 1:
747
- broad.append(max(s, r))
748
- else:
749
- raise ValueError(f"Shapes {ref} and {shape} are not broadcastable")
750
- elif j <= len(ref):
751
- broad.append(ref[-j])
752
- else:
753
- broad.append(shape[-j])
754
- ref = tuple(reversed(broad))
755
- return ref
756
-
757
-
758
- def map(
759
- func: Callable | wp.Function,
760
- *inputs: Array[DType] | Any,
761
- out: Array[DType] | list[Array[DType]] | None = None,
762
- return_kernel: bool = False,
763
- block_dim=256,
764
- device: Devicelike = None,
765
- ) -> Array[DType] | list[Array[DType]] | wp.Kernel:
766
- """
767
- Map a function over the elements of one or more arrays.
768
-
769
- You can use a Warp function, a regular Python function, or a lambda expression to map it to a set of arrays.
770
-
771
- .. testcode::
772
-
773
- a = wp.array([1, 2, 3], dtype=wp.float32)
774
- b = wp.array([4, 5, 6], dtype=wp.float32)
775
- c = wp.array([7, 8, 9], dtype=wp.float32)
776
- result = wp.map(lambda x, y, z: x + 2.0 * y - z, a, b, c)
777
- print(result)
778
-
779
- .. testoutput::
780
-
781
- [2. 4. 6.]
782
-
783
- Clamp values in an array in place:
784
-
785
- .. testcode::
786
-
787
- xs = wp.array([-1.0, 0.0, 1.0], dtype=wp.float32)
788
- wp.map(wp.clamp, xs, -0.5, 0.5, out=xs)
789
- print(xs)
790
-
791
- .. testoutput::
792
-
793
- [-0.5 0. 0.5]
794
-
795
- Note that only one of the inputs must be a Warp array. For example, it is possible
796
- vectorize the function :func:`warp.transform_point` over a collection of points
797
- with a given input transform as follows:
798
-
799
- .. code-block:: python
800
-
801
- tf = wp.transform((1.0, 2.0, 3.0), wp.quat_rpy(0.2, -0.6, 0.1))
802
- points = wp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=wp.vec3)
803
- transformed = wp.map(wp.transform_point, tf, points)
804
-
805
- Besides regular Warp arrays, other array types, such as the ``indexedarray``, are supported as well:
806
-
807
- .. testcode::
808
-
809
- arr = wp.array(data=np.arange(10, dtype=np.float32))
810
- indices = wp.array([1, 3, 5, 7, 9], dtype=int)
811
- iarr = wp.indexedarray1d(arr, [indices])
812
- out = wp.map(lambda x: x * 10.0, iarr)
813
- print(out)
814
-
815
- .. testoutput::
816
-
817
- [10. 30. 50. 70. 90.]
818
-
819
- If multiple arrays are provided, the
820
- `NumPy broadcasting rules <https://numpy.org/doc/stable/user/basics.broadcasting.html>`_
821
- are applied to determine the shape of the output array.
822
- Two shapes are compatible when:
823
- starting from the trailing dimension,
824
-
825
- 1. the two dimensions are equal, or
826
- 2. one of the dimensions is 1.
827
-
828
- For example, given arrays of shapes ``(3, 1, 4)`` and ``(5, 4)``, the broadcasted
829
- shape is ``(3, 5, 4)``.
830
-
831
- If no array(s) are provided to the ``out`` argument, the output array(s) are created automatically.
832
- The data type(s) of the output array(s) are determined by the type of the return value(s) of
833
- the function. The ``requires_grad`` flag for an automatically created output array is set to ``True``
834
- if any of the input arrays have it set to ``True`` and the respective output array's ``dtype`` is a type that
835
- supports differentiation.
836
-
837
- Args:
838
- func (Callable | Function): The function to map over the arrays.
839
- *inputs (array | Any): The input arrays or values to pass to the function.
840
- out (array | list[array] | None): Optional output array(s) to store the result(s). If None, the output array(s) will be created automatically.
841
- return_kernel (bool): If True, only return the generated kernel without performing the mapping operation.
842
- block_dim (int): The block dimension for the kernel launch.
843
- device (Devicelike): The device on which to run the kernel.
844
-
845
- Returns:
846
- array | list[array] | Kernel:
847
- The resulting array(s) of the mapping. If ``return_kernel`` is True, only returns the kernel used for mapping.
848
- """
849
-
850
- import builtins
851
-
852
- from .codegen import Adjoint, Struct, StructInstance
853
- from .types import (
854
- is_array,
855
- type_is_matrix,
856
- type_is_quaternion,
857
- type_is_transformation,
858
- type_is_vector,
859
- type_repr,
860
- type_to_warp,
861
- types_equal,
862
- )
863
-
864
- # mapping from struct name to its Python definition
865
- referenced_modules: dict[str, ModuleType] = {}
866
-
867
- def type_to_code(wp_type) -> str:
868
- """Returns the string representation of a given Warp type."""
869
- if is_array(wp_type):
870
- return f"warp.array(ndim={wp_type.ndim}, dtype={type_to_code(wp_type.dtype)})"
871
- if isinstance(wp_type, Struct):
872
- key = f"{wp_type.__module__}.{wp_type.key}"
873
- module = sys.modules.get(wp_type.__module__, None)
874
- if module is not None:
875
- referenced_modules[wp_type.__module__] = module
876
- return key
877
- if type_is_transformation(wp_type):
878
- return f"warp.types.transformation(dtype={type_to_code(wp_type._wp_scalar_type_)})"
879
- if type_is_quaternion(wp_type):
880
- return f"warp.types.quaternion(dtype={type_to_code(wp_type._wp_scalar_type_)})"
881
- if type_is_vector(wp_type):
882
- return f"warp.types.vector(length={wp_type._shape_[0]}, dtype={type_to_code(wp_type._wp_scalar_type_)})"
883
- if type_is_matrix(wp_type):
884
- return f"warp.types.matrix(shape=({wp_type._shape_[0]}, {wp_type._shape_[1]}), dtype={type_to_code(wp_type._wp_scalar_type_)})"
885
- if wp_type == builtins.bool:
886
- return "bool"
887
- if wp_type == builtins.float:
888
- return "float"
889
- if wp_type == builtins.int:
890
- return "int"
891
-
892
- name = getattr(wp_type, "__name__", None)
893
- if name is None:
894
- return type_repr(wp_type)
895
- name = getattr(wp_type, "__qualname__", name)
896
- module = getattr(wp_type, "__module__", None)
897
- if module is not None:
898
- referenced_modules[wp_type.__module__] = module
899
- return wp_type.__module__ + "." + name
900
-
901
- def get_warp_type(value):
902
- dtype = type(value)
903
- if issubclass(dtype, StructInstance):
904
- # a struct
905
- return value._cls
906
- return type_to_warp(dtype)
907
-
908
- # gather the arrays in the inputs
909
- array_shapes = [a.shape for a in inputs if is_array(a)]
910
- if len(array_shapes) == 0:
911
- raise ValueError("map requires at least one warp.array input")
912
- # broadcast the shapes of the arrays
913
- out_shape = broadcast_shapes(array_shapes)
914
-
915
- module = None
916
- out_dtypes = None
917
- if isinstance(func, wp.Function):
918
- func_name = func.key
919
- wp_func = func
920
- else:
921
- # check if op is a callable function
922
- if not callable(func):
923
- raise TypeError("func must be a callable function or a warp.Function")
924
- wp_func, module = create_warp_function(func)
925
- func_name = wp_func.key
926
- if module is None:
927
- module = warp.context.get_module(f"map_{func_name}")
928
-
929
- arg_names = list(wp_func.input_types.keys())
930
-
931
- if len(inputs) != len(arg_names):
932
- raise TypeError(
933
- f"Number of input arguments ({len(inputs)}) does not match expected number of function arguments ({len(arg_names)})"
934
- )
935
-
936
- # determine output dtype
937
- arg_types = {}
938
- arg_values = {}
939
- for i, arg_name in enumerate(arg_names):
940
- if is_array(inputs[i]):
941
- # we will pass an element of the array to the function
942
- arg_types[arg_name] = inputs[i].dtype
943
- if device is None:
944
- device = inputs[i].device
945
- else:
946
- # we pass the input value directly to the function
947
- arg_types[arg_name] = get_warp_type(inputs[i])
948
- func_or_none = wp_func.get_overload(list(arg_types.values()), {})
949
- if func_or_none is None:
950
- raise TypeError(
951
- f"Function {func_name} does not support the provided argument types {', '.join(type_repr(t) for t in arg_types.values())}"
952
- )
953
- func = func_or_none
954
-
955
- if func.value_type is not None:
956
- out_dtype = func.value_type
957
- elif func.value_func is not None:
958
- out_dtype = func.value_func(arg_types, arg_values)
959
- else:
960
- func.build(None)
961
- out_dtype = func.value_func(arg_types, arg_values)
962
-
963
- if out_dtype is None:
964
- raise TypeError("The provided function must return a value")
965
-
966
- if isinstance(out_dtype, tuple) or isinstance(out_dtype, list):
967
- out_dtypes = out_dtype
968
- else:
969
- out_dtypes = (out_dtype,)
970
-
971
- if out is None:
972
- requires_grad = any(getattr(a, "requires_grad", False) for a in inputs if is_array(a))
973
- outputs = []
974
- for dtype in out_dtypes:
975
- rg = requires_grad and Adjoint.is_differentiable_value_type(dtype)
976
- outputs.append(wp.empty(out_shape, dtype=dtype, requires_grad=rg, device=device))
977
- elif len(out_dtypes) == 1 and is_array(out):
978
- if not types_equal(out.dtype, out_dtypes[0]):
979
- raise TypeError(
980
- f"Output array dtype {type_repr(out.dtype)} does not match expected dtype {type_repr(out_dtypes[0])}"
981
- )
982
- if out.shape != out_shape:
983
- raise TypeError(f"Output array shape {out.shape} does not match expected shape {out_shape}")
984
- outputs = [out]
985
- elif len(out_dtypes) > 1:
986
- if isinstance(out, tuple) or isinstance(out, list):
987
- if len(out) != len(out_dtypes):
988
- raise TypeError(
989
- f"Number of provided output arrays ({len(out)}) does not match expected number of function outputs ({len(out_dtypes)})"
990
- )
991
- for i, a in enumerate(out):
992
- if not types_equal(a.dtype, out_dtypes[i]):
993
- raise TypeError(
994
- f"Output array {i} dtype {type_repr(a.dtype)} does not match expected dtype {type_repr(out_dtypes[i])}"
995
- )
996
- if a.shape != out_shape:
997
- raise TypeError(f"Output array {i} shape {a.shape} does not match expected shape {out_shape}")
998
- outputs = list(out)
999
- else:
1000
- raise TypeError(
1001
- f"Invalid output provided, expected {len(out_dtypes)} Warp arrays with shape {out_shape} and dtypes ({', '.join(type_repr(t) for t in out_dtypes)})"
1002
- )
1003
-
1004
- # create code for a kernel
1005
- code = """def map_kernel({kernel_args}):
1006
- {tids} = wp.tid()
1007
- {load_args}
1008
- """
1009
- if len(outputs) == 1:
1010
- code += "__out_0[{tids}] = {func_name}({arg_names})"
1011
- else:
1012
- code += ", ".join(f"__o_{i}" for i in range(len(outputs)))
1013
- code += " = {func_name}({arg_names})\n"
1014
- for i in range(len(outputs)):
1015
- code += f" __out_{i}" + "[{tids}]" + f" = __o_{i}\n"
1016
-
1017
- tids = [f"__tid_{i}" for i in range(len(out_shape))]
1018
-
1019
- load_args = []
1020
- kernel_args = []
1021
- for arg_name, input in zip(arg_names, inputs):
1022
- if is_array(input):
1023
- arr_name = f"{arg_name}_array"
1024
- array_type_name = type(input).__name__
1025
- kernel_args.append(
1026
- f"{arr_name}: wp.{array_type_name}(dtype={type_to_code(input.dtype)}, ndim={input.ndim})"
1027
- )
1028
- shape = input.shape
1029
- indices = []
1030
- for i in range(1, len(shape) + 1):
1031
- if shape[-i] == 1:
1032
- indices.append("0")
1033
- else:
1034
- indices.append(tids[-i])
1035
-
1036
- load_args.append(f"{arg_name} = {arr_name}[{', '.join(reversed(indices))}]")
1037
- else:
1038
- kernel_args.append(f"{arg_name}: {type_to_code(type(input))}")
1039
- for i, o in enumerate(outputs):
1040
- array_type_name = type(o).__name__
1041
- kernel_args.append(f"__out_{i}: wp.{array_type_name}(dtype={type_to_code(o.dtype)}, ndim={o.ndim})")
1042
- code = code.format(
1043
- func_name=func_name,
1044
- kernel_args=", ".join(kernel_args),
1045
- arg_names=", ".join(arg_names),
1046
- tids=", ".join(tids),
1047
- load_args="\n ".join(load_args),
1048
- )
1049
- namespace = {}
1050
- namespace.update({"wp": wp, "warp": wp, func_name: wp_func, "Any": Any})
1051
- namespace.update(referenced_modules)
1052
- exec(code, namespace)
1053
-
1054
- kernel = wp.Kernel(namespace["map_kernel"], key="map_kernel", source=code, module=module)
1055
- if return_kernel:
1056
- return kernel
1057
-
1058
- wp.launch(
1059
- kernel,
1060
- dim=out_shape,
1061
- inputs=inputs,
1062
- outputs=outputs,
1063
- block_dim=block_dim,
1064
- device=device,
1065
- )
1066
-
1067
- if len(outputs) == 1:
1068
- o = outputs[0]
1069
- else:
1070
- o = outputs
1071
-
1072
- return o
1073
-
1074
-
1075
- # code snippet for invoking cProfile
1076
- # cp = cProfile.Profile()
1077
- # cp.enable()
1078
- # for i in range(1000):
1079
- # self.state = self.integrator.forward(self.model, self.state, self.sim_dt)
1080
-
1081
- # cp.disable()
1082
- # cp.print_stats(sort='tottime')
1083
- # exit(0)
1084
-
1085
-
1086
- # helper kernels for initializing NVDB volumes from a dense array
1087
- @wp.kernel
1088
- def copy_dense_volume_to_nano_vdb_v(volume: wp.uint64, values: wp.array(dtype=wp.vec3, ndim=3)):
1089
- i, j, k = wp.tid()
1090
- wp.volume_store_v(volume, i, j, k, values[i, j, k])
1091
-
1092
-
1093
- @wp.kernel
1094
- def copy_dense_volume_to_nano_vdb_f(volume: wp.uint64, values: wp.array(dtype=wp.float32, ndim=3)):
1095
- i, j, k = wp.tid()
1096
- wp.volume_store_f(volume, i, j, k, values[i, j, k])
1097
-
1098
-
1099
- @wp.kernel
1100
- def copy_dense_volume_to_nano_vdb_i(volume: wp.uint64, values: wp.array(dtype=wp.int32, ndim=3)):
1101
- i, j, k = wp.tid()
1102
- wp.volume_store_i(volume, i, j, k, values[i, j, k])
1103
-
1104
-
1105
- # represent an edge between v0, v1 with connected faces f0, f1, and opposite vertex o0, and o1
1106
- # winding is such that first tri can be reconstructed as {v0, v1, o0}, and second tri as { v1, v0, o1 }
1107
- class MeshEdge:
1108
- def __init__(self, v0, v1, o0, o1, f0, f1):
1109
- self.v0 = v0 # vertex 0
1110
- self.v1 = v1 # vertex 1
1111
- self.o0 = o0 # opposite vertex 1
1112
- self.o1 = o1 # opposite vertex 2
1113
- self.f0 = f0 # index of tri1
1114
- self.f1 = f1 # index of tri2
1115
-
1116
-
1117
- class MeshAdjacency:
1118
- def __init__(self, indices, num_tris):
1119
- # map edges (v0, v1) to faces (f0, f1)
1120
- self.edges = {}
1121
- self.indices = indices
1122
-
1123
- for index, tri in enumerate(indices):
1124
- self.add_edge(tri[0], tri[1], tri[2], index)
1125
- self.add_edge(tri[1], tri[2], tri[0], index)
1126
- self.add_edge(tri[2], tri[0], tri[1], index)
1127
-
1128
- def add_edge(self, i0, i1, o, f): # index1, index2, index3, index of triangle
1129
- key = (min(i0, i1), max(i0, i1))
1130
- edge = None
1131
-
1132
- if key in self.edges:
1133
- edge = self.edges[key]
1134
-
1135
- if edge.f1 != -1:
1136
- print("Detected non-manifold edge")
1137
- return
1138
- else:
1139
- # update other side of the edge
1140
- edge.o1 = o
1141
- edge.f1 = f
1142
- else:
1143
- # create new edge with opposite yet to be filled
1144
- edge = MeshEdge(i0, i1, o, -1, f, -1)
1145
-
1146
- self.edges[key] = edge
1147
-
1148
-
1149
- def mem_report(): # pragma: no cover
1150
- def _mem_report(tensors, mem_type):
1151
- """Print the selected tensors of type
1152
- There are two major storage types in our major concern:
1153
- - GPU: tensors transferred to CUDA devices
1154
- - CPU: tensors remaining on the system memory (usually unimportant)
1155
- Args:
1156
- - tensors: the tensors of specified type
1157
- - mem_type: 'CPU' or 'GPU' in current implementation"""
1158
- total_numel = 0
1159
- total_mem = 0
1160
- visited_data = []
1161
- for tensor in tensors:
1162
- if tensor.is_sparse:
1163
- continue
1164
- # a data_ptr indicates a memory block allocated
1165
- data_ptr = tensor.storage().data_ptr()
1166
- if data_ptr in visited_data:
1167
- continue
1168
- visited_data.append(data_ptr)
1169
-
1170
- numel = tensor.storage().size()
1171
- total_numel += numel
1172
- element_size = tensor.storage().element_size()
1173
- mem = numel * element_size / 1024 / 1024 # 32bit=4Byte, MByte
1174
- total_mem += mem
1175
- print(f"Type: {mem_type:<4} | Total Tensors: {total_numel:>8} | Used Memory: {total_mem:>8.2f} MB")
1176
-
1177
- import gc
1178
-
1179
- import torch
1180
-
1181
- gc.collect()
1182
-
1183
- LEN = 65
1184
- objects = gc.get_objects()
1185
- # print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') )
1186
- tensors = [obj for obj in objects if torch.is_tensor(obj)]
1187
- cuda_tensors = [t for t in tensors if t.is_cuda]
1188
- host_tensors = [t for t in tensors if not t.is_cuda]
1189
- _mem_report(cuda_tensors, "GPU")
1190
- _mem_report(host_tensors, "CPU")
1191
- print("=" * LEN)
1192
-
1193
-
1194
- class ScopedDevice:
1195
- """A context manager to temporarily change the current default device.
1196
-
1197
- For CUDA devices, this context manager makes the device's CUDA context
1198
- current and restores the previous CUDA context on exit. This is handy when
1199
- running Warp scripts as part of a bigger pipeline because it avoids any side
1200
- effects of changing the CUDA context in the enclosed code.
1201
-
1202
- Attributes:
1203
- device (Device): The device that will temporarily become the default
1204
- device within the context.
1205
- saved_device (Device): The previous default device. This is restored as
1206
- the default device on exiting the context.
1207
- """
1208
-
1209
- def __init__(self, device: Devicelike):
1210
- """Initializes the context manager with a device.
1211
-
1212
- Args:
1213
- device: The device that will temporarily become the default device
1214
- within the context.
1215
- """
1216
- self.device = wp.get_device(device)
1217
-
1218
- def __enter__(self):
1219
- # save the previous default device
1220
- self.saved_device = self.device.runtime.default_device
1221
-
1222
- # make this the default device
1223
- self.device.runtime.default_device = self.device
1224
-
1225
- # make it the current CUDA device so that device alias "cuda" will evaluate to this device
1226
- self.device.context_guard.__enter__()
1227
-
1228
- return self.device
1229
-
1230
- def __exit__(self, exc_type, exc_value, traceback):
1231
- # restore original CUDA context
1232
- self.device.context_guard.__exit__(exc_type, exc_value, traceback)
1233
-
1234
- # restore original target device
1235
- self.device.runtime.default_device = self.saved_device
1236
-
1237
-
1238
- class ScopedStream:
1239
- """A context manager to temporarily change the current stream on a device.
1240
-
1241
- Attributes:
1242
- stream (Stream or None): The stream that will temporarily become the device's
1243
- default stream within the context.
1244
- saved_stream (Stream): The device's previous current stream. This is
1245
- restored as the device's current stream on exiting the context.
1246
- sync_enter (bool): Whether to synchronize this context's stream with
1247
- the device's previous current stream on entering the context.
1248
- sync_exit (bool): Whether to synchronize the device's previous current
1249
- with this context's stream on exiting the context.
1250
- device (Device): The device associated with the stream.
1251
- """
1252
-
1253
- def __init__(self, stream: wp.Stream | None, sync_enter: bool = True, sync_exit: bool = False):
1254
- """Initializes the context manager with a stream and synchronization options.
1255
-
1256
- Args:
1257
- stream: The stream that will temporarily become the device's
1258
- default stream within the context.
1259
- sync_enter (bool): Whether to synchronize this context's stream with
1260
- the device's previous current stream on entering the context.
1261
- sync_exit (bool): Whether to synchronize the device's previous current
1262
- with this context's stream on exiting the context.
1263
- """
1264
-
1265
- self.stream = stream
1266
- self.sync_enter = sync_enter
1267
- self.sync_exit = sync_exit
1268
- if stream is not None:
1269
- self.device = stream.device
1270
- self.device_scope = ScopedDevice(self.device)
1271
-
1272
- def __enter__(self):
1273
- if self.stream is not None:
1274
- self.device_scope.__enter__()
1275
- self.saved_stream = self.device.stream
1276
- self.device.set_stream(self.stream, self.sync_enter)
1277
-
1278
- return self.stream
1279
-
1280
- def __exit__(self, exc_type, exc_value, traceback):
1281
- if self.stream is not None:
1282
- self.device.set_stream(self.saved_stream, self.sync_exit)
1283
- self.device_scope.__exit__(exc_type, exc_value, traceback)
1284
-
1285
-
1286
- TIMING_KERNEL = 1
1287
- TIMING_KERNEL_BUILTIN = 2
1288
- TIMING_MEMCPY = 4
1289
- TIMING_MEMSET = 8
1290
- TIMING_GRAPH = 16
1291
- TIMING_ALL = 0xFFFFFFFF
1292
-
1293
-
1294
- # timer utils
1295
- class ScopedTimer:
1296
- indent = -1
1297
-
1298
- enabled = True
1299
-
1300
- def __init__(
1301
- self,
1302
- name: str,
1303
- active: bool = True,
1304
- print: bool = True,
1305
- detailed: bool = False,
1306
- dict: dict[str, list[float]] | None = None,
1307
- use_nvtx: bool = False,
1308
- color: int | str = "rapids",
1309
- synchronize: bool = False,
1310
- cuda_filter: int = 0,
1311
- report_func: Callable[[list[TimingResult], str], None] | None = None,
1312
- skip_tape: bool = False,
1313
- ):
1314
- """Context manager object for a timer
1315
-
1316
- Parameters:
1317
- name: Name of timer
1318
- active: Enables this timer
1319
- print: At context manager exit, print elapsed time to ``sys.stdout``
1320
- detailed: Collects additional profiling data using cProfile and calls ``print_stats()`` at context exit
1321
- dict: A dictionary of lists to which the elapsed time will be appended using ``name`` as a key
1322
- use_nvtx: If true, timing functionality is replaced by an NVTX range
1323
- color: ARGB value (e.g. 0x00FFFF) or color name (e.g. 'cyan') associated with the NVTX range
1324
- synchronize: Synchronize the CPU thread with any outstanding CUDA work to return accurate GPU timings
1325
- cuda_filter: Filter flags for CUDA activity timing, e.g. ``warp.TIMING_KERNEL`` or ``warp.TIMING_ALL``
1326
- report_func: A callback function to print the activity report.
1327
- If ``None``, :func:`wp.timing_print() <timing_print>` will be used.
1328
- skip_tape: If true, the timer will not be recorded in the tape
1329
-
1330
- Attributes:
1331
- extra_msg (str): Can be set to a string that will be added to the printout at context exit.
1332
- elapsed (float): The duration of the ``with`` block used with this object
1333
- timing_results (list[TimingResult]): The list of activity timing results, if collection was requested using ``cuda_filter``
1334
- """
1335
- self.name = name
1336
- self.active = active and self.enabled
1337
- self.print = print
1338
- self.detailed = detailed
1339
- self.dict = dict
1340
- self.use_nvtx = use_nvtx
1341
- self.color = color
1342
- self.synchronize = synchronize
1343
- self.skip_tape = skip_tape
1344
- self.elapsed = 0.0
1345
- self.cuda_filter = cuda_filter
1346
- self.report_func = report_func or wp.timing_print
1347
- self.extra_msg = "" # Can be used to add to the message printed at manager exit
1348
-
1349
- if self.dict is not None:
1350
- if name not in self.dict:
1351
- self.dict[name] = []
1352
-
1353
- def __enter__(self):
1354
- if not self.skip_tape and warp.context.runtime is not None and warp.context.runtime.tape is not None:
1355
- warp.context.runtime.tape.record_scope_begin(self.name)
1356
- if self.active:
1357
- if self.synchronize:
1358
- wp.synchronize()
1359
-
1360
- if self.cuda_filter:
1361
- # begin CUDA activity collection, synchronizing if needed
1362
- timing_begin(self.cuda_filter, synchronize=not self.synchronize)
1363
-
1364
- if self.detailed:
1365
- self.cp = cProfile.Profile()
1366
- self.cp.clear()
1367
- self.cp.enable()
1368
-
1369
- if self.use_nvtx:
1370
- import nvtx
1371
-
1372
- self.nvtx_range_id = nvtx.start_range(self.name, color=self.color)
1373
-
1374
- if self.print:
1375
- ScopedTimer.indent += 1
1376
-
1377
- if warp.config.verbose:
1378
- indent = " " * ScopedTimer.indent
1379
- print(f"{indent}{self.name} ...", flush=True)
1380
-
1381
- self.start = time.perf_counter_ns()
1382
-
1383
- return self
1384
-
1385
- def __exit__(self, exc_type, exc_value, traceback):
1386
- if not self.skip_tape and warp.context.runtime is not None and warp.context.runtime.tape is not None:
1387
- warp.context.runtime.tape.record_scope_end()
1388
- if self.active:
1389
- if self.synchronize:
1390
- wp.synchronize()
1391
-
1392
- self.elapsed = (time.perf_counter_ns() - self.start) / 1000000.0
1393
-
1394
- if self.use_nvtx:
1395
- import nvtx
1396
-
1397
- nvtx.end_range(self.nvtx_range_id)
1398
-
1399
- if self.detailed:
1400
- self.cp.disable()
1401
- self.cp.print_stats(sort="tottime")
1402
-
1403
- if self.cuda_filter:
1404
- # end CUDA activity collection, synchronizing if needed
1405
- self.timing_results = timing_end(synchronize=not self.synchronize)
1406
- else:
1407
- self.timing_results = []
1408
-
1409
- if self.dict is not None:
1410
- self.dict[self.name].append(self.elapsed)
1411
-
1412
- if self.print:
1413
- indent = " " * ScopedTimer.indent
1414
-
1415
- if self.timing_results:
1416
- self.report_func(self.timing_results, indent=indent)
1417
- print()
1418
-
1419
- if self.extra_msg:
1420
- print(f"{indent}{self.name} took {self.elapsed:.2f} ms {self.extra_msg}")
1421
- else:
1422
- print(f"{indent}{self.name} took {self.elapsed:.2f} ms")
1423
-
1424
- ScopedTimer.indent -= 1
1425
-
1426
-
1427
- # Allow temporarily enabling/disabling mempool allocators
1428
- class ScopedMempool:
1429
- def __init__(self, device: Devicelike, enable: bool):
1430
- self.device = wp.get_device(device)
1431
- self.enable = enable
1432
-
1433
- def __enter__(self):
1434
- self.saved_setting = wp.is_mempool_enabled(self.device)
1435
- wp.set_mempool_enabled(self.device, self.enable)
1436
-
1437
- def __exit__(self, exc_type, exc_value, traceback):
1438
- wp.set_mempool_enabled(self.device, self.saved_setting)
1439
-
1440
-
1441
- # Allow temporarily enabling/disabling mempool access
1442
- class ScopedMempoolAccess:
1443
- def __init__(self, target_device: Devicelike, peer_device: Devicelike, enable: bool):
1444
- self.target_device = target_device
1445
- self.peer_device = peer_device
1446
- self.enable = enable
1447
-
1448
- def __enter__(self):
1449
- self.saved_setting = wp.is_mempool_access_enabled(self.target_device, self.peer_device)
1450
- wp.set_mempool_access_enabled(self.target_device, self.peer_device, self.enable)
1451
-
1452
- def __exit__(self, exc_type, exc_value, traceback):
1453
- wp.set_mempool_access_enabled(self.target_device, self.peer_device, self.saved_setting)
1454
-
1455
-
1456
- # Allow temporarily enabling/disabling peer access
1457
- class ScopedPeerAccess:
1458
- def __init__(self, target_device: Devicelike, peer_device: Devicelike, enable: bool):
1459
- self.target_device = target_device
1460
- self.peer_device = peer_device
1461
- self.enable = enable
1462
-
1463
- def __enter__(self):
1464
- self.saved_setting = wp.is_peer_access_enabled(self.target_device, self.peer_device)
1465
- wp.set_peer_access_enabled(self.target_device, self.peer_device, self.enable)
1466
-
1467
- def __exit__(self, exc_type, exc_value, traceback):
1468
- wp.set_peer_access_enabled(self.target_device, self.peer_device, self.saved_setting)
1469
-
1470
-
1471
- class ScopedCapture:
1472
- def __init__(self, device: Devicelike = None, stream=None, force_module_load=None, external=False):
1473
- self.device = device
1474
- self.stream = stream
1475
- self.force_module_load = force_module_load
1476
- self.external = external
1477
- self.active = False
1478
- self.graph = None
1479
-
1480
- def __enter__(self):
1481
- try:
1482
- wp.capture_begin(
1483
- device=self.device, stream=self.stream, force_module_load=self.force_module_load, external=self.external
1484
- )
1485
- self.active = True
1486
- return self
1487
- except:
1488
- raise
1489
-
1490
- def __exit__(self, exc_type, exc_value, traceback):
1491
- if self.active:
1492
- try:
1493
- self.graph = wp.capture_end(device=self.device, stream=self.stream)
1494
- except Exception:
1495
- # Only report this exception if __exit__() was reached without an exception,
1496
- # otherwise re-raise the original exception.
1497
- if exc_type is None:
1498
- raise
1499
- finally:
1500
- self.active = False
1501
-
1502
-
1503
- def check_p2p():
1504
- """Check if the machine is configured properly for peer-to-peer transfers.
1505
-
1506
- Returns:
1507
- A Boolean indicating whether the machine is configured properly for peer-to-peer transfers.
1508
- On Linux, this function attempts to determine if IOMMU is enabled and will return `False` if IOMMU is detected.
1509
- On other operating systems, it always return `True`.
1510
- """
1511
-
1512
- # HACK: allow disabling P2P tests using an environment variable
1513
- disable_p2p_tests = os.getenv("WARP_DISABLE_P2P_TESTS", default="0")
1514
- if int(disable_p2p_tests):
1515
- return False
1516
-
1517
- if sys.platform == "linux":
1518
- # IOMMU enablement can affect peer-to-peer transfers.
1519
- # On modern Linux, there should be IOMMU-related entries in the /sys file system.
1520
- # This should be more reliable than checking kernel logs like dmesg.
1521
- if os.path.isdir("/sys/class/iommu") and os.listdir("/sys/class/iommu"):
1522
- return False
1523
- if os.path.isdir("/sys/kernel/iommu_groups") and os.listdir("/sys/kernel/iommu_groups"):
1524
- return False
1525
-
1526
- return True
1527
-
1528
-
1529
- class timing_result_t(ctypes.Structure):
1530
- """CUDA timing struct for fetching values from C++"""
1531
-
1532
- _fields_ = (
1533
- ("context", ctypes.c_void_p),
1534
- ("name", ctypes.c_char_p),
1535
- ("filter", ctypes.c_int),
1536
- ("elapsed", ctypes.c_float),
1537
- )
1538
-
1539
-
1540
- class TimingResult:
1541
- """Timing result for a single activity."""
1542
-
1543
- def __init__(self, device, name, filter, elapsed):
1544
- self.device: warp.context.Device = device
1545
- """The device where the activity was recorded."""
1546
-
1547
- self.name: str = name
1548
- """The activity name."""
1549
-
1550
- self.filter: int = filter
1551
- """The type of activity (e.g., ``warp.TIMING_KERNEL``)."""
1552
-
1553
- self.elapsed: float = elapsed
1554
- """The elapsed time in milliseconds."""
1555
-
1556
-
1557
- def timing_begin(cuda_filter: int = TIMING_ALL, synchronize: bool = True) -> None:
1558
- """Begin detailed activity timing.
1559
-
1560
- Parameters:
1561
- cuda_filter: Filter flags for CUDA activity timing, e.g. ``warp.TIMING_KERNEL`` or ``warp.TIMING_ALL``
1562
- synchronize: Whether to synchronize all CUDA devices before timing starts
1563
- """
1564
-
1565
- if synchronize:
1566
- warp.synchronize()
1567
-
1568
- warp.context.runtime.core.wp_cuda_timing_begin(cuda_filter)
1569
-
1570
-
1571
- def timing_end(synchronize: bool = True) -> list[TimingResult]:
1572
- """End detailed activity timing.
1573
-
1574
- Parameters:
1575
- synchronize: Whether to synchronize all CUDA devices before timing ends
1576
-
1577
- Returns:
1578
- A list of :class:`TimingResult` objects for all recorded activities.
1579
- """
1580
-
1581
- if synchronize:
1582
- warp.synchronize()
1583
-
1584
- # get result count
1585
- count = warp.context.runtime.core.wp_cuda_timing_get_result_count()
1586
-
1587
- # get result array from C++
1588
- result_buffer = (timing_result_t * count)()
1589
- warp.context.runtime.core.wp_cuda_timing_end(ctypes.byref(result_buffer), count)
1590
-
1591
- # prepare Python result list
1592
- results = []
1593
- for r in result_buffer:
1594
- device = warp.context.runtime.context_map.get(r.context)
1595
- filter = r.filter
1596
- elapsed = r.elapsed
1597
-
1598
- name = r.name.decode()
1599
- if filter == TIMING_KERNEL:
1600
- if name.endswith("forward"):
1601
- # strip trailing "_cuda_kernel_forward"
1602
- name = f"forward kernel {name[:-20]}"
1603
- else:
1604
- # strip trailing "_cuda_kernel_backward"
1605
- name = f"backward kernel {name[:-21]}"
1606
- elif filter == TIMING_KERNEL_BUILTIN:
1607
- if name.startswith("wp::"):
1608
- name = f"builtin kernel {name[4:]}"
1609
- else:
1610
- name = f"builtin kernel {name}"
1611
-
1612
- results.append(TimingResult(device, name, filter, elapsed))
1613
-
1614
- return results
1615
-
1616
-
1617
- def timing_print(results: list[TimingResult], indent: str = "") -> None:
1618
- """Print timing results.
1619
-
1620
- Parameters:
1621
- results: List of :class:`TimingResult` objects to print.
1622
- indent: Optional indentation to prepend to all output lines.
1623
- """
1624
-
1625
- if not results:
1626
- print("No activity")
1627
- return
1628
-
1629
- class Aggregate:
1630
- def __init__(self, count=0, elapsed=0):
1631
- self.count = count
1632
- self.elapsed = elapsed
1633
-
1634
- device_totals = {}
1635
- activity_totals = {}
1636
-
1637
- max_name_len = len("Activity")
1638
- for r in results:
1639
- name_len = len(r.name)
1640
- max_name_len = max(max_name_len, name_len)
1641
-
1642
- activity_width = max_name_len + 1
1643
- activity_dashes = "-" * activity_width
1644
-
1645
- print(f"{indent}CUDA timeline:")
1646
- print(f"{indent}----------------+---------+{activity_dashes}")
1647
- print(f"{indent}Time | Device | Activity")
1648
- print(f"{indent}----------------+---------+{activity_dashes}")
1649
- for r in results:
1650
- device_agg = device_totals.get(r.device.alias)
1651
- if device_agg is None:
1652
- device_totals[r.device.alias] = Aggregate(count=1, elapsed=r.elapsed)
1653
- else:
1654
- device_agg.count += 1
1655
- device_agg.elapsed += r.elapsed
1656
-
1657
- activity_agg = activity_totals.get(r.name)
1658
- if activity_agg is None:
1659
- activity_totals[r.name] = Aggregate(count=1, elapsed=r.elapsed)
1660
- else:
1661
- activity_agg.count += 1
1662
- activity_agg.elapsed += r.elapsed
1663
-
1664
- print(f"{indent}{r.elapsed:12.6f} ms | {r.device.alias:7s} | {r.name}")
1665
-
1666
- print()
1667
- print(f"{indent}CUDA activity summary:")
1668
- print(f"{indent}----------------+---------+{activity_dashes}")
1669
- print(f"{indent}Total time | Count | Activity")
1670
- print(f"{indent}----------------+---------+{activity_dashes}")
1671
- for name, agg in activity_totals.items():
1672
- print(f"{indent}{agg.elapsed:12.6f} ms | {agg.count:7d} | {name}")
1673
-
1674
- print()
1675
- print(f"{indent}CUDA device summary:")
1676
- print(f"{indent}----------------+---------+{activity_dashes}")
1677
- print(f"{indent}Total time | Count | Device")
1678
- print(f"{indent}----------------+---------+{activity_dashes}")
1679
- for device, agg in device_totals.items():
1680
- print(f"{indent}{agg.elapsed:12.6f} ms | {agg.count:7d} | {device}")
31
+ return get_deprecated_api(_utils, "wp", name)