warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/_src/tape.py
ADDED
|
@@ -0,0 +1,1206 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from collections import defaultdict, namedtuple
|
|
19
|
+
|
|
20
|
+
import warp as wp
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Tape:
|
|
24
|
+
"""
|
|
25
|
+
Record kernel launches within a Tape scope to enable automatic differentiation.
|
|
26
|
+
Gradients can be computed after the operations have been recorded on the tape via
|
|
27
|
+
:meth:`Tape.backward()`.
|
|
28
|
+
|
|
29
|
+
Example
|
|
30
|
+
-------
|
|
31
|
+
|
|
32
|
+
.. code-block:: python
|
|
33
|
+
|
|
34
|
+
tape = wp.Tape()
|
|
35
|
+
|
|
36
|
+
# forward pass
|
|
37
|
+
with tape:
|
|
38
|
+
wp.launch(kernel=compute1, inputs=[a, b], device="cuda")
|
|
39
|
+
wp.launch(kernel=compute2, inputs=[c, d], device="cuda")
|
|
40
|
+
wp.launch(kernel=loss, inputs=[d, l], device="cuda")
|
|
41
|
+
|
|
42
|
+
# reverse pass
|
|
43
|
+
tape.backward(l)
|
|
44
|
+
|
|
45
|
+
Gradients can be accessed via the ``tape.gradients`` dictionary, e.g.:
|
|
46
|
+
|
|
47
|
+
.. code-block:: python
|
|
48
|
+
|
|
49
|
+
print(tape.gradients[a])
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self):
|
|
54
|
+
self.gradients = {}
|
|
55
|
+
self.launches = []
|
|
56
|
+
self.scopes = []
|
|
57
|
+
|
|
58
|
+
self.loss = None
|
|
59
|
+
|
|
60
|
+
def __enter__(self):
|
|
61
|
+
wp._src.context.init()
|
|
62
|
+
|
|
63
|
+
if wp._src.context.runtime.tape is not None:
|
|
64
|
+
raise RuntimeError("Warp: Error, entering a tape while one is already active")
|
|
65
|
+
|
|
66
|
+
wp._src.context.runtime.tape = self
|
|
67
|
+
|
|
68
|
+
return self
|
|
69
|
+
|
|
70
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
71
|
+
if wp._src.context.runtime.tape is None:
|
|
72
|
+
raise RuntimeError("Warp: Error, ended tape capture, but tape not present")
|
|
73
|
+
|
|
74
|
+
wp._src.context.runtime.tape = None
|
|
75
|
+
|
|
76
|
+
# adj_outputs is a mapping from output tensor -> adjoint of the output
|
|
77
|
+
# after running backward the gradients of tensors may be retrieved by:
|
|
78
|
+
#
|
|
79
|
+
# adj_tensor = tape.gradients[tensor]
|
|
80
|
+
#
|
|
81
|
+
def backward(self, loss: wp.array | None = None, grads: dict[wp.array, wp.array] | None = None):
|
|
82
|
+
"""Evaluate the backward pass of the recorded operations on the tape.
|
|
83
|
+
|
|
84
|
+
A single-element array ``loss`` or a dictionary of arrays ``grads``
|
|
85
|
+
can be provided to assign the incoming gradients for the reverse-mode
|
|
86
|
+
automatic differentiation pass.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
loss: A single-element array that holds the loss function value whose gradient is to be computed
|
|
90
|
+
grads: A dictionary of arrays that map from Warp arrays to their incoming gradients
|
|
91
|
+
"""
|
|
92
|
+
# if scalar loss is specified then initialize
|
|
93
|
+
# a 'seed' array for it, with gradient of one
|
|
94
|
+
if loss:
|
|
95
|
+
if loss.size > 1 or wp._src.types.type_size(loss.dtype) > 1:
|
|
96
|
+
raise RuntimeError("Can only return gradients for scalar loss functions.")
|
|
97
|
+
|
|
98
|
+
if not loss.requires_grad:
|
|
99
|
+
raise RuntimeError(
|
|
100
|
+
"Scalar loss arrays should have requires_grad=True set before calling Tape.backward()"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# set the seed grad to 1.0
|
|
104
|
+
loss.grad.fill_(1.0)
|
|
105
|
+
|
|
106
|
+
# simply apply dict grads to objects
|
|
107
|
+
# this is just for backward compat. with
|
|
108
|
+
# existing code before we added wp.array.grad attribute
|
|
109
|
+
if grads:
|
|
110
|
+
for a, g in grads.items():
|
|
111
|
+
if a.grad is None:
|
|
112
|
+
a.grad = g
|
|
113
|
+
else:
|
|
114
|
+
# ensure we can capture this backward pass in a CUDA graph
|
|
115
|
+
a.grad.assign(g)
|
|
116
|
+
|
|
117
|
+
# run launches backwards
|
|
118
|
+
for launch in reversed(self.launches):
|
|
119
|
+
if callable(launch):
|
|
120
|
+
launch()
|
|
121
|
+
|
|
122
|
+
else:
|
|
123
|
+
# kernel option takes precedence over module option
|
|
124
|
+
enable_backward = launch[0].options.get("enable_backward")
|
|
125
|
+
if enable_backward is False:
|
|
126
|
+
msg = f"Running the tape backwards may produce incorrect gradients because recorded kernel {launch[0].key} is configured with the option 'enable_backward=False'."
|
|
127
|
+
wp._src.utils.warn(msg)
|
|
128
|
+
elif enable_backward is None:
|
|
129
|
+
enable_backward = launch[0].module.options.get("enable_backward")
|
|
130
|
+
if enable_backward is False:
|
|
131
|
+
msg = f"Running the tape backwards may produce incorrect gradients because recorded kernel {launch[0].key} is defined in a module with the option 'enable_backward=False' set."
|
|
132
|
+
wp._src.utils.warn(msg)
|
|
133
|
+
|
|
134
|
+
kernel = launch[0]
|
|
135
|
+
dim = launch[1]
|
|
136
|
+
max_blocks = launch[2]
|
|
137
|
+
inputs = launch[3]
|
|
138
|
+
outputs = launch[4]
|
|
139
|
+
device = launch[5]
|
|
140
|
+
block_dim = launch[6]
|
|
141
|
+
|
|
142
|
+
adj_inputs = []
|
|
143
|
+
adj_outputs = []
|
|
144
|
+
|
|
145
|
+
# lookup adjoint inputs
|
|
146
|
+
for a in inputs:
|
|
147
|
+
adj_inputs.append(self.get_adjoint(a))
|
|
148
|
+
|
|
149
|
+
# lookup adjoint outputs, todo: only allocate outputs if necessary
|
|
150
|
+
for a in outputs:
|
|
151
|
+
adj_outputs.append(self.get_adjoint(a))
|
|
152
|
+
|
|
153
|
+
if enable_backward:
|
|
154
|
+
wp.launch(
|
|
155
|
+
kernel=kernel,
|
|
156
|
+
dim=dim,
|
|
157
|
+
inputs=inputs,
|
|
158
|
+
outputs=outputs,
|
|
159
|
+
adj_inputs=adj_inputs,
|
|
160
|
+
adj_outputs=adj_outputs,
|
|
161
|
+
device=device,
|
|
162
|
+
adjoint=True,
|
|
163
|
+
max_blocks=max_blocks,
|
|
164
|
+
block_dim=block_dim,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# record a kernel launch on the tape
|
|
168
|
+
def record_launch(self, kernel, dim, max_blocks, inputs, outputs, device, block_dim=0, metadata=None):
|
|
169
|
+
if metadata is None:
|
|
170
|
+
metadata = {}
|
|
171
|
+
self.launches.append([kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata])
|
|
172
|
+
|
|
173
|
+
def record_func(self, backward, arrays):
|
|
174
|
+
"""
|
|
175
|
+
Records a custom function to be executed only in the backward pass.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
backward (Callable): A callable Python object (can be any function) that will be executed in the backward pass.
|
|
179
|
+
arrays (list): A list of arrays that are used by the backward function. The tape keeps track of these to be able to zero their gradients in Tape.zero()
|
|
180
|
+
"""
|
|
181
|
+
self.launches.append(backward)
|
|
182
|
+
|
|
183
|
+
for a in arrays:
|
|
184
|
+
if isinstance(a, wp.array) and a.grad:
|
|
185
|
+
self.gradients[a] = a.grad
|
|
186
|
+
else:
|
|
187
|
+
raise RuntimeError(
|
|
188
|
+
f"Array {a} is not of type wp.array or is missing a gradient array. Set array parameter requires_grad=True during instantiation."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def record_scope_begin(self, scope_name, metadata=None):
|
|
192
|
+
"""
|
|
193
|
+
Begin a scope on the tape to group operations together. Scopes are only used in the visualization functions.
|
|
194
|
+
"""
|
|
195
|
+
if metadata is None:
|
|
196
|
+
metadata = {}
|
|
197
|
+
self.scopes.append((len(self.launches), scope_name, metadata))
|
|
198
|
+
|
|
199
|
+
def record_scope_end(self, remove_scope_if_empty=True):
|
|
200
|
+
"""
|
|
201
|
+
End a scope on the tape.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
remove_scope_if_empty (bool): If True, the scope will be removed if no kernel launches were recorded within it.
|
|
205
|
+
"""
|
|
206
|
+
if remove_scope_if_empty and self.scopes[-1][0] == len(self.launches):
|
|
207
|
+
self.scopes = self.scopes[:-1]
|
|
208
|
+
else:
|
|
209
|
+
self.scopes.append((len(self.launches), None, None))
|
|
210
|
+
|
|
211
|
+
def _check_kernel_array_access(self, kernel, args):
|
|
212
|
+
"""Detect illegal inter-kernel write after read access patterns during launch capture"""
|
|
213
|
+
adj = kernel.adj
|
|
214
|
+
kernel_name = adj.fun_name
|
|
215
|
+
filename = adj.filename
|
|
216
|
+
lineno = adj.fun_lineno
|
|
217
|
+
|
|
218
|
+
for i, arg in enumerate(args):
|
|
219
|
+
if isinstance(arg, wp.array):
|
|
220
|
+
arg_name = adj.args[i].label
|
|
221
|
+
|
|
222
|
+
# we check write condition first because we allow (write --> read) within the same kernel
|
|
223
|
+
if adj.args[i].is_write:
|
|
224
|
+
arg.mark_write(arg_name=arg_name, kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
225
|
+
|
|
226
|
+
if adj.args[i].is_read:
|
|
227
|
+
arg.mark_read()
|
|
228
|
+
|
|
229
|
+
# returns the adjoint of a kernel parameter
|
|
230
|
+
def get_adjoint(self, a):
|
|
231
|
+
if not wp._src.types.is_array(a) and not isinstance(a, wp._src.codegen.StructInstance):
|
|
232
|
+
# if input is a simple type (e.g.: float, vec3, etc) or a non-Warp array,
|
|
233
|
+
# then no gradient needed (we only return gradients through Warp arrays and structs)
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
elif wp._src.types.is_array(a) and a.grad:
|
|
237
|
+
# keep track of all gradients used by the tape (for zeroing)
|
|
238
|
+
# ignore the scalar loss since we don't want to clear its grad
|
|
239
|
+
self.gradients[a] = a.grad
|
|
240
|
+
return a.grad
|
|
241
|
+
|
|
242
|
+
elif isinstance(a, wp._src.codegen.StructInstance):
|
|
243
|
+
adj = a._cls()
|
|
244
|
+
for name, _ in a._cls.ctype._fields_:
|
|
245
|
+
if name.startswith("_"):
|
|
246
|
+
continue
|
|
247
|
+
if isinstance(a._cls.vars[name].type, wp.array):
|
|
248
|
+
arr = getattr(a, name)
|
|
249
|
+
if arr.grad:
|
|
250
|
+
grad = self.gradients[arr] = arr.grad
|
|
251
|
+
else:
|
|
252
|
+
grad = None
|
|
253
|
+
setattr(adj, name, grad)
|
|
254
|
+
elif isinstance(a._cls.vars[name].type, wp._src.codegen.Struct):
|
|
255
|
+
setattr(adj, name, self.get_adjoint(getattr(a, name)))
|
|
256
|
+
else:
|
|
257
|
+
setattr(adj, name, getattr(a, name))
|
|
258
|
+
|
|
259
|
+
self.gradients[a] = adj
|
|
260
|
+
return adj
|
|
261
|
+
|
|
262
|
+
return None
|
|
263
|
+
|
|
264
|
+
def reset(self):
|
|
265
|
+
"""
|
|
266
|
+
Clear all operations recorded on the tape and zero out all gradients.
|
|
267
|
+
"""
|
|
268
|
+
self.launches = []
|
|
269
|
+
self.scopes = []
|
|
270
|
+
self.zero()
|
|
271
|
+
if wp.config.verify_autograd_array_access:
|
|
272
|
+
self._reset_array_read_flags()
|
|
273
|
+
|
|
274
|
+
def zero(self):
|
|
275
|
+
"""
|
|
276
|
+
Zero out all gradients recorded on the tape.
|
|
277
|
+
"""
|
|
278
|
+
for a, g in self.gradients.items():
|
|
279
|
+
if isinstance(a, wp._src.codegen.StructInstance):
|
|
280
|
+
for name in g._cls.vars:
|
|
281
|
+
if isinstance(g._cls.vars[name].type, wp.array) and g._cls.vars[name].requires_grad:
|
|
282
|
+
getattr(g, name).zero_()
|
|
283
|
+
else:
|
|
284
|
+
g.zero_()
|
|
285
|
+
|
|
286
|
+
def _reset_array_read_flags(self):
|
|
287
|
+
"""
|
|
288
|
+
Reset all recorded array read flags to False
|
|
289
|
+
"""
|
|
290
|
+
for a in self.gradients:
|
|
291
|
+
if isinstance(a, wp.array):
|
|
292
|
+
a.mark_init()
|
|
293
|
+
|
|
294
|
+
def visualize(
|
|
295
|
+
self,
|
|
296
|
+
filename: str | None = None,
|
|
297
|
+
simplify_graph: bool = True,
|
|
298
|
+
hide_readonly_arrays: bool = False,
|
|
299
|
+
array_labels: dict[wp.array, str] | None = None,
|
|
300
|
+
choose_longest_node_name: bool = True,
|
|
301
|
+
ignore_graph_scopes: bool = False,
|
|
302
|
+
track_inputs: list[wp.array] | None = None,
|
|
303
|
+
track_outputs: list[wp.array] | None = None,
|
|
304
|
+
track_input_names: list[str] | None = None,
|
|
305
|
+
track_output_names: list[str] | None = None,
|
|
306
|
+
graph_direction: str = "LR",
|
|
307
|
+
) -> str:
|
|
308
|
+
"""Visualize the recorded operations on the tape as a `GraphViz diagram <https://graphviz.org/>`_.
|
|
309
|
+
|
|
310
|
+
Example
|
|
311
|
+
-------
|
|
312
|
+
|
|
313
|
+
.. code-block:: python
|
|
314
|
+
|
|
315
|
+
import warp as wp
|
|
316
|
+
|
|
317
|
+
tape = wp.Tape()
|
|
318
|
+
with tape:
|
|
319
|
+
# record Warp kernel launches here
|
|
320
|
+
wp.launch(...)
|
|
321
|
+
|
|
322
|
+
dot_code = tape.visualize("tape.dot")
|
|
323
|
+
|
|
324
|
+
This function creates a GraphViz dot file that can be rendered into an image using the GraphViz command line tool, e.g. via
|
|
325
|
+
|
|
326
|
+
.. code-block:: bash
|
|
327
|
+
|
|
328
|
+
dot -Tpng tape.dot -o tape.png
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
filename: The filename to save the visualization to (optional).
|
|
332
|
+
simplify_graph: If True, simplify the graph by detecting repeated kernel launch sequences and summarizing them in subgraphs.
|
|
333
|
+
hide_readonly_arrays: If True, hide arrays that are not modified by any kernel launch.
|
|
334
|
+
array_labels: A dictionary mapping arrays to custom labels.
|
|
335
|
+
choose_longest_node_name: If True, the automatic name resolution will aim to find the longest name for each array in the computation graph.
|
|
336
|
+
ignore_graph_scopes: If True, ignore the scopes recorded on the tape when visualizing the graph.
|
|
337
|
+
track_inputs: A list of arrays to track as inputs in the graph to ensure they are shown regardless of the `hide_readonly_arrays` setting.
|
|
338
|
+
track_outputs: A list of arrays to track as outputs in the graph so that they remain visible.
|
|
339
|
+
track_input_names: A list of custom names for the input arrays to track in the graph (used in conjunction with `track_inputs`).
|
|
340
|
+
track_output_names: A list of custom names for the output arrays to track in the graph (used in conjunction with `track_outputs`).
|
|
341
|
+
graph_direction: The direction of the graph layout (default: "LR").
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
str: The dot code representing the graph.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
if track_output_names is None:
|
|
348
|
+
track_output_names = []
|
|
349
|
+
if track_input_names is None:
|
|
350
|
+
track_input_names = []
|
|
351
|
+
if track_outputs is None:
|
|
352
|
+
track_outputs = []
|
|
353
|
+
if track_inputs is None:
|
|
354
|
+
track_inputs = []
|
|
355
|
+
if array_labels is None:
|
|
356
|
+
array_labels = {}
|
|
357
|
+
return visualize_tape_graphviz(
|
|
358
|
+
self,
|
|
359
|
+
filename,
|
|
360
|
+
simplify_graph,
|
|
361
|
+
hide_readonly_arrays,
|
|
362
|
+
array_labels,
|
|
363
|
+
choose_longest_node_name,
|
|
364
|
+
ignore_graph_scopes,
|
|
365
|
+
track_inputs,
|
|
366
|
+
track_outputs,
|
|
367
|
+
track_input_names,
|
|
368
|
+
track_output_names,
|
|
369
|
+
graph_direction,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class TapeVisitor:
|
|
374
|
+
def emit_array_node(self, arr: wp.array, label: str, active_scope_stack: list[str], indent_level: int):
|
|
375
|
+
pass
|
|
376
|
+
|
|
377
|
+
def emit_kernel_launch_node(
|
|
378
|
+
self, kernel: wp.Kernel, kernel_launch_id: str, launch_data: dict, rendered: bool, indent_level: int
|
|
379
|
+
):
|
|
380
|
+
pass
|
|
381
|
+
|
|
382
|
+
def emit_edge_array_kernel(self, arr: wp.array, kernel_launch_id: str, kernel_input_id: int, indent_level: int):
|
|
383
|
+
pass
|
|
384
|
+
|
|
385
|
+
def emit_edge_kernel_array(self, kernel_launch_id: str, kernel_output_id: int, arr: wp.array, indent_level: int):
|
|
386
|
+
pass
|
|
387
|
+
|
|
388
|
+
def emit_edge_array_array(self, src: wp.array, dst: wp.array, indent_level: int):
|
|
389
|
+
pass
|
|
390
|
+
|
|
391
|
+
def emit_scope_begin(self, active_scope_id: int, active_scope_name: str, metadata: dict, indent_level: int):
|
|
392
|
+
pass
|
|
393
|
+
|
|
394
|
+
def emit_scope_end(self, indent_level: int):
|
|
395
|
+
pass
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def get_struct_vars(x: wp._src.codegen.StructInstance):
|
|
399
|
+
return {varname: getattr(x, varname) for varname, _ in x._cls.ctype._fields_}
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
class GraphvizTapeVisitor(TapeVisitor):
|
|
403
|
+
def __init__(self):
|
|
404
|
+
self.graphviz_lines = []
|
|
405
|
+
self.indent_str = "\t"
|
|
406
|
+
self.scope_classes = {}
|
|
407
|
+
self.max_indent = 0
|
|
408
|
+
# mapping from array pointer to kernel:port ID
|
|
409
|
+
self.pointer_to_port = {}
|
|
410
|
+
# set of inserted edges between kernel:port IDs
|
|
411
|
+
self.edges = set()
|
|
412
|
+
# set of inserted array nodes
|
|
413
|
+
self.array_nodes = set()
|
|
414
|
+
|
|
415
|
+
@staticmethod
|
|
416
|
+
def sanitize(s):
|
|
417
|
+
return (
|
|
418
|
+
s.replace("\n", " ")
|
|
419
|
+
.replace('"', " ")
|
|
420
|
+
.replace("'", " ")
|
|
421
|
+
.replace("[", "[")
|
|
422
|
+
.replace("]", "]")
|
|
423
|
+
.replace("`", "`")
|
|
424
|
+
.replace(":", ":")
|
|
425
|
+
.replace("\\", "\\\\")
|
|
426
|
+
.replace("/", "/")
|
|
427
|
+
.replace("(", "(")
|
|
428
|
+
.replace(")", ")")
|
|
429
|
+
.replace(",", "")
|
|
430
|
+
.replace("{", "{")
|
|
431
|
+
.replace("}", "}")
|
|
432
|
+
.replace("<", "<")
|
|
433
|
+
.replace(">", ">")
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
@staticmethod
|
|
437
|
+
def dtype2str(dtype):
|
|
438
|
+
type_str = str(dtype)
|
|
439
|
+
if hasattr(dtype, "key"):
|
|
440
|
+
type_str = dtype.key
|
|
441
|
+
elif "'" in type_str:
|
|
442
|
+
type_str = type_str.split("'")[1]
|
|
443
|
+
return type_str
|
|
444
|
+
|
|
445
|
+
def emit_array_node(self, arr: wp.array, label: str, active_scope_stack: list[str], indent_level: int):
|
|
446
|
+
if arr.ptr in self.array_nodes:
|
|
447
|
+
return
|
|
448
|
+
if arr.ptr in self.pointer_to_port:
|
|
449
|
+
return
|
|
450
|
+
self.array_nodes.add(arr.ptr)
|
|
451
|
+
color = "lightgray"
|
|
452
|
+
if arr.requires_grad:
|
|
453
|
+
color = "#76B900"
|
|
454
|
+
options = [
|
|
455
|
+
f'label="{label}"',
|
|
456
|
+
"shape=ellipse",
|
|
457
|
+
"style=filled",
|
|
458
|
+
f'fillcolor="{color}"',
|
|
459
|
+
]
|
|
460
|
+
chart_indent = self.indent_str * indent_level
|
|
461
|
+
arr_id = f"arr{arr.ptr}"
|
|
462
|
+
type_str = self.dtype2str(arr.dtype)
|
|
463
|
+
# type_str = self.sanitize(type_str)
|
|
464
|
+
# class_name = "array" if not arr.requires_grad else "array_grad"
|
|
465
|
+
# self.graphviz_lines.append(chart_indent + f'{arr_id}(["`{label}`"]):::{class_name}')
|
|
466
|
+
tooltip = (
|
|
467
|
+
f"Array {label} / ptr={arr.ptr}, shape={arr.shape}, dtype={type_str}, requires_grad={arr.requires_grad}"
|
|
468
|
+
)
|
|
469
|
+
options.append(f'tooltip="{tooltip}"')
|
|
470
|
+
# self.graphviz_lines.append(chart_indent + f'click {arr_id} callback "{tooltip}"')
|
|
471
|
+
# self.max_indent = max(self.max_indent, indent_level)
|
|
472
|
+
self.graphviz_lines.append(f"{chart_indent}{arr_id} [{','.join(options)}];")
|
|
473
|
+
|
|
474
|
+
def emit_kernel_launch_node(
|
|
475
|
+
self, kernel: wp.Kernel, kernel_launch_id: str, launch_data: dict, rendered: bool, indent_level: int
|
|
476
|
+
):
|
|
477
|
+
if not rendered:
|
|
478
|
+
return
|
|
479
|
+
chart_indent = self.indent_str * indent_level
|
|
480
|
+
|
|
481
|
+
table = []
|
|
482
|
+
table.append(
|
|
483
|
+
'<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" border="0" cellspacing="2" cellpadding="4" bgcolor="#888888" gradientangle="0">'
|
|
484
|
+
)
|
|
485
|
+
table.append(f'<TR><TD BGCOLOR="#ffffffaa" colspan="2" align="center"><b>{kernel.key}</b></TD></TR>')
|
|
486
|
+
num_inputs = len(launch_data["inputs"])
|
|
487
|
+
num_outputs = len(launch_data["outputs"])
|
|
488
|
+
nrows = max(num_inputs, num_outputs)
|
|
489
|
+
nargs = len(kernel.adj.args)
|
|
490
|
+
for i in range(nrows):
|
|
491
|
+
row = []
|
|
492
|
+
if i < num_inputs:
|
|
493
|
+
arg = kernel.adj.args[i]
|
|
494
|
+
port_id = f"in_{i}"
|
|
495
|
+
if isinstance(arg.type, wp.array):
|
|
496
|
+
tooltip = f"array: dtype={self.dtype2str(arg.type.dtype)}"
|
|
497
|
+
else:
|
|
498
|
+
tooltip = f"dtype={self.sanitize(self.dtype2str(arg.type))}"
|
|
499
|
+
row.append(
|
|
500
|
+
f'<TD PORT="{port_id}" BGCOLOR="#BBBBBB" align="left" title="{tooltip}"><font color="black">{arg.label}</font></TD>'
|
|
501
|
+
)
|
|
502
|
+
launch_data["inputs"][i]
|
|
503
|
+
# if var is not None and isinstance(var, wp.array):
|
|
504
|
+
# self.pointer_to_port[var.ptr] = f"{kernel_launch_id}:{port_id}"
|
|
505
|
+
else:
|
|
506
|
+
row.append('<TD BORDER="0"></TD>')
|
|
507
|
+
# if i >= nargs - 1:
|
|
508
|
+
# row.append('<TD></TD>')
|
|
509
|
+
# table.append(f'<TR>{row[0]}{row[1]}</TR>')
|
|
510
|
+
# break
|
|
511
|
+
if i < num_outputs and i + num_inputs < nargs:
|
|
512
|
+
arg = kernel.adj.args[i + num_inputs].label
|
|
513
|
+
port_id = f"out_{i}"
|
|
514
|
+
row.append(
|
|
515
|
+
f'<TD PORT="{port_id}" BGCOLOR="#BBBBBB" align="right"><font color="black">{arg}</font></TD>'
|
|
516
|
+
)
|
|
517
|
+
launch_data["outputs"][i]
|
|
518
|
+
# if var is not None and isinstance(var, wp.array):
|
|
519
|
+
# self.pointer_to_port[var.ptr] = f"{kernel_launch_id}:{port_id}"
|
|
520
|
+
else:
|
|
521
|
+
row.append('<TD BORDER="0"></TD>')
|
|
522
|
+
table.append(f"<TR>{row[0]}{row[1]}</TR>")
|
|
523
|
+
table.append("</TABLE>")
|
|
524
|
+
|
|
525
|
+
label = f"{chart_indent}\n".join(table)
|
|
526
|
+
node_attrs = f"label=<{label}>"
|
|
527
|
+
if "caller" in launch_data:
|
|
528
|
+
caller = launch_data["caller"]
|
|
529
|
+
node_attrs += f',tooltip="{self.sanitize(caller["file"])}:{caller["lineno"]} ({caller["func"]})"'
|
|
530
|
+
|
|
531
|
+
self.graphviz_lines.append(f"{chart_indent}{kernel_launch_id} [{node_attrs}];")
|
|
532
|
+
|
|
533
|
+
def emit_edge_array_kernel(self, arr_ptr: int, kernel_launch_id: str, kernel_input_id: int, indent_level: int):
|
|
534
|
+
chart_indent = self.indent_str * indent_level
|
|
535
|
+
if arr_ptr in self.pointer_to_port:
|
|
536
|
+
arr_id = self.pointer_to_port[arr_ptr]
|
|
537
|
+
elif arr_ptr in self.array_nodes:
|
|
538
|
+
arr_id = f"arr{arr_ptr}"
|
|
539
|
+
else:
|
|
540
|
+
return
|
|
541
|
+
target_id = f"{kernel_launch_id}:in_{kernel_input_id}"
|
|
542
|
+
if (arr_id, target_id) in self.edges:
|
|
543
|
+
return
|
|
544
|
+
self.edges.add((arr_id, target_id))
|
|
545
|
+
self.graphviz_lines.append(f"{chart_indent}{arr_id} -> {target_id}")
|
|
546
|
+
|
|
547
|
+
def emit_edge_kernel_array(self, kernel_launch_id: str, kernel_output_id: int, arr_ptr: int, indent_level: int):
|
|
548
|
+
chart_indent = self.indent_str * indent_level
|
|
549
|
+
if arr_ptr in self.pointer_to_port:
|
|
550
|
+
arr_id = self.pointer_to_port[arr_ptr]
|
|
551
|
+
elif arr_ptr in self.array_nodes:
|
|
552
|
+
arr_id = f"arr{arr_ptr}"
|
|
553
|
+
else:
|
|
554
|
+
return
|
|
555
|
+
source_id = f"{kernel_launch_id}:out_{kernel_output_id}"
|
|
556
|
+
if (source_id, arr_id) in self.edges:
|
|
557
|
+
return
|
|
558
|
+
self.edges.add((source_id, arr_id))
|
|
559
|
+
self.graphviz_lines.append(f"{chart_indent}{source_id} -> {arr_id};")
|
|
560
|
+
|
|
561
|
+
def emit_edge_array_array(self, src: wp.array, dst: wp.array, indent_level: int):
|
|
562
|
+
chart_indent = self.indent_str * indent_level
|
|
563
|
+
src_id = f"arr{src.ptr}"
|
|
564
|
+
dst_id = f"arr{dst.ptr}"
|
|
565
|
+
if (src_id, dst_id) in self.edges:
|
|
566
|
+
return
|
|
567
|
+
self.edges.add((src_id, dst_id))
|
|
568
|
+
self.graphviz_lines.append(f'{chart_indent}{src_id} -> {dst_id} [color="#0072B9",constraint=false];')
|
|
569
|
+
|
|
570
|
+
def emit_scope_begin(self, active_scope_id: int, active_scope_name: str, metadata: dict, indent_level: int):
|
|
571
|
+
chart_indent = self.indent_str * indent_level
|
|
572
|
+
scope_key = f"cluster{active_scope_id}"
|
|
573
|
+
scope_class = metadata.get("type", "scope")
|
|
574
|
+
self.graphviz_lines.append(f"{chart_indent}subgraph {scope_key} {{")
|
|
575
|
+
chart_indent += self.indent_str
|
|
576
|
+
self.graphviz_lines.append(f'{chart_indent}style="rounded,filled";')
|
|
577
|
+
if scope_class == "scope":
|
|
578
|
+
self.graphviz_lines.append(f'{chart_indent}fillcolor="#76B90022";')
|
|
579
|
+
self.graphviz_lines.append(f'{chart_indent}pencolor="#76B900";')
|
|
580
|
+
else:
|
|
581
|
+
self.graphviz_lines.append(f'{chart_indent}fillcolor="#0072B922";')
|
|
582
|
+
self.graphviz_lines.append(f'{chart_indent}pencolor="#0072B9";')
|
|
583
|
+
self.graphviz_lines.append(f"{chart_indent}label=<<b>{active_scope_name}</b>>;\n")
|
|
584
|
+
|
|
585
|
+
def emit_scope_end(self, indent_level: int):
|
|
586
|
+
chart_indent = self.indent_str * indent_level
|
|
587
|
+
self.graphviz_lines.append(f"{chart_indent}}};\n")
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
class ArrayStatsVisitor(TapeVisitor):
|
|
591
|
+
ArrayState = namedtuple("ArrayState", ["mean", "std", "min", "max"])
|
|
592
|
+
|
|
593
|
+
def __init__(self):
|
|
594
|
+
self.array_names = {}
|
|
595
|
+
self.launch_data = {}
|
|
596
|
+
self.launches = []
|
|
597
|
+
self.array_value_stats = []
|
|
598
|
+
self.array_grad_stats = []
|
|
599
|
+
|
|
600
|
+
def emit_array_node(self, arr: wp.array, label: str, active_scope_stack: list[str], indent_level: int):
|
|
601
|
+
if arr.device.is_capturing:
|
|
602
|
+
raise RuntimeError("Cannot record arrays while graph capturing is active.")
|
|
603
|
+
self.array_names[arr.ptr] = label
|
|
604
|
+
|
|
605
|
+
def emit_kernel_launch_node(
|
|
606
|
+
self, kernel: wp.Kernel, kernel_launch_id: str, launch_data: dict, rendered: bool, indent_level: int
|
|
607
|
+
):
|
|
608
|
+
self.launch_data[kernel_launch_id] = launch_data
|
|
609
|
+
self.launches.append(kernel_launch_id)
|
|
610
|
+
value_stats = {}
|
|
611
|
+
grad_stats = {}
|
|
612
|
+
for output in launch_data["outputs"]:
|
|
613
|
+
if isinstance(output, wp.array):
|
|
614
|
+
arr_np = output.numpy()
|
|
615
|
+
value_stats[output.ptr] = self.ArrayState(
|
|
616
|
+
mean=arr_np.mean(), std=arr_np.std(), min=arr_np.min(), max=arr_np.max()
|
|
617
|
+
)
|
|
618
|
+
for input in launch_data["inputs"]:
|
|
619
|
+
if isinstance(input, wp.array) and input.requires_grad and input.grad is not None:
|
|
620
|
+
arr_np = input.grad.numpy()
|
|
621
|
+
grad_stats[input.ptr] = self.ArrayState(
|
|
622
|
+
mean=arr_np.mean(), std=arr_np.std(), min=arr_np.min(), max=arr_np.max()
|
|
623
|
+
)
|
|
624
|
+
self.array_value_stats.append(value_stats)
|
|
625
|
+
self.array_grad_stats.insert(0, grad_stats)
|
|
626
|
+
|
|
627
|
+
|
|
628
|
+
Launch = namedtuple(
|
|
629
|
+
"Launch", ["id", "kernel", "dim", "max_blocks", "inputs", "outputs", "device", "block_dim", "metadata"]
|
|
630
|
+
)
|
|
631
|
+
RepeatedSequence = namedtuple("RepeatedSequence", ["start", "end", "repetitions"])
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def visit_tape(
|
|
635
|
+
tape: Tape,
|
|
636
|
+
visitor: TapeVisitor,
|
|
637
|
+
simplify_graph: bool = True,
|
|
638
|
+
hide_readonly_arrays: bool = False,
|
|
639
|
+
array_labels: dict[wp.array, str] | None = None,
|
|
640
|
+
choose_longest_node_name: bool = True,
|
|
641
|
+
ignore_graph_scopes: bool = False,
|
|
642
|
+
track_inputs: list[wp.array] | None = None,
|
|
643
|
+
track_outputs: list[wp.array] | None = None,
|
|
644
|
+
track_input_names: list[str] | None = None,
|
|
645
|
+
track_output_names: list[str] | None = None,
|
|
646
|
+
):
|
|
647
|
+
if track_output_names is None:
|
|
648
|
+
track_output_names = []
|
|
649
|
+
if track_input_names is None:
|
|
650
|
+
track_input_names = []
|
|
651
|
+
if track_outputs is None:
|
|
652
|
+
track_outputs = []
|
|
653
|
+
if track_inputs is None:
|
|
654
|
+
track_inputs = []
|
|
655
|
+
if array_labels is None:
|
|
656
|
+
array_labels = {}
|
|
657
|
+
|
|
658
|
+
def get_launch_id(launch):
|
|
659
|
+
kernel = launch[0]
|
|
660
|
+
suffix = ""
|
|
661
|
+
if len(launch) > 7:
|
|
662
|
+
metadata = launch[7]
|
|
663
|
+
# calling function helps to identify unique launches
|
|
664
|
+
if "caller" in metadata:
|
|
665
|
+
caller = metadata["caller"]
|
|
666
|
+
suffix = str(hash(caller["file"] + caller["func"] + str(caller["lineno"])))
|
|
667
|
+
return f"{kernel.module.name}.{kernel.key}{suffix}"
|
|
668
|
+
|
|
669
|
+
# exclude function calls, only consider kernel launches
|
|
670
|
+
kernel_launches = []
|
|
671
|
+
kernel_scopes = []
|
|
672
|
+
|
|
673
|
+
next_scope_id = 0
|
|
674
|
+
id_offset = 0
|
|
675
|
+
for i, launch in enumerate(tape.launches):
|
|
676
|
+
if isinstance(launch, list):
|
|
677
|
+
kernel_launches.append(launch)
|
|
678
|
+
else:
|
|
679
|
+
id_offset -= 1
|
|
680
|
+
while next_scope_id < len(tape.scopes) and i == tape.scopes[next_scope_id][0]:
|
|
681
|
+
scope = tape.scopes[next_scope_id]
|
|
682
|
+
# update scope launch index to account for removed function calls
|
|
683
|
+
new_scope = (scope[0] + id_offset, *scope[1:])
|
|
684
|
+
kernel_scopes.append(new_scope)
|
|
685
|
+
next_scope_id += 1
|
|
686
|
+
|
|
687
|
+
launch_structs = [
|
|
688
|
+
Launch(
|
|
689
|
+
id=get_launch_id(launch),
|
|
690
|
+
kernel=launch[0],
|
|
691
|
+
dim=launch[1],
|
|
692
|
+
max_blocks=launch[2],
|
|
693
|
+
inputs=launch[3],
|
|
694
|
+
outputs=launch[4],
|
|
695
|
+
device=launch[5],
|
|
696
|
+
block_dim=launch[6],
|
|
697
|
+
metadata=launch[7] if len(launch) > 7 else {},
|
|
698
|
+
)
|
|
699
|
+
for launch in kernel_launches
|
|
700
|
+
]
|
|
701
|
+
launch_ids = [get_launch_id(launch) for launch in kernel_launches]
|
|
702
|
+
|
|
703
|
+
def get_repeating_sequences(sequence: list[str]):
|
|
704
|
+
# yield all consecutively repeating subsequences in descending order of length
|
|
705
|
+
for length in range(len(sequence) // 2 + 1, 0, -1):
|
|
706
|
+
for start in range(len(sequence) - length):
|
|
707
|
+
if sequence[start : start + length] == sequence[start + length : start + 2 * length]:
|
|
708
|
+
# we found a sequence that repeats at least once
|
|
709
|
+
candidate = RepeatedSequence(start, start + length, 2)
|
|
710
|
+
if length == 1:
|
|
711
|
+
# this repetition cannot be made up of smaller repetitions
|
|
712
|
+
yield candidate
|
|
713
|
+
|
|
714
|
+
# check if this sequence is made up entirely of smaller repetitions
|
|
715
|
+
for sl in range(1, length // 2 + 1):
|
|
716
|
+
# loop over subsequence lengths and check if they repeat
|
|
717
|
+
subseq = sequence[start : start + sl]
|
|
718
|
+
if all(
|
|
719
|
+
sequence[start + i * sl : start + (i + 1) * sl] == subseq for i in range(1, length // sl)
|
|
720
|
+
):
|
|
721
|
+
rep_count = length // sl + 1
|
|
722
|
+
# check whether there are more repetitions beyond the previous end
|
|
723
|
+
for cstart in range(start + length, len(sequence) - sl, sl):
|
|
724
|
+
if sequence[cstart : cstart + sl] != subseq:
|
|
725
|
+
break
|
|
726
|
+
rep_count += 1
|
|
727
|
+
candidate = RepeatedSequence(start, start + sl, rep_count)
|
|
728
|
+
yield candidate
|
|
729
|
+
break
|
|
730
|
+
|
|
731
|
+
def process_sequence(sequence: list[str]) -> RepeatedSequence | None:
|
|
732
|
+
# find the longest contiguous repetition in the sequence
|
|
733
|
+
if len(sequence) < 2:
|
|
734
|
+
return None
|
|
735
|
+
|
|
736
|
+
for r in get_repeating_sequences(sequence):
|
|
737
|
+
rlen = r.end - r.start
|
|
738
|
+
rseq = sequence[r.start : r.end]
|
|
739
|
+
# ensure that the repetitions of this subsequence immediately follow each other
|
|
740
|
+
candidates = defaultdict(int) # mapping from start index to number of repetitions
|
|
741
|
+
curr_start = r.start
|
|
742
|
+
i = r.end
|
|
743
|
+
while i + rlen <= len(sequence):
|
|
744
|
+
if sequence[i : i + rlen] == rseq:
|
|
745
|
+
candidates[curr_start] += 1
|
|
746
|
+
i += rlen
|
|
747
|
+
else:
|
|
748
|
+
try:
|
|
749
|
+
curr_start = sequence.index(rseq, i)
|
|
750
|
+
i = curr_start + rlen
|
|
751
|
+
except ValueError:
|
|
752
|
+
break
|
|
753
|
+
|
|
754
|
+
if len(candidates) > 0:
|
|
755
|
+
start, reps = max(candidates.items(), key=lambda x: x[1])
|
|
756
|
+
return RepeatedSequence(start, start + rlen, reps + 1)
|
|
757
|
+
|
|
758
|
+
return None
|
|
759
|
+
|
|
760
|
+
repetitions = []
|
|
761
|
+
|
|
762
|
+
def find_sequences(sequence):
|
|
763
|
+
# recursively find repetitions in sequence
|
|
764
|
+
nonlocal repetitions
|
|
765
|
+
|
|
766
|
+
if len(sequence) == 0:
|
|
767
|
+
return
|
|
768
|
+
|
|
769
|
+
# find LRS in current sequence
|
|
770
|
+
longest_rep = process_sequence(sequence)
|
|
771
|
+
if longest_rep is None:
|
|
772
|
+
return
|
|
773
|
+
|
|
774
|
+
# process sequence up until the current LRS
|
|
775
|
+
find_sequences(sequence[: longest_rep.start])
|
|
776
|
+
|
|
777
|
+
# process repeated sequence
|
|
778
|
+
rstr = sequence[longest_rep.start : longest_rep.end]
|
|
779
|
+
if longest_rep.repetitions >= 2:
|
|
780
|
+
repetitions.append(longest_rep)
|
|
781
|
+
|
|
782
|
+
find_sequences(rstr)
|
|
783
|
+
|
|
784
|
+
# process remaining sequence
|
|
785
|
+
rlen = longest_rep.end - longest_rep.start
|
|
786
|
+
reps = longest_rep.repetitions
|
|
787
|
+
end_idx = longest_rep.start + (reps + 1) * rlen
|
|
788
|
+
if end_idx < len(sequence):
|
|
789
|
+
find_sequences(sequence[end_idx:])
|
|
790
|
+
|
|
791
|
+
return
|
|
792
|
+
|
|
793
|
+
find_sequences(launch_ids)
|
|
794
|
+
|
|
795
|
+
wrap_around_connections = set()
|
|
796
|
+
|
|
797
|
+
# mapping from array ptr to already existing array in a repetition
|
|
798
|
+
array_repeated = {}
|
|
799
|
+
|
|
800
|
+
array_to_launch = defaultdict(list)
|
|
801
|
+
launch_to_array = defaultdict(list)
|
|
802
|
+
|
|
803
|
+
if simplify_graph:
|
|
804
|
+
# mappings from unique launch string to index of first occurrence and vice versa
|
|
805
|
+
launch_to_index = {}
|
|
806
|
+
index_to_launch = {}
|
|
807
|
+
|
|
808
|
+
# new arrays of launches, scopes without repetitions
|
|
809
|
+
launches = []
|
|
810
|
+
scopes = []
|
|
811
|
+
|
|
812
|
+
def find_scope_end(scope_i):
|
|
813
|
+
opened_scopes = 0
|
|
814
|
+
for i in range(scope_i, len(kernel_scopes)):
|
|
815
|
+
scope = kernel_scopes[i]
|
|
816
|
+
if scope[1] is not None:
|
|
817
|
+
opened_scopes += 1
|
|
818
|
+
else:
|
|
819
|
+
opened_scopes -= 1
|
|
820
|
+
if opened_scopes == 0:
|
|
821
|
+
return scope[0]
|
|
822
|
+
return len(kernel_scopes)
|
|
823
|
+
|
|
824
|
+
def process_launches(kernel_launches, start_i, end_i, rep_i, scope_i, skipped_i):
|
|
825
|
+
nonlocal \
|
|
826
|
+
launches, \
|
|
827
|
+
scopes, \
|
|
828
|
+
launch_to_index, \
|
|
829
|
+
index_to_launch, \
|
|
830
|
+
wrap_around_connections, \
|
|
831
|
+
launch_to_array, \
|
|
832
|
+
array_to_launch
|
|
833
|
+
i = start_i # index of current launch
|
|
834
|
+
opened_scopes = 0
|
|
835
|
+
while i < end_i:
|
|
836
|
+
launch_id = launch_ids[i]
|
|
837
|
+
|
|
838
|
+
while (
|
|
839
|
+
scope_i < len(kernel_scopes)
|
|
840
|
+
and i >= kernel_scopes[scope_i][0]
|
|
841
|
+
and kernel_scopes[scope_i][1] is None
|
|
842
|
+
):
|
|
843
|
+
# add any missing closing scopes before we go into a repeating sequence
|
|
844
|
+
scope = kernel_scopes[scope_i]
|
|
845
|
+
if opened_scopes >= 1:
|
|
846
|
+
scopes.append((scope[0] - skipped_i, *scope[1:]))
|
|
847
|
+
scope_i += 1
|
|
848
|
+
opened_scopes -= 1
|
|
849
|
+
|
|
850
|
+
# keep track of the mapping between arrays and kernel launch arguments
|
|
851
|
+
for arg_i, arg in enumerate(kernel_launches[i].inputs + kernel_launches[i].outputs):
|
|
852
|
+
if isinstance(arg, wp.array):
|
|
853
|
+
array_to_launch[arg.ptr].append((launch_id, arg_i))
|
|
854
|
+
launch_to_array[(launch_id, arg_i)].append(arg)
|
|
855
|
+
|
|
856
|
+
# handle repetitions
|
|
857
|
+
if rep_i < len(repetitions):
|
|
858
|
+
rep = repetitions[rep_i]
|
|
859
|
+
if i == rep.start:
|
|
860
|
+
rep_len = rep.end - rep.start
|
|
861
|
+
after_rep = rep.start + rep.repetitions * rep_len
|
|
862
|
+
# check if there is a scope that matches the entire repetition
|
|
863
|
+
skip_adding_repetition_scope = False
|
|
864
|
+
for scope_j in range(scope_i, len(kernel_scopes)):
|
|
865
|
+
scope = kernel_scopes[scope_j]
|
|
866
|
+
if scope[0] > rep.start:
|
|
867
|
+
break
|
|
868
|
+
if scope[0] == rep.start and scope[1] is not None:
|
|
869
|
+
# check if this scope also ends at the end of the repetition
|
|
870
|
+
scope_end = find_scope_end(scope_j)
|
|
871
|
+
if scope_end == after_rep:
|
|
872
|
+
# replace scope details
|
|
873
|
+
kernel_scopes[scope_j] = (
|
|
874
|
+
rep.start,
|
|
875
|
+
f"{scope[1]} (repeated {rep.repetitions}x)",
|
|
876
|
+
{"type": "repeated", "count": rep.repetitions},
|
|
877
|
+
)
|
|
878
|
+
skip_adding_repetition_scope = True
|
|
879
|
+
break
|
|
880
|
+
|
|
881
|
+
if not skip_adding_repetition_scope:
|
|
882
|
+
# add a new scope marking this repetition
|
|
883
|
+
scope_name = f"repeated {rep.repetitions}x"
|
|
884
|
+
scopes.append((len(launches), scope_name, {"type": "repeated", "count": rep.repetitions}))
|
|
885
|
+
|
|
886
|
+
# process repetition recursively to handle nested repetitions
|
|
887
|
+
process_launches(kernel_launches, rep.start, rep.end, rep_i + 1, scope_i, skipped_i)
|
|
888
|
+
|
|
889
|
+
if not skip_adding_repetition_scope:
|
|
890
|
+
# close the scope we just added marking this repetition
|
|
891
|
+
scopes.append((len(launches), None, None))
|
|
892
|
+
|
|
893
|
+
# collect all output arrays from the first iteration
|
|
894
|
+
output_arrays = {}
|
|
895
|
+
for j in range(i, i + rep_len):
|
|
896
|
+
launch = kernel_launches[j]
|
|
897
|
+
launch_id = launch_ids[j]
|
|
898
|
+
for k, arg in enumerate(launch.outputs):
|
|
899
|
+
arg_i = k + len(launch.inputs)
|
|
900
|
+
if isinstance(arg, wp.array):
|
|
901
|
+
output_arrays[arg.ptr] = arg
|
|
902
|
+
array_to_launch[arg.ptr].append((launch_id, arg_i))
|
|
903
|
+
|
|
904
|
+
# find out which output arrays feed back as inputs to the next iteration
|
|
905
|
+
# so we can add them as wrap-around connections
|
|
906
|
+
for j in range(i + rep_len, i + 2 * rep_len):
|
|
907
|
+
launch = kernel_launches[j]
|
|
908
|
+
launch_id = launch_ids[j]
|
|
909
|
+
for arg_i, arg in enumerate(launch.inputs):
|
|
910
|
+
if isinstance(arg, wp.array) and arg.ptr in output_arrays:
|
|
911
|
+
first_encountered_var = launch_to_array[(launch_id, arg_i)][0]
|
|
912
|
+
# print(array_to_launch[arg.ptr])
|
|
913
|
+
# array_to_launch[arg.ptr].append((launch_id, arg_i))
|
|
914
|
+
# launch_to_array[(launch_id, arg_i)].append(arg)
|
|
915
|
+
src_launch = array_to_launch[arg.ptr][-1]
|
|
916
|
+
src_arr = launch_to_array[src_launch][-1]
|
|
917
|
+
wrap_around_connections.add((src_arr.ptr, first_encountered_var.ptr))
|
|
918
|
+
|
|
919
|
+
# map arrays appearing as launch arguments in following repetitions to their first occurrence
|
|
920
|
+
skip_len = rep.repetitions * rep_len
|
|
921
|
+
for j in range(i + rep_len, i + skip_len):
|
|
922
|
+
launch = kernel_launches[j]
|
|
923
|
+
launch_id = launch_ids[j]
|
|
924
|
+
for arg_i, arg in enumerate(launch.inputs + launch.outputs):
|
|
925
|
+
if isinstance(arg, wp.array):
|
|
926
|
+
array_repeated[arg.ptr] = launch_to_array[(launch_id, arg_i)][0].ptr
|
|
927
|
+
|
|
928
|
+
# skip launches during these repetitions
|
|
929
|
+
i += skip_len
|
|
930
|
+
skipped_i += skip_len - rep_len
|
|
931
|
+
rep_i += 1
|
|
932
|
+
|
|
933
|
+
# skip scopes during the repetitions
|
|
934
|
+
while scope_i < len(kernel_scopes) and i > kernel_scopes[scope_i][0]:
|
|
935
|
+
scope_i += 1
|
|
936
|
+
|
|
937
|
+
continue
|
|
938
|
+
|
|
939
|
+
# add launch
|
|
940
|
+
launch = kernel_launches[i]
|
|
941
|
+
launch_id = launch_ids[i]
|
|
942
|
+
if launch_id not in launch_to_index:
|
|
943
|
+
# we encountered an unseen kernel
|
|
944
|
+
j = len(launch_to_index)
|
|
945
|
+
launch_to_index[launch_id] = j
|
|
946
|
+
index_to_launch[j] = launch_id
|
|
947
|
+
launches.append(launch)
|
|
948
|
+
|
|
949
|
+
while scope_i < len(kernel_scopes) and i >= kernel_scopes[scope_i][0]:
|
|
950
|
+
# add scopes encompassing the kernels we added so far
|
|
951
|
+
scope = kernel_scopes[scope_i]
|
|
952
|
+
if scope[1] is not None:
|
|
953
|
+
scopes.append((scope[0] - skipped_i, *scope[1:]))
|
|
954
|
+
opened_scopes += 1
|
|
955
|
+
else:
|
|
956
|
+
if opened_scopes >= 1:
|
|
957
|
+
# only add closing scope if there was an opening scope
|
|
958
|
+
scopes.append((scope[0] - skipped_i, *scope[1:]))
|
|
959
|
+
opened_scopes -= 1
|
|
960
|
+
scope_i += 1
|
|
961
|
+
|
|
962
|
+
i += 1
|
|
963
|
+
|
|
964
|
+
# close any remaining open scopes
|
|
965
|
+
for _ in range(opened_scopes):
|
|
966
|
+
scopes.append((end_i - skipped_i, None, None))
|
|
967
|
+
|
|
968
|
+
process_launches(launch_structs, 0, len(launch_structs), 0, 0, 0)
|
|
969
|
+
|
|
970
|
+
# end of simplify_graph
|
|
971
|
+
else:
|
|
972
|
+
launches = launch_structs
|
|
973
|
+
scopes = kernel_scopes
|
|
974
|
+
|
|
975
|
+
node_labels = {}
|
|
976
|
+
inserted_arrays = {} # mapping from array ptr to array
|
|
977
|
+
kernel_launch_count = defaultdict(int)
|
|
978
|
+
# array -> list of kernels that modify it
|
|
979
|
+
manipulated_nodes = defaultdict(list)
|
|
980
|
+
|
|
981
|
+
indent_level = 0
|
|
982
|
+
|
|
983
|
+
input_output_ptr = set()
|
|
984
|
+
for input in track_inputs:
|
|
985
|
+
input_output_ptr.add(input.ptr)
|
|
986
|
+
for output in track_outputs:
|
|
987
|
+
input_output_ptr.add(output.ptr)
|
|
988
|
+
|
|
989
|
+
def add_array_node(x: wp.array, name: str, active_scope_stack=None):
|
|
990
|
+
if active_scope_stack is None:
|
|
991
|
+
active_scope_stack = []
|
|
992
|
+
nonlocal node_labels
|
|
993
|
+
if x in array_labels:
|
|
994
|
+
name = array_labels[x]
|
|
995
|
+
if x.ptr in node_labels:
|
|
996
|
+
if x.ptr not in input_output_ptr:
|
|
997
|
+
# update name unless it is an input/output array
|
|
998
|
+
if choose_longest_node_name:
|
|
999
|
+
if len(name) > len(node_labels[x.ptr]):
|
|
1000
|
+
node_labels[x.ptr] = name
|
|
1001
|
+
else:
|
|
1002
|
+
node_labels[x.ptr] = name
|
|
1003
|
+
return
|
|
1004
|
+
|
|
1005
|
+
visitor.emit_array_node(x, name, active_scope_stack, indent_level)
|
|
1006
|
+
node_labels[x.ptr] = name
|
|
1007
|
+
inserted_arrays[x.ptr] = x
|
|
1008
|
+
|
|
1009
|
+
for i, x in enumerate(track_inputs):
|
|
1010
|
+
if i < len(track_input_names):
|
|
1011
|
+
name = track_input_names[i]
|
|
1012
|
+
else:
|
|
1013
|
+
name = f"input_{i}"
|
|
1014
|
+
add_array_node(x, name)
|
|
1015
|
+
for i, x in enumerate(track_outputs):
|
|
1016
|
+
if i < len(track_output_names):
|
|
1017
|
+
name = track_output_names[i]
|
|
1018
|
+
else:
|
|
1019
|
+
name = f"output_{i}"
|
|
1020
|
+
add_array_node(x, name)
|
|
1021
|
+
# add arrays which are output of a kernel (used to simplify the graph)
|
|
1022
|
+
computed_nodes = set()
|
|
1023
|
+
for output in track_outputs:
|
|
1024
|
+
computed_nodes.add(output.ptr)
|
|
1025
|
+
active_scope_stack = []
|
|
1026
|
+
active_scope = None
|
|
1027
|
+
active_scope_id = -1
|
|
1028
|
+
active_scope_kernels = {}
|
|
1029
|
+
if not hasattr(tape, "scopes"):
|
|
1030
|
+
ignore_graph_scopes = True
|
|
1031
|
+
if not ignore_graph_scopes and len(scopes) > 0:
|
|
1032
|
+
active_scope = scopes[0]
|
|
1033
|
+
active_scope_id = 0
|
|
1034
|
+
for launch_id, launch in enumerate(launches):
|
|
1035
|
+
if active_scope is not None:
|
|
1036
|
+
if launch_id == active_scope[0]:
|
|
1037
|
+
if active_scope[1] is None:
|
|
1038
|
+
# end previous scope
|
|
1039
|
+
indent_level -= 1
|
|
1040
|
+
visitor.emit_scope_end(indent_level)
|
|
1041
|
+
active_scope_stack = active_scope_stack[:-1]
|
|
1042
|
+
else:
|
|
1043
|
+
# begin new scope
|
|
1044
|
+
active_scope_stack.append(f"scope{active_scope_id}")
|
|
1045
|
+
visitor.emit_scope_begin(active_scope_id, active_scope[1], active_scope[2], indent_level)
|
|
1046
|
+
indent_level += 1
|
|
1047
|
+
# check if we are in the next scope now
|
|
1048
|
+
while (
|
|
1049
|
+
not ignore_graph_scopes
|
|
1050
|
+
and active_scope_id < len(scopes) - 1
|
|
1051
|
+
and launch_id == scopes[active_scope_id + 1][0]
|
|
1052
|
+
):
|
|
1053
|
+
active_scope_id += 1
|
|
1054
|
+
active_scope = scopes[active_scope_id]
|
|
1055
|
+
active_scope_kernels = {}
|
|
1056
|
+
if active_scope[1] is None:
|
|
1057
|
+
# end previous scope
|
|
1058
|
+
indent_level -= 1
|
|
1059
|
+
visitor.emit_scope_end(indent_level)
|
|
1060
|
+
active_scope_stack = active_scope_stack[:-1]
|
|
1061
|
+
else:
|
|
1062
|
+
# begin new scope
|
|
1063
|
+
active_scope_stack.append(f"scope{active_scope_id}")
|
|
1064
|
+
visitor.emit_scope_begin(active_scope_id, active_scope[1], active_scope[2], indent_level)
|
|
1065
|
+
indent_level += 1
|
|
1066
|
+
|
|
1067
|
+
kernel = launch.kernel
|
|
1068
|
+
launch_data = {
|
|
1069
|
+
"id": launch_id,
|
|
1070
|
+
"dim": launch.dim,
|
|
1071
|
+
"inputs": launch.inputs,
|
|
1072
|
+
"outputs": launch.outputs,
|
|
1073
|
+
"stack_trace": "",
|
|
1074
|
+
"kernel_launch_count": kernel_launch_count[kernel.key],
|
|
1075
|
+
}
|
|
1076
|
+
launch_data.update(launch.metadata)
|
|
1077
|
+
|
|
1078
|
+
rendered = not hide_readonly_arrays or ignore_graph_scopes or kernel.key not in active_scope_kernels
|
|
1079
|
+
if rendered:
|
|
1080
|
+
active_scope_kernels[kernel.key] = launch_id
|
|
1081
|
+
|
|
1082
|
+
if not ignore_graph_scopes and hide_readonly_arrays:
|
|
1083
|
+
k_id = f"kernel{active_scope_kernels[kernel.key]}"
|
|
1084
|
+
else:
|
|
1085
|
+
k_id = f"kernel{launch_id}"
|
|
1086
|
+
|
|
1087
|
+
visitor.emit_kernel_launch_node(kernel, k_id, launch_data, rendered, indent_level)
|
|
1088
|
+
|
|
1089
|
+
# loop over inputs and outputs to add them to the graph
|
|
1090
|
+
input_arrays = []
|
|
1091
|
+
for id, x in enumerate(launch.inputs):
|
|
1092
|
+
name = kernel.adj.args[id].label
|
|
1093
|
+
if isinstance(x, wp.array):
|
|
1094
|
+
if x.ptr is None:
|
|
1095
|
+
continue
|
|
1096
|
+
# if x.ptr in array_to_launch and len(array_to_launch[x.ptr]) > 1:
|
|
1097
|
+
# launch_arg_i = array_to_launch[x.ptr]
|
|
1098
|
+
# actual_input = launch_to_array[launch_arg_i][0]
|
|
1099
|
+
# visitor.emit_edge_array_kernel(actual_input.ptr, k_id, id, indent_level)
|
|
1100
|
+
if not hide_readonly_arrays or x.ptr in computed_nodes or x.ptr in input_output_ptr:
|
|
1101
|
+
xptr = x.ptr
|
|
1102
|
+
if xptr in array_repeated:
|
|
1103
|
+
xptr = array_repeated[xptr]
|
|
1104
|
+
else:
|
|
1105
|
+
add_array_node(x, name, active_scope_stack)
|
|
1106
|
+
# input_arrays.append(x.ptr)
|
|
1107
|
+
visitor.emit_edge_array_kernel(xptr, k_id, id, indent_level)
|
|
1108
|
+
elif isinstance(x, wp._src.codegen.StructInstance):
|
|
1109
|
+
for varname, var in get_struct_vars(x).items():
|
|
1110
|
+
if isinstance(var, wp.array):
|
|
1111
|
+
if not hide_readonly_arrays or var.ptr in computed_nodes or var.ptr in input_output_ptr:
|
|
1112
|
+
add_array_node(var, f"{name}.{varname}", active_scope_stack)
|
|
1113
|
+
input_arrays.append(var.ptr)
|
|
1114
|
+
xptr = var.ptr
|
|
1115
|
+
if xptr in array_repeated:
|
|
1116
|
+
xptr = array_repeated[xptr]
|
|
1117
|
+
visitor.emit_edge_array_kernel(xptr, k_id, id, indent_level)
|
|
1118
|
+
output_arrays = []
|
|
1119
|
+
for id, x in enumerate(launch.outputs):
|
|
1120
|
+
name = kernel.adj.args[id + len(launch.inputs)].label
|
|
1121
|
+
if isinstance(x, wp.array) and x.ptr is not None:
|
|
1122
|
+
add_array_node(x, name, active_scope_stack)
|
|
1123
|
+
output_arrays.append(x.ptr)
|
|
1124
|
+
computed_nodes.add(x.ptr)
|
|
1125
|
+
visitor.emit_edge_kernel_array(k_id, id, x.ptr, indent_level)
|
|
1126
|
+
elif isinstance(x, wp._src.codegen.StructInstance):
|
|
1127
|
+
for varname, var in get_struct_vars(x).items():
|
|
1128
|
+
if isinstance(var, wp.array):
|
|
1129
|
+
add_array_node(var, f"{name}.{varname}", active_scope_stack)
|
|
1130
|
+
output_arrays.append(var.ptr)
|
|
1131
|
+
computed_nodes.add(var.ptr)
|
|
1132
|
+
visitor.emit_edge_kernel_array(k_id, id, var.ptr, indent_level)
|
|
1133
|
+
|
|
1134
|
+
for output_x in output_arrays:
|
|
1135
|
+
# track how many kernels modify each array
|
|
1136
|
+
manipulated_nodes[output_x].append(kernel.key)
|
|
1137
|
+
|
|
1138
|
+
kernel_launch_count[kernel.key] += 1
|
|
1139
|
+
|
|
1140
|
+
# close any open scopes
|
|
1141
|
+
for _ in range(len(active_scope_stack)):
|
|
1142
|
+
indent_level -= 1
|
|
1143
|
+
visitor.emit_scope_end(indent_level)
|
|
1144
|
+
|
|
1145
|
+
# add additional edges between arrays
|
|
1146
|
+
for src, dst in wrap_around_connections:
|
|
1147
|
+
if src == dst or src not in inserted_arrays or dst not in inserted_arrays:
|
|
1148
|
+
continue
|
|
1149
|
+
visitor.emit_edge_array_array(inserted_arrays[src], inserted_arrays[dst], indent_level)
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
def visualize_tape_graphviz(
|
|
1153
|
+
tape: Tape,
|
|
1154
|
+
filename: str,
|
|
1155
|
+
simplify_graph: bool = True,
|
|
1156
|
+
hide_readonly_arrays: bool = False,
|
|
1157
|
+
array_labels: dict[wp.array, str] | None = None,
|
|
1158
|
+
choose_longest_node_name: bool = True,
|
|
1159
|
+
ignore_graph_scopes: bool = False,
|
|
1160
|
+
track_inputs: list[wp.array] | None = None,
|
|
1161
|
+
track_outputs: list[wp.array] | None = None,
|
|
1162
|
+
track_input_names: list[str] | None = None,
|
|
1163
|
+
track_output_names: list[str] | None = None,
|
|
1164
|
+
graph_direction: str = "LR",
|
|
1165
|
+
) -> str:
|
|
1166
|
+
if track_output_names is None:
|
|
1167
|
+
track_output_names = []
|
|
1168
|
+
if track_input_names is None:
|
|
1169
|
+
track_input_names = []
|
|
1170
|
+
if track_outputs is None:
|
|
1171
|
+
track_outputs = []
|
|
1172
|
+
if track_inputs is None:
|
|
1173
|
+
track_inputs = []
|
|
1174
|
+
if array_labels is None:
|
|
1175
|
+
array_labels = {}
|
|
1176
|
+
visitor = GraphvizTapeVisitor()
|
|
1177
|
+
visit_tape(
|
|
1178
|
+
tape,
|
|
1179
|
+
visitor,
|
|
1180
|
+
simplify_graph,
|
|
1181
|
+
hide_readonly_arrays,
|
|
1182
|
+
array_labels,
|
|
1183
|
+
choose_longest_node_name,
|
|
1184
|
+
ignore_graph_scopes,
|
|
1185
|
+
track_inputs,
|
|
1186
|
+
track_outputs,
|
|
1187
|
+
track_input_names,
|
|
1188
|
+
track_output_names,
|
|
1189
|
+
)
|
|
1190
|
+
|
|
1191
|
+
chart = "\n".join(visitor.graphviz_lines)
|
|
1192
|
+
code = f"""digraph " " {{
|
|
1193
|
+
graph [fontname="Helvetica,Arial,sans-serif",tooltip=" "];
|
|
1194
|
+
node [style=rounded,shape=plaintext,fontname="Helvetica,Arial,sans-serif", margin="0.05,0.02", width=0, height=0, tooltip=" "];
|
|
1195
|
+
edge [fontname="Helvetica,Arial,sans-serif",tooltip=" "];
|
|
1196
|
+
rankdir={graph_direction};
|
|
1197
|
+
|
|
1198
|
+
{chart}
|
|
1199
|
+
}}
|
|
1200
|
+
"""
|
|
1201
|
+
|
|
1202
|
+
if filename is not None:
|
|
1203
|
+
with open(filename, "w") as f:
|
|
1204
|
+
f.write(code)
|
|
1205
|
+
|
|
1206
|
+
return code
|