warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -7,23 +7,40 @@
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
import re
|
|
11
|
-
import sys
|
|
12
10
|
import ast
|
|
13
|
-
import
|
|
11
|
+
import builtins
|
|
14
12
|
import ctypes
|
|
13
|
+
import inspect
|
|
14
|
+
import math
|
|
15
|
+
import re
|
|
16
|
+
import sys
|
|
15
17
|
import textwrap
|
|
16
18
|
import types
|
|
19
|
+
from typing import Any, Callable, Mapping
|
|
17
20
|
|
|
18
|
-
import
|
|
21
|
+
import warp.config
|
|
22
|
+
from warp.types import *
|
|
19
23
|
|
|
20
|
-
from typing import Any
|
|
21
|
-
from typing import Callable
|
|
22
|
-
from typing import Mapping
|
|
23
|
-
from typing import Union
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
class WarpCodegenError(RuntimeError):
|
|
26
|
+
def __init__(self, message):
|
|
27
|
+
super().__init__(message)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class WarpCodegenTypeError(TypeError):
|
|
31
|
+
def __init__(self, message):
|
|
32
|
+
super().__init__(message)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class WarpCodegenAttributeError(AttributeError):
|
|
36
|
+
def __init__(self, message):
|
|
37
|
+
super().__init__(message)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class WarpCodegenKeyError(KeyError):
|
|
41
|
+
def __init__(self, message):
|
|
42
|
+
super().__init__(message)
|
|
43
|
+
|
|
27
44
|
|
|
28
45
|
# map operator to function name
|
|
29
46
|
builtin_operators = {}
|
|
@@ -57,6 +74,19 @@ builtin_operators[ast.Invert] = "invert"
|
|
|
57
74
|
builtin_operators[ast.LShift] = "lshift"
|
|
58
75
|
builtin_operators[ast.RShift] = "rshift"
|
|
59
76
|
|
|
77
|
+
comparison_chain_strings = [
|
|
78
|
+
builtin_operators[ast.Gt],
|
|
79
|
+
builtin_operators[ast.Lt],
|
|
80
|
+
builtin_operators[ast.LtE],
|
|
81
|
+
builtin_operators[ast.GtE],
|
|
82
|
+
builtin_operators[ast.Eq],
|
|
83
|
+
builtin_operators[ast.NotEq],
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def op_str_is_chainable(op: str) -> builtins.bool:
|
|
88
|
+
return op in comparison_chain_strings
|
|
89
|
+
|
|
60
90
|
|
|
61
91
|
def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
62
92
|
"""Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
|
|
@@ -70,16 +100,14 @@ def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
|
70
100
|
def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
|
|
71
101
|
indent = "\t"
|
|
72
102
|
|
|
73
|
-
|
|
103
|
+
# handle empty structs
|
|
104
|
+
if len(inst._cls.vars) == 0:
|
|
74
105
|
return f"{inst._cls.key}()"
|
|
75
106
|
|
|
76
107
|
lines = []
|
|
77
108
|
lines.append(f"{inst._cls.key}(")
|
|
78
109
|
|
|
79
110
|
for field_name, _ in inst._cls.ctype._fields_:
|
|
80
|
-
if field_name == "_dummy_":
|
|
81
|
-
continue
|
|
82
|
-
|
|
83
111
|
field_value = getattr(inst, field_name, None)
|
|
84
112
|
|
|
85
113
|
if isinstance(field_value, StructInstance):
|
|
@@ -126,9 +154,7 @@ class StructInstance:
|
|
|
126
154
|
assert isinstance(value, array)
|
|
127
155
|
assert types_equal(
|
|
128
156
|
value.dtype, var.type.dtype
|
|
129
|
-
), "assign to struct member variable {} failed, expected type {}, got type {}"
|
|
130
|
-
name, type_repr(var.type.dtype), type_repr(value.dtype)
|
|
131
|
-
)
|
|
157
|
+
), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
|
|
132
158
|
setattr(self._ctype, name, value.__ctype__())
|
|
133
159
|
|
|
134
160
|
elif isinstance(var.type, Struct):
|
|
@@ -247,7 +273,7 @@ class Struct:
|
|
|
247
273
|
|
|
248
274
|
class StructType(ctypes.Structure):
|
|
249
275
|
# if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
|
|
250
|
-
_fields_ = fields or [("_dummy_", ctypes.
|
|
276
|
+
_fields_ = fields or [("_dummy_", ctypes.c_byte)]
|
|
251
277
|
|
|
252
278
|
self.ctype = StructType
|
|
253
279
|
|
|
@@ -368,21 +394,38 @@ class Struct:
|
|
|
368
394
|
return instance
|
|
369
395
|
|
|
370
396
|
|
|
397
|
+
class Reference:
|
|
398
|
+
def __init__(self, value_type):
|
|
399
|
+
self.value_type = value_type
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def is_reference(type):
|
|
403
|
+
return isinstance(type, Reference)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def strip_reference(arg):
|
|
407
|
+
if is_reference(arg):
|
|
408
|
+
return arg.value_type
|
|
409
|
+
else:
|
|
410
|
+
return arg
|
|
411
|
+
|
|
412
|
+
|
|
371
413
|
def compute_type_str(base_name, template_params):
|
|
372
|
-
if
|
|
414
|
+
if not template_params:
|
|
373
415
|
return base_name
|
|
374
|
-
else:
|
|
375
416
|
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
417
|
+
def param2str(p):
|
|
418
|
+
if isinstance(p, int):
|
|
419
|
+
return str(p)
|
|
420
|
+
elif hasattr(p, "_type_"):
|
|
421
|
+
return f"wp::{p.__name__}"
|
|
422
|
+
return p.__name__
|
|
380
423
|
|
|
381
|
-
|
|
424
|
+
return f"{base_name}<{','.join(map(param2str, template_params))}>"
|
|
382
425
|
|
|
383
426
|
|
|
384
427
|
class Var:
|
|
385
|
-
def __init__(self, label, type, requires_grad=False, constant=None):
|
|
428
|
+
def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
|
|
386
429
|
# convert built-in types to wp types
|
|
387
430
|
if type == float:
|
|
388
431
|
type = float32
|
|
@@ -393,26 +436,49 @@ class Var:
|
|
|
393
436
|
self.type = type
|
|
394
437
|
self.requires_grad = requires_grad
|
|
395
438
|
self.constant = constant
|
|
439
|
+
self.prefix = prefix
|
|
396
440
|
|
|
397
441
|
def __str__(self):
|
|
398
442
|
return self.label
|
|
399
443
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
444
|
+
@staticmethod
|
|
445
|
+
def type_to_ctype(t, value_type=False):
|
|
446
|
+
if is_array(t):
|
|
447
|
+
if hasattr(t.dtype, "_wp_generic_type_str_"):
|
|
448
|
+
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
449
|
+
elif isinstance(t.dtype, Struct):
|
|
450
|
+
dtypestr = make_full_qualified_name(t.dtype.cls)
|
|
451
|
+
elif t.dtype.__name__ in ("bool", "int", "float"):
|
|
452
|
+
dtypestr = t.dtype.__name__
|
|
406
453
|
else:
|
|
407
|
-
dtypestr =
|
|
408
|
-
classstr = type(
|
|
454
|
+
dtypestr = f"wp::{t.dtype.__name__}"
|
|
455
|
+
classstr = f"wp::{type(t).__name__}"
|
|
409
456
|
return f"{classstr}_t<{dtypestr}>"
|
|
410
|
-
elif isinstance(
|
|
411
|
-
return make_full_qualified_name(
|
|
412
|
-
elif
|
|
413
|
-
|
|
457
|
+
elif isinstance(t, Struct):
|
|
458
|
+
return make_full_qualified_name(t.cls)
|
|
459
|
+
elif is_reference(t):
|
|
460
|
+
if not value_type:
|
|
461
|
+
return Var.type_to_ctype(t.value_type) + "*"
|
|
462
|
+
else:
|
|
463
|
+
return Var.type_to_ctype(t.value_type)
|
|
464
|
+
elif hasattr(t, "_wp_generic_type_str_"):
|
|
465
|
+
return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
|
|
466
|
+
elif t.__name__ in ("bool", "int", "float"):
|
|
467
|
+
return t.__name__
|
|
468
|
+
else:
|
|
469
|
+
return f"wp::{t.__name__}"
|
|
470
|
+
|
|
471
|
+
def ctype(self, value_type=False):
|
|
472
|
+
return Var.type_to_ctype(self.type, value_type)
|
|
473
|
+
|
|
474
|
+
def emit(self, prefix: str = "var"):
|
|
475
|
+
if self.prefix:
|
|
476
|
+
return f"{prefix}_{self.label}"
|
|
414
477
|
else:
|
|
415
|
-
return
|
|
478
|
+
return self.label
|
|
479
|
+
|
|
480
|
+
def emit_adj(self):
|
|
481
|
+
return self.emit("adj")
|
|
416
482
|
|
|
417
483
|
|
|
418
484
|
class Block:
|
|
@@ -429,35 +495,65 @@ class Block:
|
|
|
429
495
|
self.vars = []
|
|
430
496
|
|
|
431
497
|
|
|
498
|
+
def is_local_value(value) -> bool:
|
|
499
|
+
"""Check whether a variable is defined inside a kernel."""
|
|
500
|
+
return isinstance(value, (warp.context.Function, Var))
|
|
501
|
+
|
|
502
|
+
|
|
432
503
|
class Adjoint:
|
|
433
504
|
# Source code transformer, this class takes a Python function and
|
|
434
505
|
# generates forward and backward SSA forms of the function instructions
|
|
435
506
|
|
|
436
|
-
def __init__(
|
|
507
|
+
def __init__(
|
|
508
|
+
adj,
|
|
509
|
+
func,
|
|
510
|
+
overload_annotations=None,
|
|
511
|
+
is_user_function=False,
|
|
512
|
+
skip_forward_codegen=False,
|
|
513
|
+
skip_reverse_codegen=False,
|
|
514
|
+
custom_reverse_mode=False,
|
|
515
|
+
custom_reverse_num_input_args=-1,
|
|
516
|
+
transformers: List[ast.NodeTransformer] = [],
|
|
517
|
+
):
|
|
437
518
|
adj.func = func
|
|
438
519
|
|
|
439
|
-
|
|
440
|
-
adj.source = inspect.getsource(func)
|
|
520
|
+
adj.is_user_function = is_user_function
|
|
441
521
|
|
|
442
|
-
#
|
|
443
|
-
adj.
|
|
522
|
+
# whether the generation of the forward code is skipped for this function
|
|
523
|
+
adj.skip_forward_codegen = skip_forward_codegen
|
|
524
|
+
# whether the generation of the adjoint code is skipped for this function
|
|
525
|
+
adj.skip_reverse_codegen = skip_reverse_codegen
|
|
444
526
|
|
|
445
|
-
#
|
|
446
|
-
adj.
|
|
527
|
+
# extract name of source file
|
|
528
|
+
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
529
|
+
# get source file line number where function starts
|
|
530
|
+
_, adj.fun_lineno = inspect.getsourcelines(func)
|
|
447
531
|
|
|
532
|
+
# get function source code
|
|
533
|
+
adj.source = inspect.getsource(func)
|
|
448
534
|
# ensures that indented class methods can be parsed as kernels
|
|
449
535
|
adj.source = textwrap.dedent(adj.source)
|
|
450
536
|
|
|
451
|
-
|
|
452
|
-
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
537
|
+
adj.source_lines = adj.source.splitlines()
|
|
453
538
|
|
|
454
539
|
# build AST and apply node transformers
|
|
455
540
|
adj.tree = ast.parse(adj.source)
|
|
541
|
+
adj.transformers = transformers
|
|
456
542
|
for transformer in transformers:
|
|
457
543
|
adj.tree = transformer.visit(adj.tree)
|
|
458
544
|
|
|
459
545
|
adj.fun_name = adj.tree.body[0].name
|
|
460
546
|
|
|
547
|
+
# for keeping track of line number in function code
|
|
548
|
+
adj.lineno = None
|
|
549
|
+
|
|
550
|
+
# whether the forward code shall be used for the reverse pass and a custom
|
|
551
|
+
# function signature is applied to the reverse version of the function
|
|
552
|
+
adj.custom_reverse_mode = custom_reverse_mode
|
|
553
|
+
# the number of function arguments that pertain to the forward function
|
|
554
|
+
# input arguments (i.e. the number of arguments that are not adjoint arguments)
|
|
555
|
+
adj.custom_reverse_num_input_args = custom_reverse_num_input_args
|
|
556
|
+
|
|
461
557
|
# parse argument types
|
|
462
558
|
argspec = inspect.getfullargspec(func)
|
|
463
559
|
|
|
@@ -465,16 +561,17 @@ class Adjoint:
|
|
|
465
561
|
if overload_annotations is None:
|
|
466
562
|
# use source-level argument annotations
|
|
467
563
|
if len(argspec.annotations) < len(argspec.args):
|
|
468
|
-
raise
|
|
564
|
+
raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
|
|
469
565
|
adj.arg_types = argspec.annotations
|
|
470
566
|
else:
|
|
471
567
|
# use overload argument annotations
|
|
472
568
|
for arg_name in argspec.args:
|
|
473
569
|
if arg_name not in overload_annotations:
|
|
474
|
-
raise
|
|
570
|
+
raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
|
|
475
571
|
adj.arg_types = overload_annotations.copy()
|
|
476
572
|
|
|
477
573
|
adj.args = []
|
|
574
|
+
adj.symbols = {}
|
|
478
575
|
|
|
479
576
|
for name, type in adj.arg_types.items():
|
|
480
577
|
# skip return hint
|
|
@@ -485,8 +582,23 @@ class Adjoint:
|
|
|
485
582
|
arg = Var(name, type, False)
|
|
486
583
|
adj.args.append(arg)
|
|
487
584
|
|
|
585
|
+
# pre-populate symbol dictionary with function argument names
|
|
586
|
+
# this is to avoid registering false references to overshadowed modules
|
|
587
|
+
adj.symbols[name] = arg
|
|
588
|
+
|
|
589
|
+
# There are cases where a same module might be rebuilt multiple times,
|
|
590
|
+
# for example when kernels are nested inside of functions, or when
|
|
591
|
+
# a kernel's launch raises an exception. Ideally we'd always want to
|
|
592
|
+
# avoid rebuilding kernels but some corner cases seem to depend on it,
|
|
593
|
+
# so we only avoid rebuilding kernels that errored out to give a chance
|
|
594
|
+
# for unit testing errors being spit out from kernels.
|
|
595
|
+
adj.skip_build = False
|
|
596
|
+
|
|
488
597
|
# generate function ssa form and adjoint
|
|
489
598
|
def build(adj, builder):
|
|
599
|
+
if adj.skip_build:
|
|
600
|
+
return
|
|
601
|
+
|
|
490
602
|
adj.builder = builder
|
|
491
603
|
|
|
492
604
|
adj.symbols = {} # map from symbols to adjoint variables
|
|
@@ -500,7 +612,7 @@ class Adjoint:
|
|
|
500
612
|
adj.loop_blocks = []
|
|
501
613
|
|
|
502
614
|
# holds current indent level
|
|
503
|
-
adj.
|
|
615
|
+
adj.indentation = ""
|
|
504
616
|
|
|
505
617
|
# used to generate new label indices
|
|
506
618
|
adj.label_count = 0
|
|
@@ -514,19 +626,25 @@ class Adjoint:
|
|
|
514
626
|
adj.eval(adj.tree.body[0])
|
|
515
627
|
except Exception as e:
|
|
516
628
|
try:
|
|
629
|
+
if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
|
|
630
|
+
msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
|
|
631
|
+
else:
|
|
632
|
+
msg = "Error"
|
|
517
633
|
lineno = adj.lineno + adj.fun_lineno
|
|
518
|
-
line = adj.
|
|
519
|
-
msg
|
|
634
|
+
line = adj.source_lines[adj.lineno]
|
|
635
|
+
msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
520
636
|
ex, data, traceback = sys.exc_info()
|
|
521
|
-
e = ex("".join([msg] +
|
|
637
|
+
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
|
|
522
638
|
finally:
|
|
639
|
+
adj.skip_build = True
|
|
523
640
|
raise e
|
|
524
641
|
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
642
|
+
if builder is not None:
|
|
643
|
+
for a in adj.args:
|
|
644
|
+
if isinstance(a.type, Struct):
|
|
645
|
+
builder.build_struct_recursive(a.type)
|
|
646
|
+
elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
|
|
647
|
+
builder.build_struct_recursive(a.type.dtype)
|
|
530
648
|
|
|
531
649
|
# code generation methods
|
|
532
650
|
def format_template(adj, template, input_vars, output_var):
|
|
@@ -541,44 +659,56 @@ class Adjoint:
|
|
|
541
659
|
arg_strs = []
|
|
542
660
|
|
|
543
661
|
for a in args:
|
|
544
|
-
if
|
|
662
|
+
if isinstance(a, warp.context.Function):
|
|
545
663
|
# functions don't have a var_ prefix so strip it off here
|
|
546
|
-
if prefix == "
|
|
664
|
+
if prefix == "var":
|
|
547
665
|
arg_strs.append(a.key)
|
|
548
666
|
else:
|
|
549
|
-
arg_strs.append(prefix
|
|
550
|
-
|
|
667
|
+
arg_strs.append(f"{prefix}_{a.key}")
|
|
668
|
+
elif is_reference(a.type):
|
|
669
|
+
arg_strs.append(f"{prefix}_{a}")
|
|
670
|
+
elif isinstance(a, Var):
|
|
671
|
+
arg_strs.append(a.emit(prefix))
|
|
551
672
|
else:
|
|
552
|
-
|
|
673
|
+
raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
|
|
553
674
|
|
|
554
675
|
return arg_strs
|
|
555
676
|
|
|
556
677
|
# generates argument string for a forward function call
|
|
557
678
|
def format_forward_call_args(adj, args, use_initializer_list):
|
|
558
|
-
arg_str = ", ".join(adj.format_args("
|
|
679
|
+
arg_str = ", ".join(adj.format_args("var", args))
|
|
559
680
|
if use_initializer_list:
|
|
560
|
-
return "{{{}}}"
|
|
681
|
+
return f"{{{arg_str}}}"
|
|
561
682
|
return arg_str
|
|
562
683
|
|
|
563
684
|
# generates argument string for a reverse function call
|
|
564
|
-
def format_reverse_call_args(
|
|
565
|
-
|
|
685
|
+
def format_reverse_call_args(
|
|
686
|
+
adj,
|
|
687
|
+
args_var,
|
|
688
|
+
args,
|
|
689
|
+
args_out,
|
|
690
|
+
use_initializer_list,
|
|
691
|
+
has_output_args=True,
|
|
692
|
+
require_original_output_arg=False,
|
|
693
|
+
):
|
|
694
|
+
formatted_var = adj.format_args("var", args_var)
|
|
566
695
|
formatted_out = []
|
|
567
|
-
if len(args_out) > 1:
|
|
568
|
-
formatted_out = adj.format_args("
|
|
696
|
+
if has_output_args and (require_original_output_arg or len(args_out) > 1):
|
|
697
|
+
formatted_out = adj.format_args("var", args_out)
|
|
569
698
|
formatted_var_adj = adj.format_args(
|
|
570
|
-
"&
|
|
699
|
+
"&adj" if use_initializer_list else "adj",
|
|
700
|
+
args,
|
|
571
701
|
)
|
|
572
|
-
formatted_out_adj = adj.format_args("
|
|
702
|
+
formatted_out_adj = adj.format_args("adj", args_out)
|
|
573
703
|
|
|
574
704
|
if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
|
|
575
705
|
# there are no adjoint arguments, so we don't need to call the reverse function
|
|
576
706
|
return None
|
|
577
707
|
|
|
578
708
|
if use_initializer_list:
|
|
579
|
-
var_str = "{{{
|
|
580
|
-
out_str = "{{{
|
|
581
|
-
adj_str = "{{{
|
|
709
|
+
var_str = f"{{{', '.join(formatted_var)}}}"
|
|
710
|
+
out_str = f"{{{', '.join(formatted_out)}}}"
|
|
711
|
+
adj_str = f"{{{', '.join(formatted_var_adj)}}}"
|
|
582
712
|
out_adj_str = ", ".join(formatted_out_adj)
|
|
583
713
|
if len(args_out) > 1:
|
|
584
714
|
arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
|
|
@@ -589,10 +719,10 @@ class Adjoint:
|
|
|
589
719
|
return arg_str
|
|
590
720
|
|
|
591
721
|
def indent(adj):
|
|
592
|
-
adj.
|
|
722
|
+
adj.indentation = adj.indentation + " "
|
|
593
723
|
|
|
594
724
|
def dedent(adj):
|
|
595
|
-
adj.
|
|
725
|
+
adj.indentation = adj.indentation[:-4]
|
|
596
726
|
|
|
597
727
|
def begin_block(adj):
|
|
598
728
|
b = Block()
|
|
@@ -607,10 +737,9 @@ class Adjoint:
|
|
|
607
737
|
def end_block(adj):
|
|
608
738
|
return adj.blocks.pop()
|
|
609
739
|
|
|
610
|
-
def add_var(adj, type=None, constant=None
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
name = str(index)
|
|
740
|
+
def add_var(adj, type=None, constant=None):
|
|
741
|
+
index = len(adj.variables)
|
|
742
|
+
name = str(index)
|
|
614
743
|
|
|
615
744
|
# allocate new variable
|
|
616
745
|
v = Var(name, type=type, constant=constant)
|
|
@@ -623,30 +752,54 @@ class Adjoint:
|
|
|
623
752
|
|
|
624
753
|
# append a statement to the forward pass
|
|
625
754
|
def add_forward(adj, statement, replay=None, skip_replay=False):
|
|
626
|
-
adj.blocks[-1].body_forward.append(adj.
|
|
755
|
+
adj.blocks[-1].body_forward.append(adj.indentation + statement)
|
|
627
756
|
|
|
628
757
|
if not skip_replay:
|
|
629
758
|
if replay:
|
|
630
759
|
# if custom replay specified then output it
|
|
631
|
-
adj.blocks[-1].body_replay.append(adj.
|
|
760
|
+
adj.blocks[-1].body_replay.append(adj.indentation + replay)
|
|
632
761
|
else:
|
|
633
762
|
# by default just replay the original statement
|
|
634
|
-
adj.blocks[-1].body_replay.append(adj.
|
|
763
|
+
adj.blocks[-1].body_replay.append(adj.indentation + statement)
|
|
635
764
|
|
|
636
765
|
# append a statement to the reverse pass
|
|
637
766
|
def add_reverse(adj, statement):
|
|
638
|
-
adj.blocks[-1].body_reverse.append(adj.
|
|
767
|
+
adj.blocks[-1].body_reverse.append(adj.indentation + statement)
|
|
639
768
|
|
|
640
769
|
def add_constant(adj, n):
|
|
641
770
|
output = adj.add_var(type=type(n), constant=n)
|
|
642
771
|
return output
|
|
643
772
|
|
|
773
|
+
def load(adj, var):
|
|
774
|
+
if is_reference(var.type):
|
|
775
|
+
var = adj.add_builtin_call("load", [var])
|
|
776
|
+
return var
|
|
777
|
+
|
|
644
778
|
def add_comp(adj, op_strings, left, comps):
|
|
645
|
-
output = adj.add_var(bool)
|
|
779
|
+
output = adj.add_var(builtins.bool)
|
|
780
|
+
|
|
781
|
+
left = adj.load(left)
|
|
782
|
+
s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
|
|
783
|
+
|
|
784
|
+
prev_comp = None
|
|
646
785
|
|
|
647
|
-
s = "var_" + str(output) + " = " + ("(" * len(comps)) + "var_" + str(left) + " "
|
|
648
786
|
for op, comp in zip(op_strings, comps):
|
|
649
|
-
|
|
787
|
+
comp_chainable = op_str_is_chainable(op)
|
|
788
|
+
if comp_chainable and prev_comp:
|
|
789
|
+
# We restrict chaining to operands of the same type
|
|
790
|
+
if prev_comp.type is comp.type:
|
|
791
|
+
prev_comp = adj.load(prev_comp)
|
|
792
|
+
comp = adj.load(comp)
|
|
793
|
+
s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
|
|
794
|
+
else:
|
|
795
|
+
raise WarpCodegenTypeError(
|
|
796
|
+
f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
|
|
797
|
+
)
|
|
798
|
+
else:
|
|
799
|
+
comp = adj.load(comp)
|
|
800
|
+
s += op + " " + comp.emit() + ") "
|
|
801
|
+
|
|
802
|
+
prev_comp = comp
|
|
650
803
|
|
|
651
804
|
s = s.rstrip() + ";"
|
|
652
805
|
|
|
@@ -655,110 +808,106 @@ class Adjoint:
|
|
|
655
808
|
return output
|
|
656
809
|
|
|
657
810
|
def add_bool_op(adj, op_string, exprs):
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
)
|
|
811
|
+
exprs = [adj.load(expr) for expr in exprs]
|
|
812
|
+
output = adj.add_var(builtins.bool)
|
|
813
|
+
command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
|
|
662
814
|
adj.add_forward(command)
|
|
663
815
|
|
|
664
816
|
return output
|
|
665
817
|
|
|
666
|
-
def
|
|
667
|
-
|
|
668
|
-
# we validate argument types before they go to generated native code
|
|
669
|
-
resolved_func = None
|
|
818
|
+
def resolve_func(adj, func, args, min_outputs, templates, kwds):
|
|
819
|
+
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
|
|
670
820
|
|
|
671
|
-
if func.is_builtin():
|
|
821
|
+
if not func.is_builtin():
|
|
822
|
+
# user-defined function
|
|
823
|
+
overload = func.get_overload(arg_types)
|
|
824
|
+
if overload is not None:
|
|
825
|
+
return overload
|
|
826
|
+
else:
|
|
827
|
+
# if func is overloaded then perform overload resolution here
|
|
828
|
+
# we validate argument types before they go to generated native code
|
|
672
829
|
for f in func.overloads:
|
|
673
|
-
match = True
|
|
674
|
-
|
|
675
830
|
# skip type checking for variadic functions
|
|
676
831
|
if not f.variadic:
|
|
677
832
|
# check argument counts match are compatible (may be some default args)
|
|
678
833
|
if len(f.input_types) < len(args):
|
|
679
|
-
match = False
|
|
680
834
|
continue
|
|
681
835
|
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
836
|
+
def match_args(args, f):
|
|
837
|
+
# check argument types equal
|
|
838
|
+
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
|
|
839
|
+
# if arg type registered as Any, treat as
|
|
840
|
+
# template allowing any type to match
|
|
841
|
+
if arg_type == Any:
|
|
842
|
+
continue
|
|
843
|
+
|
|
844
|
+
# handle function refs as a special case
|
|
845
|
+
if arg_type == Callable and type(args[i]) is warp.context.Function:
|
|
846
|
+
continue
|
|
847
|
+
|
|
848
|
+
if arg_type == Reference and is_reference(args[i].type):
|
|
849
|
+
continue
|
|
850
|
+
|
|
851
|
+
# look for default values for missing args
|
|
852
|
+
if i >= len(args):
|
|
853
|
+
if arg_name not in f.defaults:
|
|
854
|
+
return False
|
|
855
|
+
else:
|
|
856
|
+
# otherwise check arg type matches input variable type
|
|
857
|
+
if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
|
|
858
|
+
return False
|
|
859
|
+
|
|
860
|
+
return True
|
|
861
|
+
|
|
862
|
+
if not match_args(args, f):
|
|
863
|
+
continue
|
|
703
864
|
|
|
704
865
|
# check output dimensions match expectations
|
|
705
866
|
if min_outputs:
|
|
706
867
|
try:
|
|
707
868
|
value_type = f.value_func(args, kwds, templates)
|
|
708
|
-
if len(value_type) != min_outputs:
|
|
709
|
-
match = False
|
|
869
|
+
if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
|
|
710
870
|
continue
|
|
711
871
|
except Exception:
|
|
712
872
|
# value func may fail if the user has given
|
|
713
873
|
# incorrect args, so we need to catch this
|
|
714
|
-
match = False
|
|
715
874
|
continue
|
|
716
875
|
|
|
717
876
|
# found a match, use it
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
# shorten Warp primitive type names
|
|
733
|
-
if isinstance(x.type, list):
|
|
734
|
-
if len(x.type) != 1:
|
|
735
|
-
raise Exception("Argument must not be the result from a multi-valued function")
|
|
736
|
-
arg_type = x.type[0]
|
|
737
|
-
else:
|
|
738
|
-
arg_type = x.type
|
|
739
|
-
if arg_type.__module__ == "warp.types":
|
|
740
|
-
arg_types.append(arg_type.__name__)
|
|
741
|
-
else:
|
|
742
|
-
arg_types.append(arg_type.__module__ + "." + arg_type.__name__)
|
|
743
|
-
|
|
744
|
-
if isinstance(x, warp.context.Function):
|
|
745
|
-
arg_types.append("function")
|
|
746
|
-
|
|
747
|
-
raise Exception(
|
|
748
|
-
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
|
|
749
|
-
)
|
|
877
|
+
return f
|
|
878
|
+
|
|
879
|
+
# unresolved function, report error
|
|
880
|
+
arg_types = []
|
|
881
|
+
|
|
882
|
+
for x in args:
|
|
883
|
+
if isinstance(x, Var):
|
|
884
|
+
# shorten Warp primitive type names
|
|
885
|
+
if isinstance(x.type, list):
|
|
886
|
+
if len(x.type) != 1:
|
|
887
|
+
raise WarpCodegenError("Argument must not be the result from a multi-valued function")
|
|
888
|
+
arg_type = x.type[0]
|
|
889
|
+
else:
|
|
890
|
+
arg_type = x.type
|
|
750
891
|
|
|
751
|
-
|
|
752
|
-
|
|
892
|
+
arg_types.append(type_repr(arg_type))
|
|
893
|
+
|
|
894
|
+
if isinstance(x, warp.context.Function):
|
|
895
|
+
arg_types.append("function")
|
|
896
|
+
|
|
897
|
+
raise WarpCodegenError(
|
|
898
|
+
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
|
|
902
|
+
func = adj.resolve_func(func, args, min_outputs, templates, kwds)
|
|
753
903
|
|
|
754
904
|
# push any default values onto args
|
|
755
905
|
for i, (arg_name, arg_type) in enumerate(func.input_types.items()):
|
|
756
906
|
if i >= len(args):
|
|
757
|
-
if arg_name in
|
|
907
|
+
if arg_name in func.defaults:
|
|
758
908
|
const = adj.add_constant(func.defaults[arg_name])
|
|
759
909
|
args.append(const)
|
|
760
910
|
else:
|
|
761
|
-
match = False
|
|
762
911
|
break
|
|
763
912
|
|
|
764
913
|
# if it is a user-function then build it recursively
|
|
@@ -766,93 +915,105 @@ class Adjoint:
|
|
|
766
915
|
adj.builder.build_function(func)
|
|
767
916
|
|
|
768
917
|
# evaluate the function type based on inputs
|
|
769
|
-
|
|
918
|
+
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
|
|
919
|
+
return_type = func.value_func(arg_types, kwds, templates)
|
|
770
920
|
|
|
771
921
|
func_name = compute_type_str(func.native_func, templates)
|
|
922
|
+
param_types = list(func.input_types.values())
|
|
772
923
|
|
|
773
924
|
use_initializer_list = func.initializer_list_func(args, templates)
|
|
774
925
|
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
if func.skip_replay:
|
|
782
|
-
adj.add_forward(forward_call, replay="//" + forward_call)
|
|
783
|
-
else:
|
|
784
|
-
adj.add_forward(forward_call)
|
|
926
|
+
args_var = [
|
|
927
|
+
adj.load(a)
|
|
928
|
+
if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
|
|
929
|
+
else a
|
|
930
|
+
for i, a in enumerate(args)
|
|
931
|
+
]
|
|
785
932
|
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
if arg_str is not None:
|
|
789
|
-
reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
|
|
790
|
-
adj.add_reverse(reverse_call)
|
|
933
|
+
if return_type is None:
|
|
934
|
+
# handles expression (zero output) functions, e.g.: void do_something();
|
|
791
935
|
|
|
792
|
-
|
|
936
|
+
output = None
|
|
937
|
+
output_list = []
|
|
793
938
|
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
if isinstance(value_type, list):
|
|
798
|
-
value_type = value_type[0]
|
|
799
|
-
output = adj.add_var(value_type)
|
|
800
|
-
forward_call = "var_{} = {}{}({});".format(
|
|
801
|
-
output, func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
|
|
939
|
+
forward_call = (
|
|
940
|
+
f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
802
941
|
)
|
|
942
|
+
replay_call = forward_call
|
|
943
|
+
if func.custom_replay_func is not None:
|
|
944
|
+
replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
803
945
|
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
else:
|
|
807
|
-
adj.add_forward(forward_call)
|
|
946
|
+
elif not isinstance(return_type, list) or len(return_type) == 1:
|
|
947
|
+
# handle simple function (one output)
|
|
808
948
|
|
|
809
|
-
if
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
adj.add_reverse(reverse_call)
|
|
949
|
+
if isinstance(return_type, list):
|
|
950
|
+
return_type = return_type[0]
|
|
951
|
+
output = adj.add_var(return_type)
|
|
952
|
+
output_list = [output]
|
|
814
953
|
|
|
815
|
-
|
|
954
|
+
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
955
|
+
replay_call = forward_call
|
|
956
|
+
if func.custom_replay_func is not None:
|
|
957
|
+
replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
816
958
|
|
|
817
959
|
else:
|
|
818
960
|
# handle multiple value functions
|
|
819
961
|
|
|
820
|
-
output = [adj.add_var(v) for v in
|
|
821
|
-
|
|
822
|
-
|
|
962
|
+
output = [adj.add_var(v) for v in return_type]
|
|
963
|
+
output_list = output
|
|
964
|
+
|
|
965
|
+
forward_call = (
|
|
966
|
+
f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var + output, use_initializer_list)});"
|
|
823
967
|
)
|
|
824
|
-
|
|
968
|
+
replay_call = forward_call
|
|
825
969
|
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
970
|
+
if func.skip_replay:
|
|
971
|
+
adj.add_forward(forward_call, replay="// " + replay_call)
|
|
972
|
+
else:
|
|
973
|
+
adj.add_forward(forward_call, replay=replay_call)
|
|
974
|
+
|
|
975
|
+
if not func.missing_grad and len(args):
|
|
976
|
+
reverse_has_output_args = (
|
|
977
|
+
func.require_original_output_arg or len(output_list) > 1
|
|
978
|
+
) and func.custom_grad_func is None
|
|
979
|
+
arg_str = adj.format_reverse_call_args(
|
|
980
|
+
args_var,
|
|
981
|
+
args,
|
|
982
|
+
output_list,
|
|
983
|
+
use_initializer_list,
|
|
984
|
+
has_output_args=reverse_has_output_args,
|
|
985
|
+
require_original_output_arg=func.require_original_output_arg,
|
|
986
|
+
)
|
|
987
|
+
if arg_str is not None:
|
|
988
|
+
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
|
|
989
|
+
adj.add_reverse(reverse_call)
|
|
831
990
|
|
|
832
|
-
|
|
833
|
-
return output[0]
|
|
991
|
+
return output
|
|
834
992
|
|
|
835
|
-
|
|
993
|
+
def add_builtin_call(adj, func_name, args, min_outputs=None, templates=[], kwds=None):
|
|
994
|
+
func = warp.context.builtin_functions[func_name]
|
|
995
|
+
return adj.add_call(func, args, min_outputs, templates, kwds)
|
|
836
996
|
|
|
837
997
|
def add_return(adj, var):
|
|
838
998
|
if var is None or len(var) == 0:
|
|
839
|
-
adj.add_forward("return;", "goto label{};"
|
|
999
|
+
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
840
1000
|
elif len(var) == 1:
|
|
841
|
-
adj.add_forward("return
|
|
1001
|
+
adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
|
|
842
1002
|
adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
|
|
843
1003
|
else:
|
|
844
1004
|
for i, v in enumerate(var):
|
|
845
|
-
adj.add_forward("ret_{} =
|
|
846
|
-
adj.add_reverse("adj_{} += adj_ret_{};"
|
|
847
|
-
adj.add_forward("return;", "goto label{};"
|
|
1005
|
+
adj.add_forward(f"ret_{i} = {v.emit()};")
|
|
1006
|
+
adj.add_reverse(f"adj_{v} += adj_ret_{i};")
|
|
1007
|
+
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
848
1008
|
|
|
849
|
-
adj.add_reverse("label{}:;"
|
|
1009
|
+
adj.add_reverse(f"label{adj.label_count}:;")
|
|
850
1010
|
|
|
851
1011
|
adj.label_count += 1
|
|
852
1012
|
|
|
853
1013
|
# define an if statement
|
|
854
1014
|
def begin_if(adj, cond):
|
|
855
|
-
|
|
1015
|
+
cond = adj.load(cond)
|
|
1016
|
+
adj.add_forward(f"if ({cond.emit()}) {{")
|
|
856
1017
|
adj.add_reverse("}")
|
|
857
1018
|
|
|
858
1019
|
adj.indent()
|
|
@@ -861,10 +1022,12 @@ class Adjoint:
|
|
|
861
1022
|
adj.dedent()
|
|
862
1023
|
|
|
863
1024
|
adj.add_forward("}")
|
|
864
|
-
adj.
|
|
1025
|
+
cond = adj.load(cond)
|
|
1026
|
+
adj.add_reverse(f"if ({cond.emit()}) {{")
|
|
865
1027
|
|
|
866
1028
|
def begin_else(adj, cond):
|
|
867
|
-
adj.
|
|
1029
|
+
cond = adj.load(cond)
|
|
1030
|
+
adj.add_forward(f"if (!{cond.emit()}) {{")
|
|
868
1031
|
adj.add_reverse("}")
|
|
869
1032
|
|
|
870
1033
|
adj.indent()
|
|
@@ -873,7 +1036,8 @@ class Adjoint:
|
|
|
873
1036
|
adj.dedent()
|
|
874
1037
|
|
|
875
1038
|
adj.add_forward("}")
|
|
876
|
-
adj.
|
|
1039
|
+
cond = adj.load(cond)
|
|
1040
|
+
adj.add_reverse(f"if (!{cond.emit()}) {{")
|
|
877
1041
|
|
|
878
1042
|
# define a for-loop
|
|
879
1043
|
def begin_for(adj, iter):
|
|
@@ -883,10 +1047,10 @@ class Adjoint:
|
|
|
883
1047
|
adj.indent()
|
|
884
1048
|
|
|
885
1049
|
# evaluate cond
|
|
886
|
-
adj.add_forward(f"if (iter_cmp(
|
|
1050
|
+
adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto for_end_{cond_block.label};")
|
|
887
1051
|
|
|
888
1052
|
# evaluate iter
|
|
889
|
-
val = adj.
|
|
1053
|
+
val = adj.add_builtin_call("iter_next", [iter])
|
|
890
1054
|
|
|
891
1055
|
adj.begin_block()
|
|
892
1056
|
|
|
@@ -917,17 +1081,14 @@ class Adjoint:
|
|
|
917
1081
|
reverse = []
|
|
918
1082
|
|
|
919
1083
|
# reverse iterator
|
|
920
|
-
reverse.append(adj.
|
|
1084
|
+
reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
|
|
921
1085
|
|
|
922
1086
|
for i in cond_block.body_forward:
|
|
923
1087
|
reverse.append(i)
|
|
924
1088
|
|
|
925
1089
|
# zero adjoints
|
|
926
1090
|
for i in body_block.vars:
|
|
927
|
-
|
|
928
|
-
reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}{{}};")
|
|
929
|
-
else:
|
|
930
|
-
reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}(0);")
|
|
1091
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
931
1092
|
|
|
932
1093
|
# replay
|
|
933
1094
|
for i in body_block.body_replay:
|
|
@@ -937,14 +1098,14 @@ class Adjoint:
|
|
|
937
1098
|
for i in reversed(body_block.body_reverse):
|
|
938
1099
|
reverse.append(i)
|
|
939
1100
|
|
|
940
|
-
reverse.append(adj.
|
|
941
|
-
reverse.append(adj.
|
|
1101
|
+
reverse.append(adj.indentation + f"\tgoto for_start_{cond_block.label};")
|
|
1102
|
+
reverse.append(adj.indentation + f"for_end_{cond_block.label}:;")
|
|
942
1103
|
|
|
943
1104
|
adj.blocks[-1].body_reverse.extend(reversed(reverse))
|
|
944
1105
|
|
|
945
1106
|
# define a while loop
|
|
946
1107
|
def begin_while(adj, cond):
|
|
947
|
-
#
|
|
1108
|
+
# evaluate condition in its own block
|
|
948
1109
|
# so we can control replay
|
|
949
1110
|
cond_block = adj.begin_block()
|
|
950
1111
|
adj.loop_blocks.append(cond_block)
|
|
@@ -952,7 +1113,7 @@ class Adjoint:
|
|
|
952
1113
|
|
|
953
1114
|
c = adj.eval(cond)
|
|
954
1115
|
|
|
955
|
-
cond_block.body_forward.append(f"if ((
|
|
1116
|
+
cond_block.body_forward.append(f"if (({c.emit()}) == false) goto while_end_{cond_block.label};")
|
|
956
1117
|
|
|
957
1118
|
# being block around loop
|
|
958
1119
|
adj.begin_block()
|
|
@@ -986,10 +1147,7 @@ class Adjoint:
|
|
|
986
1147
|
|
|
987
1148
|
# zero adjoints of local vars
|
|
988
1149
|
for i in body_block.vars:
|
|
989
|
-
|
|
990
|
-
reverse.append(f"adj_{i} = {i.ctype()}{{}};")
|
|
991
|
-
else:
|
|
992
|
-
reverse.append(f"adj_{i} = {i.ctype()}(0);")
|
|
1150
|
+
reverse.append(f"{i.emit_adj()} = {{}};")
|
|
993
1151
|
|
|
994
1152
|
# replay
|
|
995
1153
|
for i in body_block.body_replay:
|
|
@@ -1009,6 +1167,10 @@ class Adjoint:
|
|
|
1009
1167
|
for f in node.body:
|
|
1010
1168
|
adj.eval(f)
|
|
1011
1169
|
|
|
1170
|
+
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
1171
|
+
if not isinstance(node.body[-1], ast.Return):
|
|
1172
|
+
adj.add_forward("return {};", skip_replay=True)
|
|
1173
|
+
|
|
1012
1174
|
def emit_If(adj, node):
|
|
1013
1175
|
if len(node.body) == 0:
|
|
1014
1176
|
return None
|
|
@@ -1036,7 +1198,7 @@ class Adjoint:
|
|
|
1036
1198
|
|
|
1037
1199
|
if var1 != var2:
|
|
1038
1200
|
# insert a phi function that selects var1, var2 based on cond
|
|
1039
|
-
out = adj.
|
|
1201
|
+
out = adj.add_builtin_call("select", [cond, var1, var2])
|
|
1040
1202
|
adj.symbols[sym] = out
|
|
1041
1203
|
|
|
1042
1204
|
symbols_prev = adj.symbols.copy()
|
|
@@ -1060,7 +1222,7 @@ class Adjoint:
|
|
|
1060
1222
|
if var1 != var2:
|
|
1061
1223
|
# insert a phi function that selects var1, var2 based on cond
|
|
1062
1224
|
# note the reversed order of vars since we want to use !cond as our select
|
|
1063
|
-
out = adj.
|
|
1225
|
+
out = adj.add_builtin_call("select", [cond, var2, var1])
|
|
1064
1226
|
adj.symbols[sym] = out
|
|
1065
1227
|
|
|
1066
1228
|
def emit_Compare(adj, node):
|
|
@@ -1082,7 +1244,7 @@ class Adjoint:
|
|
|
1082
1244
|
elif isinstance(op, ast.Or):
|
|
1083
1245
|
func = "||"
|
|
1084
1246
|
else:
|
|
1085
|
-
raise
|
|
1247
|
+
raise WarpCodegenKeyError(f"Op {op} is not supported")
|
|
1086
1248
|
|
|
1087
1249
|
return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
|
|
1088
1250
|
|
|
@@ -1102,7 +1264,7 @@ class Adjoint:
|
|
|
1102
1264
|
obj = capturedvars.get(str(node.id), None)
|
|
1103
1265
|
|
|
1104
1266
|
if obj is None:
|
|
1105
|
-
raise
|
|
1267
|
+
raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
|
|
1106
1268
|
|
|
1107
1269
|
if warp.types.is_value(obj):
|
|
1108
1270
|
# evaluate constant
|
|
@@ -1114,26 +1276,96 @@ class Adjoint:
|
|
|
1114
1276
|
# pass it back to the caller for processing
|
|
1115
1277
|
return obj
|
|
1116
1278
|
|
|
1279
|
+
@staticmethod
|
|
1280
|
+
def resolve_type_attribute(var_type: type, attr: str):
|
|
1281
|
+
if isinstance(var_type, type) and type_is_value(var_type):
|
|
1282
|
+
if attr == "dtype":
|
|
1283
|
+
return type_scalar_type(var_type)
|
|
1284
|
+
elif attr == "length":
|
|
1285
|
+
return type_length(var_type)
|
|
1286
|
+
|
|
1287
|
+
return getattr(var_type, attr, None)
|
|
1288
|
+
|
|
1289
|
+
def vector_component_index(adj, component, vector_type):
|
|
1290
|
+
if len(component) != 1:
|
|
1291
|
+
raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
|
|
1292
|
+
|
|
1293
|
+
dim = vector_type._shape_[0]
|
|
1294
|
+
swizzles = "xyzw"[0:dim]
|
|
1295
|
+
if component not in swizzles:
|
|
1296
|
+
raise WarpCodegenAttributeError(
|
|
1297
|
+
f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
index = swizzles.index(component)
|
|
1301
|
+
index = adj.add_constant(index)
|
|
1302
|
+
return index
|
|
1303
|
+
|
|
1304
|
+
@staticmethod
|
|
1305
|
+
def is_differentiable_value_type(var_type):
|
|
1306
|
+
# checks that the argument type is a value type (i.e, not an array)
|
|
1307
|
+
# possibly holding differentiable values (for which gradients must be accumulated)
|
|
1308
|
+
return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
|
|
1309
|
+
|
|
1117
1310
|
def emit_Attribute(adj, node):
|
|
1118
|
-
|
|
1119
|
-
|
|
1311
|
+
if hasattr(node, "is_adjoint"):
|
|
1312
|
+
node.value.is_adjoint = True
|
|
1313
|
+
|
|
1314
|
+
aggregate = adj.eval(node.value)
|
|
1120
1315
|
|
|
1121
|
-
|
|
1122
|
-
|
|
1316
|
+
try:
|
|
1317
|
+
if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
|
|
1318
|
+
out = getattr(aggregate, node.attr)
|
|
1123
1319
|
|
|
1124
1320
|
if warp.types.is_value(out):
|
|
1125
1321
|
return adj.add_constant(out)
|
|
1126
1322
|
|
|
1127
1323
|
return out
|
|
1128
1324
|
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1325
|
+
if hasattr(node, "is_adjoint"):
|
|
1326
|
+
# create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
|
|
1327
|
+
attr_name = aggregate.label + "." + node.attr
|
|
1328
|
+
attr_type = aggregate.type.vars[node.attr].type
|
|
1329
|
+
|
|
1330
|
+
return Var(attr_name, attr_type)
|
|
1331
|
+
|
|
1332
|
+
aggregate_type = strip_reference(aggregate.type)
|
|
1333
|
+
|
|
1334
|
+
# reading a vector component
|
|
1335
|
+
if type_is_vector(aggregate_type):
|
|
1336
|
+
index = adj.vector_component_index(node.attr, aggregate_type)
|
|
1337
|
+
|
|
1338
|
+
return adj.add_builtin_call("extract", [aggregate, index])
|
|
1339
|
+
|
|
1340
|
+
else:
|
|
1341
|
+
attr_type = Reference(aggregate_type.vars[node.attr].type)
|
|
1342
|
+
attr = adj.add_var(attr_type)
|
|
1343
|
+
|
|
1344
|
+
if is_reference(aggregate.type):
|
|
1345
|
+
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
|
|
1346
|
+
else:
|
|
1347
|
+
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
|
|
1348
|
+
|
|
1349
|
+
if adj.is_differentiable_value_type(strip_reference(attr_type)):
|
|
1350
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
|
|
1351
|
+
else:
|
|
1352
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
|
|
1353
|
+
|
|
1354
|
+
return attr
|
|
1132
1355
|
|
|
1133
|
-
|
|
1356
|
+
except (KeyError, AttributeError):
|
|
1357
|
+
# Try resolving as type attribute
|
|
1358
|
+
aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
|
|
1134
1359
|
|
|
1135
|
-
|
|
1136
|
-
|
|
1360
|
+
type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
|
|
1361
|
+
if type_attribute is not None:
|
|
1362
|
+
return type_attribute
|
|
1363
|
+
|
|
1364
|
+
if isinstance(aggregate, Var):
|
|
1365
|
+
raise WarpCodegenAttributeError(
|
|
1366
|
+
f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
|
|
1367
|
+
)
|
|
1368
|
+
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
|
|
1137
1369
|
|
|
1138
1370
|
def emit_String(adj, node):
|
|
1139
1371
|
# string constant
|
|
@@ -1150,19 +1382,25 @@ class Adjoint:
|
|
|
1150
1382
|
adj.symbols[key] = out
|
|
1151
1383
|
return out
|
|
1152
1384
|
|
|
1385
|
+
def emit_Ellipsis(adj, node):
|
|
1386
|
+
# stubbed @wp.native_func
|
|
1387
|
+
return
|
|
1388
|
+
|
|
1153
1389
|
def emit_NameConstant(adj, node):
|
|
1154
|
-
if node.value
|
|
1390
|
+
if node.value:
|
|
1155
1391
|
return adj.add_constant(True)
|
|
1156
|
-
elif node.value == False:
|
|
1157
|
-
return adj.add_constant(False)
|
|
1158
1392
|
elif node.value is None:
|
|
1159
|
-
raise
|
|
1393
|
+
raise WarpCodegenTypeError("None type unsupported")
|
|
1394
|
+
else:
|
|
1395
|
+
return adj.add_constant(False)
|
|
1160
1396
|
|
|
1161
1397
|
def emit_Constant(adj, node):
|
|
1162
1398
|
if isinstance(node, ast.Str):
|
|
1163
1399
|
return adj.emit_String(node)
|
|
1164
1400
|
elif isinstance(node, ast.Num):
|
|
1165
1401
|
return adj.emit_Num(node)
|
|
1402
|
+
elif isinstance(node, ast.Ellipsis):
|
|
1403
|
+
return adj.emit_Ellipsis(node)
|
|
1166
1404
|
else:
|
|
1167
1405
|
assert isinstance(node, ast.NameConstant)
|
|
1168
1406
|
return adj.emit_NameConstant(node)
|
|
@@ -1173,18 +1411,16 @@ class Adjoint:
|
|
|
1173
1411
|
right = adj.eval(node.right)
|
|
1174
1412
|
|
|
1175
1413
|
name = builtin_operators[type(node.op)]
|
|
1176
|
-
func = warp.context.builtin_functions[name]
|
|
1177
1414
|
|
|
1178
|
-
return adj.
|
|
1415
|
+
return adj.add_builtin_call(name, [left, right])
|
|
1179
1416
|
|
|
1180
1417
|
def emit_UnaryOp(adj, node):
|
|
1181
1418
|
# evaluate unary op arguments
|
|
1182
1419
|
arg = adj.eval(node.operand)
|
|
1183
1420
|
|
|
1184
1421
|
name = builtin_operators[type(node.op)]
|
|
1185
|
-
func = warp.context.builtin_functions[name]
|
|
1186
1422
|
|
|
1187
|
-
return adj.
|
|
1423
|
+
return adj.add_builtin_call(name, [arg])
|
|
1188
1424
|
|
|
1189
1425
|
def materialize_redefinitions(adj, symbols):
|
|
1190
1426
|
# detect symbols with conflicting definitions (assigned inside the for loop)
|
|
@@ -1194,21 +1430,19 @@ class Adjoint:
|
|
|
1194
1430
|
var2 = adj.symbols[sym]
|
|
1195
1431
|
|
|
1196
1432
|
if var1 != var2:
|
|
1197
|
-
if warp.config.verbose:
|
|
1433
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1198
1434
|
lineno = adj.lineno + adj.fun_lineno
|
|
1199
|
-
line = adj.
|
|
1200
|
-
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this
|
|
1435
|
+
line = adj.source_lines[adj.lineno]
|
|
1436
|
+
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
|
|
1201
1437
|
print(msg)
|
|
1202
1438
|
|
|
1203
1439
|
if var1.constant is not None:
|
|
1204
|
-
raise
|
|
1205
|
-
"Error mutating a constant {} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
|
|
1206
|
-
sym
|
|
1207
|
-
)
|
|
1440
|
+
raise WarpCodegenError(
|
|
1441
|
+
f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
|
|
1208
1442
|
)
|
|
1209
1443
|
|
|
1210
1444
|
# overwrite the old variable value (violates SSA)
|
|
1211
|
-
adj.
|
|
1445
|
+
adj.add_builtin_call("assign", [var1, var2])
|
|
1212
1446
|
|
|
1213
1447
|
# reset the symbol to point to the original variable
|
|
1214
1448
|
adj.symbols[sym] = var1
|
|
@@ -1227,35 +1461,20 @@ class Adjoint:
|
|
|
1227
1461
|
|
|
1228
1462
|
adj.end_while()
|
|
1229
1463
|
|
|
1230
|
-
def is_num(adj, a):
|
|
1231
|
-
# simple constant
|
|
1232
|
-
if isinstance(a, ast.Num):
|
|
1233
|
-
return True
|
|
1234
|
-
# expression of form -constant
|
|
1235
|
-
elif isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
|
|
1236
|
-
return True
|
|
1237
|
-
else:
|
|
1238
|
-
# try and resolve the expression to an object
|
|
1239
|
-
# e.g.: wp.constant in the globals scope
|
|
1240
|
-
obj, path = adj.resolve_path(a)
|
|
1241
|
-
if warp.types.is_int(obj):
|
|
1242
|
-
return True
|
|
1243
|
-
else:
|
|
1244
|
-
return False
|
|
1245
|
-
|
|
1246
1464
|
def eval_num(adj, a):
|
|
1247
1465
|
if isinstance(a, ast.Num):
|
|
1248
|
-
return a.n
|
|
1249
|
-
|
|
1250
|
-
return -a.operand.n
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1466
|
+
return True, a.n
|
|
1467
|
+
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
|
|
1468
|
+
return True, -a.operand.n
|
|
1469
|
+
|
|
1470
|
+
# try and resolve the expression to an object
|
|
1471
|
+
# e.g.: wp.constant in the globals scope
|
|
1472
|
+
obj, _ = adj.resolve_static_expression(a)
|
|
1473
|
+
|
|
1474
|
+
if isinstance(obj, Var) and obj.constant is not None:
|
|
1475
|
+
obj = obj.constant
|
|
1476
|
+
|
|
1477
|
+
return warp.types.is_int(obj), obj
|
|
1259
1478
|
|
|
1260
1479
|
# detects whether a loop contains a break (or continue) statement
|
|
1261
1480
|
def contains_break(adj, body):
|
|
@@ -1278,61 +1497,82 @@ class Adjoint:
|
|
|
1278
1497
|
|
|
1279
1498
|
# returns a constant range() if unrollable, otherwise None
|
|
1280
1499
|
def get_unroll_range(adj, loop):
|
|
1281
|
-
if
|
|
1500
|
+
if (
|
|
1501
|
+
not isinstance(loop.iter, ast.Call)
|
|
1502
|
+
or not isinstance(loop.iter.func, ast.Name)
|
|
1503
|
+
or loop.iter.func.id != "range"
|
|
1504
|
+
or len(loop.iter.args) == 0
|
|
1505
|
+
or len(loop.iter.args) > 3
|
|
1506
|
+
):
|
|
1282
1507
|
return None
|
|
1283
1508
|
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
# constant compile-time expressions e.g.: range(0, 3*2)
|
|
1288
|
-
if not adj.is_num(a):
|
|
1289
|
-
return None
|
|
1290
|
-
|
|
1291
|
-
# range(end)
|
|
1292
|
-
if len(loop.iter.args) == 1:
|
|
1293
|
-
start = 0
|
|
1294
|
-
end = adj.eval_num(loop.iter.args[0])
|
|
1295
|
-
step = 1
|
|
1296
|
-
|
|
1297
|
-
# range(start, end)
|
|
1298
|
-
elif len(loop.iter.args) == 2:
|
|
1299
|
-
start = adj.eval_num(loop.iter.args[0])
|
|
1300
|
-
end = adj.eval_num(loop.iter.args[1])
|
|
1301
|
-
step = 1
|
|
1302
|
-
|
|
1303
|
-
# range(start, end, step)
|
|
1304
|
-
elif len(loop.iter.args) == 3:
|
|
1305
|
-
start = adj.eval_num(loop.iter.args[0])
|
|
1306
|
-
end = adj.eval_num(loop.iter.args[1])
|
|
1307
|
-
step = adj.eval_num(loop.iter.args[2])
|
|
1308
|
-
|
|
1309
|
-
# test if we're above max unroll count
|
|
1310
|
-
max_iters = abs(end - start) // abs(step)
|
|
1311
|
-
max_unroll = adj.builder.options["max_unroll"]
|
|
1312
|
-
|
|
1313
|
-
if max_iters > max_unroll:
|
|
1314
|
-
if warp.config.verbose:
|
|
1315
|
-
print(
|
|
1316
|
-
f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
|
|
1317
|
-
)
|
|
1318
|
-
return None
|
|
1509
|
+
# if all range() arguments are numeric constants we will unroll
|
|
1510
|
+
# note that this only handles trivial constants, it will not unroll
|
|
1511
|
+
# constant compile-time expressions e.g.: range(0, 3*2)
|
|
1319
1512
|
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
|
|
1323
|
-
|
|
1513
|
+
# Evaluate the arguments and check that they are numeric constants
|
|
1514
|
+
# It is important to do that in one pass, so that if evaluating these arguments have side effects
|
|
1515
|
+
# the code does not get generated more than once
|
|
1516
|
+
range_args = [adj.eval_num(arg) for arg in loop.iter.args]
|
|
1517
|
+
arg_is_numeric, arg_values = zip(*range_args)
|
|
1518
|
+
|
|
1519
|
+
if all(arg_is_numeric):
|
|
1520
|
+
# All argument are numeric constants
|
|
1521
|
+
|
|
1522
|
+
# range(end)
|
|
1523
|
+
if len(loop.iter.args) == 1:
|
|
1524
|
+
start = 0
|
|
1525
|
+
end = arg_values[0]
|
|
1526
|
+
step = 1
|
|
1527
|
+
|
|
1528
|
+
# range(start, end)
|
|
1529
|
+
elif len(loop.iter.args) == 2:
|
|
1530
|
+
start = arg_values[0]
|
|
1531
|
+
end = arg_values[1]
|
|
1532
|
+
step = 1
|
|
1533
|
+
|
|
1534
|
+
# range(start, end, step)
|
|
1535
|
+
elif len(loop.iter.args) == 3:
|
|
1536
|
+
start = arg_values[0]
|
|
1537
|
+
end = arg_values[1]
|
|
1538
|
+
step = arg_values[2]
|
|
1539
|
+
|
|
1540
|
+
# test if we're above max unroll count
|
|
1541
|
+
max_iters = abs(end - start) // abs(step)
|
|
1542
|
+
max_unroll = adj.builder.options["max_unroll"]
|
|
1324
1543
|
|
|
1325
|
-
|
|
1326
|
-
|
|
1544
|
+
ok_to_unroll = True
|
|
1545
|
+
|
|
1546
|
+
if max_iters > max_unroll:
|
|
1547
|
+
if warp.config.verbose:
|
|
1548
|
+
print(
|
|
1549
|
+
f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
|
|
1550
|
+
)
|
|
1551
|
+
ok_to_unroll = False
|
|
1552
|
+
|
|
1553
|
+
elif adj.contains_break(loop.body):
|
|
1554
|
+
if warp.config.verbose:
|
|
1555
|
+
print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
|
|
1556
|
+
ok_to_unroll = False
|
|
1557
|
+
|
|
1558
|
+
if ok_to_unroll:
|
|
1559
|
+
return range(start, end, step)
|
|
1560
|
+
|
|
1561
|
+
# Unroll is not possible, range needs to be valuated dynamically
|
|
1562
|
+
range_call = adj.add_builtin_call(
|
|
1563
|
+
"range",
|
|
1564
|
+
[adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
|
|
1565
|
+
)
|
|
1566
|
+
return range_call
|
|
1327
1567
|
|
|
1328
1568
|
def emit_For(adj, node):
|
|
1329
1569
|
# try and unroll simple range() statements that use constant args
|
|
1330
1570
|
unroll_range = adj.get_unroll_range(node)
|
|
1331
1571
|
|
|
1332
|
-
if unroll_range:
|
|
1572
|
+
if isinstance(unroll_range, range):
|
|
1333
1573
|
for i in unroll_range:
|
|
1334
1574
|
const_iter = adj.add_constant(i)
|
|
1335
|
-
var_iter = adj.
|
|
1575
|
+
var_iter = adj.add_builtin_call("int", [const_iter])
|
|
1336
1576
|
adj.symbols[node.target.id] = var_iter
|
|
1337
1577
|
|
|
1338
1578
|
# eval body
|
|
@@ -1341,8 +1581,12 @@ class Adjoint:
|
|
|
1341
1581
|
|
|
1342
1582
|
# otherwise generate a dynamic loop
|
|
1343
1583
|
else:
|
|
1344
|
-
# evaluate the Iterable
|
|
1345
|
-
|
|
1584
|
+
# evaluate the Iterable -- only if not previously evaluated when trying to unroll
|
|
1585
|
+
if unroll_range is not None:
|
|
1586
|
+
# Range has already been evaluated when trying to unroll, do not re-evaluate
|
|
1587
|
+
iter = unroll_range
|
|
1588
|
+
else:
|
|
1589
|
+
iter = adj.eval(node.iter)
|
|
1346
1590
|
|
|
1347
1591
|
adj.symbols[node.target.id] = adj.begin_for(iter)
|
|
1348
1592
|
|
|
@@ -1371,15 +1615,28 @@ class Adjoint:
|
|
|
1371
1615
|
def emit_Expr(adj, node):
|
|
1372
1616
|
return adj.eval(node.value)
|
|
1373
1617
|
|
|
1618
|
+
def check_tid_in_func_error(adj, node):
|
|
1619
|
+
if adj.is_user_function:
|
|
1620
|
+
if hasattr(node.func, "attr") and node.func.attr == "tid":
|
|
1621
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
1622
|
+
line = adj.source_lines[adj.lineno]
|
|
1623
|
+
raise WarpCodegenError(
|
|
1624
|
+
"tid() may only be called from a Warp kernel, not a Warp function. "
|
|
1625
|
+
"Instead, obtain the indices from a @wp.kernel and pass them as "
|
|
1626
|
+
f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
|
|
1627
|
+
)
|
|
1628
|
+
|
|
1374
1629
|
def emit_Call(adj, node):
|
|
1630
|
+
adj.check_tid_in_func_error(node)
|
|
1631
|
+
|
|
1375
1632
|
# try and lookup function in globals by
|
|
1376
1633
|
# resolving path (e.g.: module.submodule.attr)
|
|
1377
|
-
func, path = adj.
|
|
1634
|
+
func, path = adj.resolve_static_expression(node.func)
|
|
1378
1635
|
templates = []
|
|
1379
1636
|
|
|
1380
|
-
if isinstance(func, warp.context.Function)
|
|
1637
|
+
if not isinstance(func, warp.context.Function):
|
|
1381
1638
|
if len(path) == 0:
|
|
1382
|
-
raise
|
|
1639
|
+
raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
|
|
1383
1640
|
|
|
1384
1641
|
attr = path[-1]
|
|
1385
1642
|
caller = func
|
|
@@ -1404,7 +1661,7 @@ class Adjoint:
|
|
|
1404
1661
|
func = caller.initializer()
|
|
1405
1662
|
|
|
1406
1663
|
if func is None:
|
|
1407
|
-
raise
|
|
1664
|
+
raise WarpCodegenError(
|
|
1408
1665
|
f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
|
|
1409
1666
|
)
|
|
1410
1667
|
|
|
@@ -1413,16 +1670,25 @@ class Adjoint:
|
|
|
1413
1670
|
# eval all arguments
|
|
1414
1671
|
for arg in node.args:
|
|
1415
1672
|
var = adj.eval(arg)
|
|
1673
|
+
if not is_local_value(var):
|
|
1674
|
+
raise RuntimeError(
|
|
1675
|
+
"Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
|
|
1676
|
+
)
|
|
1416
1677
|
args.append(var)
|
|
1417
1678
|
|
|
1418
|
-
# eval all keyword
|
|
1679
|
+
# eval all keyword args
|
|
1419
1680
|
def kwval(kw):
|
|
1420
1681
|
if isinstance(kw.value, ast.Num):
|
|
1421
1682
|
return kw.value.n
|
|
1422
1683
|
elif isinstance(kw.value, ast.Tuple):
|
|
1423
|
-
|
|
1684
|
+
arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
|
|
1685
|
+
if not all(arg_is_numeric):
|
|
1686
|
+
raise WarpCodegenError(
|
|
1687
|
+
f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
|
|
1688
|
+
)
|
|
1689
|
+
return arg_values
|
|
1424
1690
|
else:
|
|
1425
|
-
return adj.
|
|
1691
|
+
return adj.resolve_static_expression(kw.value)[0]
|
|
1426
1692
|
|
|
1427
1693
|
kwds = {kw.arg: kwval(kw) for kw in node.keywords}
|
|
1428
1694
|
|
|
@@ -1439,10 +1705,26 @@ class Adjoint:
|
|
|
1439
1705
|
# the ast.Index node appears in 3.7 versions
|
|
1440
1706
|
# when performing array slices, e.g.: x = arr[i]
|
|
1441
1707
|
# but in version 3.8 and higher it does not appear
|
|
1708
|
+
|
|
1709
|
+
if hasattr(node, "is_adjoint"):
|
|
1710
|
+
node.value.is_adjoint = True
|
|
1711
|
+
|
|
1442
1712
|
return adj.eval(node.value)
|
|
1443
1713
|
|
|
1444
1714
|
def emit_Subscript(adj, node):
|
|
1715
|
+
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
1716
|
+
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
1717
|
+
node.slice.is_adjoint = True
|
|
1718
|
+
var = adj.eval(node.slice)
|
|
1719
|
+
var_name = var.label
|
|
1720
|
+
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
1721
|
+
return var
|
|
1722
|
+
|
|
1445
1723
|
target = adj.eval(node.value)
|
|
1724
|
+
if not is_local_value(target):
|
|
1725
|
+
raise RuntimeError(
|
|
1726
|
+
"Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
|
|
1727
|
+
)
|
|
1446
1728
|
|
|
1447
1729
|
indices = []
|
|
1448
1730
|
|
|
@@ -1462,28 +1744,34 @@ class Adjoint:
|
|
|
1462
1744
|
var = adj.eval(node.slice)
|
|
1463
1745
|
indices.append(var)
|
|
1464
1746
|
|
|
1465
|
-
|
|
1466
|
-
|
|
1747
|
+
target_type = strip_reference(target.type)
|
|
1748
|
+
if is_array(target_type):
|
|
1749
|
+
if len(indices) == target_type.ndim:
|
|
1467
1750
|
# handles array loads (where each dimension has an index specified)
|
|
1468
|
-
out = adj.
|
|
1751
|
+
out = adj.add_builtin_call("address", [target, *indices])
|
|
1469
1752
|
else:
|
|
1470
1753
|
# handles array views (fewer indices than dimensions)
|
|
1471
|
-
out = adj.
|
|
1754
|
+
out = adj.add_builtin_call("view", [target, *indices])
|
|
1472
1755
|
|
|
1473
1756
|
else:
|
|
1474
1757
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
1475
|
-
out = adj.
|
|
1758
|
+
out = adj.add_builtin_call("extract", [target, *indices])
|
|
1476
1759
|
|
|
1477
1760
|
return out
|
|
1478
1761
|
|
|
1479
1762
|
def emit_Assign(adj, node):
|
|
1763
|
+
if len(node.targets) != 1:
|
|
1764
|
+
raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
|
|
1765
|
+
|
|
1766
|
+
lhs = node.targets[0]
|
|
1767
|
+
|
|
1480
1768
|
# handle the case where we are assigning multiple output variables
|
|
1481
|
-
if isinstance(
|
|
1769
|
+
if isinstance(lhs, ast.Tuple):
|
|
1482
1770
|
# record the expected number of outputs on the node
|
|
1483
1771
|
# we do this so we can decide which function to
|
|
1484
1772
|
# call based on the number of expected outputs
|
|
1485
1773
|
if isinstance(node.value, ast.Call):
|
|
1486
|
-
node.value.expects = len(
|
|
1774
|
+
node.value.expects = len(lhs.elts)
|
|
1487
1775
|
|
|
1488
1776
|
# evaluate values
|
|
1489
1777
|
if isinstance(node.value, ast.Tuple):
|
|
@@ -1492,40 +1780,47 @@ class Adjoint:
|
|
|
1492
1780
|
out = adj.eval(node.value)
|
|
1493
1781
|
|
|
1494
1782
|
names = []
|
|
1495
|
-
for v in
|
|
1783
|
+
for v in lhs.elts:
|
|
1496
1784
|
if isinstance(v, ast.Name):
|
|
1497
1785
|
names.append(v.id)
|
|
1498
1786
|
else:
|
|
1499
|
-
raise
|
|
1787
|
+
raise WarpCodegenError(
|
|
1500
1788
|
"Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
|
|
1501
1789
|
)
|
|
1502
1790
|
|
|
1503
1791
|
if len(names) != len(out):
|
|
1504
|
-
raise
|
|
1505
|
-
"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {}, got {})"
|
|
1506
|
-
len(out), len(names)
|
|
1507
|
-
)
|
|
1792
|
+
raise WarpCodegenError(
|
|
1793
|
+
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
|
|
1508
1794
|
)
|
|
1509
1795
|
|
|
1510
1796
|
for name, rhs in zip(names, out):
|
|
1511
1797
|
if name in adj.symbols:
|
|
1512
1798
|
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
1513
|
-
raise
|
|
1514
|
-
"Error, assigning to existing symbol {} ({}) with different type ({})"
|
|
1515
|
-
name, adj.symbols[name].type, rhs.type
|
|
1516
|
-
)
|
|
1799
|
+
raise WarpCodegenTypeError(
|
|
1800
|
+
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
1517
1801
|
)
|
|
1518
1802
|
|
|
1519
1803
|
adj.symbols[name] = rhs
|
|
1520
1804
|
|
|
1521
|
-
return out
|
|
1522
|
-
|
|
1523
1805
|
# handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
|
|
1524
|
-
elif isinstance(
|
|
1525
|
-
|
|
1806
|
+
elif isinstance(lhs, ast.Subscript):
|
|
1807
|
+
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
1808
|
+
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
1809
|
+
lhs.slice.is_adjoint = True
|
|
1810
|
+
src_var = adj.eval(lhs.slice)
|
|
1811
|
+
var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
|
|
1812
|
+
value = adj.eval(node.value)
|
|
1813
|
+
adj.add_forward(f"{var.emit()} = {value.emit()};")
|
|
1814
|
+
return
|
|
1815
|
+
|
|
1816
|
+
target = adj.eval(lhs.value)
|
|
1526
1817
|
value = adj.eval(node.value)
|
|
1818
|
+
if not is_local_value(value):
|
|
1819
|
+
raise RuntimeError(
|
|
1820
|
+
"Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
|
|
1821
|
+
)
|
|
1527
1822
|
|
|
1528
|
-
slice =
|
|
1823
|
+
slice = lhs.slice
|
|
1529
1824
|
indices = []
|
|
1530
1825
|
|
|
1531
1826
|
if isinstance(slice, ast.Tuple):
|
|
@@ -1533,7 +1828,6 @@ class Adjoint:
|
|
|
1533
1828
|
for arg in slice.elts:
|
|
1534
1829
|
var = adj.eval(arg)
|
|
1535
1830
|
indices.append(var)
|
|
1536
|
-
|
|
1537
1831
|
elif isinstance(slice, ast.Index) and isinstance(slice.value, ast.Tuple):
|
|
1538
1832
|
# handles the x[i, j] case (Python 3.7.x)
|
|
1539
1833
|
for arg in slice.value.elts:
|
|
@@ -1544,64 +1838,84 @@ class Adjoint:
|
|
|
1544
1838
|
var = adj.eval(slice)
|
|
1545
1839
|
indices.append(var)
|
|
1546
1840
|
|
|
1547
|
-
|
|
1548
|
-
adj.add_call(warp.context.builtin_functions["store"], [target, *indices, value])
|
|
1841
|
+
target_type = strip_reference(target.type)
|
|
1549
1842
|
|
|
1550
|
-
|
|
1551
|
-
adj.
|
|
1843
|
+
if is_array(target_type):
|
|
1844
|
+
adj.add_builtin_call("array_store", [target, *indices, value])
|
|
1552
1845
|
|
|
1553
|
-
|
|
1846
|
+
elif type_is_vector(target_type) or type_is_matrix(target_type):
|
|
1847
|
+
if is_reference(target.type):
|
|
1848
|
+
attr = adj.add_builtin_call("indexref", [target, *indices])
|
|
1849
|
+
else:
|
|
1850
|
+
attr = adj.add_builtin_call("index", [target, *indices])
|
|
1851
|
+
|
|
1852
|
+
adj.add_builtin_call("store", [attr, value])
|
|
1853
|
+
|
|
1854
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1554
1855
|
lineno = adj.lineno + adj.fun_lineno
|
|
1555
|
-
line = adj.
|
|
1856
|
+
line = adj.source_lines[adj.lineno]
|
|
1857
|
+
node_source = adj.get_node_source(lhs.value)
|
|
1556
1858
|
print(
|
|
1557
|
-
f"Warning: mutating {
|
|
1859
|
+
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
1558
1860
|
)
|
|
1559
1861
|
|
|
1560
1862
|
else:
|
|
1561
|
-
raise
|
|
1562
|
-
|
|
1563
|
-
return var
|
|
1863
|
+
raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
|
|
1564
1864
|
|
|
1565
|
-
elif isinstance(
|
|
1865
|
+
elif isinstance(lhs, ast.Name):
|
|
1566
1866
|
# symbol name
|
|
1567
|
-
name =
|
|
1867
|
+
name = lhs.id
|
|
1568
1868
|
|
|
1569
1869
|
# evaluate rhs
|
|
1570
1870
|
rhs = adj.eval(node.value)
|
|
1571
1871
|
|
|
1572
1872
|
# check type matches if symbol already defined
|
|
1573
1873
|
if name in adj.symbols:
|
|
1574
|
-
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
1575
|
-
raise
|
|
1576
|
-
"Error, assigning to existing symbol {} ({}) with different type ({})"
|
|
1577
|
-
name, adj.symbols[name].type, rhs.type
|
|
1578
|
-
)
|
|
1874
|
+
if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
|
|
1875
|
+
raise WarpCodegenTypeError(
|
|
1876
|
+
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
1579
1877
|
)
|
|
1580
1878
|
|
|
1581
1879
|
# handle simple assignment case (a = b), where we generate a value copy rather than reference
|
|
1582
|
-
if isinstance(node.value, ast.Name):
|
|
1583
|
-
out = adj.
|
|
1584
|
-
adj.add_call(warp.context.builtin_functions["copy"], [out, rhs])
|
|
1880
|
+
if isinstance(node.value, ast.Name) or is_reference(rhs.type):
|
|
1881
|
+
out = adj.add_builtin_call("copy", [rhs])
|
|
1585
1882
|
else:
|
|
1586
1883
|
out = rhs
|
|
1587
1884
|
|
|
1588
1885
|
# update symbol map (assumes lhs is a Name node)
|
|
1589
1886
|
adj.symbols[name] = out
|
|
1590
|
-
return out
|
|
1591
1887
|
|
|
1592
|
-
elif isinstance(
|
|
1888
|
+
elif isinstance(lhs, ast.Attribute):
|
|
1593
1889
|
rhs = adj.eval(node.value)
|
|
1594
|
-
|
|
1595
|
-
|
|
1890
|
+
aggregate = adj.eval(lhs.value)
|
|
1891
|
+
aggregate_type = strip_reference(aggregate.type)
|
|
1596
1892
|
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1893
|
+
# assigning to a vector component
|
|
1894
|
+
if type_is_vector(aggregate_type):
|
|
1895
|
+
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
1896
|
+
|
|
1897
|
+
if is_reference(aggregate.type):
|
|
1898
|
+
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
1899
|
+
else:
|
|
1900
|
+
attr = adj.add_builtin_call("index", [aggregate, index])
|
|
1901
|
+
|
|
1902
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
1903
|
+
|
|
1904
|
+
else:
|
|
1905
|
+
attr = adj.emit_Attribute(lhs)
|
|
1906
|
+
if is_reference(attr.type):
|
|
1907
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
1908
|
+
else:
|
|
1909
|
+
adj.add_builtin_call("assign", [attr, rhs])
|
|
1910
|
+
|
|
1911
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1912
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
1913
|
+
line = adj.source_lines[adj.lineno]
|
|
1914
|
+
msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
|
|
1915
|
+
print(msg)
|
|
1602
1916
|
|
|
1603
1917
|
else:
|
|
1604
|
-
raise
|
|
1918
|
+
raise WarpCodegenError("Error, unsupported assignment statement.")
|
|
1605
1919
|
|
|
1606
1920
|
def emit_Return(adj, node):
|
|
1607
1921
|
if node.value is None:
|
|
@@ -1612,30 +1926,26 @@ class Adjoint:
|
|
|
1612
1926
|
var = (adj.eval(node.value),)
|
|
1613
1927
|
|
|
1614
1928
|
if adj.return_var is not None:
|
|
1615
|
-
old_ctypes = tuple(v.ctype() for v in adj.return_var)
|
|
1616
|
-
new_ctypes = tuple(v.ctype() for v in var)
|
|
1929
|
+
old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
|
|
1930
|
+
new_ctypes = tuple(v.ctype(value_type=True) for v in var)
|
|
1617
1931
|
if old_ctypes != new_ctypes:
|
|
1618
|
-
raise
|
|
1932
|
+
raise WarpCodegenTypeError(
|
|
1619
1933
|
f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
|
|
1620
1934
|
)
|
|
1621
|
-
else:
|
|
1622
|
-
adj.return_var = var
|
|
1623
|
-
|
|
1624
|
-
adj.add_return(var)
|
|
1625
1935
|
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
name = builtin_operators[type(node.op)]
|
|
1633
|
-
func = warp.context.builtin_functions[name]
|
|
1936
|
+
if var is not None:
|
|
1937
|
+
adj.return_var = tuple()
|
|
1938
|
+
for ret in var:
|
|
1939
|
+
if is_reference(ret.type):
|
|
1940
|
+
ret = adj.add_builtin_call("copy", [ret])
|
|
1941
|
+
adj.return_var += (ret,)
|
|
1634
1942
|
|
|
1635
|
-
|
|
1943
|
+
adj.add_return(adj.return_var)
|
|
1636
1944
|
|
|
1637
|
-
|
|
1638
|
-
|
|
1945
|
+
def emit_AugAssign(adj, node):
|
|
1946
|
+
# replace augmented assignment with assignment statement + binary op
|
|
1947
|
+
new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
|
|
1948
|
+
adj.eval(new_node)
|
|
1639
1949
|
|
|
1640
1950
|
def emit_Tuple(adj, node):
|
|
1641
1951
|
# LHS for expressions, such as i, j, k = 1, 2, 3
|
|
@@ -1645,115 +1955,160 @@ class Adjoint:
|
|
|
1645
1955
|
def emit_Pass(adj, node):
|
|
1646
1956
|
pass
|
|
1647
1957
|
|
|
1958
|
+
node_visitors = {
|
|
1959
|
+
ast.FunctionDef: emit_FunctionDef,
|
|
1960
|
+
ast.If: emit_If,
|
|
1961
|
+
ast.Compare: emit_Compare,
|
|
1962
|
+
ast.BoolOp: emit_BoolOp,
|
|
1963
|
+
ast.Name: emit_Name,
|
|
1964
|
+
ast.Attribute: emit_Attribute,
|
|
1965
|
+
ast.Str: emit_String, # Deprecated in 3.8; use Constant
|
|
1966
|
+
ast.Num: emit_Num, # Deprecated in 3.8; use Constant
|
|
1967
|
+
ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
1968
|
+
ast.Constant: emit_Constant,
|
|
1969
|
+
ast.BinOp: emit_BinOp,
|
|
1970
|
+
ast.UnaryOp: emit_UnaryOp,
|
|
1971
|
+
ast.While: emit_While,
|
|
1972
|
+
ast.For: emit_For,
|
|
1973
|
+
ast.Break: emit_Break,
|
|
1974
|
+
ast.Continue: emit_Continue,
|
|
1975
|
+
ast.Expr: emit_Expr,
|
|
1976
|
+
ast.Call: emit_Call,
|
|
1977
|
+
ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
|
|
1978
|
+
ast.Subscript: emit_Subscript,
|
|
1979
|
+
ast.Assign: emit_Assign,
|
|
1980
|
+
ast.Return: emit_Return,
|
|
1981
|
+
ast.AugAssign: emit_AugAssign,
|
|
1982
|
+
ast.Tuple: emit_Tuple,
|
|
1983
|
+
ast.Pass: emit_Pass,
|
|
1984
|
+
ast.Ellipsis: emit_Ellipsis,
|
|
1985
|
+
}
|
|
1986
|
+
|
|
1648
1987
|
def eval(adj, node):
|
|
1649
1988
|
if hasattr(node, "lineno"):
|
|
1650
1989
|
adj.set_lineno(node.lineno - 1)
|
|
1651
1990
|
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
|
|
1655
|
-
ast.Compare: Adjoint.emit_Compare,
|
|
1656
|
-
ast.BoolOp: Adjoint.emit_BoolOp,
|
|
1657
|
-
ast.Name: Adjoint.emit_Name,
|
|
1658
|
-
ast.Attribute: Adjoint.emit_Attribute,
|
|
1659
|
-
ast.Str: Adjoint.emit_String, # Deprecated in 3.8; use Constant
|
|
1660
|
-
ast.Num: Adjoint.emit_Num, # Deprecated in 3.8; use Constant
|
|
1661
|
-
ast.NameConstant: Adjoint.emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
1662
|
-
ast.Constant: Adjoint.emit_Constant,
|
|
1663
|
-
ast.BinOp: Adjoint.emit_BinOp,
|
|
1664
|
-
ast.UnaryOp: Adjoint.emit_UnaryOp,
|
|
1665
|
-
ast.While: Adjoint.emit_While,
|
|
1666
|
-
ast.For: Adjoint.emit_For,
|
|
1667
|
-
ast.Break: Adjoint.emit_Break,
|
|
1668
|
-
ast.Continue: Adjoint.emit_Continue,
|
|
1669
|
-
ast.Expr: Adjoint.emit_Expr,
|
|
1670
|
-
ast.Call: Adjoint.emit_Call,
|
|
1671
|
-
ast.Index: Adjoint.emit_Index, # Deprecated in 3.8; Use the index value directly instead.
|
|
1672
|
-
ast.Subscript: Adjoint.emit_Subscript,
|
|
1673
|
-
ast.Assign: Adjoint.emit_Assign,
|
|
1674
|
-
ast.Return: Adjoint.emit_Return,
|
|
1675
|
-
ast.AugAssign: Adjoint.emit_AugAssign,
|
|
1676
|
-
ast.Tuple: Adjoint.emit_Tuple,
|
|
1677
|
-
ast.Pass: Adjoint.emit_Pass,
|
|
1678
|
-
}
|
|
1679
|
-
|
|
1680
|
-
emit_node = node_visitors.get(type(node))
|
|
1681
|
-
|
|
1682
|
-
if emit_node is not None:
|
|
1683
|
-
return emit_node(adj, node)
|
|
1684
|
-
else:
|
|
1685
|
-
raise Exception("Error, ast node of type {} not supported".format(type(node)))
|
|
1991
|
+
emit_node = adj.node_visitors[type(node)]
|
|
1992
|
+
|
|
1993
|
+
return emit_node(adj, node)
|
|
1686
1994
|
|
|
1687
1995
|
# helper to evaluate expressions of the form
|
|
1688
1996
|
# obj1.obj2.obj3.attr in the function's global scope
|
|
1689
|
-
def resolve_path(adj,
|
|
1690
|
-
|
|
1997
|
+
def resolve_path(adj, path):
|
|
1998
|
+
if len(path) == 0:
|
|
1999
|
+
return None
|
|
1691
2000
|
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
2001
|
+
# if root is overshadowed by local symbols, bail out
|
|
2002
|
+
if path[0] in adj.symbols:
|
|
2003
|
+
return None
|
|
1695
2004
|
|
|
1696
|
-
if
|
|
1697
|
-
|
|
2005
|
+
if path[0] in __builtins__:
|
|
2006
|
+
return __builtins__[path[0]]
|
|
1698
2007
|
|
|
1699
|
-
#
|
|
1700
|
-
|
|
2008
|
+
# Look up the closure info and append it to adj.func.__globals__
|
|
2009
|
+
# in case you want to define a kernel inside a function and refer
|
|
2010
|
+
# to variables you've declared inside that function:
|
|
2011
|
+
extract_contents = (
|
|
2012
|
+
lambda contents: contents
|
|
2013
|
+
if isinstance(contents, warp.context.Function) or not callable(contents)
|
|
2014
|
+
else contents
|
|
2015
|
+
)
|
|
2016
|
+
capturedvars = dict(
|
|
2017
|
+
zip(
|
|
2018
|
+
adj.func.__code__.co_freevars,
|
|
2019
|
+
[extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
|
|
2020
|
+
)
|
|
2021
|
+
)
|
|
2022
|
+
vars_dict = {**adj.func.__globals__, **capturedvars}
|
|
1701
2023
|
|
|
1702
|
-
if
|
|
1703
|
-
|
|
2024
|
+
if path[0] in vars_dict:
|
|
2025
|
+
func = vars_dict[path[0]]
|
|
1704
2026
|
|
|
1705
|
-
#
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
# in case you want to define a kernel inside a function and refer
|
|
1709
|
-
# to variables you've declared inside that function:
|
|
1710
|
-
extract_contents = (
|
|
1711
|
-
lambda contents: contents
|
|
1712
|
-
if isinstance(contents, warp.context.Function) or not callable(contents)
|
|
1713
|
-
else contents
|
|
1714
|
-
)
|
|
1715
|
-
capturedvars = dict(
|
|
1716
|
-
zip(
|
|
1717
|
-
adj.func.__code__.co_freevars,
|
|
1718
|
-
[extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
|
|
1719
|
-
)
|
|
1720
|
-
)
|
|
2027
|
+
# Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
|
|
2028
|
+
else:
|
|
2029
|
+
func = getattr(warp, path[0], None)
|
|
1721
2030
|
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
pass
|
|
2031
|
+
if func:
|
|
2032
|
+
for i in range(1, len(path)):
|
|
2033
|
+
if hasattr(func, path[i]):
|
|
2034
|
+
func = getattr(func, path[i])
|
|
1727
2035
|
|
|
1728
|
-
|
|
1729
|
-
# in a kernel:
|
|
2036
|
+
return func
|
|
1730
2037
|
|
|
1731
|
-
|
|
2038
|
+
# Evaluates a static expression that does not depend on runtime values
|
|
2039
|
+
# if eval_types is True, try resolving the path using evaluated type information as well
|
|
2040
|
+
def resolve_static_expression(adj, root_node, eval_types=True):
|
|
2041
|
+
attributes = []
|
|
1732
2042
|
|
|
1733
|
-
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
2043
|
+
node = root_node
|
|
2044
|
+
while isinstance(node, ast.Attribute):
|
|
2045
|
+
attributes.append(node.attr)
|
|
2046
|
+
node = node.value
|
|
1737
2047
|
|
|
1738
|
-
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
2048
|
+
if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
|
2049
|
+
# support for operators returning modules
|
|
2050
|
+
# i.e. operator_name(*operator_args).x.y.z
|
|
2051
|
+
operator_args = node.args
|
|
2052
|
+
operator_name = node.func.id
|
|
2053
|
+
|
|
2054
|
+
if operator_name == "type":
|
|
2055
|
+
if len(operator_args) != 1:
|
|
2056
|
+
raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
|
|
2057
|
+
|
|
2058
|
+
# type() operator
|
|
2059
|
+
var = adj.eval(operator_args[0])
|
|
2060
|
+
|
|
2061
|
+
if isinstance(var, Var):
|
|
2062
|
+
var_type = strip_reference(var.type)
|
|
2063
|
+
# Allow accessing type attributes, for instance array.dtype
|
|
2064
|
+
while attributes:
|
|
2065
|
+
attr_name = attributes.pop()
|
|
2066
|
+
var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
|
|
2067
|
+
|
|
2068
|
+
if var_type is None:
|
|
2069
|
+
raise WarpCodegenAttributeError(
|
|
2070
|
+
f"{attr_name} is not an attribute of {type_repr(prev_type)}"
|
|
2071
|
+
)
|
|
2072
|
+
|
|
2073
|
+
return var_type, [type_repr(var_type)]
|
|
2074
|
+
else:
|
|
2075
|
+
raise WarpCodegenError(f"Cannot deduce the type of {var}")
|
|
2076
|
+
|
|
2077
|
+
# reverse list since ast presents it backward order
|
|
2078
|
+
path = [*reversed(attributes)]
|
|
2079
|
+
if isinstance(node, ast.Name):
|
|
2080
|
+
path.insert(0, node.id)
|
|
2081
|
+
|
|
2082
|
+
# Try resolving path from captured context
|
|
2083
|
+
captured_obj = adj.resolve_path(path)
|
|
2084
|
+
if captured_obj is not None:
|
|
2085
|
+
return captured_obj, path
|
|
2086
|
+
|
|
2087
|
+
# Still nothing found, maybe this is a predefined type attribute like `dtype`
|
|
2088
|
+
if eval_types:
|
|
2089
|
+
try:
|
|
2090
|
+
val = adj.eval(root_node)
|
|
2091
|
+
if val:
|
|
2092
|
+
return [val, type_repr(val)]
|
|
2093
|
+
|
|
2094
|
+
except Exception:
|
|
2095
|
+
pass
|
|
2096
|
+
|
|
2097
|
+
return None, path
|
|
1747
2098
|
|
|
1748
2099
|
# annotate generated code with the original source code line
|
|
1749
2100
|
def set_lineno(adj, lineno):
|
|
1750
2101
|
if adj.lineno is None or adj.lineno != lineno:
|
|
1751
2102
|
line = lineno + adj.fun_lineno
|
|
1752
|
-
source = adj.
|
|
2103
|
+
source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
|
|
1753
2104
|
adj.add_forward(f"// {source} <L {line}>")
|
|
1754
2105
|
adj.add_reverse(f"// adj: {source} <L {line}>")
|
|
1755
2106
|
adj.lineno = lineno
|
|
1756
2107
|
|
|
2108
|
+
def get_node_source(adj, node):
|
|
2109
|
+
# return the Python code corresponding to the given AST node
|
|
2110
|
+
return ast.get_source_segment(adj.source, node)
|
|
2111
|
+
|
|
1757
2112
|
|
|
1758
2113
|
# ----------------
|
|
1759
2114
|
# code generation
|
|
@@ -1769,7 +2124,10 @@ cpu_module_header = """
|
|
|
1769
2124
|
#define int(x) cast_int(x)
|
|
1770
2125
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
1771
2126
|
|
|
1772
|
-
|
|
2127
|
+
#define builtin_tid1d() wp::tid(wp::s_threadIdx)
|
|
2128
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, wp::s_threadIdx, dim)
|
|
2129
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, wp::s_threadIdx, dim)
|
|
2130
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, wp::s_threadIdx, dim)
|
|
1773
2131
|
|
|
1774
2132
|
"""
|
|
1775
2133
|
|
|
@@ -1784,8 +2142,10 @@ cuda_module_header = """
|
|
|
1784
2142
|
#define int(x) cast_int(x)
|
|
1785
2143
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
1786
2144
|
|
|
1787
|
-
|
|
1788
|
-
|
|
2145
|
+
#define builtin_tid1d() wp::tid(_idx)
|
|
2146
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
|
|
2147
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
2148
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
1789
2149
|
|
|
1790
2150
|
"""
|
|
1791
2151
|
|
|
@@ -1799,54 +2159,56 @@ struct {name}
|
|
|
1799
2159
|
{{
|
|
1800
2160
|
}}
|
|
1801
2161
|
|
|
1802
|
-
CUDA_CALLABLE {name}& operator += (const {name}&)
|
|
2162
|
+
CUDA_CALLABLE {name}& operator += (const {name}& rhs)
|
|
2163
|
+
{{{prefix_add_body}
|
|
2164
|
+
return *this;}}
|
|
1803
2165
|
|
|
1804
2166
|
}};
|
|
1805
2167
|
|
|
1806
2168
|
static CUDA_CALLABLE void adj_{name}({reverse_args})
|
|
1807
2169
|
{{
|
|
1808
|
-
{reverse_body}
|
|
1809
|
-
}}
|
|
2170
|
+
{reverse_body}}}
|
|
1810
2171
|
|
|
1811
|
-
CUDA_CALLABLE void
|
|
2172
|
+
CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
|
|
1812
2173
|
{{
|
|
1813
|
-
{atomic_add_body}
|
|
1814
|
-
}}
|
|
2174
|
+
{atomic_add_body}}}
|
|
1815
2175
|
|
|
1816
2176
|
|
|
1817
2177
|
"""
|
|
1818
2178
|
|
|
1819
|
-
|
|
2179
|
+
cpu_forward_function_template = """
|
|
1820
2180
|
// {filename}:{lineno}
|
|
1821
2181
|
static {return_type} {name}(
|
|
1822
2182
|
{forward_args})
|
|
1823
2183
|
{{
|
|
1824
|
-
{forward_body}
|
|
1825
|
-
|
|
2184
|
+
{forward_body}}}
|
|
2185
|
+
|
|
2186
|
+
"""
|
|
1826
2187
|
|
|
2188
|
+
cpu_reverse_function_template = """
|
|
1827
2189
|
// {filename}:{lineno}
|
|
1828
2190
|
static void adj_{name}(
|
|
1829
2191
|
{reverse_args})
|
|
1830
2192
|
{{
|
|
1831
|
-
{reverse_body}
|
|
1832
|
-
}}
|
|
2193
|
+
{reverse_body}}}
|
|
1833
2194
|
|
|
1834
2195
|
"""
|
|
1835
2196
|
|
|
1836
|
-
|
|
2197
|
+
cuda_forward_function_template = """
|
|
1837
2198
|
// {filename}:{lineno}
|
|
1838
2199
|
static CUDA_CALLABLE {return_type} {name}(
|
|
1839
2200
|
{forward_args})
|
|
1840
2201
|
{{
|
|
1841
|
-
{forward_body}
|
|
1842
|
-
}}
|
|
2202
|
+
{forward_body}}}
|
|
1843
2203
|
|
|
2204
|
+
"""
|
|
2205
|
+
|
|
2206
|
+
cuda_reverse_function_template = """
|
|
1844
2207
|
// {filename}:{lineno}
|
|
1845
2208
|
static CUDA_CALLABLE void adj_{name}(
|
|
1846
2209
|
{reverse_args})
|
|
1847
2210
|
{{
|
|
1848
|
-
{reverse_body}
|
|
1849
|
-
}}
|
|
2211
|
+
{reverse_body}}}
|
|
1850
2212
|
|
|
1851
2213
|
"""
|
|
1852
2214
|
|
|
@@ -1855,25 +2217,21 @@ cuda_kernel_template = """
|
|
|
1855
2217
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
1856
2218
|
{forward_args})
|
|
1857
2219
|
{{
|
|
1858
|
-
size_t _idx =
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
{forward_body}
|
|
2220
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
2221
|
+
_idx < dim.size;
|
|
2222
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
2223
|
+
{{
|
|
2224
|
+
{forward_body} }}
|
|
1865
2225
|
}}
|
|
1866
2226
|
|
|
1867
2227
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
1868
2228
|
{reverse_args})
|
|
1869
2229
|
{{
|
|
1870
|
-
size_t _idx =
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
{reverse_body}
|
|
2230
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
2231
|
+
_idx < dim.size;
|
|
2232
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
2233
|
+
{{
|
|
2234
|
+
{reverse_body} }}
|
|
1877
2235
|
}}
|
|
1878
2236
|
|
|
1879
2237
|
"""
|
|
@@ -1883,14 +2241,12 @@ cpu_kernel_template = """
|
|
|
1883
2241
|
void {name}_cpu_kernel_forward(
|
|
1884
2242
|
{forward_args})
|
|
1885
2243
|
{{
|
|
1886
|
-
{forward_body}
|
|
1887
|
-
}}
|
|
2244
|
+
{forward_body}}}
|
|
1888
2245
|
|
|
1889
2246
|
void {name}_cpu_kernel_backward(
|
|
1890
2247
|
{reverse_args})
|
|
1891
2248
|
{{
|
|
1892
|
-
{reverse_body}
|
|
1893
|
-
}}
|
|
2249
|
+
{reverse_body}}}
|
|
1894
2250
|
|
|
1895
2251
|
"""
|
|
1896
2252
|
|
|
@@ -1902,11 +2258,9 @@ extern "C" {{
|
|
|
1902
2258
|
WP_API void {name}_cpu_forward(
|
|
1903
2259
|
{forward_args})
|
|
1904
2260
|
{{
|
|
1905
|
-
set_launch_bounds(dim);
|
|
1906
|
-
|
|
1907
2261
|
for (size_t i=0; i < dim.size; ++i)
|
|
1908
2262
|
{{
|
|
1909
|
-
s_threadIdx = i;
|
|
2263
|
+
wp::s_threadIdx = i;
|
|
1910
2264
|
|
|
1911
2265
|
{name}_cpu_kernel_forward(
|
|
1912
2266
|
{forward_params});
|
|
@@ -1916,11 +2270,9 @@ WP_API void {name}_cpu_forward(
|
|
|
1916
2270
|
WP_API void {name}_cpu_backward(
|
|
1917
2271
|
{reverse_args})
|
|
1918
2272
|
{{
|
|
1919
|
-
set_launch_bounds(dim);
|
|
1920
|
-
|
|
1921
2273
|
for (size_t i=0; i < dim.size; ++i)
|
|
1922
2274
|
{{
|
|
1923
|
-
s_threadIdx = i;
|
|
2275
|
+
wp::s_threadIdx = i;
|
|
1924
2276
|
|
|
1925
2277
|
{name}_cpu_kernel_backward(
|
|
1926
2278
|
{reverse_params});
|
|
@@ -1966,7 +2318,7 @@ WP_API void {name}_cpu_backward(
|
|
|
1966
2318
|
def constant_str(value):
|
|
1967
2319
|
value_type = type(value)
|
|
1968
2320
|
|
|
1969
|
-
if value_type == bool:
|
|
2321
|
+
if value_type == bool or value_type == builtins.bool:
|
|
1970
2322
|
if value:
|
|
1971
2323
|
return "true"
|
|
1972
2324
|
else:
|
|
@@ -1983,7 +2335,9 @@ def constant_str(value):
|
|
|
1983
2335
|
|
|
1984
2336
|
scalar_value = runtime.core.half_bits_to_float
|
|
1985
2337
|
else:
|
|
1986
|
-
|
|
2338
|
+
|
|
2339
|
+
def scalar_value(x):
|
|
2340
|
+
return x
|
|
1987
2341
|
|
|
1988
2342
|
# list of scalar initializer values
|
|
1989
2343
|
initlist = []
|
|
@@ -2000,6 +2354,9 @@ def constant_str(value):
|
|
|
2000
2354
|
# make sure we emit the value of objects, e.g. uint32
|
|
2001
2355
|
return str(value.value)
|
|
2002
2356
|
|
|
2357
|
+
elif value == math.inf:
|
|
2358
|
+
return "INFINITY"
|
|
2359
|
+
|
|
2003
2360
|
else:
|
|
2004
2361
|
# otherwise just convert constant to string
|
|
2005
2362
|
return str(value)
|
|
@@ -2008,7 +2365,7 @@ def constant_str(value):
|
|
|
2008
2365
|
def indent(args, stops=1):
|
|
2009
2366
|
sep = ",\n"
|
|
2010
2367
|
for i in range(stops):
|
|
2011
|
-
sep += "
|
|
2368
|
+
sep += " "
|
|
2012
2369
|
|
|
2013
2370
|
# return sep + args.replace(", ", "," + sep)
|
|
2014
2371
|
return sep.join(args)
|
|
@@ -2016,7 +2373,9 @@ def indent(args, stops=1):
|
|
|
2016
2373
|
|
|
2017
2374
|
# generates a C function name based on the python function name
|
|
2018
2375
|
def make_full_qualified_name(func):
|
|
2019
|
-
|
|
2376
|
+
if not isinstance(func, str):
|
|
2377
|
+
func = func.__qualname__
|
|
2378
|
+
return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
|
|
2020
2379
|
|
|
2021
2380
|
|
|
2022
2381
|
def codegen_struct(struct, device="cpu", indent_size=4):
|
|
@@ -2024,8 +2383,13 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2024
2383
|
|
|
2025
2384
|
body = []
|
|
2026
2385
|
indent_block = " " * indent_size
|
|
2027
|
-
|
|
2028
|
-
|
|
2386
|
+
|
|
2387
|
+
if len(struct.vars) > 0:
|
|
2388
|
+
for label, var in struct.vars.items():
|
|
2389
|
+
body.append(var.ctype() + " " + label + ";\n")
|
|
2390
|
+
else:
|
|
2391
|
+
# for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
|
|
2392
|
+
body.append("char _dummy_;\n")
|
|
2029
2393
|
|
|
2030
2394
|
forward_args = []
|
|
2031
2395
|
reverse_args = []
|
|
@@ -2033,24 +2397,32 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2033
2397
|
forward_initializers = []
|
|
2034
2398
|
reverse_body = []
|
|
2035
2399
|
atomic_add_body = []
|
|
2400
|
+
prefix_add_body = []
|
|
2036
2401
|
|
|
2037
2402
|
# forward args
|
|
2038
2403
|
for label, var in struct.vars.items():
|
|
2039
|
-
|
|
2040
|
-
|
|
2404
|
+
var_ctype = var.ctype()
|
|
2405
|
+
forward_args.append(f"{var_ctype} const& {label} = {{}}")
|
|
2406
|
+
reverse_args.append(f"{var_ctype} const&")
|
|
2041
2407
|
|
|
2042
|
-
|
|
2408
|
+
namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
|
|
2409
|
+
atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
|
|
2043
2410
|
|
|
2044
2411
|
prefix = f"{indent_block}," if forward_initializers else ":"
|
|
2045
2412
|
forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
|
|
2046
2413
|
|
|
2414
|
+
# prefix-add operator
|
|
2415
|
+
for label, var in struct.vars.items():
|
|
2416
|
+
if not is_array(var.type):
|
|
2417
|
+
prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
|
|
2418
|
+
|
|
2047
2419
|
# reverse args
|
|
2048
2420
|
for label, var in struct.vars.items():
|
|
2049
2421
|
reverse_args.append(var.ctype() + " & adj_" + label)
|
|
2050
2422
|
if is_array(var.type):
|
|
2051
|
-
reverse_body.append(f"adj_{label} =
|
|
2423
|
+
reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
|
|
2052
2424
|
else:
|
|
2053
|
-
reverse_body.append(f"adj_{label} +=
|
|
2425
|
+
reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
|
|
2054
2426
|
|
|
2055
2427
|
reverse_args.append(name + " & adj_ret")
|
|
2056
2428
|
|
|
@@ -2061,109 +2433,101 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
2061
2433
|
forward_initializers="".join(forward_initializers),
|
|
2062
2434
|
reverse_args=indent(reverse_args),
|
|
2063
2435
|
reverse_body="".join(reverse_body),
|
|
2436
|
+
prefix_add_body="".join(prefix_add_body),
|
|
2064
2437
|
atomic_add_body="".join(atomic_add_body),
|
|
2065
2438
|
)
|
|
2066
2439
|
|
|
2067
2440
|
|
|
2068
|
-
def codegen_func_forward_body(adj, device="cpu", indent=4):
|
|
2069
|
-
body = []
|
|
2070
|
-
indent_block = " " * indent
|
|
2071
|
-
|
|
2072
|
-
for f in adj.blocks[0].body_forward:
|
|
2073
|
-
body += [f + "\n"]
|
|
2074
|
-
|
|
2075
|
-
return "".join([indent_block + l for l in body])
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
2441
|
def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
2079
|
-
|
|
2442
|
+
if device == "cpu":
|
|
2443
|
+
indent = 4
|
|
2444
|
+
elif device == "cuda":
|
|
2445
|
+
if func_type == "kernel":
|
|
2446
|
+
indent = 8
|
|
2447
|
+
else:
|
|
2448
|
+
indent = 4
|
|
2449
|
+
else:
|
|
2450
|
+
raise ValueError(f"Device {device} not supported for codegen")
|
|
2451
|
+
|
|
2452
|
+
indent_block = " " * indent
|
|
2080
2453
|
|
|
2081
2454
|
# primal vars
|
|
2082
|
-
|
|
2083
|
-
|
|
2455
|
+
lines = []
|
|
2456
|
+
lines += ["//---------\n"]
|
|
2457
|
+
lines += ["// primal vars\n"]
|
|
2084
2458
|
|
|
2085
2459
|
for var in adj.variables:
|
|
2086
2460
|
if var.constant is None:
|
|
2087
|
-
|
|
2461
|
+
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
2088
2462
|
else:
|
|
2089
|
-
|
|
2463
|
+
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
2090
2464
|
|
|
2091
2465
|
# forward pass
|
|
2092
|
-
|
|
2093
|
-
|
|
2466
|
+
lines += ["//---------\n"]
|
|
2467
|
+
lines += ["// forward\n"]
|
|
2094
2468
|
|
|
2095
|
-
|
|
2096
|
-
|
|
2469
|
+
for f in adj.blocks[0].body_forward:
|
|
2470
|
+
lines += [f + "\n"]
|
|
2097
2471
|
|
|
2472
|
+
return "".join([indent_block + l for l in lines])
|
|
2473
|
+
|
|
2474
|
+
|
|
2475
|
+
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
2476
|
+
if device == "cpu":
|
|
2477
|
+
indent = 4
|
|
2098
2478
|
elif device == "cuda":
|
|
2099
2479
|
if func_type == "kernel":
|
|
2100
|
-
|
|
2480
|
+
indent = 8
|
|
2101
2481
|
else:
|
|
2102
|
-
|
|
2103
|
-
|
|
2104
|
-
|
|
2105
|
-
|
|
2482
|
+
indent = 4
|
|
2483
|
+
else:
|
|
2484
|
+
raise ValueError(f"Device {device} not supported for codegen")
|
|
2106
2485
|
|
|
2107
|
-
def codegen_func_reverse_body(adj, device="cpu", indent=4):
|
|
2108
|
-
body = []
|
|
2109
2486
|
indent_block = " " * indent
|
|
2110
2487
|
|
|
2111
|
-
|
|
2112
|
-
body += ["//---------\n"]
|
|
2113
|
-
body += ["// forward\n"]
|
|
2114
|
-
|
|
2115
|
-
for f in adj.blocks[0].body_replay:
|
|
2116
|
-
body += [f + "\n"]
|
|
2117
|
-
|
|
2118
|
-
# reverse pass
|
|
2119
|
-
body += ["//---------\n"]
|
|
2120
|
-
body += ["// reverse\n"]
|
|
2121
|
-
|
|
2122
|
-
for l in reversed(adj.blocks[0].body_reverse):
|
|
2123
|
-
body += [l + "\n"]
|
|
2124
|
-
|
|
2125
|
-
body += ["return;\n"]
|
|
2126
|
-
|
|
2127
|
-
return "".join([indent_block + l for l in body])
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
2131
|
-
s = ""
|
|
2488
|
+
lines = []
|
|
2132
2489
|
|
|
2133
2490
|
# primal vars
|
|
2134
|
-
|
|
2135
|
-
|
|
2491
|
+
lines += ["//---------\n"]
|
|
2492
|
+
lines += ["// primal vars\n"]
|
|
2136
2493
|
|
|
2137
2494
|
for var in adj.variables:
|
|
2138
2495
|
if var.constant is None:
|
|
2139
|
-
|
|
2496
|
+
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
2140
2497
|
else:
|
|
2141
|
-
|
|
2498
|
+
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
2142
2499
|
|
|
2143
2500
|
# dual vars
|
|
2144
|
-
|
|
2145
|
-
|
|
2501
|
+
lines += ["//---------\n"]
|
|
2502
|
+
lines += ["// dual vars\n"]
|
|
2146
2503
|
|
|
2147
2504
|
for var in adj.variables:
|
|
2148
|
-
|
|
2149
|
-
s += " " + var.ctype() + " adj_" + str(var.label) + ";\n"
|
|
2150
|
-
else:
|
|
2151
|
-
s += " " + var.ctype() + " adj_" + str(var.label) + "(0);\n"
|
|
2505
|
+
lines += [f"{var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"]
|
|
2152
2506
|
|
|
2153
|
-
|
|
2154
|
-
|
|
2155
|
-
|
|
2156
|
-
|
|
2157
|
-
|
|
2158
|
-
|
|
2159
|
-
|
|
2507
|
+
# forward pass
|
|
2508
|
+
lines += ["//---------\n"]
|
|
2509
|
+
lines += ["// forward\n"]
|
|
2510
|
+
|
|
2511
|
+
for f in adj.blocks[0].body_replay:
|
|
2512
|
+
lines += [f + "\n"]
|
|
2513
|
+
|
|
2514
|
+
# reverse pass
|
|
2515
|
+
lines += ["//---------\n"]
|
|
2516
|
+
lines += ["// reverse\n"]
|
|
2517
|
+
|
|
2518
|
+
for l in reversed(adj.blocks[0].body_reverse):
|
|
2519
|
+
lines += [l + "\n"]
|
|
2520
|
+
|
|
2521
|
+
# In grid-stride kernels the reverse body is in a for loop
|
|
2522
|
+
if device == "cuda" and func_type == "kernel":
|
|
2523
|
+
lines += ["continue;\n"]
|
|
2160
2524
|
else:
|
|
2161
|
-
|
|
2525
|
+
lines += ["return;\n"]
|
|
2162
2526
|
|
|
2163
|
-
return
|
|
2527
|
+
return "".join([indent_block + l for l in lines])
|
|
2164
2528
|
|
|
2165
2529
|
|
|
2166
|
-
def codegen_func(adj,
|
|
2530
|
+
def codegen_func(adj, c_func_name: str, device="cpu", options={}):
|
|
2167
2531
|
# forward header
|
|
2168
2532
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
2169
2533
|
return_type = adj.return_var[0].ctype()
|
|
@@ -2176,16 +2540,20 @@ def codegen_func(adj, name, device="cpu", options={}):
|
|
|
2176
2540
|
reverse_args = []
|
|
2177
2541
|
|
|
2178
2542
|
# forward args
|
|
2179
|
-
for arg in adj.args:
|
|
2180
|
-
|
|
2181
|
-
|
|
2543
|
+
for i, arg in enumerate(adj.args):
|
|
2544
|
+
s = f"{arg.ctype()} {arg.emit()}"
|
|
2545
|
+
forward_args.append(s)
|
|
2546
|
+
if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
|
|
2547
|
+
reverse_args.append(s)
|
|
2182
2548
|
if has_multiple_outputs:
|
|
2183
2549
|
for i, arg in enumerate(adj.return_var):
|
|
2184
2550
|
forward_args.append(arg.ctype() + " & ret_" + str(i))
|
|
2185
2551
|
reverse_args.append(arg.ctype() + " & ret_" + str(i))
|
|
2186
2552
|
|
|
2187
2553
|
# reverse args
|
|
2188
|
-
for arg in adj.args:
|
|
2554
|
+
for i, arg in enumerate(adj.args):
|
|
2555
|
+
if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
|
|
2556
|
+
break
|
|
2189
2557
|
# indexed array gradients are regular arrays
|
|
2190
2558
|
if isinstance(arg.type, indexedarray):
|
|
2191
2559
|
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
@@ -2197,28 +2565,96 @@ def codegen_func(adj, name, device="cpu", options={}):
|
|
|
2197
2565
|
reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
|
|
2198
2566
|
elif return_type != "void":
|
|
2199
2567
|
reverse_args.append(return_type + " & adj_ret")
|
|
2568
|
+
# custom output reverse args (user-declared)
|
|
2569
|
+
if adj.custom_reverse_mode:
|
|
2570
|
+
for arg in adj.args[adj.custom_reverse_num_input_args :]:
|
|
2571
|
+
reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
|
|
2572
|
+
|
|
2573
|
+
if device == "cpu":
|
|
2574
|
+
forward_template = cpu_forward_function_template
|
|
2575
|
+
reverse_template = cpu_reverse_function_template
|
|
2576
|
+
elif device == "cuda":
|
|
2577
|
+
forward_template = cuda_forward_function_template
|
|
2578
|
+
reverse_template = cuda_reverse_function_template
|
|
2579
|
+
else:
|
|
2580
|
+
raise ValueError(f"Device {device} is not supported")
|
|
2200
2581
|
|
|
2201
2582
|
# codegen body
|
|
2202
2583
|
forward_body = codegen_func_forward(adj, func_type="function", device=device)
|
|
2203
2584
|
|
|
2204
|
-
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2585
|
+
s = ""
|
|
2586
|
+
if not adj.skip_forward_codegen:
|
|
2587
|
+
s += forward_template.format(
|
|
2588
|
+
name=c_func_name,
|
|
2589
|
+
return_type=return_type,
|
|
2590
|
+
forward_args=indent(forward_args),
|
|
2591
|
+
forward_body=forward_body,
|
|
2592
|
+
filename=adj.filename,
|
|
2593
|
+
lineno=adj.fun_lineno,
|
|
2594
|
+
)
|
|
2208
2595
|
|
|
2209
|
-
if
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
|
|
2213
|
-
|
|
2214
|
-
|
|
2596
|
+
if not adj.skip_reverse_codegen:
|
|
2597
|
+
if adj.custom_reverse_mode:
|
|
2598
|
+
reverse_body = "\t// user-defined adjoint code\n" + forward_body
|
|
2599
|
+
else:
|
|
2600
|
+
if options.get("enable_backward", True):
|
|
2601
|
+
reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
|
|
2602
|
+
else:
|
|
2603
|
+
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
|
|
2604
|
+
s += reverse_template.format(
|
|
2605
|
+
name=c_func_name,
|
|
2606
|
+
return_type=return_type,
|
|
2607
|
+
reverse_args=indent(reverse_args),
|
|
2608
|
+
forward_body=forward_body,
|
|
2609
|
+
reverse_body=reverse_body,
|
|
2610
|
+
filename=adj.filename,
|
|
2611
|
+
lineno=adj.fun_lineno,
|
|
2612
|
+
)
|
|
2215
2613
|
|
|
2216
|
-
s
|
|
2614
|
+
return s
|
|
2615
|
+
|
|
2616
|
+
|
|
2617
|
+
def codegen_snippet(adj, name, snippet, adj_snippet):
|
|
2618
|
+
forward_args = []
|
|
2619
|
+
reverse_args = []
|
|
2620
|
+
|
|
2621
|
+
# forward args
|
|
2622
|
+
for i, arg in enumerate(adj.args):
|
|
2623
|
+
s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
|
|
2624
|
+
forward_args.append(s)
|
|
2625
|
+
reverse_args.append(s)
|
|
2626
|
+
|
|
2627
|
+
# reverse args
|
|
2628
|
+
for i, arg in enumerate(adj.args):
|
|
2629
|
+
if isinstance(arg.type, indexedarray):
|
|
2630
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
2631
|
+
reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
|
|
2632
|
+
else:
|
|
2633
|
+
reverse_args.append(arg.ctype() + " & adj_" + arg.label)
|
|
2634
|
+
|
|
2635
|
+
forward_template = cuda_forward_function_template
|
|
2636
|
+
reverse_template = cuda_reverse_function_template
|
|
2637
|
+
|
|
2638
|
+
s = ""
|
|
2639
|
+
s += forward_template.format(
|
|
2217
2640
|
name=name,
|
|
2218
|
-
return_type=
|
|
2641
|
+
return_type="void",
|
|
2219
2642
|
forward_args=indent(forward_args),
|
|
2643
|
+
forward_body=snippet,
|
|
2644
|
+
filename=adj.filename,
|
|
2645
|
+
lineno=adj.fun_lineno,
|
|
2646
|
+
)
|
|
2647
|
+
|
|
2648
|
+
if adj_snippet:
|
|
2649
|
+
reverse_body = adj_snippet
|
|
2650
|
+
else:
|
|
2651
|
+
reverse_body = ""
|
|
2652
|
+
|
|
2653
|
+
s += reverse_template.format(
|
|
2654
|
+
name=name,
|
|
2655
|
+
return_type="void",
|
|
2220
2656
|
reverse_args=indent(reverse_args),
|
|
2221
|
-
forward_body=
|
|
2657
|
+
forward_body=snippet,
|
|
2222
2658
|
reverse_body=reverse_body,
|
|
2223
2659
|
filename=adj.filename,
|
|
2224
2660
|
lineno=adj.fun_lineno,
|
|
@@ -2234,8 +2670,8 @@ def codegen_kernel(kernel, device, options):
|
|
|
2234
2670
|
|
|
2235
2671
|
adj = kernel.adj
|
|
2236
2672
|
|
|
2237
|
-
forward_args = ["launch_bounds_t dim"]
|
|
2238
|
-
reverse_args = ["launch_bounds_t dim"]
|
|
2673
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
2674
|
+
reverse_args = ["wp::launch_bounds_t dim"]
|
|
2239
2675
|
|
|
2240
2676
|
# forward args
|
|
2241
2677
|
for arg in adj.args:
|
|
@@ -2264,7 +2700,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
2264
2700
|
elif device == "cuda":
|
|
2265
2701
|
template = cuda_kernel_template
|
|
2266
2702
|
else:
|
|
2267
|
-
raise ValueError("Device {} is not supported"
|
|
2703
|
+
raise ValueError(f"Device {device} is not supported")
|
|
2268
2704
|
|
|
2269
2705
|
s = template.format(
|
|
2270
2706
|
name=kernel.get_mangled_name(),
|
|
@@ -2284,7 +2720,7 @@ def codegen_module(kernel, device="cpu"):
|
|
|
2284
2720
|
adj = kernel.adj
|
|
2285
2721
|
|
|
2286
2722
|
# build forward signature
|
|
2287
|
-
forward_args = ["launch_bounds_t dim"]
|
|
2723
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
2288
2724
|
forward_params = ["dim"]
|
|
2289
2725
|
|
|
2290
2726
|
for arg in adj.args:
|