warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -7,23 +7,40 @@
|
|
|
7
7
|
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
|
-
import re
|
|
11
|
-
import sys
|
|
12
10
|
import ast
|
|
13
|
-
import
|
|
11
|
+
import builtins
|
|
14
12
|
import ctypes
|
|
13
|
+
import inspect
|
|
14
|
+
import math
|
|
15
|
+
import re
|
|
16
|
+
import sys
|
|
15
17
|
import textwrap
|
|
16
18
|
import types
|
|
19
|
+
from typing import Any, Callable, Mapping
|
|
17
20
|
|
|
18
|
-
import
|
|
21
|
+
import warp.config
|
|
22
|
+
from warp.types import *
|
|
19
23
|
|
|
20
|
-
from typing import Any
|
|
21
|
-
from typing import Callable
|
|
22
|
-
from typing import Mapping
|
|
23
|
-
from typing import Union
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
|
|
25
|
+
class WarpCodegenError(RuntimeError):
|
|
26
|
+
def __init__(self, message):
|
|
27
|
+
super().__init__(message)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class WarpCodegenTypeError(TypeError):
|
|
31
|
+
def __init__(self, message):
|
|
32
|
+
super().__init__(message)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class WarpCodegenAttributeError(AttributeError):
|
|
36
|
+
def __init__(self, message):
|
|
37
|
+
super().__init__(message)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class WarpCodegenKeyError(KeyError):
|
|
41
|
+
def __init__(self, message):
|
|
42
|
+
super().__init__(message)
|
|
43
|
+
|
|
27
44
|
|
|
28
45
|
# map operator to function name
|
|
29
46
|
builtin_operators = {}
|
|
@@ -57,6 +74,19 @@ builtin_operators[ast.Invert] = "invert"
|
|
|
57
74
|
builtin_operators[ast.LShift] = "lshift"
|
|
58
75
|
builtin_operators[ast.RShift] = "rshift"
|
|
59
76
|
|
|
77
|
+
comparison_chain_strings = [
|
|
78
|
+
builtin_operators[ast.Gt],
|
|
79
|
+
builtin_operators[ast.Lt],
|
|
80
|
+
builtin_operators[ast.LtE],
|
|
81
|
+
builtin_operators[ast.GtE],
|
|
82
|
+
builtin_operators[ast.Eq],
|
|
83
|
+
builtin_operators[ast.NotEq],
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def op_str_is_chainable(op: str) -> builtins.bool:
|
|
88
|
+
return op in comparison_chain_strings
|
|
89
|
+
|
|
60
90
|
|
|
61
91
|
def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
62
92
|
"""Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
|
|
@@ -67,97 +97,156 @@ def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
|
67
97
|
return getattr(obj, "__annotations__", {})
|
|
68
98
|
|
|
69
99
|
|
|
70
|
-
def
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
)
|
|
75
|
-
|
|
76
|
-
return inst._struct_.ctype()
|
|
100
|
+
def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
|
|
101
|
+
indent = "\t"
|
|
102
|
+
|
|
103
|
+
# handle empty structs
|
|
104
|
+
if len(inst._cls.vars) == 0:
|
|
105
|
+
return f"{inst._cls.key}()"
|
|
77
106
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
107
|
+
lines = []
|
|
108
|
+
lines.append(f"{inst._cls.key}(")
|
|
109
|
+
|
|
110
|
+
for field_name, _ in inst._cls.ctype._fields_:
|
|
111
|
+
field_value = getattr(inst, field_name, None)
|
|
112
|
+
|
|
113
|
+
if isinstance(field_value, StructInstance):
|
|
114
|
+
field_value = struct_instance_repr_recursive(field_value, depth + 1)
|
|
115
|
+
|
|
116
|
+
lines.append(f"{indent * (depth + 1)}{field_name}={field_value},")
|
|
82
117
|
|
|
83
|
-
|
|
84
|
-
|
|
118
|
+
lines.append(f"{indent * depth})")
|
|
119
|
+
return "\n".join(lines)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class StructInstance:
|
|
123
|
+
def __init__(self, cls: Struct, ctype):
|
|
124
|
+
super().__setattr__("_cls", cls)
|
|
125
|
+
|
|
126
|
+
# maintain a c-types object for the top-level instance the struct
|
|
127
|
+
if not ctype:
|
|
128
|
+
super().__setattr__("_ctype", cls.ctype())
|
|
129
|
+
else:
|
|
130
|
+
super().__setattr__("_ctype", ctype)
|
|
131
|
+
|
|
132
|
+
# create Python attributes for each of the struct's variables
|
|
133
|
+
for field, var in cls.vars.items():
|
|
134
|
+
if isinstance(var.type, warp.codegen.Struct):
|
|
135
|
+
self.__dict__[field] = StructInstance(var.type, getattr(self._ctype, field))
|
|
136
|
+
elif isinstance(var.type, warp.types.array):
|
|
137
|
+
self.__dict__[field] = None
|
|
138
|
+
else:
|
|
139
|
+
self.__dict__[field] = var.type()
|
|
85
140
|
|
|
86
|
-
|
|
87
|
-
if
|
|
141
|
+
def __setattr__(self, name, value):
|
|
142
|
+
if name not in self._cls.vars:
|
|
143
|
+
raise RuntimeError(f"Trying to set Warp struct attribute that does not exist {name}")
|
|
144
|
+
|
|
145
|
+
var = self._cls.vars[name]
|
|
146
|
+
|
|
147
|
+
# update our ctype flat copy
|
|
148
|
+
if isinstance(var.type, array):
|
|
88
149
|
if value is None:
|
|
89
150
|
# create array with null pointer
|
|
90
|
-
setattr(
|
|
151
|
+
setattr(self._ctype, name, array_t())
|
|
91
152
|
else:
|
|
92
153
|
# wp.array
|
|
93
154
|
assert isinstance(value, array)
|
|
94
|
-
assert (
|
|
95
|
-
value.dtype
|
|
96
|
-
), "assign to struct member variable {} failed, expected type {}, got type {}"
|
|
97
|
-
|
|
155
|
+
assert types_equal(
|
|
156
|
+
value.dtype, var.type.dtype
|
|
157
|
+
), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
|
|
158
|
+
setattr(self._ctype, name, value.__ctype__())
|
|
159
|
+
|
|
160
|
+
elif isinstance(var.type, Struct):
|
|
161
|
+
# assign structs by-value, otherwise we would have problematic cases transferring ownership
|
|
162
|
+
# of the underlying ctypes data between shared Python struct instances
|
|
163
|
+
|
|
164
|
+
if not isinstance(value, StructInstance):
|
|
165
|
+
raise RuntimeError(
|
|
166
|
+
f"Trying to assign a non-structure value to a struct attribute with type: {self._cls.key}"
|
|
98
167
|
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
168
|
+
|
|
169
|
+
# destination attribution on self
|
|
170
|
+
dest = getattr(self, name)
|
|
171
|
+
|
|
172
|
+
if dest._cls.key is not value._cls.key:
|
|
173
|
+
raise RuntimeError(
|
|
174
|
+
f"Trying to assign a structure of type {value._cls.key} to an attribute of {self._cls.key}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# update all nested ctype vars by deep copy
|
|
178
|
+
for n in dest._cls.vars:
|
|
179
|
+
setattr(dest, n, getattr(value, n))
|
|
180
|
+
|
|
181
|
+
# early return to avoid updating our Python StructInstance
|
|
182
|
+
return
|
|
183
|
+
|
|
184
|
+
elif issubclass(var.type, ctypes.Array):
|
|
106
185
|
# vector/matrix type, e.g. vec3
|
|
107
186
|
if value is None:
|
|
108
|
-
setattr(
|
|
109
|
-
elif types_equal(type(value),
|
|
110
|
-
setattr(
|
|
187
|
+
setattr(self._ctype, name, var.type())
|
|
188
|
+
elif types_equal(type(value), var.type):
|
|
189
|
+
setattr(self._ctype, name, value)
|
|
111
190
|
else:
|
|
112
191
|
# conversion from list/tuple, ndarray, etc.
|
|
113
|
-
setattr(
|
|
192
|
+
setattr(self._ctype, name, var.type(value))
|
|
193
|
+
|
|
114
194
|
else:
|
|
115
195
|
# primitive type
|
|
116
196
|
if value is None:
|
|
117
|
-
|
|
197
|
+
# zero initialize
|
|
198
|
+
setattr(self._ctype, name, var.type._type_())
|
|
118
199
|
else:
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
if inst._struct_.ctype._fields_ == [("_dummy_", ctypes.c_int)]:
|
|
128
|
-
return f"{inst._struct_.key}()"
|
|
129
|
-
|
|
130
|
-
lines = []
|
|
131
|
-
lines.append(f"{inst._struct_.key}(")
|
|
132
|
-
|
|
133
|
-
for field_name, _ in inst._struct_.ctype._fields_:
|
|
134
|
-
if field_name == "_dummy_":
|
|
135
|
-
continue
|
|
136
|
-
|
|
137
|
-
field_value = getattr(inst, field_name, None)
|
|
138
|
-
|
|
139
|
-
if isinstance(field_value, StructInstance):
|
|
140
|
-
field_value = _fmt_struct_instance_repr(field_value, depth + 1)
|
|
200
|
+
if hasattr(value, "_type_"):
|
|
201
|
+
# assigning warp type value (e.g.: wp.float32)
|
|
202
|
+
value = value.value
|
|
203
|
+
# float16 needs conversion to uint16 bits
|
|
204
|
+
if var.type == warp.float16:
|
|
205
|
+
setattr(self._ctype, name, float_to_half_bits(value))
|
|
206
|
+
else:
|
|
207
|
+
setattr(self._ctype, name, value)
|
|
141
208
|
|
|
142
|
-
|
|
209
|
+
# update Python instance
|
|
210
|
+
super().__setattr__(name, value)
|
|
143
211
|
|
|
144
|
-
|
|
145
|
-
|
|
212
|
+
def __ctype__(self):
|
|
213
|
+
return self._ctype
|
|
146
214
|
|
|
215
|
+
def __repr__(self):
|
|
216
|
+
return struct_instance_repr_recursive(self, 0)
|
|
147
217
|
|
|
148
|
-
|
|
149
|
-
def
|
|
150
|
-
self.
|
|
218
|
+
# type description used in numpy structured arrays
|
|
219
|
+
def numpy_dtype(self):
|
|
220
|
+
return self._cls.numpy_dtype()
|
|
151
221
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
222
|
+
# value usable in numpy structured arrays of .numpy_dtype(), e.g. (42, 13.37, [1.0, 2.0, 3.0])
|
|
223
|
+
def numpy_value(self):
|
|
224
|
+
npvalue = []
|
|
225
|
+
for name, var in self._cls.vars.items():
|
|
226
|
+
# get the attribute value
|
|
227
|
+
value = getattr(self._ctype, name)
|
|
155
228
|
|
|
156
|
-
|
|
157
|
-
|
|
229
|
+
if isinstance(var.type, array):
|
|
230
|
+
# array_t
|
|
231
|
+
npvalue.append(value.numpy_value())
|
|
232
|
+
elif isinstance(var.type, Struct):
|
|
233
|
+
# nested struct
|
|
234
|
+
npvalue.append(value.numpy_value())
|
|
235
|
+
elif issubclass(var.type, ctypes.Array):
|
|
236
|
+
if len(var.type._shape_) == 1:
|
|
237
|
+
# vector
|
|
238
|
+
npvalue.append(list(value))
|
|
239
|
+
else:
|
|
240
|
+
# matrix
|
|
241
|
+
npvalue.append([list(row) for row in value])
|
|
242
|
+
else:
|
|
243
|
+
# scalar
|
|
244
|
+
if var.type == warp.float16:
|
|
245
|
+
npvalue.append(half_bits_to_float(value))
|
|
246
|
+
else:
|
|
247
|
+
npvalue.append(value)
|
|
158
248
|
|
|
159
|
-
|
|
160
|
-
return _fmt_struct_instance_repr(self, 0)
|
|
249
|
+
return tuple(npvalue)
|
|
161
250
|
|
|
162
251
|
|
|
163
252
|
class Struct:
|
|
@@ -184,7 +273,7 @@ class Struct:
|
|
|
184
273
|
|
|
185
274
|
class StructType(ctypes.Structure):
|
|
186
275
|
# if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
|
|
187
|
-
_fields_ = fields or [("_dummy_", ctypes.
|
|
276
|
+
_fields_ = fields or [("_dummy_", ctypes.c_byte)]
|
|
188
277
|
|
|
189
278
|
self.ctype = StructType
|
|
190
279
|
|
|
@@ -235,29 +324,108 @@ class Struct:
|
|
|
235
324
|
|
|
236
325
|
class NewStructInstance(self.cls, StructInstance):
|
|
237
326
|
def __init__(inst):
|
|
238
|
-
StructInstance.__init__(inst, self)
|
|
327
|
+
StructInstance.__init__(inst, self, None)
|
|
239
328
|
|
|
240
329
|
return NewStructInstance()
|
|
241
330
|
|
|
242
331
|
def initializer(self):
|
|
243
332
|
return self.default_constructor
|
|
244
333
|
|
|
334
|
+
# return structured NumPy dtype, including field names, formats, and offsets
|
|
335
|
+
def numpy_dtype(self):
|
|
336
|
+
names = []
|
|
337
|
+
formats = []
|
|
338
|
+
offsets = []
|
|
339
|
+
for name, var in self.vars.items():
|
|
340
|
+
names.append(name)
|
|
341
|
+
offsets.append(getattr(self.ctype, name).offset)
|
|
342
|
+
if isinstance(var.type, array):
|
|
343
|
+
# array_t
|
|
344
|
+
formats.append(array_t.numpy_dtype())
|
|
345
|
+
elif isinstance(var.type, Struct):
|
|
346
|
+
# nested struct
|
|
347
|
+
formats.append(var.type.numpy_dtype())
|
|
348
|
+
elif issubclass(var.type, ctypes.Array):
|
|
349
|
+
scalar_typestr = type_typestr(var.type._wp_scalar_type_)
|
|
350
|
+
if len(var.type._shape_) == 1:
|
|
351
|
+
# vector
|
|
352
|
+
formats.append(f"{var.type._length_}{scalar_typestr}")
|
|
353
|
+
else:
|
|
354
|
+
# matrix
|
|
355
|
+
formats.append(f"{var.type._shape_}{scalar_typestr}")
|
|
356
|
+
else:
|
|
357
|
+
# scalar
|
|
358
|
+
formats.append(type_typestr(var.type))
|
|
359
|
+
|
|
360
|
+
return {"names": names, "formats": formats, "offsets": offsets, "itemsize": ctypes.sizeof(self.ctype)}
|
|
361
|
+
|
|
362
|
+
# constructs a Warp struct instance from a pointer to the ctype
|
|
363
|
+
def from_ptr(self, ptr):
|
|
364
|
+
if not ptr:
|
|
365
|
+
raise RuntimeError("NULL pointer exception")
|
|
366
|
+
|
|
367
|
+
# create a new struct instance
|
|
368
|
+
instance = self()
|
|
369
|
+
|
|
370
|
+
for name, var in self.vars.items():
|
|
371
|
+
offset = getattr(self.ctype, name).offset
|
|
372
|
+
if isinstance(var.type, array):
|
|
373
|
+
# We could reconstruct wp.array from array_t, but it's problematic.
|
|
374
|
+
# There's no guarantee that the original wp.array is still allocated and
|
|
375
|
+
# no easy way to make a backref.
|
|
376
|
+
# Instead, we just create a stub annotation, which is not a fully usable array object.
|
|
377
|
+
setattr(instance, name, array(dtype=var.type.dtype, ndim=var.type.ndim))
|
|
378
|
+
elif isinstance(var.type, Struct):
|
|
379
|
+
# nested struct
|
|
380
|
+
value = var.type.from_ptr(ptr + offset)
|
|
381
|
+
setattr(instance, name, value)
|
|
382
|
+
elif issubclass(var.type, ctypes.Array):
|
|
383
|
+
# vector/matrix
|
|
384
|
+
value = var.type.from_ptr(ptr + offset)
|
|
385
|
+
setattr(instance, name, value)
|
|
386
|
+
else:
|
|
387
|
+
# scalar
|
|
388
|
+
cvalue = ctypes.cast(ptr + offset, ctypes.POINTER(var.type._type_)).contents
|
|
389
|
+
if var.type == warp.float16:
|
|
390
|
+
setattr(instance, name, half_bits_to_float(cvalue))
|
|
391
|
+
else:
|
|
392
|
+
setattr(instance, name, cvalue.value)
|
|
393
|
+
|
|
394
|
+
return instance
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class Reference:
|
|
398
|
+
def __init__(self, value_type):
|
|
399
|
+
self.value_type = value_type
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def is_reference(type):
|
|
403
|
+
return isinstance(type, Reference)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def strip_reference(arg):
|
|
407
|
+
if is_reference(arg):
|
|
408
|
+
return arg.value_type
|
|
409
|
+
else:
|
|
410
|
+
return arg
|
|
411
|
+
|
|
245
412
|
|
|
246
413
|
def compute_type_str(base_name, template_params):
|
|
247
|
-
if
|
|
414
|
+
if not template_params:
|
|
248
415
|
return base_name
|
|
249
|
-
else:
|
|
250
416
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
417
|
+
def param2str(p):
|
|
418
|
+
if isinstance(p, int):
|
|
419
|
+
return str(p)
|
|
420
|
+
elif hasattr(p, "_type_"):
|
|
421
|
+
return f"wp::{p.__name__}"
|
|
422
|
+
return p.__name__
|
|
255
423
|
|
|
256
|
-
|
|
424
|
+
return f"{base_name}<{','.join(map(param2str, template_params))}>"
|
|
257
425
|
|
|
258
426
|
|
|
259
427
|
class Var:
|
|
260
|
-
def __init__(self, label, type, requires_grad=False, constant=None):
|
|
428
|
+
def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
|
|
261
429
|
# convert built-in types to wp types
|
|
262
430
|
if type == float:
|
|
263
431
|
type = float32
|
|
@@ -268,26 +436,49 @@ class Var:
|
|
|
268
436
|
self.type = type
|
|
269
437
|
self.requires_grad = requires_grad
|
|
270
438
|
self.constant = constant
|
|
439
|
+
self.prefix = prefix
|
|
271
440
|
|
|
272
441
|
def __str__(self):
|
|
273
442
|
return self.label
|
|
274
443
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
444
|
+
@staticmethod
|
|
445
|
+
def type_to_ctype(t, value_type=False):
|
|
446
|
+
if is_array(t):
|
|
447
|
+
if hasattr(t.dtype, "_wp_generic_type_str_"):
|
|
448
|
+
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
449
|
+
elif isinstance(t.dtype, Struct):
|
|
450
|
+
dtypestr = make_full_qualified_name(t.dtype.cls)
|
|
451
|
+
elif t.dtype.__name__ in ("bool", "int", "float"):
|
|
452
|
+
dtypestr = t.dtype.__name__
|
|
281
453
|
else:
|
|
282
|
-
dtypestr =
|
|
283
|
-
classstr = type(
|
|
454
|
+
dtypestr = f"wp::{t.dtype.__name__}"
|
|
455
|
+
classstr = f"wp::{type(t).__name__}"
|
|
284
456
|
return f"{classstr}_t<{dtypestr}>"
|
|
285
|
-
elif isinstance(
|
|
286
|
-
return make_full_qualified_name(
|
|
287
|
-
elif
|
|
288
|
-
|
|
457
|
+
elif isinstance(t, Struct):
|
|
458
|
+
return make_full_qualified_name(t.cls)
|
|
459
|
+
elif is_reference(t):
|
|
460
|
+
if not value_type:
|
|
461
|
+
return Var.type_to_ctype(t.value_type) + "*"
|
|
462
|
+
else:
|
|
463
|
+
return Var.type_to_ctype(t.value_type)
|
|
464
|
+
elif hasattr(t, "_wp_generic_type_str_"):
|
|
465
|
+
return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
|
|
466
|
+
elif t.__name__ in ("bool", "int", "float"):
|
|
467
|
+
return t.__name__
|
|
468
|
+
else:
|
|
469
|
+
return f"wp::{t.__name__}"
|
|
470
|
+
|
|
471
|
+
def ctype(self, value_type=False):
|
|
472
|
+
return Var.type_to_ctype(self.type, value_type)
|
|
473
|
+
|
|
474
|
+
def emit(self, prefix: str = "var"):
|
|
475
|
+
if self.prefix:
|
|
476
|
+
return f"{prefix}_{self.label}"
|
|
289
477
|
else:
|
|
290
|
-
return
|
|
478
|
+
return self.label
|
|
479
|
+
|
|
480
|
+
def emit_adj(self):
|
|
481
|
+
return self.emit("adj")
|
|
291
482
|
|
|
292
483
|
|
|
293
484
|
class Block:
|
|
@@ -304,33 +495,65 @@ class Block:
|
|
|
304
495
|
self.vars = []
|
|
305
496
|
|
|
306
497
|
|
|
498
|
+
def is_local_value(value) -> bool:
|
|
499
|
+
"""Check whether a variable is defined inside a kernel."""
|
|
500
|
+
return isinstance(value, (warp.context.Function, Var))
|
|
501
|
+
|
|
502
|
+
|
|
307
503
|
class Adjoint:
|
|
308
504
|
# Source code transformer, this class takes a Python function and
|
|
309
505
|
# generates forward and backward SSA forms of the function instructions
|
|
310
506
|
|
|
311
|
-
def __init__(
|
|
507
|
+
def __init__(
|
|
508
|
+
adj,
|
|
509
|
+
func,
|
|
510
|
+
overload_annotations=None,
|
|
511
|
+
is_user_function=False,
|
|
512
|
+
skip_forward_codegen=False,
|
|
513
|
+
skip_reverse_codegen=False,
|
|
514
|
+
custom_reverse_mode=False,
|
|
515
|
+
custom_reverse_num_input_args=-1,
|
|
516
|
+
transformers: List[ast.NodeTransformer] = [],
|
|
517
|
+
):
|
|
312
518
|
adj.func = func
|
|
313
519
|
|
|
314
|
-
|
|
315
|
-
adj.source = inspect.getsource(func)
|
|
520
|
+
adj.is_user_function = is_user_function
|
|
316
521
|
|
|
317
|
-
#
|
|
318
|
-
adj.
|
|
522
|
+
# whether the generation of the forward code is skipped for this function
|
|
523
|
+
adj.skip_forward_codegen = skip_forward_codegen
|
|
524
|
+
# whether the generation of the adjoint code is skipped for this function
|
|
525
|
+
adj.skip_reverse_codegen = skip_reverse_codegen
|
|
319
526
|
|
|
320
|
-
#
|
|
321
|
-
adj.
|
|
527
|
+
# extract name of source file
|
|
528
|
+
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
529
|
+
# get source file line number where function starts
|
|
530
|
+
_, adj.fun_lineno = inspect.getsourcelines(func)
|
|
322
531
|
|
|
532
|
+
# get function source code
|
|
533
|
+
adj.source = inspect.getsource(func)
|
|
323
534
|
# ensures that indented class methods can be parsed as kernels
|
|
324
535
|
adj.source = textwrap.dedent(adj.source)
|
|
325
536
|
|
|
326
|
-
|
|
327
|
-
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
537
|
+
adj.source_lines = adj.source.splitlines()
|
|
328
538
|
|
|
329
|
-
# build AST
|
|
539
|
+
# build AST and apply node transformers
|
|
330
540
|
adj.tree = ast.parse(adj.source)
|
|
541
|
+
adj.transformers = transformers
|
|
542
|
+
for transformer in transformers:
|
|
543
|
+
adj.tree = transformer.visit(adj.tree)
|
|
331
544
|
|
|
332
545
|
adj.fun_name = adj.tree.body[0].name
|
|
333
546
|
|
|
547
|
+
# for keeping track of line number in function code
|
|
548
|
+
adj.lineno = None
|
|
549
|
+
|
|
550
|
+
# whether the forward code shall be used for the reverse pass and a custom
|
|
551
|
+
# function signature is applied to the reverse version of the function
|
|
552
|
+
adj.custom_reverse_mode = custom_reverse_mode
|
|
553
|
+
# the number of function arguments that pertain to the forward function
|
|
554
|
+
# input arguments (i.e. the number of arguments that are not adjoint arguments)
|
|
555
|
+
adj.custom_reverse_num_input_args = custom_reverse_num_input_args
|
|
556
|
+
|
|
334
557
|
# parse argument types
|
|
335
558
|
argspec = inspect.getfullargspec(func)
|
|
336
559
|
|
|
@@ -338,16 +561,17 @@ class Adjoint:
|
|
|
338
561
|
if overload_annotations is None:
|
|
339
562
|
# use source-level argument annotations
|
|
340
563
|
if len(argspec.annotations) < len(argspec.args):
|
|
341
|
-
raise
|
|
564
|
+
raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
|
|
342
565
|
adj.arg_types = argspec.annotations
|
|
343
566
|
else:
|
|
344
567
|
# use overload argument annotations
|
|
345
568
|
for arg_name in argspec.args:
|
|
346
569
|
if arg_name not in overload_annotations:
|
|
347
|
-
raise
|
|
570
|
+
raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
|
|
348
571
|
adj.arg_types = overload_annotations.copy()
|
|
349
572
|
|
|
350
573
|
adj.args = []
|
|
574
|
+
adj.symbols = {}
|
|
351
575
|
|
|
352
576
|
for name, type in adj.arg_types.items():
|
|
353
577
|
# skip return hint
|
|
@@ -358,8 +582,23 @@ class Adjoint:
|
|
|
358
582
|
arg = Var(name, type, False)
|
|
359
583
|
adj.args.append(arg)
|
|
360
584
|
|
|
585
|
+
# pre-populate symbol dictionary with function argument names
|
|
586
|
+
# this is to avoid registering false references to overshadowed modules
|
|
587
|
+
adj.symbols[name] = arg
|
|
588
|
+
|
|
589
|
+
# There are cases where a same module might be rebuilt multiple times,
|
|
590
|
+
# for example when kernels are nested inside of functions, or when
|
|
591
|
+
# a kernel's launch raises an exception. Ideally we'd always want to
|
|
592
|
+
# avoid rebuilding kernels but some corner cases seem to depend on it,
|
|
593
|
+
# so we only avoid rebuilding kernels that errored out to give a chance
|
|
594
|
+
# for unit testing errors being spit out from kernels.
|
|
595
|
+
adj.skip_build = False
|
|
596
|
+
|
|
361
597
|
# generate function ssa form and adjoint
|
|
362
598
|
def build(adj, builder):
|
|
599
|
+
if adj.skip_build:
|
|
600
|
+
return
|
|
601
|
+
|
|
363
602
|
adj.builder = builder
|
|
364
603
|
|
|
365
604
|
adj.symbols = {} # map from symbols to adjoint variables
|
|
@@ -373,7 +612,7 @@ class Adjoint:
|
|
|
373
612
|
adj.loop_blocks = []
|
|
374
613
|
|
|
375
614
|
# holds current indent level
|
|
376
|
-
adj.
|
|
615
|
+
adj.indentation = ""
|
|
377
616
|
|
|
378
617
|
# used to generate new label indices
|
|
379
618
|
adj.label_count = 0
|
|
@@ -387,20 +626,25 @@ class Adjoint:
|
|
|
387
626
|
adj.eval(adj.tree.body[0])
|
|
388
627
|
except Exception as e:
|
|
389
628
|
try:
|
|
629
|
+
if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
|
|
630
|
+
msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
|
|
631
|
+
else:
|
|
632
|
+
msg = "Error"
|
|
390
633
|
lineno = adj.lineno + adj.fun_lineno
|
|
391
|
-
line = adj.
|
|
392
|
-
msg
|
|
634
|
+
line = adj.source_lines[adj.lineno]
|
|
635
|
+
msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
393
636
|
ex, data, traceback = sys.exc_info()
|
|
394
|
-
e = ex("".join([msg] +
|
|
637
|
+
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
|
|
395
638
|
finally:
|
|
639
|
+
adj.skip_build = True
|
|
396
640
|
raise e
|
|
397
641
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
642
|
+
if builder is not None:
|
|
643
|
+
for a in adj.args:
|
|
644
|
+
if isinstance(a.type, Struct):
|
|
645
|
+
builder.build_struct_recursive(a.type)
|
|
646
|
+
elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
|
|
647
|
+
builder.build_struct_recursive(a.type.dtype)
|
|
404
648
|
|
|
405
649
|
# code generation methods
|
|
406
650
|
def format_template(adj, template, input_vars, output_var):
|
|
@@ -415,44 +659,56 @@ class Adjoint:
|
|
|
415
659
|
arg_strs = []
|
|
416
660
|
|
|
417
661
|
for a in args:
|
|
418
|
-
if
|
|
662
|
+
if isinstance(a, warp.context.Function):
|
|
419
663
|
# functions don't have a var_ prefix so strip it off here
|
|
420
|
-
if prefix == "
|
|
664
|
+
if prefix == "var":
|
|
421
665
|
arg_strs.append(a.key)
|
|
422
666
|
else:
|
|
423
|
-
arg_strs.append(prefix
|
|
424
|
-
|
|
667
|
+
arg_strs.append(f"{prefix}_{a.key}")
|
|
668
|
+
elif is_reference(a.type):
|
|
669
|
+
arg_strs.append(f"{prefix}_{a}")
|
|
670
|
+
elif isinstance(a, Var):
|
|
671
|
+
arg_strs.append(a.emit(prefix))
|
|
425
672
|
else:
|
|
426
|
-
|
|
673
|
+
raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
|
|
427
674
|
|
|
428
675
|
return arg_strs
|
|
429
676
|
|
|
430
677
|
# generates argument string for a forward function call
|
|
431
678
|
def format_forward_call_args(adj, args, use_initializer_list):
|
|
432
|
-
arg_str = ", ".join(adj.format_args("
|
|
679
|
+
arg_str = ", ".join(adj.format_args("var", args))
|
|
433
680
|
if use_initializer_list:
|
|
434
|
-
return "{{{}}}"
|
|
681
|
+
return f"{{{arg_str}}}"
|
|
435
682
|
return arg_str
|
|
436
683
|
|
|
437
684
|
# generates argument string for a reverse function call
|
|
438
|
-
def format_reverse_call_args(
|
|
439
|
-
|
|
685
|
+
def format_reverse_call_args(
|
|
686
|
+
adj,
|
|
687
|
+
args_var,
|
|
688
|
+
args,
|
|
689
|
+
args_out,
|
|
690
|
+
use_initializer_list,
|
|
691
|
+
has_output_args=True,
|
|
692
|
+
require_original_output_arg=False,
|
|
693
|
+
):
|
|
694
|
+
formatted_var = adj.format_args("var", args_var)
|
|
440
695
|
formatted_out = []
|
|
441
|
-
if len(args_out) > 1:
|
|
442
|
-
formatted_out = adj.format_args("
|
|
696
|
+
if has_output_args and (require_original_output_arg or len(args_out) > 1):
|
|
697
|
+
formatted_out = adj.format_args("var", args_out)
|
|
443
698
|
formatted_var_adj = adj.format_args(
|
|
444
|
-
"&
|
|
699
|
+
"&adj" if use_initializer_list else "adj",
|
|
700
|
+
args,
|
|
445
701
|
)
|
|
446
|
-
formatted_out_adj = adj.format_args("
|
|
702
|
+
formatted_out_adj = adj.format_args("adj", args_out)
|
|
447
703
|
|
|
448
704
|
if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
|
|
449
705
|
# there are no adjoint arguments, so we don't need to call the reverse function
|
|
450
706
|
return None
|
|
451
707
|
|
|
452
708
|
if use_initializer_list:
|
|
453
|
-
var_str = "{{{
|
|
454
|
-
out_str = "{{{
|
|
455
|
-
adj_str = "{{{
|
|
709
|
+
var_str = f"{{{', '.join(formatted_var)}}}"
|
|
710
|
+
out_str = f"{{{', '.join(formatted_out)}}}"
|
|
711
|
+
adj_str = f"{{{', '.join(formatted_var_adj)}}}"
|
|
456
712
|
out_adj_str = ", ".join(formatted_out_adj)
|
|
457
713
|
if len(args_out) > 1:
|
|
458
714
|
arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
|
|
@@ -463,10 +719,10 @@ class Adjoint:
|
|
|
463
719
|
return arg_str
|
|
464
720
|
|
|
465
721
|
def indent(adj):
|
|
466
|
-
adj.
|
|
722
|
+
adj.indentation = adj.indentation + " "
|
|
467
723
|
|
|
468
724
|
def dedent(adj):
|
|
469
|
-
adj.
|
|
725
|
+
adj.indentation = adj.indentation[:-4]
|
|
470
726
|
|
|
471
727
|
def begin_block(adj):
|
|
472
728
|
b = Block()
|
|
@@ -481,10 +737,9 @@ class Adjoint:
|
|
|
481
737
|
def end_block(adj):
|
|
482
738
|
return adj.blocks.pop()
|
|
483
739
|
|
|
484
|
-
def add_var(adj, type=None, constant=None
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
name = str(index)
|
|
740
|
+
def add_var(adj, type=None, constant=None):
|
|
741
|
+
index = len(adj.variables)
|
|
742
|
+
name = str(index)
|
|
488
743
|
|
|
489
744
|
# allocate new variable
|
|
490
745
|
v = Var(name, type=type, constant=constant)
|
|
@@ -497,30 +752,54 @@ class Adjoint:
|
|
|
497
752
|
|
|
498
753
|
# append a statement to the forward pass
|
|
499
754
|
def add_forward(adj, statement, replay=None, skip_replay=False):
|
|
500
|
-
adj.blocks[-1].body_forward.append(adj.
|
|
755
|
+
adj.blocks[-1].body_forward.append(adj.indentation + statement)
|
|
501
756
|
|
|
502
757
|
if not skip_replay:
|
|
503
758
|
if replay:
|
|
504
759
|
# if custom replay specified then output it
|
|
505
|
-
adj.blocks[-1].body_replay.append(adj.
|
|
760
|
+
adj.blocks[-1].body_replay.append(adj.indentation + replay)
|
|
506
761
|
else:
|
|
507
762
|
# by default just replay the original statement
|
|
508
|
-
adj.blocks[-1].body_replay.append(adj.
|
|
763
|
+
adj.blocks[-1].body_replay.append(adj.indentation + statement)
|
|
509
764
|
|
|
510
765
|
# append a statement to the reverse pass
|
|
511
766
|
def add_reverse(adj, statement):
|
|
512
|
-
adj.blocks[-1].body_reverse.append(adj.
|
|
767
|
+
adj.blocks[-1].body_reverse.append(adj.indentation + statement)
|
|
513
768
|
|
|
514
769
|
def add_constant(adj, n):
|
|
515
770
|
output = adj.add_var(type=type(n), constant=n)
|
|
516
771
|
return output
|
|
517
772
|
|
|
773
|
+
def load(adj, var):
|
|
774
|
+
if is_reference(var.type):
|
|
775
|
+
var = adj.add_builtin_call("load", [var])
|
|
776
|
+
return var
|
|
777
|
+
|
|
518
778
|
def add_comp(adj, op_strings, left, comps):
|
|
519
|
-
output = adj.add_var(bool)
|
|
779
|
+
output = adj.add_var(builtins.bool)
|
|
780
|
+
|
|
781
|
+
left = adj.load(left)
|
|
782
|
+
s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
|
|
783
|
+
|
|
784
|
+
prev_comp = None
|
|
520
785
|
|
|
521
|
-
s = "var_" + str(output) + " = " + ("(" * len(comps)) + "var_" + str(left) + " "
|
|
522
786
|
for op, comp in zip(op_strings, comps):
|
|
523
|
-
|
|
787
|
+
comp_chainable = op_str_is_chainable(op)
|
|
788
|
+
if comp_chainable and prev_comp:
|
|
789
|
+
# We restrict chaining to operands of the same type
|
|
790
|
+
if prev_comp.type is comp.type:
|
|
791
|
+
prev_comp = adj.load(prev_comp)
|
|
792
|
+
comp = adj.load(comp)
|
|
793
|
+
s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
|
|
794
|
+
else:
|
|
795
|
+
raise WarpCodegenTypeError(
|
|
796
|
+
f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
|
|
797
|
+
)
|
|
798
|
+
else:
|
|
799
|
+
comp = adj.load(comp)
|
|
800
|
+
s += op + " " + comp.emit() + ") "
|
|
801
|
+
|
|
802
|
+
prev_comp = comp
|
|
524
803
|
|
|
525
804
|
s = s.rstrip() + ";"
|
|
526
805
|
|
|
@@ -529,109 +808,106 @@ class Adjoint:
|
|
|
529
808
|
return output
|
|
530
809
|
|
|
531
810
|
def add_bool_op(adj, op_string, exprs):
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
)
|
|
811
|
+
exprs = [adj.load(expr) for expr in exprs]
|
|
812
|
+
output = adj.add_var(builtins.bool)
|
|
813
|
+
command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
|
|
536
814
|
adj.add_forward(command)
|
|
537
815
|
|
|
538
816
|
return output
|
|
539
817
|
|
|
540
|
-
def
|
|
541
|
-
|
|
542
|
-
# we validate argument types before they go to generated native code
|
|
543
|
-
resolved_func = None
|
|
818
|
+
def resolve_func(adj, func, args, min_outputs, templates, kwds):
|
|
819
|
+
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
|
|
544
820
|
|
|
545
|
-
if func.is_builtin():
|
|
821
|
+
if not func.is_builtin():
|
|
822
|
+
# user-defined function
|
|
823
|
+
overload = func.get_overload(arg_types)
|
|
824
|
+
if overload is not None:
|
|
825
|
+
return overload
|
|
826
|
+
else:
|
|
827
|
+
# if func is overloaded then perform overload resolution here
|
|
828
|
+
# we validate argument types before they go to generated native code
|
|
546
829
|
for f in func.overloads:
|
|
547
|
-
match = True
|
|
548
|
-
|
|
549
830
|
# skip type checking for variadic functions
|
|
550
831
|
if not f.variadic:
|
|
551
832
|
# check argument counts match are compatible (may be some default args)
|
|
552
833
|
if len(f.input_types) < len(args):
|
|
553
|
-
match = False
|
|
554
834
|
continue
|
|
555
835
|
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
836
|
+
def match_args(args, f):
|
|
837
|
+
# check argument types equal
|
|
838
|
+
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
|
|
839
|
+
# if arg type registered as Any, treat as
|
|
840
|
+
# template allowing any type to match
|
|
841
|
+
if arg_type == Any:
|
|
842
|
+
continue
|
|
843
|
+
|
|
844
|
+
# handle function refs as a special case
|
|
845
|
+
if arg_type == Callable and type(args[i]) is warp.context.Function:
|
|
846
|
+
continue
|
|
847
|
+
|
|
848
|
+
if arg_type == Reference and is_reference(args[i].type):
|
|
849
|
+
continue
|
|
850
|
+
|
|
851
|
+
# look for default values for missing args
|
|
852
|
+
if i >= len(args):
|
|
853
|
+
if arg_name not in f.defaults:
|
|
854
|
+
return False
|
|
855
|
+
else:
|
|
856
|
+
# otherwise check arg type matches input variable type
|
|
857
|
+
if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
|
|
858
|
+
return False
|
|
859
|
+
|
|
860
|
+
return True
|
|
861
|
+
|
|
862
|
+
if not match_args(args, f):
|
|
863
|
+
continue
|
|
577
864
|
|
|
578
865
|
# check output dimensions match expectations
|
|
579
866
|
if min_outputs:
|
|
580
867
|
try:
|
|
581
868
|
value_type = f.value_func(args, kwds, templates)
|
|
582
|
-
if len(value_type) != min_outputs:
|
|
583
|
-
match = False
|
|
869
|
+
if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
|
|
584
870
|
continue
|
|
585
871
|
except Exception:
|
|
586
872
|
# value func may fail if the user has given
|
|
587
873
|
# incorrect args, so we need to catch this
|
|
588
|
-
match = False
|
|
589
874
|
continue
|
|
590
875
|
|
|
591
876
|
# found a match, use it
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
if isinstance(x.type, list):
|
|
607
|
-
if len(x.type) != 1:
|
|
608
|
-
raise Exception("Argument must not be the result from a multi-valued function")
|
|
609
|
-
arg_type = x.type[0]
|
|
610
|
-
else:
|
|
611
|
-
arg_type = x.type
|
|
612
|
-
if arg_type.__module__ == "warp.types":
|
|
613
|
-
arg_types.append(arg_type.__name__)
|
|
614
|
-
else:
|
|
615
|
-
arg_types.append(arg_type.__module__ + "." + arg_type.__name__)
|
|
616
|
-
|
|
617
|
-
if isinstance(x, warp.context.Function):
|
|
618
|
-
arg_types.append("function")
|
|
619
|
-
|
|
620
|
-
raise Exception(
|
|
621
|
-
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
|
|
622
|
-
)
|
|
877
|
+
return f
|
|
878
|
+
|
|
879
|
+
# unresolved function, report error
|
|
880
|
+
arg_types = []
|
|
881
|
+
|
|
882
|
+
for x in args:
|
|
883
|
+
if isinstance(x, Var):
|
|
884
|
+
# shorten Warp primitive type names
|
|
885
|
+
if isinstance(x.type, list):
|
|
886
|
+
if len(x.type) != 1:
|
|
887
|
+
raise WarpCodegenError("Argument must not be the result from a multi-valued function")
|
|
888
|
+
arg_type = x.type[0]
|
|
889
|
+
else:
|
|
890
|
+
arg_type = x.type
|
|
623
891
|
|
|
624
|
-
|
|
625
|
-
|
|
892
|
+
arg_types.append(type_repr(arg_type))
|
|
893
|
+
|
|
894
|
+
if isinstance(x, warp.context.Function):
|
|
895
|
+
arg_types.append("function")
|
|
896
|
+
|
|
897
|
+
raise WarpCodegenError(
|
|
898
|
+
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
|
|
902
|
+
func = adj.resolve_func(func, args, min_outputs, templates, kwds)
|
|
626
903
|
|
|
627
904
|
# push any default values onto args
|
|
628
905
|
for i, (arg_name, arg_type) in enumerate(func.input_types.items()):
|
|
629
906
|
if i >= len(args):
|
|
630
|
-
if arg_name in
|
|
907
|
+
if arg_name in func.defaults:
|
|
631
908
|
const = adj.add_constant(func.defaults[arg_name])
|
|
632
909
|
args.append(const)
|
|
633
910
|
else:
|
|
634
|
-
match = False
|
|
635
911
|
break
|
|
636
912
|
|
|
637
913
|
# if it is a user-function then build it recursively
|
|
@@ -639,93 +915,105 @@ class Adjoint:
|
|
|
639
915
|
adj.builder.build_function(func)
|
|
640
916
|
|
|
641
917
|
# evaluate the function type based on inputs
|
|
642
|
-
|
|
918
|
+
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
|
|
919
|
+
return_type = func.value_func(arg_types, kwds, templates)
|
|
643
920
|
|
|
644
921
|
func_name = compute_type_str(func.native_func, templates)
|
|
922
|
+
param_types = list(func.input_types.values())
|
|
645
923
|
|
|
646
924
|
use_initializer_list = func.initializer_list_func(args, templates)
|
|
647
925
|
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
if func.skip_replay:
|
|
655
|
-
adj.add_forward(forward_call, replay="//" + forward_call)
|
|
656
|
-
else:
|
|
657
|
-
adj.add_forward(forward_call)
|
|
658
|
-
|
|
659
|
-
if not func.missing_grad and len(args):
|
|
660
|
-
arg_str = adj.format_reverse_call_args(args, [], {}, {}, use_initializer_list)
|
|
661
|
-
if arg_str is not None:
|
|
662
|
-
reverse_call = "{}adj_{}({});".format(func.namespace, func.native_func, arg_str)
|
|
663
|
-
adj.add_reverse(reverse_call)
|
|
926
|
+
args_var = [
|
|
927
|
+
adj.load(a)
|
|
928
|
+
if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
|
|
929
|
+
else a
|
|
930
|
+
for i, a in enumerate(args)
|
|
931
|
+
]
|
|
664
932
|
|
|
665
|
-
|
|
933
|
+
if return_type is None:
|
|
934
|
+
# handles expression (zero output) functions, e.g.: void do_something();
|
|
666
935
|
|
|
667
|
-
|
|
668
|
-
|
|
936
|
+
output = None
|
|
937
|
+
output_list = []
|
|
669
938
|
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
output = adj.add_var(value_type)
|
|
673
|
-
forward_call = "var_{} = {}{}({});".format(
|
|
674
|
-
output, func.namespace, func_name, adj.format_forward_call_args(args, use_initializer_list)
|
|
939
|
+
forward_call = (
|
|
940
|
+
f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
675
941
|
)
|
|
942
|
+
replay_call = forward_call
|
|
943
|
+
if func.custom_replay_func is not None:
|
|
944
|
+
replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
676
945
|
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
else:
|
|
680
|
-
adj.add_forward(forward_call)
|
|
946
|
+
elif not isinstance(return_type, list) or len(return_type) == 1:
|
|
947
|
+
# handle simple function (one output)
|
|
681
948
|
|
|
682
|
-
if
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
adj.add_reverse(reverse_call)
|
|
949
|
+
if isinstance(return_type, list):
|
|
950
|
+
return_type = return_type[0]
|
|
951
|
+
output = adj.add_var(return_type)
|
|
952
|
+
output_list = [output]
|
|
687
953
|
|
|
688
|
-
|
|
954
|
+
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
955
|
+
replay_call = forward_call
|
|
956
|
+
if func.custom_replay_func is not None:
|
|
957
|
+
replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
|
|
689
958
|
|
|
690
959
|
else:
|
|
691
960
|
# handle multiple value functions
|
|
692
961
|
|
|
693
|
-
output = [adj.add_var(v) for v in
|
|
694
|
-
|
|
695
|
-
|
|
962
|
+
output = [adj.add_var(v) for v in return_type]
|
|
963
|
+
output_list = output
|
|
964
|
+
|
|
965
|
+
forward_call = (
|
|
966
|
+
f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var + output, use_initializer_list)});"
|
|
696
967
|
)
|
|
697
|
-
|
|
968
|
+
replay_call = forward_call
|
|
698
969
|
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
970
|
+
if func.skip_replay:
|
|
971
|
+
adj.add_forward(forward_call, replay="// " + replay_call)
|
|
972
|
+
else:
|
|
973
|
+
adj.add_forward(forward_call, replay=replay_call)
|
|
974
|
+
|
|
975
|
+
if not func.missing_grad and len(args):
|
|
976
|
+
reverse_has_output_args = (
|
|
977
|
+
func.require_original_output_arg or len(output_list) > 1
|
|
978
|
+
) and func.custom_grad_func is None
|
|
979
|
+
arg_str = adj.format_reverse_call_args(
|
|
980
|
+
args_var,
|
|
981
|
+
args,
|
|
982
|
+
output_list,
|
|
983
|
+
use_initializer_list,
|
|
984
|
+
has_output_args=reverse_has_output_args,
|
|
985
|
+
require_original_output_arg=func.require_original_output_arg,
|
|
986
|
+
)
|
|
987
|
+
if arg_str is not None:
|
|
988
|
+
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
|
|
989
|
+
adj.add_reverse(reverse_call)
|
|
704
990
|
|
|
705
|
-
|
|
706
|
-
return output[0]
|
|
991
|
+
return output
|
|
707
992
|
|
|
708
|
-
|
|
993
|
+
def add_builtin_call(adj, func_name, args, min_outputs=None, templates=[], kwds=None):
|
|
994
|
+
func = warp.context.builtin_functions[func_name]
|
|
995
|
+
return adj.add_call(func, args, min_outputs, templates, kwds)
|
|
709
996
|
|
|
710
997
|
def add_return(adj, var):
|
|
711
998
|
if var is None or len(var) == 0:
|
|
712
|
-
adj.add_forward("return;", "goto label{};"
|
|
999
|
+
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
713
1000
|
elif len(var) == 1:
|
|
714
|
-
adj.add_forward("return
|
|
1001
|
+
adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
|
|
715
1002
|
adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
|
|
716
1003
|
else:
|
|
717
1004
|
for i, v in enumerate(var):
|
|
718
|
-
adj.add_forward("ret_{} =
|
|
719
|
-
adj.add_reverse("adj_{} += adj_ret_{};"
|
|
720
|
-
adj.add_forward("return;", "goto label{};"
|
|
1005
|
+
adj.add_forward(f"ret_{i} = {v.emit()};")
|
|
1006
|
+
adj.add_reverse(f"adj_{v} += adj_ret_{i};")
|
|
1007
|
+
adj.add_forward("return;", f"goto label{adj.label_count};")
|
|
721
1008
|
|
|
722
|
-
adj.add_reverse("label{}:;"
|
|
1009
|
+
adj.add_reverse(f"label{adj.label_count}:;")
|
|
723
1010
|
|
|
724
1011
|
adj.label_count += 1
|
|
725
1012
|
|
|
726
1013
|
# define an if statement
|
|
727
1014
|
def begin_if(adj, cond):
|
|
728
|
-
|
|
1015
|
+
cond = adj.load(cond)
|
|
1016
|
+
adj.add_forward(f"if ({cond.emit()}) {{")
|
|
729
1017
|
adj.add_reverse("}")
|
|
730
1018
|
|
|
731
1019
|
adj.indent()
|
|
@@ -734,10 +1022,12 @@ class Adjoint:
|
|
|
734
1022
|
adj.dedent()
|
|
735
1023
|
|
|
736
1024
|
adj.add_forward("}")
|
|
737
|
-
adj.
|
|
1025
|
+
cond = adj.load(cond)
|
|
1026
|
+
adj.add_reverse(f"if ({cond.emit()}) {{")
|
|
738
1027
|
|
|
739
1028
|
def begin_else(adj, cond):
|
|
740
|
-
adj.
|
|
1029
|
+
cond = adj.load(cond)
|
|
1030
|
+
adj.add_forward(f"if (!{cond.emit()}) {{")
|
|
741
1031
|
adj.add_reverse("}")
|
|
742
1032
|
|
|
743
1033
|
adj.indent()
|
|
@@ -746,7 +1036,8 @@ class Adjoint:
|
|
|
746
1036
|
adj.dedent()
|
|
747
1037
|
|
|
748
1038
|
adj.add_forward("}")
|
|
749
|
-
adj.
|
|
1039
|
+
cond = adj.load(cond)
|
|
1040
|
+
adj.add_reverse(f"if (!{cond.emit()}) {{")
|
|
750
1041
|
|
|
751
1042
|
# define a for-loop
|
|
752
1043
|
def begin_for(adj, iter):
|
|
@@ -756,10 +1047,10 @@ class Adjoint:
|
|
|
756
1047
|
adj.indent()
|
|
757
1048
|
|
|
758
1049
|
# evaluate cond
|
|
759
|
-
adj.add_forward(f"if (iter_cmp(
|
|
1050
|
+
adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto for_end_{cond_block.label};")
|
|
760
1051
|
|
|
761
1052
|
# evaluate iter
|
|
762
|
-
val = adj.
|
|
1053
|
+
val = adj.add_builtin_call("iter_next", [iter])
|
|
763
1054
|
|
|
764
1055
|
adj.begin_block()
|
|
765
1056
|
|
|
@@ -790,17 +1081,14 @@ class Adjoint:
|
|
|
790
1081
|
reverse = []
|
|
791
1082
|
|
|
792
1083
|
# reverse iterator
|
|
793
|
-
reverse.append(adj.
|
|
1084
|
+
reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
|
|
794
1085
|
|
|
795
1086
|
for i in cond_block.body_forward:
|
|
796
1087
|
reverse.append(i)
|
|
797
1088
|
|
|
798
1089
|
# zero adjoints
|
|
799
1090
|
for i in body_block.vars:
|
|
800
|
-
|
|
801
|
-
reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}{{}};")
|
|
802
|
-
else:
|
|
803
|
-
reverse.append(adj.prefix + f"\tadj_{i} = {i.ctype()}(0);")
|
|
1091
|
+
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
|
|
804
1092
|
|
|
805
1093
|
# replay
|
|
806
1094
|
for i in body_block.body_replay:
|
|
@@ -810,14 +1098,14 @@ class Adjoint:
|
|
|
810
1098
|
for i in reversed(body_block.body_reverse):
|
|
811
1099
|
reverse.append(i)
|
|
812
1100
|
|
|
813
|
-
reverse.append(adj.
|
|
814
|
-
reverse.append(adj.
|
|
1101
|
+
reverse.append(adj.indentation + f"\tgoto for_start_{cond_block.label};")
|
|
1102
|
+
reverse.append(adj.indentation + f"for_end_{cond_block.label}:;")
|
|
815
1103
|
|
|
816
1104
|
adj.blocks[-1].body_reverse.extend(reversed(reverse))
|
|
817
1105
|
|
|
818
1106
|
# define a while loop
|
|
819
1107
|
def begin_while(adj, cond):
|
|
820
|
-
#
|
|
1108
|
+
# evaluate condition in its own block
|
|
821
1109
|
# so we can control replay
|
|
822
1110
|
cond_block = adj.begin_block()
|
|
823
1111
|
adj.loop_blocks.append(cond_block)
|
|
@@ -825,7 +1113,7 @@ class Adjoint:
|
|
|
825
1113
|
|
|
826
1114
|
c = adj.eval(cond)
|
|
827
1115
|
|
|
828
|
-
cond_block.body_forward.append(f"if ((
|
|
1116
|
+
cond_block.body_forward.append(f"if (({c.emit()}) == false) goto while_end_{cond_block.label};")
|
|
829
1117
|
|
|
830
1118
|
# being block around loop
|
|
831
1119
|
adj.begin_block()
|
|
@@ -859,10 +1147,7 @@ class Adjoint:
|
|
|
859
1147
|
|
|
860
1148
|
# zero adjoints of local vars
|
|
861
1149
|
for i in body_block.vars:
|
|
862
|
-
|
|
863
|
-
reverse.append(f"adj_{i} = {i.ctype()}{{}};")
|
|
864
|
-
else:
|
|
865
|
-
reverse.append(f"adj_{i} = {i.ctype()}(0);")
|
|
1150
|
+
reverse.append(f"{i.emit_adj()} = {{}};")
|
|
866
1151
|
|
|
867
1152
|
# replay
|
|
868
1153
|
for i in body_block.body_replay:
|
|
@@ -882,6 +1167,10 @@ class Adjoint:
|
|
|
882
1167
|
for f in node.body:
|
|
883
1168
|
adj.eval(f)
|
|
884
1169
|
|
|
1170
|
+
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
1171
|
+
if not isinstance(node.body[-1], ast.Return):
|
|
1172
|
+
adj.add_forward("return {};", skip_replay=True)
|
|
1173
|
+
|
|
885
1174
|
def emit_If(adj, node):
|
|
886
1175
|
if len(node.body) == 0:
|
|
887
1176
|
return None
|
|
@@ -909,7 +1198,7 @@ class Adjoint:
|
|
|
909
1198
|
|
|
910
1199
|
if var1 != var2:
|
|
911
1200
|
# insert a phi function that selects var1, var2 based on cond
|
|
912
|
-
out = adj.
|
|
1201
|
+
out = adj.add_builtin_call("select", [cond, var1, var2])
|
|
913
1202
|
adj.symbols[sym] = out
|
|
914
1203
|
|
|
915
1204
|
symbols_prev = adj.symbols.copy()
|
|
@@ -933,7 +1222,7 @@ class Adjoint:
|
|
|
933
1222
|
if var1 != var2:
|
|
934
1223
|
# insert a phi function that selects var1, var2 based on cond
|
|
935
1224
|
# note the reversed order of vars since we want to use !cond as our select
|
|
936
|
-
out = adj.
|
|
1225
|
+
out = adj.add_builtin_call("select", [cond, var2, var1])
|
|
937
1226
|
adj.symbols[sym] = out
|
|
938
1227
|
|
|
939
1228
|
def emit_Compare(adj, node):
|
|
@@ -955,7 +1244,7 @@ class Adjoint:
|
|
|
955
1244
|
elif isinstance(op, ast.Or):
|
|
956
1245
|
func = "||"
|
|
957
1246
|
else:
|
|
958
|
-
raise
|
|
1247
|
+
raise WarpCodegenKeyError(f"Op {op} is not supported")
|
|
959
1248
|
|
|
960
1249
|
return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
|
|
961
1250
|
|
|
@@ -975,7 +1264,7 @@ class Adjoint:
|
|
|
975
1264
|
obj = capturedvars.get(str(node.id), None)
|
|
976
1265
|
|
|
977
1266
|
if obj is None:
|
|
978
|
-
raise
|
|
1267
|
+
raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
|
|
979
1268
|
|
|
980
1269
|
if warp.types.is_value(obj):
|
|
981
1270
|
# evaluate constant
|
|
@@ -987,26 +1276,96 @@ class Adjoint:
|
|
|
987
1276
|
# pass it back to the caller for processing
|
|
988
1277
|
return obj
|
|
989
1278
|
|
|
1279
|
+
@staticmethod
|
|
1280
|
+
def resolve_type_attribute(var_type: type, attr: str):
|
|
1281
|
+
if isinstance(var_type, type) and type_is_value(var_type):
|
|
1282
|
+
if attr == "dtype":
|
|
1283
|
+
return type_scalar_type(var_type)
|
|
1284
|
+
elif attr == "length":
|
|
1285
|
+
return type_length(var_type)
|
|
1286
|
+
|
|
1287
|
+
return getattr(var_type, attr, None)
|
|
1288
|
+
|
|
1289
|
+
def vector_component_index(adj, component, vector_type):
|
|
1290
|
+
if len(component) != 1:
|
|
1291
|
+
raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
|
|
1292
|
+
|
|
1293
|
+
dim = vector_type._shape_[0]
|
|
1294
|
+
swizzles = "xyzw"[0:dim]
|
|
1295
|
+
if component not in swizzles:
|
|
1296
|
+
raise WarpCodegenAttributeError(
|
|
1297
|
+
f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
|
|
1298
|
+
)
|
|
1299
|
+
|
|
1300
|
+
index = swizzles.index(component)
|
|
1301
|
+
index = adj.add_constant(index)
|
|
1302
|
+
return index
|
|
1303
|
+
|
|
1304
|
+
@staticmethod
|
|
1305
|
+
def is_differentiable_value_type(var_type):
|
|
1306
|
+
# checks that the argument type is a value type (i.e, not an array)
|
|
1307
|
+
# possibly holding differentiable values (for which gradients must be accumulated)
|
|
1308
|
+
return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
|
|
1309
|
+
|
|
990
1310
|
def emit_Attribute(adj, node):
|
|
991
|
-
|
|
992
|
-
|
|
1311
|
+
if hasattr(node, "is_adjoint"):
|
|
1312
|
+
node.value.is_adjoint = True
|
|
1313
|
+
|
|
1314
|
+
aggregate = adj.eval(node.value)
|
|
993
1315
|
|
|
994
|
-
|
|
995
|
-
|
|
1316
|
+
try:
|
|
1317
|
+
if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
|
|
1318
|
+
out = getattr(aggregate, node.attr)
|
|
996
1319
|
|
|
997
1320
|
if warp.types.is_value(out):
|
|
998
1321
|
return adj.add_constant(out)
|
|
999
1322
|
|
|
1000
1323
|
return out
|
|
1001
1324
|
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1325
|
+
if hasattr(node, "is_adjoint"):
|
|
1326
|
+
# create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
|
|
1327
|
+
attr_name = aggregate.label + "." + node.attr
|
|
1328
|
+
attr_type = aggregate.type.vars[node.attr].type
|
|
1329
|
+
|
|
1330
|
+
return Var(attr_name, attr_type)
|
|
1331
|
+
|
|
1332
|
+
aggregate_type = strip_reference(aggregate.type)
|
|
1333
|
+
|
|
1334
|
+
# reading a vector component
|
|
1335
|
+
if type_is_vector(aggregate_type):
|
|
1336
|
+
index = adj.vector_component_index(node.attr, aggregate_type)
|
|
1337
|
+
|
|
1338
|
+
return adj.add_builtin_call("extract", [aggregate, index])
|
|
1339
|
+
|
|
1340
|
+
else:
|
|
1341
|
+
attr_type = Reference(aggregate_type.vars[node.attr].type)
|
|
1342
|
+
attr = adj.add_var(attr_type)
|
|
1343
|
+
|
|
1344
|
+
if is_reference(aggregate.type):
|
|
1345
|
+
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
|
|
1346
|
+
else:
|
|
1347
|
+
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
|
|
1348
|
+
|
|
1349
|
+
if adj.is_differentiable_value_type(strip_reference(attr_type)):
|
|
1350
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
|
|
1351
|
+
else:
|
|
1352
|
+
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
|
|
1353
|
+
|
|
1354
|
+
return attr
|
|
1355
|
+
|
|
1356
|
+
except (KeyError, AttributeError):
|
|
1357
|
+
# Try resolving as type attribute
|
|
1358
|
+
aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
|
|
1005
1359
|
|
|
1006
|
-
|
|
1360
|
+
type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
|
|
1361
|
+
if type_attribute is not None:
|
|
1362
|
+
return type_attribute
|
|
1007
1363
|
|
|
1008
|
-
|
|
1009
|
-
|
|
1364
|
+
if isinstance(aggregate, Var):
|
|
1365
|
+
raise WarpCodegenAttributeError(
|
|
1366
|
+
f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
|
|
1367
|
+
)
|
|
1368
|
+
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
|
|
1010
1369
|
|
|
1011
1370
|
def emit_String(adj, node):
|
|
1012
1371
|
# string constant
|
|
@@ -1023,19 +1382,25 @@ class Adjoint:
|
|
|
1023
1382
|
adj.symbols[key] = out
|
|
1024
1383
|
return out
|
|
1025
1384
|
|
|
1385
|
+
def emit_Ellipsis(adj, node):
|
|
1386
|
+
# stubbed @wp.native_func
|
|
1387
|
+
return
|
|
1388
|
+
|
|
1026
1389
|
def emit_NameConstant(adj, node):
|
|
1027
|
-
if node.value
|
|
1390
|
+
if node.value:
|
|
1028
1391
|
return adj.add_constant(True)
|
|
1029
|
-
elif node.value == False:
|
|
1030
|
-
return adj.add_constant(False)
|
|
1031
1392
|
elif node.value is None:
|
|
1032
|
-
raise
|
|
1393
|
+
raise WarpCodegenTypeError("None type unsupported")
|
|
1394
|
+
else:
|
|
1395
|
+
return adj.add_constant(False)
|
|
1033
1396
|
|
|
1034
1397
|
def emit_Constant(adj, node):
|
|
1035
1398
|
if isinstance(node, ast.Str):
|
|
1036
1399
|
return adj.emit_String(node)
|
|
1037
1400
|
elif isinstance(node, ast.Num):
|
|
1038
1401
|
return adj.emit_Num(node)
|
|
1402
|
+
elif isinstance(node, ast.Ellipsis):
|
|
1403
|
+
return adj.emit_Ellipsis(node)
|
|
1039
1404
|
else:
|
|
1040
1405
|
assert isinstance(node, ast.NameConstant)
|
|
1041
1406
|
return adj.emit_NameConstant(node)
|
|
@@ -1046,18 +1411,16 @@ class Adjoint:
|
|
|
1046
1411
|
right = adj.eval(node.right)
|
|
1047
1412
|
|
|
1048
1413
|
name = builtin_operators[type(node.op)]
|
|
1049
|
-
func = warp.context.builtin_functions[name]
|
|
1050
1414
|
|
|
1051
|
-
return adj.
|
|
1415
|
+
return adj.add_builtin_call(name, [left, right])
|
|
1052
1416
|
|
|
1053
1417
|
def emit_UnaryOp(adj, node):
|
|
1054
1418
|
# evaluate unary op arguments
|
|
1055
1419
|
arg = adj.eval(node.operand)
|
|
1056
1420
|
|
|
1057
1421
|
name = builtin_operators[type(node.op)]
|
|
1058
|
-
func = warp.context.builtin_functions[name]
|
|
1059
1422
|
|
|
1060
|
-
return adj.
|
|
1423
|
+
return adj.add_builtin_call(name, [arg])
|
|
1061
1424
|
|
|
1062
1425
|
def materialize_redefinitions(adj, symbols):
|
|
1063
1426
|
# detect symbols with conflicting definitions (assigned inside the for loop)
|
|
@@ -1067,21 +1430,19 @@ class Adjoint:
|
|
|
1067
1430
|
var2 = adj.symbols[sym]
|
|
1068
1431
|
|
|
1069
1432
|
if var1 != var2:
|
|
1070
|
-
if warp.config.verbose:
|
|
1433
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1071
1434
|
lineno = adj.lineno + adj.fun_lineno
|
|
1072
|
-
line = adj.
|
|
1073
|
-
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this
|
|
1435
|
+
line = adj.source_lines[adj.lineno]
|
|
1436
|
+
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
|
|
1074
1437
|
print(msg)
|
|
1075
1438
|
|
|
1076
1439
|
if var1.constant is not None:
|
|
1077
|
-
raise
|
|
1078
|
-
"Error mutating a constant {} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
|
|
1079
|
-
sym
|
|
1080
|
-
)
|
|
1440
|
+
raise WarpCodegenError(
|
|
1441
|
+
f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
|
|
1081
1442
|
)
|
|
1082
1443
|
|
|
1083
1444
|
# overwrite the old variable value (violates SSA)
|
|
1084
|
-
adj.
|
|
1445
|
+
adj.add_builtin_call("assign", [var1, var2])
|
|
1085
1446
|
|
|
1086
1447
|
# reset the symbol to point to the original variable
|
|
1087
1448
|
adj.symbols[sym] = var1
|
|
@@ -1100,95 +1461,132 @@ class Adjoint:
|
|
|
1100
1461
|
|
|
1101
1462
|
adj.end_while()
|
|
1102
1463
|
|
|
1103
|
-
def
|
|
1104
|
-
# simple constant
|
|
1464
|
+
def eval_num(adj, a):
|
|
1105
1465
|
if isinstance(a, ast.Num):
|
|
1106
|
-
return True
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1466
|
+
return True, a.n
|
|
1467
|
+
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
|
|
1468
|
+
return True, -a.operand.n
|
|
1469
|
+
|
|
1470
|
+
# try and resolve the expression to an object
|
|
1471
|
+
# e.g.: wp.constant in the globals scope
|
|
1472
|
+
obj, _ = adj.resolve_static_expression(a)
|
|
1473
|
+
|
|
1474
|
+
if isinstance(obj, Var) and obj.constant is not None:
|
|
1475
|
+
obj = obj.constant
|
|
1476
|
+
|
|
1477
|
+
return warp.types.is_int(obj), obj
|
|
1478
|
+
|
|
1479
|
+
# detects whether a loop contains a break (or continue) statement
|
|
1480
|
+
def contains_break(adj, body):
|
|
1481
|
+
for s in body:
|
|
1482
|
+
if isinstance(s, ast.Break):
|
|
1115
1483
|
return True
|
|
1484
|
+
elif isinstance(s, ast.Continue):
|
|
1485
|
+
return True
|
|
1486
|
+
elif isinstance(s, ast.If):
|
|
1487
|
+
if adj.contains_break(s.body):
|
|
1488
|
+
return True
|
|
1489
|
+
if adj.contains_break(s.orelse):
|
|
1490
|
+
return True
|
|
1116
1491
|
else:
|
|
1117
|
-
|
|
1492
|
+
# note that nested for or while loops containing a break statement
|
|
1493
|
+
# do not affect the current loop
|
|
1494
|
+
pass
|
|
1495
|
+
|
|
1496
|
+
return False
|
|
1497
|
+
|
|
1498
|
+
# returns a constant range() if unrollable, otherwise None
|
|
1499
|
+
def get_unroll_range(adj, loop):
|
|
1500
|
+
if (
|
|
1501
|
+
not isinstance(loop.iter, ast.Call)
|
|
1502
|
+
or not isinstance(loop.iter.func, ast.Name)
|
|
1503
|
+
or loop.iter.func.id != "range"
|
|
1504
|
+
or len(loop.iter.args) == 0
|
|
1505
|
+
or len(loop.iter.args) > 3
|
|
1506
|
+
):
|
|
1507
|
+
return None
|
|
1118
1508
|
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1509
|
+
# if all range() arguments are numeric constants we will unroll
|
|
1510
|
+
# note that this only handles trivial constants, it will not unroll
|
|
1511
|
+
# constant compile-time expressions e.g.: range(0, 3*2)
|
|
1512
|
+
|
|
1513
|
+
# Evaluate the arguments and check that they are numeric constants
|
|
1514
|
+
# It is important to do that in one pass, so that if evaluating these arguments have side effects
|
|
1515
|
+
# the code does not get generated more than once
|
|
1516
|
+
range_args = [adj.eval_num(arg) for arg in loop.iter.args]
|
|
1517
|
+
arg_is_numeric, arg_values = zip(*range_args)
|
|
1518
|
+
|
|
1519
|
+
if all(arg_is_numeric):
|
|
1520
|
+
# All argument are numeric constants
|
|
1521
|
+
|
|
1522
|
+
# range(end)
|
|
1523
|
+
if len(loop.iter.args) == 1:
|
|
1524
|
+
start = 0
|
|
1525
|
+
end = arg_values[0]
|
|
1526
|
+
step = 1
|
|
1527
|
+
|
|
1528
|
+
# range(start, end)
|
|
1529
|
+
elif len(loop.iter.args) == 2:
|
|
1530
|
+
start = arg_values[0]
|
|
1531
|
+
end = arg_values[1]
|
|
1532
|
+
step = 1
|
|
1533
|
+
|
|
1534
|
+
# range(start, end, step)
|
|
1535
|
+
elif len(loop.iter.args) == 3:
|
|
1536
|
+
start = arg_values[0]
|
|
1537
|
+
end = arg_values[1]
|
|
1538
|
+
step = arg_values[2]
|
|
1539
|
+
|
|
1540
|
+
# test if we're above max unroll count
|
|
1541
|
+
max_iters = abs(end - start) // abs(step)
|
|
1542
|
+
max_unroll = adj.builder.options["max_unroll"]
|
|
1543
|
+
|
|
1544
|
+
ok_to_unroll = True
|
|
1545
|
+
|
|
1546
|
+
if max_iters > max_unroll:
|
|
1547
|
+
if warp.config.verbose:
|
|
1548
|
+
print(
|
|
1549
|
+
f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
|
|
1550
|
+
)
|
|
1551
|
+
ok_to_unroll = False
|
|
1552
|
+
|
|
1553
|
+
elif adj.contains_break(loop.body):
|
|
1554
|
+
if warp.config.verbose:
|
|
1555
|
+
print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
|
|
1556
|
+
ok_to_unroll = False
|
|
1557
|
+
|
|
1558
|
+
if ok_to_unroll:
|
|
1559
|
+
return range(start, end, step)
|
|
1560
|
+
|
|
1561
|
+
# Unroll is not possible, range needs to be valuated dynamically
|
|
1562
|
+
range_call = adj.add_builtin_call(
|
|
1563
|
+
"range",
|
|
1564
|
+
[adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
|
|
1565
|
+
)
|
|
1566
|
+
return range_call
|
|
1132
1567
|
|
|
1133
1568
|
def emit_For(adj, node):
|
|
1134
1569
|
# try and unroll simple range() statements that use constant args
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
if isinstance(node.iter, ast.Call) and node.iter.func.id == "range":
|
|
1138
|
-
is_constant = True
|
|
1139
|
-
for a in node.iter.args:
|
|
1140
|
-
# if all range() arguments are numeric constants we will unroll
|
|
1141
|
-
# note that this only handles trivial constants, it will not unroll
|
|
1142
|
-
# constant compile-time expressions e.g.: range(0, 3*2)
|
|
1143
|
-
if not adj.is_num(a):
|
|
1144
|
-
is_constant = False
|
|
1145
|
-
break
|
|
1146
|
-
|
|
1147
|
-
if is_constant:
|
|
1148
|
-
# range(end)
|
|
1149
|
-
if len(node.iter.args) == 1:
|
|
1150
|
-
start = 0
|
|
1151
|
-
end = adj.eval_num(node.iter.args[0])
|
|
1152
|
-
step = 1
|
|
1153
|
-
|
|
1154
|
-
# range(start, end)
|
|
1155
|
-
elif len(node.iter.args) == 2:
|
|
1156
|
-
start = adj.eval_num(node.iter.args[0])
|
|
1157
|
-
end = adj.eval_num(node.iter.args[1])
|
|
1158
|
-
step = 1
|
|
1159
|
-
|
|
1160
|
-
# range(start, end, step)
|
|
1161
|
-
elif len(node.iter.args) == 3:
|
|
1162
|
-
start = adj.eval_num(node.iter.args[0])
|
|
1163
|
-
end = adj.eval_num(node.iter.args[1])
|
|
1164
|
-
step = adj.eval_num(node.iter.args[2])
|
|
1165
|
-
|
|
1166
|
-
# test if we're above max unroll count
|
|
1167
|
-
max_iters = abs(end - start) // abs(step)
|
|
1168
|
-
max_unroll = adj.builder.options["max_unroll"]
|
|
1169
|
-
|
|
1170
|
-
if max_iters > max_unroll:
|
|
1171
|
-
if warp.config.verbose:
|
|
1172
|
-
print(
|
|
1173
|
-
f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
|
|
1174
|
-
)
|
|
1175
|
-
else:
|
|
1176
|
-
# unroll
|
|
1177
|
-
for i in range(start, end, step):
|
|
1178
|
-
const_iter = adj.add_constant(i)
|
|
1179
|
-
var_iter = adj.add_call(warp.context.builtin_functions["int"], [const_iter])
|
|
1180
|
-
adj.symbols[node.target.id] = var_iter
|
|
1570
|
+
unroll_range = adj.get_unroll_range(node)
|
|
1181
1571
|
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1572
|
+
if isinstance(unroll_range, range):
|
|
1573
|
+
for i in unroll_range:
|
|
1574
|
+
const_iter = adj.add_constant(i)
|
|
1575
|
+
var_iter = adj.add_builtin_call("int", [const_iter])
|
|
1576
|
+
adj.symbols[node.target.id] = var_iter
|
|
1185
1577
|
|
|
1186
|
-
|
|
1578
|
+
# eval body
|
|
1579
|
+
for s in node.body:
|
|
1580
|
+
adj.eval(s)
|
|
1187
1581
|
|
|
1188
|
-
#
|
|
1189
|
-
|
|
1190
|
-
# evaluate the Iterable
|
|
1191
|
-
|
|
1582
|
+
# otherwise generate a dynamic loop
|
|
1583
|
+
else:
|
|
1584
|
+
# evaluate the Iterable -- only if not previously evaluated when trying to unroll
|
|
1585
|
+
if unroll_range is not None:
|
|
1586
|
+
# Range has already been evaluated when trying to unroll, do not re-evaluate
|
|
1587
|
+
iter = unroll_range
|
|
1588
|
+
else:
|
|
1589
|
+
iter = adj.eval(node.iter)
|
|
1192
1590
|
|
|
1193
1591
|
adj.symbols[node.target.id] = adj.begin_for(iter)
|
|
1194
1592
|
|
|
@@ -1217,15 +1615,28 @@ class Adjoint:
|
|
|
1217
1615
|
def emit_Expr(adj, node):
|
|
1218
1616
|
return adj.eval(node.value)
|
|
1219
1617
|
|
|
1618
|
+
def check_tid_in_func_error(adj, node):
|
|
1619
|
+
if adj.is_user_function:
|
|
1620
|
+
if hasattr(node.func, "attr") and node.func.attr == "tid":
|
|
1621
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
1622
|
+
line = adj.source_lines[adj.lineno]
|
|
1623
|
+
raise WarpCodegenError(
|
|
1624
|
+
"tid() may only be called from a Warp kernel, not a Warp function. "
|
|
1625
|
+
"Instead, obtain the indices from a @wp.kernel and pass them as "
|
|
1626
|
+
f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
|
|
1627
|
+
)
|
|
1628
|
+
|
|
1220
1629
|
def emit_Call(adj, node):
|
|
1630
|
+
adj.check_tid_in_func_error(node)
|
|
1631
|
+
|
|
1221
1632
|
# try and lookup function in globals by
|
|
1222
1633
|
# resolving path (e.g.: module.submodule.attr)
|
|
1223
|
-
func, path = adj.
|
|
1634
|
+
func, path = adj.resolve_static_expression(node.func)
|
|
1224
1635
|
templates = []
|
|
1225
1636
|
|
|
1226
|
-
if isinstance(func, warp.context.Function)
|
|
1637
|
+
if not isinstance(func, warp.context.Function):
|
|
1227
1638
|
if len(path) == 0:
|
|
1228
|
-
raise
|
|
1639
|
+
raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
|
|
1229
1640
|
|
|
1230
1641
|
attr = path[-1]
|
|
1231
1642
|
caller = func
|
|
@@ -1250,7 +1661,7 @@ class Adjoint:
|
|
|
1250
1661
|
func = caller.initializer()
|
|
1251
1662
|
|
|
1252
1663
|
if func is None:
|
|
1253
|
-
raise
|
|
1664
|
+
raise WarpCodegenError(
|
|
1254
1665
|
f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
|
|
1255
1666
|
)
|
|
1256
1667
|
|
|
@@ -1259,16 +1670,25 @@ class Adjoint:
|
|
|
1259
1670
|
# eval all arguments
|
|
1260
1671
|
for arg in node.args:
|
|
1261
1672
|
var = adj.eval(arg)
|
|
1673
|
+
if not is_local_value(var):
|
|
1674
|
+
raise RuntimeError(
|
|
1675
|
+
"Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
|
|
1676
|
+
)
|
|
1262
1677
|
args.append(var)
|
|
1263
1678
|
|
|
1264
|
-
# eval all keyword
|
|
1679
|
+
# eval all keyword args
|
|
1265
1680
|
def kwval(kw):
|
|
1266
1681
|
if isinstance(kw.value, ast.Num):
|
|
1267
1682
|
return kw.value.n
|
|
1268
1683
|
elif isinstance(kw.value, ast.Tuple):
|
|
1269
|
-
|
|
1684
|
+
arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
|
|
1685
|
+
if not all(arg_is_numeric):
|
|
1686
|
+
raise WarpCodegenError(
|
|
1687
|
+
f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
|
|
1688
|
+
)
|
|
1689
|
+
return arg_values
|
|
1270
1690
|
else:
|
|
1271
|
-
return adj.
|
|
1691
|
+
return adj.resolve_static_expression(kw.value)[0]
|
|
1272
1692
|
|
|
1273
1693
|
kwds = {kw.arg: kwval(kw) for kw in node.keywords}
|
|
1274
1694
|
|
|
@@ -1285,10 +1705,26 @@ class Adjoint:
|
|
|
1285
1705
|
# the ast.Index node appears in 3.7 versions
|
|
1286
1706
|
# when performing array slices, e.g.: x = arr[i]
|
|
1287
1707
|
# but in version 3.8 and higher it does not appear
|
|
1708
|
+
|
|
1709
|
+
if hasattr(node, "is_adjoint"):
|
|
1710
|
+
node.value.is_adjoint = True
|
|
1711
|
+
|
|
1288
1712
|
return adj.eval(node.value)
|
|
1289
1713
|
|
|
1290
1714
|
def emit_Subscript(adj, node):
|
|
1715
|
+
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
|
|
1716
|
+
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
1717
|
+
node.slice.is_adjoint = True
|
|
1718
|
+
var = adj.eval(node.slice)
|
|
1719
|
+
var_name = var.label
|
|
1720
|
+
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
|
|
1721
|
+
return var
|
|
1722
|
+
|
|
1291
1723
|
target = adj.eval(node.value)
|
|
1724
|
+
if not is_local_value(target):
|
|
1725
|
+
raise RuntimeError(
|
|
1726
|
+
"Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
|
|
1727
|
+
)
|
|
1292
1728
|
|
|
1293
1729
|
indices = []
|
|
1294
1730
|
|
|
@@ -1308,28 +1744,34 @@ class Adjoint:
|
|
|
1308
1744
|
var = adj.eval(node.slice)
|
|
1309
1745
|
indices.append(var)
|
|
1310
1746
|
|
|
1311
|
-
|
|
1312
|
-
|
|
1747
|
+
target_type = strip_reference(target.type)
|
|
1748
|
+
if is_array(target_type):
|
|
1749
|
+
if len(indices) == target_type.ndim:
|
|
1313
1750
|
# handles array loads (where each dimension has an index specified)
|
|
1314
|
-
out = adj.
|
|
1751
|
+
out = adj.add_builtin_call("address", [target, *indices])
|
|
1315
1752
|
else:
|
|
1316
1753
|
# handles array views (fewer indices than dimensions)
|
|
1317
|
-
out = adj.
|
|
1754
|
+
out = adj.add_builtin_call("view", [target, *indices])
|
|
1318
1755
|
|
|
1319
1756
|
else:
|
|
1320
1757
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
1321
|
-
out = adj.
|
|
1758
|
+
out = adj.add_builtin_call("extract", [target, *indices])
|
|
1322
1759
|
|
|
1323
1760
|
return out
|
|
1324
1761
|
|
|
1325
1762
|
def emit_Assign(adj, node):
|
|
1763
|
+
if len(node.targets) != 1:
|
|
1764
|
+
raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
|
|
1765
|
+
|
|
1766
|
+
lhs = node.targets[0]
|
|
1767
|
+
|
|
1326
1768
|
# handle the case where we are assigning multiple output variables
|
|
1327
|
-
if isinstance(
|
|
1769
|
+
if isinstance(lhs, ast.Tuple):
|
|
1328
1770
|
# record the expected number of outputs on the node
|
|
1329
1771
|
# we do this so we can decide which function to
|
|
1330
1772
|
# call based on the number of expected outputs
|
|
1331
1773
|
if isinstance(node.value, ast.Call):
|
|
1332
|
-
node.value.expects = len(
|
|
1774
|
+
node.value.expects = len(lhs.elts)
|
|
1333
1775
|
|
|
1334
1776
|
# evaluate values
|
|
1335
1777
|
if isinstance(node.value, ast.Tuple):
|
|
@@ -1338,40 +1780,47 @@ class Adjoint:
|
|
|
1338
1780
|
out = adj.eval(node.value)
|
|
1339
1781
|
|
|
1340
1782
|
names = []
|
|
1341
|
-
for v in
|
|
1783
|
+
for v in lhs.elts:
|
|
1342
1784
|
if isinstance(v, ast.Name):
|
|
1343
1785
|
names.append(v.id)
|
|
1344
1786
|
else:
|
|
1345
|
-
raise
|
|
1787
|
+
raise WarpCodegenError(
|
|
1346
1788
|
"Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
|
|
1347
1789
|
)
|
|
1348
1790
|
|
|
1349
1791
|
if len(names) != len(out):
|
|
1350
|
-
raise
|
|
1351
|
-
"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {}, got {})"
|
|
1352
|
-
len(out), len(names)
|
|
1353
|
-
)
|
|
1792
|
+
raise WarpCodegenError(
|
|
1793
|
+
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
|
|
1354
1794
|
)
|
|
1355
1795
|
|
|
1356
1796
|
for name, rhs in zip(names, out):
|
|
1357
1797
|
if name in adj.symbols:
|
|
1358
1798
|
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
1359
|
-
raise
|
|
1360
|
-
"Error, assigning to existing symbol {} ({}) with different type ({})"
|
|
1361
|
-
name, adj.symbols[name].type, rhs.type
|
|
1362
|
-
)
|
|
1799
|
+
raise WarpCodegenTypeError(
|
|
1800
|
+
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
1363
1801
|
)
|
|
1364
1802
|
|
|
1365
1803
|
adj.symbols[name] = rhs
|
|
1366
1804
|
|
|
1367
|
-
return out
|
|
1368
|
-
|
|
1369
1805
|
# handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
|
|
1370
|
-
elif isinstance(
|
|
1371
|
-
|
|
1806
|
+
elif isinstance(lhs, ast.Subscript):
|
|
1807
|
+
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
|
|
1808
|
+
# handle adjoint of a variable, i.e. wp.adjoint[var]
|
|
1809
|
+
lhs.slice.is_adjoint = True
|
|
1810
|
+
src_var = adj.eval(lhs.slice)
|
|
1811
|
+
var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
|
|
1812
|
+
value = adj.eval(node.value)
|
|
1813
|
+
adj.add_forward(f"{var.emit()} = {value.emit()};")
|
|
1814
|
+
return
|
|
1815
|
+
|
|
1816
|
+
target = adj.eval(lhs.value)
|
|
1372
1817
|
value = adj.eval(node.value)
|
|
1818
|
+
if not is_local_value(value):
|
|
1819
|
+
raise RuntimeError(
|
|
1820
|
+
"Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
|
|
1821
|
+
)
|
|
1373
1822
|
|
|
1374
|
-
slice =
|
|
1823
|
+
slice = lhs.slice
|
|
1375
1824
|
indices = []
|
|
1376
1825
|
|
|
1377
1826
|
if isinstance(slice, ast.Tuple):
|
|
@@ -1379,7 +1828,6 @@ class Adjoint:
|
|
|
1379
1828
|
for arg in slice.elts:
|
|
1380
1829
|
var = adj.eval(arg)
|
|
1381
1830
|
indices.append(var)
|
|
1382
|
-
|
|
1383
1831
|
elif isinstance(slice, ast.Index) and isinstance(slice.value, ast.Tuple):
|
|
1384
1832
|
# handles the x[i, j] case (Python 3.7.x)
|
|
1385
1833
|
for arg in slice.value.elts:
|
|
@@ -1390,64 +1838,84 @@ class Adjoint:
|
|
|
1390
1838
|
var = adj.eval(slice)
|
|
1391
1839
|
indices.append(var)
|
|
1392
1840
|
|
|
1393
|
-
|
|
1394
|
-
adj.add_call(warp.context.builtin_functions["store"], [target, *indices, value])
|
|
1841
|
+
target_type = strip_reference(target.type)
|
|
1395
1842
|
|
|
1396
|
-
|
|
1397
|
-
adj.
|
|
1843
|
+
if is_array(target_type):
|
|
1844
|
+
adj.add_builtin_call("array_store", [target, *indices, value])
|
|
1398
1845
|
|
|
1399
|
-
|
|
1846
|
+
elif type_is_vector(target_type) or type_is_matrix(target_type):
|
|
1847
|
+
if is_reference(target.type):
|
|
1848
|
+
attr = adj.add_builtin_call("indexref", [target, *indices])
|
|
1849
|
+
else:
|
|
1850
|
+
attr = adj.add_builtin_call("index", [target, *indices])
|
|
1851
|
+
|
|
1852
|
+
adj.add_builtin_call("store", [attr, value])
|
|
1853
|
+
|
|
1854
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1400
1855
|
lineno = adj.lineno + adj.fun_lineno
|
|
1401
|
-
line = adj.
|
|
1856
|
+
line = adj.source_lines[adj.lineno]
|
|
1857
|
+
node_source = adj.get_node_source(lhs.value)
|
|
1402
1858
|
print(
|
|
1403
|
-
f"Warning: mutating {
|
|
1859
|
+
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
|
|
1404
1860
|
)
|
|
1405
1861
|
|
|
1406
1862
|
else:
|
|
1407
|
-
raise
|
|
1863
|
+
raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
|
|
1408
1864
|
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
elif isinstance(node.targets[0], ast.Name):
|
|
1865
|
+
elif isinstance(lhs, ast.Name):
|
|
1412
1866
|
# symbol name
|
|
1413
|
-
name =
|
|
1867
|
+
name = lhs.id
|
|
1414
1868
|
|
|
1415
1869
|
# evaluate rhs
|
|
1416
1870
|
rhs = adj.eval(node.value)
|
|
1417
1871
|
|
|
1418
1872
|
# check type matches if symbol already defined
|
|
1419
1873
|
if name in adj.symbols:
|
|
1420
|
-
if not types_equal(rhs.type, adj.symbols[name].type):
|
|
1421
|
-
raise
|
|
1422
|
-
"Error, assigning to existing symbol {} ({}) with different type ({})"
|
|
1423
|
-
name, adj.symbols[name].type, rhs.type
|
|
1424
|
-
)
|
|
1874
|
+
if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
|
|
1875
|
+
raise WarpCodegenTypeError(
|
|
1876
|
+
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
1425
1877
|
)
|
|
1426
1878
|
|
|
1427
1879
|
# handle simple assignment case (a = b), where we generate a value copy rather than reference
|
|
1428
|
-
if isinstance(node.value, ast.Name):
|
|
1429
|
-
out = adj.
|
|
1430
|
-
adj.add_call(warp.context.builtin_functions["copy"], [out, rhs])
|
|
1880
|
+
if isinstance(node.value, ast.Name) or is_reference(rhs.type):
|
|
1881
|
+
out = adj.add_builtin_call("copy", [rhs])
|
|
1431
1882
|
else:
|
|
1432
1883
|
out = rhs
|
|
1433
1884
|
|
|
1434
1885
|
# update symbol map (assumes lhs is a Name node)
|
|
1435
1886
|
adj.symbols[name] = out
|
|
1436
|
-
return out
|
|
1437
1887
|
|
|
1438
|
-
elif isinstance(
|
|
1888
|
+
elif isinstance(lhs, ast.Attribute):
|
|
1439
1889
|
rhs = adj.eval(node.value)
|
|
1440
|
-
|
|
1441
|
-
|
|
1890
|
+
aggregate = adj.eval(lhs.value)
|
|
1891
|
+
aggregate_type = strip_reference(aggregate.type)
|
|
1442
1892
|
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1893
|
+
# assigning to a vector component
|
|
1894
|
+
if type_is_vector(aggregate_type):
|
|
1895
|
+
index = adj.vector_component_index(lhs.attr, aggregate_type)
|
|
1896
|
+
|
|
1897
|
+
if is_reference(aggregate.type):
|
|
1898
|
+
attr = adj.add_builtin_call("indexref", [aggregate, index])
|
|
1899
|
+
else:
|
|
1900
|
+
attr = adj.add_builtin_call("index", [aggregate, index])
|
|
1901
|
+
|
|
1902
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
1903
|
+
|
|
1904
|
+
else:
|
|
1905
|
+
attr = adj.emit_Attribute(lhs)
|
|
1906
|
+
if is_reference(attr.type):
|
|
1907
|
+
adj.add_builtin_call("store", [attr, rhs])
|
|
1908
|
+
else:
|
|
1909
|
+
adj.add_builtin_call("assign", [attr, rhs])
|
|
1910
|
+
|
|
1911
|
+
if warp.config.verbose and not adj.custom_reverse_mode:
|
|
1912
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
1913
|
+
line = adj.source_lines[adj.lineno]
|
|
1914
|
+
msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
|
|
1915
|
+
print(msg)
|
|
1448
1916
|
|
|
1449
1917
|
else:
|
|
1450
|
-
raise
|
|
1918
|
+
raise WarpCodegenError("Error, unsupported assignment statement.")
|
|
1451
1919
|
|
|
1452
1920
|
def emit_Return(adj, node):
|
|
1453
1921
|
if node.value is None:
|
|
@@ -1458,30 +1926,26 @@ class Adjoint:
|
|
|
1458
1926
|
var = (adj.eval(node.value),)
|
|
1459
1927
|
|
|
1460
1928
|
if adj.return_var is not None:
|
|
1461
|
-
old_ctypes = tuple(v.ctype() for v in adj.return_var)
|
|
1462
|
-
new_ctypes = tuple(v.ctype() for v in var)
|
|
1929
|
+
old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
|
|
1930
|
+
new_ctypes = tuple(v.ctype(value_type=True) for v in var)
|
|
1463
1931
|
if old_ctypes != new_ctypes:
|
|
1464
|
-
raise
|
|
1932
|
+
raise WarpCodegenTypeError(
|
|
1465
1933
|
f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
|
|
1466
1934
|
)
|
|
1467
|
-
else:
|
|
1468
|
-
adj.return_var = var
|
|
1469
|
-
|
|
1470
|
-
adj.add_return(var)
|
|
1471
|
-
|
|
1472
|
-
def emit_AugAssign(adj, node):
|
|
1473
|
-
# convert inplace operations (+=, -=, etc) to ssa form, e.g.: c = a + b
|
|
1474
|
-
left = adj.eval(node.target)
|
|
1475
|
-
right = adj.eval(node.value)
|
|
1476
1935
|
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1936
|
+
if var is not None:
|
|
1937
|
+
adj.return_var = tuple()
|
|
1938
|
+
for ret in var:
|
|
1939
|
+
if is_reference(ret.type):
|
|
1940
|
+
ret = adj.add_builtin_call("copy", [ret])
|
|
1941
|
+
adj.return_var += (ret,)
|
|
1480
1942
|
|
|
1481
|
-
|
|
1943
|
+
adj.add_return(adj.return_var)
|
|
1482
1944
|
|
|
1483
|
-
|
|
1484
|
-
|
|
1945
|
+
def emit_AugAssign(adj, node):
|
|
1946
|
+
# replace augmented assignment with assignment statement + binary op
|
|
1947
|
+
new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
|
|
1948
|
+
adj.eval(new_node)
|
|
1485
1949
|
|
|
1486
1950
|
def emit_Tuple(adj, node):
|
|
1487
1951
|
# LHS for expressions, such as i, j, k = 1, 2, 3
|
|
@@ -1491,122 +1955,167 @@ class Adjoint:
|
|
|
1491
1955
|
def emit_Pass(adj, node):
|
|
1492
1956
|
pass
|
|
1493
1957
|
|
|
1958
|
+
node_visitors = {
|
|
1959
|
+
ast.FunctionDef: emit_FunctionDef,
|
|
1960
|
+
ast.If: emit_If,
|
|
1961
|
+
ast.Compare: emit_Compare,
|
|
1962
|
+
ast.BoolOp: emit_BoolOp,
|
|
1963
|
+
ast.Name: emit_Name,
|
|
1964
|
+
ast.Attribute: emit_Attribute,
|
|
1965
|
+
ast.Str: emit_String, # Deprecated in 3.8; use Constant
|
|
1966
|
+
ast.Num: emit_Num, # Deprecated in 3.8; use Constant
|
|
1967
|
+
ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
1968
|
+
ast.Constant: emit_Constant,
|
|
1969
|
+
ast.BinOp: emit_BinOp,
|
|
1970
|
+
ast.UnaryOp: emit_UnaryOp,
|
|
1971
|
+
ast.While: emit_While,
|
|
1972
|
+
ast.For: emit_For,
|
|
1973
|
+
ast.Break: emit_Break,
|
|
1974
|
+
ast.Continue: emit_Continue,
|
|
1975
|
+
ast.Expr: emit_Expr,
|
|
1976
|
+
ast.Call: emit_Call,
|
|
1977
|
+
ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
|
|
1978
|
+
ast.Subscript: emit_Subscript,
|
|
1979
|
+
ast.Assign: emit_Assign,
|
|
1980
|
+
ast.Return: emit_Return,
|
|
1981
|
+
ast.AugAssign: emit_AugAssign,
|
|
1982
|
+
ast.Tuple: emit_Tuple,
|
|
1983
|
+
ast.Pass: emit_Pass,
|
|
1984
|
+
ast.Ellipsis: emit_Ellipsis,
|
|
1985
|
+
}
|
|
1986
|
+
|
|
1494
1987
|
def eval(adj, node):
|
|
1495
1988
|
if hasattr(node, "lineno"):
|
|
1496
1989
|
adj.set_lineno(node.lineno - 1)
|
|
1497
1990
|
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
ast.Compare: Adjoint.emit_Compare,
|
|
1502
|
-
ast.BoolOp: Adjoint.emit_BoolOp,
|
|
1503
|
-
ast.Name: Adjoint.emit_Name,
|
|
1504
|
-
ast.Attribute: Adjoint.emit_Attribute,
|
|
1505
|
-
ast.Str: Adjoint.emit_String, # Deprecated in 3.8; use Constant
|
|
1506
|
-
ast.Num: Adjoint.emit_Num, # Deprecated in 3.8; use Constant
|
|
1507
|
-
ast.NameConstant: Adjoint.emit_NameConstant, # Deprecated in 3.8; use Constant
|
|
1508
|
-
ast.Constant: Adjoint.emit_Constant,
|
|
1509
|
-
ast.BinOp: Adjoint.emit_BinOp,
|
|
1510
|
-
ast.UnaryOp: Adjoint.emit_UnaryOp,
|
|
1511
|
-
ast.While: Adjoint.emit_While,
|
|
1512
|
-
ast.For: Adjoint.emit_For,
|
|
1513
|
-
ast.Break: Adjoint.emit_Break,
|
|
1514
|
-
ast.Continue: Adjoint.emit_Continue,
|
|
1515
|
-
ast.Expr: Adjoint.emit_Expr,
|
|
1516
|
-
ast.Call: Adjoint.emit_Call,
|
|
1517
|
-
ast.Index: Adjoint.emit_Index, # Deprecated in 3.8; Use the index value directly instead.
|
|
1518
|
-
ast.Subscript: Adjoint.emit_Subscript,
|
|
1519
|
-
ast.Assign: Adjoint.emit_Assign,
|
|
1520
|
-
ast.Return: Adjoint.emit_Return,
|
|
1521
|
-
ast.AugAssign: Adjoint.emit_AugAssign,
|
|
1522
|
-
ast.Tuple: Adjoint.emit_Tuple,
|
|
1523
|
-
ast.Pass: Adjoint.emit_Pass,
|
|
1524
|
-
}
|
|
1525
|
-
|
|
1526
|
-
emit_node = node_visitors.get(type(node))
|
|
1527
|
-
|
|
1528
|
-
if emit_node is not None:
|
|
1529
|
-
return emit_node(adj, node)
|
|
1530
|
-
else:
|
|
1531
|
-
raise Exception("Error, ast node of type {} not supported".format(type(node)))
|
|
1991
|
+
emit_node = adj.node_visitors[type(node)]
|
|
1992
|
+
|
|
1993
|
+
return emit_node(adj, node)
|
|
1532
1994
|
|
|
1533
1995
|
# helper to evaluate expressions of the form
|
|
1534
1996
|
# obj1.obj2.obj3.attr in the function's global scope
|
|
1535
|
-
def resolve_path(adj,
|
|
1536
|
-
|
|
1997
|
+
def resolve_path(adj, path):
|
|
1998
|
+
if len(path) == 0:
|
|
1999
|
+
return None
|
|
1537
2000
|
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
2001
|
+
# if root is overshadowed by local symbols, bail out
|
|
2002
|
+
if path[0] in adj.symbols:
|
|
2003
|
+
return None
|
|
1541
2004
|
|
|
1542
|
-
if
|
|
1543
|
-
|
|
2005
|
+
if path[0] in __builtins__:
|
|
2006
|
+
return __builtins__[path[0]]
|
|
1544
2007
|
|
|
1545
|
-
#
|
|
1546
|
-
|
|
2008
|
+
# Look up the closure info and append it to adj.func.__globals__
|
|
2009
|
+
# in case you want to define a kernel inside a function and refer
|
|
2010
|
+
# to variables you've declared inside that function:
|
|
2011
|
+
extract_contents = (
|
|
2012
|
+
lambda contents: contents
|
|
2013
|
+
if isinstance(contents, warp.context.Function) or not callable(contents)
|
|
2014
|
+
else contents
|
|
2015
|
+
)
|
|
2016
|
+
capturedvars = dict(
|
|
2017
|
+
zip(
|
|
2018
|
+
adj.func.__code__.co_freevars,
|
|
2019
|
+
[extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
|
|
2020
|
+
)
|
|
2021
|
+
)
|
|
2022
|
+
vars_dict = {**adj.func.__globals__, **capturedvars}
|
|
1547
2023
|
|
|
1548
|
-
if
|
|
1549
|
-
|
|
2024
|
+
if path[0] in vars_dict:
|
|
2025
|
+
func = vars_dict[path[0]]
|
|
1550
2026
|
|
|
1551
|
-
#
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
# in case you want to define a kernel inside a function and refer
|
|
1555
|
-
# to variables you've declared inside that function:
|
|
1556
|
-
extract_contents = (
|
|
1557
|
-
lambda contents: contents
|
|
1558
|
-
if isinstance(contents, warp.context.Function) or not callable(contents)
|
|
1559
|
-
else contents
|
|
1560
|
-
)
|
|
1561
|
-
capturedvars = dict(
|
|
1562
|
-
zip(
|
|
1563
|
-
adj.func.__code__.co_freevars,
|
|
1564
|
-
[extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
|
|
1565
|
-
)
|
|
1566
|
-
)
|
|
2027
|
+
# Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
|
|
2028
|
+
else:
|
|
2029
|
+
func = getattr(warp, path[0], None)
|
|
1567
2030
|
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1572
|
-
pass
|
|
2031
|
+
if func:
|
|
2032
|
+
for i in range(1, len(path)):
|
|
2033
|
+
if hasattr(func, path[i]):
|
|
2034
|
+
func = getattr(func, path[i])
|
|
1573
2035
|
|
|
1574
|
-
|
|
1575
|
-
# in a kernel:
|
|
2036
|
+
return func
|
|
1576
2037
|
|
|
1577
|
-
|
|
2038
|
+
# Evaluates a static expression that does not depend on runtime values
|
|
2039
|
+
# if eval_types is True, try resolving the path using evaluated type information as well
|
|
2040
|
+
def resolve_static_expression(adj, root_node, eval_types=True):
|
|
2041
|
+
attributes = []
|
|
1578
2042
|
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
2043
|
+
node = root_node
|
|
2044
|
+
while isinstance(node, ast.Attribute):
|
|
2045
|
+
attributes.append(node.attr)
|
|
2046
|
+
node = node.value
|
|
1583
2047
|
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
2048
|
+
if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
|
2049
|
+
# support for operators returning modules
|
|
2050
|
+
# i.e. operator_name(*operator_args).x.y.z
|
|
2051
|
+
operator_args = node.args
|
|
2052
|
+
operator_name = node.func.id
|
|
2053
|
+
|
|
2054
|
+
if operator_name == "type":
|
|
2055
|
+
if len(operator_args) != 1:
|
|
2056
|
+
raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
|
|
2057
|
+
|
|
2058
|
+
# type() operator
|
|
2059
|
+
var = adj.eval(operator_args[0])
|
|
2060
|
+
|
|
2061
|
+
if isinstance(var, Var):
|
|
2062
|
+
var_type = strip_reference(var.type)
|
|
2063
|
+
# Allow accessing type attributes, for instance array.dtype
|
|
2064
|
+
while attributes:
|
|
2065
|
+
attr_name = attributes.pop()
|
|
2066
|
+
var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
|
|
2067
|
+
|
|
2068
|
+
if var_type is None:
|
|
2069
|
+
raise WarpCodegenAttributeError(
|
|
2070
|
+
f"{attr_name} is not an attribute of {type_repr(prev_type)}"
|
|
2071
|
+
)
|
|
2072
|
+
|
|
2073
|
+
return var_type, [type_repr(var_type)]
|
|
2074
|
+
else:
|
|
2075
|
+
raise WarpCodegenError(f"Cannot deduce the type of {var}")
|
|
2076
|
+
|
|
2077
|
+
# reverse list since ast presents it backward order
|
|
2078
|
+
path = [*reversed(attributes)]
|
|
2079
|
+
if isinstance(node, ast.Name):
|
|
2080
|
+
path.insert(0, node.id)
|
|
2081
|
+
|
|
2082
|
+
# Try resolving path from captured context
|
|
2083
|
+
captured_obj = adj.resolve_path(path)
|
|
2084
|
+
if captured_obj is not None:
|
|
2085
|
+
return captured_obj, path
|
|
2086
|
+
|
|
2087
|
+
# Still nothing found, maybe this is a predefined type attribute like `dtype`
|
|
2088
|
+
if eval_types:
|
|
2089
|
+
try:
|
|
2090
|
+
val = adj.eval(root_node)
|
|
2091
|
+
if val:
|
|
2092
|
+
return [val, type_repr(val)]
|
|
2093
|
+
|
|
2094
|
+
except Exception:
|
|
2095
|
+
pass
|
|
2096
|
+
|
|
2097
|
+
return None, path
|
|
1593
2098
|
|
|
1594
2099
|
# annotate generated code with the original source code line
|
|
1595
2100
|
def set_lineno(adj, lineno):
|
|
1596
2101
|
if adj.lineno is None or adj.lineno != lineno:
|
|
1597
2102
|
line = lineno + adj.fun_lineno
|
|
1598
|
-
source = adj.
|
|
2103
|
+
source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
|
|
1599
2104
|
adj.add_forward(f"// {source} <L {line}>")
|
|
1600
2105
|
adj.add_reverse(f"// adj: {source} <L {line}>")
|
|
1601
2106
|
adj.lineno = lineno
|
|
1602
2107
|
|
|
2108
|
+
def get_node_source(adj, node):
|
|
2109
|
+
# return the Python code corresponding to the given AST node
|
|
2110
|
+
return ast.get_source_segment(adj.source, node)
|
|
2111
|
+
|
|
1603
2112
|
|
|
1604
2113
|
# ----------------
|
|
1605
2114
|
# code generation
|
|
1606
2115
|
|
|
1607
2116
|
cpu_module_header = """
|
|
1608
2117
|
#define WP_NO_CRT
|
|
1609
|
-
#include "
|
|
2118
|
+
#include "builtin.h"
|
|
1610
2119
|
|
|
1611
2120
|
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
|
|
1612
2121
|
#define float(x) cast_float(x)
|
|
@@ -1615,13 +2124,16 @@ cpu_module_header = """
|
|
|
1615
2124
|
#define int(x) cast_int(x)
|
|
1616
2125
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
1617
2126
|
|
|
1618
|
-
|
|
2127
|
+
#define builtin_tid1d() wp::tid(wp::s_threadIdx)
|
|
2128
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, wp::s_threadIdx, dim)
|
|
2129
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, wp::s_threadIdx, dim)
|
|
2130
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, wp::s_threadIdx, dim)
|
|
1619
2131
|
|
|
1620
2132
|
"""
|
|
1621
2133
|
|
|
1622
2134
|
cuda_module_header = """
|
|
1623
2135
|
#define WP_NO_CRT
|
|
1624
|
-
#include "
|
|
2136
|
+
#include "builtin.h"
|
|
1625
2137
|
|
|
1626
2138
|
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
|
|
1627
2139
|
#define float(x) cast_float(x)
|
|
@@ -1630,8 +2142,10 @@ cuda_module_header = """
|
|
|
1630
2142
|
#define int(x) cast_int(x)
|
|
1631
2143
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
1632
2144
|
|
|
1633
|
-
|
|
1634
|
-
|
|
2145
|
+
#define builtin_tid1d() wp::tid(_idx)
|
|
2146
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
|
|
2147
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
2148
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
1635
2149
|
|
|
1636
2150
|
"""
|
|
1637
2151
|
|
|
@@ -1645,54 +2159,56 @@ struct {name}
|
|
|
1645
2159
|
{{
|
|
1646
2160
|
}}
|
|
1647
2161
|
|
|
1648
|
-
{name}& operator += (const {name}&)
|
|
2162
|
+
CUDA_CALLABLE {name}& operator += (const {name}& rhs)
|
|
2163
|
+
{{{prefix_add_body}
|
|
2164
|
+
return *this;}}
|
|
1649
2165
|
|
|
1650
2166
|
}};
|
|
1651
2167
|
|
|
1652
2168
|
static CUDA_CALLABLE void adj_{name}({reverse_args})
|
|
1653
2169
|
{{
|
|
1654
|
-
{reverse_body}
|
|
1655
|
-
}}
|
|
2170
|
+
{reverse_body}}}
|
|
1656
2171
|
|
|
1657
|
-
CUDA_CALLABLE void
|
|
2172
|
+
CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
|
|
1658
2173
|
{{
|
|
1659
|
-
{atomic_add_body}
|
|
1660
|
-
}}
|
|
2174
|
+
{atomic_add_body}}}
|
|
1661
2175
|
|
|
1662
2176
|
|
|
1663
2177
|
"""
|
|
1664
2178
|
|
|
1665
|
-
|
|
2179
|
+
cpu_forward_function_template = """
|
|
1666
2180
|
// {filename}:{lineno}
|
|
1667
2181
|
static {return_type} {name}(
|
|
1668
2182
|
{forward_args})
|
|
1669
2183
|
{{
|
|
1670
|
-
{forward_body}
|
|
1671
|
-
}}
|
|
2184
|
+
{forward_body}}}
|
|
1672
2185
|
|
|
2186
|
+
"""
|
|
2187
|
+
|
|
2188
|
+
cpu_reverse_function_template = """
|
|
1673
2189
|
// {filename}:{lineno}
|
|
1674
2190
|
static void adj_{name}(
|
|
1675
2191
|
{reverse_args})
|
|
1676
2192
|
{{
|
|
1677
|
-
{reverse_body}
|
|
1678
|
-
}}
|
|
2193
|
+
{reverse_body}}}
|
|
1679
2194
|
|
|
1680
2195
|
"""
|
|
1681
2196
|
|
|
1682
|
-
|
|
2197
|
+
cuda_forward_function_template = """
|
|
1683
2198
|
// {filename}:{lineno}
|
|
1684
2199
|
static CUDA_CALLABLE {return_type} {name}(
|
|
1685
2200
|
{forward_args})
|
|
1686
2201
|
{{
|
|
1687
|
-
{forward_body}
|
|
1688
|
-
}}
|
|
2202
|
+
{forward_body}}}
|
|
1689
2203
|
|
|
2204
|
+
"""
|
|
2205
|
+
|
|
2206
|
+
cuda_reverse_function_template = """
|
|
1690
2207
|
// {filename}:{lineno}
|
|
1691
2208
|
static CUDA_CALLABLE void adj_{name}(
|
|
1692
2209
|
{reverse_args})
|
|
1693
2210
|
{{
|
|
1694
|
-
{reverse_body}
|
|
1695
|
-
}}
|
|
2211
|
+
{reverse_body}}}
|
|
1696
2212
|
|
|
1697
2213
|
"""
|
|
1698
2214
|
|
|
@@ -1701,25 +2217,21 @@ cuda_kernel_template = """
|
|
|
1701
2217
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
1702
2218
|
{forward_args})
|
|
1703
2219
|
{{
|
|
1704
|
-
size_t _idx =
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
{forward_body}
|
|
2220
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
2221
|
+
_idx < dim.size;
|
|
2222
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
2223
|
+
{{
|
|
2224
|
+
{forward_body} }}
|
|
1711
2225
|
}}
|
|
1712
2226
|
|
|
1713
2227
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
1714
2228
|
{reverse_args})
|
|
1715
2229
|
{{
|
|
1716
|
-
size_t _idx =
|
|
1717
|
-
|
|
1718
|
-
|
|
1719
|
-
|
|
1720
|
-
|
|
1721
|
-
|
|
1722
|
-
{reverse_body}
|
|
2230
|
+
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
2231
|
+
_idx < dim.size;
|
|
2232
|
+
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
2233
|
+
{{
|
|
2234
|
+
{reverse_body} }}
|
|
1723
2235
|
}}
|
|
1724
2236
|
|
|
1725
2237
|
"""
|
|
@@ -1729,39 +2241,12 @@ cpu_kernel_template = """
|
|
|
1729
2241
|
void {name}_cpu_kernel_forward(
|
|
1730
2242
|
{forward_args})
|
|
1731
2243
|
{{
|
|
1732
|
-
{forward_body}
|
|
1733
|
-
}}
|
|
2244
|
+
{forward_body}}}
|
|
1734
2245
|
|
|
1735
2246
|
void {name}_cpu_kernel_backward(
|
|
1736
2247
|
{reverse_args})
|
|
1737
2248
|
{{
|
|
1738
|
-
{reverse_body}
|
|
1739
|
-
}}
|
|
1740
|
-
|
|
1741
|
-
"""
|
|
1742
|
-
|
|
1743
|
-
cuda_module_template = """
|
|
1744
|
-
|
|
1745
|
-
extern "C" {{
|
|
1746
|
-
|
|
1747
|
-
// Python entry points
|
|
1748
|
-
WP_API void {name}_cuda_forward(
|
|
1749
|
-
void* stream,
|
|
1750
|
-
{forward_args})
|
|
1751
|
-
{{
|
|
1752
|
-
{name}_cuda_kernel_forward<<<(dim.size + 256 - 1) / 256, 256, 0, (cudaStream_t)stream>>>(
|
|
1753
|
-
{forward_params});
|
|
1754
|
-
}}
|
|
1755
|
-
|
|
1756
|
-
WP_API void {name}_cuda_backward(
|
|
1757
|
-
void* stream,
|
|
1758
|
-
{reverse_args})
|
|
1759
|
-
{{
|
|
1760
|
-
{name}_cuda_kernel_backward<<<(dim.size + 256 - 1) / 256, 256, 0, (cudaStream_t)stream>>>(
|
|
1761
|
-
{reverse_params});
|
|
1762
|
-
}}
|
|
1763
|
-
|
|
1764
|
-
}} // extern C
|
|
2249
|
+
{reverse_body}}}
|
|
1765
2250
|
|
|
1766
2251
|
"""
|
|
1767
2252
|
|
|
@@ -1773,11 +2258,9 @@ extern "C" {{
|
|
|
1773
2258
|
WP_API void {name}_cpu_forward(
|
|
1774
2259
|
{forward_args})
|
|
1775
2260
|
{{
|
|
1776
|
-
set_launch_bounds(dim);
|
|
1777
|
-
|
|
1778
2261
|
for (size_t i=0; i < dim.size; ++i)
|
|
1779
2262
|
{{
|
|
1780
|
-
s_threadIdx = i;
|
|
2263
|
+
wp::s_threadIdx = i;
|
|
1781
2264
|
|
|
1782
2265
|
{name}_cpu_kernel_forward(
|
|
1783
2266
|
{forward_params});
|
|
@@ -1787,11 +2270,9 @@ WP_API void {name}_cpu_forward(
|
|
|
1787
2270
|
WP_API void {name}_cpu_backward(
|
|
1788
2271
|
{reverse_args})
|
|
1789
2272
|
{{
|
|
1790
|
-
set_launch_bounds(dim);
|
|
1791
|
-
|
|
1792
2273
|
for (size_t i=0; i < dim.size; ++i)
|
|
1793
2274
|
{{
|
|
1794
|
-
s_threadIdx = i;
|
|
2275
|
+
wp::s_threadIdx = i;
|
|
1795
2276
|
|
|
1796
2277
|
{name}_cpu_kernel_backward(
|
|
1797
2278
|
{reverse_params});
|
|
@@ -1837,7 +2318,7 @@ WP_API void {name}_cpu_backward(
|
|
|
1837
2318
|
def constant_str(value):
|
|
1838
2319
|
value_type = type(value)
|
|
1839
2320
|
|
|
1840
|
-
if value_type == bool:
|
|
2321
|
+
if value_type == bool or value_type == builtins.bool:
|
|
1841
2322
|
if value:
|
|
1842
2323
|
return "true"
|
|
1843
2324
|
else:
|
|
@@ -1854,7 +2335,9 @@ def constant_str(value):
|
|
|
1854
2335
|
|
|
1855
2336
|
scalar_value = runtime.core.half_bits_to_float
|
|
1856
2337
|
else:
|
|
1857
|
-
|
|
2338
|
+
|
|
2339
|
+
def scalar_value(x):
|
|
2340
|
+
return x
|
|
1858
2341
|
|
|
1859
2342
|
# list of scalar initializer values
|
|
1860
2343
|
initlist = []
|
|
@@ -1871,6 +2354,9 @@ def constant_str(value):
|
|
|
1871
2354
|
# make sure we emit the value of objects, e.g. uint32
|
|
1872
2355
|
return str(value.value)
|
|
1873
2356
|
|
|
2357
|
+
elif value == math.inf:
|
|
2358
|
+
return "INFINITY"
|
|
2359
|
+
|
|
1874
2360
|
else:
|
|
1875
2361
|
# otherwise just convert constant to string
|
|
1876
2362
|
return str(value)
|
|
@@ -1879,7 +2365,7 @@ def constant_str(value):
|
|
|
1879
2365
|
def indent(args, stops=1):
|
|
1880
2366
|
sep = ",\n"
|
|
1881
2367
|
for i in range(stops):
|
|
1882
|
-
sep += "
|
|
2368
|
+
sep += " "
|
|
1883
2369
|
|
|
1884
2370
|
# return sep + args.replace(", ", "," + sep)
|
|
1885
2371
|
return sep.join(args)
|
|
@@ -1887,7 +2373,9 @@ def indent(args, stops=1):
|
|
|
1887
2373
|
|
|
1888
2374
|
# generates a C function name based on the python function name
|
|
1889
2375
|
def make_full_qualified_name(func):
|
|
1890
|
-
|
|
2376
|
+
if not isinstance(func, str):
|
|
2377
|
+
func = func.__qualname__
|
|
2378
|
+
return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
|
|
1891
2379
|
|
|
1892
2380
|
|
|
1893
2381
|
def codegen_struct(struct, device="cpu", indent_size=4):
|
|
@@ -1895,8 +2383,13 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
1895
2383
|
|
|
1896
2384
|
body = []
|
|
1897
2385
|
indent_block = " " * indent_size
|
|
1898
|
-
|
|
1899
|
-
|
|
2386
|
+
|
|
2387
|
+
if len(struct.vars) > 0:
|
|
2388
|
+
for label, var in struct.vars.items():
|
|
2389
|
+
body.append(var.ctype() + " " + label + ";\n")
|
|
2390
|
+
else:
|
|
2391
|
+
# for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
|
|
2392
|
+
body.append("char _dummy_;\n")
|
|
1900
2393
|
|
|
1901
2394
|
forward_args = []
|
|
1902
2395
|
reverse_args = []
|
|
@@ -1904,21 +2397,32 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
1904
2397
|
forward_initializers = []
|
|
1905
2398
|
reverse_body = []
|
|
1906
2399
|
atomic_add_body = []
|
|
2400
|
+
prefix_add_body = []
|
|
1907
2401
|
|
|
1908
2402
|
# forward args
|
|
1909
2403
|
for label, var in struct.vars.items():
|
|
1910
|
-
|
|
1911
|
-
|
|
2404
|
+
var_ctype = var.ctype()
|
|
2405
|
+
forward_args.append(f"{var_ctype} const& {label} = {{}}")
|
|
2406
|
+
reverse_args.append(f"{var_ctype} const&")
|
|
1912
2407
|
|
|
1913
|
-
|
|
2408
|
+
namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
|
|
2409
|
+
atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
|
|
1914
2410
|
|
|
1915
2411
|
prefix = f"{indent_block}," if forward_initializers else ":"
|
|
1916
2412
|
forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
|
|
1917
2413
|
|
|
2414
|
+
# prefix-add operator
|
|
2415
|
+
for label, var in struct.vars.items():
|
|
2416
|
+
if not is_array(var.type):
|
|
2417
|
+
prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
|
|
2418
|
+
|
|
1918
2419
|
# reverse args
|
|
1919
2420
|
for label, var in struct.vars.items():
|
|
1920
|
-
reverse_args.append(var.ctype() + "
|
|
1921
|
-
|
|
2421
|
+
reverse_args.append(var.ctype() + " & adj_" + label)
|
|
2422
|
+
if is_array(var.type):
|
|
2423
|
+
reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
|
|
2424
|
+
else:
|
|
2425
|
+
reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
|
|
1922
2426
|
|
|
1923
2427
|
reverse_args.append(name + " & adj_ret")
|
|
1924
2428
|
|
|
@@ -1929,109 +2433,101 @@ def codegen_struct(struct, device="cpu", indent_size=4):
|
|
|
1929
2433
|
forward_initializers="".join(forward_initializers),
|
|
1930
2434
|
reverse_args=indent(reverse_args),
|
|
1931
2435
|
reverse_body="".join(reverse_body),
|
|
2436
|
+
prefix_add_body="".join(prefix_add_body),
|
|
1932
2437
|
atomic_add_body="".join(atomic_add_body),
|
|
1933
2438
|
)
|
|
1934
2439
|
|
|
1935
2440
|
|
|
1936
|
-
def codegen_func_forward_body(adj, device="cpu", indent=4):
|
|
1937
|
-
body = []
|
|
1938
|
-
indent_block = " " * indent
|
|
1939
|
-
|
|
1940
|
-
for f in adj.blocks[0].body_forward:
|
|
1941
|
-
body += [f + "\n"]
|
|
1942
|
-
|
|
1943
|
-
return "".join([indent_block + l for l in body])
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
2441
|
def codegen_func_forward(adj, func_type="kernel", device="cpu"):
|
|
1947
|
-
|
|
2442
|
+
if device == "cpu":
|
|
2443
|
+
indent = 4
|
|
2444
|
+
elif device == "cuda":
|
|
2445
|
+
if func_type == "kernel":
|
|
2446
|
+
indent = 8
|
|
2447
|
+
else:
|
|
2448
|
+
indent = 4
|
|
2449
|
+
else:
|
|
2450
|
+
raise ValueError(f"Device {device} not supported for codegen")
|
|
2451
|
+
|
|
2452
|
+
indent_block = " " * indent
|
|
1948
2453
|
|
|
1949
2454
|
# primal vars
|
|
1950
|
-
|
|
1951
|
-
|
|
2455
|
+
lines = []
|
|
2456
|
+
lines += ["//---------\n"]
|
|
2457
|
+
lines += ["// primal vars\n"]
|
|
1952
2458
|
|
|
1953
2459
|
for var in adj.variables:
|
|
1954
2460
|
if var.constant is None:
|
|
1955
|
-
|
|
2461
|
+
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
1956
2462
|
else:
|
|
1957
|
-
|
|
2463
|
+
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
1958
2464
|
|
|
1959
2465
|
# forward pass
|
|
1960
|
-
|
|
1961
|
-
|
|
2466
|
+
lines += ["//---------\n"]
|
|
2467
|
+
lines += ["// forward\n"]
|
|
1962
2468
|
|
|
1963
|
-
|
|
1964
|
-
|
|
2469
|
+
for f in adj.blocks[0].body_forward:
|
|
2470
|
+
lines += [f + "\n"]
|
|
2471
|
+
|
|
2472
|
+
return "".join([indent_block + l for l in lines])
|
|
1965
2473
|
|
|
2474
|
+
|
|
2475
|
+
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
2476
|
+
if device == "cpu":
|
|
2477
|
+
indent = 4
|
|
1966
2478
|
elif device == "cuda":
|
|
1967
2479
|
if func_type == "kernel":
|
|
1968
|
-
|
|
2480
|
+
indent = 8
|
|
1969
2481
|
else:
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
|
|
2482
|
+
indent = 4
|
|
2483
|
+
else:
|
|
2484
|
+
raise ValueError(f"Device {device} not supported for codegen")
|
|
1974
2485
|
|
|
1975
|
-
def codegen_func_reverse_body(adj, device="cpu", indent=4):
|
|
1976
|
-
body = []
|
|
1977
2486
|
indent_block = " " * indent
|
|
1978
2487
|
|
|
1979
|
-
|
|
1980
|
-
body += ["//---------\n"]
|
|
1981
|
-
body += ["// forward\n"]
|
|
1982
|
-
|
|
1983
|
-
for f in adj.blocks[0].body_replay:
|
|
1984
|
-
body += [f + "\n"]
|
|
1985
|
-
|
|
1986
|
-
# reverse pass
|
|
1987
|
-
body += ["//---------\n"]
|
|
1988
|
-
body += ["// reverse\n"]
|
|
1989
|
-
|
|
1990
|
-
for l in reversed(adj.blocks[0].body_reverse):
|
|
1991
|
-
body += [l + "\n"]
|
|
1992
|
-
|
|
1993
|
-
body += ["return;\n"]
|
|
1994
|
-
|
|
1995
|
-
return "".join([indent_block + l for l in body])
|
|
1996
|
-
|
|
1997
|
-
|
|
1998
|
-
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
|
|
1999
|
-
s = ""
|
|
2488
|
+
lines = []
|
|
2000
2489
|
|
|
2001
2490
|
# primal vars
|
|
2002
|
-
|
|
2003
|
-
|
|
2491
|
+
lines += ["//---------\n"]
|
|
2492
|
+
lines += ["// primal vars\n"]
|
|
2004
2493
|
|
|
2005
2494
|
for var in adj.variables:
|
|
2006
2495
|
if var.constant is None:
|
|
2007
|
-
|
|
2496
|
+
lines += [f"{var.ctype()} {var.emit()};\n"]
|
|
2008
2497
|
else:
|
|
2009
|
-
|
|
2498
|
+
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
|
|
2010
2499
|
|
|
2011
2500
|
# dual vars
|
|
2012
|
-
|
|
2013
|
-
|
|
2501
|
+
lines += ["//---------\n"]
|
|
2502
|
+
lines += ["// dual vars\n"]
|
|
2014
2503
|
|
|
2015
2504
|
for var in adj.variables:
|
|
2016
|
-
|
|
2017
|
-
s += " " + var.ctype() + " adj_" + str(var.label) + ";\n"
|
|
2018
|
-
else:
|
|
2019
|
-
s += " " + var.ctype() + " adj_" + str(var.label) + "(0);\n"
|
|
2505
|
+
lines += [f"{var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"]
|
|
2020
2506
|
|
|
2021
|
-
|
|
2022
|
-
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
|
|
2026
|
-
|
|
2027
|
-
|
|
2507
|
+
# forward pass
|
|
2508
|
+
lines += ["//---------\n"]
|
|
2509
|
+
lines += ["// forward\n"]
|
|
2510
|
+
|
|
2511
|
+
for f in adj.blocks[0].body_replay:
|
|
2512
|
+
lines += [f + "\n"]
|
|
2513
|
+
|
|
2514
|
+
# reverse pass
|
|
2515
|
+
lines += ["//---------\n"]
|
|
2516
|
+
lines += ["// reverse\n"]
|
|
2517
|
+
|
|
2518
|
+
for l in reversed(adj.blocks[0].body_reverse):
|
|
2519
|
+
lines += [l + "\n"]
|
|
2520
|
+
|
|
2521
|
+
# In grid-stride kernels the reverse body is in a for loop
|
|
2522
|
+
if device == "cuda" and func_type == "kernel":
|
|
2523
|
+
lines += ["continue;\n"]
|
|
2028
2524
|
else:
|
|
2029
|
-
|
|
2525
|
+
lines += ["return;\n"]
|
|
2030
2526
|
|
|
2031
|
-
return
|
|
2527
|
+
return "".join([indent_block + l for l in lines])
|
|
2032
2528
|
|
|
2033
2529
|
|
|
2034
|
-
def codegen_func(adj, device="cpu"):
|
|
2530
|
+
def codegen_func(adj, c_func_name: str, device="cpu", options={}):
|
|
2035
2531
|
# forward header
|
|
2036
2532
|
if adj.return_var is not None and len(adj.return_var) == 1:
|
|
2037
2533
|
return_type = adj.return_var[0].ctype()
|
|
@@ -2044,16 +2540,20 @@ def codegen_func(adj, device="cpu"):
|
|
|
2044
2540
|
reverse_args = []
|
|
2045
2541
|
|
|
2046
2542
|
# forward args
|
|
2047
|
-
for arg in adj.args:
|
|
2048
|
-
|
|
2049
|
-
|
|
2543
|
+
for i, arg in enumerate(adj.args):
|
|
2544
|
+
s = f"{arg.ctype()} {arg.emit()}"
|
|
2545
|
+
forward_args.append(s)
|
|
2546
|
+
if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
|
|
2547
|
+
reverse_args.append(s)
|
|
2050
2548
|
if has_multiple_outputs:
|
|
2051
2549
|
for i, arg in enumerate(adj.return_var):
|
|
2052
2550
|
forward_args.append(arg.ctype() + " & ret_" + str(i))
|
|
2053
2551
|
reverse_args.append(arg.ctype() + " & ret_" + str(i))
|
|
2054
2552
|
|
|
2055
2553
|
# reverse args
|
|
2056
|
-
for arg in adj.args:
|
|
2554
|
+
for i, arg in enumerate(adj.args):
|
|
2555
|
+
if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
|
|
2556
|
+
break
|
|
2057
2557
|
# indexed array gradients are regular arrays
|
|
2058
2558
|
if isinstance(arg.type, indexedarray):
|
|
2059
2559
|
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
@@ -2065,24 +2565,96 @@ def codegen_func(adj, device="cpu"):
|
|
|
2065
2565
|
reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
|
|
2066
2566
|
elif return_type != "void":
|
|
2067
2567
|
reverse_args.append(return_type + " & adj_ret")
|
|
2068
|
-
|
|
2069
|
-
|
|
2070
|
-
|
|
2071
|
-
|
|
2568
|
+
# custom output reverse args (user-declared)
|
|
2569
|
+
if adj.custom_reverse_mode:
|
|
2570
|
+
for arg in adj.args[adj.custom_reverse_num_input_args :]:
|
|
2571
|
+
reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
|
|
2072
2572
|
|
|
2073
2573
|
if device == "cpu":
|
|
2074
|
-
|
|
2574
|
+
forward_template = cpu_forward_function_template
|
|
2575
|
+
reverse_template = cpu_reverse_function_template
|
|
2075
2576
|
elif device == "cuda":
|
|
2076
|
-
|
|
2577
|
+
forward_template = cuda_forward_function_template
|
|
2578
|
+
reverse_template = cuda_reverse_function_template
|
|
2077
2579
|
else:
|
|
2078
|
-
raise ValueError("Device {} is not supported"
|
|
2580
|
+
raise ValueError(f"Device {device} is not supported")
|
|
2079
2581
|
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2582
|
+
# codegen body
|
|
2583
|
+
forward_body = codegen_func_forward(adj, func_type="function", device=device)
|
|
2584
|
+
|
|
2585
|
+
s = ""
|
|
2586
|
+
if not adj.skip_forward_codegen:
|
|
2587
|
+
s += forward_template.format(
|
|
2588
|
+
name=c_func_name,
|
|
2589
|
+
return_type=return_type,
|
|
2590
|
+
forward_args=indent(forward_args),
|
|
2591
|
+
forward_body=forward_body,
|
|
2592
|
+
filename=adj.filename,
|
|
2593
|
+
lineno=adj.fun_lineno,
|
|
2594
|
+
)
|
|
2595
|
+
|
|
2596
|
+
if not adj.skip_reverse_codegen:
|
|
2597
|
+
if adj.custom_reverse_mode:
|
|
2598
|
+
reverse_body = "\t// user-defined adjoint code\n" + forward_body
|
|
2599
|
+
else:
|
|
2600
|
+
if options.get("enable_backward", True):
|
|
2601
|
+
reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
|
|
2602
|
+
else:
|
|
2603
|
+
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
|
|
2604
|
+
s += reverse_template.format(
|
|
2605
|
+
name=c_func_name,
|
|
2606
|
+
return_type=return_type,
|
|
2607
|
+
reverse_args=indent(reverse_args),
|
|
2608
|
+
forward_body=forward_body,
|
|
2609
|
+
reverse_body=reverse_body,
|
|
2610
|
+
filename=adj.filename,
|
|
2611
|
+
lineno=adj.fun_lineno,
|
|
2612
|
+
)
|
|
2613
|
+
|
|
2614
|
+
return s
|
|
2615
|
+
|
|
2616
|
+
|
|
2617
|
+
def codegen_snippet(adj, name, snippet, adj_snippet):
|
|
2618
|
+
forward_args = []
|
|
2619
|
+
reverse_args = []
|
|
2620
|
+
|
|
2621
|
+
# forward args
|
|
2622
|
+
for i, arg in enumerate(adj.args):
|
|
2623
|
+
s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
|
|
2624
|
+
forward_args.append(s)
|
|
2625
|
+
reverse_args.append(s)
|
|
2626
|
+
|
|
2627
|
+
# reverse args
|
|
2628
|
+
for i, arg in enumerate(adj.args):
|
|
2629
|
+
if isinstance(arg.type, indexedarray):
|
|
2630
|
+
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
|
|
2631
|
+
reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
|
|
2632
|
+
else:
|
|
2633
|
+
reverse_args.append(arg.ctype() + " & adj_" + arg.label)
|
|
2634
|
+
|
|
2635
|
+
forward_template = cuda_forward_function_template
|
|
2636
|
+
reverse_template = cuda_reverse_function_template
|
|
2637
|
+
|
|
2638
|
+
s = ""
|
|
2639
|
+
s += forward_template.format(
|
|
2640
|
+
name=name,
|
|
2641
|
+
return_type="void",
|
|
2083
2642
|
forward_args=indent(forward_args),
|
|
2643
|
+
forward_body=snippet,
|
|
2644
|
+
filename=adj.filename,
|
|
2645
|
+
lineno=adj.fun_lineno,
|
|
2646
|
+
)
|
|
2647
|
+
|
|
2648
|
+
if adj_snippet:
|
|
2649
|
+
reverse_body = adj_snippet
|
|
2650
|
+
else:
|
|
2651
|
+
reverse_body = ""
|
|
2652
|
+
|
|
2653
|
+
s += reverse_template.format(
|
|
2654
|
+
name=name,
|
|
2655
|
+
return_type="void",
|
|
2084
2656
|
reverse_args=indent(reverse_args),
|
|
2085
|
-
forward_body=
|
|
2657
|
+
forward_body=snippet,
|
|
2086
2658
|
reverse_body=reverse_body,
|
|
2087
2659
|
filename=adj.filename,
|
|
2088
2660
|
lineno=adj.fun_lineno,
|
|
@@ -2098,8 +2670,8 @@ def codegen_kernel(kernel, device, options):
|
|
|
2098
2670
|
|
|
2099
2671
|
adj = kernel.adj
|
|
2100
2672
|
|
|
2101
|
-
forward_args = ["launch_bounds_t dim"]
|
|
2102
|
-
reverse_args = ["launch_bounds_t dim"]
|
|
2673
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
2674
|
+
reverse_args = ["wp::launch_bounds_t dim"]
|
|
2103
2675
|
|
|
2104
2676
|
# forward args
|
|
2105
2677
|
for arg in adj.args:
|
|
@@ -2128,7 +2700,7 @@ def codegen_kernel(kernel, device, options):
|
|
|
2128
2700
|
elif device == "cuda":
|
|
2129
2701
|
template = cuda_kernel_template
|
|
2130
2702
|
else:
|
|
2131
|
-
raise ValueError("Device {} is not supported"
|
|
2703
|
+
raise ValueError(f"Device {device} is not supported")
|
|
2132
2704
|
|
|
2133
2705
|
s = template.format(
|
|
2134
2706
|
name=kernel.get_mangled_name(),
|
|
@@ -2142,10 +2714,13 @@ def codegen_kernel(kernel, device, options):
|
|
|
2142
2714
|
|
|
2143
2715
|
|
|
2144
2716
|
def codegen_module(kernel, device="cpu"):
|
|
2717
|
+
if device != "cpu":
|
|
2718
|
+
return ""
|
|
2719
|
+
|
|
2145
2720
|
adj = kernel.adj
|
|
2146
2721
|
|
|
2147
2722
|
# build forward signature
|
|
2148
|
-
forward_args = ["launch_bounds_t dim"]
|
|
2723
|
+
forward_args = ["wp::launch_bounds_t dim"]
|
|
2149
2724
|
forward_params = ["dim"]
|
|
2150
2725
|
|
|
2151
2726
|
for arg in adj.args:
|
|
@@ -2175,14 +2750,7 @@ def codegen_module(kernel, device="cpu"):
|
|
|
2175
2750
|
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
|
|
2176
2751
|
reverse_params.append(f"adj_{arg.label}")
|
|
2177
2752
|
|
|
2178
|
-
|
|
2179
|
-
template = cpu_module_template
|
|
2180
|
-
elif device == "cuda":
|
|
2181
|
-
template = cuda_module_template
|
|
2182
|
-
else:
|
|
2183
|
-
raise ValueError("Device {} is not supported".format(device))
|
|
2184
|
-
|
|
2185
|
-
s = template.format(
|
|
2753
|
+
s = cpu_module_template.format(
|
|
2186
2754
|
name=kernel.get_mangled_name(),
|
|
2187
2755
|
forward_args=indent(forward_args),
|
|
2188
2756
|
reverse_args=indent(reverse_args),
|