warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: Copyright (c)
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
@@ -13,4257 +13,12 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
# TODO: Remove after cleaning up the public API.
|
|
17
17
|
|
|
18
|
-
import
|
|
19
|
-
import builtins
|
|
20
|
-
import ctypes
|
|
21
|
-
import enum
|
|
22
|
-
import functools
|
|
23
|
-
import hashlib
|
|
24
|
-
import inspect
|
|
25
|
-
import itertools
|
|
26
|
-
import math
|
|
27
|
-
import re
|
|
28
|
-
import sys
|
|
29
|
-
import textwrap
|
|
30
|
-
import types
|
|
31
|
-
from typing import Any, Callable, ClassVar, Mapping, Sequence, get_args, get_origin
|
|
18
|
+
from warp._src import codegen as _codegen
|
|
32
19
|
|
|
33
|
-
import warp.config
|
|
34
|
-
from warp.types import *
|
|
35
20
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
options = {}
|
|
21
|
+
def __getattr__(name):
|
|
22
|
+
from warp._src.utils import get_deprecated_api
|
|
39
23
|
|
|
40
|
-
|
|
41
|
-
class WarpCodegenError(RuntimeError):
|
|
42
|
-
def __init__(self, message):
|
|
43
|
-
super().__init__(message)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
class WarpCodegenTypeError(TypeError):
|
|
47
|
-
def __init__(self, message):
|
|
48
|
-
super().__init__(message)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class WarpCodegenAttributeError(AttributeError):
|
|
52
|
-
def __init__(self, message):
|
|
53
|
-
super().__init__(message)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class WarpCodegenKeyError(KeyError):
|
|
57
|
-
def __init__(self, message):
|
|
58
|
-
super().__init__(message)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
# map operator to function name
|
|
62
|
-
builtin_operators: dict[type[ast.AST], str] = {}
|
|
63
|
-
|
|
64
|
-
# see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
|
|
65
|
-
# nice overview of python operators
|
|
66
|
-
|
|
67
|
-
builtin_operators[ast.Add] = "add"
|
|
68
|
-
builtin_operators[ast.Sub] = "sub"
|
|
69
|
-
builtin_operators[ast.Mult] = "mul"
|
|
70
|
-
builtin_operators[ast.MatMult] = "mul"
|
|
71
|
-
builtin_operators[ast.Div] = "div"
|
|
72
|
-
builtin_operators[ast.FloorDiv] = "floordiv"
|
|
73
|
-
builtin_operators[ast.Pow] = "pow"
|
|
74
|
-
builtin_operators[ast.Mod] = "mod"
|
|
75
|
-
builtin_operators[ast.UAdd] = "pos"
|
|
76
|
-
builtin_operators[ast.USub] = "neg"
|
|
77
|
-
builtin_operators[ast.Not] = "unot"
|
|
78
|
-
|
|
79
|
-
builtin_operators[ast.Gt] = ">"
|
|
80
|
-
builtin_operators[ast.Lt] = "<"
|
|
81
|
-
builtin_operators[ast.GtE] = ">="
|
|
82
|
-
builtin_operators[ast.LtE] = "<="
|
|
83
|
-
builtin_operators[ast.Eq] = "=="
|
|
84
|
-
builtin_operators[ast.NotEq] = "!="
|
|
85
|
-
|
|
86
|
-
builtin_operators[ast.BitAnd] = "bit_and"
|
|
87
|
-
builtin_operators[ast.BitOr] = "bit_or"
|
|
88
|
-
builtin_operators[ast.BitXor] = "bit_xor"
|
|
89
|
-
builtin_operators[ast.Invert] = "invert"
|
|
90
|
-
builtin_operators[ast.LShift] = "lshift"
|
|
91
|
-
builtin_operators[ast.RShift] = "rshift"
|
|
92
|
-
|
|
93
|
-
comparison_chain_strings = [
|
|
94
|
-
builtin_operators[ast.Gt],
|
|
95
|
-
builtin_operators[ast.Lt],
|
|
96
|
-
builtin_operators[ast.LtE],
|
|
97
|
-
builtin_operators[ast.GtE],
|
|
98
|
-
builtin_operators[ast.Eq],
|
|
99
|
-
builtin_operators[ast.NotEq],
|
|
100
|
-
]
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def values_check_equal(a, b):
|
|
104
|
-
if isinstance(a, Sequence) and isinstance(b, Sequence):
|
|
105
|
-
if len(a) != len(b):
|
|
106
|
-
return False
|
|
107
|
-
|
|
108
|
-
return all(x == y for x, y in zip(a, b))
|
|
109
|
-
|
|
110
|
-
return a == b
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def op_str_is_chainable(op: str) -> builtins.bool:
|
|
114
|
-
return op in comparison_chain_strings
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
def get_closure_cell_contents(obj):
|
|
118
|
-
"""Retrieve a closure's cell contents or `None` if it's empty."""
|
|
119
|
-
try:
|
|
120
|
-
return obj.cell_contents
|
|
121
|
-
except ValueError:
|
|
122
|
-
pass
|
|
123
|
-
|
|
124
|
-
return None
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
|
|
128
|
-
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
|
|
129
|
-
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
|
|
130
|
-
if not annotations:
|
|
131
|
-
return {}
|
|
132
|
-
|
|
133
|
-
if not any(isinstance(x, str) for x in annotations.values()):
|
|
134
|
-
# No annotation to un-stringize.
|
|
135
|
-
return annotations
|
|
136
|
-
|
|
137
|
-
if isinstance(obj, type):
|
|
138
|
-
# class
|
|
139
|
-
globals = {}
|
|
140
|
-
module_name = getattr(obj, "__module__", None)
|
|
141
|
-
if module_name:
|
|
142
|
-
module = sys.modules.get(module_name, None)
|
|
143
|
-
if module:
|
|
144
|
-
globals = getattr(module, "__dict__", {})
|
|
145
|
-
locals = dict(vars(obj))
|
|
146
|
-
unwrap = obj
|
|
147
|
-
elif isinstance(obj, types.ModuleType):
|
|
148
|
-
# module
|
|
149
|
-
globals = obj.__dict__
|
|
150
|
-
locals = {}
|
|
151
|
-
unwrap = None
|
|
152
|
-
elif callable(obj):
|
|
153
|
-
# function
|
|
154
|
-
globals = getattr(obj, "__globals__", {})
|
|
155
|
-
# Capture the variables from the surrounding scope.
|
|
156
|
-
closure_vars = zip(
|
|
157
|
-
obj.__code__.co_freevars, tuple(get_closure_cell_contents(x) for x in (obj.__closure__ or ()))
|
|
158
|
-
)
|
|
159
|
-
locals = {k: v for k, v in closure_vars if v is not None}
|
|
160
|
-
unwrap = obj
|
|
161
|
-
else:
|
|
162
|
-
raise TypeError(f"{obj!r} is not a module, class, or callable.")
|
|
163
|
-
|
|
164
|
-
if unwrap is not None:
|
|
165
|
-
while True:
|
|
166
|
-
if hasattr(unwrap, "__wrapped__"):
|
|
167
|
-
unwrap = unwrap.__wrapped__
|
|
168
|
-
continue
|
|
169
|
-
if isinstance(unwrap, functools.partial):
|
|
170
|
-
unwrap = unwrap.func
|
|
171
|
-
continue
|
|
172
|
-
break
|
|
173
|
-
if hasattr(unwrap, "__globals__"):
|
|
174
|
-
globals = unwrap.__globals__
|
|
175
|
-
|
|
176
|
-
# "Inject" type parameters into the local namespace
|
|
177
|
-
# (unless they are shadowed by assignments *in* the local namespace),
|
|
178
|
-
# as a way of emulating annotation scopes when calling `eval()`
|
|
179
|
-
type_params = getattr(obj, "__type_params__", ())
|
|
180
|
-
if type_params:
|
|
181
|
-
locals = {param.__name__: param for param in type_params} | locals
|
|
182
|
-
|
|
183
|
-
return {k: v if not isinstance(v, str) else eval(v, globals, locals) for k, v in annotations.items()}
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
187
|
-
"""Same as `inspect.get_annotations()` but always returning un-stringized annotations."""
|
|
188
|
-
# This backports `inspect.get_annotations()` for Python 3.9 and older.
|
|
189
|
-
# See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
|
|
190
|
-
if isinstance(obj, type):
|
|
191
|
-
annotations = obj.__dict__.get("__annotations__", {})
|
|
192
|
-
else:
|
|
193
|
-
annotations = getattr(obj, "__annotations__", {})
|
|
194
|
-
|
|
195
|
-
# Evaluating annotations can be done using the `eval_str` parameter with
|
|
196
|
-
# the official function from the `inspect` module.
|
|
197
|
-
return eval_annotations(annotations, obj)
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
|
|
201
|
-
"""Same as `inspect.getfullargspec()` but always returning un-stringized annotations."""
|
|
202
|
-
# See https://docs.python.org/3/howto/annotations.html#manually-un-stringizing-stringized-annotations
|
|
203
|
-
spec = inspect.getfullargspec(func)
|
|
204
|
-
return spec._replace(annotations=eval_annotations(spec.annotations, func))
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
def struct_instance_repr_recursive(inst: StructInstance, depth: int, use_repr: bool) -> str:
|
|
208
|
-
indent = "\t"
|
|
209
|
-
|
|
210
|
-
# handle empty structs
|
|
211
|
-
if len(inst._cls.vars) == 0:
|
|
212
|
-
return f"{inst._cls.key}()"
|
|
213
|
-
|
|
214
|
-
lines = []
|
|
215
|
-
lines.append(f"{inst._cls.key}(")
|
|
216
|
-
|
|
217
|
-
for field_name, _ in inst._cls.ctype._fields_:
|
|
218
|
-
field_value = getattr(inst, field_name, None)
|
|
219
|
-
|
|
220
|
-
if isinstance(field_value, StructInstance):
|
|
221
|
-
field_value = struct_instance_repr_recursive(field_value, depth + 1, use_repr)
|
|
222
|
-
|
|
223
|
-
if use_repr:
|
|
224
|
-
lines.append(f"{indent * (depth + 1)}{field_name}={field_value!r},")
|
|
225
|
-
else:
|
|
226
|
-
lines.append(f"{indent * (depth + 1)}{field_name}={field_value!s},")
|
|
227
|
-
|
|
228
|
-
lines.append(f"{indent * depth})")
|
|
229
|
-
return "\n".join(lines)
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
class StructInstance:
|
|
233
|
-
def __init__(self, cls: Struct, ctype):
|
|
234
|
-
super().__setattr__("_cls", cls)
|
|
235
|
-
|
|
236
|
-
# maintain a c-types object for the top-level instance the struct
|
|
237
|
-
if not ctype:
|
|
238
|
-
super().__setattr__("_ctype", cls.ctype())
|
|
239
|
-
else:
|
|
240
|
-
super().__setattr__("_ctype", ctype)
|
|
241
|
-
|
|
242
|
-
# create Python attributes for each of the struct's variables
|
|
243
|
-
for field, var in cls.vars.items():
|
|
244
|
-
if isinstance(var.type, warp.codegen.Struct):
|
|
245
|
-
self.__dict__[field] = var.type.instance_type(ctype=getattr(self._ctype, field))
|
|
246
|
-
elif isinstance(var.type, warp.types.array):
|
|
247
|
-
self.__dict__[field] = None
|
|
248
|
-
else:
|
|
249
|
-
self.__dict__[field] = var.type()
|
|
250
|
-
|
|
251
|
-
def __getattribute__(self, name):
|
|
252
|
-
cls = super().__getattribute__("_cls")
|
|
253
|
-
if name == "native_name":
|
|
254
|
-
return cls.native_name
|
|
255
|
-
|
|
256
|
-
var = cls.vars.get(name)
|
|
257
|
-
if var is not None:
|
|
258
|
-
if isinstance(var.type, type) and issubclass(var.type, ctypes.Array):
|
|
259
|
-
# Each field stored in a `StructInstance` is exposed as
|
|
260
|
-
# a standard Python attribute but also has a `ctypes`
|
|
261
|
-
# equivalent that is being updated in `__setattr__`.
|
|
262
|
-
# However, when assigning in place an object such as a vec/mat
|
|
263
|
-
# (e.g.: `my_struct.my_vec[0] = 1.23`), the `__setattr__` method
|
|
264
|
-
# from `StructInstance` isn't called, and the synchronization
|
|
265
|
-
# mechanism has no chance of updating the underlying ctype data.
|
|
266
|
-
# As a workaround, we catch here all attempts at accessing such
|
|
267
|
-
# objects and directly return their underlying ctype since
|
|
268
|
-
# the Python-facing Warp vectors and matrices are implemented
|
|
269
|
-
# using `ctypes.Array` anyways.
|
|
270
|
-
return getattr(self._ctype, name)
|
|
271
|
-
|
|
272
|
-
return super().__getattribute__(name)
|
|
273
|
-
|
|
274
|
-
def __setattr__(self, name, value):
|
|
275
|
-
if name not in self._cls.vars:
|
|
276
|
-
raise RuntimeError(f"Trying to set Warp struct attribute that does not exist {name}")
|
|
277
|
-
|
|
278
|
-
var = self._cls.vars[name]
|
|
279
|
-
|
|
280
|
-
# update our ctype flat copy
|
|
281
|
-
if isinstance(var.type, array):
|
|
282
|
-
if value is None:
|
|
283
|
-
# create array with null pointer
|
|
284
|
-
setattr(self._ctype, name, array_t())
|
|
285
|
-
else:
|
|
286
|
-
# wp.array
|
|
287
|
-
assert isinstance(value, array)
|
|
288
|
-
assert types_equal(value.dtype, var.type.dtype), (
|
|
289
|
-
f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
|
|
290
|
-
)
|
|
291
|
-
setattr(self._ctype, name, value.__ctype__())
|
|
292
|
-
|
|
293
|
-
# workaround to prevent gradient buffers being garbage collected
|
|
294
|
-
# since users can do struct.array.requires_grad = False the gradient array
|
|
295
|
-
# would be collected while the struct ctype still holds a reference to it
|
|
296
|
-
super().__setattr__("_" + name + "_grad", value.grad)
|
|
297
|
-
|
|
298
|
-
elif isinstance(var.type, Struct):
|
|
299
|
-
# assign structs by-value, otherwise we would have problematic cases transferring ownership
|
|
300
|
-
# of the underlying ctypes data between shared Python struct instances
|
|
301
|
-
|
|
302
|
-
if not isinstance(value, StructInstance):
|
|
303
|
-
raise RuntimeError(
|
|
304
|
-
f"Trying to assign a non-structure value to a struct attribute with type: {self._cls.key}"
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
# destination attribution on self
|
|
308
|
-
dest = getattr(self, name)
|
|
309
|
-
|
|
310
|
-
if dest._cls.key is not value._cls.key:
|
|
311
|
-
raise RuntimeError(
|
|
312
|
-
f"Trying to assign a structure of type {value._cls.key} to an attribute of {self._cls.key}"
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
# update all nested ctype vars by deep copy
|
|
316
|
-
for n in dest._cls.vars:
|
|
317
|
-
setattr(dest, n, getattr(value, n))
|
|
318
|
-
|
|
319
|
-
# early return to avoid updating our Python StructInstance
|
|
320
|
-
return
|
|
321
|
-
|
|
322
|
-
elif issubclass(var.type, ctypes.Array):
|
|
323
|
-
# vector/matrix type, e.g. vec3
|
|
324
|
-
if value is None:
|
|
325
|
-
setattr(self._ctype, name, var.type())
|
|
326
|
-
elif type(value) == var.type:
|
|
327
|
-
setattr(self._ctype, name, value)
|
|
328
|
-
else:
|
|
329
|
-
# conversion from list/tuple, ndarray, etc.
|
|
330
|
-
setattr(self._ctype, name, var.type(value))
|
|
331
|
-
|
|
332
|
-
else:
|
|
333
|
-
# primitive type
|
|
334
|
-
if value is None:
|
|
335
|
-
# zero initialize
|
|
336
|
-
setattr(self._ctype, name, var.type._type_())
|
|
337
|
-
else:
|
|
338
|
-
if hasattr(value, "_type_"):
|
|
339
|
-
# assigning warp type value (e.g.: wp.float32)
|
|
340
|
-
value = value.value
|
|
341
|
-
# float16 needs conversion to uint16 bits
|
|
342
|
-
if var.type == warp.float16:
|
|
343
|
-
setattr(self._ctype, name, float_to_half_bits(value))
|
|
344
|
-
else:
|
|
345
|
-
setattr(self._ctype, name, value)
|
|
346
|
-
|
|
347
|
-
# update Python instance
|
|
348
|
-
super().__setattr__(name, value)
|
|
349
|
-
|
|
350
|
-
def __ctype__(self):
|
|
351
|
-
return self._ctype
|
|
352
|
-
|
|
353
|
-
def __repr__(self):
|
|
354
|
-
return struct_instance_repr_recursive(self, 0, use_repr=True)
|
|
355
|
-
|
|
356
|
-
def __str__(self):
|
|
357
|
-
return struct_instance_repr_recursive(self, 0, use_repr=False)
|
|
358
|
-
|
|
359
|
-
def to(self, device):
|
|
360
|
-
"""Copies this struct with all array members moved onto the given device.
|
|
361
|
-
|
|
362
|
-
Arrays already living on the desired device are referenced as-is, while
|
|
363
|
-
arrays being moved are copied.
|
|
364
|
-
"""
|
|
365
|
-
out = self._cls()
|
|
366
|
-
stack = [(self, out, k, v) for k, v in self._cls.vars.items()]
|
|
367
|
-
while stack:
|
|
368
|
-
src, dst, name, var = stack.pop()
|
|
369
|
-
value = getattr(src, name)
|
|
370
|
-
if isinstance(var.type, array):
|
|
371
|
-
# array_t
|
|
372
|
-
setattr(dst, name, value.to(device))
|
|
373
|
-
elif isinstance(var.type, Struct):
|
|
374
|
-
# nested struct
|
|
375
|
-
new_struct = value._cls()
|
|
376
|
-
setattr(dst, name, new_struct)
|
|
377
|
-
# The call to `setattr()` just above makes a copy of `new_struct`
|
|
378
|
-
# so we need to reference that new instance of the struct.
|
|
379
|
-
new_struct = getattr(dst, name)
|
|
380
|
-
stack.extend((value, new_struct, k, v) for k, v in value._cls.vars.items())
|
|
381
|
-
else:
|
|
382
|
-
setattr(dst, name, value)
|
|
383
|
-
|
|
384
|
-
return out
|
|
385
|
-
|
|
386
|
-
# type description used in numpy structured arrays
|
|
387
|
-
def numpy_dtype(self):
|
|
388
|
-
return self._cls.numpy_dtype()
|
|
389
|
-
|
|
390
|
-
# value usable in numpy structured arrays of .numpy_dtype(), e.g. (42, 13.37, [1.0, 2.0, 3.0])
|
|
391
|
-
def numpy_value(self):
|
|
392
|
-
npvalue = []
|
|
393
|
-
for name, var in self._cls.vars.items():
|
|
394
|
-
# get the attribute value
|
|
395
|
-
value = getattr(self._ctype, name)
|
|
396
|
-
|
|
397
|
-
if isinstance(var.type, array):
|
|
398
|
-
# array_t
|
|
399
|
-
npvalue.append(value.numpy_value())
|
|
400
|
-
elif isinstance(var.type, Struct):
|
|
401
|
-
# nested struct
|
|
402
|
-
npvalue.append(value.numpy_value())
|
|
403
|
-
elif issubclass(var.type, ctypes.Array):
|
|
404
|
-
if len(var.type._shape_) == 1:
|
|
405
|
-
# vector
|
|
406
|
-
npvalue.append(list(value))
|
|
407
|
-
else:
|
|
408
|
-
# matrix
|
|
409
|
-
npvalue.append([list(row) for row in value])
|
|
410
|
-
else:
|
|
411
|
-
# scalar
|
|
412
|
-
if var.type == warp.float16:
|
|
413
|
-
npvalue.append(half_bits_to_float(value))
|
|
414
|
-
else:
|
|
415
|
-
npvalue.append(value)
|
|
416
|
-
|
|
417
|
-
return tuple(npvalue)
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
class Struct:
|
|
421
|
-
hash: bytes
|
|
422
|
-
|
|
423
|
-
def __init__(self, key: str, cls: type, module: warp.context.Module):
|
|
424
|
-
self.key = key
|
|
425
|
-
self.cls = cls
|
|
426
|
-
self.module = module
|
|
427
|
-
self.vars: dict[str, Var] = {}
|
|
428
|
-
|
|
429
|
-
if isinstance(self.cls, Sequence):
|
|
430
|
-
raise RuntimeError("Warp structs must be defined as base classes")
|
|
431
|
-
|
|
432
|
-
annotations = get_annotations(self.cls)
|
|
433
|
-
for label, type in annotations.items():
|
|
434
|
-
self.vars[label] = Var(label, type)
|
|
435
|
-
|
|
436
|
-
fields = []
|
|
437
|
-
for label, var in self.vars.items():
|
|
438
|
-
if isinstance(var.type, array):
|
|
439
|
-
fields.append((label, array_t))
|
|
440
|
-
elif isinstance(var.type, Struct):
|
|
441
|
-
fields.append((label, var.type.ctype))
|
|
442
|
-
elif issubclass(var.type, ctypes.Array):
|
|
443
|
-
fields.append((label, var.type))
|
|
444
|
-
else:
|
|
445
|
-
# HACK: fp16 requires conversion functions from warp.so
|
|
446
|
-
if var.type is warp.float16:
|
|
447
|
-
warp.init()
|
|
448
|
-
fields.append((label, var.type._type_))
|
|
449
|
-
|
|
450
|
-
class StructType(ctypes.Structure):
|
|
451
|
-
# if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
|
|
452
|
-
_fields_ = fields or [("_dummy_", ctypes.c_byte)]
|
|
453
|
-
|
|
454
|
-
self.ctype = StructType
|
|
455
|
-
|
|
456
|
-
# Compute the hash. We can cache the hash because it's static, even with nested structs.
|
|
457
|
-
# All field types are specified in the annotations, so they're resolved at declaration time.
|
|
458
|
-
ch = hashlib.sha256()
|
|
459
|
-
|
|
460
|
-
ch.update(bytes(self.key, "utf-8"))
|
|
461
|
-
|
|
462
|
-
for name, type_hint in annotations.items():
|
|
463
|
-
s = f"{name}:{warp.types.get_type_code(type_hint)}"
|
|
464
|
-
ch.update(bytes(s, "utf-8"))
|
|
465
|
-
|
|
466
|
-
# recurse on nested structs
|
|
467
|
-
if isinstance(type_hint, Struct):
|
|
468
|
-
ch.update(type_hint.hash)
|
|
469
|
-
|
|
470
|
-
self.hash = ch.digest()
|
|
471
|
-
|
|
472
|
-
# generate unique identifier for structs in native code
|
|
473
|
-
hash_suffix = f"{self.hash.hex()[:8]}"
|
|
474
|
-
self.native_name = f"{self.key}_{hash_suffix}"
|
|
475
|
-
|
|
476
|
-
# create default constructor (zero-initialize)
|
|
477
|
-
self.default_constructor = warp.context.Function(
|
|
478
|
-
func=None,
|
|
479
|
-
key=self.native_name,
|
|
480
|
-
namespace="",
|
|
481
|
-
value_func=lambda *_: self,
|
|
482
|
-
input_types={},
|
|
483
|
-
initializer_list_func=lambda *_: False,
|
|
484
|
-
native_func=self.native_name,
|
|
485
|
-
)
|
|
486
|
-
|
|
487
|
-
# build a constructor that takes each param as a value
|
|
488
|
-
input_types = {label: var.type for label, var in self.vars.items()}
|
|
489
|
-
|
|
490
|
-
self.value_constructor = warp.context.Function(
|
|
491
|
-
func=None,
|
|
492
|
-
key=self.native_name,
|
|
493
|
-
namespace="",
|
|
494
|
-
value_func=lambda *_: self,
|
|
495
|
-
input_types=input_types,
|
|
496
|
-
initializer_list_func=lambda *_: False,
|
|
497
|
-
native_func=self.native_name,
|
|
498
|
-
)
|
|
499
|
-
|
|
500
|
-
self.default_constructor.add_overload(self.value_constructor)
|
|
501
|
-
|
|
502
|
-
if isinstance(module, warp.context.Module):
|
|
503
|
-
module.register_struct(self)
|
|
504
|
-
|
|
505
|
-
# Define class for instances of this struct
|
|
506
|
-
# To enable autocomplete on s, we inherit from self.cls.
|
|
507
|
-
# For example,
|
|
508
|
-
|
|
509
|
-
# @wp.struct
|
|
510
|
-
# class A:
|
|
511
|
-
# # annotations
|
|
512
|
-
# ...
|
|
513
|
-
|
|
514
|
-
# The type annotations are inherited in A(), allowing autocomplete in kernels
|
|
515
|
-
class NewStructInstance(self.cls, StructInstance):
|
|
516
|
-
def __init__(inst, ctype=None):
|
|
517
|
-
StructInstance.__init__(inst, self, ctype)
|
|
518
|
-
|
|
519
|
-
# make sure warp.types.get_type_code works with this StructInstance
|
|
520
|
-
NewStructInstance.cls = self.cls
|
|
521
|
-
NewStructInstance.native_name = self.native_name
|
|
522
|
-
|
|
523
|
-
self.instance_type = NewStructInstance
|
|
524
|
-
|
|
525
|
-
def __call__(self):
|
|
526
|
-
"""
|
|
527
|
-
This function returns s = StructInstance(self)
|
|
528
|
-
s uses self.cls as template.
|
|
529
|
-
"""
|
|
530
|
-
return self.instance_type()
|
|
531
|
-
|
|
532
|
-
def initializer(self):
|
|
533
|
-
return self.default_constructor
|
|
534
|
-
|
|
535
|
-
# return structured NumPy dtype, including field names, formats, and offsets
|
|
536
|
-
def numpy_dtype(self):
|
|
537
|
-
names = []
|
|
538
|
-
formats = []
|
|
539
|
-
offsets = []
|
|
540
|
-
for name, var in self.vars.items():
|
|
541
|
-
names.append(name)
|
|
542
|
-
offsets.append(getattr(self.ctype, name).offset)
|
|
543
|
-
if isinstance(var.type, array):
|
|
544
|
-
# array_t
|
|
545
|
-
formats.append(array_t.numpy_dtype())
|
|
546
|
-
elif isinstance(var.type, Struct):
|
|
547
|
-
# nested struct
|
|
548
|
-
formats.append(var.type.numpy_dtype())
|
|
549
|
-
elif issubclass(var.type, ctypes.Array):
|
|
550
|
-
scalar_typestr = type_typestr(var.type._wp_scalar_type_)
|
|
551
|
-
if len(var.type._shape_) == 1:
|
|
552
|
-
# vector
|
|
553
|
-
formats.append(f"{var.type._length_}{scalar_typestr}")
|
|
554
|
-
else:
|
|
555
|
-
# matrix
|
|
556
|
-
formats.append(f"{var.type._shape_}{scalar_typestr}")
|
|
557
|
-
else:
|
|
558
|
-
# scalar
|
|
559
|
-
formats.append(type_typestr(var.type))
|
|
560
|
-
|
|
561
|
-
return {"names": names, "formats": formats, "offsets": offsets, "itemsize": ctypes.sizeof(self.ctype)}
|
|
562
|
-
|
|
563
|
-
# constructs a Warp struct instance from a pointer to the ctype
|
|
564
|
-
def from_ptr(self, ptr):
|
|
565
|
-
if not ptr:
|
|
566
|
-
raise RuntimeError("NULL pointer exception")
|
|
567
|
-
|
|
568
|
-
# create a new struct instance
|
|
569
|
-
instance = self()
|
|
570
|
-
|
|
571
|
-
for name, var in self.vars.items():
|
|
572
|
-
offset = getattr(self.ctype, name).offset
|
|
573
|
-
if isinstance(var.type, array):
|
|
574
|
-
# We could reconstruct wp.array from array_t, but it's problematic.
|
|
575
|
-
# There's no guarantee that the original wp.array is still allocated and
|
|
576
|
-
# no easy way to make a backref.
|
|
577
|
-
# Instead, we just create a stub annotation, which is not a fully usable array object.
|
|
578
|
-
setattr(instance, name, array(dtype=var.type.dtype, ndim=var.type.ndim))
|
|
579
|
-
elif isinstance(var.type, Struct):
|
|
580
|
-
# nested struct
|
|
581
|
-
value = var.type.from_ptr(ptr + offset)
|
|
582
|
-
setattr(instance, name, value)
|
|
583
|
-
elif issubclass(var.type, ctypes.Array):
|
|
584
|
-
# vector/matrix
|
|
585
|
-
value = var.type.from_ptr(ptr + offset)
|
|
586
|
-
setattr(instance, name, value)
|
|
587
|
-
else:
|
|
588
|
-
# scalar
|
|
589
|
-
cvalue = ctypes.cast(ptr + offset, ctypes.POINTER(var.type._type_)).contents
|
|
590
|
-
if var.type == warp.float16:
|
|
591
|
-
setattr(instance, name, half_bits_to_float(cvalue))
|
|
592
|
-
else:
|
|
593
|
-
setattr(instance, name, cvalue.value)
|
|
594
|
-
|
|
595
|
-
return instance
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
class Reference:
|
|
599
|
-
def __init__(self, value_type):
|
|
600
|
-
self.value_type = value_type
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
def is_reference(type: Any) -> builtins.bool:
|
|
604
|
-
return isinstance(type, Reference)
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
def strip_reference(arg: Any) -> Any:
|
|
608
|
-
if is_reference(arg):
|
|
609
|
-
return arg.value_type
|
|
610
|
-
else:
|
|
611
|
-
return arg
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
def compute_type_str(base_name, template_params):
|
|
615
|
-
if not template_params:
|
|
616
|
-
return base_name
|
|
617
|
-
|
|
618
|
-
def param2str(p):
|
|
619
|
-
if isinstance(p, builtins.bool):
|
|
620
|
-
return "true" if p else "false"
|
|
621
|
-
if isinstance(p, int):
|
|
622
|
-
return str(p)
|
|
623
|
-
elif hasattr(p, "_wp_generic_type_str_"):
|
|
624
|
-
return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
|
|
625
|
-
elif hasattr(p, "_type_"):
|
|
626
|
-
if p.__name__ == "bool":
|
|
627
|
-
return "bool"
|
|
628
|
-
else:
|
|
629
|
-
return f"wp::{p.__name__}"
|
|
630
|
-
elif is_tile(p):
|
|
631
|
-
return p.ctype()
|
|
632
|
-
elif isinstance(p, Struct):
|
|
633
|
-
return p.native_name
|
|
634
|
-
|
|
635
|
-
return p.__name__
|
|
636
|
-
|
|
637
|
-
return f"{base_name}<{', '.join(map(param2str, template_params))}>"
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
class Var:
|
|
641
|
-
def __init__(
|
|
642
|
-
self,
|
|
643
|
-
label: str,
|
|
644
|
-
type: type,
|
|
645
|
-
requires_grad: builtins.bool = False,
|
|
646
|
-
constant: builtins.bool | None = None,
|
|
647
|
-
prefix: builtins.bool = True,
|
|
648
|
-
relative_lineno: int | None = None,
|
|
649
|
-
):
|
|
650
|
-
# convert built-in types to wp types
|
|
651
|
-
if type == float:
|
|
652
|
-
type = float32
|
|
653
|
-
elif type == int:
|
|
654
|
-
type = int32
|
|
655
|
-
elif type == builtins.bool:
|
|
656
|
-
type = bool
|
|
657
|
-
|
|
658
|
-
self.label = label
|
|
659
|
-
self.type = type
|
|
660
|
-
self.requires_grad = requires_grad
|
|
661
|
-
self.constant = constant
|
|
662
|
-
self.prefix = prefix
|
|
663
|
-
|
|
664
|
-
# records whether this Var has been read from in a kernel function (array only)
|
|
665
|
-
self.is_read = False
|
|
666
|
-
# records whether this Var has been written to in a kernel function (array only)
|
|
667
|
-
self.is_write = False
|
|
668
|
-
|
|
669
|
-
# used to associate a view array Var with its parent array Var
|
|
670
|
-
self.parent = None
|
|
671
|
-
|
|
672
|
-
# Used to associate the variable with the Python statement that resulted in it being created.
|
|
673
|
-
self.relative_lineno = relative_lineno
|
|
674
|
-
|
|
675
|
-
def __str__(self):
|
|
676
|
-
return self.label
|
|
677
|
-
|
|
678
|
-
@staticmethod
|
|
679
|
-
def dtype_to_ctype(t: type) -> str:
|
|
680
|
-
if hasattr(t, "_wp_generic_type_str_"):
|
|
681
|
-
return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
|
|
682
|
-
elif isinstance(t, Struct):
|
|
683
|
-
return t.native_name
|
|
684
|
-
elif hasattr(t, "_wp_native_name_"):
|
|
685
|
-
return f"wp::{t._wp_native_name_}"
|
|
686
|
-
elif t.__name__ in ("bool", "int", "float"):
|
|
687
|
-
return t.__name__
|
|
688
|
-
|
|
689
|
-
return f"wp::{t.__name__}"
|
|
690
|
-
|
|
691
|
-
@staticmethod
|
|
692
|
-
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
693
|
-
if isinstance(t, fixedarray):
|
|
694
|
-
template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
|
|
695
|
-
dtypestr = ", ".join(template_args)
|
|
696
|
-
classstr = f"wp::{type(t).__name__}"
|
|
697
|
-
return f"{classstr}_t<{dtypestr}>"
|
|
698
|
-
elif is_array(t):
|
|
699
|
-
dtypestr = Var.dtype_to_ctype(t.dtype)
|
|
700
|
-
classstr = f"wp::{type(t).__name__}"
|
|
701
|
-
return f"{classstr}_t<{dtypestr}>"
|
|
702
|
-
elif get_origin(t) is tuple:
|
|
703
|
-
dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in get_args(t))
|
|
704
|
-
return f"wp::tuple_t<{dtypestr}>"
|
|
705
|
-
elif is_tuple(t):
|
|
706
|
-
dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in t.types)
|
|
707
|
-
classstr = f"wp::{type(t).__name__}"
|
|
708
|
-
return f"{classstr}<{dtypestr}>"
|
|
709
|
-
elif is_tile(t):
|
|
710
|
-
return t.ctype()
|
|
711
|
-
elif isinstance(t, type) and issubclass(t, StructInstance):
|
|
712
|
-
# ensure the actual Struct name is used instead of "NewStructInstance"
|
|
713
|
-
return t.native_name
|
|
714
|
-
elif is_reference(t):
|
|
715
|
-
if not value_type:
|
|
716
|
-
return Var.type_to_ctype(t.value_type) + "*"
|
|
717
|
-
|
|
718
|
-
return Var.type_to_ctype(t.value_type)
|
|
719
|
-
|
|
720
|
-
return Var.dtype_to_ctype(t)
|
|
721
|
-
|
|
722
|
-
def ctype(self, value_type: builtins.bool = False) -> str:
|
|
723
|
-
return Var.type_to_ctype(self.type, value_type)
|
|
724
|
-
|
|
725
|
-
def emit(self, prefix: str = "var"):
|
|
726
|
-
if self.prefix:
|
|
727
|
-
return f"{prefix}_{self.label}"
|
|
728
|
-
else:
|
|
729
|
-
return self.label
|
|
730
|
-
|
|
731
|
-
def emit_adj(self):
|
|
732
|
-
return self.emit("adj")
|
|
733
|
-
|
|
734
|
-
def mark_read(self):
|
|
735
|
-
"""Marks this Var as having been read from in a kernel (array only)."""
|
|
736
|
-
if not is_array(self.type):
|
|
737
|
-
return
|
|
738
|
-
|
|
739
|
-
self.is_read = True
|
|
740
|
-
|
|
741
|
-
# recursively update all parent states
|
|
742
|
-
parent = self.parent
|
|
743
|
-
while parent is not None:
|
|
744
|
-
parent.is_read = True
|
|
745
|
-
parent = parent.parent
|
|
746
|
-
|
|
747
|
-
def mark_write(self, **kwargs):
|
|
748
|
-
"""Marks this Var has having been written to in a kernel (array only)."""
|
|
749
|
-
if not is_array(self.type):
|
|
750
|
-
return
|
|
751
|
-
|
|
752
|
-
# detect if we are writing to an array after reading from it within the same kernel
|
|
753
|
-
if self.is_read and warp.config.verify_autograd_array_access:
|
|
754
|
-
if "kernel_name" and "filename" and "lineno" in kwargs:
|
|
755
|
-
print(
|
|
756
|
-
f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
|
|
757
|
-
)
|
|
758
|
-
else:
|
|
759
|
-
print(
|
|
760
|
-
f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
|
|
761
|
-
)
|
|
762
|
-
self.is_write = True
|
|
763
|
-
|
|
764
|
-
# recursively update all parent states
|
|
765
|
-
parent = self.parent
|
|
766
|
-
while parent is not None:
|
|
767
|
-
parent.is_write = True
|
|
768
|
-
parent = parent.parent
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
class Block:
|
|
772
|
-
# Represents a basic block of instructions, e.g.: list
|
|
773
|
-
# of straight line instructions inside a for-loop or conditional
|
|
774
|
-
|
|
775
|
-
def __init__(self):
|
|
776
|
-
# list of statements inside this block
|
|
777
|
-
self.body_forward = []
|
|
778
|
-
self.body_replay = []
|
|
779
|
-
self.body_reverse = []
|
|
780
|
-
|
|
781
|
-
# list of vars declared in this block
|
|
782
|
-
self.vars = []
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
def apply_defaults(
|
|
786
|
-
bound_args: inspect.BoundArguments,
|
|
787
|
-
values: Mapping[str, Any],
|
|
788
|
-
):
|
|
789
|
-
# Similar to Python's `inspect.BoundArguments.apply_defaults()`
|
|
790
|
-
# but with the possibility to pass an augmented set of default values.
|
|
791
|
-
arguments = bound_args.arguments
|
|
792
|
-
new_arguments = []
|
|
793
|
-
for name in bound_args._signature.parameters.keys():
|
|
794
|
-
if name in arguments:
|
|
795
|
-
new_arguments.append((name, arguments[name]))
|
|
796
|
-
elif name in values:
|
|
797
|
-
new_arguments.append((name, values[name]))
|
|
798
|
-
|
|
799
|
-
bound_args.arguments = dict(new_arguments)
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
def func_match_args(func, arg_types, kwarg_types):
|
|
803
|
-
try:
|
|
804
|
-
# Try to bind the given arguments to the function's signature.
|
|
805
|
-
# This is not checking whether the argument types are matching,
|
|
806
|
-
# rather it's just assigning each argument to the corresponding
|
|
807
|
-
# function parameter.
|
|
808
|
-
bound_arg_types = func.signature.bind(*arg_types, **kwarg_types)
|
|
809
|
-
except TypeError:
|
|
810
|
-
return False
|
|
811
|
-
|
|
812
|
-
# Populate the bound arguments with any default values.
|
|
813
|
-
default_arg_types = {
|
|
814
|
-
k: None if v is None else get_arg_type(v)
|
|
815
|
-
for k, v in func.defaults.items()
|
|
816
|
-
if k not in bound_arg_types.arguments
|
|
817
|
-
}
|
|
818
|
-
apply_defaults(bound_arg_types, default_arg_types)
|
|
819
|
-
bound_arg_types = tuple(bound_arg_types.arguments.values())
|
|
820
|
-
|
|
821
|
-
# Check the given argument types against the ones defined on the function.
|
|
822
|
-
for bound_arg_type, func_arg_type in zip(bound_arg_types, func.input_types.values()):
|
|
823
|
-
# Let the `value_func` callback infer the type.
|
|
824
|
-
if bound_arg_type is None:
|
|
825
|
-
continue
|
|
826
|
-
|
|
827
|
-
# if arg type registered as Any, treat as
|
|
828
|
-
# template allowing any type to match
|
|
829
|
-
if func_arg_type == Any:
|
|
830
|
-
continue
|
|
831
|
-
|
|
832
|
-
# handle function refs as a special case
|
|
833
|
-
if func_arg_type == Callable and isinstance(bound_arg_type, warp.context.Function):
|
|
834
|
-
continue
|
|
835
|
-
|
|
836
|
-
# check arg type matches input variable type
|
|
837
|
-
if not types_equal(func_arg_type, strip_reference(bound_arg_type), match_generic=True):
|
|
838
|
-
return False
|
|
839
|
-
|
|
840
|
-
return True
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
def get_arg_type(arg: Var | Any) -> type:
|
|
844
|
-
if isinstance(arg, str):
|
|
845
|
-
return str
|
|
846
|
-
|
|
847
|
-
if isinstance(arg, Sequence):
|
|
848
|
-
return tuple(get_arg_type(x) for x in arg)
|
|
849
|
-
|
|
850
|
-
if is_array(arg):
|
|
851
|
-
return arg
|
|
852
|
-
|
|
853
|
-
if get_origin(arg) is tuple:
|
|
854
|
-
return tuple(get_arg_type(x) for x in get_args(arg))
|
|
855
|
-
|
|
856
|
-
if is_tuple(arg):
|
|
857
|
-
return arg
|
|
858
|
-
|
|
859
|
-
if isinstance(arg, (type, warp.context.Function)):
|
|
860
|
-
return arg
|
|
861
|
-
|
|
862
|
-
if isinstance(arg, Var):
|
|
863
|
-
if get_origin(arg.type) is tuple:
|
|
864
|
-
return get_args(arg.type)
|
|
865
|
-
|
|
866
|
-
return arg.type
|
|
867
|
-
|
|
868
|
-
return type(arg)
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
def get_arg_value(arg: Any) -> Any:
|
|
872
|
-
if isinstance(arg, Sequence):
|
|
873
|
-
return tuple(get_arg_value(x) for x in arg)
|
|
874
|
-
|
|
875
|
-
if isinstance(arg, (type, warp.context.Function)):
|
|
876
|
-
return arg
|
|
877
|
-
|
|
878
|
-
if isinstance(arg, Var):
|
|
879
|
-
if is_tuple(arg.type):
|
|
880
|
-
return tuple(get_arg_value(x) for x in arg.type.values)
|
|
881
|
-
|
|
882
|
-
if arg.constant is not None:
|
|
883
|
-
return arg.constant
|
|
884
|
-
|
|
885
|
-
return arg
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
class Adjoint:
|
|
889
|
-
# Source code transformer, this class takes a Python function and
|
|
890
|
-
# generates forward and backward SSA forms of the function instructions
|
|
891
|
-
|
|
892
|
-
def __init__(
|
|
893
|
-
adj,
|
|
894
|
-
func: Callable[..., Any],
|
|
895
|
-
overload_annotations=None,
|
|
896
|
-
is_user_function=False,
|
|
897
|
-
skip_forward_codegen=False,
|
|
898
|
-
skip_reverse_codegen=False,
|
|
899
|
-
custom_reverse_mode=False,
|
|
900
|
-
custom_reverse_num_input_args=-1,
|
|
901
|
-
transformers: list[ast.NodeTransformer] | None = None,
|
|
902
|
-
source: str | None = None,
|
|
903
|
-
):
|
|
904
|
-
adj.func = func
|
|
905
|
-
|
|
906
|
-
adj.is_user_function = is_user_function
|
|
907
|
-
|
|
908
|
-
# whether the generation of the forward code is skipped for this function
|
|
909
|
-
adj.skip_forward_codegen = skip_forward_codegen
|
|
910
|
-
# whether the generation of the adjoint code is skipped for this function
|
|
911
|
-
adj.skip_reverse_codegen = skip_reverse_codegen
|
|
912
|
-
# Whether this function is used by a kernel that has has the backward pass enabled.
|
|
913
|
-
adj.used_by_backward_kernel = False
|
|
914
|
-
|
|
915
|
-
# extract name of source file
|
|
916
|
-
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
917
|
-
# get source file line number where function starts
|
|
918
|
-
adj.fun_lineno = 0
|
|
919
|
-
adj.source = source
|
|
920
|
-
if adj.source is None:
|
|
921
|
-
adj.source, adj.fun_lineno = adj.extract_function_source(func)
|
|
922
|
-
|
|
923
|
-
assert adj.source is not None, f"Failed to extract source code for function {func.__name__}"
|
|
924
|
-
|
|
925
|
-
# Indicates where the function definition starts (excludes decorators)
|
|
926
|
-
adj.fun_def_lineno = None
|
|
927
|
-
|
|
928
|
-
# get function source code
|
|
929
|
-
# ensures that indented class methods can be parsed as kernels
|
|
930
|
-
adj.source = textwrap.dedent(adj.source)
|
|
931
|
-
|
|
932
|
-
adj.source_lines = adj.source.splitlines()
|
|
933
|
-
|
|
934
|
-
if transformers is None:
|
|
935
|
-
transformers = []
|
|
936
|
-
|
|
937
|
-
# build AST and apply node transformers
|
|
938
|
-
adj.tree = ast.parse(adj.source)
|
|
939
|
-
adj.transformers = transformers
|
|
940
|
-
for transformer in transformers:
|
|
941
|
-
adj.tree = transformer.visit(adj.tree)
|
|
942
|
-
|
|
943
|
-
adj.fun_name = adj.tree.body[0].name
|
|
944
|
-
|
|
945
|
-
# for keeping track of line number in function code
|
|
946
|
-
adj.lineno = None
|
|
947
|
-
|
|
948
|
-
# whether the forward code shall be used for the reverse pass and a custom
|
|
949
|
-
# function signature is applied to the reverse version of the function
|
|
950
|
-
adj.custom_reverse_mode = custom_reverse_mode
|
|
951
|
-
# the number of function arguments that pertain to the forward function
|
|
952
|
-
# input arguments (i.e. the number of arguments that are not adjoint arguments)
|
|
953
|
-
adj.custom_reverse_num_input_args = custom_reverse_num_input_args
|
|
954
|
-
|
|
955
|
-
# parse argument types
|
|
956
|
-
argspec = get_full_arg_spec(func)
|
|
957
|
-
|
|
958
|
-
# ensure all arguments are annotated
|
|
959
|
-
if overload_annotations is None:
|
|
960
|
-
# use source-level argument annotations
|
|
961
|
-
if len(argspec.annotations) < len(argspec.args):
|
|
962
|
-
raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
|
|
963
|
-
adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
|
|
964
|
-
else:
|
|
965
|
-
# use overload argument annotations
|
|
966
|
-
for arg_name in argspec.args:
|
|
967
|
-
if arg_name not in overload_annotations:
|
|
968
|
-
raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
|
|
969
|
-
adj.arg_types = overload_annotations.copy()
|
|
970
|
-
|
|
971
|
-
adj.args = []
|
|
972
|
-
adj.symbols = {}
|
|
973
|
-
|
|
974
|
-
for name, type in adj.arg_types.items():
|
|
975
|
-
# skip return hint
|
|
976
|
-
if name == "return":
|
|
977
|
-
continue
|
|
978
|
-
|
|
979
|
-
# add variable for argument
|
|
980
|
-
arg = Var(name, type, requires_grad=False)
|
|
981
|
-
adj.args.append(arg)
|
|
982
|
-
|
|
983
|
-
# pre-populate symbol dictionary with function argument names
|
|
984
|
-
# this is to avoid registering false references to overshadowed modules
|
|
985
|
-
adj.symbols[name] = arg
|
|
986
|
-
|
|
987
|
-
# Indicates whether there are unresolved static expressions in the function.
|
|
988
|
-
# These stem from wp.static() expressions that could not be evaluated at declaration time.
|
|
989
|
-
# This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
|
|
990
|
-
adj.has_unresolved_static_expressions = False
|
|
991
|
-
|
|
992
|
-
# try to replace static expressions by their constant result if the
|
|
993
|
-
# expression can be evaluated at declaration time
|
|
994
|
-
adj.static_expressions: dict[str, Any] = {}
|
|
995
|
-
if "static" in adj.source:
|
|
996
|
-
adj.replace_static_expressions()
|
|
997
|
-
|
|
998
|
-
# There are cases where a same module might be rebuilt multiple times,
|
|
999
|
-
# for example when kernels are nested inside of functions, or when
|
|
1000
|
-
# a kernel's launch raises an exception. Ideally we'd always want to
|
|
1001
|
-
# avoid rebuilding kernels but some corner cases seem to depend on it,
|
|
1002
|
-
# so we only avoid rebuilding kernels that errored out to give a chance
|
|
1003
|
-
# for unit testing errors being spit out from kernels.
|
|
1004
|
-
adj.skip_build = False
|
|
1005
|
-
|
|
1006
|
-
# allocate extra space for a function call that requires its
|
|
1007
|
-
# own shared memory space, we treat shared memory as a stack
|
|
1008
|
-
# where each function pushes and pops space off, the extra
|
|
1009
|
-
# quantity is the 'roofline' amount required for the entire kernel
|
|
1010
|
-
def alloc_shared_extra(adj, num_bytes):
|
|
1011
|
-
adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
|
|
1012
|
-
|
|
1013
|
-
# returns the total number of bytes for a function
|
|
1014
|
-
# based on it's own requirements + worst case
|
|
1015
|
-
# requirements of any dependent functions
|
|
1016
|
-
def get_total_required_shared(adj):
|
|
1017
|
-
total_shared = 0
|
|
1018
|
-
|
|
1019
|
-
for var in adj.variables:
|
|
1020
|
-
if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
|
|
1021
|
-
total_shared += var.type.size_in_bytes()
|
|
1022
|
-
|
|
1023
|
-
return total_shared + adj.max_required_extra_shared_memory
|
|
1024
|
-
|
|
1025
|
-
@staticmethod
|
|
1026
|
-
def extract_function_source(func: Callable) -> tuple[str, int]:
|
|
1027
|
-
try:
|
|
1028
|
-
_, fun_lineno = inspect.getsourcelines(func)
|
|
1029
|
-
source = inspect.getsource(func)
|
|
1030
|
-
except OSError as e:
|
|
1031
|
-
raise RuntimeError(
|
|
1032
|
-
"Directly evaluating Warp code defined as a string using `exec()` is not supported, "
|
|
1033
|
-
"please save it to a file and use `importlib` if needed."
|
|
1034
|
-
) from e
|
|
1035
|
-
return source, fun_lineno
|
|
1036
|
-
|
|
1037
|
-
# generate function ssa form and adjoint
|
|
1038
|
-
def build(adj, builder, default_builder_options=None):
|
|
1039
|
-
# arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
|
|
1040
|
-
for arg in adj.args:
|
|
1041
|
-
arg.is_read = False
|
|
1042
|
-
arg.is_write = False
|
|
1043
|
-
|
|
1044
|
-
if adj.skip_build:
|
|
1045
|
-
return
|
|
1046
|
-
|
|
1047
|
-
adj.builder = builder
|
|
1048
|
-
|
|
1049
|
-
if default_builder_options is None:
|
|
1050
|
-
default_builder_options = {}
|
|
1051
|
-
|
|
1052
|
-
if adj.builder:
|
|
1053
|
-
adj.builder_options = adj.builder.options
|
|
1054
|
-
else:
|
|
1055
|
-
adj.builder_options = default_builder_options
|
|
1056
|
-
|
|
1057
|
-
global options
|
|
1058
|
-
options = adj.builder_options
|
|
1059
|
-
|
|
1060
|
-
adj.symbols = {} # map from symbols to adjoint variables
|
|
1061
|
-
adj.variables = [] # list of local variables (in order)
|
|
1062
|
-
|
|
1063
|
-
adj.return_var = None # return type for function or kernel
|
|
1064
|
-
adj.loop_symbols = [] # symbols at the start of each loop
|
|
1065
|
-
adj.loop_const_iter_symbols = (
|
|
1066
|
-
set()
|
|
1067
|
-
) # constant iteration variables for static loops (mutating them does not raise an error)
|
|
1068
|
-
|
|
1069
|
-
# blocks
|
|
1070
|
-
adj.blocks = [Block()]
|
|
1071
|
-
adj.loop_blocks = []
|
|
1072
|
-
|
|
1073
|
-
# holds current indent level
|
|
1074
|
-
adj.indentation = ""
|
|
1075
|
-
|
|
1076
|
-
# used to generate new label indices
|
|
1077
|
-
adj.label_count = 0
|
|
1078
|
-
|
|
1079
|
-
# tracks how much additional shared memory is required by any dependent function calls
|
|
1080
|
-
adj.max_required_extra_shared_memory = 0
|
|
1081
|
-
|
|
1082
|
-
# update symbol map for each argument
|
|
1083
|
-
for a in adj.args:
|
|
1084
|
-
adj.symbols[a.label] = a
|
|
1085
|
-
|
|
1086
|
-
# recursively evaluate function body
|
|
1087
|
-
try:
|
|
1088
|
-
adj.eval(adj.tree.body[0])
|
|
1089
|
-
except Exception as original_exc:
|
|
1090
|
-
try:
|
|
1091
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
1092
|
-
line = adj.source_lines[adj.lineno]
|
|
1093
|
-
msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
1094
|
-
|
|
1095
|
-
# Combine the new message with the original exception's arguments
|
|
1096
|
-
new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
|
|
1097
|
-
|
|
1098
|
-
# Enhance the original exception with parser context before re-raising.
|
|
1099
|
-
# 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
|
|
1100
|
-
raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
|
|
1101
|
-
finally:
|
|
1102
|
-
adj.skip_build = True
|
|
1103
|
-
adj.builder = None
|
|
1104
|
-
|
|
1105
|
-
if builder is not None:
|
|
1106
|
-
for a in adj.args:
|
|
1107
|
-
if isinstance(a.type, Struct):
|
|
1108
|
-
builder.build_struct_recursive(a.type)
|
|
1109
|
-
elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
|
|
1110
|
-
builder.build_struct_recursive(a.type.dtype)
|
|
1111
|
-
|
|
1112
|
-
# release builder reference for GC
|
|
1113
|
-
adj.builder = None
|
|
1114
|
-
|
|
1115
|
-
# code generation methods
|
|
1116
|
-
def format_template(adj, template, input_vars, output_var):
|
|
1117
|
-
# output var is always the 0th index
|
|
1118
|
-
args = [output_var, *input_vars]
|
|
1119
|
-
s = template.format(*args)
|
|
1120
|
-
|
|
1121
|
-
return s
|
|
1122
|
-
|
|
1123
|
-
# generates a list of formatted args
|
|
1124
|
-
def format_args(adj, prefix, args):
|
|
1125
|
-
arg_strs = []
|
|
1126
|
-
|
|
1127
|
-
for a in args:
|
|
1128
|
-
if isinstance(a, warp.context.Function):
|
|
1129
|
-
# functions don't have a var_ prefix so strip it off here
|
|
1130
|
-
if prefix == "var":
|
|
1131
|
-
arg_strs.append(f"{a.namespace}{a.native_func}")
|
|
1132
|
-
else:
|
|
1133
|
-
arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
|
|
1134
|
-
elif is_reference(a.type):
|
|
1135
|
-
arg_strs.append(f"{prefix}_{a}")
|
|
1136
|
-
elif isinstance(a, Var):
|
|
1137
|
-
arg_strs.append(a.emit(prefix))
|
|
1138
|
-
else:
|
|
1139
|
-
raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
|
|
1140
|
-
|
|
1141
|
-
return arg_strs
|
|
1142
|
-
|
|
1143
|
-
# generates argument string for a forward function call
|
|
1144
|
-
def format_forward_call_args(adj, args, use_initializer_list):
|
|
1145
|
-
arg_str = ", ".join(adj.format_args("var", args))
|
|
1146
|
-
if use_initializer_list:
|
|
1147
|
-
return f"{{{arg_str}}}"
|
|
1148
|
-
return arg_str
|
|
1149
|
-
|
|
1150
|
-
# generates argument string for a reverse function call
|
|
1151
|
-
def format_reverse_call_args(
|
|
1152
|
-
adj,
|
|
1153
|
-
args_var,
|
|
1154
|
-
args,
|
|
1155
|
-
args_out,
|
|
1156
|
-
use_initializer_list,
|
|
1157
|
-
has_output_args=True,
|
|
1158
|
-
require_original_output_arg=False,
|
|
1159
|
-
):
|
|
1160
|
-
formatted_var = adj.format_args("var", args_var)
|
|
1161
|
-
formatted_out = []
|
|
1162
|
-
if has_output_args and (require_original_output_arg or len(args_out) > 1):
|
|
1163
|
-
formatted_out = adj.format_args("var", args_out)
|
|
1164
|
-
formatted_var_adj = adj.format_args(
|
|
1165
|
-
"&adj" if use_initializer_list else "adj",
|
|
1166
|
-
args,
|
|
1167
|
-
)
|
|
1168
|
-
formatted_out_adj = adj.format_args("adj", args_out)
|
|
1169
|
-
|
|
1170
|
-
if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
|
|
1171
|
-
# there are no adjoint arguments, so we don't need to call the reverse function
|
|
1172
|
-
return None
|
|
1173
|
-
|
|
1174
|
-
if use_initializer_list:
|
|
1175
|
-
var_str = f"{{{', '.join(formatted_var)}}}"
|
|
1176
|
-
out_str = f"{{{', '.join(formatted_out)}}}"
|
|
1177
|
-
adj_str = f"{{{', '.join(formatted_var_adj)}}}"
|
|
1178
|
-
out_adj_str = ", ".join(formatted_out_adj)
|
|
1179
|
-
if len(args_out) > 1:
|
|
1180
|
-
arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
|
|
1181
|
-
else:
|
|
1182
|
-
arg_str = ", ".join([var_str, adj_str, out_adj_str])
|
|
1183
|
-
else:
|
|
1184
|
-
arg_str = ", ".join(formatted_var + formatted_out + formatted_var_adj + formatted_out_adj)
|
|
1185
|
-
return arg_str
|
|
1186
|
-
|
|
1187
|
-
def indent(adj):
|
|
1188
|
-
adj.indentation = adj.indentation + " "
|
|
1189
|
-
|
|
1190
|
-
def dedent(adj):
|
|
1191
|
-
adj.indentation = adj.indentation[:-4]
|
|
1192
|
-
|
|
1193
|
-
def begin_block(adj, name="block"):
|
|
1194
|
-
b = Block()
|
|
1195
|
-
|
|
1196
|
-
# give block a unique id
|
|
1197
|
-
b.label = name + "_" + str(adj.label_count)
|
|
1198
|
-
adj.label_count += 1
|
|
1199
|
-
|
|
1200
|
-
adj.blocks.append(b)
|
|
1201
|
-
return b
|
|
1202
|
-
|
|
1203
|
-
def end_block(adj):
|
|
1204
|
-
return adj.blocks.pop()
|
|
1205
|
-
|
|
1206
|
-
def add_var(adj, type=None, constant=None):
|
|
1207
|
-
index = len(adj.variables)
|
|
1208
|
-
name = str(index)
|
|
1209
|
-
|
|
1210
|
-
# allocate new variable
|
|
1211
|
-
v = Var(name, type=type, constant=constant, relative_lineno=adj.lineno)
|
|
1212
|
-
|
|
1213
|
-
adj.variables.append(v)
|
|
1214
|
-
|
|
1215
|
-
adj.blocks[-1].vars.append(v)
|
|
1216
|
-
|
|
1217
|
-
return v
|
|
1218
|
-
|
|
1219
|
-
def register_var(adj, var):
|
|
1220
|
-
# We sometimes initialize `Var` instances that might be thrown away
|
|
1221
|
-
# afterwards, so this method allows to defer their registration among
|
|
1222
|
-
# the list of primal vars until later on, instead of registering them
|
|
1223
|
-
# immediately if we were to use `adj.add_var()` or `adj.add_constant()`.
|
|
1224
|
-
|
|
1225
|
-
if isinstance(var, (Reference, warp.context.Function)):
|
|
1226
|
-
return var
|
|
1227
|
-
|
|
1228
|
-
if isinstance(var, int):
|
|
1229
|
-
return adj.add_constant(var)
|
|
1230
|
-
|
|
1231
|
-
if var.label is None:
|
|
1232
|
-
return adj.add_var(var.type, var.constant)
|
|
1233
|
-
|
|
1234
|
-
return var
|
|
1235
|
-
|
|
1236
|
-
def get_line_directive(adj, statement: str, relative_lineno: int | None = None) -> str | None:
|
|
1237
|
-
"""Get a line directive for the given statement.
|
|
1238
|
-
|
|
1239
|
-
Args:
|
|
1240
|
-
statement: The statement to get the line directive for.
|
|
1241
|
-
relative_lineno: The line number of the statement relative to the function.
|
|
1242
|
-
|
|
1243
|
-
Returns:
|
|
1244
|
-
A line directive for the given statement, or None if no line directive is needed.
|
|
1245
|
-
"""
|
|
1246
|
-
|
|
1247
|
-
if adj.filename == "unknown source file" or adj.fun_lineno == 0:
|
|
1248
|
-
# Early return if function is not associated with a source file or is otherwise invalid
|
|
1249
|
-
# TODO: Get line directives working with wp.map() functions
|
|
1250
|
-
return None
|
|
1251
|
-
|
|
1252
|
-
# lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
|
|
1253
|
-
# emit line directives in generated code if it's not being compiled with line information
|
|
1254
|
-
build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
|
|
1255
|
-
|
|
1256
|
-
lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
|
|
1257
|
-
|
|
1258
|
-
if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
|
|
1259
|
-
is_comment = statement.strip().startswith("//")
|
|
1260
|
-
if not is_comment:
|
|
1261
|
-
line = relative_lineno + adj.fun_lineno
|
|
1262
|
-
# Convert backslashes to forward slashes for CUDA compatibility
|
|
1263
|
-
normalized_path = adj.filename.replace("\\", "/")
|
|
1264
|
-
return f'#line {line} "{normalized_path}"'
|
|
1265
|
-
return None
|
|
1266
|
-
|
|
1267
|
-
def add_forward(adj, statement: str, replay: str | None = None, skip_replay: builtins.bool = False) -> None:
|
|
1268
|
-
"""Append a statement to the forward pass."""
|
|
1269
|
-
|
|
1270
|
-
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
1271
|
-
adj.blocks[-1].body_forward.append(line_directive)
|
|
1272
|
-
|
|
1273
|
-
adj.blocks[-1].body_forward.append(adj.indentation + statement)
|
|
1274
|
-
|
|
1275
|
-
if not skip_replay:
|
|
1276
|
-
if line_directive:
|
|
1277
|
-
adj.blocks[-1].body_replay.append(line_directive)
|
|
1278
|
-
|
|
1279
|
-
if replay:
|
|
1280
|
-
# if custom replay specified then output it
|
|
1281
|
-
adj.blocks[-1].body_replay.append(adj.indentation + replay)
|
|
1282
|
-
else:
|
|
1283
|
-
# by default just replay the original statement
|
|
1284
|
-
adj.blocks[-1].body_replay.append(adj.indentation + statement)
|
|
1285
|
-
|
|
1286
|
-
# append a statement to the reverse pass
|
|
1287
|
-
def add_reverse(adj, statement: str) -> None:
|
|
1288
|
-
"""Append a statement to the reverse pass."""
|
|
1289
|
-
|
|
1290
|
-
adj.blocks[-1].body_reverse.append(adj.indentation + statement)
|
|
1291
|
-
|
|
1292
|
-
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
1293
|
-
adj.blocks[-1].body_reverse.append(line_directive)
|
|
1294
|
-
|
|
1295
|
-
def add_constant(adj, n):
|
|
1296
|
-
output = adj.add_var(type=type(n), constant=n)
|
|
1297
|
-
return output
|
|
1298
|
-
|
|
1299
|
-
def load(adj, var):
|
|
1300
|
-
if is_reference(var.type):
|
|
1301
|
-
var = adj.add_builtin_call("load", [var])
|
|
1302
|
-
return var
|
|
1303
|
-
|
|
1304
|
-
def add_comp(adj, op_strings, left, comps):
|
|
1305
|
-
output = adj.add_var(builtins.bool)
|
|
1306
|
-
|
|
1307
|
-
left = adj.load(left)
|
|
1308
|
-
s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
|
|
1309
|
-
|
|
1310
|
-
prev_comp_var = None
|
|
1311
|
-
|
|
1312
|
-
for op, comp in zip(op_strings, comps):
|
|
1313
|
-
comp_chainable = op_str_is_chainable(op)
|
|
1314
|
-
if comp_chainable and prev_comp_var:
|
|
1315
|
-
# We restrict chaining to operands of the same type
|
|
1316
|
-
if prev_comp_var.type is comp.type:
|
|
1317
|
-
prev_comp_var = adj.load(prev_comp_var)
|
|
1318
|
-
comp_var = adj.load(comp)
|
|
1319
|
-
s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
|
|
1320
|
-
else:
|
|
1321
|
-
raise WarpCodegenTypeError(
|
|
1322
|
-
f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
|
|
1323
|
-
)
|
|
1324
|
-
else:
|
|
1325
|
-
comp_var = adj.load(comp)
|
|
1326
|
-
s += op + " " + comp_var.emit() + ") "
|
|
1327
|
-
|
|
1328
|
-
prev_comp_var = comp_var
|
|
1329
|
-
|
|
1330
|
-
s = s.rstrip() + ";"
|
|
1331
|
-
|
|
1332
|
-
adj.add_forward(s)
|
|
1333
|
-
|
|
1334
|
-
return output
|
|
1335
|
-
|
|
1336
|
-
def add_bool_op(adj, op_string, exprs):
|
|
1337
|
-
exprs = [adj.load(expr) for expr in exprs]
|
|
1338
|
-
output = adj.add_var(builtins.bool)
|
|
1339
|
-
command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
|
|
1340
|
-
adj.add_forward(command)
|
|
1341
|
-
|
|
1342
|
-
return output
|
|
1343
|
-
|
|
1344
|
-
def resolve_func(adj, func, arg_types, kwarg_types, min_outputs):
|
|
1345
|
-
if not func.is_builtin():
|
|
1346
|
-
# user-defined function
|
|
1347
|
-
overload = func.get_overload(arg_types, kwarg_types)
|
|
1348
|
-
if overload is not None:
|
|
1349
|
-
return overload
|
|
1350
|
-
else:
|
|
1351
|
-
# if func is overloaded then perform overload resolution here
|
|
1352
|
-
# we validate argument types before they go to generated native code
|
|
1353
|
-
for f in func.overloads:
|
|
1354
|
-
# skip type checking for variadic functions
|
|
1355
|
-
if not f.variadic:
|
|
1356
|
-
# check argument counts match are compatible (may be some default args)
|
|
1357
|
-
if len(f.input_types) < len(arg_types) + len(kwarg_types):
|
|
1358
|
-
continue
|
|
1359
|
-
|
|
1360
|
-
if not func_match_args(f, arg_types, kwarg_types):
|
|
1361
|
-
continue
|
|
1362
|
-
|
|
1363
|
-
# check output dimensions match expectations
|
|
1364
|
-
if min_outputs:
|
|
1365
|
-
value_type = f.value_func(None, None)
|
|
1366
|
-
if not isinstance(value_type, Sequence) or len(value_type) != min_outputs:
|
|
1367
|
-
continue
|
|
1368
|
-
|
|
1369
|
-
# found a match, use it
|
|
1370
|
-
return f
|
|
1371
|
-
|
|
1372
|
-
# unresolved function, report error
|
|
1373
|
-
arg_type_reprs = []
|
|
1374
|
-
|
|
1375
|
-
for x in itertools.chain(arg_types, kwarg_types.values()):
|
|
1376
|
-
if isinstance(x, warp.context.Function):
|
|
1377
|
-
arg_type_reprs.append("function")
|
|
1378
|
-
else:
|
|
1379
|
-
# shorten Warp primitive type names
|
|
1380
|
-
if isinstance(x, Sequence):
|
|
1381
|
-
if len(x) != 1:
|
|
1382
|
-
raise WarpCodegenError("Argument must not be the result from a multi-valued function")
|
|
1383
|
-
arg_type = x[0]
|
|
1384
|
-
else:
|
|
1385
|
-
arg_type = x
|
|
1386
|
-
|
|
1387
|
-
arg_type_reprs.append(type_repr(arg_type))
|
|
1388
|
-
|
|
1389
|
-
raise WarpCodegenError(
|
|
1390
|
-
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_type_reprs)}]"
|
|
1391
|
-
)
|
|
1392
|
-
|
|
1393
|
-
def add_call(adj, func, args, kwargs, type_args, min_outputs=None):
|
|
1394
|
-
# Extract the types and values passed as arguments to the function call.
|
|
1395
|
-
arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
|
|
1396
|
-
kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
|
|
1397
|
-
|
|
1398
|
-
# Resolve the exact function signature among any existing overload.
|
|
1399
|
-
func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
|
|
1400
|
-
|
|
1401
|
-
# Bind the positional and keyword arguments to the function's signature
|
|
1402
|
-
# in order to process them as Python does it.
|
|
1403
|
-
bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
|
|
1404
|
-
|
|
1405
|
-
# Type args are the "compile time" argument values we get from codegen.
|
|
1406
|
-
# For example, when calling `wp.vec3f(...)` from within a kernel,
|
|
1407
|
-
# this translates in fact to calling the `vector()` built-in augmented
|
|
1408
|
-
# with the type args `length=3, dtype=float`.
|
|
1409
|
-
# Eventually, these need to be passed to the underlying C++ function,
|
|
1410
|
-
# so we update the arguments with the type args here.
|
|
1411
|
-
if type_args:
|
|
1412
|
-
for arg in type_args:
|
|
1413
|
-
if arg in bound_args.arguments:
|
|
1414
|
-
# In case of conflict, ideally we'd throw an error since
|
|
1415
|
-
# what comes from codegen should be the source of truth
|
|
1416
|
-
# and users also passing the same value as an argument
|
|
1417
|
-
# is redundant (e.g.: `wp.mat22(shape=(2, 2))`).
|
|
1418
|
-
# However, for backward compatibility, we allow that form
|
|
1419
|
-
# as long as the values are equal.
|
|
1420
|
-
if values_check_equal(get_arg_value(bound_args.arguments[arg]), type_args[arg]):
|
|
1421
|
-
continue
|
|
1422
|
-
|
|
1423
|
-
raise RuntimeError(
|
|
1424
|
-
f"Remove the extraneous `{arg}` parameter "
|
|
1425
|
-
f"when calling the templated version of "
|
|
1426
|
-
f"`wp.{func.native_func}()`"
|
|
1427
|
-
)
|
|
1428
|
-
|
|
1429
|
-
type_vars = {k: Var(None, type=type(v), constant=v) for k, v in type_args.items()}
|
|
1430
|
-
apply_defaults(bound_args, type_vars)
|
|
1431
|
-
|
|
1432
|
-
if func.defaults:
|
|
1433
|
-
default_vars = {
|
|
1434
|
-
k: Var(None, type=type(v), constant=v)
|
|
1435
|
-
for k, v in func.defaults.items()
|
|
1436
|
-
if k not in bound_args.arguments and v is not None
|
|
1437
|
-
}
|
|
1438
|
-
apply_defaults(bound_args, default_vars)
|
|
1439
|
-
|
|
1440
|
-
bound_args = bound_args.arguments
|
|
1441
|
-
|
|
1442
|
-
# if it is a user-function then build it recursively
|
|
1443
|
-
if not func.is_builtin():
|
|
1444
|
-
# If the function called is a user function,
|
|
1445
|
-
# we need to ensure its adjoint is also being generated.
|
|
1446
|
-
if adj.used_by_backward_kernel:
|
|
1447
|
-
func.adj.used_by_backward_kernel = True
|
|
1448
|
-
|
|
1449
|
-
if adj.builder is None:
|
|
1450
|
-
func.build(None)
|
|
1451
|
-
|
|
1452
|
-
elif func not in adj.builder.functions:
|
|
1453
|
-
adj.builder.build_function(func)
|
|
1454
|
-
# add custom grad, replay functions to the list of functions
|
|
1455
|
-
# to be built later (invalid code could be generated if we built them now)
|
|
1456
|
-
# so that they are not missed when only the forward function is imported
|
|
1457
|
-
# from another module
|
|
1458
|
-
if func.custom_grad_func:
|
|
1459
|
-
adj.builder.deferred_functions.append(func.custom_grad_func)
|
|
1460
|
-
if func.custom_replay_func:
|
|
1461
|
-
adj.builder.deferred_functions.append(func.custom_replay_func)
|
|
1462
|
-
|
|
1463
|
-
# Resolve the return value based on the types and values of the given arguments.
|
|
1464
|
-
bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
|
|
1465
|
-
bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
|
|
1466
|
-
|
|
1467
|
-
return_type = func.value_func(
|
|
1468
|
-
{k: strip_reference(v) for k, v in bound_arg_types.items()},
|
|
1469
|
-
bound_arg_values,
|
|
1470
|
-
)
|
|
1471
|
-
|
|
1472
|
-
# Handle the special case where a Var instance is returned from the `value_func`
|
|
1473
|
-
# callback, in which case we replace the call with a reference to that variable.
|
|
1474
|
-
if isinstance(return_type, Var):
|
|
1475
|
-
return adj.register_var(return_type)
|
|
1476
|
-
elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
|
|
1477
|
-
return tuple(adj.register_var(x) for x in return_type)
|
|
1478
|
-
|
|
1479
|
-
if get_origin(return_type) is tuple:
|
|
1480
|
-
types = get_args(return_type)
|
|
1481
|
-
return_type = warp.types.tuple_t(types=types, values=(None,) * len(types))
|
|
1482
|
-
|
|
1483
|
-
# immediately allocate output variables so we can pass them into the dispatch method
|
|
1484
|
-
if return_type is None:
|
|
1485
|
-
# void function
|
|
1486
|
-
output = None
|
|
1487
|
-
output_list = []
|
|
1488
|
-
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1489
|
-
# single return value function
|
|
1490
|
-
if isinstance(return_type, Sequence):
|
|
1491
|
-
return_type = return_type[0]
|
|
1492
|
-
output = adj.add_var(return_type)
|
|
1493
|
-
output_list = [output]
|
|
1494
|
-
else:
|
|
1495
|
-
# multiple return value function
|
|
1496
|
-
output = [adj.add_var(v) for v in return_type]
|
|
1497
|
-
output_list = output
|
|
1498
|
-
|
|
1499
|
-
# If we have a built-in that requires special handling to dispatch
|
|
1500
|
-
# the arguments to the underlying C++ function, then we can resolve
|
|
1501
|
-
# these using the `dispatch_func`. Since this is only called from
|
|
1502
|
-
# within codegen, we pass it directly `codegen.Var` objects,
|
|
1503
|
-
# which allows for some more advanced resolution to be performed,
|
|
1504
|
-
# for example by checking whether an argument corresponds to
|
|
1505
|
-
# a literal value or references a variable.
|
|
1506
|
-
extra_shared_memory = 0
|
|
1507
|
-
if func.lto_dispatch_func is not None:
|
|
1508
|
-
func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
|
|
1509
|
-
func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
|
|
1510
|
-
)
|
|
1511
|
-
elif func.dispatch_func is not None:
|
|
1512
|
-
func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
|
|
1513
|
-
else:
|
|
1514
|
-
func_args = tuple(bound_args.values())
|
|
1515
|
-
template_args = ()
|
|
1516
|
-
|
|
1517
|
-
func_args = tuple(adj.register_var(x) for x in func_args)
|
|
1518
|
-
func_name = compute_type_str(func.native_func, template_args)
|
|
1519
|
-
use_initializer_list = func.initializer_list_func(bound_args, return_type)
|
|
1520
|
-
|
|
1521
|
-
fwd_args = []
|
|
1522
|
-
for func_arg in func_args:
|
|
1523
|
-
if not isinstance(func_arg, (Reference, warp.context.Function)):
|
|
1524
|
-
func_arg_var = adj.load(func_arg)
|
|
1525
|
-
else:
|
|
1526
|
-
func_arg_var = func_arg
|
|
1527
|
-
|
|
1528
|
-
# if the argument is a function (and not a builtin), then build it recursively
|
|
1529
|
-
if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
|
|
1530
|
-
if adj.used_by_backward_kernel:
|
|
1531
|
-
func_arg_var.adj.used_by_backward_kernel = True
|
|
1532
|
-
|
|
1533
|
-
adj.builder.build_function(func_arg_var)
|
|
1534
|
-
|
|
1535
|
-
fwd_args.append(strip_reference(func_arg_var))
|
|
1536
|
-
|
|
1537
|
-
if return_type is None:
|
|
1538
|
-
# handles expression (zero output) functions, e.g.: void do_something();
|
|
1539
|
-
forward_call = (
|
|
1540
|
-
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1541
|
-
)
|
|
1542
|
-
replay_call = forward_call
|
|
1543
|
-
if func.custom_replay_func is not None or func.replay_snippet is not None:
|
|
1544
|
-
replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1545
|
-
|
|
1546
|
-
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1547
|
-
# handle simple function (one output)
|
|
1548
|
-
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1549
|
-
replay_call = forward_call
|
|
1550
|
-
if func.custom_replay_func is not None:
|
|
1551
|
-
replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1552
|
-
|
|
1553
|
-
else:
|
|
1554
|
-
# handle multiple value functions
|
|
1555
|
-
forward_call = (
|
|
1556
|
-
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
|
|
1557
|
-
)
|
|
1558
|
-
replay_call = forward_call
|
|
1559
|
-
|
|
1560
|
-
if func.skip_replay:
|
|
1561
|
-
adj.add_forward(forward_call, replay="// " + replay_call)
|
|
1562
|
-
else:
|
|
1563
|
-
adj.add_forward(forward_call, replay=replay_call)
|
|
1564
|
-
|
|
1565
|
-
if not func.missing_grad and len(func_args):
|
|
1566
|
-
adj_args = tuple(strip_reference(x) for x in func_args)
|
|
1567
|
-
reverse_has_output_args = (
|
|
1568
|
-
func.require_original_output_arg or len(output_list) > 1
|
|
1569
|
-
) and func.custom_grad_func is None
|
|
1570
|
-
arg_str = adj.format_reverse_call_args(
|
|
1571
|
-
fwd_args,
|
|
1572
|
-
adj_args,
|
|
1573
|
-
output_list,
|
|
1574
|
-
use_initializer_list,
|
|
1575
|
-
has_output_args=reverse_has_output_args,
|
|
1576
|
-
require_original_output_arg=func.require_original_output_arg,
|
|
1577
|
-
)
|
|
1578
|
-
if arg_str is not None:
|
|
1579
|
-
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
|
|
1580
|
-
adj.add_reverse(reverse_call)
|
|
1581
|
-
|
|
1582
|
-
# update our smem roofline requirements based on any
|
|
1583
|
-
# shared memory required by the dependent function call
|
|
1584
|
-
if not func.is_builtin():
|
|
1585
|
-
adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
|
|
1586
|
-
else:
|
|
1587
|
-
adj.alloc_shared_extra(extra_shared_memory)
|
|
1588
|
-
|
|
1589
|
-
return output
|
|
1590
|
-
|
|
1591
|
-
def add_builtin_call(adj, func_name, args, min_outputs=None):
|
|
1592
|
-
func = warp.context.builtin_functions[func_name]
|
|
1593
|
-
return adj.add_call(func, args, {}, {}, min_outputs=min_outputs)
|
|
1594
|
-
|
|
1595
|
-
def add_return(adj, var):
|
|
1596
|
-
if var is None or len(var) == 0:
|
|
1597
|
-
# NOTE: If this kernel gets compiled for a CUDA device, then we need
|
|
1598
|
-
# to convert the return; into a continue; in codegen_func_forward()
|
|
1599
|
-
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
1600
|
-
elif len(var) == 1:
|
|
1601
|
-
adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
|
|
1602
|
-
adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
|
|
1603
|
-
else:
|
|
1604
|
-
for i, v in enumerate(var):
|
|
1605
|
-
adj.add_forward(f"ret_{i} = {v.emit()};")
|
|
1606
|
-
adj.add_reverse(f"adj_{v} += adj_ret_{i};")
|
|
1607
|
-
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
1608
|
-
|
|
1609
|
-
adj.add_reverse(f"label{adj.label_count}:;")
|
|
1610
|
-
|
|
1611
|
-
adj.label_count += 1
|
|
1612
|
-
|
|
1613
|
-
# define an if statement
|
|
1614
|
-
def begin_if(adj, cond):
|
|
1615
|
-
cond = adj.load(cond)
|
|
1616
|
-
adj.add_forward(f"if ({cond.emit()}) {{")
|
|
1617
|
-
adj.add_reverse("}")
|
|
1618
|
-
|
|
1619
|
-
adj.indent()
|
|
1620
|
-
|
|
1621
|
-
def end_if(adj, cond):
|
|
1622
|
-
adj.dedent()
|
|
1623
|
-
|
|
1624
|
-
adj.add_forward("}")
|
|
1625
|
-
cond = adj.load(cond)
|
|
1626
|
-
adj.add_reverse(f"if ({cond.emit()}) {{")
|
|
1627
|
-
|
|
1628
|
-
def begin_else(adj, cond):
|
|
1629
|
-
cond = adj.load(cond)
|
|
1630
|
-
adj.add_forward(f"if (!{cond.emit()}) {{")
|
|
1631
|
-
adj.add_reverse("}")
|
|
1632
|
-
|
|
1633
|
-
adj.indent()
|
|
1634
|
-
|
|
1635
|
-
def end_else(adj, cond):
|
|
1636
|
-
adj.dedent()
|
|
1637
|
-
|
|
1638
|
-
adj.add_forward("}")
|
|
1639
|
-
cond = adj.load(cond)
|
|
1640
|
-
adj.add_reverse(f"if (!{cond.emit()}) {{")
|
|
1641
|
-
|
|
1642
|
-
# define a for-loop
|
|
1643
|
-
def begin_for(adj, iter):
|
|
1644
|
-
cond_block = adj.begin_block("for")
|
|
1645
|
-
adj.loop_blocks.append(cond_block)
|
|
1646
|
-
adj.add_forward(f"start_{cond_block.label}:;")
|
|
1647
|
-
adj.indent()
|
|
1648
|
-
|
|
1649
|
-
# evaluate cond
|
|
1650
|
-
adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto end_{cond_block.label};")
|
|
1651
|
-
|
|
1652
|
-
# evaluate iter
|
|
1653
|
-
val = adj.add_builtin_call("iter_next", [iter])
|
|
1654
|
-
|
|
1655
|
-
adj.begin_block()
|
|
1656
|
-
|
|
1657
|
-
return val
|
|
1658
|
-
|
|
1659
|
-
def end_for(adj, iter):
|
|
1660
|
-
body_block = adj.end_block()
|
|
1661
|
-
cond_block = adj.end_block()
|
|
1662
|
-
adj.loop_blocks.pop()
|
|
1663
|
-
|
|
1664
|
-
####################
|
|
1665
|
-
# forward pass
|
|
1666
|
-
|
|
1667
|
-
for i in cond_block.body_forward:
|
|
1668
|
-
adj.blocks[-1].body_forward.append(i)
|
|
1669
|
-
|
|
1670
|
-
for i in body_block.body_forward:
|
|
1671
|
-
adj.blocks[-1].body_forward.append(i)
|
|
1672
|
-
|
|
1673
|
-
adj.add_forward(f"goto start_{cond_block.label};", skip_replay=True)
|
|
1674
|
-
|
|
1675
|
-
adj.dedent()
|
|
1676
|
-
adj.add_forward(f"end_{cond_block.label}:;", skip_replay=True)
|
|
1677
|
-
|
|
1678
|
-
####################
|
|
1679
|
-
# reverse pass
|
|
1680
|
-
|
|
1681
|
-
reverse = []
|
|
1682
|
-
|
|
1683
|
-
# reverse iterator
|
|
1684
|
-
reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
|
|
1685
|
-
|
|
1686
|
-
for i in cond_block.body_forward:
|
|
1687
|
-
reverse.append(i)
|
|
1688
|
-
|
|
1689
|
-
# zero adjoints
|
|
1690
|
-
for i in body_block.vars:
|
|
1691
|
-
if is_tile(i.type):
|
|
1692
|
-
if i.type.owner:
|
|
1693
|
-
reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
|
|
1694
|
-
else:
|
|
1695
|
-
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
1696
|
-
|
|
1697
|
-
# replay
|
|
1698
|
-
for i in body_block.body_replay:
|
|
1699
|
-
reverse.append(i)
|
|
1700
|
-
|
|
1701
|
-
# reverse
|
|
1702
|
-
for i in reversed(body_block.body_reverse):
|
|
1703
|
-
reverse.append(i)
|
|
1704
|
-
|
|
1705
|
-
reverse.append(adj.indentation + f"\tgoto start_{cond_block.label};")
|
|
1706
|
-
reverse.append(adj.indentation + f"end_{cond_block.label}:;")
|
|
1707
|
-
|
|
1708
|
-
adj.blocks[-1].body_reverse.extend(reversed(reverse))
|
|
1709
|
-
|
|
1710
|
-
# define a while loop
|
|
1711
|
-
def begin_while(adj, cond):
|
|
1712
|
-
# evaluate condition in its own block
|
|
1713
|
-
# so we can control replay
|
|
1714
|
-
cond_block = adj.begin_block("while")
|
|
1715
|
-
adj.loop_blocks.append(cond_block)
|
|
1716
|
-
cond_block.body_forward.append(f"start_{cond_block.label}:;")
|
|
1717
|
-
|
|
1718
|
-
c = adj.eval(cond)
|
|
1719
|
-
c = adj.load(c)
|
|
1720
|
-
|
|
1721
|
-
cond_block.body_forward.append(f"if (({c.emit()}) == false) goto end_{cond_block.label};")
|
|
1722
|
-
|
|
1723
|
-
# being block around loop
|
|
1724
|
-
adj.begin_block()
|
|
1725
|
-
adj.indent()
|
|
1726
|
-
|
|
1727
|
-
def end_while(adj):
|
|
1728
|
-
adj.dedent()
|
|
1729
|
-
body_block = adj.end_block()
|
|
1730
|
-
cond_block = adj.end_block()
|
|
1731
|
-
adj.loop_blocks.pop()
|
|
1732
|
-
|
|
1733
|
-
####################
|
|
1734
|
-
# forward pass
|
|
1735
|
-
|
|
1736
|
-
for i in cond_block.body_forward:
|
|
1737
|
-
adj.blocks[-1].body_forward.append(i)
|
|
1738
|
-
|
|
1739
|
-
for i in body_block.body_forward:
|
|
1740
|
-
adj.blocks[-1].body_forward.append(i)
|
|
1741
|
-
|
|
1742
|
-
adj.blocks[-1].body_forward.append(f"goto start_{cond_block.label};")
|
|
1743
|
-
adj.blocks[-1].body_forward.append(f"end_{cond_block.label}:;")
|
|
1744
|
-
|
|
1745
|
-
####################
|
|
1746
|
-
# reverse pass
|
|
1747
|
-
reverse = []
|
|
1748
|
-
|
|
1749
|
-
# cond
|
|
1750
|
-
for i in cond_block.body_forward:
|
|
1751
|
-
reverse.append(i)
|
|
1752
|
-
|
|
1753
|
-
# zero adjoints of local vars
|
|
1754
|
-
for i in body_block.vars:
|
|
1755
|
-
reverse.append(f"{i.emit_adj()} = {{}};")
|
|
1756
|
-
|
|
1757
|
-
# replay
|
|
1758
|
-
for i in body_block.body_replay:
|
|
1759
|
-
reverse.append(i)
|
|
1760
|
-
|
|
1761
|
-
# reverse
|
|
1762
|
-
for i in reversed(body_block.body_reverse):
|
|
1763
|
-
reverse.append(i)
|
|
1764
|
-
|
|
1765
|
-
reverse.append(f"goto start_{cond_block.label};")
|
|
1766
|
-
reverse.append(f"end_{cond_block.label}:;")
|
|
1767
|
-
|
|
1768
|
-
# output
|
|
1769
|
-
adj.blocks[-1].body_reverse.extend(reversed(reverse))
|
|
1770
|
-
|
|
1771
|
-
def emit_FunctionDef(adj, node):
|
|
1772
|
-
adj.fun_def_lineno = node.lineno
|
|
1773
|
-
|
|
1774
|
-
for f in node.body:
|
|
1775
|
-
# Skip variable creation for standalone constants, including docstrings
|
|
1776
|
-
if isinstance(f, ast.Expr) and isinstance(f.value, ast.Constant):
|
|
1777
|
-
continue
|
|
1778
|
-
adj.eval(f)
|
|
1779
|
-
|
|
1780
|
-
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
1781
|
-
if not isinstance(node.body[-1], ast.Return):
|
|
1782
|
-
adj.add_forward("return {};", skip_replay=True)
|
|
1783
|
-
|
|
1784
|
-
# native function case: return type is specified, eg -> int or -> wp.float32
|
|
1785
|
-
is_func_native = False
|
|
1786
|
-
if node.decorator_list is not None and len(node.decorator_list) == 1:
|
|
1787
|
-
obj = node.decorator_list[0]
|
|
1788
|
-
if isinstance(obj, ast.Call):
|
|
1789
|
-
if isinstance(obj.func, ast.Attribute):
|
|
1790
|
-
if obj.func.attr == "func_native":
|
|
1791
|
-
is_func_native = True
|
|
1792
|
-
if is_func_native and node.returns is not None:
|
|
1793
|
-
if isinstance(node.returns, ast.Name): # python built-in type
|
|
1794
|
-
var = Var(label="return_type", type=eval(node.returns.id))
|
|
1795
|
-
elif isinstance(node.returns, ast.Attribute): # warp type
|
|
1796
|
-
var = Var(label="return_type", type=eval(node.returns.attr))
|
|
1797
|
-
else:
|
|
1798
|
-
raise WarpCodegenTypeError("Native function return type not recognized")
|
|
1799
|
-
adj.return_var = (var,)
|
|
1800
|
-
|
|
1801
|
-
def emit_If(adj, node):
|
|
1802
|
-
if len(node.body) == 0:
|
|
1803
|
-
return None
|
|
1804
|
-
|
|
1805
|
-
# eval condition
|
|
1806
|
-
cond = adj.eval(node.test)
|
|
1807
|
-
|
|
1808
|
-
if cond.constant is not None:
|
|
1809
|
-
# resolve constant condition
|
|
1810
|
-
if cond.constant:
|
|
1811
|
-
for stmt in node.body:
|
|
1812
|
-
adj.eval(stmt)
|
|
1813
|
-
else:
|
|
1814
|
-
for stmt in node.orelse:
|
|
1815
|
-
adj.eval(stmt)
|
|
1816
|
-
return None
|
|
1817
|
-
|
|
1818
|
-
# save symbol map
|
|
1819
|
-
symbols_prev = adj.symbols.copy()
|
|
1820
|
-
|
|
1821
|
-
# eval body
|
|
1822
|
-
adj.begin_if(cond)
|
|
1823
|
-
|
|
1824
|
-
for stmt in node.body:
|
|
1825
|
-
adj.eval(stmt)
|
|
1826
|
-
|
|
1827
|
-
adj.end_if(cond)
|
|
1828
|
-
|
|
1829
|
-
# detect existing symbols with conflicting definitions (variables assigned inside the branch)
|
|
1830
|
-
# and resolve with a phi (select) function
|
|
1831
|
-
for items in symbols_prev.items():
|
|
1832
|
-
sym = items[0]
|
|
1833
|
-
var1 = items[1]
|
|
1834
|
-
var2 = adj.symbols[sym]
|
|
1835
|
-
|
|
1836
|
-
if var1 != var2:
|
|
1837
|
-
# insert a phi function that selects var1, var2 based on cond
|
|
1838
|
-
out = adj.add_builtin_call("where", [cond, var2, var1])
|
|
1839
|
-
adj.symbols[sym] = out
|
|
1840
|
-
|
|
1841
|
-
symbols_prev = adj.symbols.copy()
|
|
1842
|
-
|
|
1843
|
-
# evaluate 'else' statement as if (!cond)
|
|
1844
|
-
if len(node.orelse) > 0:
|
|
1845
|
-
adj.begin_else(cond)
|
|
1846
|
-
|
|
1847
|
-
for stmt in node.orelse:
|
|
1848
|
-
adj.eval(stmt)
|
|
1849
|
-
|
|
1850
|
-
adj.end_else(cond)
|
|
1851
|
-
|
|
1852
|
-
# detect existing symbols with conflicting definitions (variables assigned inside the else)
|
|
1853
|
-
# and resolve with a phi (select) function
|
|
1854
|
-
for items in symbols_prev.items():
|
|
1855
|
-
sym = items[0]
|
|
1856
|
-
var1 = items[1]
|
|
1857
|
-
var2 = adj.symbols[sym]
|
|
1858
|
-
|
|
1859
|
-
if var1 != var2:
|
|
1860
|
-
# insert a phi function that selects var1, var2 based on cond
|
|
1861
|
-
# note the reversed order of vars since we want to use !cond as our select
|
|
1862
|
-
out = adj.add_builtin_call("where", [cond, var1, var2])
|
|
1863
|
-
adj.symbols[sym] = out
|
|
1864
|
-
|
|
1865
|
-
def emit_IfExp(adj, node):
|
|
1866
|
-
cond = adj.eval(node.test)
|
|
1867
|
-
|
|
1868
|
-
if cond.constant is not None:
|
|
1869
|
-
return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)
|
|
1870
|
-
|
|
1871
|
-
adj.begin_if(cond)
|
|
1872
|
-
body = adj.eval(node.body)
|
|
1873
|
-
adj.end_if(cond)
|
|
1874
|
-
|
|
1875
|
-
adj.begin_else(cond)
|
|
1876
|
-
orelse = adj.eval(node.orelse)
|
|
1877
|
-
adj.end_else(cond)
|
|
1878
|
-
|
|
1879
|
-
return adj.add_builtin_call("where", [cond, body, orelse])
|
|
1880
|
-
|
|
1881
|
-
def emit_Compare(adj, node):
|
|
1882
|
-
# node.left, node.ops (list of ops), node.comparators (things to compare to)
|
|
1883
|
-
# e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
|
|
1884
|
-
|
|
1885
|
-
left = adj.eval(node.left)
|
|
1886
|
-
comps = [adj.eval(comp) for comp in node.comparators]
|
|
1887
|
-
op_strings = [builtin_operators[type(op)] for op in node.ops]
|
|
1888
|
-
|
|
1889
|
-
return adj.add_comp(op_strings, left, comps)
|
|
1890
|
-
|
|
1891
|
-
def emit_BoolOp(adj, node):
|
|
1892
|
-
# op, expr list values
|
|
1893
|
-
|
|
1894
|
-
op = node.op
|
|
1895
|
-
if isinstance(op, ast.And):
|
|
1896
|
-
func = "&&"
|
|
1897
|
-
elif isinstance(op, ast.Or):
|
|
1898
|
-
func = "||"
|
|
1899
|
-
else:
|
|
1900
|
-
raise WarpCodegenKeyError(f"Op {op} is not supported")
|
|
1901
|
-
|
|
1902
|
-
return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
|
|
1903
|
-
|
|
1904
|
-
def emit_Name(adj, node):
|
|
1905
|
-
# lookup symbol, if it has already been assigned to a variable then return the existing mapping
|
|
1906
|
-
if node.id in adj.symbols:
|
|
1907
|
-
return adj.symbols[node.id]
|
|
1908
|
-
|
|
1909
|
-
obj = adj.resolve_external_reference(node.id)
|
|
1910
|
-
|
|
1911
|
-
if obj is None:
|
|
1912
|
-
raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
|
|
1913
|
-
|
|
1914
|
-
if warp.types.is_value(obj):
|
|
1915
|
-
# evaluate constant
|
|
1916
|
-
out = adj.add_constant(obj)
|
|
1917
|
-
adj.symbols[node.id] = out
|
|
1918
|
-
return out
|
|
1919
|
-
|
|
1920
|
-
# the named object is either a function, class name, or module
|
|
1921
|
-
# pass it back to the caller for processing
|
|
1922
|
-
if isinstance(obj, warp.context.Function):
|
|
1923
|
-
return obj
|
|
1924
|
-
if isinstance(obj, type):
|
|
1925
|
-
return obj
|
|
1926
|
-
if isinstance(obj, Struct):
|
|
1927
|
-
adj.builder.build_struct_recursive(obj)
|
|
1928
|
-
return obj
|
|
1929
|
-
if isinstance(obj, types.ModuleType):
|
|
1930
|
-
return obj
|
|
1931
|
-
|
|
1932
|
-
raise TypeError(f"Invalid external reference type: {type(obj)}")
|
|
1933
|
-
|
|
1934
|
-
@staticmethod
|
|
1935
|
-
def resolve_type_attribute(var_type: type, attr: str):
|
|
1936
|
-
if isinstance(var_type, type) and type_is_value(var_type):
|
|
1937
|
-
if attr == "dtype":
|
|
1938
|
-
return type_scalar_type(var_type)
|
|
1939
|
-
elif attr == "length":
|
|
1940
|
-
return type_size(var_type)
|
|
1941
|
-
|
|
1942
|
-
return getattr(var_type, attr, None)
|
|
1943
|
-
|
|
1944
|
-
def vector_component_index(adj, component, vector_type):
|
|
1945
|
-
if len(component) != 1:
|
|
1946
|
-
raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
|
|
1947
|
-
|
|
1948
|
-
dim = vector_type._shape_[0]
|
|
1949
|
-
swizzles = "xyzw"[0:dim]
|
|
1950
|
-
if component not in swizzles:
|
|
1951
|
-
raise WarpCodegenAttributeError(
|
|
1952
|
-
f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
|
|
1953
|
-
)
|
|
1954
|
-
|
|
1955
|
-
index = swizzles.index(component)
|
|
1956
|
-
index = adj.add_constant(index)
|
|
1957
|
-
return index
|
|
1958
|
-
|
|
1959
|
-
def transform_component(adj, component):
|
|
1960
|
-
if len(component) != 1:
|
|
1961
|
-
raise WarpCodegenAttributeError(f"Transform attribute must be single character, got .{component}")
|
|
1962
|
-
|
|
1963
|
-
if component not in ("p", "q"):
|
|
1964
|
-
raise WarpCodegenAttributeError(f"Attribute for transformation must be either 'p' or 'q', got {component}")
|
|
1965
|
-
|
|
1966
|
-
return component
|
|
1967
|
-
|
|
1968
|
-
@staticmethod
|
|
1969
|
-
def is_differentiable_value_type(var_type):
|
|
1970
|
-
# checks that the argument type is a value type (i.e, not an array)
|
|
1971
|
-
# possibly holding differentiable values (for which gradients must be accumulated)
|
|
1972
|
-
return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
|
|
1973
|
-
|
|
1974
|
-
def emit_Attribute(adj, node):
|
|
1975
|
-
if hasattr(node, "is_adjoint"):
|
|
1976
|
-
node.value.is_adjoint = True
|
|
1977
|
-
|
|
1978
|
-
aggregate = adj.eval(node.value)
|
|
1979
|
-
|
|
1980
|
-
try:
|
|
1981
|
-
if isinstance(aggregate, Var) and aggregate.constant is not None:
|
|
1982
|
-
# this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
|
|
1983
|
-
return aggregate
|
|
1984
|
-
|
|
1985
|
-
if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
|
|
1986
|
-
out = getattr(aggregate, node.attr)
|
|
1987
|
-
|
|
1988
|
-
if warp.types.is_value(out):
|
|
1989
|
-
return adj.add_constant(out)
|
|
1990
|
-
if isinstance(out, (enum.IntEnum, enum.IntFlag)):
|
|
1991
|
-
return adj.add_constant(int(out))
|
|
1992
|
-
|
|
1993
|
-
return out
|
|
1994
|
-
|
|
1995
|
-
if hasattr(node, "is_adjoint"):
|
|
1996
|
-
# create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
|
|
1997
|
-
attr_name = aggregate.label + "." + node.attr
|
|
1998
|
-
attr_type = aggregate.type.vars[node.attr].type
|
|
1999
|
-
|
|
2000
|
-
return Var(attr_name, attr_type)
|
|
2001
|
-
|
|
2002
|
-
aggregate_type = strip_reference(aggregate.type)
|
|
2003
|
-
|
|
2004
|
-
# reading a vector or quaternion component
|
|
2005
|
-
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
2006
|
-
index = adj.vector_component_index(node.attr, aggregate_type)
|
|
2007
|
-
|
|
2008
|
-
return adj.add_builtin_call("extract", [aggregate, index])
|
|
2009
|
-
|
|
2010
|
-
elif type_is_transformation(aggregate_type):
|
|
2011
|
-
component = adj.transform_component(node.attr)
|
|
2012
|
-
|
|
2013
|
-
if component == "p":
|
|
2014
|
-
return adj.add_builtin_call("transform_get_translation", [aggregate])
|
|
2015
|
-
else:
|
|
2016
|
-
return adj.add_builtin_call("transform_get_rotation", [aggregate])
|
|
2017
|
-
|
|
2018
|
-
else:
|
|
2019
|
-
attr_var = aggregate_type.vars[node.attr]
|
|
2020
|
-
|
|
2021
|
-
# represent pointer types as uint64
|
|
2022
|
-
if isinstance(attr_var.type, pointer_t):
|
|
2023
|
-
cast = f"({Var.dtype_to_ctype(uint64)}*)"
|
|
2024
|
-
adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
|
|
2025
|
-
attr_type = Reference(uint64)
|
|
2026
|
-
else:
|
|
2027
|
-
cast = ""
|
|
2028
|
-
adj_cast = ""
|
|
2029
|
-
attr_type = Reference(attr_var.type)
|
|
2030
|
-
|
|
2031
|
-
attr = adj.add_var(attr_type)
|
|
2032
|
-
|
|
2033
|
-
if is_reference(aggregate.type):
|
|
2034
|
-
adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
|
|
2035
|
-
else:
|
|
2036
|
-
adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
|
|
2037
|
-
|
|
2038
|
-
if adj.is_differentiable_value_type(strip_reference(attr_type)):
|
|
2039
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
|
|
2040
|
-
else:
|
|
2041
|
-
adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
|
|
2042
|
-
|
|
2043
|
-
return attr
|
|
2044
|
-
|
|
2045
|
-
except (KeyError, AttributeError) as e:
|
|
2046
|
-
# Try resolving as type attribute
|
|
2047
|
-
aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
|
|
2048
|
-
|
|
2049
|
-
type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
|
|
2050
|
-
if type_attribute is not None:
|
|
2051
|
-
return type_attribute
|
|
2052
|
-
|
|
2053
|
-
if isinstance(aggregate, Var):
|
|
2054
|
-
raise WarpCodegenAttributeError(
|
|
2055
|
-
f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
|
|
2056
|
-
) from e
|
|
2057
|
-
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'") from e
|
|
2058
|
-
|
|
2059
|
-
def emit_Assert(adj, node):
|
|
2060
|
-
# eval condition
|
|
2061
|
-
cond = adj.eval(node.test)
|
|
2062
|
-
cond = adj.load(cond)
|
|
2063
|
-
|
|
2064
|
-
source_segment = ast.get_source_segment(adj.source, node)
|
|
2065
|
-
# If a message was provided with the assert, " marks can interfere with the generated code
|
|
2066
|
-
escaped_segment = source_segment.replace('"', '\\"')
|
|
2067
|
-
|
|
2068
|
-
adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
|
|
2069
|
-
|
|
2070
|
-
def emit_Constant(adj, node):
|
|
2071
|
-
if node.value is None:
|
|
2072
|
-
raise WarpCodegenTypeError("None type unsupported")
|
|
2073
|
-
else:
|
|
2074
|
-
return adj.add_constant(node.value)
|
|
2075
|
-
|
|
2076
|
-
def emit_BinOp(adj, node):
|
|
2077
|
-
# evaluate binary operator arguments
|
|
2078
|
-
|
|
2079
|
-
if warp.config.verify_autograd_array_access:
|
|
2080
|
-
# array overwrite tracking: in-place operators are a special case
|
|
2081
|
-
# x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write
|
|
2082
|
-
# so we save the current arg read flags and restore them after lhs eval
|
|
2083
|
-
is_read_states = []
|
|
2084
|
-
for arg in adj.args:
|
|
2085
|
-
is_read_states.append(arg.is_read)
|
|
2086
|
-
|
|
2087
|
-
# evaluate lhs binary operator argument
|
|
2088
|
-
left = adj.eval(node.left)
|
|
2089
|
-
|
|
2090
|
-
if warp.config.verify_autograd_array_access:
|
|
2091
|
-
# restore arg read flags
|
|
2092
|
-
for i, arg in enumerate(adj.args):
|
|
2093
|
-
arg.is_read = is_read_states[i]
|
|
2094
|
-
|
|
2095
|
-
# evaluate rhs binary operator argument
|
|
2096
|
-
right = adj.eval(node.right)
|
|
2097
|
-
|
|
2098
|
-
name = builtin_operators[type(node.op)]
|
|
2099
|
-
|
|
2100
|
-
try:
|
|
2101
|
-
# Check if there is any user-defined overload for this operator
|
|
2102
|
-
user_func = adj.resolve_external_reference(name)
|
|
2103
|
-
if isinstance(user_func, warp.context.Function):
|
|
2104
|
-
return adj.add_call(user_func, (left, right), {}, {})
|
|
2105
|
-
except WarpCodegenError:
|
|
2106
|
-
pass
|
|
2107
|
-
|
|
2108
|
-
return adj.add_builtin_call(name, [left, right])
|
|
2109
|
-
|
|
2110
|
-
def emit_UnaryOp(adj, node):
|
|
2111
|
-
# evaluate unary op arguments
|
|
2112
|
-
arg = adj.eval(node.operand)
|
|
2113
|
-
|
|
2114
|
-
# evaluate expression to a compile-time constant if arg is a constant
|
|
2115
|
-
if arg.constant is not None and math.isfinite(arg.constant):
|
|
2116
|
-
if isinstance(node.op, ast.USub):
|
|
2117
|
-
return adj.add_constant(-arg.constant)
|
|
2118
|
-
|
|
2119
|
-
name = builtin_operators[type(node.op)]
|
|
2120
|
-
|
|
2121
|
-
return adj.add_builtin_call(name, [arg])
|
|
2122
|
-
|
|
2123
|
-
def materialize_redefinitions(adj, symbols):
|
|
2124
|
-
# detect symbols with conflicting definitions (assigned inside the for loop)
|
|
2125
|
-
for items in symbols.items():
|
|
2126
|
-
sym = items[0]
|
|
2127
|
-
if adj.is_constant_iter_symbol(sym):
|
|
2128
|
-
# ignore constant overwriting in for-loops if it is a loop iterator
|
|
2129
|
-
# (it is no problem to unroll static loops multiple times in sequence)
|
|
2130
|
-
continue
|
|
2131
|
-
|
|
2132
|
-
var1 = items[1]
|
|
2133
|
-
var2 = adj.symbols[sym]
|
|
2134
|
-
|
|
2135
|
-
if var1 != var2:
|
|
2136
|
-
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
2137
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
2138
|
-
line = adj.source_lines[adj.lineno]
|
|
2139
|
-
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
|
|
2140
|
-
print(msg)
|
|
2141
|
-
|
|
2142
|
-
if var1.constant is not None:
|
|
2143
|
-
raise WarpCodegenError(
|
|
2144
|
-
f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
|
|
2145
|
-
)
|
|
2146
|
-
|
|
2147
|
-
# overwrite the old variable value (violates SSA)
|
|
2148
|
-
adj.add_builtin_call("assign", [var1, var2])
|
|
2149
|
-
|
|
2150
|
-
# reset the symbol to point to the original variable
|
|
2151
|
-
adj.symbols[sym] = var1
|
|
2152
|
-
|
|
2153
|
-
def emit_While(adj, node):
|
|
2154
|
-
adj.begin_while(node.test)
|
|
2155
|
-
|
|
2156
|
-
adj.loop_symbols.append(adj.symbols.copy())
|
|
2157
|
-
|
|
2158
|
-
# eval body
|
|
2159
|
-
for s in node.body:
|
|
2160
|
-
adj.eval(s)
|
|
2161
|
-
|
|
2162
|
-
adj.materialize_redefinitions(adj.loop_symbols[-1])
|
|
2163
|
-
adj.loop_symbols.pop()
|
|
2164
|
-
|
|
2165
|
-
adj.end_while()
|
|
2166
|
-
|
|
2167
|
-
def eval_num(adj, a):
|
|
2168
|
-
if isinstance(a, ast.Constant):
|
|
2169
|
-
return True, a.value
|
|
2170
|
-
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Constant):
|
|
2171
|
-
# Negative constant
|
|
2172
|
-
return True, -a.operand.value
|
|
2173
|
-
|
|
2174
|
-
# try and resolve the expression to an object
|
|
2175
|
-
# e.g.: wp.constant in the globals scope
|
|
2176
|
-
obj, _ = adj.resolve_static_expression(a)
|
|
2177
|
-
|
|
2178
|
-
if obj is None:
|
|
2179
|
-
obj = adj.eval(a)
|
|
2180
|
-
|
|
2181
|
-
if isinstance(obj, Var) and obj.constant is not None:
|
|
2182
|
-
obj = obj.constant
|
|
2183
|
-
|
|
2184
|
-
return warp.types.is_int(obj), obj
|
|
2185
|
-
|
|
2186
|
-
# detects whether a loop contains a break (or continue) statement
|
|
2187
|
-
def contains_break(adj, body):
|
|
2188
|
-
for s in body:
|
|
2189
|
-
if isinstance(s, ast.Break):
|
|
2190
|
-
return True
|
|
2191
|
-
elif isinstance(s, ast.Continue):
|
|
2192
|
-
return True
|
|
2193
|
-
elif isinstance(s, ast.If):
|
|
2194
|
-
if adj.contains_break(s.body):
|
|
2195
|
-
return True
|
|
2196
|
-
if adj.contains_break(s.orelse):
|
|
2197
|
-
return True
|
|
2198
|
-
else:
|
|
2199
|
-
# note that nested for or while loops containing a break statement
|
|
2200
|
-
# do not affect the current loop
|
|
2201
|
-
pass
|
|
2202
|
-
|
|
2203
|
-
return False
|
|
2204
|
-
|
|
2205
|
-
# returns a constant range() if unrollable, otherwise None
|
|
2206
|
-
def get_unroll_range(adj, loop):
|
|
2207
|
-
if (
|
|
2208
|
-
not isinstance(loop.iter, ast.Call)
|
|
2209
|
-
or not isinstance(loop.iter.func, ast.Name)
|
|
2210
|
-
or loop.iter.func.id != "range"
|
|
2211
|
-
or len(loop.iter.args) == 0
|
|
2212
|
-
or len(loop.iter.args) > 3
|
|
2213
|
-
):
|
|
2214
|
-
return None
|
|
2215
|
-
|
|
2216
|
-
# if all range() arguments are numeric constants we will unroll
|
|
2217
|
-
# note that this only handles trivial constants, it will not unroll
|
|
2218
|
-
# constant compile-time expressions e.g.: range(0, 3*2)
|
|
2219
|
-
|
|
2220
|
-
# Evaluate the arguments and check that they are numeric constants
|
|
2221
|
-
# It is important to do that in one pass, so that if evaluating these arguments have side effects
|
|
2222
|
-
# the code does not get generated more than once
|
|
2223
|
-
range_args = [adj.eval_num(arg) for arg in loop.iter.args]
|
|
2224
|
-
arg_is_numeric, arg_values = zip(*range_args)
|
|
2225
|
-
|
|
2226
|
-
if all(arg_is_numeric):
|
|
2227
|
-
# All argument are numeric constants
|
|
2228
|
-
|
|
2229
|
-
# range(end)
|
|
2230
|
-
if len(loop.iter.args) == 1:
|
|
2231
|
-
start = 0
|
|
2232
|
-
end = arg_values[0]
|
|
2233
|
-
step = 1
|
|
2234
|
-
|
|
2235
|
-
# range(start, end)
|
|
2236
|
-
elif len(loop.iter.args) == 2:
|
|
2237
|
-
start = arg_values[0]
|
|
2238
|
-
end = arg_values[1]
|
|
2239
|
-
step = 1
|
|
2240
|
-
|
|
2241
|
-
# range(start, end, step)
|
|
2242
|
-
elif len(loop.iter.args) == 3:
|
|
2243
|
-
start = arg_values[0]
|
|
2244
|
-
end = arg_values[1]
|
|
2245
|
-
step = arg_values[2]
|
|
2246
|
-
|
|
2247
|
-
# test if we're above max unroll count
|
|
2248
|
-
max_iters = abs(end - start) // abs(step)
|
|
2249
|
-
|
|
2250
|
-
if "max_unroll" in adj.builder_options:
|
|
2251
|
-
max_unroll = adj.builder_options["max_unroll"]
|
|
2252
|
-
else:
|
|
2253
|
-
max_unroll = warp.config.max_unroll
|
|
2254
|
-
|
|
2255
|
-
ok_to_unroll = True
|
|
2256
|
-
|
|
2257
|
-
if max_iters > max_unroll:
|
|
2258
|
-
if warp.config.verbose:
|
|
2259
|
-
print(
|
|
2260
|
-
f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
|
|
2261
|
-
)
|
|
2262
|
-
ok_to_unroll = False
|
|
2263
|
-
|
|
2264
|
-
elif adj.contains_break(loop.body):
|
|
2265
|
-
if warp.config.verbose:
|
|
2266
|
-
print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
|
|
2267
|
-
ok_to_unroll = False
|
|
2268
|
-
|
|
2269
|
-
if ok_to_unroll:
|
|
2270
|
-
return range(start, end, step)
|
|
2271
|
-
|
|
2272
|
-
# Unroll is not possible, range needs to be valuated dynamically
|
|
2273
|
-
range_call = adj.add_builtin_call(
|
|
2274
|
-
"range",
|
|
2275
|
-
[adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
|
|
2276
|
-
)
|
|
2277
|
-
return range_call
|
|
2278
|
-
|
|
2279
|
-
def record_constant_iter_symbol(adj, sym):
|
|
2280
|
-
adj.loop_const_iter_symbols.add(sym)
|
|
2281
|
-
|
|
2282
|
-
def is_constant_iter_symbol(adj, sym):
|
|
2283
|
-
return sym in adj.loop_const_iter_symbols
|
|
2284
|
-
|
|
2285
|
-
def emit_For(adj, node):
|
|
2286
|
-
# try and unroll simple range() statements that use constant args
|
|
2287
|
-
unroll_range = adj.get_unroll_range(node)
|
|
2288
|
-
|
|
2289
|
-
if isinstance(unroll_range, range):
|
|
2290
|
-
const_iter_sym = node.target.id
|
|
2291
|
-
# prevent constant conflicts in `materialize_redefinitions()`
|
|
2292
|
-
adj.record_constant_iter_symbol(const_iter_sym)
|
|
2293
|
-
|
|
2294
|
-
# unroll static for-loop
|
|
2295
|
-
for i in unroll_range:
|
|
2296
|
-
const_iter = adj.add_constant(i)
|
|
2297
|
-
adj.symbols[const_iter_sym] = const_iter
|
|
2298
|
-
|
|
2299
|
-
# eval body
|
|
2300
|
-
for s in node.body:
|
|
2301
|
-
adj.eval(s)
|
|
2302
|
-
|
|
2303
|
-
# otherwise generate a dynamic loop
|
|
2304
|
-
else:
|
|
2305
|
-
# evaluate the Iterable -- only if not previously evaluated when trying to unroll
|
|
2306
|
-
if unroll_range is not None:
|
|
2307
|
-
# Range has already been evaluated when trying to unroll, do not re-evaluate
|
|
2308
|
-
iter = unroll_range
|
|
2309
|
-
else:
|
|
2310
|
-
iter = adj.eval(node.iter)
|
|
2311
|
-
|
|
2312
|
-
adj.symbols[node.target.id] = adj.begin_for(iter)
|
|
2313
|
-
|
|
2314
|
-
# for loops should be side-effect free, here we store a copy
|
|
2315
|
-
adj.loop_symbols.append(adj.symbols.copy())
|
|
2316
|
-
|
|
2317
|
-
# eval body
|
|
2318
|
-
for s in node.body:
|
|
2319
|
-
adj.eval(s)
|
|
2320
|
-
|
|
2321
|
-
adj.materialize_redefinitions(adj.loop_symbols[-1])
|
|
2322
|
-
adj.loop_symbols.pop()
|
|
2323
|
-
|
|
2324
|
-
adj.end_for(iter)
|
|
2325
|
-
|
|
2326
|
-
def emit_Break(adj, node):
|
|
2327
|
-
adj.materialize_redefinitions(adj.loop_symbols[-1])
|
|
2328
|
-
|
|
2329
|
-
adj.add_forward(f"goto end_{adj.loop_blocks[-1].label};")
|
|
2330
|
-
|
|
2331
|
-
def emit_Continue(adj, node):
|
|
2332
|
-
adj.materialize_redefinitions(adj.loop_symbols[-1])
|
|
2333
|
-
|
|
2334
|
-
adj.add_forward(f"goto start_{adj.loop_blocks[-1].label};")
|
|
2335
|
-
|
|
2336
|
-
def emit_Expr(adj, node):
|
|
2337
|
-
return adj.eval(node.value)
|
|
2338
|
-
|
|
2339
|
-
def check_tid_in_func_error(adj, node):
|
|
2340
|
-
if adj.is_user_function:
|
|
2341
|
-
if hasattr(node.func, "attr") and node.func.attr == "tid":
|
|
2342
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
2343
|
-
line = adj.source_lines[adj.lineno]
|
|
2344
|
-
raise WarpCodegenError(
|
|
2345
|
-
"tid() may only be called from a Warp kernel, not a Warp function. "
|
|
2346
|
-
"Instead, obtain the indices from a @wp.kernel and pass them as "
|
|
2347
|
-
f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
|
|
2348
|
-
)
|
|
2349
|
-
|
|
2350
|
-
def resolve_arg(adj, arg):
|
|
2351
|
-
# Always try to start with evaluating the argument since it can help
|
|
2352
|
-
# detecting some issues such as global variables being accessed.
|
|
2353
|
-
try:
|
|
2354
|
-
var = adj.eval(arg)
|
|
2355
|
-
except (WarpCodegenError, WarpCodegenKeyError) as e:
|
|
2356
|
-
error = e
|
|
2357
|
-
else:
|
|
2358
|
-
error = None
|
|
2359
|
-
|
|
2360
|
-
# Check if we can resolve the argument as a static expression.
|
|
2361
|
-
# If not, return the variable resulting from evaluating the argument.
|
|
2362
|
-
expr, _ = adj.resolve_static_expression(arg)
|
|
2363
|
-
if expr is None:
|
|
2364
|
-
if error is not None:
|
|
2365
|
-
raise error
|
|
2366
|
-
|
|
2367
|
-
return var
|
|
2368
|
-
|
|
2369
|
-
if isinstance(expr, (type, Struct, Var, warp.context.Function)):
|
|
2370
|
-
return expr
|
|
2371
|
-
|
|
2372
|
-
if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
|
|
2373
|
-
return adj.add_constant(int(expr))
|
|
2374
|
-
|
|
2375
|
-
return adj.add_constant(expr)
|
|
2376
|
-
|
|
2377
|
-
def emit_Call(adj, node):
|
|
2378
|
-
adj.check_tid_in_func_error(node)
|
|
2379
|
-
|
|
2380
|
-
# try and lookup function in globals by
|
|
2381
|
-
# resolving path (e.g.: module.submodule.attr)
|
|
2382
|
-
if hasattr(node.func, "warp_func"):
|
|
2383
|
-
func = node.func.warp_func
|
|
2384
|
-
path = []
|
|
2385
|
-
else:
|
|
2386
|
-
func, path = adj.resolve_static_expression(node.func)
|
|
2387
|
-
if func is None:
|
|
2388
|
-
func = adj.eval(node.func)
|
|
2389
|
-
|
|
2390
|
-
if adj.is_static_expression(func):
|
|
2391
|
-
# try to evaluate wp.static() expressions
|
|
2392
|
-
obj, code = adj.evaluate_static_expression(node)
|
|
2393
|
-
if obj is not None:
|
|
2394
|
-
adj.static_expressions[code] = obj
|
|
2395
|
-
if isinstance(obj, warp.context.Function):
|
|
2396
|
-
# special handling for wp.static() evaluating to a function
|
|
2397
|
-
return obj
|
|
2398
|
-
else:
|
|
2399
|
-
out = adj.add_constant(obj)
|
|
2400
|
-
return out
|
|
2401
|
-
|
|
2402
|
-
type_args = {}
|
|
2403
|
-
|
|
2404
|
-
if len(path) > 0 and not isinstance(func, warp.context.Function):
|
|
2405
|
-
attr = path[-1]
|
|
2406
|
-
caller = func
|
|
2407
|
-
func = None
|
|
2408
|
-
|
|
2409
|
-
# try and lookup function name in builtins (e.g.: using `dot` directly without wp prefix)
|
|
2410
|
-
if attr in warp.context.builtin_functions:
|
|
2411
|
-
func = warp.context.builtin_functions[attr]
|
|
2412
|
-
|
|
2413
|
-
# vector class type e.g.: wp.vec3f constructor
|
|
2414
|
-
if func is None and hasattr(caller, "_wp_generic_type_str_"):
|
|
2415
|
-
func = warp.context.builtin_functions.get(caller._wp_constructor_)
|
|
2416
|
-
|
|
2417
|
-
# scalar class type e.g.: wp.int8 constructor
|
|
2418
|
-
if func is None and hasattr(caller, "__name__") and caller.__name__ in warp.context.builtin_functions:
|
|
2419
|
-
func = warp.context.builtin_functions.get(caller.__name__)
|
|
2420
|
-
|
|
2421
|
-
# struct constructor
|
|
2422
|
-
if func is None and isinstance(caller, Struct):
|
|
2423
|
-
if adj.builder is not None:
|
|
2424
|
-
adj.builder.build_struct_recursive(caller)
|
|
2425
|
-
if node.args or node.keywords:
|
|
2426
|
-
func = caller.value_constructor
|
|
2427
|
-
else:
|
|
2428
|
-
func = caller.default_constructor
|
|
2429
|
-
|
|
2430
|
-
# lambda function
|
|
2431
|
-
if func is None and getattr(caller, "__name__", None) == "<lambda>":
|
|
2432
|
-
raise NotImplementedError("Lambda expressions are not yet supported")
|
|
2433
|
-
|
|
2434
|
-
if hasattr(caller, "_wp_type_args_"):
|
|
2435
|
-
type_args = caller._wp_type_args_
|
|
2436
|
-
|
|
2437
|
-
if func is None:
|
|
2438
|
-
raise WarpCodegenError(
|
|
2439
|
-
f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
|
|
2440
|
-
)
|
|
2441
|
-
|
|
2442
|
-
# get expected return count, e.g.: for multi-assignment
|
|
2443
|
-
min_outputs = None
|
|
2444
|
-
if hasattr(node, "expects"):
|
|
2445
|
-
min_outputs = node.expects
|
|
2446
|
-
|
|
2447
|
-
# Evaluate all positional and keywords arguments.
|
|
2448
|
-
args = tuple(adj.resolve_arg(x) for x in node.args)
|
|
2449
|
-
kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
|
|
2450
|
-
|
|
2451
|
-
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
2452
|
-
|
|
2453
|
-
if warp.config.verify_autograd_array_access:
|
|
2454
|
-
# Extract the types and values passed as arguments to the function call.
|
|
2455
|
-
arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
|
|
2456
|
-
kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
|
|
2457
|
-
|
|
2458
|
-
# Resolve the exact function signature among any existing overload.
|
|
2459
|
-
resolved_func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
|
|
2460
|
-
|
|
2461
|
-
# update arg read/write states according to what happens to that arg in the called function
|
|
2462
|
-
if hasattr(resolved_func, "adj"):
|
|
2463
|
-
for i, arg in enumerate(args):
|
|
2464
|
-
if resolved_func.adj.args[i].is_write:
|
|
2465
|
-
kernel_name = adj.fun_name
|
|
2466
|
-
filename = adj.filename
|
|
2467
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
2468
|
-
arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2469
|
-
if resolved_func.adj.args[i].is_read:
|
|
2470
|
-
arg.mark_read()
|
|
2471
|
-
|
|
2472
|
-
return out
|
|
2473
|
-
|
|
2474
|
-
def emit_Index(adj, node):
|
|
2475
|
-
# the ast.Index node appears in 3.7 versions
|
|
2476
|
-
# when performing array slices, e.g.: x = arr[i]
|
|
2477
|
-
# but in version 3.8 and higher it does not appear
|
|
2478
|
-
|
|
2479
|
-
if hasattr(node, "is_adjoint"):
|
|
2480
|
-
node.value.is_adjoint = True
|
|
2481
|
-
|
|
2482
|
-
return adj.eval(node.value)
|
|
2483
|
-
|
|
2484
|
-
def eval_indices(adj, target_type, indices):
|
|
2485
|
-
nodes = indices
|
|
2486
|
-
if hasattr(target_type, "_wp_generic_type_hint_"):
|
|
2487
|
-
indices = []
|
|
2488
|
-
for dim, node in enumerate(nodes):
|
|
2489
|
-
if isinstance(node, ast.Slice):
|
|
2490
|
-
# In the context of slicing a vec/mat type, indices are expected
|
|
2491
|
-
# to be compile-time constants, hence we can infer the actual slice
|
|
2492
|
-
# bounds also at compile-time.
|
|
2493
|
-
length = target_type._shape_[dim]
|
|
2494
|
-
step = 1 if node.step is None else adj.eval(node.step).constant
|
|
2495
|
-
|
|
2496
|
-
if node.lower is None:
|
|
2497
|
-
start = length - 1 if step < 0 else 0
|
|
2498
|
-
else:
|
|
2499
|
-
start = adj.eval(node.lower).constant
|
|
2500
|
-
start = min(max(start, -length), length)
|
|
2501
|
-
start = start + length if start < 0 else start
|
|
2502
|
-
|
|
2503
|
-
if node.upper is None:
|
|
2504
|
-
stop = -1 if step < 0 else length
|
|
2505
|
-
else:
|
|
2506
|
-
stop = adj.eval(node.upper).constant
|
|
2507
|
-
stop = min(max(stop, -length), length)
|
|
2508
|
-
stop = stop + length if stop < 0 else stop
|
|
2509
|
-
|
|
2510
|
-
slice = adj.add_builtin_call("slice", (start, stop, step))
|
|
2511
|
-
indices.append(slice)
|
|
2512
|
-
else:
|
|
2513
|
-
indices.append(adj.eval(node))
|
|
2514
|
-
|
|
2515
|
-
return tuple(indices)
|
|
2516
|
-
else:
|
|
2517
|
-
return tuple(adj.eval(x) for x in nodes)
|
|
2518
|
-
|
|
2519
|
-
def emit_indexing(adj, target, indices):
|
|
2520
|
-
target_type = strip_reference(target.type)
|
|
2521
|
-
indices = adj.eval_indices(target_type, indices)
|
|
2522
|
-
|
|
2523
|
-
if is_array(target_type):
|
|
2524
|
-
if len(indices) == target_type.ndim:
|
|
2525
|
-
# handles array loads (where each dimension has an index specified)
|
|
2526
|
-
out = adj.add_builtin_call("address", [target, *indices])
|
|
2527
|
-
|
|
2528
|
-
if warp.config.verify_autograd_array_access:
|
|
2529
|
-
target.mark_read()
|
|
2530
|
-
|
|
2531
|
-
else:
|
|
2532
|
-
# handles array views (fewer indices than dimensions)
|
|
2533
|
-
out = adj.add_builtin_call("view", [target, *indices])
|
|
2534
|
-
|
|
2535
|
-
if warp.config.verify_autograd_array_access:
|
|
2536
|
-
# store reference to target Var to propagate downstream read/write state back to root arg Var
|
|
2537
|
-
out.parent = target
|
|
2538
|
-
|
|
2539
|
-
# view arg inherits target Var's read/write states
|
|
2540
|
-
out.is_read = target.is_read
|
|
2541
|
-
out.is_write = target.is_write
|
|
2542
|
-
|
|
2543
|
-
elif is_tile(target_type):
|
|
2544
|
-
if len(indices) == len(target_type.shape):
|
|
2545
|
-
# handles extracting a single element from a tile
|
|
2546
|
-
out = adj.add_builtin_call("tile_extract", [target, *indices])
|
|
2547
|
-
elif len(indices) < len(target_type.shape):
|
|
2548
|
-
# handles tile views
|
|
2549
|
-
out = adj.add_builtin_call("tile_view", [target, indices])
|
|
2550
|
-
else:
|
|
2551
|
-
raise RuntimeError(
|
|
2552
|
-
f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
|
|
2553
|
-
)
|
|
2554
|
-
|
|
2555
|
-
else:
|
|
2556
|
-
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
2557
|
-
out = adj.add_builtin_call("extract", [target, *indices])
|
|
2558
|
-
|
|
2559
|
-
return out
|
|
2560
|
-
|
|
2561
|
-
# from a list of lists of indices, strip the first `count` indices
|
|
2562
|
-
@staticmethod
|
|
2563
|
-
def strip_indices(indices, count):
|
|
2564
|
-
dim = count
|
|
2565
|
-
while count > 0:
|
|
2566
|
-
ij = indices[0]
|
|
2567
|
-
indices = indices[1:]
|
|
2568
|
-
count -= len(ij)
|
|
2569
|
-
|
|
2570
|
-
# report straddling like in `arr2d[0][1,2]` as a syntax error
|
|
2571
|
-
if count < 0:
|
|
2572
|
-
raise WarpCodegenError(
|
|
2573
|
-
f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
|
|
2574
|
-
)
|
|
2575
|
-
|
|
2576
|
-
return indices
|
|
2577
|
-
|
|
2578
|
-
def recurse_subscript(adj, node, indices):
|
|
2579
|
-
if isinstance(node, ast.Name):
|
|
2580
|
-
target = adj.eval(node)
|
|
2581
|
-
return target, indices
|
|
2582
|
-
|
|
2583
|
-
if isinstance(node, ast.Subscript):
|
|
2584
|
-
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
2585
|
-
return adj.eval(node), indices
|
|
2586
|
-
|
|
2587
|
-
if isinstance(node.slice, ast.Tuple):
|
|
2588
|
-
ij = node.slice.elts
|
|
2589
|
-
elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
|
|
2590
|
-
# The node `ast.Index` is deprecated in Python 3.9.
|
|
2591
|
-
ij = node.slice.value.elts
|
|
2592
|
-
elif isinstance(node.slice, ast.ExtSlice):
|
|
2593
|
-
# The node `ast.ExtSlice` is deprecated in Python 3.9.
|
|
2594
|
-
ij = node.slice.dims
|
|
2595
|
-
else:
|
|
2596
|
-
ij = [node.slice]
|
|
2597
|
-
|
|
2598
|
-
indices = [ij, *indices] # prepend
|
|
2599
|
-
|
|
2600
|
-
target, indices = adj.recurse_subscript(node.value, indices)
|
|
2601
|
-
|
|
2602
|
-
target_type = strip_reference(target.type)
|
|
2603
|
-
if is_array(target_type):
|
|
2604
|
-
flat_indices = [i for ij in indices for i in ij]
|
|
2605
|
-
if len(flat_indices) > target_type.ndim:
|
|
2606
|
-
target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
|
|
2607
|
-
indices = adj.strip_indices(indices, target_type.ndim)
|
|
2608
|
-
|
|
2609
|
-
return target, indices
|
|
2610
|
-
|
|
2611
|
-
target = adj.eval(node)
|
|
2612
|
-
return target, indices
|
|
2613
|
-
|
|
2614
|
-
# returns the object being indexed, and the list of indices
|
|
2615
|
-
def eval_subscript(adj, node):
|
|
2616
|
-
target, indices = adj.recurse_subscript(node, [])
|
|
2617
|
-
flat_indices = [i for ij in indices for i in ij]
|
|
2618
|
-
return target, flat_indices
|
|
2619
|
-
|
|
2620
|
-
def emit_Subscript(adj, node):
|
|
2621
|
-
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
2622
|
-
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
2623
|
-
node.slice.is_adjoint = True
|
|
2624
|
-
var = adj.eval(node.slice)
|
|
2625
|
-
var_name = var.label
|
|
2626
|
-
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
2627
|
-
return var
|
|
2628
|
-
|
|
2629
|
-
target, indices = adj.eval_subscript(node)
|
|
2630
|
-
|
|
2631
|
-
return adj.emit_indexing(target, indices)
|
|
2632
|
-
|
|
2633
|
-
def emit_Assign(adj, node):
|
|
2634
|
-
if len(node.targets) != 1:
|
|
2635
|
-
raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
|
|
2636
|
-
|
|
2637
|
-
# Check if the rhs corresponds to an unsupported construct.
|
|
2638
|
-
# Tuples are supported in the context of assigning multiple variables
|
|
2639
|
-
# at once, but not for simple assignments like `x = (1, 2, 3)`.
|
|
2640
|
-
# Therefore, we need to catch this specific case here instead of
|
|
2641
|
-
# more generally in `adj.eval()`.
|
|
2642
|
-
if isinstance(node.value, ast.List):
|
|
2643
|
-
raise WarpCodegenError(
|
|
2644
|
-
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2645
|
-
)
|
|
2646
|
-
|
|
2647
|
-
lhs = node.targets[0]
|
|
2648
|
-
|
|
2649
|
-
if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
|
|
2650
|
-
# record the expected number of outputs on the node
|
|
2651
|
-
# we do this so we can decide which function to
|
|
2652
|
-
# call based on the number of expected outputs
|
|
2653
|
-
node.value.expects = len(lhs.elts)
|
|
2654
|
-
|
|
2655
|
-
# evaluate rhs
|
|
2656
|
-
if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
|
|
2657
|
-
rhs = [adj.eval(v) for v in node.value.elts]
|
|
2658
|
-
else:
|
|
2659
|
-
rhs = adj.eval(node.value)
|
|
2660
|
-
|
|
2661
|
-
# handle the case where we are assigning multiple output variables
|
|
2662
|
-
if isinstance(lhs, ast.Tuple):
|
|
2663
|
-
subtype = getattr(rhs, "type", None)
|
|
2664
|
-
|
|
2665
|
-
if isinstance(subtype, warp.types.tuple_t):
|
|
2666
|
-
if len(rhs.type.types) != len(lhs.elts):
|
|
2667
|
-
raise WarpCodegenError(
|
|
2668
|
-
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
|
|
2669
|
-
)
|
|
2670
|
-
rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
|
|
2671
|
-
|
|
2672
|
-
names = []
|
|
2673
|
-
for v in lhs.elts:
|
|
2674
|
-
if isinstance(v, ast.Name):
|
|
2675
|
-
names.append(v.id)
|
|
2676
|
-
else:
|
|
2677
|
-
raise WarpCodegenError(
|
|
2678
|
-
"Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
|
|
2679
|
-
)
|
|
2680
|
-
|
|
2681
|
-
if len(names) != len(rhs):
|
|
2682
|
-
raise WarpCodegenError(
|
|
2683
|
-
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
|
|
2684
|
-
)
|
|
2685
|
-
|
|
2686
|
-
out = rhs
|
|
2687
|
-
for name, rhs in zip(names, out):
|
|
2688
|
-
if name in adj.symbols:
|
|
2689
|
-
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
2690
|
-
raise WarpCodegenTypeError(
|
|
2691
|
-
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
2692
|
-
)
|
|
2693
|
-
|
|
2694
|
-
adj.symbols[name] = rhs
|
|
2695
|
-
|
|
2696
|
-
# handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
|
|
2697
|
-
elif isinstance(lhs, ast.Subscript):
|
|
2698
|
-
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
2699
|
-
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
2700
|
-
lhs.slice.is_adjoint = True
|
|
2701
|
-
src_var = adj.eval(lhs.slice)
|
|
2702
|
-
var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
|
|
2703
|
-
adj.add_forward(f"{var.emit()} = {rhs.emit()};")
|
|
2704
|
-
return
|
|
2705
|
-
|
|
2706
|
-
target, indices = adj.eval_subscript(lhs)
|
|
2707
|
-
|
|
2708
|
-
target_type = strip_reference(target.type)
|
|
2709
|
-
indices = adj.eval_indices(target_type, indices)
|
|
2710
|
-
|
|
2711
|
-
if is_array(target_type):
|
|
2712
|
-
adj.add_builtin_call("array_store", [target, *indices, rhs])
|
|
2713
|
-
|
|
2714
|
-
if warp.config.verify_autograd_array_access:
|
|
2715
|
-
kernel_name = adj.fun_name
|
|
2716
|
-
filename = adj.filename
|
|
2717
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
2718
|
-
|
|
2719
|
-
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2720
|
-
|
|
2721
|
-
elif is_tile(target_type):
|
|
2722
|
-
adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2723
|
-
|
|
2724
|
-
elif (
|
|
2725
|
-
type_is_vector(target_type)
|
|
2726
|
-
or type_is_quaternion(target_type)
|
|
2727
|
-
or type_is_matrix(target_type)
|
|
2728
|
-
or type_is_transformation(target_type)
|
|
2729
|
-
):
|
|
2730
|
-
# recursively unwind AST, stopping at penultimate node
|
|
2731
|
-
root = lhs
|
|
2732
|
-
while hasattr(root.value, "value"):
|
|
2733
|
-
root = root.value
|
|
2734
|
-
# lhs is updating a variable adjoint (i.e. wp.adjoint[var])
|
|
2735
|
-
if hasattr(root, "attr") and root.attr == "adjoint":
|
|
2736
|
-
attr = adj.add_builtin_call("index", [target, *indices])
|
|
2737
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2738
|
-
return
|
|
2739
|
-
|
|
2740
|
-
# TODO: array vec component case
|
|
2741
|
-
if is_reference(target.type):
|
|
2742
|
-
attr = adj.add_builtin_call("indexref", [target, *indices])
|
|
2743
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2744
|
-
|
|
2745
|
-
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
2746
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
2747
|
-
line = adj.source_lines[adj.lineno]
|
|
2748
|
-
node_source = adj.get_node_source(lhs.value)
|
|
2749
|
-
print(
|
|
2750
|
-
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
2751
|
-
)
|
|
2752
|
-
else:
|
|
2753
|
-
if warp.config.enable_vector_component_overwrites:
|
|
2754
|
-
out = adj.add_builtin_call("assign_copy", [target, *indices, rhs])
|
|
2755
|
-
|
|
2756
|
-
# re-point target symbol to out var
|
|
2757
|
-
for id in adj.symbols:
|
|
2758
|
-
if adj.symbols[id] == target:
|
|
2759
|
-
adj.symbols[id] = out
|
|
2760
|
-
break
|
|
2761
|
-
else:
|
|
2762
|
-
adj.add_builtin_call("assign_inplace", [target, *indices, rhs])
|
|
2763
|
-
|
|
2764
|
-
else:
|
|
2765
|
-
raise WarpCodegenError(
|
|
2766
|
-
f"Can only subscript assign array, vector, quaternion, transformation, and matrix types, got {target_type}"
|
|
2767
|
-
)
|
|
2768
|
-
|
|
2769
|
-
elif isinstance(lhs, ast.Name):
|
|
2770
|
-
# symbol name
|
|
2771
|
-
name = lhs.id
|
|
2772
|
-
|
|
2773
|
-
# check type matches if symbol already defined
|
|
2774
|
-
if name in adj.symbols:
|
|
2775
|
-
if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
|
|
2776
|
-
raise WarpCodegenTypeError(
|
|
2777
|
-
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
2778
|
-
)
|
|
2779
|
-
|
|
2780
|
-
if isinstance(node.value, ast.Tuple):
|
|
2781
|
-
out = rhs
|
|
2782
|
-
elif isinstance(rhs, Sequence):
|
|
2783
|
-
out = adj.add_builtin_call("tuple", rhs)
|
|
2784
|
-
elif isinstance(node.value, ast.Name) or is_reference(rhs.type):
|
|
2785
|
-
out = adj.add_builtin_call("copy", [rhs])
|
|
2786
|
-
else:
|
|
2787
|
-
out = rhs
|
|
2788
|
-
|
|
2789
|
-
# update symbol map (assumes lhs is a Name node)
|
|
2790
|
-
adj.symbols[name] = out
|
|
2791
|
-
|
|
2792
|
-
elif isinstance(lhs, ast.Attribute):
|
|
2793
|
-
aggregate = adj.eval(lhs.value)
|
|
2794
|
-
aggregate_type = strip_reference(aggregate.type)
|
|
2795
|
-
|
|
2796
|
-
# assigning to a vector or quaternion component
|
|
2797
|
-
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
2798
|
-
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
2799
|
-
|
|
2800
|
-
if is_reference(aggregate.type):
|
|
2801
|
-
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
2802
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2803
|
-
else:
|
|
2804
|
-
if warp.config.enable_vector_component_overwrites:
|
|
2805
|
-
out = adj.add_builtin_call("assign_copy", [aggregate, index, rhs])
|
|
2806
|
-
|
|
2807
|
-
# re-point target symbol to out var
|
|
2808
|
-
for id in adj.symbols:
|
|
2809
|
-
if adj.symbols[id] == aggregate:
|
|
2810
|
-
adj.symbols[id] = out
|
|
2811
|
-
break
|
|
2812
|
-
else:
|
|
2813
|
-
adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
|
|
2814
|
-
|
|
2815
|
-
elif type_is_transformation(aggregate_type):
|
|
2816
|
-
component = adj.transform_component(lhs.attr)
|
|
2817
|
-
|
|
2818
|
-
# TODO: x[i,j].p = rhs case
|
|
2819
|
-
if is_reference(aggregate.type):
|
|
2820
|
-
raise WarpCodegenError(f"Error, assigning transform attribute {component} to an array element")
|
|
2821
|
-
|
|
2822
|
-
if component == "p":
|
|
2823
|
-
return adj.add_builtin_call("transform_set_translation", [aggregate, rhs])
|
|
2824
|
-
else:
|
|
2825
|
-
return adj.add_builtin_call("transform_set_rotation", [aggregate, rhs])
|
|
2826
|
-
|
|
2827
|
-
else:
|
|
2828
|
-
attr = adj.emit_Attribute(lhs)
|
|
2829
|
-
if is_reference(attr.type):
|
|
2830
|
-
adj.add_builtin_call("store", [attr, rhs])
|
|
2831
|
-
else:
|
|
2832
|
-
adj.add_builtin_call("assign", [attr, rhs])
|
|
2833
|
-
|
|
2834
|
-
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
2835
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
2836
|
-
line = adj.source_lines[adj.lineno]
|
|
2837
|
-
msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
|
|
2838
|
-
print(msg)
|
|
2839
|
-
|
|
2840
|
-
else:
|
|
2841
|
-
raise WarpCodegenError("Error, unsupported assignment statement.")
|
|
2842
|
-
|
|
2843
|
-
def emit_Return(adj, node):
|
|
2844
|
-
if node.value is None:
|
|
2845
|
-
var = None
|
|
2846
|
-
elif isinstance(node.value, ast.Tuple):
|
|
2847
|
-
var = tuple(adj.eval(arg) for arg in node.value.elts)
|
|
2848
|
-
else:
|
|
2849
|
-
var = adj.eval(node.value)
|
|
2850
|
-
if not isinstance(var, list) and not isinstance(var, tuple):
|
|
2851
|
-
var = (var,)
|
|
2852
|
-
|
|
2853
|
-
if adj.return_var is not None:
|
|
2854
|
-
old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
|
|
2855
|
-
new_ctypes = tuple(v.ctype(value_type=True) for v in var)
|
|
2856
|
-
if old_ctypes != new_ctypes:
|
|
2857
|
-
raise WarpCodegenTypeError(
|
|
2858
|
-
f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
|
|
2859
|
-
)
|
|
2860
|
-
|
|
2861
|
-
if var is not None:
|
|
2862
|
-
adj.return_var = ()
|
|
2863
|
-
for ret in var:
|
|
2864
|
-
if is_reference(ret.type):
|
|
2865
|
-
ret_var = adj.add_builtin_call("copy", [ret])
|
|
2866
|
-
else:
|
|
2867
|
-
ret_var = ret
|
|
2868
|
-
adj.return_var += (ret_var,)
|
|
2869
|
-
|
|
2870
|
-
adj.add_return(adj.return_var)
|
|
2871
|
-
|
|
2872
|
-
def emit_AugAssign(adj, node):
|
|
2873
|
-
lhs = node.target
|
|
2874
|
-
|
|
2875
|
-
# replace augmented assignment with assignment statement + binary op (default behaviour)
|
|
2876
|
-
def make_new_assign_statement():
|
|
2877
|
-
new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
|
|
2878
|
-
adj.eval(new_node)
|
|
2879
|
-
|
|
2880
|
-
rhs = adj.eval(node.value)
|
|
2881
|
-
|
|
2882
|
-
if isinstance(lhs, ast.Subscript):
|
|
2883
|
-
# wp.adjoint[var] appears in custom grad functions, and does not require
|
|
2884
|
-
# special consideration in the AugAssign case
|
|
2885
|
-
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
2886
|
-
make_new_assign_statement()
|
|
2887
|
-
return
|
|
2888
|
-
|
|
2889
|
-
target, indices = adj.eval_subscript(lhs)
|
|
2890
|
-
|
|
2891
|
-
target_type = strip_reference(target.type)
|
|
2892
|
-
indices = adj.eval_indices(target_type, indices)
|
|
2893
|
-
|
|
2894
|
-
if is_array(target_type):
|
|
2895
|
-
# target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
|
|
2896
|
-
if target_type.dtype in warp.types.non_atomic_types:
|
|
2897
|
-
make_new_assign_statement()
|
|
2898
|
-
return
|
|
2899
|
-
|
|
2900
|
-
# the same holds true for vecs/mats/quats that are composed of these types
|
|
2901
|
-
if (
|
|
2902
|
-
type_is_vector(target_type.dtype)
|
|
2903
|
-
or type_is_quaternion(target_type.dtype)
|
|
2904
|
-
or type_is_matrix(target_type.dtype)
|
|
2905
|
-
or type_is_transformation(target_type.dtype)
|
|
2906
|
-
):
|
|
2907
|
-
dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
|
|
2908
|
-
if dtype in warp.types.non_atomic_types:
|
|
2909
|
-
make_new_assign_statement()
|
|
2910
|
-
return
|
|
2911
|
-
|
|
2912
|
-
kernel_name = adj.fun_name
|
|
2913
|
-
filename = adj.filename
|
|
2914
|
-
lineno = adj.lineno + adj.fun_lineno
|
|
2915
|
-
|
|
2916
|
-
if isinstance(node.op, ast.Add):
|
|
2917
|
-
adj.add_builtin_call("atomic_add", [target, *indices, rhs])
|
|
2918
|
-
|
|
2919
|
-
if warp.config.verify_autograd_array_access:
|
|
2920
|
-
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2921
|
-
|
|
2922
|
-
elif isinstance(node.op, ast.Sub):
|
|
2923
|
-
adj.add_builtin_call("atomic_sub", [target, *indices, rhs])
|
|
2924
|
-
|
|
2925
|
-
if warp.config.verify_autograd_array_access:
|
|
2926
|
-
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2927
|
-
else:
|
|
2928
|
-
if warp.config.verbose:
|
|
2929
|
-
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2930
|
-
make_new_assign_statement()
|
|
2931
|
-
return
|
|
2932
|
-
|
|
2933
|
-
elif (
|
|
2934
|
-
type_is_vector(target_type)
|
|
2935
|
-
or type_is_quaternion(target_type)
|
|
2936
|
-
or type_is_matrix(target_type)
|
|
2937
|
-
or type_is_transformation(target_type)
|
|
2938
|
-
):
|
|
2939
|
-
if isinstance(node.op, ast.Add):
|
|
2940
|
-
adj.add_builtin_call("add_inplace", [target, *indices, rhs])
|
|
2941
|
-
elif isinstance(node.op, ast.Sub):
|
|
2942
|
-
adj.add_builtin_call("sub_inplace", [target, *indices, rhs])
|
|
2943
|
-
else:
|
|
2944
|
-
if warp.config.verbose:
|
|
2945
|
-
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2946
|
-
make_new_assign_statement()
|
|
2947
|
-
return
|
|
2948
|
-
|
|
2949
|
-
elif is_tile(target.type):
|
|
2950
|
-
if isinstance(node.op, ast.Add):
|
|
2951
|
-
adj.add_builtin_call("tile_add_inplace", [target, *indices, rhs])
|
|
2952
|
-
elif isinstance(node.op, ast.Sub):
|
|
2953
|
-
adj.add_builtin_call("tile_sub_inplace", [target, *indices, rhs])
|
|
2954
|
-
else:
|
|
2955
|
-
if warp.config.verbose:
|
|
2956
|
-
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2957
|
-
make_new_assign_statement()
|
|
2958
|
-
return
|
|
2959
|
-
|
|
2960
|
-
else:
|
|
2961
|
-
raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
|
|
2962
|
-
|
|
2963
|
-
elif isinstance(lhs, ast.Name):
|
|
2964
|
-
target = adj.eval(node.target)
|
|
2965
|
-
|
|
2966
|
-
if is_tile(target.type) and is_tile(rhs.type):
|
|
2967
|
-
if isinstance(node.op, ast.Add):
|
|
2968
|
-
adj.add_builtin_call("add_inplace", [target, rhs])
|
|
2969
|
-
elif isinstance(node.op, ast.Sub):
|
|
2970
|
-
adj.add_builtin_call("sub_inplace", [target, rhs])
|
|
2971
|
-
else:
|
|
2972
|
-
make_new_assign_statement()
|
|
2973
|
-
return
|
|
2974
|
-
else:
|
|
2975
|
-
make_new_assign_statement()
|
|
2976
|
-
return
|
|
2977
|
-
|
|
2978
|
-
# TODO
|
|
2979
|
-
elif isinstance(lhs, ast.Attribute):
|
|
2980
|
-
make_new_assign_statement()
|
|
2981
|
-
return
|
|
2982
|
-
|
|
2983
|
-
else:
|
|
2984
|
-
make_new_assign_statement()
|
|
2985
|
-
return
|
|
2986
|
-
|
|
2987
|
-
def emit_Tuple(adj, node):
|
|
2988
|
-
elements = tuple(adj.eval(x) for x in node.elts)
|
|
2989
|
-
return adj.add_builtin_call("tuple", elements)
|
|
2990
|
-
|
|
2991
|
-
def emit_Pass(adj, node):
|
|
2992
|
-
pass
|
|
2993
|
-
|
|
2994
|
-
node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
|
|
2995
|
-
ast.FunctionDef: emit_FunctionDef,
|
|
2996
|
-
ast.If: emit_If,
|
|
2997
|
-
ast.IfExp: emit_IfExp,
|
|
2998
|
-
ast.Compare: emit_Compare,
|
|
2999
|
-
ast.BoolOp: emit_BoolOp,
|
|
3000
|
-
ast.Name: emit_Name,
|
|
3001
|
-
ast.Attribute: emit_Attribute,
|
|
3002
|
-
ast.Constant: emit_Constant,
|
|
3003
|
-
ast.BinOp: emit_BinOp,
|
|
3004
|
-
ast.UnaryOp: emit_UnaryOp,
|
|
3005
|
-
ast.While: emit_While,
|
|
3006
|
-
ast.For: emit_For,
|
|
3007
|
-
ast.Break: emit_Break,
|
|
3008
|
-
ast.Continue: emit_Continue,
|
|
3009
|
-
ast.Expr: emit_Expr,
|
|
3010
|
-
ast.Call: emit_Call,
|
|
3011
|
-
ast.Index: emit_Index, # Deprecated in 3.9
|
|
3012
|
-
ast.Subscript: emit_Subscript,
|
|
3013
|
-
ast.Assign: emit_Assign,
|
|
3014
|
-
ast.Return: emit_Return,
|
|
3015
|
-
ast.AugAssign: emit_AugAssign,
|
|
3016
|
-
ast.Tuple: emit_Tuple,
|
|
3017
|
-
ast.Pass: emit_Pass,
|
|
3018
|
-
ast.Assert: emit_Assert,
|
|
3019
|
-
}
|
|
3020
|
-
|
|
3021
|
-
def eval(adj, node):
|
|
3022
|
-
if hasattr(node, "lineno"):
|
|
3023
|
-
adj.set_lineno(node.lineno - 1)
|
|
3024
|
-
|
|
3025
|
-
try:
|
|
3026
|
-
emit_node = adj.node_visitors[type(node)]
|
|
3027
|
-
except KeyError as e:
|
|
3028
|
-
type_name = type(node).__name__
|
|
3029
|
-
namespace = "ast." if isinstance(node, ast.AST) else ""
|
|
3030
|
-
raise WarpCodegenError(f"Construct `{namespace}{type_name}` not supported in kernels.") from e
|
|
3031
|
-
|
|
3032
|
-
return emit_node(adj, node)
|
|
3033
|
-
|
|
3034
|
-
# helper to evaluate expressions of the form
|
|
3035
|
-
# obj1.obj2.obj3.attr in the function's global scope
|
|
3036
|
-
def resolve_path(adj, path):
|
|
3037
|
-
if len(path) == 0:
|
|
3038
|
-
return None
|
|
3039
|
-
|
|
3040
|
-
# if root is overshadowed by local symbols, bail out
|
|
3041
|
-
if path[0] in adj.symbols:
|
|
3042
|
-
return None
|
|
3043
|
-
|
|
3044
|
-
# look up in closure/global variables
|
|
3045
|
-
expr = adj.resolve_external_reference(path[0])
|
|
3046
|
-
|
|
3047
|
-
# Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
|
|
3048
|
-
if expr is None:
|
|
3049
|
-
expr = getattr(warp, path[0], None)
|
|
3050
|
-
|
|
3051
|
-
# look up in builtins
|
|
3052
|
-
if expr is None:
|
|
3053
|
-
expr = __builtins__.get(path[0])
|
|
3054
|
-
|
|
3055
|
-
if expr is not None:
|
|
3056
|
-
for i in range(1, len(path)):
|
|
3057
|
-
if hasattr(expr, path[i]):
|
|
3058
|
-
expr = getattr(expr, path[i])
|
|
3059
|
-
|
|
3060
|
-
return expr
|
|
3061
|
-
|
|
3062
|
-
# retrieves a dictionary of all closure and global variables and their values
|
|
3063
|
-
# to be used in the evaluation context of wp.static() expressions
|
|
3064
|
-
def get_static_evaluation_context(adj):
|
|
3065
|
-
closure_vars = dict(
|
|
3066
|
-
zip(
|
|
3067
|
-
adj.func.__code__.co_freevars,
|
|
3068
|
-
[c.cell_contents for c in (adj.func.__closure__ or [])],
|
|
3069
|
-
)
|
|
3070
|
-
)
|
|
3071
|
-
|
|
3072
|
-
vars_dict = {}
|
|
3073
|
-
vars_dict.update(adj.func.__globals__)
|
|
3074
|
-
# variables captured in closure have precedence over global vars
|
|
3075
|
-
vars_dict.update(closure_vars)
|
|
3076
|
-
|
|
3077
|
-
return vars_dict
|
|
3078
|
-
|
|
3079
|
-
def is_static_expression(adj, func):
|
|
3080
|
-
return (
|
|
3081
|
-
isinstance(func, types.FunctionType)
|
|
3082
|
-
and func.__module__ == "warp.builtins"
|
|
3083
|
-
and func.__qualname__ == "static"
|
|
3084
|
-
)
|
|
3085
|
-
|
|
3086
|
-
# verify the return type of a wp.static() expression is supported inside a Warp kernel
|
|
3087
|
-
def verify_static_return_value(adj, value):
|
|
3088
|
-
if value is None:
|
|
3089
|
-
raise ValueError("None is returned")
|
|
3090
|
-
if warp.types.is_value(value):
|
|
3091
|
-
return True
|
|
3092
|
-
if warp.types.is_array(value):
|
|
3093
|
-
# more useful explanation for the common case of creating a Warp array
|
|
3094
|
-
raise ValueError("a Warp array cannot be created inside Warp kernels")
|
|
3095
|
-
if isinstance(value, str):
|
|
3096
|
-
# we want to support cases such as `print(wp.static("test"))`
|
|
3097
|
-
return True
|
|
3098
|
-
if isinstance(value, warp.context.Function):
|
|
3099
|
-
return True
|
|
3100
|
-
|
|
3101
|
-
def verify_struct(s: StructInstance, attr_path: list[str]):
|
|
3102
|
-
for key in s._cls.vars.keys():
|
|
3103
|
-
v = getattr(s, key)
|
|
3104
|
-
if issubclass(type(v), StructInstance):
|
|
3105
|
-
verify_struct(v, [*attr_path, key])
|
|
3106
|
-
else:
|
|
3107
|
-
try:
|
|
3108
|
-
adj.verify_static_return_value(v)
|
|
3109
|
-
except ValueError as e:
|
|
3110
|
-
raise ValueError(
|
|
3111
|
-
f"the returned Warp struct contains a data type that cannot be constructed inside Warp kernels: {e} at {value._cls.key}.{'.'.join(attr_path)}"
|
|
3112
|
-
) from e
|
|
3113
|
-
|
|
3114
|
-
if issubclass(type(value), StructInstance):
|
|
3115
|
-
return verify_struct(value, [])
|
|
3116
|
-
|
|
3117
|
-
raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
|
|
3118
|
-
|
|
3119
|
-
# find the source code string of an AST node
|
|
3120
|
-
@staticmethod
|
|
3121
|
-
def extract_node_source_from_lines(source_lines, node) -> str | None:
|
|
3122
|
-
if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
|
|
3123
|
-
return None
|
|
3124
|
-
|
|
3125
|
-
start_line = node.lineno - 1 # line numbers start at 1
|
|
3126
|
-
start_col = node.col_offset
|
|
3127
|
-
|
|
3128
|
-
if hasattr(node, "end_lineno") and hasattr(node, "end_col_offset"):
|
|
3129
|
-
end_line = node.end_lineno - 1
|
|
3130
|
-
end_col = node.end_col_offset
|
|
3131
|
-
else:
|
|
3132
|
-
# fallback for Python versions before 3.8
|
|
3133
|
-
# we have to find the end line and column manually
|
|
3134
|
-
end_line = start_line
|
|
3135
|
-
end_col = start_col
|
|
3136
|
-
parenthesis_count = 1
|
|
3137
|
-
for lineno in range(start_line, len(source_lines)):
|
|
3138
|
-
if lineno == start_line:
|
|
3139
|
-
c_start = start_col
|
|
3140
|
-
else:
|
|
3141
|
-
c_start = 0
|
|
3142
|
-
line = source_lines[lineno]
|
|
3143
|
-
for i in range(c_start, len(line)):
|
|
3144
|
-
c = line[i]
|
|
3145
|
-
if c == "(":
|
|
3146
|
-
parenthesis_count += 1
|
|
3147
|
-
elif c == ")":
|
|
3148
|
-
parenthesis_count -= 1
|
|
3149
|
-
if parenthesis_count == 0:
|
|
3150
|
-
end_col = i
|
|
3151
|
-
end_line = lineno
|
|
3152
|
-
break
|
|
3153
|
-
if parenthesis_count == 0:
|
|
3154
|
-
break
|
|
3155
|
-
|
|
3156
|
-
if start_line == end_line:
|
|
3157
|
-
# single-line expression
|
|
3158
|
-
return source_lines[start_line][start_col:end_col]
|
|
3159
|
-
else:
|
|
3160
|
-
# multi-line expression
|
|
3161
|
-
lines = []
|
|
3162
|
-
# first line (from start_col to the end)
|
|
3163
|
-
lines.append(source_lines[start_line][start_col:])
|
|
3164
|
-
# middle lines (entire lines)
|
|
3165
|
-
lines.extend(source_lines[start_line + 1 : end_line])
|
|
3166
|
-
# last line (from the start to end_col)
|
|
3167
|
-
lines.append(source_lines[end_line][:end_col])
|
|
3168
|
-
return "\n".join(lines).strip()
|
|
3169
|
-
|
|
3170
|
-
@staticmethod
|
|
3171
|
-
def extract_lambda_source(func, only_body=False) -> str | None:
|
|
3172
|
-
try:
|
|
3173
|
-
source_lines = inspect.getsourcelines(func)[0]
|
|
3174
|
-
source_lines[0] = source_lines[0][source_lines[0].index("lambda") :]
|
|
3175
|
-
except OSError as e:
|
|
3176
|
-
raise WarpCodegenError(
|
|
3177
|
-
"Could not access lambda function source code. Please use a named function instead."
|
|
3178
|
-
) from e
|
|
3179
|
-
source = "".join(source_lines)
|
|
3180
|
-
source = source[source.index("lambda") :].rstrip()
|
|
3181
|
-
# Remove trailing unbalanced parentheses
|
|
3182
|
-
while source.count("(") < source.count(")"):
|
|
3183
|
-
source = source[:-1]
|
|
3184
|
-
# extract lambda expression up until a comma, e.g. in the case of
|
|
3185
|
-
# "map(lambda a: (a + 2.0, a + 3.0), a, return_kernel=True)"
|
|
3186
|
-
si = max(source.find(")"), source.find(":"))
|
|
3187
|
-
ci = source.find(",", si)
|
|
3188
|
-
if ci != -1:
|
|
3189
|
-
source = source[:ci]
|
|
3190
|
-
tree = ast.parse(source)
|
|
3191
|
-
lambda_source = None
|
|
3192
|
-
for node in ast.walk(tree):
|
|
3193
|
-
if isinstance(node, ast.Lambda):
|
|
3194
|
-
if only_body:
|
|
3195
|
-
# extract the body of the lambda function
|
|
3196
|
-
lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node.body)
|
|
3197
|
-
else:
|
|
3198
|
-
# extract the entire lambda function
|
|
3199
|
-
lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node)
|
|
3200
|
-
break
|
|
3201
|
-
return lambda_source
|
|
3202
|
-
|
|
3203
|
-
def extract_node_source(adj, node) -> str | None:
|
|
3204
|
-
return adj.extract_node_source_from_lines(adj.source_lines, node)
|
|
3205
|
-
|
|
3206
|
-
# handles a wp.static() expression and returns the resulting object and a string representing the code
|
|
3207
|
-
# of the static expression
|
|
3208
|
-
def evaluate_static_expression(adj, node) -> tuple[Any, str]:
|
|
3209
|
-
if len(node.args) == 1:
|
|
3210
|
-
static_code = adj.extract_node_source(node.args[0])
|
|
3211
|
-
elif len(node.keywords) == 1:
|
|
3212
|
-
static_code = adj.extract_node_source(node.keywords[0])
|
|
3213
|
-
else:
|
|
3214
|
-
raise WarpCodegenError("warp.static() requires a single argument or keyword")
|
|
3215
|
-
if static_code is None:
|
|
3216
|
-
raise WarpCodegenError("Error extracting source code from wp.static() expression")
|
|
3217
|
-
|
|
3218
|
-
# Since this is an expression, we can enforce it to be defined on a single line.
|
|
3219
|
-
static_code = static_code.replace("\n", "")
|
|
3220
|
-
code_to_eval = static_code # code to be evaluated
|
|
3221
|
-
|
|
3222
|
-
vars_dict = adj.get_static_evaluation_context()
|
|
3223
|
-
# add constant variables to the static call context
|
|
3224
|
-
constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
|
|
3225
|
-
vars_dict.update(constant_vars)
|
|
3226
|
-
|
|
3227
|
-
# Replace all constant `len()` expressions with their value.
|
|
3228
|
-
if "len" in static_code:
|
|
3229
|
-
len_expr_ctx = vars_dict.copy()
|
|
3230
|
-
constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
|
|
3231
|
-
len_expr_ctx.update(constant_types)
|
|
3232
|
-
len_expr_ctx.update({"len": warp.types.type_length})
|
|
3233
|
-
|
|
3234
|
-
# We want to replace the expression code in-place,
|
|
3235
|
-
# so reparse it to get the correct column info.
|
|
3236
|
-
len_value_locs: list[tuple[int, int, int]] = []
|
|
3237
|
-
expr_tree = ast.parse(static_code)
|
|
3238
|
-
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
3239
|
-
expr_root = expr_tree.body[0].value
|
|
3240
|
-
for expr_node in ast.walk(expr_root):
|
|
3241
|
-
if (
|
|
3242
|
-
isinstance(expr_node, ast.Call)
|
|
3243
|
-
and getattr(expr_node.func, "id", None) == "len"
|
|
3244
|
-
and len(expr_node.args) == 1
|
|
3245
|
-
):
|
|
3246
|
-
len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
|
|
3247
|
-
try:
|
|
3248
|
-
len_value = eval(len_expr, len_expr_ctx)
|
|
3249
|
-
except Exception:
|
|
3250
|
-
pass
|
|
3251
|
-
else:
|
|
3252
|
-
len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
|
|
3253
|
-
|
|
3254
|
-
if len_value_locs:
|
|
3255
|
-
new_static_code = ""
|
|
3256
|
-
loc = 0
|
|
3257
|
-
for value, start, end in len_value_locs:
|
|
3258
|
-
new_static_code += f"{static_code[loc:start]}{value}"
|
|
3259
|
-
loc = end
|
|
3260
|
-
|
|
3261
|
-
new_static_code += static_code[len_value_locs[-1][2] :]
|
|
3262
|
-
code_to_eval = new_static_code
|
|
3263
|
-
|
|
3264
|
-
try:
|
|
3265
|
-
value = eval(code_to_eval, vars_dict)
|
|
3266
|
-
if isinstance(value, (enum.IntEnum, enum.IntFlag)):
|
|
3267
|
-
value = int(value)
|
|
3268
|
-
if warp.config.verbose:
|
|
3269
|
-
print(f"Evaluated static command: {static_code} = {value}")
|
|
3270
|
-
except NameError as e:
|
|
3271
|
-
raise WarpCodegenError(
|
|
3272
|
-
f"Error evaluating static expression: {e}. Make sure all variables used in the static expression are constant."
|
|
3273
|
-
) from e
|
|
3274
|
-
except Exception as e:
|
|
3275
|
-
raise WarpCodegenError(
|
|
3276
|
-
f"Error evaluating static expression: {e} while evaluating the following code generated from the static expression:\n{static_code}"
|
|
3277
|
-
) from e
|
|
3278
|
-
|
|
3279
|
-
try:
|
|
3280
|
-
adj.verify_static_return_value(value)
|
|
3281
|
-
except ValueError as e:
|
|
3282
|
-
raise WarpCodegenError(
|
|
3283
|
-
f"Static expression returns an unsupported value: {e} while evaluating the following code generated from the static expression:\n{static_code}"
|
|
3284
|
-
) from e
|
|
3285
|
-
|
|
3286
|
-
return value, static_code
|
|
3287
|
-
|
|
3288
|
-
# try to replace wp.static() expressions by their evaluated value if the
|
|
3289
|
-
# expression can be evaluated
|
|
3290
|
-
def replace_static_expressions(adj):
|
|
3291
|
-
class StaticExpressionReplacer(ast.NodeTransformer):
|
|
3292
|
-
def visit_Call(self, node):
|
|
3293
|
-
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
3294
|
-
if adj.is_static_expression(func):
|
|
3295
|
-
try:
|
|
3296
|
-
# the static expression will execute as long as the static expression is valid and
|
|
3297
|
-
# only depends on global or captured variables
|
|
3298
|
-
obj, code = adj.evaluate_static_expression(node)
|
|
3299
|
-
if code is not None:
|
|
3300
|
-
adj.static_expressions[code] = obj
|
|
3301
|
-
if isinstance(obj, warp.context.Function):
|
|
3302
|
-
name_node = ast.Name("__warp_func__")
|
|
3303
|
-
# we add a pointer to the Warp function here so that we can refer to it later at
|
|
3304
|
-
# codegen time (note that the function key itself is not sufficient to uniquely
|
|
3305
|
-
# identify the function, as the function may be redefined between the current time
|
|
3306
|
-
# of wp.static() declaration and the time of codegen during module building)
|
|
3307
|
-
name_node.warp_func = obj
|
|
3308
|
-
return ast.copy_location(name_node, node)
|
|
3309
|
-
else:
|
|
3310
|
-
return ast.copy_location(ast.Constant(value=obj), node)
|
|
3311
|
-
except Exception:
|
|
3312
|
-
# Ignoring failing static expressions should generally not be an issue because only
|
|
3313
|
-
# one of these cases should be possible:
|
|
3314
|
-
# 1) the static expression itself is invalid code, in which case the module cannot be
|
|
3315
|
-
# built all,
|
|
3316
|
-
# 2) the static expression contains a reference to a local (even if constant) variable
|
|
3317
|
-
# (and is therefore not executable and raises this exception), in which
|
|
3318
|
-
# case changing the constant, or the code affecting this constant, would lead to
|
|
3319
|
-
# a different module hash anyway.
|
|
3320
|
-
# In any case, we mark this Adjoint to have unresolvable static expressions.
|
|
3321
|
-
# This will trigger a code generation step even if the module hash is unchanged.
|
|
3322
|
-
adj.has_unresolved_static_expressions = True
|
|
3323
|
-
pass
|
|
3324
|
-
|
|
3325
|
-
return self.generic_visit(node)
|
|
3326
|
-
|
|
3327
|
-
adj.tree = StaticExpressionReplacer().visit(adj.tree)
|
|
3328
|
-
|
|
3329
|
-
# Evaluates a static expression that does not depend on runtime values
|
|
3330
|
-
# if eval_types is True, try resolving the path using evaluated type information as well
|
|
3331
|
-
def resolve_static_expression(adj, root_node, eval_types=True):
|
|
3332
|
-
attributes = []
|
|
3333
|
-
|
|
3334
|
-
node = root_node
|
|
3335
|
-
while isinstance(node, ast.Attribute):
|
|
3336
|
-
attributes.append(node.attr)
|
|
3337
|
-
node = node.value
|
|
3338
|
-
|
|
3339
|
-
if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
|
3340
|
-
# support for operators returning modules
|
|
3341
|
-
# i.e. operator_name(*operator_args).x.y.z
|
|
3342
|
-
operator_args = node.args
|
|
3343
|
-
operator_name = node.func.id
|
|
3344
|
-
|
|
3345
|
-
if operator_name == "type":
|
|
3346
|
-
if len(operator_args) != 1:
|
|
3347
|
-
raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
|
|
3348
|
-
|
|
3349
|
-
# type() operator
|
|
3350
|
-
var = adj.eval(operator_args[0])
|
|
3351
|
-
|
|
3352
|
-
if isinstance(var, Var):
|
|
3353
|
-
var_type = strip_reference(var.type)
|
|
3354
|
-
# Allow accessing type attributes, for instance array.dtype
|
|
3355
|
-
while attributes:
|
|
3356
|
-
attr_name = attributes.pop()
|
|
3357
|
-
var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
|
|
3358
|
-
|
|
3359
|
-
if var_type is None:
|
|
3360
|
-
raise WarpCodegenAttributeError(
|
|
3361
|
-
f"{attr_name} is not an attribute of {type_repr(prev_type)}"
|
|
3362
|
-
)
|
|
3363
|
-
|
|
3364
|
-
return var_type, [str(var_type)]
|
|
3365
|
-
else:
|
|
3366
|
-
raise WarpCodegenError(f"Cannot deduce the type of {var}")
|
|
3367
|
-
|
|
3368
|
-
# reverse list since ast presents it in backward order
|
|
3369
|
-
path = [*reversed(attributes)]
|
|
3370
|
-
if isinstance(node, ast.Name):
|
|
3371
|
-
path.insert(0, node.id)
|
|
3372
|
-
|
|
3373
|
-
# Try resolving path from captured context
|
|
3374
|
-
captured_obj = adj.resolve_path(path)
|
|
3375
|
-
if captured_obj is not None:
|
|
3376
|
-
return captured_obj, path
|
|
3377
|
-
|
|
3378
|
-
return None, path
|
|
3379
|
-
|
|
3380
|
-
def resolve_external_reference(adj, name: str):
|
|
3381
|
-
try:
|
|
3382
|
-
# look up in closure variables
|
|
3383
|
-
idx = adj.func.__code__.co_freevars.index(name)
|
|
3384
|
-
obj = adj.func.__closure__[idx].cell_contents
|
|
3385
|
-
except ValueError:
|
|
3386
|
-
# look up in global variables
|
|
3387
|
-
obj = adj.func.__globals__.get(name)
|
|
3388
|
-
return obj
|
|
3389
|
-
|
|
3390
|
-
# annotate generated code with the original source code line
|
|
3391
|
-
def set_lineno(adj, lineno):
|
|
3392
|
-
if adj.lineno is None or adj.lineno != lineno:
|
|
3393
|
-
line = lineno + adj.fun_lineno
|
|
3394
|
-
source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
|
|
3395
|
-
adj.add_forward(f"// {source} <L {line}>")
|
|
3396
|
-
adj.add_reverse(f"// adj: {source} <L {line}>")
|
|
3397
|
-
adj.lineno = lineno
|
|
3398
|
-
|
|
3399
|
-
def get_node_source(adj, node):
|
|
3400
|
-
# return the Python code corresponding to the given AST node
|
|
3401
|
-
return ast.get_source_segment(adj.source, node)
|
|
3402
|
-
|
|
3403
|
-
def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp.context.Function, Any]]:
|
|
3404
|
-
"""Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
|
|
3405
|
-
|
|
3406
|
-
local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
|
|
3407
|
-
|
|
3408
|
-
constants: dict[str, Any] = {}
|
|
3409
|
-
types: dict[Struct | type, Any] = {}
|
|
3410
|
-
functions: dict[warp.context.Function, Any] = {}
|
|
3411
|
-
|
|
3412
|
-
for node in ast.walk(adj.tree):
|
|
3413
|
-
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
3414
|
-
# look up in closure/global variables
|
|
3415
|
-
obj = adj.resolve_external_reference(node.id)
|
|
3416
|
-
if warp.types.is_value(obj):
|
|
3417
|
-
constants[node.id] = obj
|
|
3418
|
-
|
|
3419
|
-
elif isinstance(node, ast.Attribute):
|
|
3420
|
-
obj, path = adj.resolve_static_expression(node, eval_types=False)
|
|
3421
|
-
if warp.types.is_value(obj):
|
|
3422
|
-
constants[".".join(path)] = obj
|
|
3423
|
-
|
|
3424
|
-
elif isinstance(node, ast.Call):
|
|
3425
|
-
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
3426
|
-
if isinstance(func, warp.context.Function) and not func.is_builtin():
|
|
3427
|
-
# calling user-defined function
|
|
3428
|
-
functions[func] = None
|
|
3429
|
-
elif isinstance(func, Struct):
|
|
3430
|
-
# calling struct constructor
|
|
3431
|
-
types[func] = None
|
|
3432
|
-
elif isinstance(func, type) and warp.types.type_is_value(func):
|
|
3433
|
-
# calling value type constructor
|
|
3434
|
-
types[func] = None
|
|
3435
|
-
|
|
3436
|
-
elif isinstance(node, ast.Assign):
|
|
3437
|
-
# Add the LHS names to the local_variables so we know any subsequent uses are shadowed
|
|
3438
|
-
lhs = node.targets[0]
|
|
3439
|
-
if isinstance(lhs, ast.Tuple):
|
|
3440
|
-
for v in lhs.elts:
|
|
3441
|
-
if isinstance(v, ast.Name):
|
|
3442
|
-
local_variables.add(v.id)
|
|
3443
|
-
elif isinstance(lhs, ast.Name):
|
|
3444
|
-
local_variables.add(lhs.id)
|
|
3445
|
-
|
|
3446
|
-
return constants, types, functions
|
|
3447
|
-
|
|
3448
|
-
|
|
3449
|
-
# ----------------
|
|
3450
|
-
# code generation
|
|
3451
|
-
|
|
3452
|
-
cpu_module_header = """
|
|
3453
|
-
#define WP_TILE_BLOCK_DIM {block_dim}
|
|
3454
|
-
#define WP_NO_CRT
|
|
3455
|
-
#include "builtin.h"
|
|
3456
|
-
|
|
3457
|
-
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
|
|
3458
|
-
#define float(x) cast_float(x)
|
|
3459
|
-
#define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
|
|
3460
|
-
|
|
3461
|
-
#define int(x) cast_int(x)
|
|
3462
|
-
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
3463
|
-
|
|
3464
|
-
#define builtin_tid1d() wp::tid(task_index, dim)
|
|
3465
|
-
#define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
|
|
3466
|
-
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
|
|
3467
|
-
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
|
|
3468
|
-
|
|
3469
|
-
#define builtin_block_dim() wp::block_dim()
|
|
3470
|
-
|
|
3471
|
-
"""
|
|
3472
|
-
|
|
3473
|
-
cuda_module_header = """
|
|
3474
|
-
#define WP_TILE_BLOCK_DIM {block_dim}
|
|
3475
|
-
#define WP_NO_CRT
|
|
3476
|
-
#include "builtin.h"
|
|
3477
|
-
|
|
3478
|
-
// Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
|
|
3479
|
-
#if defined(__CUDACC__) && !defined(_MSC_VER)
|
|
3480
|
-
#define __debugbreak() __brkpt()
|
|
3481
|
-
#endif
|
|
3482
|
-
|
|
3483
|
-
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
|
|
3484
|
-
#define float(x) cast_float(x)
|
|
3485
|
-
#define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
|
|
3486
|
-
|
|
3487
|
-
#define int(x) cast_int(x)
|
|
3488
|
-
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
3489
|
-
|
|
3490
|
-
#define builtin_tid1d() wp::tid(_idx, dim)
|
|
3491
|
-
#define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
|
|
3492
|
-
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
3493
|
-
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
3494
|
-
|
|
3495
|
-
#define builtin_block_dim() wp::block_dim()
|
|
3496
|
-
|
|
3497
|
-
"""
|
|
3498
|
-
|
|
3499
|
-
struct_template = """
|
|
3500
|
-
struct {name}
|
|
3501
|
-
{{
|
|
3502
|
-
{struct_body}
|
|
3503
|
-
|
|
3504
|
-
{defaulted_constructor_def}
|
|
3505
|
-
CUDA_CALLABLE {name}({forward_args})
|
|
3506
|
-
{forward_initializers}
|
|
3507
|
-
{{
|
|
3508
|
-
}}
|
|
3509
|
-
|
|
3510
|
-
CUDA_CALLABLE {name}& operator += (const {name}& rhs)
|
|
3511
|
-
{{{prefix_add_body}
|
|
3512
|
-
return *this;}}
|
|
3513
|
-
|
|
3514
|
-
}};
|
|
3515
|
-
|
|
3516
|
-
static CUDA_CALLABLE void adj_{name}({reverse_args})
|
|
3517
|
-
{{
|
|
3518
|
-
{reverse_body}}}
|
|
3519
|
-
|
|
3520
|
-
// Required when compiling adjoints.
|
|
3521
|
-
CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
|
|
3522
|
-
{{
|
|
3523
|
-
return {name}();
|
|
3524
|
-
}}
|
|
3525
|
-
|
|
3526
|
-
CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
|
|
3527
|
-
{{
|
|
3528
|
-
{atomic_add_body}}}
|
|
3529
|
-
|
|
3530
|
-
|
|
3531
|
-
"""
|
|
3532
|
-
|
|
3533
|
-
cpu_forward_function_template = """
|
|
3534
|
-
// {filename}:{lineno}
|
|
3535
|
-
static {return_type} {name}(
|
|
3536
|
-
{forward_args})
|
|
3537
|
-
{{
|
|
3538
|
-
{forward_body}}}
|
|
3539
|
-
|
|
3540
|
-
"""
|
|
3541
|
-
|
|
3542
|
-
cpu_reverse_function_template = """
|
|
3543
|
-
// {filename}:{lineno}
|
|
3544
|
-
static void adj_{name}(
|
|
3545
|
-
{reverse_args})
|
|
3546
|
-
{{
|
|
3547
|
-
{reverse_body}}}
|
|
3548
|
-
|
|
3549
|
-
"""
|
|
3550
|
-
|
|
3551
|
-
cuda_forward_function_template = """
|
|
3552
|
-
// {filename}:{lineno}
|
|
3553
|
-
{line_directive}static CUDA_CALLABLE {return_type} {name}(
|
|
3554
|
-
{forward_args})
|
|
3555
|
-
{{
|
|
3556
|
-
{forward_body}{line_directive}}}
|
|
3557
|
-
|
|
3558
|
-
"""
|
|
3559
|
-
|
|
3560
|
-
cuda_reverse_function_template = """
|
|
3561
|
-
// {filename}:{lineno}
|
|
3562
|
-
{line_directive}static CUDA_CALLABLE void adj_{name}(
|
|
3563
|
-
{reverse_args})
|
|
3564
|
-
{{
|
|
3565
|
-
{reverse_body}{line_directive}}}
|
|
3566
|
-
|
|
3567
|
-
"""
|
|
3568
|
-
|
|
3569
|
-
cuda_kernel_template_forward = """
|
|
3570
|
-
|
|
3571
|
-
{line_directive}extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
3572
|
-
{forward_args})
|
|
3573
|
-
{{
|
|
3574
|
-
{line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3575
|
-
{line_directive} _idx < dim.size;
|
|
3576
|
-
{line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3577
|
-
{{
|
|
3578
|
-
// reset shared memory allocator
|
|
3579
|
-
{line_directive} wp::tile_alloc_shared(0, true);
|
|
3580
|
-
|
|
3581
|
-
{forward_body}{line_directive} }}
|
|
3582
|
-
{line_directive}}}
|
|
3583
|
-
|
|
3584
|
-
"""
|
|
3585
|
-
|
|
3586
|
-
cuda_kernel_template_backward = """
|
|
3587
|
-
|
|
3588
|
-
{line_directive}extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
3589
|
-
{reverse_args})
|
|
3590
|
-
{{
|
|
3591
|
-
{line_directive} for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
3592
|
-
{line_directive} _idx < dim.size;
|
|
3593
|
-
{line_directive} _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
3594
|
-
{{
|
|
3595
|
-
// reset shared memory allocator
|
|
3596
|
-
{line_directive} wp::tile_alloc_shared(0, true);
|
|
3597
|
-
|
|
3598
|
-
{reverse_body}{line_directive} }}
|
|
3599
|
-
{line_directive}}}
|
|
3600
|
-
|
|
3601
|
-
"""
|
|
3602
|
-
|
|
3603
|
-
cpu_kernel_template_forward = """
|
|
3604
|
-
|
|
3605
|
-
void {name}_cpu_kernel_forward(
|
|
3606
|
-
{forward_args},
|
|
3607
|
-
wp_args_{name} *_wp_args)
|
|
3608
|
-
{{
|
|
3609
|
-
{forward_body}}}
|
|
3610
|
-
|
|
3611
|
-
"""
|
|
3612
|
-
|
|
3613
|
-
cpu_kernel_template_backward = """
|
|
3614
|
-
|
|
3615
|
-
void {name}_cpu_kernel_backward(
|
|
3616
|
-
{reverse_args},
|
|
3617
|
-
wp_args_{name} *_wp_args,
|
|
3618
|
-
wp_args_{name} *_wp_adj_args)
|
|
3619
|
-
{{
|
|
3620
|
-
{reverse_body}}}
|
|
3621
|
-
|
|
3622
|
-
"""
|
|
3623
|
-
|
|
3624
|
-
cpu_module_template_forward = """
|
|
3625
|
-
|
|
3626
|
-
extern "C" {{
|
|
3627
|
-
|
|
3628
|
-
// Python CPU entry points
|
|
3629
|
-
WP_API void {name}_cpu_forward(
|
|
3630
|
-
wp::launch_bounds_t dim,
|
|
3631
|
-
wp_args_{name} *_wp_args)
|
|
3632
|
-
{{
|
|
3633
|
-
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3634
|
-
{{
|
|
3635
|
-
// init shared memory allocator
|
|
3636
|
-
wp::tile_alloc_shared(0, true);
|
|
3637
|
-
|
|
3638
|
-
{name}_cpu_kernel_forward(dim, task_index, _wp_args);
|
|
3639
|
-
|
|
3640
|
-
// check shared memory allocator
|
|
3641
|
-
wp::tile_alloc_shared(0, false, true);
|
|
3642
|
-
|
|
3643
|
-
}}
|
|
3644
|
-
}}
|
|
3645
|
-
|
|
3646
|
-
}} // extern C
|
|
3647
|
-
|
|
3648
|
-
"""
|
|
3649
|
-
|
|
3650
|
-
cpu_module_template_backward = """
|
|
3651
|
-
|
|
3652
|
-
extern "C" {{
|
|
3653
|
-
|
|
3654
|
-
WP_API void {name}_cpu_backward(
|
|
3655
|
-
wp::launch_bounds_t dim,
|
|
3656
|
-
wp_args_{name} *_wp_args,
|
|
3657
|
-
wp_args_{name} *_wp_adj_args)
|
|
3658
|
-
{{
|
|
3659
|
-
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
3660
|
-
{{
|
|
3661
|
-
// initialize shared memory allocator
|
|
3662
|
-
wp::tile_alloc_shared(0, true);
|
|
3663
|
-
|
|
3664
|
-
{name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
|
|
3665
|
-
|
|
3666
|
-
// check shared memory allocator
|
|
3667
|
-
wp::tile_alloc_shared(0, false, true);
|
|
3668
|
-
}}
|
|
3669
|
-
}}
|
|
3670
|
-
|
|
3671
|
-
}} // extern C
|
|
3672
|
-
|
|
3673
|
-
"""
|
|
3674
|
-
|
|
3675
|
-
|
|
3676
|
-
# converts a constant Python value to equivalent C-repr
|
|
3677
|
-
def constant_str(value):
|
|
3678
|
-
value_type = type(value)
|
|
3679
|
-
|
|
3680
|
-
if value_type == bool or value_type == builtins.bool:
|
|
3681
|
-
if value:
|
|
3682
|
-
return "true"
|
|
3683
|
-
else:
|
|
3684
|
-
return "false"
|
|
3685
|
-
|
|
3686
|
-
elif value_type == str:
|
|
3687
|
-
# ensure constant strings are correctly escaped
|
|
3688
|
-
return '"' + str(value.encode("unicode-escape").decode()) + '"'
|
|
3689
|
-
|
|
3690
|
-
elif isinstance(value, ctypes.Array):
|
|
3691
|
-
if value_type._wp_scalar_type_ == float16:
|
|
3692
|
-
# special case for float16, which is stored as uint16 in the ctypes.Array
|
|
3693
|
-
from warp.context import runtime
|
|
3694
|
-
|
|
3695
|
-
scalar_value = runtime.core.wp_half_bits_to_float
|
|
3696
|
-
else:
|
|
3697
|
-
|
|
3698
|
-
def scalar_value(x):
|
|
3699
|
-
return x
|
|
3700
|
-
|
|
3701
|
-
# list of scalar initializer values
|
|
3702
|
-
initlist = []
|
|
3703
|
-
for i in range(value._length_):
|
|
3704
|
-
x = ctypes.Array.__getitem__(value, i)
|
|
3705
|
-
initlist.append(str(scalar_value(x)).lower())
|
|
3706
|
-
|
|
3707
|
-
if value._wp_scalar_type_ is bool:
|
|
3708
|
-
dtypestr = f"wp::initializer_array<{value._length_},{value._wp_scalar_type_.__name__}>"
|
|
3709
|
-
else:
|
|
3710
|
-
dtypestr = f"wp::initializer_array<{value._length_},wp::{value._wp_scalar_type_.__name__}>"
|
|
3711
|
-
|
|
3712
|
-
# construct value from initializer array, e.g. wp::initializer_array<4,wp::float32>{1.0, 2.0, 3.0, 4.0}
|
|
3713
|
-
return f"{dtypestr}{{{', '.join(initlist)}}}"
|
|
3714
|
-
|
|
3715
|
-
elif value_type in warp.types.scalar_types:
|
|
3716
|
-
# make sure we emit the value of objects, e.g. uint32
|
|
3717
|
-
return str(value.value)
|
|
3718
|
-
|
|
3719
|
-
elif issubclass(value_type, warp.codegen.StructInstance):
|
|
3720
|
-
# constant struct instance
|
|
3721
|
-
arg_strs = []
|
|
3722
|
-
for key, var in value._cls.vars.items():
|
|
3723
|
-
attr = getattr(value, key)
|
|
3724
|
-
arg_strs.append(f"{Var.type_to_ctype(var.type)}({constant_str(attr)})")
|
|
3725
|
-
arg_str = ", ".join(arg_strs)
|
|
3726
|
-
return f"{value.native_name}({arg_str})"
|
|
3727
|
-
|
|
3728
|
-
elif value == math.inf:
|
|
3729
|
-
return "INFINITY"
|
|
3730
|
-
|
|
3731
|
-
elif math.isnan(value):
|
|
3732
|
-
return "NAN"
|
|
3733
|
-
|
|
3734
|
-
else:
|
|
3735
|
-
# otherwise just convert constant to string
|
|
3736
|
-
return str(value)
|
|
3737
|
-
|
|
3738
|
-
|
|
3739
|
-
def indent(args, stops=1):
|
|
3740
|
-
sep = ",\n"
|
|
3741
|
-
for _i in range(stops):
|
|
3742
|
-
sep += " "
|
|
3743
|
-
|
|
3744
|
-
# return sep + args.replace(", ", "," + sep)
|
|
3745
|
-
return sep.join(args)
|
|
3746
|
-
|
|
3747
|
-
|
|
3748
|
-
# generates a C function name based on the python function name
|
|
3749
|
-
def make_full_qualified_name(func: Union[str, Callable]) -> str:
|
|
3750
|
-
if not isinstance(func, str):
|
|
3751
|
-
func = func.__qualname__
|
|
3752
|
-
return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
|
|
3753
|
-
|
|
3754
|
-
|
|
3755
|
-
def codegen_struct(struct, device="cpu", indent_size=4):
|
|
3756
|
-
name = struct.native_name
|
|
3757
|
-
|
|
3758
|
-
body = []
|
|
3759
|
-
indent_block = " " * indent_size
|
|
3760
|
-
|
|
3761
|
-
if len(struct.vars) > 0:
|
|
3762
|
-
for label, var in struct.vars.items():
|
|
3763
|
-
body.append(var.ctype() + " " + label + ";\n")
|
|
3764
|
-
else:
|
|
3765
|
-
# for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
|
|
3766
|
-
body.append("char _dummy_;\n")
|
|
3767
|
-
|
|
3768
|
-
forward_args = []
|
|
3769
|
-
reverse_args = []
|
|
3770
|
-
|
|
3771
|
-
forward_initializers = []
|
|
3772
|
-
reverse_body = []
|
|
3773
|
-
atomic_add_body = []
|
|
3774
|
-
prefix_add_body = []
|
|
3775
|
-
|
|
3776
|
-
# forward args
|
|
3777
|
-
for label, var in struct.vars.items():
|
|
3778
|
-
var_ctype = var.ctype()
|
|
3779
|
-
default_arg_def = " = {}" if forward_args else ""
|
|
3780
|
-
forward_args.append(f"{var_ctype} const& {label}{default_arg_def}")
|
|
3781
|
-
reverse_args.append(f"{var_ctype} const&")
|
|
3782
|
-
|
|
3783
|
-
namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
|
|
3784
|
-
atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
|
|
3785
|
-
|
|
3786
|
-
prefix = f"{indent_block}," if forward_initializers else ":"
|
|
3787
|
-
forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
|
|
3788
|
-
|
|
3789
|
-
# prefix-add operator
|
|
3790
|
-
for label, var in struct.vars.items():
|
|
3791
|
-
if not is_array(var.type):
|
|
3792
|
-
prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
|
|
3793
|
-
|
|
3794
|
-
# reverse args
|
|
3795
|
-
for label, var in struct.vars.items():
|
|
3796
|
-
reverse_args.append(var.ctype() + " & adj_" + label)
|
|
3797
|
-
if is_array(var.type):
|
|
3798
|
-
reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
|
|
3799
|
-
else:
|
|
3800
|
-
reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
|
|
3801
|
-
|
|
3802
|
-
reverse_args.append(name + " & adj_ret")
|
|
3803
|
-
|
|
3804
|
-
# explicitly defaulted default constructor if no default constructor has been defined
|
|
3805
|
-
defaulted_constructor_def = f"{name}() = default;" if forward_args else ""
|
|
3806
|
-
|
|
3807
|
-
return struct_template.format(
|
|
3808
|
-
name=name,
|
|
3809
|
-
struct_body="".join([indent_block + l for l in body]),
|
|
3810
|
-
forward_args=indent(forward_args),
|
|
3811
|
-
forward_initializers="".join(forward_initializers),
|
|
3812
|
-
reverse_args=indent(reverse_args),
|
|
3813
|
-
reverse_body="".join(reverse_body),
|
|
3814
|
-
prefix_add_body="".join(prefix_add_body),
|
|
3815
|
-
atomic_add_body="".join(atomic_add_body),
|
|
3816
|
-
defaulted_constructor_def=defaulted_constructor_def,
|
|
3817
|
-
)
|
|
3818
|
-
|
|
3819
|
-
|
|
3820
|
-
def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
3821
|
-
if device == "cpu":
|
|
3822
|
-
indent = 4
|
|
3823
|
-
elif device == "cuda":
|
|
3824
|
-
if func_type == "kernel":
|
|
3825
|
-
indent = 8
|
|
3826
|
-
else:
|
|
3827
|
-
indent = 4
|
|
3828
|
-
else:
|
|
3829
|
-
raise ValueError(f"Device {device} not supported for codegen")
|
|
3830
|
-
|
|
3831
|
-
indent_block = " " * indent
|
|
3832
|
-
|
|
3833
|
-
lines = []
|
|
3834
|
-
|
|
3835
|
-
# argument vars
|
|
3836
|
-
if device == "cpu" and func_type == "kernel":
|
|
3837
|
-
lines += ["//---------\n"]
|
|
3838
|
-
lines += ["// argument vars\n"]
|
|
3839
|
-
|
|
3840
|
-
for var in adj.args:
|
|
3841
|
-
lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
|
|
3842
|
-
|
|
3843
|
-
# primal vars
|
|
3844
|
-
lines += ["//---------\n"]
|
|
3845
|
-
lines += ["// primal vars\n"]
|
|
3846
|
-
|
|
3847
|
-
for var in adj.variables:
|
|
3848
|
-
if is_tile(var.type):
|
|
3849
|
-
lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
|
|
3850
|
-
elif var.constant is None:
|
|
3851
|
-
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
3852
|
-
else:
|
|
3853
|
-
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
3854
|
-
|
|
3855
|
-
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3856
|
-
lines.insert(-1, f"{line_directive}\n")
|
|
3857
|
-
|
|
3858
|
-
# forward pass
|
|
3859
|
-
lines += ["//---------\n"]
|
|
3860
|
-
lines += ["// forward\n"]
|
|
3861
|
-
|
|
3862
|
-
for f in adj.blocks[0].body_forward:
|
|
3863
|
-
if func_type == "kernel" and device == "cuda" and f.lstrip().startswith("return;"):
|
|
3864
|
-
# Use of grid-stride loops in CUDA kernels requires that we convert return; to continue;
|
|
3865
|
-
lines += [f.replace("return;", "continue;") + "\n"]
|
|
3866
|
-
else:
|
|
3867
|
-
lines += [f + "\n"]
|
|
3868
|
-
|
|
3869
|
-
return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
|
|
3870
|
-
|
|
3871
|
-
|
|
3872
|
-
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
3873
|
-
if device == "cpu":
|
|
3874
|
-
indent = 4
|
|
3875
|
-
elif device == "cuda":
|
|
3876
|
-
if func_type == "kernel":
|
|
3877
|
-
indent = 8
|
|
3878
|
-
else:
|
|
3879
|
-
indent = 4
|
|
3880
|
-
else:
|
|
3881
|
-
raise ValueError(f"Device {device} not supported for codegen")
|
|
3882
|
-
|
|
3883
|
-
indent_block = " " * indent
|
|
3884
|
-
|
|
3885
|
-
lines = []
|
|
3886
|
-
|
|
3887
|
-
# argument vars
|
|
3888
|
-
if device == "cpu" and func_type == "kernel":
|
|
3889
|
-
lines += ["//---------\n"]
|
|
3890
|
-
lines += ["// argument vars\n"]
|
|
3891
|
-
|
|
3892
|
-
for var in adj.args:
|
|
3893
|
-
lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
|
|
3894
|
-
|
|
3895
|
-
for var in adj.args:
|
|
3896
|
-
lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
|
|
3897
|
-
|
|
3898
|
-
# primal vars
|
|
3899
|
-
lines += ["//---------\n"]
|
|
3900
|
-
lines += ["// primal vars\n"]
|
|
3901
|
-
|
|
3902
|
-
for var in adj.variables:
|
|
3903
|
-
if is_tile(var.type):
|
|
3904
|
-
lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
|
|
3905
|
-
elif var.constant is None:
|
|
3906
|
-
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
3907
|
-
else:
|
|
3908
|
-
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
3909
|
-
|
|
3910
|
-
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3911
|
-
lines.insert(-1, f"{line_directive}\n")
|
|
3912
|
-
|
|
3913
|
-
# dual vars
|
|
3914
|
-
lines += ["//---------\n"]
|
|
3915
|
-
lines += ["// dual vars\n"]
|
|
3916
|
-
|
|
3917
|
-
for var in adj.variables:
|
|
3918
|
-
name = var.emit_adj()
|
|
3919
|
-
ctype = var.ctype(value_type=True)
|
|
3920
|
-
|
|
3921
|
-
if is_tile(var.type):
|
|
3922
|
-
if var.type.storage == "register":
|
|
3923
|
-
lines += [
|
|
3924
|
-
f"{var.type.ctype()} {name}(0.0);\n"
|
|
3925
|
-
] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
|
|
3926
|
-
elif var.type.storage == "shared":
|
|
3927
|
-
lines += [
|
|
3928
|
-
f"{var.type.ctype()}& {name} = {var.emit()};\n"
|
|
3929
|
-
] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
|
|
3930
|
-
else:
|
|
3931
|
-
lines += [f"{ctype} {name} = {{}};\n"]
|
|
3932
|
-
|
|
3933
|
-
if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno):
|
|
3934
|
-
lines.insert(-1, f"{line_directive}\n")
|
|
3935
|
-
|
|
3936
|
-
# forward pass
|
|
3937
|
-
lines += ["//---------\n"]
|
|
3938
|
-
lines += ["// forward\n"]
|
|
3939
|
-
|
|
3940
|
-
for f in adj.blocks[0].body_replay:
|
|
3941
|
-
lines += [f + "\n"]
|
|
3942
|
-
|
|
3943
|
-
# reverse pass
|
|
3944
|
-
lines += ["//---------\n"]
|
|
3945
|
-
lines += ["// reverse\n"]
|
|
3946
|
-
|
|
3947
|
-
for l in reversed(adj.blocks[0].body_reverse):
|
|
3948
|
-
lines += [l + "\n"]
|
|
3949
|
-
|
|
3950
|
-
# In grid-stride kernels the reverse body is in a for loop
|
|
3951
|
-
if device == "cuda" and func_type == "kernel":
|
|
3952
|
-
lines += ["continue;\n"]
|
|
3953
|
-
else:
|
|
3954
|
-
lines += ["return;\n"]
|
|
3955
|
-
|
|
3956
|
-
return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
|
|
3957
|
-
|
|
3958
|
-
|
|
3959
|
-
def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
3960
|
-
if options is None:
|
|
3961
|
-
options = {}
|
|
3962
|
-
|
|
3963
|
-
if adj.return_var is not None and "return" in adj.arg_types:
|
|
3964
|
-
if get_origin(adj.arg_types["return"]) is tuple:
|
|
3965
|
-
if len(get_args(adj.arg_types["return"])) != len(adj.return_var):
|
|
3966
|
-
raise WarpCodegenError(
|
|
3967
|
-
f"The function `{adj.fun_name}` has its return type "
|
|
3968
|
-
f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
|
|
3969
|
-
f"but the code returns {len(adj.return_var)} values."
|
|
3970
|
-
)
|
|
3971
|
-
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var), match_generic=True):
|
|
3972
|
-
raise WarpCodegenError(
|
|
3973
|
-
f"The function `{adj.fun_name}` has its return type "
|
|
3974
|
-
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3975
|
-
f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
|
|
3976
|
-
)
|
|
3977
|
-
elif len(adj.return_var) > 1 and get_origin(adj.arg_types["return"]) is not tuple:
|
|
3978
|
-
raise WarpCodegenError(
|
|
3979
|
-
f"The function `{adj.fun_name}` has its return type "
|
|
3980
|
-
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3981
|
-
f"but the code returns {len(adj.return_var)} values."
|
|
3982
|
-
)
|
|
3983
|
-
elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
|
|
3984
|
-
raise WarpCodegenError(
|
|
3985
|
-
f"The function `{adj.fun_name}` has its return type "
|
|
3986
|
-
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|
|
3987
|
-
f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
|
|
3988
|
-
)
|
|
3989
|
-
elif (
|
|
3990
|
-
isinstance(adj.return_var[0].type, warp.types.fixedarray)
|
|
3991
|
-
and type(adj.arg_types["return"]) is warp.types.array
|
|
3992
|
-
):
|
|
3993
|
-
# If the return statement yields a `fixedarray` while the function is annotated
|
|
3994
|
-
# to return a standard `array`, then raise an error since the `fixedarray` storage
|
|
3995
|
-
# allocated on the stack will be freed once the function exits, meaning that the
|
|
3996
|
-
# resulting `array` instance will point to an invalid data.
|
|
3997
|
-
raise WarpCodegenError(
|
|
3998
|
-
f"The function `{adj.fun_name}` returns a fixed-size array "
|
|
3999
|
-
f"whereas it has its return type annotated as "
|
|
4000
|
-
f"`{warp.context.type_str(adj.arg_types['return'])}`."
|
|
4001
|
-
)
|
|
4002
|
-
|
|
4003
|
-
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
4004
|
-
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
4005
|
-
# a direct mapping to a Python source line.
|
|
4006
|
-
func_line_directive = ""
|
|
4007
|
-
if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
|
|
4008
|
-
func_line_directive = f"{line_directive}\n"
|
|
4009
|
-
|
|
4010
|
-
# forward header
|
|
4011
|
-
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
4012
|
-
return_type = adj.return_var[0].ctype()
|
|
4013
|
-
else:
|
|
4014
|
-
return_type = "void"
|
|
4015
|
-
|
|
4016
|
-
has_multiple_outputs = adj.return_var is not None and len(adj.return_var) != 1
|
|
4017
|
-
|
|
4018
|
-
forward_args = []
|
|
4019
|
-
reverse_args = []
|
|
4020
|
-
|
|
4021
|
-
# forward args
|
|
4022
|
-
for i, arg in enumerate(adj.args):
|
|
4023
|
-
s = f"{arg.ctype()} {arg.emit()}"
|
|
4024
|
-
forward_args.append(s)
|
|
4025
|
-
if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
|
|
4026
|
-
reverse_args.append(s)
|
|
4027
|
-
if has_multiple_outputs:
|
|
4028
|
-
for i, arg in enumerate(adj.return_var):
|
|
4029
|
-
forward_args.append(arg.ctype() + " & ret_" + str(i))
|
|
4030
|
-
reverse_args.append(arg.ctype() + " & ret_" + str(i))
|
|
4031
|
-
|
|
4032
|
-
# reverse args
|
|
4033
|
-
for i, arg in enumerate(adj.args):
|
|
4034
|
-
if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
|
|
4035
|
-
break
|
|
4036
|
-
# indexed array gradients are regular arrays
|
|
4037
|
-
if isinstance(arg.type, indexedarray):
|
|
4038
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4039
|
-
reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
|
|
4040
|
-
else:
|
|
4041
|
-
reverse_args.append(arg.ctype() + " & adj_" + arg.label)
|
|
4042
|
-
if has_multiple_outputs:
|
|
4043
|
-
for i, arg in enumerate(adj.return_var):
|
|
4044
|
-
reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
|
|
4045
|
-
elif return_type != "void":
|
|
4046
|
-
reverse_args.append(return_type + " & adj_ret")
|
|
4047
|
-
# custom output reverse args (user-declared)
|
|
4048
|
-
if adj.custom_reverse_mode:
|
|
4049
|
-
for arg in adj.args[adj.custom_reverse_num_input_args :]:
|
|
4050
|
-
reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
|
|
4051
|
-
|
|
4052
|
-
if device == "cpu":
|
|
4053
|
-
forward_template = cpu_forward_function_template
|
|
4054
|
-
reverse_template = cpu_reverse_function_template
|
|
4055
|
-
elif device == "cuda":
|
|
4056
|
-
forward_template = cuda_forward_function_template
|
|
4057
|
-
reverse_template = cuda_reverse_function_template
|
|
4058
|
-
else:
|
|
4059
|
-
raise ValueError(f"Device {device} is not supported")
|
|
4060
|
-
|
|
4061
|
-
# codegen body
|
|
4062
|
-
forward_body = codegen_func_forward(adj, func_type="function", device=device)
|
|
4063
|
-
|
|
4064
|
-
s = ""
|
|
4065
|
-
if not adj.skip_forward_codegen:
|
|
4066
|
-
s += forward_template.format(
|
|
4067
|
-
name=c_func_name,
|
|
4068
|
-
return_type=return_type,
|
|
4069
|
-
forward_args=indent(forward_args),
|
|
4070
|
-
forward_body=forward_body,
|
|
4071
|
-
filename=adj.filename,
|
|
4072
|
-
lineno=adj.fun_lineno,
|
|
4073
|
-
line_directive=func_line_directive,
|
|
4074
|
-
)
|
|
4075
|
-
|
|
4076
|
-
if not adj.skip_reverse_codegen:
|
|
4077
|
-
if adj.custom_reverse_mode:
|
|
4078
|
-
reverse_body = "\t// user-defined adjoint code\n" + forward_body
|
|
4079
|
-
else:
|
|
4080
|
-
if options.get("enable_backward", True) and adj.used_by_backward_kernel:
|
|
4081
|
-
reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
|
|
4082
|
-
else:
|
|
4083
|
-
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
|
|
4084
|
-
s += reverse_template.format(
|
|
4085
|
-
name=c_func_name,
|
|
4086
|
-
return_type=return_type,
|
|
4087
|
-
reverse_args=indent(reverse_args),
|
|
4088
|
-
forward_body=forward_body,
|
|
4089
|
-
reverse_body=reverse_body,
|
|
4090
|
-
filename=adj.filename,
|
|
4091
|
-
lineno=adj.fun_lineno,
|
|
4092
|
-
line_directive=func_line_directive,
|
|
4093
|
-
)
|
|
4094
|
-
|
|
4095
|
-
return s
|
|
4096
|
-
|
|
4097
|
-
|
|
4098
|
-
def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
|
|
4099
|
-
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
4100
|
-
return_type = adj.return_var[0].ctype()
|
|
4101
|
-
else:
|
|
4102
|
-
return_type = "void"
|
|
4103
|
-
|
|
4104
|
-
forward_args = []
|
|
4105
|
-
reverse_args = []
|
|
4106
|
-
|
|
4107
|
-
# forward args
|
|
4108
|
-
for _i, arg in enumerate(adj.args):
|
|
4109
|
-
s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
|
|
4110
|
-
forward_args.append(s)
|
|
4111
|
-
reverse_args.append(s)
|
|
4112
|
-
|
|
4113
|
-
# reverse args
|
|
4114
|
-
for _i, arg in enumerate(adj.args):
|
|
4115
|
-
if isinstance(arg.type, indexedarray):
|
|
4116
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4117
|
-
reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
|
|
4118
|
-
else:
|
|
4119
|
-
reverse_args.append(arg.ctype() + " & adj_" + arg.label)
|
|
4120
|
-
if return_type != "void":
|
|
4121
|
-
reverse_args.append(return_type + " & adj_ret")
|
|
4122
|
-
|
|
4123
|
-
forward_template = cuda_forward_function_template
|
|
4124
|
-
replay_template = cuda_forward_function_template
|
|
4125
|
-
reverse_template = cuda_reverse_function_template
|
|
4126
|
-
|
|
4127
|
-
s = ""
|
|
4128
|
-
s += forward_template.format(
|
|
4129
|
-
name=name,
|
|
4130
|
-
return_type=return_type,
|
|
4131
|
-
forward_args=indent(forward_args),
|
|
4132
|
-
forward_body=snippet,
|
|
4133
|
-
filename=adj.filename,
|
|
4134
|
-
lineno=adj.fun_lineno,
|
|
4135
|
-
line_directive="",
|
|
4136
|
-
)
|
|
4137
|
-
|
|
4138
|
-
if replay_snippet is not None:
|
|
4139
|
-
s += replay_template.format(
|
|
4140
|
-
name="replay_" + name,
|
|
4141
|
-
return_type=return_type,
|
|
4142
|
-
forward_args=indent(forward_args),
|
|
4143
|
-
forward_body=replay_snippet,
|
|
4144
|
-
filename=adj.filename,
|
|
4145
|
-
lineno=adj.fun_lineno,
|
|
4146
|
-
line_directive="",
|
|
4147
|
-
)
|
|
4148
|
-
|
|
4149
|
-
if adj_snippet:
|
|
4150
|
-
reverse_body = adj_snippet
|
|
4151
|
-
else:
|
|
4152
|
-
reverse_body = ""
|
|
4153
|
-
|
|
4154
|
-
s += reverse_template.format(
|
|
4155
|
-
name=name,
|
|
4156
|
-
return_type=return_type,
|
|
4157
|
-
reverse_args=indent(reverse_args),
|
|
4158
|
-
forward_body=snippet,
|
|
4159
|
-
reverse_body=reverse_body,
|
|
4160
|
-
filename=adj.filename,
|
|
4161
|
-
lineno=adj.fun_lineno,
|
|
4162
|
-
line_directive="",
|
|
4163
|
-
)
|
|
4164
|
-
|
|
4165
|
-
return s
|
|
4166
|
-
|
|
4167
|
-
|
|
4168
|
-
def codegen_kernel(kernel, device, options):
|
|
4169
|
-
# Update the module's options with the ones defined on the kernel, if any.
|
|
4170
|
-
options = dict(options)
|
|
4171
|
-
options.update(kernel.options)
|
|
4172
|
-
|
|
4173
|
-
adj = kernel.adj
|
|
4174
|
-
|
|
4175
|
-
args_struct = ""
|
|
4176
|
-
if device == "cpu":
|
|
4177
|
-
args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
|
|
4178
|
-
for i in adj.args:
|
|
4179
|
-
args_struct += f" {i.ctype()} {i.label};\n"
|
|
4180
|
-
args_struct += "};\n"
|
|
4181
|
-
|
|
4182
|
-
# Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
|
|
4183
|
-
# This is used as a catch-all C-to-Python source line mapping for any code that does not have
|
|
4184
|
-
# a direct mapping to a Python source line.
|
|
4185
|
-
func_line_directive = ""
|
|
4186
|
-
if line_directive := adj.get_line_directive("", adj.fun_def_lineno - 1):
|
|
4187
|
-
func_line_directive = f"{line_directive}\n"
|
|
4188
|
-
|
|
4189
|
-
if device == "cpu":
|
|
4190
|
-
template_forward = cpu_kernel_template_forward
|
|
4191
|
-
template_backward = cpu_kernel_template_backward
|
|
4192
|
-
elif device == "cuda":
|
|
4193
|
-
template_forward = cuda_kernel_template_forward
|
|
4194
|
-
template_backward = cuda_kernel_template_backward
|
|
4195
|
-
else:
|
|
4196
|
-
raise ValueError(f"Device {device} is not supported")
|
|
4197
|
-
|
|
4198
|
-
template = ""
|
|
4199
|
-
template_fmt_args = {
|
|
4200
|
-
"name": kernel.get_mangled_name(),
|
|
4201
|
-
}
|
|
4202
|
-
|
|
4203
|
-
# build forward signature
|
|
4204
|
-
forward_args = ["wp::launch_bounds_t dim"]
|
|
4205
|
-
if device == "cpu":
|
|
4206
|
-
forward_args.append("size_t task_index")
|
|
4207
|
-
else:
|
|
4208
|
-
for arg in adj.args:
|
|
4209
|
-
forward_args.append(arg.ctype() + " var_" + arg.label)
|
|
4210
|
-
|
|
4211
|
-
forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
|
|
4212
|
-
template_fmt_args.update(
|
|
4213
|
-
{
|
|
4214
|
-
"forward_args": indent(forward_args),
|
|
4215
|
-
"forward_body": forward_body,
|
|
4216
|
-
"line_directive": func_line_directive,
|
|
4217
|
-
}
|
|
4218
|
-
)
|
|
4219
|
-
template += template_forward
|
|
4220
|
-
|
|
4221
|
-
if options["enable_backward"]:
|
|
4222
|
-
# build reverse signature
|
|
4223
|
-
reverse_args = ["wp::launch_bounds_t dim"]
|
|
4224
|
-
if device == "cpu":
|
|
4225
|
-
reverse_args.append("size_t task_index")
|
|
4226
|
-
else:
|
|
4227
|
-
for arg in adj.args:
|
|
4228
|
-
reverse_args.append(arg.ctype() + " var_" + arg.label)
|
|
4229
|
-
for arg in adj.args:
|
|
4230
|
-
# indexed array gradients are regular arrays
|
|
4231
|
-
if isinstance(arg.type, indexedarray):
|
|
4232
|
-
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
4233
|
-
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
|
|
4234
|
-
else:
|
|
4235
|
-
reverse_args.append(arg.ctype() + " adj_" + arg.label)
|
|
4236
|
-
|
|
4237
|
-
reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
|
|
4238
|
-
template_fmt_args.update(
|
|
4239
|
-
{
|
|
4240
|
-
"reverse_args": indent(reverse_args),
|
|
4241
|
-
"reverse_body": reverse_body,
|
|
4242
|
-
}
|
|
4243
|
-
)
|
|
4244
|
-
template += template_backward
|
|
4245
|
-
|
|
4246
|
-
s = template.format(**template_fmt_args)
|
|
4247
|
-
return args_struct + s
|
|
4248
|
-
|
|
4249
|
-
|
|
4250
|
-
def codegen_module(kernel, device, options):
|
|
4251
|
-
if device != "cpu":
|
|
4252
|
-
return ""
|
|
4253
|
-
|
|
4254
|
-
# Update the module's options with the ones defined on the kernel, if any.
|
|
4255
|
-
options = dict(options)
|
|
4256
|
-
options.update(kernel.options)
|
|
4257
|
-
|
|
4258
|
-
template = ""
|
|
4259
|
-
template_fmt_args = {
|
|
4260
|
-
"name": kernel.get_mangled_name(),
|
|
4261
|
-
}
|
|
4262
|
-
|
|
4263
|
-
template += cpu_module_template_forward
|
|
4264
|
-
|
|
4265
|
-
if options["enable_backward"]:
|
|
4266
|
-
template += cpu_module_template_backward
|
|
4267
|
-
|
|
4268
|
-
s = template.format(**template_fmt_args)
|
|
4269
|
-
return s
|
|
24
|
+
return get_deprecated_api(_codegen, "wp", name)
|