warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/context.py
CHANGED
|
@@ -5,37 +5,27 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
-
import
|
|
9
|
-
import os
|
|
10
|
-
import sys
|
|
11
|
-
import hashlib
|
|
8
|
+
import ast
|
|
12
9
|
import ctypes
|
|
10
|
+
import gc
|
|
11
|
+
import hashlib
|
|
12
|
+
import inspect
|
|
13
|
+
import io
|
|
14
|
+
import os
|
|
13
15
|
import platform
|
|
14
|
-
import
|
|
16
|
+
import sys
|
|
15
17
|
import types
|
|
16
|
-
import
|
|
17
|
-
|
|
18
|
-
from typing import Tuple
|
|
19
|
-
from typing import List
|
|
20
|
-
from typing import Dict
|
|
21
|
-
from typing import Any
|
|
22
|
-
from typing import Callable
|
|
23
|
-
from typing import Union
|
|
24
|
-
from typing import Mapping
|
|
25
|
-
from typing import Optional
|
|
26
|
-
|
|
18
|
+
from copy import copy as shallowcopy
|
|
27
19
|
from types import ModuleType
|
|
20
|
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
|
|
28
21
|
|
|
29
|
-
|
|
22
|
+
import numpy as np
|
|
30
23
|
|
|
31
24
|
import warp
|
|
32
|
-
import warp.utils
|
|
33
|
-
import warp.codegen
|
|
34
25
|
import warp.build
|
|
26
|
+
import warp.codegen
|
|
35
27
|
import warp.config
|
|
36
28
|
|
|
37
|
-
import numpy as np
|
|
38
|
-
|
|
39
29
|
# represents either a built-in or user-defined function
|
|
40
30
|
|
|
41
31
|
|
|
@@ -46,6 +36,18 @@ def create_value_func(type):
|
|
|
46
36
|
return value_func
|
|
47
37
|
|
|
48
38
|
|
|
39
|
+
def get_function_args(func):
|
|
40
|
+
"""Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type."""
|
|
41
|
+
import inspect
|
|
42
|
+
|
|
43
|
+
argspec = inspect.getfullargspec(func)
|
|
44
|
+
|
|
45
|
+
# use source-level argument annotations
|
|
46
|
+
if len(argspec.annotations) < len(argspec.args):
|
|
47
|
+
raise RuntimeError(f"Incomplete argument annotations on function {func.__qualname__}")
|
|
48
|
+
return argspec.annotations
|
|
49
|
+
|
|
50
|
+
|
|
49
51
|
class Function:
|
|
50
52
|
def __init__(
|
|
51
53
|
self,
|
|
@@ -67,6 +69,17 @@ class Function:
|
|
|
67
69
|
generic=False,
|
|
68
70
|
native_func=None,
|
|
69
71
|
defaults=None,
|
|
72
|
+
custom_replay_func=None,
|
|
73
|
+
native_snippet=None,
|
|
74
|
+
adj_native_snippet=None,
|
|
75
|
+
skip_forward_codegen=False,
|
|
76
|
+
skip_reverse_codegen=False,
|
|
77
|
+
custom_reverse_num_input_args=-1,
|
|
78
|
+
custom_reverse_mode=False,
|
|
79
|
+
overloaded_annotations=None,
|
|
80
|
+
code_transformers=[],
|
|
81
|
+
skip_adding_overload=False,
|
|
82
|
+
require_original_output_arg=False,
|
|
70
83
|
):
|
|
71
84
|
self.func = func # points to Python function decorated with @wp.func, may be None for builtins
|
|
72
85
|
self.key = key
|
|
@@ -80,6 +93,12 @@ class Function:
|
|
|
80
93
|
self.module = module
|
|
81
94
|
self.variadic = variadic # function can take arbitrary number of inputs, e.g.: printf()
|
|
82
95
|
self.defaults = defaults
|
|
96
|
+
# Function instance for a custom implementation of the replay pass
|
|
97
|
+
self.custom_replay_func = custom_replay_func
|
|
98
|
+
self.native_snippet = native_snippet
|
|
99
|
+
self.adj_native_snippet = adj_native_snippet
|
|
100
|
+
self.custom_grad_func = None
|
|
101
|
+
self.require_original_output_arg = require_original_output_arg
|
|
83
102
|
|
|
84
103
|
if initializer_list_func is None:
|
|
85
104
|
self.initializer_list_func = lambda x, y: False
|
|
@@ -108,7 +127,16 @@ class Function:
|
|
|
108
127
|
self.user_overloads = {}
|
|
109
128
|
|
|
110
129
|
# user defined (Python) function
|
|
111
|
-
self.adj = warp.codegen.Adjoint(
|
|
130
|
+
self.adj = warp.codegen.Adjoint(
|
|
131
|
+
func,
|
|
132
|
+
is_user_function=True,
|
|
133
|
+
skip_forward_codegen=skip_forward_codegen,
|
|
134
|
+
skip_reverse_codegen=skip_reverse_codegen,
|
|
135
|
+
custom_reverse_num_input_args=custom_reverse_num_input_args,
|
|
136
|
+
custom_reverse_mode=custom_reverse_mode,
|
|
137
|
+
overload_annotations=overloaded_annotations,
|
|
138
|
+
transformers=code_transformers,
|
|
139
|
+
)
|
|
112
140
|
|
|
113
141
|
# record input types
|
|
114
142
|
for name, type in self.adj.arg_types.items():
|
|
@@ -136,11 +164,12 @@ class Function:
|
|
|
136
164
|
else:
|
|
137
165
|
self.mangled_name = None
|
|
138
166
|
|
|
139
|
-
|
|
167
|
+
if not skip_adding_overload:
|
|
168
|
+
self.add_overload(self)
|
|
140
169
|
|
|
141
170
|
# add to current module
|
|
142
171
|
if module:
|
|
143
|
-
module.register_function(self)
|
|
172
|
+
module.register_function(self, skip_adding_overload)
|
|
144
173
|
|
|
145
174
|
def __call__(self, *args, **kwargs):
|
|
146
175
|
# handles calling a builtin (native) function
|
|
@@ -149,124 +178,52 @@ class Function:
|
|
|
149
178
|
# from within a kernel (experimental).
|
|
150
179
|
|
|
151
180
|
if self.is_builtin() and self.mangled_name:
|
|
152
|
-
#
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
for
|
|
156
|
-
if
|
|
181
|
+
# For each of this function's existing overloads, we attempt to pack
|
|
182
|
+
# the given arguments into the C types expected by the corresponding
|
|
183
|
+
# parameters, and we rinse and repeat until we get a match.
|
|
184
|
+
for overload in self.overloads:
|
|
185
|
+
if overload.generic:
|
|
157
186
|
continue
|
|
158
187
|
|
|
159
|
-
|
|
160
|
-
if
|
|
161
|
-
|
|
162
|
-
f"Couldn't find function {self.key} with mangled name {f.mangled_name} in the Warp native library"
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
try:
|
|
166
|
-
# try and pack args into what the function expects
|
|
167
|
-
params = []
|
|
168
|
-
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
|
|
169
|
-
a = args[i]
|
|
170
|
-
|
|
171
|
-
# try to convert to a value type (vec3, mat33, etc)
|
|
172
|
-
if issubclass(arg_type, ctypes.Array):
|
|
173
|
-
# wrap the arg_type (which is an ctypes.Array) in a structure
|
|
174
|
-
# to ensure parameter is passed to the .dll by value rather than reference
|
|
175
|
-
class ValueArg(ctypes.Structure):
|
|
176
|
-
_fields_ = [("value", arg_type)]
|
|
177
|
-
|
|
178
|
-
x = ValueArg()
|
|
179
|
-
|
|
180
|
-
# force conversion to ndarray first (handles tuple / list, Gf.Vec3 case)
|
|
181
|
-
if isinstance(a, ctypes.Array) == False:
|
|
182
|
-
# assume you want the float32 version of the function so it doesn't just
|
|
183
|
-
# grab an override for a random data type:
|
|
184
|
-
if arg_type._type_ != ctypes.c_float:
|
|
185
|
-
raise RuntimeError(
|
|
186
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' does not have c_float type."
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
a = np.array(a)
|
|
190
|
-
|
|
191
|
-
# flatten to 1D array
|
|
192
|
-
v = a.flatten()
|
|
193
|
-
if len(v) != arg_type._length_:
|
|
194
|
-
raise RuntimeError(
|
|
195
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' has length {len(v)}, but expected {arg_type._length_}. Could not convert parameter to {arg_type}."
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
for i in range(arg_type._length_):
|
|
199
|
-
x.value[i] = v[i]
|
|
200
|
-
|
|
201
|
-
else:
|
|
202
|
-
# already a built-in type, check it matches
|
|
203
|
-
if not warp.types.types_equal(type(a), arg_type):
|
|
204
|
-
raise RuntimeError(
|
|
205
|
-
f"Error calling function '{f.key}', parameter for argument '{arg_name}' has type '{type(a)}' but expected '{arg_type}'"
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
x.value = a
|
|
209
|
-
|
|
210
|
-
params.append(x)
|
|
211
|
-
|
|
212
|
-
else:
|
|
213
|
-
try:
|
|
214
|
-
# try to pack as a scalar type
|
|
215
|
-
params.append(arg_type._type_(a))
|
|
216
|
-
except:
|
|
217
|
-
raise RuntimeError(
|
|
218
|
-
f"Error calling function {f.key}, unable to pack function parameter type {type(a)} for param {arg_name}, expected {arg_type}"
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
# returns the corresponding ctype for a scalar or vector warp type
|
|
222
|
-
def type_ctype(dtype):
|
|
223
|
-
if dtype == float:
|
|
224
|
-
return ctypes.c_float
|
|
225
|
-
elif dtype == int:
|
|
226
|
-
return ctypes.c_int32
|
|
227
|
-
elif issubclass(dtype, ctypes.Array):
|
|
228
|
-
return dtype
|
|
229
|
-
elif issubclass(dtype, ctypes.Structure):
|
|
230
|
-
return dtype
|
|
231
|
-
else:
|
|
232
|
-
# scalar type
|
|
233
|
-
return dtype._type_
|
|
234
|
-
|
|
235
|
-
value_type = type_ctype(f.value_func(None, None, None))
|
|
236
|
-
|
|
237
|
-
# construct return value (passed by address)
|
|
238
|
-
ret = value_type()
|
|
239
|
-
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
188
|
+
success, return_value = call_builtin(overload, *args)
|
|
189
|
+
if success:
|
|
190
|
+
return return_value
|
|
240
191
|
|
|
241
|
-
|
|
192
|
+
# overload resolution or call failed
|
|
193
|
+
raise RuntimeError(
|
|
194
|
+
f"Couldn't find a function '{self.key}' compatible with "
|
|
195
|
+
f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
|
|
196
|
+
)
|
|
242
197
|
|
|
243
|
-
|
|
244
|
-
|
|
198
|
+
if hasattr(self, "user_overloads") and len(self.user_overloads):
|
|
199
|
+
# user-defined function with overloads
|
|
245
200
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
# return scalar types as int/float
|
|
251
|
-
return ret.value
|
|
201
|
+
if len(kwargs):
|
|
202
|
+
raise RuntimeError(
|
|
203
|
+
f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads."
|
|
204
|
+
)
|
|
252
205
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
206
|
+
# try and find a matching overload
|
|
207
|
+
for overload in self.user_overloads.values():
|
|
208
|
+
if len(overload.input_types) != len(args):
|
|
209
|
+
continue
|
|
210
|
+
template_types = list(overload.input_types.values())
|
|
211
|
+
arg_names = list(overload.input_types.keys())
|
|
212
|
+
try:
|
|
213
|
+
# attempt to unify argument types with function template types
|
|
214
|
+
warp.types.infer_argument_types(args, template_types, arg_names)
|
|
215
|
+
return overload.func(*args)
|
|
216
|
+
except Exception:
|
|
257
217
|
continue
|
|
258
218
|
|
|
259
|
-
|
|
260
|
-
# raise the last exception encountered
|
|
261
|
-
if error:
|
|
262
|
-
raise error
|
|
263
|
-
else:
|
|
264
|
-
raise RuntimeError(f"Error calling function '{f.key}'.")
|
|
219
|
+
raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
|
|
265
220
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
221
|
+
# user-defined function with no overloads
|
|
222
|
+
if self.func is None:
|
|
223
|
+
raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
|
|
224
|
+
|
|
225
|
+
# this function has no overloads, call it like a plain Python function
|
|
226
|
+
return self.func(*args, **kwargs)
|
|
270
227
|
|
|
271
228
|
def is_builtin(self):
|
|
272
229
|
return self.func is None
|
|
@@ -286,7 +243,7 @@ class Function:
|
|
|
286
243
|
# todo: construct a default value for each of the functions args
|
|
287
244
|
# so we can generate the return type for overloaded functions
|
|
288
245
|
return_type = type_str(self.value_func(None, None, None))
|
|
289
|
-
except:
|
|
246
|
+
except Exception:
|
|
290
247
|
return False
|
|
291
248
|
|
|
292
249
|
if return_type.startswith("Tuple"):
|
|
@@ -379,10 +336,187 @@ class Function:
|
|
|
379
336
|
return None
|
|
380
337
|
|
|
381
338
|
def __repr__(self):
|
|
382
|
-
inputs_str = ", ".join([f"{k}: {v
|
|
339
|
+
inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
|
|
383
340
|
return f"<Function {self.key}({inputs_str})>"
|
|
384
341
|
|
|
385
342
|
|
|
343
|
+
def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
|
|
344
|
+
uses_non_warp_array_type = False
|
|
345
|
+
|
|
346
|
+
# Retrieve the built-in function from Warp's dll.
|
|
347
|
+
c_func = getattr(warp.context.runtime.core, func.mangled_name)
|
|
348
|
+
|
|
349
|
+
# Try gathering the parameters that the function expects and pack them
|
|
350
|
+
# into their corresponding C types.
|
|
351
|
+
c_params = []
|
|
352
|
+
for i, (_, arg_type) in enumerate(func.input_types.items()):
|
|
353
|
+
param = params[i]
|
|
354
|
+
|
|
355
|
+
try:
|
|
356
|
+
iter(param)
|
|
357
|
+
except TypeError:
|
|
358
|
+
is_array = False
|
|
359
|
+
else:
|
|
360
|
+
is_array = True
|
|
361
|
+
|
|
362
|
+
if is_array:
|
|
363
|
+
if not issubclass(arg_type, ctypes.Array):
|
|
364
|
+
return (False, None)
|
|
365
|
+
|
|
366
|
+
# The argument expects a built-in Warp type like a vector or a matrix.
|
|
367
|
+
|
|
368
|
+
c_param = None
|
|
369
|
+
|
|
370
|
+
if isinstance(param, ctypes.Array):
|
|
371
|
+
# The given parameter is also a built-in Warp type, so we only need
|
|
372
|
+
# to make sure that it matches with the argument.
|
|
373
|
+
if not warp.types.types_equal(type(param), arg_type):
|
|
374
|
+
return (False, None)
|
|
375
|
+
|
|
376
|
+
if isinstance(param, arg_type):
|
|
377
|
+
c_param = param
|
|
378
|
+
else:
|
|
379
|
+
# Cast the value to its argument type to make sure that it
|
|
380
|
+
# can be assigned to the field of the `Param` struct.
|
|
381
|
+
# This could error otherwise when, for example, the field type
|
|
382
|
+
# is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
|
|
383
|
+
# even though both types are semantically identical.
|
|
384
|
+
c_param = arg_type(param)
|
|
385
|
+
else:
|
|
386
|
+
# Flatten the parameter values into a flat 1-D array.
|
|
387
|
+
arr = []
|
|
388
|
+
ndim = 1
|
|
389
|
+
stack = [(0, param)]
|
|
390
|
+
while stack:
|
|
391
|
+
depth, elem = stack.pop(0)
|
|
392
|
+
try:
|
|
393
|
+
# If `elem` is a sequence, then it should be possible
|
|
394
|
+
# to add its elements to the stack for later processing.
|
|
395
|
+
stack.extend((depth + 1, x) for x in elem)
|
|
396
|
+
except TypeError:
|
|
397
|
+
# Since `elem` doesn't seem to be a sequence,
|
|
398
|
+
# we must have a leaf value that we need to add to our
|
|
399
|
+
# resulting array.
|
|
400
|
+
arr.append(elem)
|
|
401
|
+
ndim = max(depth, ndim)
|
|
402
|
+
|
|
403
|
+
assert ndim > 0
|
|
404
|
+
|
|
405
|
+
# Ensure that if the given parameter value is, say, a 2-D array,
|
|
406
|
+
# then we try to resolve it against a matrix argument rather than
|
|
407
|
+
# a vector.
|
|
408
|
+
if ndim > len(arg_type._shape_):
|
|
409
|
+
return (False, None)
|
|
410
|
+
|
|
411
|
+
elem_count = len(arr)
|
|
412
|
+
if elem_count != arg_type._length_:
|
|
413
|
+
return (False, None)
|
|
414
|
+
|
|
415
|
+
# Retrieve the element type of the sequence while ensuring
|
|
416
|
+
# that it's homogeneous.
|
|
417
|
+
elem_type = type(arr[0])
|
|
418
|
+
for i in range(1, elem_count):
|
|
419
|
+
if type(arr[i]) is not elem_type:
|
|
420
|
+
raise ValueError("All array elements must share the same type.")
|
|
421
|
+
|
|
422
|
+
expected_elem_type = arg_type._wp_scalar_type_
|
|
423
|
+
if not (
|
|
424
|
+
elem_type is expected_elem_type
|
|
425
|
+
or (elem_type is float and expected_elem_type is warp.types.float32)
|
|
426
|
+
or (elem_type is int and expected_elem_type is warp.types.int32)
|
|
427
|
+
or (
|
|
428
|
+
issubclass(elem_type, np.number)
|
|
429
|
+
and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
|
|
430
|
+
)
|
|
431
|
+
):
|
|
432
|
+
# The parameter value has a type not matching the type defined
|
|
433
|
+
# for the corresponding argument.
|
|
434
|
+
return (False, None)
|
|
435
|
+
|
|
436
|
+
if elem_type in warp.types.int_types:
|
|
437
|
+
# Pass the value through the expected integer type
|
|
438
|
+
# in order to evaluate any integer wrapping.
|
|
439
|
+
# For example `uint8(-1)` should result in the value `-255`.
|
|
440
|
+
arr = tuple(elem_type._type_(x.value).value for x in arr)
|
|
441
|
+
elif elem_type in warp.types.float_types:
|
|
442
|
+
# Extract the floating-point values.
|
|
443
|
+
arr = tuple(x.value for x in arr)
|
|
444
|
+
|
|
445
|
+
c_param = arg_type()
|
|
446
|
+
if warp.types.type_is_matrix(arg_type):
|
|
447
|
+
rows, cols = arg_type._shape_
|
|
448
|
+
for i in range(rows):
|
|
449
|
+
idx_start = i * cols
|
|
450
|
+
idx_end = idx_start + cols
|
|
451
|
+
c_param[i] = arr[idx_start:idx_end]
|
|
452
|
+
else:
|
|
453
|
+
c_param[:] = arr
|
|
454
|
+
|
|
455
|
+
uses_non_warp_array_type = True
|
|
456
|
+
|
|
457
|
+
c_params.append(ctypes.byref(c_param))
|
|
458
|
+
else:
|
|
459
|
+
if issubclass(arg_type, ctypes.Array):
|
|
460
|
+
return (False, None)
|
|
461
|
+
|
|
462
|
+
if not (
|
|
463
|
+
isinstance(param, arg_type)
|
|
464
|
+
or (type(param) is float and arg_type is warp.types.float32)
|
|
465
|
+
or (type(param) is int and arg_type is warp.types.int32)
|
|
466
|
+
or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
|
|
467
|
+
):
|
|
468
|
+
return (False, None)
|
|
469
|
+
|
|
470
|
+
if type(param) in warp.types.scalar_types:
|
|
471
|
+
param = param.value
|
|
472
|
+
|
|
473
|
+
# try to pack as a scalar type
|
|
474
|
+
if arg_type == warp.types.float16:
|
|
475
|
+
c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
|
|
476
|
+
else:
|
|
477
|
+
c_params.append(arg_type._type_(param))
|
|
478
|
+
|
|
479
|
+
# returns the corresponding ctype for a scalar or vector warp type
|
|
480
|
+
value_type = func.value_func(None, None, None)
|
|
481
|
+
if value_type == float:
|
|
482
|
+
value_ctype = ctypes.c_float
|
|
483
|
+
elif value_type == int:
|
|
484
|
+
value_ctype = ctypes.c_int32
|
|
485
|
+
elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
|
|
486
|
+
value_ctype = value_type
|
|
487
|
+
else:
|
|
488
|
+
# scalar type
|
|
489
|
+
value_ctype = value_type._type_
|
|
490
|
+
|
|
491
|
+
# construct return value (passed by address)
|
|
492
|
+
ret = value_ctype()
|
|
493
|
+
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
|
|
494
|
+
c_params.append(ret_addr)
|
|
495
|
+
|
|
496
|
+
# Call the built-in function from Warp's dll.
|
|
497
|
+
c_func(*c_params)
|
|
498
|
+
|
|
499
|
+
if uses_non_warp_array_type:
|
|
500
|
+
warp.utils.warn(
|
|
501
|
+
"Support for built-in functions called with non-Warp array types, "
|
|
502
|
+
"such as lists, tuples, NumPy arrays, and others, will be dropped "
|
|
503
|
+
"in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
|
|
504
|
+
"`wp.quat`, or `wp.transform`.",
|
|
505
|
+
DeprecationWarning,
|
|
506
|
+
stacklevel=3,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
|
|
510
|
+
# return vector types as ctypes
|
|
511
|
+
return (True, ret)
|
|
512
|
+
|
|
513
|
+
if value_type == warp.types.float16:
|
|
514
|
+
return (True, warp.types.half_bits_to_float(ret.value))
|
|
515
|
+
|
|
516
|
+
# return scalar types as int/float
|
|
517
|
+
return (True, ret.value)
|
|
518
|
+
|
|
519
|
+
|
|
386
520
|
class KernelHooks:
|
|
387
521
|
def __init__(self, forward, backward):
|
|
388
522
|
self.forward = forward
|
|
@@ -391,13 +525,23 @@ class KernelHooks:
|
|
|
391
525
|
|
|
392
526
|
# caches source and compiled entry points for a kernel (will be populated after module loads)
|
|
393
527
|
class Kernel:
|
|
394
|
-
def __init__(self, func, key, module, options=None):
|
|
528
|
+
def __init__(self, func, key=None, module=None, options=None, code_transformers=[]):
|
|
395
529
|
self.func = func
|
|
396
|
-
|
|
397
|
-
|
|
530
|
+
|
|
531
|
+
if module is None:
|
|
532
|
+
self.module = get_module(func.__module__)
|
|
533
|
+
else:
|
|
534
|
+
self.module = module
|
|
535
|
+
|
|
536
|
+
if key is None:
|
|
537
|
+
unique_key = self.module.generate_unique_kernel_key(func.__name__)
|
|
538
|
+
self.key = unique_key
|
|
539
|
+
else:
|
|
540
|
+
self.key = key
|
|
541
|
+
|
|
398
542
|
self.options = {} if options is None else options
|
|
399
543
|
|
|
400
|
-
self.adj = warp.codegen.Adjoint(func)
|
|
544
|
+
self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
|
|
401
545
|
|
|
402
546
|
# check if generic
|
|
403
547
|
self.is_generic = False
|
|
@@ -415,8 +559,8 @@ class Kernel:
|
|
|
415
559
|
# argument indices by name
|
|
416
560
|
self.arg_indices = dict((a.label, i) for i, a in enumerate(self.adj.args))
|
|
417
561
|
|
|
418
|
-
if module:
|
|
419
|
-
module.register_kernel(self)
|
|
562
|
+
if self.module:
|
|
563
|
+
self.module.register_kernel(self)
|
|
420
564
|
|
|
421
565
|
def infer_argument_types(self, args):
|
|
422
566
|
template_types = list(self.adj.arg_types.values())
|
|
@@ -425,44 +569,8 @@ class Kernel:
|
|
|
425
569
|
raise RuntimeError(f"Invalid number of arguments for kernel {self.key}")
|
|
426
570
|
|
|
427
571
|
arg_names = list(self.adj.arg_types.keys())
|
|
428
|
-
arg_types = []
|
|
429
|
-
|
|
430
|
-
for i in range(len(args)):
|
|
431
|
-
arg = args[i]
|
|
432
|
-
arg_type = type(arg)
|
|
433
|
-
if arg_type in warp.types.array_types:
|
|
434
|
-
arg_types.append(arg_type(dtype=arg.dtype, ndim=arg.ndim))
|
|
435
|
-
elif arg_type in warp.types.scalar_types:
|
|
436
|
-
arg_types.append(arg_type)
|
|
437
|
-
elif arg_type in [int, float]:
|
|
438
|
-
# canonicalize type
|
|
439
|
-
arg_types.append(warp.types.type_to_warp(arg_type))
|
|
440
|
-
elif hasattr(arg_type, "_wp_scalar_type_"):
|
|
441
|
-
# vector/matrix type
|
|
442
|
-
arg_types.append(arg_type)
|
|
443
|
-
elif issubclass(arg_type, warp.codegen.StructInstance):
|
|
444
|
-
# a struct
|
|
445
|
-
arg_types.append(arg._struct_)
|
|
446
|
-
# elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
|
|
447
|
-
# arg_types.append(arg_type)
|
|
448
|
-
# elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.bvh_query_t]:
|
|
449
|
-
# arg_types.append(arg_type)
|
|
450
|
-
elif arg is None:
|
|
451
|
-
# allow passing None for arrays
|
|
452
|
-
t = template_types[i]
|
|
453
|
-
if warp.types.is_array(t):
|
|
454
|
-
arg_types.append(type(t)(dtype=t.dtype, ndim=t.ndim))
|
|
455
|
-
else:
|
|
456
|
-
raise TypeError(
|
|
457
|
-
f"Unable to infer the type of argument '{arg_names[i]}' for kernel {self.key}, got None"
|
|
458
|
-
)
|
|
459
|
-
else:
|
|
460
|
-
# TODO: attempt to figure out if it's a vector/matrix type given as a numpy array, list, etc.
|
|
461
|
-
raise TypeError(
|
|
462
|
-
f"Unable to infer the type of argument '{arg_names[i]}' for kernel {self.key}, got {arg_type}"
|
|
463
|
-
)
|
|
464
572
|
|
|
465
|
-
return
|
|
573
|
+
return warp.types.infer_argument_types(args, template_types, arg_names)
|
|
466
574
|
|
|
467
575
|
def add_overload(self, arg_types):
|
|
468
576
|
if len(arg_types) != len(self.adj.arg_types):
|
|
@@ -529,7 +637,7 @@ def func(f):
|
|
|
529
637
|
name = warp.codegen.make_full_qualified_name(f)
|
|
530
638
|
|
|
531
639
|
m = get_module(f.__module__)
|
|
532
|
-
|
|
640
|
+
Function(
|
|
533
641
|
func=f, key=name, namespace="", module=m, value_func=None
|
|
534
642
|
) # value_type not known yet, will be inferred during Adjoint.build()
|
|
535
643
|
|
|
@@ -537,6 +645,167 @@ def func(f):
|
|
|
537
645
|
return m.functions[name]
|
|
538
646
|
|
|
539
647
|
|
|
648
|
+
def func_native(snippet, adj_snippet=None):
|
|
649
|
+
"""
|
|
650
|
+
Decorator to register native code snippet, @func_native
|
|
651
|
+
"""
|
|
652
|
+
|
|
653
|
+
def snippet_func(f):
|
|
654
|
+
name = warp.codegen.make_full_qualified_name(f)
|
|
655
|
+
|
|
656
|
+
m = get_module(f.__module__)
|
|
657
|
+
func = Function(
|
|
658
|
+
func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet
|
|
659
|
+
) # cuda snippets do not have a return value_type
|
|
660
|
+
|
|
661
|
+
return m.functions[name]
|
|
662
|
+
|
|
663
|
+
return snippet_func
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def func_grad(forward_fn):
|
|
667
|
+
"""
|
|
668
|
+
Decorator to register a custom gradient function for a given forward function.
|
|
669
|
+
The function signature must correspond to one of the function overloads in the following way:
|
|
670
|
+
the first part of the input arguments are the original input variables with the same types as their
|
|
671
|
+
corresponding arguments in the original function, and the second part of the input arguments are the
|
|
672
|
+
adjoint variables of the output variables (if available) of the original function with the same types as the
|
|
673
|
+
output variables. The function must not return anything.
|
|
674
|
+
"""
|
|
675
|
+
|
|
676
|
+
def wrapper(grad_fn):
|
|
677
|
+
generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
|
|
678
|
+
if generic:
|
|
679
|
+
raise RuntimeError(
|
|
680
|
+
f"Cannot define custom grad definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
|
|
681
|
+
)
|
|
682
|
+
|
|
683
|
+
reverse_args = {}
|
|
684
|
+
reverse_args.update(forward_fn.input_types)
|
|
685
|
+
|
|
686
|
+
# create temporary Adjoint instance to analyze the function signature
|
|
687
|
+
adj = warp.codegen.Adjoint(
|
|
688
|
+
grad_fn, skip_forward_codegen=True, skip_reverse_codegen=False, transformers=forward_fn.adj.transformers
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
from warp.types import types_equal
|
|
692
|
+
|
|
693
|
+
grad_args = adj.args
|
|
694
|
+
grad_sig = warp.types.get_signature([arg.type for arg in grad_args], func_name=forward_fn.key)
|
|
695
|
+
|
|
696
|
+
generic = any(warp.types.type_is_generic(x.type) for x in grad_args)
|
|
697
|
+
if generic:
|
|
698
|
+
raise RuntimeError(
|
|
699
|
+
f"Cannot define custom grad definition for {forward_fn.key} since the provided grad function has generic input arguments."
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
def match_function(f):
|
|
703
|
+
# check whether the function overload f matches the signature of the provided gradient function
|
|
704
|
+
if not hasattr(f.adj, "return_var"):
|
|
705
|
+
f.adj.build(None)
|
|
706
|
+
expected_args = list(f.input_types.items())
|
|
707
|
+
if f.adj.return_var is not None:
|
|
708
|
+
expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var]
|
|
709
|
+
if len(grad_args) != len(expected_args):
|
|
710
|
+
return False
|
|
711
|
+
if any(not types_equal(a.type, exp_type) for a, (_, exp_type) in zip(grad_args, expected_args)):
|
|
712
|
+
return False
|
|
713
|
+
return True
|
|
714
|
+
|
|
715
|
+
def add_custom_grad(f: Function):
|
|
716
|
+
# register custom gradient function
|
|
717
|
+
f.custom_grad_func = Function(
|
|
718
|
+
grad_fn,
|
|
719
|
+
key=f.key,
|
|
720
|
+
namespace=f.namespace,
|
|
721
|
+
input_types=reverse_args,
|
|
722
|
+
value_func=None,
|
|
723
|
+
module=f.module,
|
|
724
|
+
template_func=f.template_func,
|
|
725
|
+
skip_forward_codegen=True,
|
|
726
|
+
custom_reverse_mode=True,
|
|
727
|
+
custom_reverse_num_input_args=len(f.input_types),
|
|
728
|
+
skip_adding_overload=False,
|
|
729
|
+
code_transformers=f.adj.transformers,
|
|
730
|
+
)
|
|
731
|
+
f.adj.skip_reverse_codegen = True
|
|
732
|
+
|
|
733
|
+
if hasattr(forward_fn, "user_overloads") and len(forward_fn.user_overloads):
|
|
734
|
+
# find matching overload for which this grad function is defined
|
|
735
|
+
for sig, f in forward_fn.user_overloads.items():
|
|
736
|
+
if not grad_sig.startswith(sig):
|
|
737
|
+
continue
|
|
738
|
+
if match_function(f):
|
|
739
|
+
add_custom_grad(f)
|
|
740
|
+
return
|
|
741
|
+
raise RuntimeError(
|
|
742
|
+
f"No function overload found for gradient function {grad_fn.__qualname__} for function {forward_fn.key}"
|
|
743
|
+
)
|
|
744
|
+
else:
|
|
745
|
+
# resolve return variables
|
|
746
|
+
forward_fn.adj.build(None)
|
|
747
|
+
|
|
748
|
+
expected_args = list(forward_fn.input_types.items())
|
|
749
|
+
if forward_fn.adj.return_var is not None:
|
|
750
|
+
expected_args += [(f"adj_ret_{var.label}", var.type) for var in forward_fn.adj.return_var]
|
|
751
|
+
|
|
752
|
+
# check if the signature matches this function
|
|
753
|
+
if match_function(forward_fn):
|
|
754
|
+
add_custom_grad(forward_fn)
|
|
755
|
+
else:
|
|
756
|
+
raise RuntimeError(
|
|
757
|
+
f"Gradient function {grad_fn.__qualname__} for function {forward_fn.key} has an incorrect signature. The arguments must match the "
|
|
758
|
+
"forward function arguments plus the adjoint variables corresponding to the return variables:"
|
|
759
|
+
f"\n{', '.join(map(lambda nt: f'{nt[0]}: {nt[1].__name__}', expected_args))}"
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
return wrapper
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def func_replay(forward_fn):
|
|
766
|
+
"""
|
|
767
|
+
Decorator to register a custom replay function for a given forward function.
|
|
768
|
+
The replay function is the function version that is called in the forward phase of the backward pass (replay mode) and corresponds to the forward function by default.
|
|
769
|
+
The provided function has to match the signature of one of the original forward function overloads.
|
|
770
|
+
"""
|
|
771
|
+
|
|
772
|
+
def wrapper(replay_fn):
|
|
773
|
+
generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
|
|
774
|
+
if generic:
|
|
775
|
+
raise RuntimeError(
|
|
776
|
+
f"Cannot define custom replay definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
args = get_function_args(replay_fn)
|
|
780
|
+
arg_types = list(args.values())
|
|
781
|
+
generic = any(warp.types.type_is_generic(x) for x in arg_types)
|
|
782
|
+
if generic:
|
|
783
|
+
raise RuntimeError(
|
|
784
|
+
f"Cannot define custom replay definition for {forward_fn.key} since the provided replay function has generic input arguments."
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
f = forward_fn.get_overload(arg_types)
|
|
788
|
+
if f is None:
|
|
789
|
+
inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in args.items()])
|
|
790
|
+
raise RuntimeError(
|
|
791
|
+
f"Could not find forward definition of function {forward_fn.key} that matches custom replay definition with arguments:\n{inputs_str}"
|
|
792
|
+
)
|
|
793
|
+
f.custom_replay_func = Function(
|
|
794
|
+
replay_fn,
|
|
795
|
+
key=f"replay_{f.key}",
|
|
796
|
+
namespace=f.namespace,
|
|
797
|
+
input_types=f.input_types,
|
|
798
|
+
value_func=f.value_func,
|
|
799
|
+
module=f.module,
|
|
800
|
+
template_func=f.template_func,
|
|
801
|
+
skip_reverse_codegen=True,
|
|
802
|
+
skip_adding_overload=True,
|
|
803
|
+
code_transformers=f.adj.transformers,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
return wrapper
|
|
807
|
+
|
|
808
|
+
|
|
540
809
|
# decorator to register kernel, @kernel, custom_name may be a string
|
|
541
810
|
# that creates a kernel with a different name from the actual function
|
|
542
811
|
def kernel(f=None, *, enable_backward=None):
|
|
@@ -664,6 +933,7 @@ def add_builtin(
|
|
|
664
933
|
missing_grad=False,
|
|
665
934
|
native_func=None,
|
|
666
935
|
defaults=None,
|
|
936
|
+
require_original_output_arg=False,
|
|
667
937
|
):
|
|
668
938
|
# wrap simple single-type functions with a value_func()
|
|
669
939
|
if value_func is None:
|
|
@@ -676,7 +946,7 @@ def add_builtin(
|
|
|
676
946
|
def initializer_list_func(args, templates):
|
|
677
947
|
return False
|
|
678
948
|
|
|
679
|
-
if defaults
|
|
949
|
+
if defaults is None:
|
|
680
950
|
defaults = {}
|
|
681
951
|
|
|
682
952
|
# Add specialized versions of this builtin if it's generic by matching arguments against
|
|
@@ -757,8 +1027,8 @@ def add_builtin(
|
|
|
757
1027
|
# on the generated argument list and skip generation if it fails.
|
|
758
1028
|
# This also gives us the return type, which we keep for later:
|
|
759
1029
|
try:
|
|
760
|
-
return_type = value_func(
|
|
761
|
-
except Exception
|
|
1030
|
+
return_type = value_func(argtypes, {}, [])
|
|
1031
|
+
except Exception:
|
|
762
1032
|
continue
|
|
763
1033
|
|
|
764
1034
|
# The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
|
|
@@ -788,6 +1058,7 @@ def add_builtin(
|
|
|
788
1058
|
hidden=True,
|
|
789
1059
|
skip_replay=skip_replay,
|
|
790
1060
|
missing_grad=missing_grad,
|
|
1061
|
+
require_original_output_arg=require_original_output_arg,
|
|
791
1062
|
)
|
|
792
1063
|
|
|
793
1064
|
func = Function(
|
|
@@ -808,6 +1079,7 @@ def add_builtin(
|
|
|
808
1079
|
generic=generic,
|
|
809
1080
|
native_func=native_func,
|
|
810
1081
|
defaults=defaults,
|
|
1082
|
+
require_original_output_arg=require_original_output_arg,
|
|
811
1083
|
)
|
|
812
1084
|
|
|
813
1085
|
if key in builtin_functions:
|
|
@@ -817,7 +1089,7 @@ def add_builtin(
|
|
|
817
1089
|
|
|
818
1090
|
# export means the function will be added to the `warp` module namespace
|
|
819
1091
|
# so that users can call it directly from the Python interpreter
|
|
820
|
-
if export
|
|
1092
|
+
if export:
|
|
821
1093
|
if hasattr(warp, key):
|
|
822
1094
|
# check that we haven't already created something at this location
|
|
823
1095
|
# if it's just an overload stub for auto-complete then overwrite it
|
|
@@ -884,6 +1156,8 @@ class ModuleBuilder:
|
|
|
884
1156
|
for func in module.functions.values():
|
|
885
1157
|
for f in func.user_overloads.values():
|
|
886
1158
|
self.build_function(f)
|
|
1159
|
+
if f.custom_replay_func is not None:
|
|
1160
|
+
self.build_function(f.custom_replay_func)
|
|
887
1161
|
|
|
888
1162
|
# build all kernel entry points
|
|
889
1163
|
for kernel in module.kernels.values():
|
|
@@ -900,12 +1174,13 @@ class ModuleBuilder:
|
|
|
900
1174
|
while stack:
|
|
901
1175
|
s = stack.pop()
|
|
902
1176
|
|
|
903
|
-
|
|
904
|
-
structs.append(s)
|
|
1177
|
+
structs.append(s)
|
|
905
1178
|
|
|
906
1179
|
for var in s.vars.values():
|
|
907
1180
|
if isinstance(var.type, warp.codegen.Struct):
|
|
908
1181
|
stack.append(var.type)
|
|
1182
|
+
elif isinstance(var.type, warp.types.array) and isinstance(var.type.dtype, warp.codegen.Struct):
|
|
1183
|
+
stack.append(var.type.dtype)
|
|
909
1184
|
|
|
910
1185
|
# Build them in reverse to generate a correct dependency order.
|
|
911
1186
|
for s in reversed(structs):
|
|
@@ -931,7 +1206,7 @@ class ModuleBuilder:
|
|
|
931
1206
|
if not func.value_func:
|
|
932
1207
|
|
|
933
1208
|
def wrap(adj):
|
|
934
|
-
def value_type(
|
|
1209
|
+
def value_type(arg_types, kwds, templates):
|
|
935
1210
|
if adj.return_var is None or len(adj.return_var) == 0:
|
|
936
1211
|
return None
|
|
937
1212
|
if len(adj.return_var) == 1:
|
|
@@ -946,56 +1221,41 @@ class ModuleBuilder:
|
|
|
946
1221
|
# use dict to preserve import order
|
|
947
1222
|
self.functions[func] = None
|
|
948
1223
|
|
|
949
|
-
def
|
|
950
|
-
|
|
1224
|
+
def codegen(self, device):
|
|
1225
|
+
source = ""
|
|
951
1226
|
|
|
952
1227
|
# code-gen structs
|
|
953
1228
|
for struct in self.structs.keys():
|
|
954
|
-
|
|
1229
|
+
source += warp.codegen.codegen_struct(struct)
|
|
955
1230
|
|
|
956
1231
|
# code-gen all imported functions
|
|
957
1232
|
for func in self.functions.keys():
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
if not kernel.is_generic:
|
|
963
|
-
cpp_source += warp.codegen.codegen_kernel(kernel, device="cpu", options=self.options)
|
|
964
|
-
cpp_source += warp.codegen.codegen_module(kernel, device="cpu")
|
|
1233
|
+
if func.native_snippet is None:
|
|
1234
|
+
source += warp.codegen.codegen_func(
|
|
1235
|
+
func.adj, c_func_name=func.native_func, device=device, options=self.options
|
|
1236
|
+
)
|
|
965
1237
|
else:
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
# add headers
|
|
971
|
-
cpp_source = warp.codegen.cpu_module_header + cpp_source
|
|
972
|
-
|
|
973
|
-
return cpp_source
|
|
974
|
-
|
|
975
|
-
def codegen_cuda(self):
|
|
976
|
-
cu_source = ""
|
|
977
|
-
|
|
978
|
-
# code-gen structs
|
|
979
|
-
for struct in self.structs.keys():
|
|
980
|
-
cu_source += warp.codegen.codegen_struct(struct)
|
|
981
|
-
|
|
982
|
-
# code-gen all imported functions
|
|
983
|
-
for func in self.functions.keys():
|
|
984
|
-
cu_source += warp.codegen.codegen_func(func.adj, device="cuda")
|
|
1238
|
+
source += warp.codegen.codegen_snippet(
|
|
1239
|
+
func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet
|
|
1240
|
+
)
|
|
985
1241
|
|
|
986
1242
|
for kernel in self.module.kernels.values():
|
|
1243
|
+
# each kernel gets an entry point in the module
|
|
987
1244
|
if not kernel.is_generic:
|
|
988
|
-
|
|
989
|
-
|
|
1245
|
+
source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
|
|
1246
|
+
source += warp.codegen.codegen_module(kernel, device=device)
|
|
990
1247
|
else:
|
|
991
1248
|
for k in kernel.overloads.values():
|
|
992
|
-
|
|
993
|
-
|
|
1249
|
+
source += warp.codegen.codegen_kernel(k, device=device, options=self.options)
|
|
1250
|
+
source += warp.codegen.codegen_module(k, device=device)
|
|
994
1251
|
|
|
995
1252
|
# add headers
|
|
996
|
-
|
|
1253
|
+
if device == "cpu":
|
|
1254
|
+
source = warp.codegen.cpu_module_header + source
|
|
1255
|
+
else:
|
|
1256
|
+
source = warp.codegen.cuda_module_header + source
|
|
997
1257
|
|
|
998
|
-
return
|
|
1258
|
+
return source
|
|
999
1259
|
|
|
1000
1260
|
|
|
1001
1261
|
# -----------------------------------------------------
|
|
@@ -1014,7 +1274,6 @@ class Module:
|
|
|
1014
1274
|
self.constants = []
|
|
1015
1275
|
self.structs = {}
|
|
1016
1276
|
|
|
1017
|
-
self.dll = None
|
|
1018
1277
|
self.cpu_module = None
|
|
1019
1278
|
self.cuda_modules = {} # module lookup by CUDA context
|
|
1020
1279
|
|
|
@@ -1058,6 +1317,10 @@ class Module:
|
|
|
1058
1317
|
|
|
1059
1318
|
self.content_hash = None
|
|
1060
1319
|
|
|
1320
|
+
# number of times module auto-generates kernel key for user
|
|
1321
|
+
# used to ensure unique kernel keys
|
|
1322
|
+
self.count = 0
|
|
1323
|
+
|
|
1061
1324
|
def register_struct(self, struct):
|
|
1062
1325
|
self.structs[struct.key] = struct
|
|
1063
1326
|
|
|
@@ -1072,7 +1335,7 @@ class Module:
|
|
|
1072
1335
|
# for a reload of module on next launch
|
|
1073
1336
|
self.unload()
|
|
1074
1337
|
|
|
1075
|
-
def register_function(self, func):
|
|
1338
|
+
def register_function(self, func, skip_adding_overload=False):
|
|
1076
1339
|
if func.key not in self.functions:
|
|
1077
1340
|
self.functions[func.key] = func
|
|
1078
1341
|
else:
|
|
@@ -1092,7 +1355,7 @@ class Module:
|
|
|
1092
1355
|
)
|
|
1093
1356
|
if sig == sig_existing:
|
|
1094
1357
|
self.functions[func.key] = func
|
|
1095
|
-
|
|
1358
|
+
elif not skip_adding_overload:
|
|
1096
1359
|
func_existing.add_overload(func)
|
|
1097
1360
|
|
|
1098
1361
|
self.find_references(func.adj)
|
|
@@ -1100,6 +1363,11 @@ class Module:
|
|
|
1100
1363
|
# for a reload of module on next launch
|
|
1101
1364
|
self.unload()
|
|
1102
1365
|
|
|
1366
|
+
def generate_unique_kernel_key(self, key):
|
|
1367
|
+
unique_key = f"{key}_{self.count}"
|
|
1368
|
+
self.count += 1
|
|
1369
|
+
return unique_key
|
|
1370
|
+
|
|
1103
1371
|
# collect all referenced functions / structs
|
|
1104
1372
|
# given the AST of a function or kernel
|
|
1105
1373
|
def find_references(self, adj):
|
|
@@ -1113,13 +1381,13 @@ class Module:
|
|
|
1113
1381
|
if isinstance(node, ast.Call):
|
|
1114
1382
|
try:
|
|
1115
1383
|
# try to resolve the function
|
|
1116
|
-
func, _ = adj.
|
|
1384
|
+
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
|
|
1117
1385
|
|
|
1118
1386
|
# if this is a user-defined function, add a module reference
|
|
1119
1387
|
if isinstance(func, warp.context.Function) and func.module is not None:
|
|
1120
1388
|
add_ref(func.module)
|
|
1121
1389
|
|
|
1122
|
-
except:
|
|
1390
|
+
except Exception:
|
|
1123
1391
|
# Lookups may fail for builtins, but that's ok.
|
|
1124
1392
|
# Lookups may also fail for functions in this module that haven't been imported yet,
|
|
1125
1393
|
# and that's ok too (not an external reference).
|
|
@@ -1139,6 +1407,11 @@ class Module:
|
|
|
1139
1407
|
|
|
1140
1408
|
return getattr(obj, "__annotations__", {})
|
|
1141
1409
|
|
|
1410
|
+
def get_type_name(type_hint):
|
|
1411
|
+
if isinstance(type_hint, warp.codegen.Struct):
|
|
1412
|
+
return get_type_name(type_hint.cls)
|
|
1413
|
+
return type_hint
|
|
1414
|
+
|
|
1142
1415
|
def hash_recursive(module, visited):
|
|
1143
1416
|
# Hash this module, including all referenced modules recursively.
|
|
1144
1417
|
# The visited set tracks modules already visited to avoid circular references.
|
|
@@ -1151,7 +1424,8 @@ class Module:
|
|
|
1151
1424
|
# struct source
|
|
1152
1425
|
for struct in module.structs.values():
|
|
1153
1426
|
s = ",".join(
|
|
1154
|
-
"{}: {}".format(name, type_hint)
|
|
1427
|
+
"{}: {}".format(name, get_type_name(type_hint))
|
|
1428
|
+
for name, type_hint in get_annotations(struct.cls).items()
|
|
1155
1429
|
)
|
|
1156
1430
|
ch.update(bytes(s, "utf-8"))
|
|
1157
1431
|
|
|
@@ -1160,13 +1434,29 @@ class Module:
|
|
|
1160
1434
|
s = func.adj.source
|
|
1161
1435
|
ch.update(bytes(s, "utf-8"))
|
|
1162
1436
|
|
|
1437
|
+
if func.custom_grad_func:
|
|
1438
|
+
s = func.custom_grad_func.adj.source
|
|
1439
|
+
ch.update(bytes(s, "utf-8"))
|
|
1440
|
+
if func.custom_replay_func:
|
|
1441
|
+
s = func.custom_replay_func.adj.source
|
|
1442
|
+
|
|
1443
|
+
# cache func arg types
|
|
1444
|
+
for arg, arg_type in func.adj.arg_types.items():
|
|
1445
|
+
s = f"{arg}: {get_type_name(arg_type)}"
|
|
1446
|
+
ch.update(bytes(s, "utf-8"))
|
|
1447
|
+
|
|
1163
1448
|
# kernel source
|
|
1164
1449
|
for kernel in module.kernels.values():
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1450
|
+
ch.update(bytes(kernel.adj.source, "utf-8"))
|
|
1451
|
+
# cache kernel arg types
|
|
1452
|
+
for arg, arg_type in kernel.adj.arg_types.items():
|
|
1453
|
+
s = f"{arg}: {get_type_name(arg_type)}"
|
|
1454
|
+
ch.update(bytes(s, "utf-8"))
|
|
1455
|
+
# for generic kernels the Python source is always the same,
|
|
1456
|
+
# but we hash the type signatures of all the overloads
|
|
1457
|
+
if kernel.is_generic:
|
|
1458
|
+
for sig in sorted(kernel.overloads.keys()):
|
|
1459
|
+
ch.update(bytes(sig, "utf-8"))
|
|
1170
1460
|
|
|
1171
1461
|
module.content_hash = ch.digest()
|
|
1172
1462
|
|
|
@@ -1204,12 +1494,12 @@ class Module:
|
|
|
1204
1494
|
return hash_recursive(self, visited=set())
|
|
1205
1495
|
|
|
1206
1496
|
def load(self, device):
|
|
1497
|
+
from warp.utils import ScopedTimer
|
|
1498
|
+
|
|
1207
1499
|
device = get_device(device)
|
|
1208
1500
|
|
|
1209
1501
|
if device.is_cpu:
|
|
1210
1502
|
# check if already loaded
|
|
1211
|
-
if self.dll:
|
|
1212
|
-
return True
|
|
1213
1503
|
if self.cpu_module:
|
|
1214
1504
|
return True
|
|
1215
1505
|
# avoid repeated build attempts
|
|
@@ -1227,7 +1517,7 @@ class Module:
|
|
|
1227
1517
|
if not warp.is_cuda_available():
|
|
1228
1518
|
raise RuntimeError("Failed to build CUDA module because CUDA is not available")
|
|
1229
1519
|
|
|
1230
|
-
with
|
|
1520
|
+
with ScopedTimer(f"Module {self.name} load on device '{device}'", active=not warp.config.quiet):
|
|
1231
1521
|
build_path = warp.build.kernel_bin_dir
|
|
1232
1522
|
gen_path = warp.build.kernel_gen_dir
|
|
1233
1523
|
|
|
@@ -1238,89 +1528,54 @@ class Module:
|
|
|
1238
1528
|
|
|
1239
1529
|
module_name = "wp_" + self.name
|
|
1240
1530
|
module_path = os.path.join(build_path, module_name)
|
|
1241
|
-
obj_path = os.path.join(gen_path, module_name)
|
|
1242
1531
|
module_hash = self.hash_module()
|
|
1243
1532
|
|
|
1244
1533
|
builder = ModuleBuilder(self, self.options)
|
|
1245
1534
|
|
|
1246
1535
|
if device.is_cpu:
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
dll_path = obj_path + ".cpp.obj"
|
|
1250
|
-
else:
|
|
1251
|
-
dll_path = obj_path + ".cpp.o"
|
|
1252
|
-
else:
|
|
1253
|
-
if os.name == "nt":
|
|
1254
|
-
dll_path = module_path + ".dll"
|
|
1255
|
-
else:
|
|
1256
|
-
dll_path = module_path + ".so"
|
|
1257
|
-
|
|
1536
|
+
obj_path = os.path.join(build_path, module_name)
|
|
1537
|
+
obj_path = obj_path + ".o"
|
|
1258
1538
|
cpu_hash_path = module_path + ".cpu.hash"
|
|
1259
1539
|
|
|
1260
1540
|
# check cache
|
|
1261
|
-
if warp.config.cache_kernels and os.path.isfile(cpu_hash_path) and os.path.isfile(
|
|
1541
|
+
if warp.config.cache_kernels and os.path.isfile(cpu_hash_path) and os.path.isfile(obj_path):
|
|
1262
1542
|
with open(cpu_hash_path, "rb") as f:
|
|
1263
1543
|
cache_hash = f.read()
|
|
1264
1544
|
|
|
1265
1545
|
if cache_hash == module_hash:
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
return True
|
|
1270
|
-
else:
|
|
1271
|
-
self.dll = warp.build.load_dll(dll_path)
|
|
1272
|
-
if self.dll is not None:
|
|
1273
|
-
return True
|
|
1546
|
+
runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
|
|
1547
|
+
self.cpu_module = module_name
|
|
1548
|
+
return True
|
|
1274
1549
|
|
|
1275
1550
|
# build
|
|
1276
1551
|
try:
|
|
1277
1552
|
cpp_path = os.path.join(gen_path, module_name + ".cpp")
|
|
1278
1553
|
|
|
1279
1554
|
# write cpp sources
|
|
1280
|
-
cpp_source = builder.
|
|
1555
|
+
cpp_source = builder.codegen("cpu")
|
|
1281
1556
|
|
|
1282
1557
|
cpp_file = open(cpp_path, "w")
|
|
1283
1558
|
cpp_file.write(cpp_source)
|
|
1284
1559
|
cpp_file.close()
|
|
1285
1560
|
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
elif sys.platform == "darwin":
|
|
1292
|
-
libs = [f"-lwarp", f"-L{bin_path}", f"-Wl,-rpath,'{bin_path}'"]
|
|
1293
|
-
else:
|
|
1294
|
-
libs = ["-l:warp.so", f"-L{bin_path}", f"-Wl,-rpath,'{bin_path}'"]
|
|
1295
|
-
|
|
1296
|
-
# build DLL or object code
|
|
1297
|
-
with warp.utils.ScopedTimer("Compile x86", active=warp.config.verbose):
|
|
1298
|
-
warp.build.build_dll(
|
|
1299
|
-
dll_path,
|
|
1300
|
-
[cpp_path],
|
|
1301
|
-
None,
|
|
1302
|
-
libs,
|
|
1561
|
+
# build object code
|
|
1562
|
+
with ScopedTimer("Compile x86", active=warp.config.verbose):
|
|
1563
|
+
warp.build.build_cpu(
|
|
1564
|
+
obj_path,
|
|
1565
|
+
cpp_path,
|
|
1303
1566
|
mode=self.options["mode"],
|
|
1304
1567
|
fast_math=self.options["fast_math"],
|
|
1305
1568
|
verify_fp=warp.config.verify_fp,
|
|
1306
1569
|
)
|
|
1307
1570
|
|
|
1308
|
-
if runtime.llvm:
|
|
1309
|
-
# load the object code
|
|
1310
|
-
obj_ext = ".obj" if os.name == "nt" else ".o"
|
|
1311
|
-
obj_path = cpp_path + obj_ext
|
|
1312
|
-
runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
|
|
1313
|
-
self.cpu_module = module_name
|
|
1314
|
-
else:
|
|
1315
|
-
# load the DLL
|
|
1316
|
-
self.dll = warp.build.load_dll(dll_path)
|
|
1317
|
-
if self.dll is None:
|
|
1318
|
-
raise Exception("Failed to load CPU module")
|
|
1319
|
-
|
|
1320
1571
|
# update cpu hash
|
|
1321
1572
|
with open(cpu_hash_path, "wb") as f:
|
|
1322
1573
|
f.write(module_hash)
|
|
1323
1574
|
|
|
1575
|
+
# load the object code
|
|
1576
|
+
runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
|
|
1577
|
+
self.cpu_module = module_name
|
|
1578
|
+
|
|
1324
1579
|
except Exception as e:
|
|
1325
1580
|
self.cpu_build_failed = True
|
|
1326
1581
|
raise (e)
|
|
@@ -1365,14 +1620,14 @@ class Module:
|
|
|
1365
1620
|
cu_path = os.path.join(gen_path, module_name + ".cu")
|
|
1366
1621
|
|
|
1367
1622
|
# write cuda sources
|
|
1368
|
-
cu_source = builder.
|
|
1623
|
+
cu_source = builder.codegen("cuda")
|
|
1369
1624
|
|
|
1370
1625
|
cu_file = open(cu_path, "w")
|
|
1371
1626
|
cu_file.write(cu_source)
|
|
1372
1627
|
cu_file.close()
|
|
1373
1628
|
|
|
1374
1629
|
# generate PTX or CUBIN
|
|
1375
|
-
with
|
|
1630
|
+
with ScopedTimer("Compile CUDA", active=warp.config.verbose):
|
|
1376
1631
|
warp.build.build_cuda(
|
|
1377
1632
|
cu_path,
|
|
1378
1633
|
output_arch,
|
|
@@ -1382,6 +1637,10 @@ class Module:
|
|
|
1382
1637
|
verify_fp=warp.config.verify_fp,
|
|
1383
1638
|
)
|
|
1384
1639
|
|
|
1640
|
+
# update cuda hash
|
|
1641
|
+
with open(cuda_hash_path, "wb") as f:
|
|
1642
|
+
f.write(module_hash)
|
|
1643
|
+
|
|
1385
1644
|
# load the module
|
|
1386
1645
|
cuda_module = warp.build.load_cuda(output_path, device)
|
|
1387
1646
|
if cuda_module is not None:
|
|
@@ -1389,10 +1648,6 @@ class Module:
|
|
|
1389
1648
|
else:
|
|
1390
1649
|
raise Exception("Failed to load CUDA module")
|
|
1391
1650
|
|
|
1392
|
-
# update cuda hash
|
|
1393
|
-
with open(cuda_hash_path, "wb") as f:
|
|
1394
|
-
f.write(module_hash)
|
|
1395
|
-
|
|
1396
1651
|
except Exception as e:
|
|
1397
1652
|
self.cuda_build_failed = True
|
|
1398
1653
|
raise (e)
|
|
@@ -1400,10 +1655,6 @@ class Module:
|
|
|
1400
1655
|
return True
|
|
1401
1656
|
|
|
1402
1657
|
def unload(self):
|
|
1403
|
-
if self.dll:
|
|
1404
|
-
warp.build.unload_dll(self.dll)
|
|
1405
|
-
self.dll = None
|
|
1406
|
-
|
|
1407
1658
|
if self.cpu_module:
|
|
1408
1659
|
runtime.llvm.unload_obj(self.cpu_module.encode("utf-8"))
|
|
1409
1660
|
self.cpu_module = None
|
|
@@ -1438,17 +1689,13 @@ class Module:
|
|
|
1438
1689
|
name = kernel.get_mangled_name()
|
|
1439
1690
|
|
|
1440
1691
|
if device.is_cpu:
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
)
|
|
1449
|
-
else:
|
|
1450
|
-
forward = eval("self.dll." + name + "_cpu_forward")
|
|
1451
|
-
backward = eval("self.dll." + name + "_cpu_backward")
|
|
1692
|
+
func = ctypes.CFUNCTYPE(None)
|
|
1693
|
+
forward = func(
|
|
1694
|
+
runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))
|
|
1695
|
+
)
|
|
1696
|
+
backward = func(
|
|
1697
|
+
runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))
|
|
1698
|
+
)
|
|
1452
1699
|
else:
|
|
1453
1700
|
cu_module = self.cuda_modules[device.context]
|
|
1454
1701
|
forward = runtime.core.cuda_get_kernel(
|
|
@@ -1475,6 +1722,8 @@ class Allocator:
|
|
|
1475
1722
|
|
|
1476
1723
|
def alloc(self, size_in_bytes, pinned=False):
|
|
1477
1724
|
if self.device.is_cuda:
|
|
1725
|
+
if self.device.is_capturing:
|
|
1726
|
+
raise RuntimeError(f"Cannot allocate memory on device {self} while graph capture is active")
|
|
1478
1727
|
return runtime.core.alloc_device(self.device.context, size_in_bytes)
|
|
1479
1728
|
elif self.device.is_cpu:
|
|
1480
1729
|
if pinned:
|
|
@@ -1484,6 +1733,8 @@ class Allocator:
|
|
|
1484
1733
|
|
|
1485
1734
|
def free(self, ptr, size_in_bytes, pinned=False):
|
|
1486
1735
|
if self.device.is_cuda:
|
|
1736
|
+
if self.device.is_capturing:
|
|
1737
|
+
raise RuntimeError(f"Cannot free memory on device {self} while graph capture is active")
|
|
1487
1738
|
return runtime.core.free_device(self.device.context, ptr)
|
|
1488
1739
|
elif self.device.is_cpu:
|
|
1489
1740
|
if pinned:
|
|
@@ -1499,13 +1750,13 @@ class ContextGuard:
|
|
|
1499
1750
|
def __enter__(self):
|
|
1500
1751
|
if self.device.is_cuda:
|
|
1501
1752
|
runtime.core.cuda_context_push_current(self.device.context)
|
|
1502
|
-
elif
|
|
1753
|
+
elif is_cuda_driver_initialized():
|
|
1503
1754
|
self.saved_context = runtime.core.cuda_context_get_current()
|
|
1504
1755
|
|
|
1505
1756
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
1506
1757
|
if self.device.is_cuda:
|
|
1507
1758
|
runtime.core.cuda_context_pop_current()
|
|
1508
|
-
elif
|
|
1759
|
+
elif is_cuda_driver_initialized():
|
|
1509
1760
|
runtime.core.cuda_context_set_current(self.saved_context)
|
|
1510
1761
|
|
|
1511
1762
|
|
|
@@ -1596,6 +1847,29 @@ class Event:
|
|
|
1596
1847
|
|
|
1597
1848
|
|
|
1598
1849
|
class Device:
|
|
1850
|
+
"""A device to allocate Warp arrays and to launch kernels on.
|
|
1851
|
+
|
|
1852
|
+
Attributes:
|
|
1853
|
+
ordinal: A Warp-specific integer label for the device. ``-1`` for CPU devices.
|
|
1854
|
+
name: A string label for the device. By default, CPU devices will be named according to the processor name,
|
|
1855
|
+
or ``"CPU"`` if the processor name cannot be determined.
|
|
1856
|
+
arch: An integer representing the compute capability version number calculated as
|
|
1857
|
+
``10 * major + minor``. ``0`` for CPU devices.
|
|
1858
|
+
is_uva: A boolean indicating whether or not the device supports unified addressing.
|
|
1859
|
+
``False`` for CPU devices.
|
|
1860
|
+
is_cubin_supported: A boolean indicating whether or not Warp's version of NVRTC can directly
|
|
1861
|
+
generate CUDA binary files (cubin) for this device's architecture. ``False`` for CPU devices.
|
|
1862
|
+
is_mempool_supported: A boolean indicating whether or not the device supports using the
|
|
1863
|
+
``cuMemAllocAsync`` and ``cuMemPool`` family of APIs for stream-ordered memory allocations. ``False`` for
|
|
1864
|
+
CPU devices.
|
|
1865
|
+
is_primary: A boolean indicating whether or not this device's CUDA context is also the
|
|
1866
|
+
device's primary context.
|
|
1867
|
+
uuid: A string representing the UUID of the CUDA device. The UUID is in the same format used by
|
|
1868
|
+
``nvidia-smi -L``. ``None`` for CPU devices.
|
|
1869
|
+
pci_bus_id: A string identifier for the CUDA device in the format ``[domain]:[bus]:[device]``, in which
|
|
1870
|
+
``domain``, ``bus``, and ``device`` are all hexadecimal values. ``None`` for CPU devices.
|
|
1871
|
+
"""
|
|
1872
|
+
|
|
1599
1873
|
def __init__(self, runtime, alias, ordinal=-1, is_primary=False, context=None):
|
|
1600
1874
|
self.runtime = runtime
|
|
1601
1875
|
self.alias = alias
|
|
@@ -1625,6 +1899,9 @@ class Device:
|
|
|
1625
1899
|
self.arch = 0
|
|
1626
1900
|
self.is_uva = False
|
|
1627
1901
|
self.is_cubin_supported = False
|
|
1902
|
+
self.is_mempool_supported = False
|
|
1903
|
+
self.uuid = None
|
|
1904
|
+
self.pci_bus_id = None
|
|
1628
1905
|
|
|
1629
1906
|
# TODO: add more device-specific dispatch functions
|
|
1630
1907
|
self.memset = runtime.core.memset_host
|
|
@@ -1637,6 +1914,26 @@ class Device:
|
|
|
1637
1914
|
self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
|
|
1638
1915
|
# check whether our NVRTC can generate CUBINs for this architecture
|
|
1639
1916
|
self.is_cubin_supported = self.arch in runtime.nvrtc_supported_archs
|
|
1917
|
+
self.is_mempool_supported = runtime.core.cuda_device_is_memory_pool_supported(ordinal)
|
|
1918
|
+
|
|
1919
|
+
uuid_buffer = (ctypes.c_char * 16)()
|
|
1920
|
+
runtime.core.cuda_device_get_uuid(ordinal, uuid_buffer)
|
|
1921
|
+
uuid_byte_str = bytes(uuid_buffer).hex()
|
|
1922
|
+
self.uuid = f"GPU-{uuid_byte_str[0:8]}-{uuid_byte_str[8:12]}-{uuid_byte_str[12:16]}-{uuid_byte_str[16:20]}-{uuid_byte_str[20:]}"
|
|
1923
|
+
|
|
1924
|
+
pci_domain_id = runtime.core.cuda_device_get_pci_domain_id(ordinal)
|
|
1925
|
+
pci_bus_id = runtime.core.cuda_device_get_pci_bus_id(ordinal)
|
|
1926
|
+
pci_device_id = runtime.core.cuda_device_get_pci_device_id(ordinal)
|
|
1927
|
+
# This is (mis)named to correspond to the naming of cudaDeviceGetPCIBusId
|
|
1928
|
+
self.pci_bus_id = f"{pci_domain_id:08X}:{pci_bus_id:02X}:{pci_device_id:02X}"
|
|
1929
|
+
|
|
1930
|
+
# Warn the user of a possible misconfiguration of their system
|
|
1931
|
+
if not self.is_mempool_supported:
|
|
1932
|
+
warp.utils.warn(
|
|
1933
|
+
f"Support for stream ordered memory allocators was not detected on device {ordinal}. "
|
|
1934
|
+
"This can prevent the use of graphs and/or result in poor performance. "
|
|
1935
|
+
"Is the UVM driver enabled?"
|
|
1936
|
+
)
|
|
1640
1937
|
|
|
1641
1938
|
# initialize streams unless context acquisition is postponed
|
|
1642
1939
|
if self._context is not None:
|
|
@@ -1660,14 +1957,17 @@ class Device:
|
|
|
1660
1957
|
|
|
1661
1958
|
@property
|
|
1662
1959
|
def is_cpu(self):
|
|
1960
|
+
"""A boolean indicating whether or not the device is a CPU device."""
|
|
1663
1961
|
return self.ordinal < 0
|
|
1664
1962
|
|
|
1665
1963
|
@property
|
|
1666
1964
|
def is_cuda(self):
|
|
1965
|
+
"""A boolean indicating whether or not the device is a CUDA device."""
|
|
1667
1966
|
return self.ordinal >= 0
|
|
1668
1967
|
|
|
1669
1968
|
@property
|
|
1670
1969
|
def context(self):
|
|
1970
|
+
"""The context associated with the device."""
|
|
1671
1971
|
if self._context is not None:
|
|
1672
1972
|
return self._context
|
|
1673
1973
|
elif self.is_primary:
|
|
@@ -1682,10 +1982,16 @@ class Device:
|
|
|
1682
1982
|
|
|
1683
1983
|
@property
|
|
1684
1984
|
def has_context(self):
|
|
1985
|
+
"""A boolean indicating whether or not the device has a CUDA context associated with it."""
|
|
1685
1986
|
return self._context is not None
|
|
1686
1987
|
|
|
1687
1988
|
@property
|
|
1688
1989
|
def stream(self):
|
|
1990
|
+
"""The stream associated with a CUDA device.
|
|
1991
|
+
|
|
1992
|
+
Raises:
|
|
1993
|
+
RuntimeError: The device is not a CUDA device.
|
|
1994
|
+
"""
|
|
1689
1995
|
if self.context:
|
|
1690
1996
|
return self._stream
|
|
1691
1997
|
else:
|
|
@@ -1703,6 +2009,7 @@ class Device:
|
|
|
1703
2009
|
|
|
1704
2010
|
@property
|
|
1705
2011
|
def has_stream(self):
|
|
2012
|
+
"""A boolean indicating whether or not the device has a stream associated with it."""
|
|
1706
2013
|
return self._stream is not None
|
|
1707
2014
|
|
|
1708
2015
|
def __str__(self):
|
|
@@ -1778,10 +2085,10 @@ class Runtime:
|
|
|
1778
2085
|
warp_lib = os.path.join(bin_path, "warp.so")
|
|
1779
2086
|
llvm_lib = os.path.join(bin_path, "warp-clang.so")
|
|
1780
2087
|
|
|
1781
|
-
self.core =
|
|
2088
|
+
self.core = self.load_dll(warp_lib)
|
|
1782
2089
|
|
|
1783
|
-
if
|
|
1784
|
-
self.llvm =
|
|
2090
|
+
if os.path.exists(llvm_lib):
|
|
2091
|
+
self.llvm = self.load_dll(llvm_lib)
|
|
1785
2092
|
# setup c-types for warp-clang.dll
|
|
1786
2093
|
self.llvm.lookup.restype = ctypes.c_uint64
|
|
1787
2094
|
else:
|
|
@@ -1852,11 +2159,106 @@ class Runtime:
|
|
|
1852
2159
|
]
|
|
1853
2160
|
self.core.array_copy_device.restype = ctypes.c_size_t
|
|
1854
2161
|
|
|
2162
|
+
self.core.array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
|
|
2163
|
+
self.core.array_fill_host.restype = None
|
|
2164
|
+
self.core.array_fill_device.argtypes = [
|
|
2165
|
+
ctypes.c_void_p,
|
|
2166
|
+
ctypes.c_void_p,
|
|
2167
|
+
ctypes.c_int,
|
|
2168
|
+
ctypes.c_void_p,
|
|
2169
|
+
ctypes.c_int,
|
|
2170
|
+
]
|
|
2171
|
+
self.core.array_fill_device.restype = None
|
|
2172
|
+
|
|
2173
|
+
self.core.array_sum_double_host.argtypes = [
|
|
2174
|
+
ctypes.c_uint64,
|
|
2175
|
+
ctypes.c_uint64,
|
|
2176
|
+
ctypes.c_int,
|
|
2177
|
+
ctypes.c_int,
|
|
2178
|
+
ctypes.c_int,
|
|
2179
|
+
]
|
|
2180
|
+
self.core.array_sum_float_host.argtypes = [
|
|
2181
|
+
ctypes.c_uint64,
|
|
2182
|
+
ctypes.c_uint64,
|
|
2183
|
+
ctypes.c_int,
|
|
2184
|
+
ctypes.c_int,
|
|
2185
|
+
ctypes.c_int,
|
|
2186
|
+
]
|
|
2187
|
+
self.core.array_sum_double_device.argtypes = [
|
|
2188
|
+
ctypes.c_uint64,
|
|
2189
|
+
ctypes.c_uint64,
|
|
2190
|
+
ctypes.c_int,
|
|
2191
|
+
ctypes.c_int,
|
|
2192
|
+
ctypes.c_int,
|
|
2193
|
+
]
|
|
2194
|
+
self.core.array_sum_float_device.argtypes = [
|
|
2195
|
+
ctypes.c_uint64,
|
|
2196
|
+
ctypes.c_uint64,
|
|
2197
|
+
ctypes.c_int,
|
|
2198
|
+
ctypes.c_int,
|
|
2199
|
+
ctypes.c_int,
|
|
2200
|
+
]
|
|
2201
|
+
|
|
2202
|
+
self.core.array_inner_double_host.argtypes = [
|
|
2203
|
+
ctypes.c_uint64,
|
|
2204
|
+
ctypes.c_uint64,
|
|
2205
|
+
ctypes.c_uint64,
|
|
2206
|
+
ctypes.c_int,
|
|
2207
|
+
ctypes.c_int,
|
|
2208
|
+
ctypes.c_int,
|
|
2209
|
+
ctypes.c_int,
|
|
2210
|
+
]
|
|
2211
|
+
self.core.array_inner_float_host.argtypes = [
|
|
2212
|
+
ctypes.c_uint64,
|
|
2213
|
+
ctypes.c_uint64,
|
|
2214
|
+
ctypes.c_uint64,
|
|
2215
|
+
ctypes.c_int,
|
|
2216
|
+
ctypes.c_int,
|
|
2217
|
+
ctypes.c_int,
|
|
2218
|
+
ctypes.c_int,
|
|
2219
|
+
]
|
|
2220
|
+
self.core.array_inner_double_device.argtypes = [
|
|
2221
|
+
ctypes.c_uint64,
|
|
2222
|
+
ctypes.c_uint64,
|
|
2223
|
+
ctypes.c_uint64,
|
|
2224
|
+
ctypes.c_int,
|
|
2225
|
+
ctypes.c_int,
|
|
2226
|
+
ctypes.c_int,
|
|
2227
|
+
ctypes.c_int,
|
|
2228
|
+
]
|
|
2229
|
+
self.core.array_inner_float_device.argtypes = [
|
|
2230
|
+
ctypes.c_uint64,
|
|
2231
|
+
ctypes.c_uint64,
|
|
2232
|
+
ctypes.c_uint64,
|
|
2233
|
+
ctypes.c_int,
|
|
2234
|
+
ctypes.c_int,
|
|
2235
|
+
ctypes.c_int,
|
|
2236
|
+
ctypes.c_int,
|
|
2237
|
+
]
|
|
2238
|
+
|
|
1855
2239
|
self.core.array_scan_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
|
|
1856
2240
|
self.core.array_scan_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
|
|
1857
2241
|
self.core.array_scan_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
|
|
1858
2242
|
self.core.array_scan_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
|
|
1859
2243
|
|
|
2244
|
+
self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
2245
|
+
self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
|
|
2246
|
+
|
|
2247
|
+
self.core.runlength_encode_int_host.argtypes = [
|
|
2248
|
+
ctypes.c_uint64,
|
|
2249
|
+
ctypes.c_uint64,
|
|
2250
|
+
ctypes.c_uint64,
|
|
2251
|
+
ctypes.c_uint64,
|
|
2252
|
+
ctypes.c_int,
|
|
2253
|
+
]
|
|
2254
|
+
self.core.runlength_encode_int_device.argtypes = [
|
|
2255
|
+
ctypes.c_uint64,
|
|
2256
|
+
ctypes.c_uint64,
|
|
2257
|
+
ctypes.c_uint64,
|
|
2258
|
+
ctypes.c_uint64,
|
|
2259
|
+
ctypes.c_int,
|
|
2260
|
+
]
|
|
2261
|
+
|
|
1860
2262
|
self.core.bvh_create_host.restype = ctypes.c_uint64
|
|
1861
2263
|
self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
|
|
1862
2264
|
|
|
@@ -1876,6 +2278,7 @@ class Runtime:
|
|
|
1876
2278
|
warp.types.array_t,
|
|
1877
2279
|
ctypes.c_int,
|
|
1878
2280
|
ctypes.c_int,
|
|
2281
|
+
ctypes.c_int,
|
|
1879
2282
|
]
|
|
1880
2283
|
|
|
1881
2284
|
self.core.mesh_create_device.restype = ctypes.c_uint64
|
|
@@ -1886,6 +2289,7 @@ class Runtime:
|
|
|
1886
2289
|
warp.types.array_t,
|
|
1887
2290
|
ctypes.c_int,
|
|
1888
2291
|
ctypes.c_int,
|
|
2292
|
+
ctypes.c_int,
|
|
1889
2293
|
]
|
|
1890
2294
|
|
|
1891
2295
|
self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
|
|
@@ -1998,6 +2402,46 @@ class Runtime:
|
|
|
1998
2402
|
ctypes.POINTER(ctypes.c_float),
|
|
1999
2403
|
]
|
|
2000
2404
|
|
|
2405
|
+
bsr_matrix_from_triplets_argtypes = [
|
|
2406
|
+
ctypes.c_int,
|
|
2407
|
+
ctypes.c_int,
|
|
2408
|
+
ctypes.c_int,
|
|
2409
|
+
ctypes.c_int,
|
|
2410
|
+
ctypes.c_uint64,
|
|
2411
|
+
ctypes.c_uint64,
|
|
2412
|
+
ctypes.c_uint64,
|
|
2413
|
+
ctypes.c_uint64,
|
|
2414
|
+
ctypes.c_uint64,
|
|
2415
|
+
ctypes.c_uint64,
|
|
2416
|
+
]
|
|
2417
|
+
self.core.bsr_matrix_from_triplets_float_host.argtypes = bsr_matrix_from_triplets_argtypes
|
|
2418
|
+
self.core.bsr_matrix_from_triplets_double_host.argtypes = bsr_matrix_from_triplets_argtypes
|
|
2419
|
+
self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
2420
|
+
self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
|
|
2421
|
+
|
|
2422
|
+
self.core.bsr_matrix_from_triplets_float_host.restype = ctypes.c_int
|
|
2423
|
+
self.core.bsr_matrix_from_triplets_double_host.restype = ctypes.c_int
|
|
2424
|
+
self.core.bsr_matrix_from_triplets_float_device.restype = ctypes.c_int
|
|
2425
|
+
self.core.bsr_matrix_from_triplets_double_device.restype = ctypes.c_int
|
|
2426
|
+
|
|
2427
|
+
bsr_transpose_argtypes = [
|
|
2428
|
+
ctypes.c_int,
|
|
2429
|
+
ctypes.c_int,
|
|
2430
|
+
ctypes.c_int,
|
|
2431
|
+
ctypes.c_int,
|
|
2432
|
+
ctypes.c_int,
|
|
2433
|
+
ctypes.c_uint64,
|
|
2434
|
+
ctypes.c_uint64,
|
|
2435
|
+
ctypes.c_uint64,
|
|
2436
|
+
ctypes.c_uint64,
|
|
2437
|
+
ctypes.c_uint64,
|
|
2438
|
+
ctypes.c_uint64,
|
|
2439
|
+
]
|
|
2440
|
+
self.core.bsr_transpose_float_host.argtypes = bsr_transpose_argtypes
|
|
2441
|
+
self.core.bsr_transpose_double_host.argtypes = bsr_transpose_argtypes
|
|
2442
|
+
self.core.bsr_transpose_float_device.argtypes = bsr_transpose_argtypes
|
|
2443
|
+
self.core.bsr_transpose_double_device.argtypes = bsr_transpose_argtypes
|
|
2444
|
+
|
|
2001
2445
|
self.core.is_cuda_enabled.argtypes = None
|
|
2002
2446
|
self.core.is_cuda_enabled.restype = ctypes.c_int
|
|
2003
2447
|
self.core.is_cuda_compatibility_enabled.argtypes = None
|
|
@@ -2009,6 +2453,8 @@ class Runtime:
|
|
|
2009
2453
|
self.core.cuda_driver_version.restype = ctypes.c_int
|
|
2010
2454
|
self.core.cuda_toolkit_version.argtypes = None
|
|
2011
2455
|
self.core.cuda_toolkit_version.restype = ctypes.c_int
|
|
2456
|
+
self.core.cuda_driver_is_initialized.argtypes = None
|
|
2457
|
+
self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
|
|
2012
2458
|
|
|
2013
2459
|
self.core.nvrtc_supported_arch_count.argtypes = None
|
|
2014
2460
|
self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
|
|
@@ -2025,6 +2471,14 @@ class Runtime:
|
|
|
2025
2471
|
self.core.cuda_device_get_arch.restype = ctypes.c_int
|
|
2026
2472
|
self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
|
|
2027
2473
|
self.core.cuda_device_is_uva.restype = ctypes.c_int
|
|
2474
|
+
self.core.cuda_device_get_uuid.argtypes = [ctypes.c_int, ctypes.c_char * 16]
|
|
2475
|
+
self.core.cuda_device_get_uuid.restype = None
|
|
2476
|
+
self.core.cuda_device_get_pci_domain_id.argtypes = [ctypes.c_int]
|
|
2477
|
+
self.core.cuda_device_get_pci_domain_id.restype = ctypes.c_int
|
|
2478
|
+
self.core.cuda_device_get_pci_bus_id.argtypes = [ctypes.c_int]
|
|
2479
|
+
self.core.cuda_device_get_pci_bus_id.restype = ctypes.c_int
|
|
2480
|
+
self.core.cuda_device_get_pci_device_id.argtypes = [ctypes.c_int]
|
|
2481
|
+
self.core.cuda_device_get_pci_device_id.restype = ctypes.c_int
|
|
2028
2482
|
|
|
2029
2483
|
self.core.cuda_context_get_current.argtypes = None
|
|
2030
2484
|
self.core.cuda_context_get_current.restype = ctypes.c_void_p
|
|
@@ -2111,6 +2565,7 @@ class Runtime:
|
|
|
2111
2565
|
ctypes.c_void_p,
|
|
2112
2566
|
ctypes.c_void_p,
|
|
2113
2567
|
ctypes.c_size_t,
|
|
2568
|
+
ctypes.c_int,
|
|
2114
2569
|
ctypes.POINTER(ctypes.c_void_p),
|
|
2115
2570
|
]
|
|
2116
2571
|
self.core.cuda_launch_kernel.restype = ctypes.c_size_t
|
|
@@ -2140,7 +2595,6 @@ class Runtime:
|
|
|
2140
2595
|
|
|
2141
2596
|
self.device_map = {} # device lookup by alias
|
|
2142
2597
|
self.context_map = {} # device lookup by context
|
|
2143
|
-
self.graph_capture_map = {} # indicates whether graph capture is active for a given device
|
|
2144
2598
|
|
|
2145
2599
|
# register CPU device
|
|
2146
2600
|
cpu_name = platform.processor()
|
|
@@ -2149,7 +2603,6 @@ class Runtime:
|
|
|
2149
2603
|
self.cpu_device = Device(self, "cpu")
|
|
2150
2604
|
self.device_map["cpu"] = self.cpu_device
|
|
2151
2605
|
self.context_map[None] = self.cpu_device
|
|
2152
|
-
self.graph_capture_map[None] = False
|
|
2153
2606
|
|
|
2154
2607
|
cuda_device_count = self.core.cuda_device_get_count()
|
|
2155
2608
|
|
|
@@ -2183,12 +2636,9 @@ class Runtime:
|
|
|
2183
2636
|
self.set_default_device("cuda")
|
|
2184
2637
|
else:
|
|
2185
2638
|
self.set_default_device("cuda:0")
|
|
2186
|
-
# save the initial CUDA device for backward compatibility with ScopedCudaGuard
|
|
2187
|
-
self.initial_cuda_device = self.default_device
|
|
2188
2639
|
else:
|
|
2189
2640
|
# CUDA not available
|
|
2190
2641
|
self.set_default_device("cpu")
|
|
2191
|
-
self.initial_cuda_device = None
|
|
2192
2642
|
|
|
2193
2643
|
# initialize kernel cache
|
|
2194
2644
|
warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
|
|
@@ -2230,6 +2680,23 @@ class Runtime:
|
|
|
2230
2680
|
# global tape
|
|
2231
2681
|
self.tape = None
|
|
2232
2682
|
|
|
2683
|
+
def load_dll(self, dll_path):
|
|
2684
|
+
try:
|
|
2685
|
+
if sys.version_info[0] > 3 or sys.version_info[0] == 3 and sys.version_info[1] >= 8:
|
|
2686
|
+
dll = ctypes.CDLL(dll_path, winmode=0)
|
|
2687
|
+
else:
|
|
2688
|
+
dll = ctypes.CDLL(dll_path)
|
|
2689
|
+
except OSError as e:
|
|
2690
|
+
if "GLIBCXX" in str(e):
|
|
2691
|
+
raise RuntimeError(
|
|
2692
|
+
f"Failed to load the shared library '{dll_path}'.\n"
|
|
2693
|
+
"The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n"
|
|
2694
|
+
"See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details."
|
|
2695
|
+
) from e
|
|
2696
|
+
else:
|
|
2697
|
+
raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e
|
|
2698
|
+
return dll
|
|
2699
|
+
|
|
2233
2700
|
def get_device(self, ident: Devicelike = None) -> Device:
|
|
2234
2701
|
if isinstance(ident, Device):
|
|
2235
2702
|
return ident
|
|
@@ -2345,15 +2812,7 @@ def assert_initialized():
|
|
|
2345
2812
|
|
|
2346
2813
|
# global entry points
|
|
2347
2814
|
def is_cpu_available():
|
|
2348
|
-
|
|
2349
|
-
return True
|
|
2350
|
-
|
|
2351
|
-
# initialize host build env (do this lazily) since
|
|
2352
|
-
# it takes 5secs to run all the batch files to locate MSVC
|
|
2353
|
-
if warp.config.host_compiler is None:
|
|
2354
|
-
warp.config.host_compiler = warp.build.find_host_compiler()
|
|
2355
|
-
|
|
2356
|
-
return warp.config.host_compiler != ""
|
|
2815
|
+
return runtime.llvm
|
|
2357
2816
|
|
|
2358
2817
|
|
|
2359
2818
|
def is_cuda_available():
|
|
@@ -2364,6 +2823,21 @@ def is_device_available(device):
|
|
|
2364
2823
|
return device in get_devices()
|
|
2365
2824
|
|
|
2366
2825
|
|
|
2826
|
+
def is_cuda_driver_initialized() -> bool:
|
|
2827
|
+
"""Returns ``True`` if the CUDA driver is initialized.
|
|
2828
|
+
|
|
2829
|
+
This is a stricter test than ``is_cuda_available()`` since a CUDA driver
|
|
2830
|
+
call to ``cuCtxGetCurrent`` is made, and the result is compared to
|
|
2831
|
+
`CUDA_SUCCESS`. Note that `CUDA_SUCCESS` is returned by ``cuCtxGetCurrent``
|
|
2832
|
+
even if there is no context bound to the calling CPU thread.
|
|
2833
|
+
|
|
2834
|
+
This can be helpful in cases in which ``cuInit()`` was called before a fork.
|
|
2835
|
+
"""
|
|
2836
|
+
assert_initialized()
|
|
2837
|
+
|
|
2838
|
+
return runtime.core.cuda_driver_is_initialized()
|
|
2839
|
+
|
|
2840
|
+
|
|
2367
2841
|
def get_devices() -> List[Device]:
|
|
2368
2842
|
"""Returns a list of devices supported in this environment."""
|
|
2369
2843
|
|
|
@@ -2590,63 +3064,53 @@ def zeros(
|
|
|
2590
3064
|
A warp.array object representing the allocation
|
|
2591
3065
|
"""
|
|
2592
3066
|
|
|
2593
|
-
|
|
2594
|
-
if isinstance(shape, int):
|
|
2595
|
-
shape = (shape,)
|
|
2596
|
-
elif "n" in kwargs:
|
|
2597
|
-
shape = (kwargs["n"],)
|
|
3067
|
+
arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
|
|
2598
3068
|
|
|
2599
|
-
#
|
|
2600
|
-
|
|
2601
|
-
|
|
2602
|
-
num_elements *= d
|
|
3069
|
+
# use the CUDA default stream for synchronous behaviour with other streams
|
|
3070
|
+
with warp.ScopedStream(arr.device.null_stream):
|
|
3071
|
+
arr.zero_()
|
|
2603
3072
|
|
|
2604
|
-
|
|
3073
|
+
return arr
|
|
2605
3074
|
|
|
2606
|
-
device = get_device(device)
|
|
2607
3075
|
|
|
2608
|
-
|
|
2609
|
-
|
|
3076
|
+
def zeros_like(
|
|
3077
|
+
src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
|
|
3078
|
+
) -> warp.array:
|
|
3079
|
+
"""Return a zero-initialized array with the same type and dimension of another array
|
|
2610
3080
|
|
|
2611
|
-
|
|
2612
|
-
|
|
2613
|
-
|
|
2614
|
-
|
|
2615
|
-
|
|
2616
|
-
|
|
2617
|
-
|
|
2618
|
-
|
|
2619
|
-
|
|
2620
|
-
with warp.ScopedStream(device.null_stream):
|
|
2621
|
-
device.memset(ptr, 0, num_bytes)
|
|
2622
|
-
|
|
2623
|
-
if requires_grad:
|
|
2624
|
-
# allocate gradient array
|
|
2625
|
-
grad_ptr = device.allocator.alloc(num_bytes, pinned=pinned)
|
|
2626
|
-
if grad_ptr is None:
|
|
2627
|
-
raise RuntimeError("Memory allocation failed on device: {} for {} bytes".format(device, num_bytes))
|
|
2628
|
-
with warp.ScopedStream(device.null_stream):
|
|
2629
|
-
device.memset(grad_ptr, 0, num_bytes)
|
|
2630
|
-
|
|
2631
|
-
# construct array
|
|
2632
|
-
return warp.types.array(
|
|
2633
|
-
dtype=dtype,
|
|
2634
|
-
shape=shape,
|
|
2635
|
-
capacity=num_bytes,
|
|
2636
|
-
ptr=ptr,
|
|
2637
|
-
grad_ptr=grad_ptr,
|
|
2638
|
-
device=device,
|
|
2639
|
-
owner=True,
|
|
2640
|
-
requires_grad=requires_grad,
|
|
2641
|
-
pinned=pinned,
|
|
2642
|
-
)
|
|
3081
|
+
Args:
|
|
3082
|
+
src: The template array to use for shape, data type, and device
|
|
3083
|
+
device: The device where the new array will be created (defaults to src.device)
|
|
3084
|
+
requires_grad: Whether the array will be tracked for back propagation
|
|
3085
|
+
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
|
|
3086
|
+
|
|
3087
|
+
Returns:
|
|
3088
|
+
A warp.array object representing the allocation
|
|
3089
|
+
"""
|
|
2643
3090
|
|
|
3091
|
+
arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
|
|
2644
3092
|
|
|
2645
|
-
|
|
2646
|
-
|
|
3093
|
+
arr.zero_()
|
|
3094
|
+
|
|
3095
|
+
return arr
|
|
3096
|
+
|
|
3097
|
+
|
|
3098
|
+
def full(
|
|
3099
|
+
shape: Tuple = None,
|
|
3100
|
+
value=0,
|
|
3101
|
+
dtype=Any,
|
|
3102
|
+
device: Devicelike = None,
|
|
3103
|
+
requires_grad: bool = False,
|
|
3104
|
+
pinned: bool = False,
|
|
3105
|
+
**kwargs,
|
|
3106
|
+
) -> warp.array:
|
|
3107
|
+
"""Return an array with all elements initialized to the given value
|
|
2647
3108
|
|
|
2648
3109
|
Args:
|
|
2649
|
-
|
|
3110
|
+
shape: Array dimensions
|
|
3111
|
+
value: Element value
|
|
3112
|
+
dtype: Type of each element, e.g.: float, warp.vec3, warp.mat33, etc
|
|
3113
|
+
device: Device that array will live on
|
|
2650
3114
|
requires_grad: Whether the array will be tracked for back propagation
|
|
2651
3115
|
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
|
|
2652
3116
|
|
|
@@ -2654,24 +3118,78 @@ def zeros_like(src: warp.array, requires_grad: bool = None, pinned: bool = None)
|
|
|
2654
3118
|
A warp.array object representing the allocation
|
|
2655
3119
|
"""
|
|
2656
3120
|
|
|
2657
|
-
if
|
|
2658
|
-
|
|
2659
|
-
|
|
3121
|
+
if dtype == Any:
|
|
3122
|
+
# determine dtype from value
|
|
3123
|
+
value_type = type(value)
|
|
3124
|
+
if value_type == int:
|
|
3125
|
+
dtype = warp.int32
|
|
3126
|
+
elif value_type == float:
|
|
3127
|
+
dtype = warp.float32
|
|
3128
|
+
elif value_type in warp.types.scalar_types or hasattr(value_type, "_wp_scalar_type_"):
|
|
3129
|
+
dtype = value_type
|
|
3130
|
+
elif isinstance(value, warp.codegen.StructInstance):
|
|
3131
|
+
dtype = value._cls
|
|
3132
|
+
elif hasattr(value, "__len__"):
|
|
3133
|
+
# a sequence, assume it's a vector or matrix value
|
|
3134
|
+
try:
|
|
3135
|
+
# try to convert to a numpy array first
|
|
3136
|
+
na = np.array(value, copy=False)
|
|
3137
|
+
except Exception as e:
|
|
3138
|
+
raise ValueError(f"Failed to interpret the value as a vector or matrix: {e}")
|
|
3139
|
+
|
|
3140
|
+
# determine the scalar type
|
|
3141
|
+
scalar_type = warp.types.np_dtype_to_warp_type.get(na.dtype)
|
|
3142
|
+
if scalar_type is None:
|
|
3143
|
+
raise ValueError(f"Failed to convert {na.dtype} to a Warp data type")
|
|
3144
|
+
|
|
3145
|
+
# determine if vector or matrix
|
|
3146
|
+
if na.ndim == 1:
|
|
3147
|
+
dtype = warp.types.vector(na.size, scalar_type)
|
|
3148
|
+
elif na.ndim == 2:
|
|
3149
|
+
dtype = warp.types.matrix(na.shape, scalar_type)
|
|
3150
|
+
else:
|
|
3151
|
+
raise ValueError("Values with more than two dimensions are not supported")
|
|
2660
3152
|
else:
|
|
2661
|
-
|
|
3153
|
+
raise ValueError(f"Invalid value type for Warp array: {value_type}")
|
|
2662
3154
|
|
|
2663
|
-
|
|
2664
|
-
|
|
3155
|
+
arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
|
|
3156
|
+
|
|
3157
|
+
# use the CUDA default stream for synchronous behaviour with other streams
|
|
3158
|
+
with warp.ScopedStream(arr.device.null_stream):
|
|
3159
|
+
arr.fill_(value)
|
|
3160
|
+
|
|
3161
|
+
return arr
|
|
3162
|
+
|
|
3163
|
+
|
|
3164
|
+
def full_like(
|
|
3165
|
+
src: warp.array, value: Any, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
|
|
3166
|
+
) -> warp.array:
|
|
3167
|
+
"""Return an array with all elements initialized to the given value with the same type and dimension of another array
|
|
3168
|
+
|
|
3169
|
+
Args:
|
|
3170
|
+
src: The template array to use for shape, data type, and device
|
|
3171
|
+
value: Element value
|
|
3172
|
+
device: The device where the new array will be created (defaults to src.device)
|
|
3173
|
+
requires_grad: Whether the array will be tracked for back propagation
|
|
3174
|
+
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
|
|
3175
|
+
|
|
3176
|
+
Returns:
|
|
3177
|
+
A warp.array object representing the allocation
|
|
3178
|
+
"""
|
|
3179
|
+
|
|
3180
|
+
arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
|
|
3181
|
+
|
|
3182
|
+
arr.fill_(value)
|
|
2665
3183
|
|
|
2666
|
-
arr = zeros(shape=src.shape, dtype=src.dtype, device=src.device, requires_grad=requires_grad, pinned=pinned)
|
|
2667
3184
|
return arr
|
|
2668
3185
|
|
|
2669
3186
|
|
|
2670
|
-
def clone(src: warp.array, requires_grad: bool = None, pinned: bool = None) -> warp.array:
|
|
3187
|
+
def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None) -> warp.array:
|
|
2671
3188
|
"""Clone an existing array, allocates a copy of the src memory
|
|
2672
3189
|
|
|
2673
3190
|
Args:
|
|
2674
3191
|
src: The source array to copy
|
|
3192
|
+
device: The device where the new array will be created (defaults to src.device)
|
|
2675
3193
|
requires_grad: Whether the array will be tracked for back propagation
|
|
2676
3194
|
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
|
|
2677
3195
|
|
|
@@ -2679,19 +3197,11 @@ def clone(src: warp.array, requires_grad: bool = None, pinned: bool = None) -> w
|
|
|
2679
3197
|
A warp.array object representing the allocation
|
|
2680
3198
|
"""
|
|
2681
3199
|
|
|
2682
|
-
|
|
2683
|
-
if hasattr(src, "requires_grad"):
|
|
2684
|
-
requires_grad = src.requires_grad
|
|
2685
|
-
else:
|
|
2686
|
-
requires_grad = False
|
|
2687
|
-
|
|
2688
|
-
if pinned is None:
|
|
2689
|
-
pinned = src.pinned
|
|
3200
|
+
arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
|
|
2690
3201
|
|
|
2691
|
-
|
|
2692
|
-
copy(dest, src)
|
|
3202
|
+
warp.copy(arr, src)
|
|
2693
3203
|
|
|
2694
|
-
return
|
|
3204
|
+
return arr
|
|
2695
3205
|
|
|
2696
3206
|
|
|
2697
3207
|
def empty(
|
|
@@ -2705,7 +3215,7 @@ def empty(
|
|
|
2705
3215
|
"""Returns an uninitialized array
|
|
2706
3216
|
|
|
2707
3217
|
Args:
|
|
2708
|
-
|
|
3218
|
+
shape: Array dimensions
|
|
2709
3219
|
dtype: Type of each element, e.g.: `warp.vec3`, `warp.mat33`, etc
|
|
2710
3220
|
device: Device that array will live on
|
|
2711
3221
|
requires_grad: Whether the array will be tracked for back propagation
|
|
@@ -2715,15 +3225,26 @@ def empty(
|
|
|
2715
3225
|
A warp.array object representing the allocation
|
|
2716
3226
|
"""
|
|
2717
3227
|
|
|
2718
|
-
#
|
|
2719
|
-
|
|
3228
|
+
# backwards compatibility for case where users called wp.empty(n=length, ...)
|
|
3229
|
+
if "n" in kwargs:
|
|
3230
|
+
shape = (kwargs["n"],)
|
|
3231
|
+
del kwargs["n"]
|
|
3232
|
+
|
|
3233
|
+
# ensure shape is specified, even if creating a zero-sized array
|
|
3234
|
+
if shape is None:
|
|
3235
|
+
shape = 0
|
|
3236
|
+
|
|
3237
|
+
return warp.array(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
|
|
2720
3238
|
|
|
2721
3239
|
|
|
2722
|
-
def empty_like(
|
|
3240
|
+
def empty_like(
|
|
3241
|
+
src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
|
|
3242
|
+
) -> warp.array:
|
|
2723
3243
|
"""Return an uninitialized array with the same type and dimension of another array
|
|
2724
3244
|
|
|
2725
3245
|
Args:
|
|
2726
|
-
src: The template array to use for
|
|
3246
|
+
src: The template array to use for shape, data type, and device
|
|
3247
|
+
device: The device where the new array will be created (defaults to src.device)
|
|
2727
3248
|
requires_grad: Whether the array will be tracked for back propagation
|
|
2728
3249
|
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
|
|
2729
3250
|
|
|
@@ -2731,6 +3252,9 @@ def empty_like(src: warp.array, requires_grad: bool = None, pinned: bool = None)
|
|
|
2731
3252
|
A warp.array object representing the allocation
|
|
2732
3253
|
"""
|
|
2733
3254
|
|
|
3255
|
+
if device is None:
|
|
3256
|
+
device = src.device
|
|
3257
|
+
|
|
2734
3258
|
if requires_grad is None:
|
|
2735
3259
|
if hasattr(src, "requires_grad"):
|
|
2736
3260
|
requires_grad = src.requires_grad
|
|
@@ -2738,14 +3262,246 @@ def empty_like(src: warp.array, requires_grad: bool = None, pinned: bool = None)
|
|
|
2738
3262
|
requires_grad = False
|
|
2739
3263
|
|
|
2740
3264
|
if pinned is None:
|
|
2741
|
-
|
|
3265
|
+
if hasattr(src, "pinned"):
|
|
3266
|
+
pinned = src.pinned
|
|
3267
|
+
else:
|
|
3268
|
+
pinned = False
|
|
2742
3269
|
|
|
2743
|
-
arr = empty(shape=src.shape, dtype=src.dtype, device=
|
|
3270
|
+
arr = empty(shape=src.shape, dtype=src.dtype, device=device, requires_grad=requires_grad, pinned=pinned)
|
|
2744
3271
|
return arr
|
|
2745
3272
|
|
|
2746
3273
|
|
|
2747
|
-
def from_numpy(
|
|
2748
|
-
|
|
3274
|
+
def from_numpy(
|
|
3275
|
+
arr: np.ndarray,
|
|
3276
|
+
dtype: Optional[type] = None,
|
|
3277
|
+
shape: Optional[Sequence[int]] = None,
|
|
3278
|
+
device: Optional[Devicelike] = None,
|
|
3279
|
+
requires_grad: bool = False,
|
|
3280
|
+
) -> warp.array:
|
|
3281
|
+
if dtype is None:
|
|
3282
|
+
base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
|
|
3283
|
+
if base_type is None:
|
|
3284
|
+
raise RuntimeError("Unsupported NumPy data type '{}'.".format(arr.dtype))
|
|
3285
|
+
|
|
3286
|
+
dim_count = len(arr.shape)
|
|
3287
|
+
if dim_count == 2:
|
|
3288
|
+
dtype = warp.types.vector(length=arr.shape[1], dtype=base_type)
|
|
3289
|
+
elif dim_count == 3:
|
|
3290
|
+
dtype = warp.types.matrix(shape=(arr.shape[1], arr.shape[2]), dtype=base_type)
|
|
3291
|
+
else:
|
|
3292
|
+
dtype = base_type
|
|
3293
|
+
|
|
3294
|
+
return warp.array(
|
|
3295
|
+
data=arr,
|
|
3296
|
+
dtype=dtype,
|
|
3297
|
+
shape=shape,
|
|
3298
|
+
owner=False,
|
|
3299
|
+
device=device,
|
|
3300
|
+
requires_grad=requires_grad,
|
|
3301
|
+
)
|
|
3302
|
+
|
|
3303
|
+
|
|
3304
|
+
# given a kernel destination argument type and a value convert
|
|
3305
|
+
# to a c-type that can be passed to a kernel
|
|
3306
|
+
def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
|
|
3307
|
+
if warp.types.is_array(arg_type):
|
|
3308
|
+
if value is None:
|
|
3309
|
+
# allow for NULL arrays
|
|
3310
|
+
return arg_type.__ctype__()
|
|
3311
|
+
|
|
3312
|
+
else:
|
|
3313
|
+
# check for array type
|
|
3314
|
+
# - in forward passes, array types have to match
|
|
3315
|
+
# - in backward passes, indexed array gradients are regular arrays
|
|
3316
|
+
if adjoint:
|
|
3317
|
+
array_matches = isinstance(value, warp.array)
|
|
3318
|
+
else:
|
|
3319
|
+
array_matches = type(value) is type(arg_type)
|
|
3320
|
+
|
|
3321
|
+
if not array_matches:
|
|
3322
|
+
adj = "adjoint " if adjoint else ""
|
|
3323
|
+
raise RuntimeError(
|
|
3324
|
+
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(value)}."
|
|
3325
|
+
)
|
|
3326
|
+
|
|
3327
|
+
# check subtype
|
|
3328
|
+
if not warp.types.types_equal(value.dtype, arg_type.dtype):
|
|
3329
|
+
adj = "adjoint " if adjoint else ""
|
|
3330
|
+
raise RuntimeError(
|
|
3331
|
+
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with dtype={arg_type.dtype} but passed array has dtype={value.dtype}."
|
|
3332
|
+
)
|
|
3333
|
+
|
|
3334
|
+
# check dimensions
|
|
3335
|
+
if value.ndim != arg_type.ndim:
|
|
3336
|
+
adj = "adjoint " if adjoint else ""
|
|
3337
|
+
raise RuntimeError(
|
|
3338
|
+
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with {arg_type.ndim} dimension(s) but the passed array has {value.ndim} dimension(s)."
|
|
3339
|
+
)
|
|
3340
|
+
|
|
3341
|
+
# check device
|
|
3342
|
+
# if a.device != device and not device.can_access(a.device):
|
|
3343
|
+
if value.device != device:
|
|
3344
|
+
raise RuntimeError(
|
|
3345
|
+
f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', but input array for argument '{arg_name}' is on device={value.device}."
|
|
3346
|
+
)
|
|
3347
|
+
|
|
3348
|
+
return value.__ctype__()
|
|
3349
|
+
|
|
3350
|
+
elif isinstance(arg_type, warp.codegen.Struct):
|
|
3351
|
+
assert value is not None
|
|
3352
|
+
return value.__ctype__()
|
|
3353
|
+
|
|
3354
|
+
# try to convert to a value type (vec3, mat33, etc)
|
|
3355
|
+
elif issubclass(arg_type, ctypes.Array):
|
|
3356
|
+
if warp.types.types_equal(type(value), arg_type):
|
|
3357
|
+
return value
|
|
3358
|
+
else:
|
|
3359
|
+
# try constructing the required value from the argument (handles tuple / list, Gf.Vec3 case)
|
|
3360
|
+
try:
|
|
3361
|
+
return arg_type(value)
|
|
3362
|
+
except Exception:
|
|
3363
|
+
raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}")
|
|
3364
|
+
|
|
3365
|
+
elif isinstance(value, bool):
|
|
3366
|
+
return ctypes.c_bool(value)
|
|
3367
|
+
|
|
3368
|
+
elif isinstance(value, arg_type):
|
|
3369
|
+
try:
|
|
3370
|
+
# try to pack as a scalar type
|
|
3371
|
+
if arg_type is warp.types.float16:
|
|
3372
|
+
return arg_type._type_(warp.types.float_to_half_bits(value.value))
|
|
3373
|
+
else:
|
|
3374
|
+
return arg_type._type_(value.value)
|
|
3375
|
+
except Exception:
|
|
3376
|
+
raise RuntimeError(
|
|
3377
|
+
"Error launching kernel, unable to pack kernel parameter type "
|
|
3378
|
+
f"{type(value)} for param {arg_name}, expected {arg_type}"
|
|
3379
|
+
)
|
|
3380
|
+
|
|
3381
|
+
else:
|
|
3382
|
+
try:
|
|
3383
|
+
# try to pack as a scalar type
|
|
3384
|
+
if arg_type is warp.types.float16:
|
|
3385
|
+
return arg_type._type_(warp.types.float_to_half_bits(value))
|
|
3386
|
+
else:
|
|
3387
|
+
return arg_type._type_(value)
|
|
3388
|
+
except Exception as e:
|
|
3389
|
+
print(e)
|
|
3390
|
+
raise RuntimeError(
|
|
3391
|
+
"Error launching kernel, unable to pack kernel parameter type "
|
|
3392
|
+
f"{type(value)} for param {arg_name}, expected {arg_type}"
|
|
3393
|
+
)
|
|
3394
|
+
|
|
3395
|
+
|
|
3396
|
+
# represents all data required for a kernel launch
|
|
3397
|
+
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
|
|
3398
|
+
class Launch:
|
|
3399
|
+
def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
|
|
3400
|
+
# if not specified look up hooks
|
|
3401
|
+
if not hooks:
|
|
3402
|
+
module = kernel.module
|
|
3403
|
+
if not module.load(device):
|
|
3404
|
+
return
|
|
3405
|
+
|
|
3406
|
+
hooks = module.get_kernel_hooks(kernel, device)
|
|
3407
|
+
|
|
3408
|
+
# if not specified set a zero bound
|
|
3409
|
+
if not bounds:
|
|
3410
|
+
bounds = warp.types.launch_bounds_t(0)
|
|
3411
|
+
|
|
3412
|
+
# if not specified then build a list of default value params for args
|
|
3413
|
+
if not params:
|
|
3414
|
+
params = []
|
|
3415
|
+
params.append(bounds)
|
|
3416
|
+
|
|
3417
|
+
for a in kernel.adj.args:
|
|
3418
|
+
if isinstance(a.type, warp.types.array):
|
|
3419
|
+
params.append(a.type.__ctype__())
|
|
3420
|
+
elif isinstance(a.type, warp.codegen.Struct):
|
|
3421
|
+
params.append(a.type().__ctype__())
|
|
3422
|
+
else:
|
|
3423
|
+
params.append(pack_arg(kernel, a.type, a.label, 0, device, False))
|
|
3424
|
+
|
|
3425
|
+
kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
|
|
3426
|
+
kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args)
|
|
3427
|
+
|
|
3428
|
+
params_addr = kernel_params
|
|
3429
|
+
|
|
3430
|
+
self.kernel = kernel
|
|
3431
|
+
self.hooks = hooks
|
|
3432
|
+
self.params = params
|
|
3433
|
+
self.params_addr = params_addr
|
|
3434
|
+
self.device = device
|
|
3435
|
+
self.bounds = bounds
|
|
3436
|
+
self.max_blocks = max_blocks
|
|
3437
|
+
|
|
3438
|
+
def set_dim(self, dim):
|
|
3439
|
+
self.bounds = warp.types.launch_bounds_t(dim)
|
|
3440
|
+
|
|
3441
|
+
# launch bounds always at index 0
|
|
3442
|
+
self.params[0] = self.bounds
|
|
3443
|
+
|
|
3444
|
+
# for CUDA kernels we need to update the address to each arg
|
|
3445
|
+
if self.params_addr:
|
|
3446
|
+
self.params_addr[0] = ctypes.c_void_p(ctypes.addressof(self.bounds))
|
|
3447
|
+
|
|
3448
|
+
# set kernel param at an index, will convert to ctype as necessary
|
|
3449
|
+
def set_param_at_index(self, index, value):
|
|
3450
|
+
arg_type = self.kernel.adj.args[index].type
|
|
3451
|
+
arg_name = self.kernel.adj.args[index].label
|
|
3452
|
+
|
|
3453
|
+
carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, False)
|
|
3454
|
+
|
|
3455
|
+
self.params[index + 1] = carg
|
|
3456
|
+
|
|
3457
|
+
# for CUDA kernels we need to update the address to each arg
|
|
3458
|
+
if self.params_addr:
|
|
3459
|
+
self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(carg))
|
|
3460
|
+
|
|
3461
|
+
# set kernel param at an index without any type conversion
|
|
3462
|
+
# args must be passed as ctypes or basic int / float types
|
|
3463
|
+
def set_param_at_index_from_ctype(self, index, value):
|
|
3464
|
+
if isinstance(value, ctypes.Structure):
|
|
3465
|
+
# not sure how to directly assign struct->struct without reallocating using ctypes
|
|
3466
|
+
self.params[index + 1] = value
|
|
3467
|
+
|
|
3468
|
+
# for CUDA kernels we need to update the address to each arg
|
|
3469
|
+
if self.params_addr:
|
|
3470
|
+
self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(value))
|
|
3471
|
+
|
|
3472
|
+
else:
|
|
3473
|
+
self.params[index + 1].__init__(value)
|
|
3474
|
+
|
|
3475
|
+
# set kernel param by argument name
|
|
3476
|
+
def set_param_by_name(self, name, value):
|
|
3477
|
+
for i, arg in enumerate(self.kernel.adj.args):
|
|
3478
|
+
if arg.label == name:
|
|
3479
|
+
self.set_param_at_index(i, value)
|
|
3480
|
+
|
|
3481
|
+
# set kernel param by argument name with no type conversions
|
|
3482
|
+
def set_param_by_name_from_ctype(self, name, value):
|
|
3483
|
+
# lookup argument index
|
|
3484
|
+
for i, arg in enumerate(self.kernel.adj.args):
|
|
3485
|
+
if arg.label == name:
|
|
3486
|
+
self.set_param_at_index_from_ctype(i, value)
|
|
3487
|
+
|
|
3488
|
+
# set all params
|
|
3489
|
+
def set_params(self, values):
|
|
3490
|
+
for i, v in enumerate(values):
|
|
3491
|
+
self.set_param_at_index(i, v)
|
|
3492
|
+
|
|
3493
|
+
# set all params without performing type-conversions
|
|
3494
|
+
def set_params_from_ctypes(self, values):
|
|
3495
|
+
for i, v in enumerate(values):
|
|
3496
|
+
self.set_param_at_index_from_ctype(i, v)
|
|
3497
|
+
|
|
3498
|
+
def launch(self) -> Any:
|
|
3499
|
+
if self.device.is_cpu:
|
|
3500
|
+
self.hooks.forward(*self.params)
|
|
3501
|
+
else:
|
|
3502
|
+
runtime.core.cuda_launch_kernel(
|
|
3503
|
+
self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr
|
|
3504
|
+
)
|
|
2749
3505
|
|
|
2750
3506
|
|
|
2751
3507
|
def launch(
|
|
@@ -2759,6 +3515,8 @@ def launch(
|
|
|
2759
3515
|
stream: Stream = None,
|
|
2760
3516
|
adjoint=False,
|
|
2761
3517
|
record_tape=True,
|
|
3518
|
+
record_cmd=False,
|
|
3519
|
+
max_blocks=0,
|
|
2762
3520
|
):
|
|
2763
3521
|
"""Launch a Warp kernel on the target device
|
|
2764
3522
|
|
|
@@ -2774,6 +3532,10 @@ def launch(
|
|
|
2774
3532
|
device: The device to launch on (optional)
|
|
2775
3533
|
stream: The stream to launch on (optional)
|
|
2776
3534
|
adjoint: Whether to run forward or backward pass (typically use False)
|
|
3535
|
+
record_tape: When true the launch will be recorded the global wp.Tape() object when present
|
|
3536
|
+
record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
|
|
3537
|
+
max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
|
|
3538
|
+
If negative or zero, the maximum hardware value will be used.
|
|
2777
3539
|
"""
|
|
2778
3540
|
|
|
2779
3541
|
assert_initialized()
|
|
@@ -2785,7 +3547,7 @@ def launch(
|
|
|
2785
3547
|
device = runtime.get_device(device)
|
|
2786
3548
|
|
|
2787
3549
|
# check function is a Kernel
|
|
2788
|
-
if isinstance(kernel, Kernel)
|
|
3550
|
+
if not isinstance(kernel, Kernel):
|
|
2789
3551
|
raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
|
|
2790
3552
|
|
|
2791
3553
|
# debugging aid
|
|
@@ -2806,85 +3568,7 @@ def launch(
|
|
|
2806
3568
|
arg_type = kernel.adj.args[i].type
|
|
2807
3569
|
arg_name = kernel.adj.args[i].label
|
|
2808
3570
|
|
|
2809
|
-
|
|
2810
|
-
if a is None:
|
|
2811
|
-
# allow for NULL arrays
|
|
2812
|
-
params.append(arg_type.__ctype__())
|
|
2813
|
-
|
|
2814
|
-
else:
|
|
2815
|
-
# check for array type
|
|
2816
|
-
# - in forward passes, array types have to match
|
|
2817
|
-
# - in backward passes, indexed array gradients are regular arrays
|
|
2818
|
-
if adjoint:
|
|
2819
|
-
array_matches = type(a) == warp.array
|
|
2820
|
-
else:
|
|
2821
|
-
array_matches = type(a) == type(arg_type)
|
|
2822
|
-
|
|
2823
|
-
if not array_matches:
|
|
2824
|
-
adj = "adjoint " if adjoint else ""
|
|
2825
|
-
raise RuntimeError(
|
|
2826
|
-
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(a)}."
|
|
2827
|
-
)
|
|
2828
|
-
|
|
2829
|
-
# check subtype
|
|
2830
|
-
if not warp.types.types_equal(a.dtype, arg_type.dtype):
|
|
2831
|
-
adj = "adjoint " if adjoint else ""
|
|
2832
|
-
raise RuntimeError(
|
|
2833
|
-
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with dtype={arg_type.dtype} but passed array has dtype={a.dtype}."
|
|
2834
|
-
)
|
|
2835
|
-
|
|
2836
|
-
# check dimensions
|
|
2837
|
-
if a.ndim != arg_type.ndim:
|
|
2838
|
-
adj = "adjoint " if adjoint else ""
|
|
2839
|
-
raise RuntimeError(
|
|
2840
|
-
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with {arg_type.ndim} dimension(s) but the passed array has {a.ndim} dimension(s)."
|
|
2841
|
-
)
|
|
2842
|
-
|
|
2843
|
-
# check device
|
|
2844
|
-
# if a.device != device and not device.can_access(a.device):
|
|
2845
|
-
if a.device != device:
|
|
2846
|
-
raise RuntimeError(
|
|
2847
|
-
f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', but input array for argument '{arg_name}' is on device={a.device}."
|
|
2848
|
-
)
|
|
2849
|
-
|
|
2850
|
-
params.append(a.__ctype__())
|
|
2851
|
-
|
|
2852
|
-
elif isinstance(arg_type, warp.codegen.Struct):
|
|
2853
|
-
assert a is not None
|
|
2854
|
-
params.append(a.__ctype__())
|
|
2855
|
-
|
|
2856
|
-
# try to convert to a value type (vec3, mat33, etc)
|
|
2857
|
-
elif issubclass(arg_type, ctypes.Array):
|
|
2858
|
-
if warp.types.types_equal(type(a), arg_type):
|
|
2859
|
-
params.append(a)
|
|
2860
|
-
else:
|
|
2861
|
-
# try constructing the required value from the argument (handles tuple / list, Gf.Vec3 case)
|
|
2862
|
-
try:
|
|
2863
|
-
params.append(arg_type(a))
|
|
2864
|
-
except:
|
|
2865
|
-
raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}")
|
|
2866
|
-
|
|
2867
|
-
elif isinstance(a, bool):
|
|
2868
|
-
params.append(ctypes.c_bool(a))
|
|
2869
|
-
|
|
2870
|
-
elif isinstance(a, arg_type):
|
|
2871
|
-
try:
|
|
2872
|
-
# try to pack as a scalar type
|
|
2873
|
-
params.append(arg_type._type_(a.value))
|
|
2874
|
-
except:
|
|
2875
|
-
raise RuntimeError(
|
|
2876
|
-
f"Error launching kernel, unable to pack kernel parameter type {type(a)} for param {arg_name}, expected {arg_type}"
|
|
2877
|
-
)
|
|
2878
|
-
|
|
2879
|
-
else:
|
|
2880
|
-
try:
|
|
2881
|
-
# try to pack as a scalar type
|
|
2882
|
-
params.append(arg_type._type_(a))
|
|
2883
|
-
except Exception as e:
|
|
2884
|
-
print(e)
|
|
2885
|
-
raise RuntimeError(
|
|
2886
|
-
f"Error launching kernel, unable to pack kernel parameter type {type(a)} for param {arg_name}, expected {arg_type}"
|
|
2887
|
-
)
|
|
3571
|
+
params.append(pack_arg(kernel, arg_type, arg_name, a, device, adjoint))
|
|
2888
3572
|
|
|
2889
3573
|
fwd_args = inputs + outputs
|
|
2890
3574
|
adj_args = adj_inputs + adj_outputs
|
|
@@ -2926,7 +3610,13 @@ def launch(
|
|
|
2926
3610
|
f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
|
|
2927
3611
|
)
|
|
2928
3612
|
|
|
2929
|
-
|
|
3613
|
+
if record_cmd:
|
|
3614
|
+
launch = Launch(
|
|
3615
|
+
kernel=kernel, hooks=hooks, params=params, params_addr=None, bounds=bounds, device=device
|
|
3616
|
+
)
|
|
3617
|
+
return launch
|
|
3618
|
+
else:
|
|
3619
|
+
hooks.forward(*params)
|
|
2930
3620
|
|
|
2931
3621
|
else:
|
|
2932
3622
|
kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
|
|
@@ -2939,7 +3629,9 @@ def launch(
|
|
|
2939
3629
|
f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
|
|
2940
3630
|
)
|
|
2941
3631
|
|
|
2942
|
-
runtime.core.cuda_launch_kernel(
|
|
3632
|
+
runtime.core.cuda_launch_kernel(
|
|
3633
|
+
device.context, hooks.backward, bounds.size, max_blocks, kernel_params
|
|
3634
|
+
)
|
|
2943
3635
|
|
|
2944
3636
|
else:
|
|
2945
3637
|
if hooks.forward is None:
|
|
@@ -2947,7 +3639,22 @@ def launch(
|
|
|
2947
3639
|
f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
|
|
2948
3640
|
)
|
|
2949
3641
|
|
|
2950
|
-
|
|
3642
|
+
if record_cmd:
|
|
3643
|
+
launch = Launch(
|
|
3644
|
+
kernel=kernel,
|
|
3645
|
+
hooks=hooks,
|
|
3646
|
+
params=params,
|
|
3647
|
+
params_addr=kernel_params,
|
|
3648
|
+
bounds=bounds,
|
|
3649
|
+
device=device,
|
|
3650
|
+
)
|
|
3651
|
+
return launch
|
|
3652
|
+
|
|
3653
|
+
else:
|
|
3654
|
+
# launch
|
|
3655
|
+
runtime.core.cuda_launch_kernel(
|
|
3656
|
+
device.context, hooks.forward, bounds.size, max_blocks, kernel_params
|
|
3657
|
+
)
|
|
2951
3658
|
|
|
2952
3659
|
try:
|
|
2953
3660
|
runtime.verify_cuda_device(device)
|
|
@@ -2957,7 +3664,7 @@ def launch(
|
|
|
2957
3664
|
|
|
2958
3665
|
# record on tape if one is active
|
|
2959
3666
|
if runtime.tape and record_tape:
|
|
2960
|
-
runtime.tape.record_launch(kernel, dim, inputs, outputs, device)
|
|
3667
|
+
runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device)
|
|
2961
3668
|
|
|
2962
3669
|
|
|
2963
3670
|
def synchronize():
|
|
@@ -2967,7 +3674,7 @@ def synchronize():
|
|
|
2967
3674
|
or memory copies have completed.
|
|
2968
3675
|
"""
|
|
2969
3676
|
|
|
2970
|
-
if
|
|
3677
|
+
if is_cuda_driver_initialized():
|
|
2971
3678
|
# save the original context to avoid side effects
|
|
2972
3679
|
saved_context = runtime.core.cuda_context_get_current()
|
|
2973
3680
|
|
|
@@ -3017,7 +3724,7 @@ def synchronize_stream(stream_or_device=None):
|
|
|
3017
3724
|
runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream)
|
|
3018
3725
|
|
|
3019
3726
|
|
|
3020
|
-
def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
|
|
3727
|
+
def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
|
|
3021
3728
|
"""Force user-defined kernels to be compiled and loaded
|
|
3022
3729
|
|
|
3023
3730
|
Args:
|
|
@@ -3025,12 +3732,14 @@ def force_load(device: Union[Device, str] = None, modules: List[Module] = None):
|
|
|
3025
3732
|
modules: List of modules to load. If None, load all imported modules.
|
|
3026
3733
|
"""
|
|
3027
3734
|
|
|
3028
|
-
if
|
|
3735
|
+
if is_cuda_driver_initialized():
|
|
3029
3736
|
# save original context to avoid side effects
|
|
3030
3737
|
saved_context = runtime.core.cuda_context_get_current()
|
|
3031
3738
|
|
|
3032
3739
|
if device is None:
|
|
3033
3740
|
devices = get_devices()
|
|
3741
|
+
elif isinstance(device, list):
|
|
3742
|
+
devices = [get_device(device_item) for device_item in device]
|
|
3034
3743
|
else:
|
|
3035
3744
|
devices = [get_device(device)]
|
|
3036
3745
|
|
|
@@ -3122,7 +3831,7 @@ def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
|
|
|
3122
3831
|
return get_module(m.__name__).options
|
|
3123
3832
|
|
|
3124
3833
|
|
|
3125
|
-
def capture_begin(device: Devicelike = None, stream=None, force_module_load=
|
|
3834
|
+
def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
|
|
3126
3835
|
"""Begin capture of a CUDA graph
|
|
3127
3836
|
|
|
3128
3837
|
Captures all subsequent kernel launches and memory operations on CUDA devices.
|
|
@@ -3136,7 +3845,10 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
|
|
|
3136
3845
|
|
|
3137
3846
|
"""
|
|
3138
3847
|
|
|
3139
|
-
if
|
|
3848
|
+
if force_module_load is None:
|
|
3849
|
+
force_module_load = warp.config.graph_capture_module_load_default
|
|
3850
|
+
|
|
3851
|
+
if warp.config.verify_cuda:
|
|
3140
3852
|
raise RuntimeError("Cannot use CUDA error verification during graph capture")
|
|
3141
3853
|
|
|
3142
3854
|
if stream is not None:
|
|
@@ -3151,6 +3863,9 @@ def capture_begin(device: Devicelike = None, stream=None, force_module_load=True
|
|
|
3151
3863
|
|
|
3152
3864
|
device.is_capturing = True
|
|
3153
3865
|
|
|
3866
|
+
# disable garbage collection to avoid older allocations getting collected during graph capture
|
|
3867
|
+
gc.disable()
|
|
3868
|
+
|
|
3154
3869
|
with warp.ScopedStream(stream):
|
|
3155
3870
|
runtime.core.cuda_graph_begin_capture(device.context)
|
|
3156
3871
|
|
|
@@ -3174,6 +3889,9 @@ def capture_end(device: Devicelike = None, stream=None) -> Graph:
|
|
|
3174
3889
|
|
|
3175
3890
|
device.is_capturing = False
|
|
3176
3891
|
|
|
3892
|
+
# re-enable GC
|
|
3893
|
+
gc.enable()
|
|
3894
|
+
|
|
3177
3895
|
if graph is None:
|
|
3178
3896
|
raise RuntimeError(
|
|
3179
3897
|
"Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event."
|
|
@@ -3226,7 +3944,14 @@ def copy(
|
|
|
3226
3944
|
if count == 0:
|
|
3227
3945
|
return
|
|
3228
3946
|
|
|
3229
|
-
|
|
3947
|
+
# copying non-contiguous arrays requires that they are on the same device
|
|
3948
|
+
if not (src.is_contiguous and dest.is_contiguous) and src.device != dest.device:
|
|
3949
|
+
if dest.is_contiguous:
|
|
3950
|
+
# make a contiguous copy of the source array
|
|
3951
|
+
src = src.contiguous()
|
|
3952
|
+
else:
|
|
3953
|
+
# make a copy of the source array on the destination device
|
|
3954
|
+
src = src.to(dest.device)
|
|
3230
3955
|
|
|
3231
3956
|
if src.is_contiguous and dest.is_contiguous:
|
|
3232
3957
|
bytes_to_copy = count * warp.types.type_size_in_bytes(src.dtype)
|
|
@@ -3240,10 +3965,6 @@ def copy(
|
|
|
3240
3965
|
src_ptr = src.ptr + src_offset_in_bytes
|
|
3241
3966
|
dst_ptr = dest.ptr + dst_offset_in_bytes
|
|
3242
3967
|
|
|
3243
|
-
if has_grad:
|
|
3244
|
-
src_grad_ptr = src.grad_ptr + src_offset_in_bytes
|
|
3245
|
-
dst_grad_ptr = dest.grad_ptr + dst_offset_in_bytes
|
|
3246
|
-
|
|
3247
3968
|
if src_offset_in_bytes + bytes_to_copy > src_size_in_bytes:
|
|
3248
3969
|
raise RuntimeError(
|
|
3249
3970
|
f"Trying to copy source buffer with size ({bytes_to_copy}) from offset ({src_offset_in_bytes}) is larger than source size ({src_size_in_bytes})"
|
|
@@ -3256,8 +3977,6 @@ def copy(
|
|
|
3256
3977
|
|
|
3257
3978
|
if src.device.is_cpu and dest.device.is_cpu:
|
|
3258
3979
|
runtime.core.memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
|
|
3259
|
-
if has_grad:
|
|
3260
|
-
runtime.core.memcpy_h2h(dst_grad_ptr, src_grad_ptr, bytes_to_copy)
|
|
3261
3980
|
else:
|
|
3262
3981
|
# figure out the CUDA context/stream for the copy
|
|
3263
3982
|
if stream is not None:
|
|
@@ -3270,32 +3989,19 @@ def copy(
|
|
|
3270
3989
|
with warp.ScopedStream(stream):
|
|
3271
3990
|
if src.device.is_cpu and dest.device.is_cuda:
|
|
3272
3991
|
runtime.core.memcpy_h2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
|
|
3273
|
-
if has_grad:
|
|
3274
|
-
runtime.core.memcpy_h2d(copy_device.context, dst_grad_ptr, src_grad_ptr, bytes_to_copy)
|
|
3275
3992
|
elif src.device.is_cuda and dest.device.is_cpu:
|
|
3276
3993
|
runtime.core.memcpy_d2h(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
|
|
3277
|
-
if has_grad:
|
|
3278
|
-
runtime.core.memcpy_d2h(copy_device.context, dst_grad_ptr, src_grad_ptr, bytes_to_copy)
|
|
3279
3994
|
elif src.device.is_cuda and dest.device.is_cuda:
|
|
3280
3995
|
if src.device == dest.device:
|
|
3281
3996
|
runtime.core.memcpy_d2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
|
|
3282
|
-
if has_grad:
|
|
3283
|
-
runtime.core.memcpy_d2d(copy_device.context, dst_grad_ptr, src_grad_ptr, bytes_to_copy)
|
|
3284
3997
|
else:
|
|
3285
3998
|
runtime.core.memcpy_peer(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
|
|
3286
|
-
if has_grad:
|
|
3287
|
-
runtime.core.memcpy_peer(copy_device.context, dst_grad_ptr, src_grad_ptr, bytes_to_copy)
|
|
3288
3999
|
else:
|
|
3289
4000
|
raise RuntimeError("Unexpected source and destination combination")
|
|
3290
4001
|
|
|
3291
4002
|
else:
|
|
3292
4003
|
# handle non-contiguous and indexed arrays
|
|
3293
4004
|
|
|
3294
|
-
if src.device != dest.device:
|
|
3295
|
-
raise RuntimeError(
|
|
3296
|
-
f"Copies between non-contiguous arrays must be on the same device, got {dest.device} and {src.device}"
|
|
3297
|
-
)
|
|
3298
|
-
|
|
3299
4005
|
if src.shape != dest.shape:
|
|
3300
4006
|
raise RuntimeError("Incompatible array shapes")
|
|
3301
4007
|
|
|
@@ -3305,18 +4011,22 @@ def copy(
|
|
|
3305
4011
|
if src_elem_size != dst_elem_size:
|
|
3306
4012
|
raise RuntimeError("Incompatible array data types")
|
|
3307
4013
|
|
|
3308
|
-
|
|
3309
|
-
|
|
3310
|
-
|
|
3311
|
-
|
|
3312
|
-
|
|
4014
|
+
# can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
|
|
4015
|
+
# TODO?
|
|
4016
|
+
if (
|
|
4017
|
+
isinstance(src, (warp.fabricarray, warp.indexedfabricarray))
|
|
4018
|
+
and src.ndim > 1
|
|
4019
|
+
or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
|
|
4020
|
+
and dest.ndim > 1
|
|
4021
|
+
):
|
|
4022
|
+
raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
|
|
3313
4023
|
|
|
3314
4024
|
src_desc = src.__ctype__()
|
|
3315
4025
|
dst_desc = dest.__ctype__()
|
|
3316
4026
|
src_ptr = ctypes.pointer(src_desc)
|
|
3317
4027
|
dst_ptr = ctypes.pointer(dst_desc)
|
|
3318
|
-
src_type =
|
|
3319
|
-
dst_type =
|
|
4028
|
+
src_type = warp.types.array_type_id(src)
|
|
4029
|
+
dst_type = warp.types.array_type_id(dest)
|
|
3320
4030
|
|
|
3321
4031
|
if src.device.is_cuda:
|
|
3322
4032
|
with warp.ScopedStream(stream):
|
|
@@ -3324,6 +4034,10 @@ def copy(
|
|
|
3324
4034
|
else:
|
|
3325
4035
|
runtime.core.array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
|
|
3326
4036
|
|
|
4037
|
+
# copy gradient, if needed
|
|
4038
|
+
if hasattr(src, "grad") and src.grad is not None and hasattr(dest, "grad") and dest.grad is not None:
|
|
4039
|
+
copy(dest.grad, src.grad, stream=stream)
|
|
4040
|
+
|
|
3327
4041
|
|
|
3328
4042
|
def type_str(t):
|
|
3329
4043
|
if t is None:
|
|
@@ -3342,6 +4056,10 @@ def type_str(t):
|
|
|
3342
4056
|
return f"Array[{type_str(t.dtype)}]"
|
|
3343
4057
|
elif isinstance(t, warp.indexedarray):
|
|
3344
4058
|
return f"IndexedArray[{type_str(t.dtype)}]"
|
|
4059
|
+
elif isinstance(t, warp.fabricarray):
|
|
4060
|
+
return f"FabricArray[{type_str(t.dtype)}]"
|
|
4061
|
+
elif isinstance(t, warp.indexedfabricarray):
|
|
4062
|
+
return f"IndexedFabricArray[{type_str(t.dtype)}]"
|
|
3345
4063
|
elif hasattr(t, "_wp_generic_type_str_"):
|
|
3346
4064
|
generic_type = t._wp_generic_type_str_
|
|
3347
4065
|
|
|
@@ -3368,7 +4086,7 @@ def type_str(t):
|
|
|
3368
4086
|
return t.__name__
|
|
3369
4087
|
|
|
3370
4088
|
|
|
3371
|
-
def print_function(f, file, noentry=False):
|
|
4089
|
+
def print_function(f, file, noentry=False): # pragma: no cover
|
|
3372
4090
|
"""Writes a function definition to a file for use in reST documentation
|
|
3373
4091
|
|
|
3374
4092
|
Args:
|
|
@@ -3392,7 +4110,7 @@ def print_function(f, file, noentry=False):
|
|
|
3392
4110
|
# todo: construct a default value for each of the functions args
|
|
3393
4111
|
# so we can generate the return type for overloaded functions
|
|
3394
4112
|
return_type = " -> " + type_str(f.value_func(None, None, None))
|
|
3395
|
-
except:
|
|
4113
|
+
except Exception:
|
|
3396
4114
|
pass
|
|
3397
4115
|
|
|
3398
4116
|
print(f".. function:: {f.key}({args}){return_type}", file=file)
|
|
@@ -3413,7 +4131,7 @@ def print_function(f, file, noentry=False):
|
|
|
3413
4131
|
return True
|
|
3414
4132
|
|
|
3415
4133
|
|
|
3416
|
-
def
|
|
4134
|
+
def export_functions_rst(file): # pragma: no cover
|
|
3417
4135
|
header = (
|
|
3418
4136
|
"..\n"
|
|
3419
4137
|
" Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
|
|
@@ -3433,6 +4151,8 @@ def print_builtins(file):
|
|
|
3433
4151
|
|
|
3434
4152
|
for t in warp.types.scalar_types:
|
|
3435
4153
|
print(f".. class:: {t.__name__}", file=file)
|
|
4154
|
+
# Manually add wp.bool since it's inconvenient to add to wp.types.scalar_types:
|
|
4155
|
+
print(f".. class:: {warp.types.bool.__name__}", file=file)
|
|
3436
4156
|
|
|
3437
4157
|
print("\n\nVector Types", file=file)
|
|
3438
4158
|
print("------------", file=file)
|
|
@@ -3443,14 +4163,22 @@ def print_builtins(file):
|
|
|
3443
4163
|
print("\nGeneric Types", file=file)
|
|
3444
4164
|
print("-------------", file=file)
|
|
3445
4165
|
|
|
3446
|
-
print(
|
|
3447
|
-
print(
|
|
3448
|
-
print(
|
|
3449
|
-
print(
|
|
3450
|
-
print(
|
|
3451
|
-
print(
|
|
3452
|
-
print(
|
|
3453
|
-
print(
|
|
4166
|
+
print(".. class:: Int", file=file)
|
|
4167
|
+
print(".. class:: Float", file=file)
|
|
4168
|
+
print(".. class:: Scalar", file=file)
|
|
4169
|
+
print(".. class:: Vector", file=file)
|
|
4170
|
+
print(".. class:: Matrix", file=file)
|
|
4171
|
+
print(".. class:: Quaternion", file=file)
|
|
4172
|
+
print(".. class:: Transformation", file=file)
|
|
4173
|
+
print(".. class:: Array", file=file)
|
|
4174
|
+
|
|
4175
|
+
print("\nQuery Types", file=file)
|
|
4176
|
+
print("-------------", file=file)
|
|
4177
|
+
print(".. autoclass:: bvh_query_t", file=file)
|
|
4178
|
+
print(".. autoclass:: hash_grid_query_t", file=file)
|
|
4179
|
+
print(".. autoclass:: mesh_query_aabb_t", file=file)
|
|
4180
|
+
print(".. autoclass:: mesh_query_point_t", file=file)
|
|
4181
|
+
print(".. autoclass:: mesh_query_ray_t", file=file)
|
|
3454
4182
|
|
|
3455
4183
|
# build dictionary of all functions by group
|
|
3456
4184
|
groups = {}
|
|
@@ -3485,7 +4213,7 @@ def print_builtins(file):
|
|
|
3485
4213
|
print(".. [1] Note: function gradients not implemented for backpropagation.", file=file)
|
|
3486
4214
|
|
|
3487
4215
|
|
|
3488
|
-
def export_stubs(file):
|
|
4216
|
+
def export_stubs(file): # pragma: no cover
|
|
3489
4217
|
"""Generates stub file for auto-complete of builtin functions"""
|
|
3490
4218
|
|
|
3491
4219
|
import textwrap
|
|
@@ -3517,6 +4245,8 @@ def export_stubs(file):
|
|
|
3517
4245
|
print("Quaternion = Generic[Float]", file=file)
|
|
3518
4246
|
print("Transformation = Generic[Float]", file=file)
|
|
3519
4247
|
print("Array = Generic[DType]", file=file)
|
|
4248
|
+
print("FabricArray = Generic[DType]", file=file)
|
|
4249
|
+
print("IndexedFabricArray = Generic[DType]", file=file)
|
|
3520
4250
|
|
|
3521
4251
|
# prepend __init__.py
|
|
3522
4252
|
with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
|
|
@@ -3533,7 +4263,7 @@ def export_stubs(file):
|
|
|
3533
4263
|
|
|
3534
4264
|
return_str = ""
|
|
3535
4265
|
|
|
3536
|
-
if f.export
|
|
4266
|
+
if not f.export or f.hidden: # or f.generic:
|
|
3537
4267
|
continue
|
|
3538
4268
|
|
|
3539
4269
|
try:
|
|
@@ -3543,29 +4273,42 @@ def export_stubs(file):
|
|
|
3543
4273
|
if return_type:
|
|
3544
4274
|
return_str = " -> " + type_str(return_type)
|
|
3545
4275
|
|
|
3546
|
-
except:
|
|
4276
|
+
except Exception:
|
|
3547
4277
|
pass
|
|
3548
4278
|
|
|
3549
4279
|
print("@over", file=file)
|
|
3550
4280
|
print(f"def {f.key}({args}){return_str}:", file=file)
|
|
3551
|
-
print(
|
|
4281
|
+
print(' """', file=file)
|
|
3552
4282
|
print(textwrap.indent(text=f.doc, prefix=" "), file=file)
|
|
3553
|
-
print(
|
|
3554
|
-
print(
|
|
4283
|
+
print(' """', file=file)
|
|
4284
|
+
print(" ...\n\n", file=file)
|
|
3555
4285
|
|
|
3556
4286
|
|
|
3557
|
-
def export_builtins(file):
|
|
3558
|
-
def
|
|
4287
|
+
def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
4288
|
+
def ctype_arg_str(t):
|
|
3559
4289
|
if isinstance(t, int):
|
|
3560
4290
|
return "int"
|
|
3561
4291
|
elif isinstance(t, float):
|
|
3562
4292
|
return "float"
|
|
4293
|
+
elif t in warp.types.vector_types:
|
|
4294
|
+
return f"{t.__name__}&"
|
|
3563
4295
|
else:
|
|
3564
4296
|
return t.__name__
|
|
3565
4297
|
|
|
4298
|
+
def ctype_ret_str(t):
|
|
4299
|
+
if isinstance(t, int):
|
|
4300
|
+
return "int"
|
|
4301
|
+
elif isinstance(t, float):
|
|
4302
|
+
return "float"
|
|
4303
|
+
else:
|
|
4304
|
+
return t.__name__
|
|
4305
|
+
|
|
4306
|
+
file.write("namespace wp {\n\n")
|
|
4307
|
+
file.write('extern "C" {\n\n')
|
|
4308
|
+
|
|
3566
4309
|
for k, g in builtin_functions.items():
|
|
3567
4310
|
for f in g.overloads:
|
|
3568
|
-
if f.export
|
|
4311
|
+
if not f.export or f.generic:
|
|
3569
4312
|
continue
|
|
3570
4313
|
|
|
3571
4314
|
simple = True
|
|
@@ -3579,7 +4322,7 @@ def export_builtins(file):
|
|
|
3579
4322
|
if not simple or f.variadic:
|
|
3580
4323
|
continue
|
|
3581
4324
|
|
|
3582
|
-
args = ", ".join(f"{
|
|
4325
|
+
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
|
|
3583
4326
|
params = ", ".join(f.input_types.keys())
|
|
3584
4327
|
|
|
3585
4328
|
return_type = ""
|
|
@@ -3587,25 +4330,25 @@ def export_builtins(file):
|
|
|
3587
4330
|
try:
|
|
3588
4331
|
# todo: construct a default value for each of the functions args
|
|
3589
4332
|
# so we can generate the return type for overloaded functions
|
|
3590
|
-
return_type =
|
|
3591
|
-
except:
|
|
4333
|
+
return_type = ctype_ret_str(f.value_func(None, None, None))
|
|
4334
|
+
except Exception:
|
|
3592
4335
|
continue
|
|
3593
4336
|
|
|
3594
4337
|
if return_type.startswith("Tuple"):
|
|
3595
4338
|
continue
|
|
3596
4339
|
|
|
3597
4340
|
if args == "":
|
|
3598
|
-
|
|
3599
|
-
f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}", file=file
|
|
3600
|
-
)
|
|
4341
|
+
file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
3601
4342
|
elif return_type == "None":
|
|
3602
|
-
|
|
4343
|
+
file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
|
|
3603
4344
|
else:
|
|
3604
|
-
|
|
3605
|
-
f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}"
|
|
3606
|
-
file=file,
|
|
4345
|
+
file.write(
|
|
4346
|
+
f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
|
|
3607
4347
|
)
|
|
3608
4348
|
|
|
4349
|
+
file.write('\n} // extern "C"\n\n')
|
|
4350
|
+
file.write("} // namespace wp\n")
|
|
4351
|
+
|
|
3609
4352
|
|
|
3610
4353
|
# initialize global runtime
|
|
3611
4354
|
runtime = None
|