warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -5,36 +5,27 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
-
import
|
|
9
|
-
import os
|
|
10
|
-
import sys
|
|
11
|
-
import hashlib
|
|
8
|
+
import ast
|
|
12
9
|
import ctypes
|
|
10
|
+
import gc
|
|
11
|
+
import hashlib
|
|
12
|
+
import inspect
|
|
13
|
+
import io
|
|
14
|
+
import os
|
|
13
15
|
import platform
|
|
14
|
-
import
|
|
16
|
+
import sys
|
|
15
17
|
import types
|
|
16
|
-
import
|
|
17
|
-
|
|
18
|
-
from typing import Tuple
|
|
19
|
-
from typing import List
|
|
20
|
-
from typing import Dict
|
|
21
|
-
from typing import Any
|
|
22
|
-
from typing import Callable
|
|
23
|
-
from typing import Union
|
|
24
|
-
from typing import Mapping
|
|
25
|
-
from typing import Optional
|
|
26
|
-
|
|
18
|
+
from copy import copy as shallowcopy
|
|
27
19
|
from types import ModuleType
|
|
20
|
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
|
28
21
|
|
|
29
|
-
|
|
22
|
+
import numpy as np
|
|
30
23
|
|
|
31
24
|
import warp
|
|
32
|
-
import warp.codegen
|
|
33
25
|
import warp.build
|
|
26
|
+
import warp.codegen
|
|
34
27
|
import warp.config
|
|
35
28
|
|
|
36
|
-
import numpy as np
|
|
37
|
-
|
|
38
29
|
# represents either a built-in or user-defined function
|
|
39
30
|
|
|
40
31
|
|
|
@@ -45,6 +36,18 @@ def create_value_func(type):
|
|
|
45
36
|
return value_func
|
|
46
37
|
|
|
47
38
|
|
|
39
|
+
def get_function_args(func):
|
|
40
|
+
"""Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type."""
|
|
41
|
+
import inspect
|
|
42
|
+
|
|
43
|
+
argspec = inspect.getfullargspec(func)
|
|
44
|
+
|
|
45
|
+
# use source-level argument annotations
|
|
46
|
+
if len(argspec.annotations) < len(argspec.args):
|
|
47
|
+
raise RuntimeError(f"Incomplete argument annotations on function {func.__qualname__}")
|
|
48
|
+
return argspec.annotations
|
|
49
|
+
|
|
50
|
+
|
|
48
51
|
class Function:
|
|
49
52
|
def __init__(
|
|
50
53
|
self,
|
|
@@ -66,8 +69,17 @@ class Function:
|
|
|
66
69
|
generic=False,
|
|
67
70
|
native_func=None,
|
|
68
71
|
defaults=None,
|
|
72
|
+
custom_replay_func=None,
|
|
73
|
+
native_snippet=None,
|
|
74
|
+
adj_native_snippet=None,
|
|
75
|
+
skip_forward_codegen=False,
|
|
76
|
+
skip_reverse_codegen=False,
|
|
77
|
+
custom_reverse_num_input_args=-1,
|
|
78
|
+
custom_reverse_mode=False,
|
|
69
79
|
overloaded_annotations=None,
|
|
70
80
|
code_transformers=[],
|
|
81
|
+
skip_adding_overload=False,
|
|
82
|
+
require_original_output_arg=False,
|
|
71
83
|
):
|
|
72
84
|
self.func = func # points to Python function decorated with @wp.func, may be None for builtins
|
|
73
85
|
self.key = key
|
|
@@ -81,6 +93,12 @@ class Function:
|
|
|
81
93
|
self.module = module
|
|
82
94
|
self.variadic = variadic # function can take arbitrary number of inputs, e.g.: printf()
|
|
83
95
|
self.defaults = defaults
|
|
96
|
+
# Function instance for a custom implementation of the replay pass
|
|
97
|
+
self.custom_replay_func = custom_replay_func
|
|
98
|
+
self.native_snippet = native_snippet
|
|
99
|
+
self.adj_native_snippet = adj_native_snippet
|
|
100
|
+
self.custom_grad_func = None
|
|
101
|
+
self.require_original_output_arg = require_original_output_arg
|
|
84
102
|
|
|
85
103
|
if initializer_list_func is None:
|
|
86
104
|
self.initializer_list_func = lambda x, y: False
|
|
@@ -110,7 +128,14 @@ class Function:
|
|
|
110
128
|
|
|
111
129
|
# user defined (Python) function
|
|
112
130
|
self.adj = warp.codegen.Adjoint(
|
|
113
|
-
func,
|
|
131
|
+
func,
|
|
132
|
+
is_user_function=True,
|
|
133
|
+
skip_forward_codegen=skip_forward_codegen,
|
|
134
|
+
skip_reverse_codegen=skip_reverse_codegen,
|
|
135
|
+
custom_reverse_num_input_args=custom_reverse_num_input_args,
|
|
136
|
+
custom_reverse_mode=custom_reverse_mode,
|
|
137
|
+
overload_annotations=overloaded_annotations,
|
|
138
|
+
transformers=code_transformers,
|
|
114
139
|
)
|
|
115
140
|
|
|
116
141
|
# record input types
|
|
@@ -139,11 +164,12 @@ class Function:
|
|
|
139
164
|
else:
|
|
140
165
|
self.mangled_name = None
|
|
141
166
|
|
|
142
|
-
|
|
167
|
+
if not skip_adding_overload:
|
|
168
|
+
self.add_overload(self)
|
|
143
169
|
|
|
144
170
|
# add to current module
|
|
145
171
|
if module:
|
|
146
|
-
module.register_function(self)
|
|
172
|
+
module.register_function(self, skip_adding_overload)
|
|
147
173
|
|
|
148
174
|
def __call__(self, *args, **kwargs):
|
|
149
175
|
# handles calling a builtin (native) function
|
|
@@ -152,121 +178,24 @@ class Function:
|
|
|
152
178
|
# from within a kernel (experimental).
|
|
153
179
|
|
|
154
180
|
if self.is_builtin() and self.mangled_name:
|
|
155
|
-
#
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
for
|
|
159
|
-
if
|
|
181
|
+
# For each of this function's existing overloads, we attempt to pack
|
|
182
|
+
# the given arguments into the C types expected by the corresponding
|
|
183
|
+
# parameters, and we rinse and repeat until we get a match.
|
|
184
|
+
for overload in self.overloads:
|
|
185
|
+
if overload.generic:
|
|
160
186
|
continue
|
|
161
187
|
|
|
162
|
-
|
|
163
|
-
if
|
|
164
|
-
|
|
165
|
-
f"Couldn't find function {self.key} with mangled name {f.mangled_name} in the Warp native library"
|
|
166
|
-
)
|
|
167
|
-
|
|
168
|
-
try:
|
|
169
|
-
# try and pack args into what the function expects
|
|
170
|
-
params = []
|
|
171
|
-
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
|
|
172
|
-
a = args[i]
|
|
173
|
-
|
|
174
|
-
# try to convert to a value type (vec3, mat33, etc)
|
|
175
|
-
if issubclass(arg_type, ctypes.Array):
|
|
176
|
-
# wrap the arg_type (which is an ctypes.Array) in a structure
|
|
177
|
-
# to ensure parameter is passed to the .dll by value rather than reference
|
|
178
|
-
class ValueArg(ctypes.Structure):
|
|
179
|
-
_fields_ = [("value", arg_type)]
|
|
180
|
-
|
|
181
|
-
x = ValueArg()
|
|
182
|
-
|
|
183
|
-
# force conversion to ndarray first (handles tuple / list, Gf.Vec3 case)
|
|
184
|
-
if isinstance(a, ctypes.Array) == False:
|
|
185
|
-
# assume you want the float32 version of the function so it doesn't just
|
|
186
|
-
# grab an override for a random data type:
|
|
187
|
-
if arg_type._type_ != ctypes.c_float:
|
|
188
|
-
raise RuntimeError(
|
|
189
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' does not have c_float type."
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
a = np.array(a)
|
|
193
|
-
|
|
194
|
-
# flatten to 1D array
|
|
195
|
-
v = a.flatten()
|
|
196
|
-
if len(v) != arg_type._length_:
|
|
197
|
-
raise RuntimeError(
|
|
198
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' has length {len(v)}, but expected {arg_type._length_}. Could not convert parameter to {arg_type}."
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
for i in range(arg_type._length_):
|
|
202
|
-
x.value[i] = v[i]
|
|
203
|
-
|
|
204
|
-
else:
|
|
205
|
-
# already a built-in type, check it matches
|
|
206
|
-
if not warp.types.types_equal(type(a), arg_type):
|
|
207
|
-
raise RuntimeError(
|
|
208
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' has type '{type(a)}' but expected '{arg_type}'"
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
x.value = a
|
|
212
|
-
|
|
213
|
-
params.append(x)
|
|
214
|
-
|
|
215
|
-
else:
|
|
216
|
-
try:
|
|
217
|
-
# try to pack as a scalar type
|
|
218
|
-
params.append(arg_type._type_(a))
|
|
219
|
-
except:
|
|
220
|
-
raise RuntimeError(
|
|
221
|
-
f"Error calling function {f.key}, unable to pack function parameter type {type(a)} for param {arg_name}, expected {arg_type}"
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
# returns the corresponding ctype for a scalar or vector warp type
|
|
225
|
-
def type_ctype(dtype):
|
|
226
|
-
if dtype == float:
|
|
227
|
-
return ctypes.c_float
|
|
228
|
-
elif dtype == int:
|
|
229
|
-
return ctypes.c_int32
|
|
230
|
-
elif issubclass(dtype, ctypes.Array):
|
|
231
|
-
return dtype
|
|
232
|
-
elif issubclass(dtype, ctypes.Structure):
|
|
233
|
-
return dtype
|
|
234
|
-
else:
|
|
235
|
-
# scalar type
|
|
236
|
-
return dtype._type_
|
|
237
|
-
|
|
238
|
-
value_type = type_ctype(f.value_func(None, None, None))
|
|
239
|
-
|
|
240
|
-
# construct return value (passed by address)
|
|
241
|
-
ret = value_type()
|
|
242
|
-
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
243
|
-
|
|
244
|
-
params.append(ret_addr)
|
|
245
|
-
|
|
246
|
-
c_func = getattr(warp.context.runtime.core, f.mangled_name)
|
|
247
|
-
c_func(*params)
|
|
248
|
-
|
|
249
|
-
if issubclass(value_type, ctypes.Array) or issubclass(value_type, ctypes.Structure):
|
|
250
|
-
# return vector types as ctypes
|
|
251
|
-
return ret
|
|
252
|
-
else:
|
|
253
|
-
# return scalar types as int/float
|
|
254
|
-
return ret.value
|
|
255
|
-
|
|
256
|
-
except Exception as e:
|
|
257
|
-
# couldn't pack values to match this overload
|
|
258
|
-
# store error and move onto the next one
|
|
259
|
-
error = e
|
|
260
|
-
continue
|
|
188
|
+
success, return_value = call_builtin(overload, *args)
|
|
189
|
+
if success:
|
|
190
|
+
return return_value
|
|
261
191
|
|
|
262
192
|
# overload resolution or call failed
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
raise RuntimeError(f"Error calling function '{f.key}'.")
|
|
193
|
+
raise RuntimeError(
|
|
194
|
+
f"Couldn't find a function '{self.key}' compatible with "
|
|
195
|
+
f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
|
|
196
|
+
)
|
|
268
197
|
|
|
269
|
-
|
|
198
|
+
if hasattr(self, "user_overloads") and len(self.user_overloads):
|
|
270
199
|
# user-defined function with overloads
|
|
271
200
|
|
|
272
201
|
if len(kwargs):
|
|
@@ -275,28 +204,26 @@ class Function:
|
|
|
275
204
|
)
|
|
276
205
|
|
|
277
206
|
# try and find a matching overload
|
|
278
|
-
for
|
|
279
|
-
if len(
|
|
207
|
+
for overload in self.user_overloads.values():
|
|
208
|
+
if len(overload.input_types) != len(args):
|
|
280
209
|
continue
|
|
281
|
-
template_types = list(
|
|
282
|
-
arg_names = list(
|
|
210
|
+
template_types = list(overload.input_types.values())
|
|
211
|
+
arg_names = list(overload.input_types.keys())
|
|
283
212
|
try:
|
|
284
213
|
# attempt to unify argument types with function template types
|
|
285
214
|
warp.types.infer_argument_types(args, template_types, arg_names)
|
|
286
|
-
return
|
|
215
|
+
return overload.func(*args)
|
|
287
216
|
except Exception:
|
|
288
217
|
continue
|
|
289
218
|
|
|
290
219
|
raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
|
|
291
220
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
if self.func is None:
|
|
296
|
-
raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
|
|
221
|
+
# user-defined function with no overloads
|
|
222
|
+
if self.func is None:
|
|
223
|
+
raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
|
|
297
224
|
|
|
298
|
-
|
|
299
|
-
|
|
225
|
+
# this function has no overloads, call it like a plain Python function
|
|
226
|
+
return self.func(*args, **kwargs)
|
|
300
227
|
|
|
301
228
|
def is_builtin(self):
|
|
302
229
|
return self.func is None
|
|
@@ -316,7 +243,7 @@ class Function:
|
|
|
316
243
|
# todo: construct a default value for each of the functions args
|
|
317
244
|
# so we can generate the return type for overloaded functions
|
|
318
245
|
return_type = type_str(self.value_func(None, None, None))
|
|
319
|
-
except:
|
|
246
|
+
except Exception:
|
|
320
247
|
return False
|
|
321
248
|
|
|
322
249
|
if return_type.startswith("Tuple"):
|
|
@@ -409,10 +336,187 @@ class Function:
|
|
|
409
336
|
return None
|
|
410
337
|
|
|
411
338
|
def __repr__(self):
|
|
412
|
-
inputs_str = ", ".join([f"{k}: {v
|
|
339
|
+
inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
|
|
413
340
|
return f"<Function {self.key}({inputs_str})>"
|
|
414
341
|
|
|
415
342
|
|
|
343
|
+
def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
344
|
+
uses_non_warp_array_type = False
|
|
345
|
+
|
|
346
|
+
# Retrieve the built-in function from Warp's dll.
|
|
347
|
+
c_func = getattr(warp.context.runtime.core, func.mangled_name)
|
|
348
|
+
|
|
349
|
+
# Try gathering the parameters that the function expects and pack them
|
|
350
|
+
# into their corresponding C types.
|
|
351
|
+
c_params = []
|
|
352
|
+
for i, (_, arg_type) in enumerate(func.input_types.items()):
|
|
353
|
+
param = params[i]
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
iter(param)
|
|
357
|
+
except TypeError:
|
|
358
|
+
is_array = False
|
|
359
|
+
else:
|
|
360
|
+
is_array = True
|
|
361
|
+
|
|
362
|
+
if is_array:
|
|
363
|
+
if not issubclass(arg_type, ctypes.Array):
|
|
364
|
+
return (False, None)
|
|
365
|
+
|
|
366
|
+
# The argument expects a built-in Warp type like a vector or a matrix.
|
|
367
|
+
|
|
368
|
+
c_param = None
|
|
369
|
+
|
|
370
|
+
if isinstance(param, ctypes.Array):
|
|
371
|
+
# The given parameter is also a built-in Warp type, so we only need
|
|
372
|
+
# to make sure that it matches with the argument.
|
|
373
|
+
if not warp.types.types_equal(type(param), arg_type):
|
|
374
|
+
return (False, None)
|
|
375
|
+
|
|
376
|
+
if isinstance(param, arg_type):
|
|
377
|
+
c_param = param
|
|
378
|
+
else:
|
|
379
|
+
# Cast the value to its argument type to make sure that it
|
|
380
|
+
# can be assigned to the field of the `Param` struct.
|
|
381
|
+
# This could error otherwise when, for example, the field type
|
|
382
|
+
# is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
|
|
383
|
+
# even though both types are semantically identical.
|
|
384
|
+
c_param = arg_type(param)
|
|
385
|
+
else:
|
|
386
|
+
# Flatten the parameter values into a flat 1-D array.
|
|
387
|
+
arr = []
|
|
388
|
+
ndim = 1
|
|
389
|
+
stack = [(0, param)]
|
|
390
|
+
while stack:
|
|
391
|
+
depth, elem = stack.pop(0)
|
|
392
|
+
try:
|
|
393
|
+
# If `elem` is a sequence, then it should be possible
|
|
394
|
+
# to add its elements to the stack for later processing.
|
|
395
|
+
stack.extend((depth + 1, x) for x in elem)
|
|
396
|
+
except TypeError:
|
|
397
|
+
# Since `elem` doesn't seem to be a sequence,
|
|
398
|
+
# we must have a leaf value that we need to add to our
|
|
399
|
+
# resulting array.
|
|
400
|
+
arr.append(elem)
|
|
401
|
+
ndim = max(depth, ndim)
|
|
402
|
+
|
|
403
|
+
assert ndim > 0
|
|
404
|
+
|
|
405
|
+
# Ensure that if the given parameter value is, say, a 2-D array,
|
|
406
|
+
# then we try to resolve it against a matrix argument rather than
|
|
407
|
+
# a vector.
|
|
408
|
+
if ndim > len(arg_type._shape_):
|
|
409
|
+
return (False, None)
|
|
410
|
+
|
|
411
|
+
elem_count = len(arr)
|
|
412
|
+
if elem_count != arg_type._length_:
|
|
413
|
+
return (False, None)
|
|
414
|
+
|
|
415
|
+
# Retrieve the element type of the sequence while ensuring
|
|
416
|
+
# that it's homogeneous.
|
|
417
|
+
elem_type = type(arr[0])
|
|
418
|
+
for i in range(1, elem_count):
|
|
419
|
+
if type(arr[i]) is not elem_type:
|
|
420
|
+
raise ValueError("All array elements must share the same type.")
|
|
421
|
+
|
|
422
|
+
expected_elem_type = arg_type._wp_scalar_type_
|
|
423
|
+
if not (
|
|
424
|
+
elem_type is expected_elem_type
|
|
425
|
+
or (elem_type is float and expected_elem_type is warp.types.float32)
|
|
426
|
+
or (elem_type is int and expected_elem_type is warp.types.int32)
|
|
427
|
+
or (
|
|
428
|
+
issubclass(elem_type, np.number)
|
|
429
|
+
and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
|
|
430
|
+
)
|
|
431
|
+
):
|
|
432
|
+
# The parameter value has a type not matching the type defined
|
|
433
|
+
# for the corresponding argument.
|
|
434
|
+
return (False, None)
|
|
435
|
+
|
|
436
|
+
if elem_type in warp.types.int_types:
|
|
437
|
+
# Pass the value through the expected integer type
|
|
438
|
+
# in order to evaluate any integer wrapping.
|
|
439
|
+
# For example `uint8(-1)` should result in the value `-255`.
|
|
440
|
+
arr = tuple(elem_type._type_(x.value).value for x in arr)
|
|
441
|
+
elif elem_type in warp.types.float_types:
|
|
442
|
+
# Extract the floating-point values.
|
|
443
|
+
arr = tuple(x.value for x in arr)
|
|
444
|
+
|
|
445
|
+
c_param = arg_type()
|
|
446
|
+
if warp.types.type_is_matrix(arg_type):
|
|
447
|
+
rows, cols = arg_type._shape_
|
|
448
|
+
for i in range(rows):
|
|
449
|
+
idx_start = i * cols
|
|
450
|
+
idx_end = idx_start + cols
|
|
451
|
+
c_param[i] = arr[idx_start:idx_end]
|
|
452
|
+
else:
|
|
453
|
+
c_param[:] = arr
|
|
454
|
+
|
|
455
|
+
uses_non_warp_array_type = True
|
|
456
|
+
|
|
457
|
+
c_params.append(ctypes.byref(c_param))
|
|
458
|
+
else:
|
|
459
|
+
if issubclass(arg_type, ctypes.Array):
|
|
460
|
+
return (False, None)
|
|
461
|
+
|
|
462
|
+
if not (
|
|
463
|
+
isinstance(param, arg_type)
|
|
464
|
+
or (type(param) is float and arg_type is warp.types.float32)
|
|
465
|
+
or (type(param) is int and arg_type is warp.types.int32)
|
|
466
|
+
or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
|
|
467
|
+
):
|
|
468
|
+
return (False, None)
|
|
469
|
+
|
|
470
|
+
if type(param) in warp.types.scalar_types:
|
|
471
|
+
param = param.value
|
|
472
|
+
|
|
473
|
+
# try to pack as a scalar type
|
|
474
|
+
if arg_type == warp.types.float16:
|
|
475
|
+
c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
|
|
476
|
+
else:
|
|
477
|
+
c_params.append(arg_type._type_(param))
|
|
478
|
+
|
|
479
|
+
# returns the corresponding ctype for a scalar or vector warp type
|
|
480
|
+
value_type = func.value_func(None, None, None)
|
|
481
|
+
if value_type == float:
|
|
482
|
+
value_ctype = ctypes.c_float
|
|
483
|
+
elif value_type == int:
|
|
484
|
+
value_ctype = ctypes.c_int32
|
|
485
|
+
elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
|
|
486
|
+
value_ctype = value_type
|
|
487
|
+
else:
|
|
488
|
+
# scalar type
|
|
489
|
+
value_ctype = value_type._type_
|
|
490
|
+
|
|
491
|
+
# construct return value (passed by address)
|
|
492
|
+
ret = value_ctype()
|
|
493
|
+
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
494
|
+
c_params.append(ret_addr)
|
|
495
|
+
|
|
496
|
+
# Call the built-in function from Warp's dll.
|
|
497
|
+
c_func(*c_params)
|
|
498
|
+
|
|
499
|
+
if uses_non_warp_array_type:
|
|
500
|
+
warp.utils.warn(
|
|
501
|
+
"Support for built-in functions called with non-Warp array types, "
|
|
502
|
+
"such as lists, tuples, NumPy arrays, and others, will be dropped "
|
|
503
|
+
"in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
|
|
504
|
+
"`wp.quat`, or `wp.transform`.",
|
|
505
|
+
DeprecationWarning,
|
|
506
|
+
stacklevel=3,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
|
|
510
|
+
# return vector types as ctypes
|
|
511
|
+
return (True, ret)
|
|
512
|
+
|
|
513
|
+
if value_type == warp.types.float16:
|
|
514
|
+
return (True, warp.types.half_bits_to_float(ret.value))
|
|
515
|
+
|
|
516
|
+
# return scalar types as int/float
|
|
517
|
+
return (True, ret.value)
|
|
518
|
+
|
|
519
|
+
|
|
416
520
|
class KernelHooks:
|
|
417
521
|
def __init__(self, forward, backward):
|
|
418
522
|
self.forward = forward
|
|
@@ -421,10 +525,20 @@ class KernelHooks:
|
|
|
421
525
|
|
|
422
526
|
# caches source and compiled entry points for a kernel (will be populated after module loads)
|
|
423
527
|
class Kernel:
|
|
424
|
-
def __init__(self, func, key, module, options=None, code_transformers=[]):
|
|
528
|
+
def __init__(self, func, key=None, module=None, options=None, code_transformers=[]):
|
|
425
529
|
self.func = func
|
|
426
|
-
|
|
427
|
-
|
|
530
|
+
|
|
531
|
+
if module is None:
|
|
532
|
+
self.module = get_module(func.__module__)
|
|
533
|
+
else:
|
|
534
|
+
self.module = module
|
|
535
|
+
|
|
536
|
+
if key is None:
|
|
537
|
+
unique_key = self.module.generate_unique_kernel_key(func.__name__)
|
|
538
|
+
self.key = unique_key
|
|
539
|
+
else:
|
|
540
|
+
self.key = key
|
|
541
|
+
|
|
428
542
|
self.options = {} if options is None else options
|
|
429
543
|
|
|
430
544
|
self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
|
|
@@ -445,8 +559,8 @@ class Kernel:
|
|
|
445
559
|
# argument indices by name
|
|
446
560
|
self.arg_indices = dict((a.label, i) for i, a in enumerate(self.adj.args))
|
|
447
561
|
|
|
448
|
-
if module:
|
|
449
|
-
module.register_kernel(self)
|
|
562
|
+
if self.module:
|
|
563
|
+
self.module.register_kernel(self)
|
|
450
564
|
|
|
451
565
|
def infer_argument_types(self, args):
|
|
452
566
|
template_types = list(self.adj.arg_types.values())
|
|
@@ -523,7 +637,7 @@ def func(f):
|
|
|
523
637
|
name = warp.codegen.make_full_qualified_name(f)
|
|
524
638
|
|
|
525
639
|
m = get_module(f.__module__)
|
|
526
|
-
|
|
640
|
+
Function(
|
|
527
641
|
func=f, key=name, namespace="", module=m, value_func=None
|
|
528
642
|
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
529
643
|
|
|
@@ -531,6 +645,167 @@ def func(f):
|
|
|
531
645
|
return m.functions[name]
|
|
532
646
|
|
|
533
647
|
|
|
648
|
+
def func_native(snippet, adj_snippet=None):
|
|
649
|
+
"""
|
|
650
|
+
Decorator to register native code snippet, @func_native
|
|
651
|
+
"""
|
|
652
|
+
|
|
653
|
+
def snippet_func(f):
|
|
654
|
+
name = warp.codegen.make_full_qualified_name(f)
|
|
655
|
+
|
|
656
|
+
m = get_module(f.__module__)
|
|
657
|
+
func = Function(
|
|
658
|
+
func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet
|
|
659
|
+
) # cuda snippets do not have a return value_type
|
|
660
|
+
|
|
661
|
+
return m.functions[name]
|
|
662
|
+
|
|
663
|
+
return snippet_func
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def func_grad(forward_fn):
|
|
667
|
+
"""
|
|
668
|
+
Decorator to register a custom gradient function for a given forward function.
|
|
669
|
+
The function signature must correspond to one of the function overloads in the following way:
|
|
670
|
+
the first part of the input arguments are the original input variables with the same types as their
|
|
671
|
+
corresponding arguments in the original function, and the second part of the input arguments are the
|
|
672
|
+
adjoint variables of the output variables (if available) of the original function with the same types as the
|
|
673
|
+
output variables. The function must not return anything.
|
|
674
|
+
"""
|
|
675
|
+
|
|
676
|
+
def wrapper(grad_fn):
|
|
677
|
+
generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
|
|
678
|
+
if generic:
|
|
679
|
+
raise RuntimeError(
|
|
680
|
+
f"Cannot define custom grad definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
reverse_args = {}
|
|
684
|
+
reverse_args.update(forward_fn.input_types)
|
|
685
|
+
|
|
686
|
+
# create temporary Adjoint instance to analyze the function signature
|
|
687
|
+
adj = warp.codegen.Adjoint(
|
|
688
|
+
grad_fn, skip_forward_codegen=True, skip_reverse_codegen=False, transformers=forward_fn.adj.transformers
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
from warp.types import types_equal
|
|
692
|
+
|
|
693
|
+
grad_args = adj.args
|
|
694
|
+
grad_sig = warp.types.get_signature([arg.type for arg in grad_args], func_name=forward_fn.key)
|
|
695
|
+
|
|
696
|
+
generic = any(warp.types.type_is_generic(x.type) for x in grad_args)
|
|
697
|
+
if generic:
|
|
698
|
+
raise RuntimeError(
|
|
699
|
+
f"Cannot define custom grad definition for {forward_fn.key} since the provided grad function has generic input arguments."
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
def match_function(f):
|
|
703
|
+
# check whether the function overload f matches the signature of the provided gradient function
|
|
704
|
+
if not hasattr(f.adj, "return_var"):
|
|
705
|
+
f.adj.build(None)
|
|
706
|
+
expected_args = list(f.input_types.items())
|
|
707
|
+
if f.adj.return_var is not None:
|
|
708
|
+
expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var]
|
|
709
|
+
if len(grad_args) != len(expected_args):
|
|
710
|
+
return False
|
|
711
|
+
if any(not types_equal(a.type, exp_type) for a, (_, exp_type) in zip(grad_args, expected_args)):
|
|
712
|
+
return False
|
|
713
|
+
return True
|
|
714
|
+
|
|
715
|
+
def add_custom_grad(f: Function):
|
|
716
|
+
# register custom gradient function
|
|
717
|
+
f.custom_grad_func = Function(
|
|
718
|
+
grad_fn,
|
|
719
|
+
key=f.key,
|
|
720
|
+
namespace=f.namespace,
|
|
721
|
+
input_types=reverse_args,
|
|
722
|
+
value_func=None,
|
|
723
|
+
module=f.module,
|
|
724
|
+
template_func=f.template_func,
|
|
725
|
+
skip_forward_codegen=True,
|
|
726
|
+
custom_reverse_mode=True,
|
|
727
|
+
custom_reverse_num_input_args=len(f.input_types),
|
|
728
|
+
skip_adding_overload=False,
|
|
729
|
+
code_transformers=f.adj.transformers,
|
|
730
|
+
)
|
|
731
|
+
f.adj.skip_reverse_codegen = True
|
|
732
|
+
|
|
733
|
+
if hasattr(forward_fn, "user_overloads") and len(forward_fn.user_overloads):
|
|
734
|
+
# find matching overload for which this grad function is defined
|
|
735
|
+
for sig, f in forward_fn.user_overloads.items():
|
|
736
|
+
if not grad_sig.startswith(sig):
|
|
737
|
+
continue
|
|
738
|
+
if match_function(f):
|
|
739
|
+
add_custom_grad(f)
|
|
740
|
+
return
|
|
741
|
+
raise RuntimeError(
|
|
742
|
+
f"No function overload found for gradient function {grad_fn.__qualname__} for function {forward_fn.key}"
|
|
743
|
+
)
|
|
744
|
+
else:
|
|
745
|
+
# resolve return variables
|
|
746
|
+
forward_fn.adj.build(None)
|
|
747
|
+
|
|
748
|
+
expected_args = list(forward_fn.input_types.items())
|
|
749
|
+
if forward_fn.adj.return_var is not None:
|
|
750
|
+
expected_args += [(f"adj_ret_{var.label}", var.type) for var in forward_fn.adj.return_var]
|
|
751
|
+
|
|
752
|
+
# check if the signature matches this function
|
|
753
|
+
if match_function(forward_fn):
|
|
754
|
+
add_custom_grad(forward_fn)
|
|
755
|
+
else:
|
|
756
|
+
raise RuntimeError(
|
|
757
|
+
f"Gradient function {grad_fn.__qualname__} for function {forward_fn.key} has an incorrect signature. The arguments must match the "
|
|
758
|
+
"forward function arguments plus the adjoint variables corresponding to the return variables:"
|
|
759
|
+
f"\n{', '.join(map(lambda nt: f'{nt[0]}: {nt[1].__name__}', expected_args))}"
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
return wrapper
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def func_replay(forward_fn):
|
|
766
|
+
"""
|
|
767
|
+
Decorator to register a custom replay function for a given forward function.
|
|
768
|
+
The replay function is the function version that is called in the forward phase of the backward pass (replay mode) and corresponds to the forward function by default.
|
|
769
|
+
The provided function has to match the signature of one of the original forward function overloads.
|
|
770
|
+
"""
|
|
771
|
+
|
|
772
|
+
def wrapper(replay_fn):
|
|
773
|
+
generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
|
|
774
|
+
if generic:
|
|
775
|
+
raise RuntimeError(
|
|
776
|
+
f"Cannot define custom replay definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
args = get_function_args(replay_fn)
|
|
780
|
+
arg_types = list(args.values())
|
|
781
|
+
generic = any(warp.types.type_is_generic(x) for x in arg_types)
|
|
782
|
+
if generic:
|
|
783
|
+
raise RuntimeError(
|
|
784
|
+
f"Cannot define custom replay definition for {forward_fn.key} since the provided replay function has generic input arguments."
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
f = forward_fn.get_overload(arg_types)
|
|
788
|
+
if f is None:
|
|
789
|
+
inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in args.items()])
|
|
790
|
+
raise RuntimeError(
|
|
791
|
+
f"Could not find forward definition of function {forward_fn.key} that matches custom replay definition with arguments:\n{inputs_str}"
|
|
792
|
+
)
|
|
793
|
+
f.custom_replay_func = Function(
|
|
794
|
+
replay_fn,
|
|
795
|
+
key=f"replay_{f.key}",
|
|
796
|
+
namespace=f.namespace,
|
|
797
|
+
input_types=f.input_types,
|
|
798
|
+
value_func=f.value_func,
|
|
799
|
+
module=f.module,
|
|
800
|
+
template_func=f.template_func,
|
|
801
|
+
skip_reverse_codegen=True,
|
|
802
|
+
skip_adding_overload=True,
|
|
803
|
+
code_transformers=f.adj.transformers,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
return wrapper
|
|
807
|
+
|
|
808
|
+
|
|
534
809
|
# decorator to register kernel, @kernel, custom_name may be a string
|
|
535
810
|
# that creates a kernel with a different name from the actual function
|
|
536
811
|
def kernel(f=None, *, enable_backward=None):
|
|
@@ -658,6 +933,7 @@ def add_builtin(
|
|
|
658
933
|
missing_grad=False,
|
|
659
934
|
native_func=None,
|
|
660
935
|
defaults=None,
|
|
936
|
+
require_original_output_arg=False,
|
|
661
937
|
):
|
|
662
938
|
# wrap simple single-type functions with a value_func()
|
|
663
939
|
if value_func is None:
|
|
@@ -670,7 +946,7 @@ def add_builtin(
|
|
|
670
946
|
def initializer_list_func(args, templates):
|
|
671
947
|
return False
|
|
672
948
|
|
|
673
|
-
if defaults
|
|
949
|
+
if defaults is None:
|
|
674
950
|
defaults = {}
|
|
675
951
|
|
|
676
952
|
# Add specialized versions of this builtin if it's generic by matching arguments against
|
|
@@ -751,8 +1027,8 @@ def add_builtin(
|
|
|
751
1027
|
# on the generated argument list and skip generation if it fails.
|
|
752
1028
|
# This also gives us the return type, which we keep for later:
|
|
753
1029
|
try:
|
|
754
|
-
return_type = value_func(
|
|
755
|
-
except Exception
|
|
1030
|
+
return_type = value_func(argtypes, {}, [])
|
|
1031
|
+
except Exception:
|
|
756
1032
|
continue
|
|
757
1033
|
|
|
758
1034
|
# The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
|
|
@@ -782,6 +1058,7 @@ def add_builtin(
|
|
|
782
1058
|
hidden=True,
|
|
783
1059
|
skip_replay=skip_replay,
|
|
784
1060
|
missing_grad=missing_grad,
|
|
1061
|
+
require_original_output_arg=require_original_output_arg,
|
|
785
1062
|
)
|
|
786
1063
|
|
|
787
1064
|
func = Function(
|
|
@@ -802,6 +1079,7 @@ def add_builtin(
|
|
|
802
1079
|
generic=generic,
|
|
803
1080
|
native_func=native_func,
|
|
804
1081
|
defaults=defaults,
|
|
1082
|
+
require_original_output_arg=require_original_output_arg,
|
|
805
1083
|
)
|
|
806
1084
|
|
|
807
1085
|
if key in builtin_functions:
|
|
@@ -811,7 +1089,7 @@ def add_builtin(
|
|
|
811
1089
|
|
|
812
1090
|
# export means the function will be added to the `warp` module namespace
|
|
813
1091
|
# so that users can call it directly from the Python interpreter
|
|
814
|
-
if export
|
|
1092
|
+
if export:
|
|
815
1093
|
if hasattr(warp, key):
|
|
816
1094
|
# check that we haven't already created something at this location
|
|
817
1095
|
# if it's just an overload stub for auto-complete then overwrite it
|
|
@@ -878,6 +1156,8 @@ class ModuleBuilder:
|
|
|
878
1156
|
for func in module.functions.values():
|
|
879
1157
|
for f in func.user_overloads.values():
|
|
880
1158
|
self.build_function(f)
|
|
1159
|
+
if f.custom_replay_func is not None:
|
|
1160
|
+
self.build_function(f.custom_replay_func)
|
|
881
1161
|
|
|
882
1162
|
# build all kernel entry points
|
|
883
1163
|
for kernel in module.kernels.values():
|
|
@@ -894,8 +1174,7 @@ class ModuleBuilder:
|
|
|
894
1174
|
while stack:
|
|
895
1175
|
s = stack.pop()
|
|
896
1176
|
|
|
897
|
-
|
|
898
|
-
structs.append(s)
|
|
1177
|
+
structs.append(s)
|
|
899
1178
|
|
|
900
1179
|
for var in s.vars.values():
|
|
901
1180
|
if isinstance(var.type, warp.codegen.Struct):
|
|
@@ -927,7 +1206,7 @@ class ModuleBuilder:
|
|
|
927
1206
|
if not func.value_func:
|
|
928
1207
|
|
|
929
1208
|
def wrap(adj):
|
|
930
|
-
def value_type(
|
|
1209
|
+
def value_type(arg_types, kwds, templates):
|
|
931
1210
|
if adj.return_var is None or len(adj.return_var) == 0:
|
|
932
1211
|
return None
|
|
933
1212
|
if len(adj.return_var) == 1:
|
|
@@ -951,7 +1230,14 @@ class ModuleBuilder:
|
|
|
951
1230
|
|
|
952
1231
|
# code-gen all imported functions
|
|
953
1232
|
for func in self.functions.keys():
|
|
954
|
-
|
|
1233
|
+
if func.native_snippet is None:
|
|
1234
|
+
source += warp.codegen.codegen_func(
|
|
1235
|
+
func.adj, c_func_name=func.native_func, device=device, options=self.options
|
|
1236
|
+
)
|
|
1237
|
+
else:
|
|
1238
|
+
source += warp.codegen.codegen_snippet(
|
|
1239
|
+
func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet
|
|
1240
|
+
)
|
|
955
1241
|
|
|
956
1242
|
for kernel in self.module.kernels.values():
|
|
957
1243
|
# each kernel gets an entry point in the module
|
|
@@ -1031,6 +1317,10 @@ class Module:
|
|
|
1031
1317
|
|
|
1032
1318
|
self.content_hash = None
|
|
1033
1319
|
|
|
1320
|
+
# number of times module auto-generates kernel key for user
|
|
1321
|
+
# used to ensure unique kernel keys
|
|
1322
|
+
self.count = 0
|
|
1323
|
+
|
|
1034
1324
|
def register_struct(self, struct):
|
|
1035
1325
|
self.structs[struct.key] = struct
|
|
1036
1326
|
|
|
@@ -1045,7 +1335,7 @@ class Module:
|
|
|
1045
1335
|
# for a reload of module on next launch
|
|
1046
1336
|
self.unload()
|
|
1047
1337
|
|
|
1048
|
-
def register_function(self, func):
|
|
1338
|
+
def register_function(self, func, skip_adding_overload=False):
|
|
1049
1339
|
if func.key not in self.functions:
|
|
1050
1340
|
self.functions[func.key] = func
|
|
1051
1341
|
else:
|
|
@@ -1065,7 +1355,7 @@ class Module:
|
|
|
1065
1355
|
)
|
|
1066
1356
|
if sig == sig_existing:
|
|
1067
1357
|
self.functions[func.key] = func
|
|
1068
|
-
|
|
1358
|
+
elif not skip_adding_overload:
|
|
1069
1359
|
func_existing.add_overload(func)
|
|
1070
1360
|
|
|
1071
1361
|
self.find_references(func.adj)
|
|
@@ -1073,6 +1363,11 @@ class Module:
|
|
|
1073
1363
|
# for a reload of module on next launch
|
|
1074
1364
|
self.unload()
|
|
1075
1365
|
|
|
1366
|
+
def generate_unique_kernel_key(self, key):
|
|
1367
|
+
unique_key = f"{key}_{self.count}"
|
|
1368
|
+
self.count += 1
|
|
1369
|
+
return unique_key
|
|
1370
|
+
|
|
1076
1371
|
# collect all referenced functions / structs
|
|
1077
1372
|
# given the AST of a function or kernel
|
|
1078
1373
|
def find_references(self, adj):
|
|
@@ -1086,13 +1381,13 @@ class Module:
|
|
|
1086
1381
|
if isinstance(node, ast.Call):
|
|
1087
1382
|
try:
|
|
1088
1383
|
# try to resolve the function
|
|
1089
|
-
func, _ = adj.
|
|
1384
|
+
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
1090
1385
|
|
|
1091
1386
|
# if this is a user-defined function, add a module reference
|
|
1092
1387
|
if isinstance(func, warp.context.Function) and func.module is not None:
|
|
1093
1388
|
add_ref(func.module)
|
|
1094
1389
|
|
|
1095
|
-
except:
|
|
1390
|
+
except Exception:
|
|
1096
1391
|
# Lookups may fail for builtins, but that's ok.
|
|
1097
1392
|
# Lookups may also fail for functions in this module that haven't been imported yet,
|
|
1098
1393
|
# and that's ok too (not an external reference).
|
|
@@ -1139,9 +1434,24 @@ class Module:
|
|
|
1139
1434
|
s = func.adj.source
|
|
1140
1435
|
ch.update(bytes(s, "utf-8"))
|
|
1141
1436
|
|
|
1437
|
+
if func.custom_grad_func:
|
|
1438
|
+
s = func.custom_grad_func.adj.source
|
|
1439
|
+
ch.update(bytes(s, "utf-8"))
|
|
1440
|
+
if func.custom_replay_func:
|
|
1441
|
+
s = func.custom_replay_func.adj.source
|
|
1442
|
+
|
|
1443
|
+
# cache func arg types
|
|
1444
|
+
for arg, arg_type in func.adj.arg_types.items():
|
|
1445
|
+
s = f"{arg}: {get_type_name(arg_type)}"
|
|
1446
|
+
ch.update(bytes(s, "utf-8"))
|
|
1447
|
+
|
|
1142
1448
|
# kernel source
|
|
1143
1449
|
for kernel in module.kernels.values():
|
|
1144
1450
|
ch.update(bytes(kernel.adj.source, "utf-8"))
|
|
1451
|
+
# cache kernel arg types
|
|
1452
|
+
for arg, arg_type in kernel.adj.arg_types.items():
|
|
1453
|
+
s = f"{arg}: {get_type_name(arg_type)}"
|
|
1454
|
+
ch.update(bytes(s, "utf-8"))
|
|
1145
1455
|
# for generic kernels the Python source is always the same,
|
|
1146
1456
|
# but we hash the type signatures of all the overloads
|
|
1147
1457
|
if kernel.is_generic:
|
|
@@ -1440,13 +1750,13 @@ class ContextGuard:
|
|
|
1440
1750
|
def __enter__(self):
|
|
1441
1751
|
if self.device.is_cuda:
|
|
1442
1752
|
runtime.core.cuda_context_push_current(self.device.context)
|
|
1443
|
-
elif
|
|
1753
|
+
elif is_cuda_driver_initialized():
|
|
1444
1754
|
self.saved_context = runtime.core.cuda_context_get_current()
|
|
1445
1755
|
|
|
1446
1756
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
1447
1757
|
if self.device.is_cuda:
|
|
1448
1758
|
runtime.core.cuda_context_pop_current()
|
|
1449
|
-
elif
|
|
1759
|
+
elif is_cuda_driver_initialized():
|
|
1450
1760
|
runtime.core.cuda_context_set_current(self.saved_context)
|
|
1451
1761
|
|
|
1452
1762
|
|
|
@@ -1537,6 +1847,29 @@ class Event:
|
|
|
1537
1847
|
|
|
1538
1848
|
|
|
1539
1849
|
class Device:
|
|
1850
|
+
"""A device to allocate Warp arrays and to launch kernels on.
|
|
1851
|
+
|
|
1852
|
+
Attributes:
|
|
1853
|
+
ordinal: A Warp-specific integer label for the device. ``-1`` for CPU devices.
|
|
1854
|
+
name: A string label for the device. By default, CPU devices will be named according to the processor name,
|
|
1855
|
+
or ``"CPU"`` if the processor name cannot be determined.
|
|
1856
|
+
arch: An integer representing the compute capability version number calculated as
|
|
1857
|
+
``10 * major + minor``. ``0`` for CPU devices.
|
|
1858
|
+
is_uva: A boolean indicating whether or not the device supports unified addressing.
|
|
1859
|
+
``False`` for CPU devices.
|
|
1860
|
+
is_cubin_supported: A boolean indicating whether or not Warp's version of NVRTC can directly
|
|
1861
|
+
generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
|
|
1862
|
+
is_mempool_supported: A boolean indicating whether or not the device supports using the
|
|
1863
|
+
``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
|
|
1864
|
+
CPU devices.
|
|
1865
|
+
is_primary: A boolean indicating whether or not this device's CUDA context is also the
|
|
1866
|
+
device's primary context.
|
|
1867
|
+
uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
|
|
1868
|
+
``nvidia-smi -L``. ``None`` for CPU devices.
|
|
1869
|
+
pci_bus_id: A string identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
|
|
1870
|
+
``domain``, ``bus``, and ``device`` are all hexadecimal values. ``None`` for CPU devices.
|
|
1871
|
+
"""
|
|
1872
|
+
|
|
1540
1873
|
def __init__(self, runtime, alias, ordinal=-1, is_primary=False, context=None):
|
|
1541
1874
|
self.runtime = runtime
|
|
1542
1875
|
self.alias = alias
|
|
@@ -1566,6 +1899,9 @@ class Device:
|
|
|
1566
1899
|
self.arch = 0
|
|
1567
1900
|
self.is_uva = False
|
|
1568
1901
|
self.is_cubin_supported = False
|
|
1902
|
+
self.is_mempool_supported = False
|
|
1903
|
+
self.uuid = None
|
|
1904
|
+
self.pci_bus_id = None
|
|
1569
1905
|
|
|
1570
1906
|
# TODO: add more device-specific dispatch functions
|
|
1571
1907
|
self.memset = runtime.core.memset_host
|
|
@@ -1578,6 +1914,26 @@ class Device:
|
|
|
1578
1914
|
self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
|
|
1579
1915
|
# check whether our NVRTC can generate CUBINs for this architecture
|
|
1580
1916
|
self.is_cubin_supported = self.arch in runtime.nvrtc_supported_archs
|
|
1917
|
+
self.is_mempool_supported = runtime.core.cuda_device_is_memory_pool_supported(ordinal)
|
|
1918
|
+
|
|
1919
|
+
uuid_buffer = (ctypes.c_char * 16)()
|
|
1920
|
+
runtime.core.cuda_device_get_uuid(ordinal, uuid_buffer)
|
|
1921
|
+
uuid_byte_str = bytes(uuid_buffer).hex()
|
|
1922
|
+
self.uuid = f"GPU-{uuid_byte_str[0:8]}-{uuid_byte_str[8:12]}-{uuid_byte_str[12:16]}-{uuid_byte_str[16:20]}-{uuid_byte_str[20:]}"
|
|
1923
|
+
|
|
1924
|
+
pci_domain_id = runtime.core.cuda_device_get_pci_domain_id(ordinal)
|
|
1925
|
+
pci_bus_id = runtime.core.cuda_device_get_pci_bus_id(ordinal)
|
|
1926
|
+
pci_device_id = runtime.core.cuda_device_get_pci_device_id(ordinal)
|
|
1927
|
+
# This is (mis)named to correspond to the naming of cudaDeviceGetPCIBusId
|
|
1928
|
+
self.pci_bus_id = f"{pci_domain_id:08X}:{pci_bus_id:02X}:{pci_device_id:02X}"
|
|
1929
|
+
|
|
1930
|
+
# Warn the user of a possible misconfiguration of their system
|
|
1931
|
+
if not self.is_mempool_supported:
|
|
1932
|
+
warp.utils.warn(
|
|
1933
|
+
f"Support for stream ordered memory allocators was not detected on device {ordinal}. "
|
|
1934
|
+
"This can prevent the use of graphs and/or result in poor performance. "
|
|
1935
|
+
"Is the UVM driver enabled?"
|
|
1936
|
+
)
|
|
1581
1937
|
|
|
1582
1938
|
# initialize streams unless context acquisition is postponed
|
|
1583
1939
|
if self._context is not None:
|
|
@@ -1601,14 +1957,17 @@ class Device:
|
|
|
1601
1957
|
|
|
1602
1958
|
@property
|
|
1603
1959
|
def is_cpu(self):
|
|
1960
|
+
"""A boolean indicating whether or not the device is a CPU device."""
|
|
1604
1961
|
return self.ordinal < 0
|
|
1605
1962
|
|
|
1606
1963
|
@property
|
|
1607
1964
|
def is_cuda(self):
|
|
1965
|
+
"""A boolean indicating whether or not the device is a CUDA device."""
|
|
1608
1966
|
return self.ordinal >= 0
|
|
1609
1967
|
|
|
1610
1968
|
@property
|
|
1611
1969
|
def context(self):
|
|
1970
|
+
"""The context associated with the device."""
|
|
1612
1971
|
if self._context is not None:
|
|
1613
1972
|
return self._context
|
|
1614
1973
|
elif self.is_primary:
|
|
@@ -1623,10 +1982,16 @@ class Device:
|
|
|
1623
1982
|
|
|
1624
1983
|
@property
|
|
1625
1984
|
def has_context(self):
|
|
1985
|
+
"""A boolean indicating whether or not the device has a CUDA context associated with it."""
|
|
1626
1986
|
return self._context is not None
|
|
1627
1987
|
|
|
1628
1988
|
@property
|
|
1629
1989
|
def stream(self):
|
|
1990
|
+
"""The stream associated with a CUDA device.
|
|
1991
|
+
|
|
1992
|
+
Raises:
|
|
1993
|
+
RuntimeError: The device is not a CUDA device.
|
|
1994
|
+
"""
|
|
1630
1995
|
if self.context:
|
|
1631
1996
|
return self._stream
|
|
1632
1997
|
else:
|
|
@@ -1644,6 +2009,7 @@ class Device:
|
|
|
1644
2009
|
|
|
1645
2010
|
@property
|
|
1646
2011
|
def has_stream(self):
|
|
2012
|
+
"""A boolean indicating whether or not the device has a stream associated with it."""
|
|
1647
2013
|
return self._stream is not None
|
|
1648
2014
|
|
|
1649
2015
|
def __str__(self):
|
|
@@ -1721,7 +2087,7 @@ class Runtime:
|
|
|
1721
2087
|
|
|
1722
2088
|
self.core = self.load_dll(warp_lib)
|
|
1723
2089
|
|
|
1724
|
-
if
|
|
2090
|
+
if os.path.exists(llvm_lib):
|
|
1725
2091
|
self.llvm = self.load_dll(llvm_lib)
|
|
1726
2092
|
# setup c-types for warp-clang.dll
|
|
1727
2093
|
self.llvm.lookup.restype = ctypes.c_uint64
|
|
@@ -2087,6 +2453,8 @@ class Runtime:
|
|
|
2087
2453
|
self.core.cuda_driver_version.restype = ctypes.c_int
|
|
2088
2454
|
self.core.cuda_toolkit_version.argtypes = None
|
|
2089
2455
|
self.core.cuda_toolkit_version.restype = ctypes.c_int
|
|
2456
|
+
self.core.cuda_driver_is_initialized.argtypes = None
|
|
2457
|
+
self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
|
|
2090
2458
|
|
|
2091
2459
|
self.core.nvrtc_supported_arch_count.argtypes = None
|
|
2092
2460
|
self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
|
|
@@ -2103,6 +2471,14 @@ class Runtime:
|
|
|
2103
2471
|
self.core.cuda_device_get_arch.restype = ctypes.c_int
|
|
2104
2472
|
self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
|
|
2105
2473
|
self.core.cuda_device_is_uva.restype = ctypes.c_int
|
|
2474
|
+
self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
|
|
2475
|
+
self.core.cuda_device_get_uuid.restype = None
|
|
2476
|
+
self.core.cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
|
|
2477
|
+
self.core.cuda_device_get_pci_domain_id.restype = ctypes.c_int
|
|
2478
|
+
self.core.cuda_device_get_pci_bus_id.argtypes = [ctypes.c_int]
|
|
2479
|
+
self.core.cuda_device_get_pci_bus_id.restype = ctypes.c_int
|
|
2480
|
+
self.core.cuda_device_get_pci_device_id.argtypes = [ctypes.c_int]
|
|
2481
|
+
self.core.cuda_device_get_pci_device_id.restype = ctypes.c_int
|
|
2106
2482
|
|
|
2107
2483
|
self.core.cuda_context_get_current.argtypes = None
|
|
2108
2484
|
self.core.cuda_context_get_current.restype = ctypes.c_void_p
|
|
@@ -2189,6 +2565,7 @@ class Runtime:
|
|
|
2189
2565
|
ctypes.c_void_p,
|
|
2190
2566
|
ctypes.c_void_p,
|
|
2191
2567
|
ctypes.c_size_t,
|
|
2568
|
+
ctypes.c_int,
|
|
2192
2569
|
ctypes.POINTER(ctypes.c_void_p),
|
|
2193
2570
|
]
|
|
2194
2571
|
self.core.cuda_launch_kernel.restype = ctypes.c_size_t
|
|
@@ -2309,8 +2686,15 @@ class Runtime:
|
|
|
2309
2686
|
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
2310
2687
|
else:
|
|
2311
2688
|
dll = ctypes.CDLL(dll_path)
|
|
2312
|
-
except OSError:
|
|
2313
|
-
|
|
2689
|
+
except OSError as e:
|
|
2690
|
+
if "GLIBCXX" in str(e):
|
|
2691
|
+
raise RuntimeError(
|
|
2692
|
+
f"Failed to load the shared library '{dll_path}'.\n"
|
|
2693
|
+
"The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n"
|
|
2694
|
+
"See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details."
|
|
2695
|
+
) from e
|
|
2696
|
+
else:
|
|
2697
|
+
raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e
|
|
2314
2698
|
return dll
|
|
2315
2699
|
|
|
2316
2700
|
def get_device(self, ident: Devicelike = None) -> Device:
|
|
@@ -2439,6 +2823,21 @@ def is_device_available(device):
|
|
|
2439
2823
|
return device in get_devices()
|
|
2440
2824
|
|
|
2441
2825
|
|
|
2826
|
+
def is_cuda_driver_initialized() -> bool:
|
|
2827
|
+
"""Returns ``True`` if the CUDA driver is initialized.
|
|
2828
|
+
|
|
2829
|
+
This is a stricter test than ``is_cuda_available()`` since a CUDA driver
|
|
2830
|
+
call to ``cuCtxGetCurrent`` is made, and the result is compared to
|
|
2831
|
+
`CUDA_SUCCESS`. Note that `CUDA_SUCCESS` is returned by ``cuCtxGetCurrent``
|
|
2832
|
+
even if there is no context bound to the calling CPU thread.
|
|
2833
|
+
|
|
2834
|
+
This can be helpful in cases in which ``cuInit()`` was called before a fork.
|
|
2835
|
+
"""
|
|
2836
|
+
assert_initialized()
|
|
2837
|
+
|
|
2838
|
+
return runtime.core.cuda_driver_is_initialized()
|
|
2839
|
+
|
|
2840
|
+
|
|
2442
2841
|
def get_devices() -> List[Device]:
|
|
2443
2842
|
"""Returns a list of devices supported in this environment."""
|
|
2444
2843
|
|
|
@@ -2749,7 +3148,7 @@ def full(
|
|
|
2749
3148
|
elif na.ndim == 2:
|
|
2750
3149
|
dtype = warp.types.matrix(na.shape, scalar_type)
|
|
2751
3150
|
else:
|
|
2752
|
-
raise ValueError(
|
|
3151
|
+
raise ValueError("Values with more than two dimensions are not supported")
|
|
2753
3152
|
else:
|
|
2754
3153
|
raise ValueError(f"Invalid value type for Warp array: {value_type}")
|
|
2755
3154
|
|
|
@@ -2872,8 +3271,34 @@ def empty_like(
|
|
|
2872
3271
|
return arr
|
|
2873
3272
|
|
|
2874
3273
|
|
|
2875
|
-
def from_numpy(
|
|
2876
|
-
|
|
3274
|
+
def from_numpy(
|
|
3275
|
+
arr: np.ndarray,
|
|
3276
|
+
dtype: Optional[type] = None,
|
|
3277
|
+
shape: Optional[Sequence[int]] = None,
|
|
3278
|
+
device: Optional[Devicelike] = None,
|
|
3279
|
+
requires_grad: bool = False,
|
|
3280
|
+
) -> warp.array:
|
|
3281
|
+
if dtype is None:
|
|
3282
|
+
base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
|
|
3283
|
+
if base_type is None:
|
|
3284
|
+
raise RuntimeError("Unsupported NumPy data type '{}'.".format(arr.dtype))
|
|
3285
|
+
|
|
3286
|
+
dim_count = len(arr.shape)
|
|
3287
|
+
if dim_count == 2:
|
|
3288
|
+
dtype = warp.types.vector(length=arr.shape[1], dtype=base_type)
|
|
3289
|
+
elif dim_count == 3:
|
|
3290
|
+
dtype = warp.types.matrix(shape=(arr.shape[1], arr.shape[2]), dtype=base_type)
|
|
3291
|
+
else:
|
|
3292
|
+
dtype = base_type
|
|
3293
|
+
|
|
3294
|
+
return warp.array(
|
|
3295
|
+
data=arr,
|
|
3296
|
+
dtype=dtype,
|
|
3297
|
+
shape=shape,
|
|
3298
|
+
owner=False,
|
|
3299
|
+
device=device,
|
|
3300
|
+
requires_grad=requires_grad,
|
|
3301
|
+
)
|
|
2877
3302
|
|
|
2878
3303
|
|
|
2879
3304
|
# given a kernel destination argument type and a value convert
|
|
@@ -2889,9 +3314,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
2889
3314
|
# - in forward passes, array types have to match
|
|
2890
3315
|
# - in backward passes, indexed array gradients are regular arrays
|
|
2891
3316
|
if adjoint:
|
|
2892
|
-
array_matches =
|
|
3317
|
+
array_matches = isinstance(value, warp.array)
|
|
2893
3318
|
else:
|
|
2894
|
-
array_matches = type(value)
|
|
3319
|
+
array_matches = type(value) is type(arg_type)
|
|
2895
3320
|
|
|
2896
3321
|
if not array_matches:
|
|
2897
3322
|
adj = "adjoint " if adjoint else ""
|
|
@@ -2934,7 +3359,7 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
2934
3359
|
# try constructing the required value from the argument (handles tuple / list, Gf.Vec3 case)
|
|
2935
3360
|
try:
|
|
2936
3361
|
return arg_type(value)
|
|
2937
|
-
except:
|
|
3362
|
+
except Exception:
|
|
2938
3363
|
raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}")
|
|
2939
3364
|
|
|
2940
3365
|
elif isinstance(value, bool):
|
|
@@ -2943,27 +3368,35 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
|
2943
3368
|
elif isinstance(value, arg_type):
|
|
2944
3369
|
try:
|
|
2945
3370
|
# try to pack as a scalar type
|
|
2946
|
-
|
|
2947
|
-
|
|
3371
|
+
if arg_type is warp.types.float16:
|
|
3372
|
+
return arg_type._type_(warp.types.float_to_half_bits(value.value))
|
|
3373
|
+
else:
|
|
3374
|
+
return arg_type._type_(value.value)
|
|
3375
|
+
except Exception:
|
|
2948
3376
|
raise RuntimeError(
|
|
2949
|
-
|
|
3377
|
+
"Error launching kernel, unable to pack kernel parameter type "
|
|
3378
|
+
f"{type(value)} for param {arg_name}, expected {arg_type}"
|
|
2950
3379
|
)
|
|
2951
3380
|
|
|
2952
3381
|
else:
|
|
2953
3382
|
try:
|
|
2954
3383
|
# try to pack as a scalar type
|
|
2955
|
-
|
|
3384
|
+
if arg_type is warp.types.float16:
|
|
3385
|
+
return arg_type._type_(warp.types.float_to_half_bits(value))
|
|
3386
|
+
else:
|
|
3387
|
+
return arg_type._type_(value)
|
|
2956
3388
|
except Exception as e:
|
|
2957
3389
|
print(e)
|
|
2958
3390
|
raise RuntimeError(
|
|
2959
|
-
|
|
3391
|
+
"Error launching kernel, unable to pack kernel parameter type "
|
|
3392
|
+
f"{type(value)} for param {arg_name}, expected {arg_type}"
|
|
2960
3393
|
)
|
|
2961
3394
|
|
|
2962
3395
|
|
|
2963
3396
|
# represents all data required for a kernel launch
|
|
2964
3397
|
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
|
|
2965
3398
|
class Launch:
|
|
2966
|
-
def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None):
|
|
3399
|
+
def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
|
|
2967
3400
|
# if not specified look up hooks
|
|
2968
3401
|
if not hooks:
|
|
2969
3402
|
module = kernel.module
|
|
@@ -3000,6 +3433,7 @@ class Launch:
|
|
|
3000
3433
|
self.params_addr = params_addr
|
|
3001
3434
|
self.device = device
|
|
3002
3435
|
self.bounds = bounds
|
|
3436
|
+
self.max_blocks = max_blocks
|
|
3003
3437
|
|
|
3004
3438
|
def set_dim(self, dim):
|
|
3005
3439
|
self.bounds = warp.types.launch_bounds_t(dim)
|
|
@@ -3065,7 +3499,9 @@ class Launch:
|
|
|
3065
3499
|
if self.device.is_cpu:
|
|
3066
3500
|
self.hooks.forward(*self.params)
|
|
3067
3501
|
else:
|
|
3068
|
-
runtime.core.cuda_launch_kernel(
|
|
3502
|
+
runtime.core.cuda_launch_kernel(
|
|
3503
|
+
self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr
|
|
3504
|
+
)
|
|
3069
3505
|
|
|
3070
3506
|
|
|
3071
3507
|
def launch(
|
|
@@ -3080,6 +3516,7 @@ def launch(
|
|
|
3080
3516
|
adjoint=False,
|
|
3081
3517
|
record_tape=True,
|
|
3082
3518
|
record_cmd=False,
|
|
3519
|
+
max_blocks=0,
|
|
3083
3520
|
):
|
|
3084
3521
|
"""Launch a Warp kernel on the target device
|
|
3085
3522
|
|
|
@@ -3097,6 +3534,8 @@ def launch(
|
|
|
3097
3534
|
adjoint: Whether to run forward or backward pass (typically use False)
|
|
3098
3535
|
record_tape: When true the launch will be recorded the global wp.Tape() object when present
|
|
3099
3536
|
record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
|
|
3537
|
+
max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
|
|
3538
|
+
If negative or zero, the maximum hardware value will be used.
|
|
3100
3539
|
"""
|
|
3101
3540
|
|
|
3102
3541
|
assert_initialized()
|
|
@@ -3108,7 +3547,7 @@ def launch(
|
|
|
3108
3547
|
device = runtime.get_device(device)
|
|
3109
3548
|
|
|
3110
3549
|
# check function is a Kernel
|
|
3111
|
-
if isinstance(kernel, Kernel)
|
|
3550
|
+
if not isinstance(kernel, Kernel):
|
|
3112
3551
|
raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
|
|
3113
3552
|
|
|
3114
3553
|
# debugging aid
|
|
@@ -3190,7 +3629,9 @@ def launch(
|
|
|
3190
3629
|
f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
|
|
3191
3630
|
)
|
|
3192
3631
|
|
|
3193
|
-
runtime.core.cuda_launch_kernel(
|
|
3632
|
+
runtime.core.cuda_launch_kernel(
|
|
3633
|
+
device.context, hooks.backward, bounds.size, max_blocks, kernel_params
|
|
3634
|
+
)
|
|
3194
3635
|
|
|
3195
3636
|
else:
|
|
3196
3637
|
if hooks.forward is None:
|
|
@@ -3211,7 +3652,9 @@ def launch(
|
|
|
3211
3652
|
|
|
3212
3653
|
else:
|
|
3213
3654
|
# launch
|
|
3214
|
-
runtime.core.cuda_launch_kernel(
|
|
3655
|
+
runtime.core.cuda_launch_kernel(
|
|
3656
|
+
device.context, hooks.forward, bounds.size, max_blocks, kernel_params
|
|
3657
|
+
)
|
|
3215
3658
|
|
|
3216
3659
|
try:
|
|
3217
3660
|
runtime.verify_cuda_device(device)
|
|
@@ -3221,7 +3664,7 @@ def launch(
|
|
|
3221
3664
|
|
|
3222
3665
|
# record on tape if one is active
|
|
3223
3666
|
if runtime.tape and record_tape:
|
|
3224
|
-
runtime.tape.record_launch(kernel, dim, inputs, outputs, device)
|
|
3667
|
+
runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device)
|
|
3225
3668
|
|
|
3226
3669
|
|
|
3227
3670
|
def synchronize():
|
|
@@ -3231,7 +3674,7 @@ def synchronize():
|
|
|
3231
3674
|
or memory copies have completed.
|
|
3232
3675
|
"""
|
|
3233
3676
|
|
|
3234
|
-
if
|
|
3677
|
+
if is_cuda_driver_initialized():
|
|
3235
3678
|
# save the original context to avoid side effects
|
|
3236
3679
|
saved_context = runtime.core.cuda_context_get_current()
|
|
3237
3680
|
|
|
@@ -3281,7 +3724,7 @@ def synchronize_stream(stream_or_device=None):
|
|
|
3281
3724
|
runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream)
|
|
3282
3725
|
|
|
3283
3726
|
|
|
3284
|
-
def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
|
|
3727
|
+
def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
|
|
3285
3728
|
"""Force user-defined kernels to be compiled and loaded
|
|
3286
3729
|
|
|
3287
3730
|
Args:
|
|
@@ -3289,12 +3732,14 @@ def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
|
|
|
3289
3732
|
modules: List of modules to load. If None, load all imported modules.
|
|
3290
3733
|
"""
|
|
3291
3734
|
|
|
3292
|
-
if
|
|
3735
|
+
if is_cuda_driver_initialized():
|
|
3293
3736
|
# save original context to avoid side effects
|
|
3294
3737
|
saved_context = runtime.core.cuda_context_get_current()
|
|
3295
3738
|
|
|
3296
3739
|
if device is None:
|
|
3297
3740
|
devices = get_devices()
|
|
3741
|
+
elif isinstance(device, list):
|
|
3742
|
+
devices = [get_device(device_item) for device_item in device]
|
|
3298
3743
|
else:
|
|
3299
3744
|
devices = [get_device(device)]
|
|
3300
3745
|
|
|
@@ -3386,7 +3831,7 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
|
3386
3831
|
return get_module(m.__name__).options
|
|
3387
3832
|
|
|
3388
3833
|
|
|
3389
|
-
def capture_begin(device: Devicelike = None, stream=None, force_module_load=
|
|
3834
|
+
def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
|
|
3390
3835
|
"""Begin capture of a CUDA graph
|
|
3391
3836
|
|
|
3392
3837
|
Captures all subsequent kernel launches and memory operations on CUDA devices.
|
|
@@ -3400,7 +3845,10 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
|
|
|
3400
3845
|
|
|
3401
3846
|
"""
|
|
3402
3847
|
|
|
3403
|
-
if
|
|
3848
|
+
if force_module_load is None:
|
|
3849
|
+
force_module_load = warp.config.graph_capture_module_load_default
|
|
3850
|
+
|
|
3851
|
+
if warp.config.verify_cuda:
|
|
3404
3852
|
raise RuntimeError("Cannot use CUDA error verification during graph capture")
|
|
3405
3853
|
|
|
3406
3854
|
if stream is not None:
|
|
@@ -3415,6 +3863,9 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
|
|
|
3415
3863
|
|
|
3416
3864
|
device.is_capturing = True
|
|
3417
3865
|
|
|
3866
|
+
# disable garbage collection to avoid older allocations getting collected during graph capture
|
|
3867
|
+
gc.disable()
|
|
3868
|
+
|
|
3418
3869
|
with warp.ScopedStream(stream):
|
|
3419
3870
|
runtime.core.cuda_graph_begin_capture(device.context)
|
|
3420
3871
|
|
|
@@ -3438,6 +3889,9 @@ def capture_end(device: Devicelike = None, stream=None) -> Graph:
|
|
|
3438
3889
|
|
|
3439
3890
|
device.is_capturing = False
|
|
3440
3891
|
|
|
3892
|
+
# re-enable GC
|
|
3893
|
+
gc.enable()
|
|
3894
|
+
|
|
3441
3895
|
if graph is None:
|
|
3442
3896
|
raise RuntimeError(
|
|
3443
3897
|
"Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event."
|
|
@@ -3557,6 +4011,16 @@ def copy(
|
|
|
3557
4011
|
if src_elem_size != dst_elem_size:
|
|
3558
4012
|
raise RuntimeError("Incompatible array data types")
|
|
3559
4013
|
|
|
4014
|
+
# can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
|
|
4015
|
+
# TODO?
|
|
4016
|
+
if (
|
|
4017
|
+
isinstance(src, (warp.fabricarray, warp.indexedfabricarray))
|
|
4018
|
+
and src.ndim > 1
|
|
4019
|
+
or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
|
|
4020
|
+
and dest.ndim > 1
|
|
4021
|
+
):
|
|
4022
|
+
raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
|
|
4023
|
+
|
|
3560
4024
|
src_desc = src.__ctype__()
|
|
3561
4025
|
dst_desc = dest.__ctype__()
|
|
3562
4026
|
src_ptr = ctypes.pointer(src_desc)
|
|
@@ -3592,6 +4056,10 @@ def type_str(t):
|
|
|
3592
4056
|
return f"Array[{type_str(t.dtype)}]"
|
|
3593
4057
|
elif isinstance(t, warp.indexedarray):
|
|
3594
4058
|
return f"IndexedArray[{type_str(t.dtype)}]"
|
|
4059
|
+
elif isinstance(t, warp.fabricarray):
|
|
4060
|
+
return f"FabricArray[{type_str(t.dtype)}]"
|
|
4061
|
+
elif isinstance(t, warp.indexedfabricarray):
|
|
4062
|
+
return f"IndexedFabricArray[{type_str(t.dtype)}]"
|
|
3595
4063
|
elif hasattr(t, "_wp_generic_type_str_"):
|
|
3596
4064
|
generic_type = t._wp_generic_type_str_
|
|
3597
4065
|
|
|
@@ -3618,7 +4086,7 @@ def type_str(t):
|
|
|
3618
4086
|
return t.__name__
|
|
3619
4087
|
|
|
3620
4088
|
|
|
3621
|
-
def print_function(f, file, noentry=False):
|
|
4089
|
+
def print_function(f, file, noentry=False): # pragma: no cover
|
|
3622
4090
|
"""Writes a function definition to a file for use in reST documentation
|
|
3623
4091
|
|
|
3624
4092
|
Args:
|
|
@@ -3642,7 +4110,7 @@ def print_function(f, file, noentry=False):
|
|
|
3642
4110
|
# todo: construct a default value for each of the functions args
|
|
3643
4111
|
# so we can generate the return type for overloaded functions
|
|
3644
4112
|
return_type = " -> " + type_str(f.value_func(None, None, None))
|
|
3645
|
-
except:
|
|
4113
|
+
except Exception:
|
|
3646
4114
|
pass
|
|
3647
4115
|
|
|
3648
4116
|
print(f".. function:: {f.key}({args}){return_type}", file=file)
|
|
@@ -3663,7 +4131,7 @@ def print_function(f, file, noentry=False):
|
|
|
3663
4131
|
return True
|
|
3664
4132
|
|
|
3665
4133
|
|
|
3666
|
-
def
|
|
4134
|
+
def export_functions_rst(file): # pragma: no cover
|
|
3667
4135
|
header = (
|
|
3668
4136
|
"..\n"
|
|
3669
4137
|
" Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
|
|
@@ -3683,6 +4151,8 @@ def print_builtins(file):
|
|
|
3683
4151
|
|
|
3684
4152
|
for t in warp.types.scalar_types:
|
|
3685
4153
|
print(f".. class:: {t.__name__}", file=file)
|
|
4154
|
+
# Manually add wp.bool since it's inconvenient to add to wp.types.scalar_types:
|
|
4155
|
+
print(f".. class:: {warp.types.bool.__name__}", file=file)
|
|
3686
4156
|
|
|
3687
4157
|
print("\n\nVector Types", file=file)
|
|
3688
4158
|
print("------------", file=file)
|
|
@@ -3693,14 +4163,22 @@ def print_builtins(file):
|
|
|
3693
4163
|
print("\nGeneric Types", file=file)
|
|
3694
4164
|
print("-------------", file=file)
|
|
3695
4165
|
|
|
3696
|
-
print(
|
|
3697
|
-
print(
|
|
3698
|
-
print(
|
|
3699
|
-
print(
|
|
3700
|
-
print(
|
|
3701
|
-
print(
|
|
3702
|
-
print(
|
|
3703
|
-
print(
|
|
4166
|
+
print(".. class:: Int", file=file)
|
|
4167
|
+
print(".. class:: Float", file=file)
|
|
4168
|
+
print(".. class:: Scalar", file=file)
|
|
4169
|
+
print(".. class:: Vector", file=file)
|
|
4170
|
+
print(".. class:: Matrix", file=file)
|
|
4171
|
+
print(".. class:: Quaternion", file=file)
|
|
4172
|
+
print(".. class:: Transformation", file=file)
|
|
4173
|
+
print(".. class:: Array", file=file)
|
|
4174
|
+
|
|
4175
|
+
print("\nQuery Types", file=file)
|
|
4176
|
+
print("-------------", file=file)
|
|
4177
|
+
print(".. autoclass:: bvh_query_t", file=file)
|
|
4178
|
+
print(".. autoclass:: hash_grid_query_t", file=file)
|
|
4179
|
+
print(".. autoclass:: mesh_query_aabb_t", file=file)
|
|
4180
|
+
print(".. autoclass:: mesh_query_point_t", file=file)
|
|
4181
|
+
print(".. autoclass:: mesh_query_ray_t", file=file)
|
|
3704
4182
|
|
|
3705
4183
|
# build dictionary of all functions by group
|
|
3706
4184
|
groups = {}
|
|
@@ -3735,7 +4213,7 @@ def print_builtins(file):
|
|
|
3735
4213
|
print(".. [1] Note: function gradients not implemented for backpropagation.", file=file)
|
|
3736
4214
|
|
|
3737
4215
|
|
|
3738
|
-
def export_stubs(file):
|
|
4216
|
+
def export_stubs(file): # pragma: no cover
|
|
3739
4217
|
"""Generates stub file for auto-complete of builtin functions"""
|
|
3740
4218
|
|
|
3741
4219
|
import textwrap
|
|
@@ -3767,6 +4245,8 @@ def export_stubs(file):
|
|
|
3767
4245
|
print("Quaternion = Generic[Float]", file=file)
|
|
3768
4246
|
print("Transformation = Generic[Float]", file=file)
|
|
3769
4247
|
print("Array = Generic[DType]", file=file)
|
|
4248
|
+
print("FabricArray = Generic[DType]", file=file)
|
|
4249
|
+
print("IndexedFabricArray = Generic[DType]", file=file)
|
|
3770
4250
|
|
|
3771
4251
|
# prepend __init__.py
|
|
3772
4252
|
with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
|
|
@@ -3783,7 +4263,7 @@ def export_stubs(file):
|
|
|
3783
4263
|
|
|
3784
4264
|
return_str = ""
|
|
3785
4265
|
|
|
3786
|
-
if f.export
|
|
4266
|
+
if not f.export or f.hidden: # or f.generic:
|
|
3787
4267
|
continue
|
|
3788
4268
|
|
|
3789
4269
|
try:
|
|
@@ -3793,29 +4273,42 @@ def export_stubs(file):
|
|
|
3793
4273
|
if return_type:
|
|
3794
4274
|
return_str = " -> " + type_str(return_type)
|
|
3795
4275
|
|
|
3796
|
-
except:
|
|
4276
|
+
except Exception:
|
|
3797
4277
|
pass
|
|
3798
4278
|
|
|
3799
4279
|
print("@over", file=file)
|
|
3800
4280
|
print(f"def {f.key}({args}){return_str}:", file=file)
|
|
3801
|
-
print(
|
|
4281
|
+
print(' """', file=file)
|
|
3802
4282
|
print(textwrap.indent(text=f.doc, prefix=" "), file=file)
|
|
3803
|
-
print(
|
|
3804
|
-
print(
|
|
4283
|
+
print(' """', file=file)
|
|
4284
|
+
print(" ...\n\n", file=file)
|
|
3805
4285
|
|
|
3806
4286
|
|
|
3807
|
-
def export_builtins(file):
|
|
3808
|
-
def
|
|
4287
|
+
def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
4288
|
+
def ctype_arg_str(t):
|
|
3809
4289
|
if isinstance(t, int):
|
|
3810
4290
|
return "int"
|
|
3811
4291
|
elif isinstance(t, float):
|
|
3812
4292
|
return "float"
|
|
4293
|
+
elif t in warp.types.vector_types:
|
|
4294
|
+
return f"{t.__name__}&"
|
|
3813
4295
|
else:
|
|
3814
4296
|
return t.__name__
|
|
3815
4297
|
|
|
4298
|
+
def ctype_ret_str(t):
|
|
4299
|
+
if isinstance(t, int):
|
|
4300
|
+
return "int"
|
|
4301
|
+
elif isinstance(t, float):
|
|
4302
|
+
return "float"
|
|
4303
|
+
else:
|
|
4304
|
+
return t.__name__
|
|
4305
|
+
|
|
4306
|
+
file.write("namespace wp {\n\n")
|
|
4307
|
+
file.write('extern "C" {\n\n')
|
|
4308
|
+
|
|
3816
4309
|
for k, g in builtin_functions.items():
|
|
3817
4310
|
for f in g.overloads:
|
|
3818
|
-
if f.export
|
|
4311
|
+
if not f.export or f.generic:
|
|
3819
4312
|
continue
|
|
3820
4313
|
|
|
3821
4314
|
simple = True
|
|
@@ -3829,7 +4322,7 @@ def export_builtins(file):
|
|
|
3829
4322
|
if not simple or f.variadic:
|
|
3830
4323
|
continue
|
|
3831
4324
|
|
|
3832
|
-
args = ", ".join(f"{
|
|
4325
|
+
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
|
|
3833
4326
|
params = ", ".join(f.input_types.keys())
|
|
3834
4327
|
|
|
3835
4328
|
return_type = ""
|
|
@@ -3837,25 +4330,25 @@ def export_builtins(file):
|
|
|
3837
4330
|
try:
|
|
3838
4331
|
# todo: construct a default value for each of the functions args
|
|
3839
4332
|
# so we can generate the return type for overloaded functions
|
|
3840
|
-
return_type =
|
|
3841
|
-
except:
|
|
4333
|
+
return_type = ctype_ret_str(f.value_func(None, None, None))
|
|
4334
|
+
except Exception:
|
|
3842
4335
|
continue
|
|
3843
4336
|
|
|
3844
4337
|
if return_type.startswith("Tuple"):
|
|
3845
4338
|
continue
|
|
3846
4339
|
|
|
3847
4340
|
if args == "":
|
|
3848
|
-
|
|
3849
|
-
f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}", file=file
|
|
3850
|
-
)
|
|
4341
|
+
file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
3851
4342
|
elif return_type == "None":
|
|
3852
|
-
|
|
4343
|
+
file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
|
|
3853
4344
|
else:
|
|
3854
|
-
|
|
3855
|
-
f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}"
|
|
3856
|
-
file=file,
|
|
4345
|
+
file.write(
|
|
4346
|
+
f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
|
|
3857
4347
|
)
|
|
3858
4348
|
|
|
4349
|
+
file.write('\n} // extern "C"\n\n')
|
|
4350
|
+
file.write("} // namespace wp\n")
|
|
4351
|
+
|
|
3859
4352
|
|
|
3860
4353
|
# initialize global runtime
|
|
3861
4354
|
runtime = None
|