warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/types.py
CHANGED
|
@@ -5,19 +5,17 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import builtins
|
|
8
11
|
import ctypes
|
|
9
12
|
import hashlib
|
|
13
|
+
import inspect
|
|
10
14
|
import struct
|
|
11
15
|
import zlib
|
|
12
|
-
import
|
|
16
|
+
from typing import Any, Callable, Generic, List, Tuple, TypeVar, Union
|
|
13
17
|
|
|
14
|
-
|
|
15
|
-
from typing import Tuple
|
|
16
|
-
from typing import TypeVar
|
|
17
|
-
from typing import Generic
|
|
18
|
-
from typing import List
|
|
19
|
-
from typing import Callable
|
|
20
|
-
from typing import Union
|
|
18
|
+
import numpy as np
|
|
21
19
|
|
|
22
20
|
import warp
|
|
23
21
|
|
|
@@ -54,12 +52,14 @@ def constant(x):
|
|
|
54
52
|
global _constant_hash
|
|
55
53
|
|
|
56
54
|
# hash the constant value
|
|
57
|
-
if isinstance(x,
|
|
55
|
+
if isinstance(x, builtins.bool):
|
|
56
|
+
# This needs to come before the check for `int` since all boolean
|
|
57
|
+
# values are also instances of `int`.
|
|
58
|
+
_constant_hash.update(struct.pack("?", x))
|
|
59
|
+
elif isinstance(x, int):
|
|
58
60
|
_constant_hash.update(struct.pack("<q", x))
|
|
59
61
|
elif isinstance(x, float):
|
|
60
62
|
_constant_hash.update(struct.pack("<d", x))
|
|
61
|
-
elif isinstance(x, bool):
|
|
62
|
-
_constant_hash.update(struct.pack("?", x))
|
|
63
63
|
elif isinstance(x, float16):
|
|
64
64
|
# float16 is a special case
|
|
65
65
|
p = ctypes.pointer(ctypes.c_float(x.value))
|
|
@@ -149,28 +149,74 @@ def vector(length, dtype):
|
|
|
149
149
|
|
|
150
150
|
def __setitem__(self, key, value):
|
|
151
151
|
if isinstance(key, int):
|
|
152
|
-
|
|
153
|
-
|
|
152
|
+
try:
|
|
153
|
+
return super().__setitem__(key, vec_t.scalar_import(value))
|
|
154
|
+
except (TypeError, ctypes.ArgumentError):
|
|
155
|
+
raise TypeError(
|
|
156
|
+
f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
|
|
157
|
+
f"but got `{type(value).__name__}` instead"
|
|
158
|
+
) from None
|
|
154
159
|
elif isinstance(key, slice):
|
|
160
|
+
try:
|
|
161
|
+
iter(value)
|
|
162
|
+
except TypeError:
|
|
163
|
+
raise TypeError(
|
|
164
|
+
f"Expected to assign a slice from a sequence of values "
|
|
165
|
+
f"but got `{type(value).__name__}` instead"
|
|
166
|
+
) from None
|
|
167
|
+
|
|
155
168
|
if self._wp_scalar_type_ == float16:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
169
|
+
converted = []
|
|
170
|
+
try:
|
|
171
|
+
for x in value:
|
|
172
|
+
converted.append(vec_t.scalar_import(x))
|
|
173
|
+
except ctypes.ArgumentError:
|
|
174
|
+
raise TypeError(
|
|
175
|
+
f"Expected to assign a slice from a sequence of `float16` values "
|
|
176
|
+
f"but got `{type(x).__name__}` instead"
|
|
177
|
+
) from None
|
|
178
|
+
|
|
179
|
+
value = converted
|
|
180
|
+
|
|
181
|
+
try:
|
|
159
182
|
return super().__setitem__(key, value)
|
|
183
|
+
except TypeError:
|
|
184
|
+
for x in value:
|
|
185
|
+
try:
|
|
186
|
+
self._type_(x)
|
|
187
|
+
except TypeError:
|
|
188
|
+
raise TypeError(
|
|
189
|
+
f"Expected to assign a slice from a sequence of `{self._wp_scalar_type_.__name__}` values "
|
|
190
|
+
f"but got `{type(x).__name__}` instead"
|
|
191
|
+
) from None
|
|
160
192
|
else:
|
|
161
193
|
raise KeyError(f"Invalid key {key}, expected int or slice")
|
|
162
194
|
|
|
195
|
+
def __getattr__(self, name):
|
|
196
|
+
idx = "xyzw".find(name)
|
|
197
|
+
if idx != -1:
|
|
198
|
+
return self.__getitem__(idx)
|
|
199
|
+
|
|
200
|
+
return self.__getattribute__(name)
|
|
201
|
+
|
|
202
|
+
def __setattr__(self, name, value):
|
|
203
|
+
idx = "xyzw".find(name)
|
|
204
|
+
if idx != -1:
|
|
205
|
+
return self.__setitem__(idx, value)
|
|
206
|
+
|
|
207
|
+
return super().__setattr__(name, value)
|
|
208
|
+
|
|
163
209
|
def __add__(self, y):
|
|
164
210
|
return warp.add(self, y)
|
|
165
211
|
|
|
166
212
|
def __radd__(self, y):
|
|
167
|
-
return warp.add(
|
|
213
|
+
return warp.add(y, self)
|
|
168
214
|
|
|
169
215
|
def __sub__(self, y):
|
|
170
216
|
return warp.sub(self, y)
|
|
171
217
|
|
|
172
|
-
def __rsub__(self,
|
|
173
|
-
return warp.sub(
|
|
218
|
+
def __rsub__(self, y):
|
|
219
|
+
return warp.sub(y, self)
|
|
174
220
|
|
|
175
221
|
def __mul__(self, y):
|
|
176
222
|
return warp.mul(self, y)
|
|
@@ -178,17 +224,17 @@ def vector(length, dtype):
|
|
|
178
224
|
def __rmul__(self, x):
|
|
179
225
|
return warp.mul(x, self)
|
|
180
226
|
|
|
181
|
-
def
|
|
227
|
+
def __truediv__(self, y):
|
|
182
228
|
return warp.div(self, y)
|
|
183
229
|
|
|
184
|
-
def
|
|
230
|
+
def __rtruediv__(self, x):
|
|
185
231
|
return warp.div(x, self)
|
|
186
232
|
|
|
187
|
-
def __pos__(self
|
|
188
|
-
return warp.pos(self
|
|
233
|
+
def __pos__(self):
|
|
234
|
+
return warp.pos(self)
|
|
189
235
|
|
|
190
|
-
def __neg__(self
|
|
191
|
-
return warp.neg(self
|
|
236
|
+
def __neg__(self):
|
|
237
|
+
return warp.neg(self)
|
|
192
238
|
|
|
193
239
|
def __str__(self):
|
|
194
240
|
return f"[{', '.join(map(str, self))}]"
|
|
@@ -280,13 +326,13 @@ def matrix(shape, dtype):
|
|
|
280
326
|
return warp.add(self, y)
|
|
281
327
|
|
|
282
328
|
def __radd__(self, y):
|
|
283
|
-
return warp.add(
|
|
329
|
+
return warp.add(y, self)
|
|
284
330
|
|
|
285
331
|
def __sub__(self, y):
|
|
286
332
|
return warp.sub(self, y)
|
|
287
333
|
|
|
288
|
-
def __rsub__(self,
|
|
289
|
-
return warp.sub(
|
|
334
|
+
def __rsub__(self, y):
|
|
335
|
+
return warp.sub(y, self)
|
|
290
336
|
|
|
291
337
|
def __mul__(self, y):
|
|
292
338
|
return warp.mul(self, y)
|
|
@@ -300,17 +346,17 @@ def matrix(shape, dtype):
|
|
|
300
346
|
def __rmatmul__(self, x):
|
|
301
347
|
return warp.mul(x, self)
|
|
302
348
|
|
|
303
|
-
def
|
|
349
|
+
def __truediv__(self, y):
|
|
304
350
|
return warp.div(self, y)
|
|
305
351
|
|
|
306
|
-
def
|
|
352
|
+
def __rtruediv__(self, x):
|
|
307
353
|
return warp.div(x, self)
|
|
308
354
|
|
|
309
|
-
def __pos__(self
|
|
310
|
-
return warp.pos(self
|
|
355
|
+
def __pos__(self):
|
|
356
|
+
return warp.pos(self)
|
|
311
357
|
|
|
312
|
-
def __neg__(self
|
|
313
|
-
return warp.neg(self
|
|
358
|
+
def __neg__(self):
|
|
359
|
+
return warp.neg(self)
|
|
314
360
|
|
|
315
361
|
def __str__(self):
|
|
316
362
|
row_str = []
|
|
@@ -341,10 +387,28 @@ def matrix(shape, dtype):
|
|
|
341
387
|
def set_row(self, r, v):
|
|
342
388
|
if r < 0 or r >= self._shape_[0]:
|
|
343
389
|
raise IndexError("Invalid row index")
|
|
390
|
+
try:
|
|
391
|
+
iter(v)
|
|
392
|
+
except TypeError:
|
|
393
|
+
raise TypeError(
|
|
394
|
+
f"Expected to assign a slice from a sequence of values "
|
|
395
|
+
f"but got `{type(v).__name__}` instead"
|
|
396
|
+
) from None
|
|
397
|
+
|
|
344
398
|
row_start = r * self._shape_[1]
|
|
345
399
|
row_end = row_start + self._shape_[1]
|
|
346
400
|
if self._wp_scalar_type_ == float16:
|
|
347
|
-
|
|
401
|
+
converted = []
|
|
402
|
+
try:
|
|
403
|
+
for x in v:
|
|
404
|
+
converted.append(mat_t.scalar_import(x))
|
|
405
|
+
except ctypes.ArgumentError:
|
|
406
|
+
raise TypeError(
|
|
407
|
+
f"Expected to assign a slice from a sequence of `float16` values "
|
|
408
|
+
f"but got `{type(x).__name__}` instead"
|
|
409
|
+
) from None
|
|
410
|
+
|
|
411
|
+
v = converted
|
|
348
412
|
super().__setitem__(slice(row_start, row_end), v)
|
|
349
413
|
|
|
350
414
|
def __getitem__(self, key):
|
|
@@ -352,6 +416,8 @@ def matrix(shape, dtype):
|
|
|
352
416
|
# element indexing m[i,j]
|
|
353
417
|
if len(key) != 2:
|
|
354
418
|
raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
|
|
419
|
+
if any(isinstance(x, slice) for x in key):
|
|
420
|
+
raise KeyError(f"Slices are not supported when indexing matrices using the `m[i, j]` notation")
|
|
355
421
|
return mat_t.scalar_export(super().__getitem__(key[0] * self._shape_[1] + key[1]))
|
|
356
422
|
elif isinstance(key, int):
|
|
357
423
|
# row vector indexing m[r]
|
|
@@ -364,12 +430,20 @@ def matrix(shape, dtype):
|
|
|
364
430
|
# element indexing m[i,j] = x
|
|
365
431
|
if len(key) != 2:
|
|
366
432
|
raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
|
|
367
|
-
|
|
368
|
-
|
|
433
|
+
if any(isinstance(x, slice) for x in key):
|
|
434
|
+
raise KeyError(f"Slices are not supported when indexing matrices using the `m[i, j]` notation")
|
|
435
|
+
try:
|
|
436
|
+
return super().__setitem__(key[0] * self._shape_[1] + key[1], mat_t.scalar_import(value))
|
|
437
|
+
except (TypeError, ctypes.ArgumentError):
|
|
438
|
+
raise TypeError(
|
|
439
|
+
f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
|
|
440
|
+
f"but got `{type(value).__name__}` instead"
|
|
441
|
+
) from None
|
|
369
442
|
elif isinstance(key, int):
|
|
370
443
|
# row vector indexing m[r] = v
|
|
371
|
-
self.set_row(key, value)
|
|
372
|
-
|
|
444
|
+
return self.set_row(key, value)
|
|
445
|
+
elif isinstance(key, slice):
|
|
446
|
+
raise KeyError(f"Slices are not supported when indexing matrices using the `m[start:end]` notation")
|
|
373
447
|
else:
|
|
374
448
|
raise KeyError(f"Invalid key {key}, expected int or pair of ints")
|
|
375
449
|
|
|
@@ -392,6 +466,23 @@ class void:
|
|
|
392
466
|
pass
|
|
393
467
|
|
|
394
468
|
|
|
469
|
+
class bool:
|
|
470
|
+
_length_ = 1
|
|
471
|
+
_type_ = ctypes.c_bool
|
|
472
|
+
|
|
473
|
+
def __init__(self, x=False):
|
|
474
|
+
self.value = x
|
|
475
|
+
|
|
476
|
+
def __bool__(self) -> bool:
|
|
477
|
+
return self.value != 0
|
|
478
|
+
|
|
479
|
+
def __float__(self) -> float:
|
|
480
|
+
return float(self.value != 0)
|
|
481
|
+
|
|
482
|
+
def __int__(self) -> int:
|
|
483
|
+
return int(self.value != 0)
|
|
484
|
+
|
|
485
|
+
|
|
395
486
|
class float16:
|
|
396
487
|
_length_ = 1
|
|
397
488
|
_type_ = ctypes.c_uint16
|
|
@@ -399,6 +490,15 @@ class float16:
|
|
|
399
490
|
def __init__(self, x=0.0):
|
|
400
491
|
self.value = x
|
|
401
492
|
|
|
493
|
+
def __bool__(self) -> bool:
|
|
494
|
+
return self.value != 0.0
|
|
495
|
+
|
|
496
|
+
def __float__(self) -> float:
|
|
497
|
+
return float(self.value)
|
|
498
|
+
|
|
499
|
+
def __int__(self) -> int:
|
|
500
|
+
return int(self.value)
|
|
501
|
+
|
|
402
502
|
|
|
403
503
|
class float32:
|
|
404
504
|
_length_ = 1
|
|
@@ -407,6 +507,15 @@ class float32:
|
|
|
407
507
|
def __init__(self, x=0.0):
|
|
408
508
|
self.value = x
|
|
409
509
|
|
|
510
|
+
def __bool__(self) -> bool:
|
|
511
|
+
return self.value != 0.0
|
|
512
|
+
|
|
513
|
+
def __float__(self) -> float:
|
|
514
|
+
return float(self.value)
|
|
515
|
+
|
|
516
|
+
def __int__(self) -> int:
|
|
517
|
+
return int(self.value)
|
|
518
|
+
|
|
410
519
|
|
|
411
520
|
class float64:
|
|
412
521
|
_length_ = 1
|
|
@@ -415,6 +524,15 @@ class float64:
|
|
|
415
524
|
def __init__(self, x=0.0):
|
|
416
525
|
self.value = x
|
|
417
526
|
|
|
527
|
+
def __bool__(self) -> bool:
|
|
528
|
+
return self.value != 0.0
|
|
529
|
+
|
|
530
|
+
def __float__(self) -> float:
|
|
531
|
+
return float(self.value)
|
|
532
|
+
|
|
533
|
+
def __int__(self) -> int:
|
|
534
|
+
return int(self.value)
|
|
535
|
+
|
|
418
536
|
|
|
419
537
|
class int8:
|
|
420
538
|
_length_ = 1
|
|
@@ -423,6 +541,18 @@ class int8:
|
|
|
423
541
|
def __init__(self, x=0):
|
|
424
542
|
self.value = x
|
|
425
543
|
|
|
544
|
+
def __bool__(self) -> bool:
|
|
545
|
+
return self.value != 0
|
|
546
|
+
|
|
547
|
+
def __float__(self) -> float:
|
|
548
|
+
return float(self.value)
|
|
549
|
+
|
|
550
|
+
def __int__(self) -> int:
|
|
551
|
+
return int(self.value)
|
|
552
|
+
|
|
553
|
+
def __index__(self) -> int:
|
|
554
|
+
return int(self.value)
|
|
555
|
+
|
|
426
556
|
|
|
427
557
|
class uint8:
|
|
428
558
|
_length_ = 1
|
|
@@ -431,6 +561,18 @@ class uint8:
|
|
|
431
561
|
def __init__(self, x=0):
|
|
432
562
|
self.value = x
|
|
433
563
|
|
|
564
|
+
def __bool__(self) -> bool:
|
|
565
|
+
return self.value != 0
|
|
566
|
+
|
|
567
|
+
def __float__(self) -> float:
|
|
568
|
+
return float(self.value)
|
|
569
|
+
|
|
570
|
+
def __int__(self) -> int:
|
|
571
|
+
return int(self.value)
|
|
572
|
+
|
|
573
|
+
def __index__(self) -> int:
|
|
574
|
+
return int(self.value)
|
|
575
|
+
|
|
434
576
|
|
|
435
577
|
class int16:
|
|
436
578
|
_length_ = 1
|
|
@@ -439,6 +581,18 @@ class int16:
|
|
|
439
581
|
def __init__(self, x=0):
|
|
440
582
|
self.value = x
|
|
441
583
|
|
|
584
|
+
def __bool__(self) -> bool:
|
|
585
|
+
return self.value != 0
|
|
586
|
+
|
|
587
|
+
def __float__(self) -> float:
|
|
588
|
+
return float(self.value)
|
|
589
|
+
|
|
590
|
+
def __int__(self) -> int:
|
|
591
|
+
return int(self.value)
|
|
592
|
+
|
|
593
|
+
def __index__(self) -> int:
|
|
594
|
+
return int(self.value)
|
|
595
|
+
|
|
442
596
|
|
|
443
597
|
class uint16:
|
|
444
598
|
_length_ = 1
|
|
@@ -447,6 +601,18 @@ class uint16:
|
|
|
447
601
|
def __init__(self, x=0):
|
|
448
602
|
self.value = x
|
|
449
603
|
|
|
604
|
+
def __bool__(self) -> bool:
|
|
605
|
+
return self.value != 0
|
|
606
|
+
|
|
607
|
+
def __float__(self) -> float:
|
|
608
|
+
return float(self.value)
|
|
609
|
+
|
|
610
|
+
def __int__(self) -> int:
|
|
611
|
+
return int(self.value)
|
|
612
|
+
|
|
613
|
+
def __index__(self) -> int:
|
|
614
|
+
return int(self.value)
|
|
615
|
+
|
|
450
616
|
|
|
451
617
|
class int32:
|
|
452
618
|
_length_ = 1
|
|
@@ -455,6 +621,18 @@ class int32:
|
|
|
455
621
|
def __init__(self, x=0):
|
|
456
622
|
self.value = x
|
|
457
623
|
|
|
624
|
+
def __bool__(self) -> bool:
|
|
625
|
+
return self.value != 0
|
|
626
|
+
|
|
627
|
+
def __float__(self) -> float:
|
|
628
|
+
return float(self.value)
|
|
629
|
+
|
|
630
|
+
def __int__(self) -> int:
|
|
631
|
+
return int(self.value)
|
|
632
|
+
|
|
633
|
+
def __index__(self) -> int:
|
|
634
|
+
return int(self.value)
|
|
635
|
+
|
|
458
636
|
|
|
459
637
|
class uint32:
|
|
460
638
|
_length_ = 1
|
|
@@ -463,6 +641,18 @@ class uint32:
|
|
|
463
641
|
def __init__(self, x=0):
|
|
464
642
|
self.value = x
|
|
465
643
|
|
|
644
|
+
def __bool__(self) -> bool:
|
|
645
|
+
return self.value != 0
|
|
646
|
+
|
|
647
|
+
def __float__(self) -> float:
|
|
648
|
+
return float(self.value)
|
|
649
|
+
|
|
650
|
+
def __int__(self) -> int:
|
|
651
|
+
return int(self.value)
|
|
652
|
+
|
|
653
|
+
def __index__(self) -> int:
|
|
654
|
+
return int(self.value)
|
|
655
|
+
|
|
466
656
|
|
|
467
657
|
class int64:
|
|
468
658
|
_length_ = 1
|
|
@@ -471,6 +661,18 @@ class int64:
|
|
|
471
661
|
def __init__(self, x=0):
|
|
472
662
|
self.value = x
|
|
473
663
|
|
|
664
|
+
def __bool__(self) -> bool:
|
|
665
|
+
return self.value != 0
|
|
666
|
+
|
|
667
|
+
def __float__(self) -> float:
|
|
668
|
+
return float(self.value)
|
|
669
|
+
|
|
670
|
+
def __int__(self) -> int:
|
|
671
|
+
return int(self.value)
|
|
672
|
+
|
|
673
|
+
def __index__(self) -> int:
|
|
674
|
+
return int(self.value)
|
|
675
|
+
|
|
474
676
|
|
|
475
677
|
class uint64:
|
|
476
678
|
_length_ = 1
|
|
@@ -479,6 +681,18 @@ class uint64:
|
|
|
479
681
|
def __init__(self, x=0):
|
|
480
682
|
self.value = x
|
|
481
683
|
|
|
684
|
+
def __bool__(self) -> bool:
|
|
685
|
+
return self.value != 0
|
|
686
|
+
|
|
687
|
+
def __float__(self) -> float:
|
|
688
|
+
return float(self.value)
|
|
689
|
+
|
|
690
|
+
def __int__(self) -> int:
|
|
691
|
+
return int(self.value)
|
|
692
|
+
|
|
693
|
+
def __index__(self) -> int:
|
|
694
|
+
return int(self.value)
|
|
695
|
+
|
|
482
696
|
|
|
483
697
|
def quaternion(dtype=Any):
|
|
484
698
|
class quat_t(vector(length=4, dtype=dtype)):
|
|
@@ -508,23 +722,63 @@ class quatd(quaternion(dtype=float64)):
|
|
|
508
722
|
|
|
509
723
|
def transformation(dtype=Any):
|
|
510
724
|
class transform_t(vector(length=7, dtype=dtype)):
|
|
725
|
+
_wp_init_from_components_sig_ = inspect.Signature(
|
|
726
|
+
(
|
|
727
|
+
inspect.Parameter(
|
|
728
|
+
"p",
|
|
729
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
730
|
+
default=(0.0, 0.0, 0.0),
|
|
731
|
+
),
|
|
732
|
+
inspect.Parameter(
|
|
733
|
+
"q",
|
|
734
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
735
|
+
default=(0.0, 0.0, 0.0, 1.0),
|
|
736
|
+
),
|
|
737
|
+
),
|
|
738
|
+
)
|
|
511
739
|
_wp_type_params_ = [dtype]
|
|
512
740
|
_wp_generic_type_str_ = "transform_t"
|
|
513
741
|
_wp_constructor_ = "transformation"
|
|
514
742
|
|
|
515
|
-
def __init__(self,
|
|
516
|
-
|
|
743
|
+
def __init__(self, *args, **kwargs):
|
|
744
|
+
if len(args) == 1 and len(kwargs) == 0:
|
|
745
|
+
if getattr(args[0], "_wp_generic_type_str_") == self._wp_generic_type_str_:
|
|
746
|
+
# Copy constructor.
|
|
747
|
+
super().__init__(*args[0])
|
|
748
|
+
return
|
|
749
|
+
|
|
750
|
+
try:
|
|
751
|
+
# For backward compatibility, try to check if the arguments
|
|
752
|
+
# match the original signature that'd allow initializing
|
|
753
|
+
# the `p` and `q` components separately.
|
|
754
|
+
bound_args = self._wp_init_from_components_sig_.bind(*args, **kwargs)
|
|
755
|
+
bound_args.apply_defaults()
|
|
756
|
+
p, q = bound_args.args
|
|
757
|
+
except (TypeError, ValueError):
|
|
758
|
+
# Fallback to the vector's constructor.
|
|
759
|
+
super().__init__(*args)
|
|
760
|
+
return
|
|
761
|
+
|
|
762
|
+
# Even if the arguments match the original “from components”
|
|
763
|
+
# signature, we still need to make sure that they represent
|
|
764
|
+
# sequences that can be unpacked.
|
|
765
|
+
if hasattr(p, "__len__") and hasattr(q, "__len__"):
|
|
766
|
+
# Initialize from the `p` and `q` components.
|
|
767
|
+
super().__init__()
|
|
768
|
+
self[0:3] = vector(length=3, dtype=dtype)(*p)
|
|
769
|
+
self[3:7] = quaternion(dtype=dtype)(*q)
|
|
770
|
+
return
|
|
517
771
|
|
|
518
|
-
|
|
519
|
-
|
|
772
|
+
# Fallback to the vector's constructor.
|
|
773
|
+
super().__init__(*args)
|
|
520
774
|
|
|
521
775
|
@property
|
|
522
776
|
def p(self):
|
|
523
|
-
return self[0:3]
|
|
777
|
+
return vec3(self[0:3])
|
|
524
778
|
|
|
525
779
|
@property
|
|
526
780
|
def q(self):
|
|
527
|
-
return self[3:7]
|
|
781
|
+
return quat(self[3:7])
|
|
528
782
|
|
|
529
783
|
return transform_t
|
|
530
784
|
|
|
@@ -808,6 +1062,7 @@ vector_types = [
|
|
|
808
1062
|
]
|
|
809
1063
|
|
|
810
1064
|
np_dtype_to_warp_type = {
|
|
1065
|
+
np.dtype(np.bool_): bool,
|
|
811
1066
|
np.dtype(np.int8): int8,
|
|
812
1067
|
np.dtype(np.uint8): uint8,
|
|
813
1068
|
np.dtype(np.int16): int16,
|
|
@@ -824,6 +1079,7 @@ np_dtype_to_warp_type = {
|
|
|
824
1079
|
}
|
|
825
1080
|
|
|
826
1081
|
warp_type_to_np_dtype = {
|
|
1082
|
+
bool: np.bool_,
|
|
827
1083
|
int8: np.int8,
|
|
828
1084
|
int16: np.int16,
|
|
829
1085
|
int32: np.int32,
|
|
@@ -846,18 +1102,21 @@ class range_t:
|
|
|
846
1102
|
|
|
847
1103
|
# definition just for kernel type (cannot be a parameter), see bvh.h
|
|
848
1104
|
class bvh_query_t:
|
|
1105
|
+
"""Object used to track state during BVH traversal."""
|
|
849
1106
|
def __init__(self):
|
|
850
1107
|
pass
|
|
851
1108
|
|
|
852
1109
|
|
|
853
1110
|
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
854
1111
|
class mesh_query_aabb_t:
|
|
1112
|
+
"""Object used to track state during mesh traversal."""
|
|
855
1113
|
def __init__(self):
|
|
856
1114
|
pass
|
|
857
1115
|
|
|
858
1116
|
|
|
859
1117
|
# definition just for kernel type (cannot be a parameter), see hash_grid.h
|
|
860
1118
|
class hash_grid_query_t:
|
|
1119
|
+
"""Object used to track state during neighbor traversal."""
|
|
861
1120
|
def __init__(self):
|
|
862
1121
|
pass
|
|
863
1122
|
|
|
@@ -869,6 +1128,8 @@ LAUNCH_MAX_DIMS = 4
|
|
|
869
1128
|
# must match array.h
|
|
870
1129
|
ARRAY_TYPE_REGULAR = 0
|
|
871
1130
|
ARRAY_TYPE_INDEXED = 1
|
|
1131
|
+
ARRAY_TYPE_FABRIC = 2
|
|
1132
|
+
ARRAY_TYPE_FABRIC_INDEXED = 3
|
|
872
1133
|
|
|
873
1134
|
|
|
874
1135
|
# represents bounds for kernel launch (number of threads across multiple dimensions)
|
|
@@ -992,7 +1253,7 @@ def type_scalar_type(dtype):
|
|
|
992
1253
|
def type_size_in_bytes(dtype):
|
|
993
1254
|
if dtype.__module__ == "ctypes":
|
|
994
1255
|
return ctypes.sizeof(dtype)
|
|
995
|
-
elif
|
|
1256
|
+
elif isinstance(dtype, warp.codegen.Struct):
|
|
996
1257
|
return ctypes.sizeof(dtype.ctype)
|
|
997
1258
|
elif dtype == float or dtype == int:
|
|
998
1259
|
return 4
|
|
@@ -1013,9 +1274,9 @@ def type_to_warp(dtype):
|
|
|
1013
1274
|
|
|
1014
1275
|
|
|
1015
1276
|
def type_typestr(dtype):
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1277
|
+
if dtype == bool:
|
|
1278
|
+
return "?"
|
|
1279
|
+
elif dtype == float16:
|
|
1019
1280
|
return "<f2"
|
|
1020
1281
|
elif dtype == float32:
|
|
1021
1282
|
return "<f4"
|
|
@@ -1037,7 +1298,7 @@ def type_typestr(dtype):
|
|
|
1037
1298
|
return "<i8"
|
|
1038
1299
|
elif dtype == uint64:
|
|
1039
1300
|
return "<u8"
|
|
1040
|
-
elif isinstance(dtype, Struct):
|
|
1301
|
+
elif isinstance(dtype, warp.codegen.Struct):
|
|
1041
1302
|
return f"|V{ctypes.sizeof(dtype.ctype)}"
|
|
1042
1303
|
elif issubclass(dtype, ctypes.Array):
|
|
1043
1304
|
return type_typestr(dtype._wp_scalar_type_)
|
|
@@ -1051,9 +1312,16 @@ def type_repr(t):
|
|
|
1051
1312
|
return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
|
|
1052
1313
|
if type_is_vector(t):
|
|
1053
1314
|
return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
|
|
1054
|
-
|
|
1315
|
+
if type_is_matrix(t):
|
|
1055
1316
|
return str(f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={t._wp_scalar_type_})")
|
|
1056
|
-
|
|
1317
|
+
if isinstance(t, warp.codegen.Struct):
|
|
1318
|
+
return type_repr(t.cls)
|
|
1319
|
+
if t in scalar_types:
|
|
1320
|
+
return t.__name__
|
|
1321
|
+
|
|
1322
|
+
try:
|
|
1323
|
+
return t.__module__ + "." + t.__qualname__
|
|
1324
|
+
except AttributeError:
|
|
1057
1325
|
return str(t)
|
|
1058
1326
|
|
|
1059
1327
|
|
|
@@ -1071,15 +1339,6 @@ def type_is_float(t):
|
|
|
1071
1339
|
return t in float_types
|
|
1072
1340
|
|
|
1073
1341
|
|
|
1074
|
-
def type_is_struct(dtype):
|
|
1075
|
-
from warp.codegen import Struct
|
|
1076
|
-
|
|
1077
|
-
if isinstance(dtype, Struct):
|
|
1078
|
-
return True
|
|
1079
|
-
else:
|
|
1080
|
-
return False
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
1342
|
# returns True if the passed *type* is a vector
|
|
1084
1343
|
def type_is_vector(t):
|
|
1085
1344
|
if hasattr(t, "_wp_generic_type_str_") and t._wp_generic_type_str_ == "vec_t":
|
|
@@ -1098,7 +1357,7 @@ def type_is_matrix(t):
|
|
|
1098
1357
|
|
|
1099
1358
|
# returns true for all value types (int, float, bool, scalars, vectors, matrices)
|
|
1100
1359
|
def type_is_value(x):
|
|
1101
|
-
if (x == int) or (x == float) or (x == bool) or (x in scalar_types) or issubclass(x, ctypes.Array):
|
|
1360
|
+
if (x == int) or (x == float) or (x == builtins.bool) or (x in scalar_types) or issubclass(x, ctypes.Array):
|
|
1102
1361
|
return True
|
|
1103
1362
|
else:
|
|
1104
1363
|
return False
|
|
@@ -1126,14 +1385,16 @@ def types_equal(a, b, match_generic=False):
|
|
|
1126
1385
|
# convert to canonical types
|
|
1127
1386
|
if a == float:
|
|
1128
1387
|
a = float32
|
|
1129
|
-
|
|
1388
|
+
elif a == int:
|
|
1130
1389
|
a = int32
|
|
1131
1390
|
|
|
1132
1391
|
if b == float:
|
|
1133
1392
|
b = float32
|
|
1134
|
-
|
|
1393
|
+
elif b == int:
|
|
1135
1394
|
b = int32
|
|
1136
1395
|
|
|
1396
|
+
compatible_bool_types = [builtins.bool, bool]
|
|
1397
|
+
|
|
1137
1398
|
def are_equal(p1, p2):
|
|
1138
1399
|
if match_generic:
|
|
1139
1400
|
if p1 == Any or p2 == Any:
|
|
@@ -1150,7 +1411,22 @@ def types_equal(a, b, match_generic=False):
|
|
|
1150
1411
|
return True
|
|
1151
1412
|
if p1 == Float and p2 == Float:
|
|
1152
1413
|
return True
|
|
1153
|
-
|
|
1414
|
+
|
|
1415
|
+
# convert to canonical types
|
|
1416
|
+
if p1 == float:
|
|
1417
|
+
p1 = float32
|
|
1418
|
+
elif p1 == int:
|
|
1419
|
+
p1 = int32
|
|
1420
|
+
|
|
1421
|
+
if p2 == float:
|
|
1422
|
+
p2 = float32
|
|
1423
|
+
elif b == int:
|
|
1424
|
+
p2 = int32
|
|
1425
|
+
|
|
1426
|
+
if p1 in compatible_bool_types and p2 in compatible_bool_types:
|
|
1427
|
+
return True
|
|
1428
|
+
else:
|
|
1429
|
+
return p1 == p2
|
|
1154
1430
|
|
|
1155
1431
|
if (
|
|
1156
1432
|
hasattr(a, "_wp_generic_type_str_")
|
|
@@ -1158,9 +1434,7 @@ def types_equal(a, b, match_generic=False):
|
|
|
1158
1434
|
and a._wp_generic_type_str_ == b._wp_generic_type_str_
|
|
1159
1435
|
):
|
|
1160
1436
|
return all([are_equal(p1, p2) for p1, p2 in zip(a._wp_type_params_, b._wp_type_params_)])
|
|
1161
|
-
if
|
|
1162
|
-
return True
|
|
1163
|
-
if isinstance(a, indexedarray) and isinstance(b, indexedarray):
|
|
1437
|
+
if is_array(a) and type(a) is type(b):
|
|
1164
1438
|
return True
|
|
1165
1439
|
else:
|
|
1166
1440
|
return are_equal(a, b)
|
|
@@ -1244,6 +1518,7 @@ class array(Array):
|
|
|
1244
1518
|
self._grad = None
|
|
1245
1519
|
# __array_interface__ or __cuda_array_interface__, evaluated lazily and cached
|
|
1246
1520
|
self._array_interface = None
|
|
1521
|
+
self.is_transposed = False
|
|
1247
1522
|
|
|
1248
1523
|
# canonicalize dtype
|
|
1249
1524
|
if dtype == int:
|
|
@@ -1317,7 +1592,9 @@ class array(Array):
|
|
|
1317
1592
|
if isinstance(data, np.ndarray):
|
|
1318
1593
|
# construct from numpy structured array
|
|
1319
1594
|
if data.dtype != dtype.numpy_dtype():
|
|
1320
|
-
raise RuntimeError(
|
|
1595
|
+
raise RuntimeError(
|
|
1596
|
+
f"Invalid source data type for array of structs, expected {dtype.numpy_dtype()}, got {data.dtype}"
|
|
1597
|
+
)
|
|
1321
1598
|
arr = data
|
|
1322
1599
|
elif isinstance(data, (list, tuple)):
|
|
1323
1600
|
# construct from a sequence of structs
|
|
@@ -1329,9 +1606,13 @@ class array(Array):
|
|
|
1329
1606
|
# convert to numpy
|
|
1330
1607
|
arr = np.frombuffer(ctype_arr, dtype=dtype.ctype)
|
|
1331
1608
|
except Exception as e:
|
|
1332
|
-
raise RuntimeError(
|
|
1609
|
+
raise RuntimeError(
|
|
1610
|
+
f"Error while trying to construct Warp array from a sequence of Warp structs: {e}"
|
|
1611
|
+
)
|
|
1333
1612
|
else:
|
|
1334
|
-
raise RuntimeError(
|
|
1613
|
+
raise RuntimeError(
|
|
1614
|
+
"Invalid data argument for array of structs, expected a sequence of structs or a NumPy structured array"
|
|
1615
|
+
)
|
|
1335
1616
|
else:
|
|
1336
1617
|
# convert input data to the given dtype
|
|
1337
1618
|
npdtype = warp_type_to_np_dtype.get(scalar_dtype)
|
|
@@ -1416,7 +1697,7 @@ class array(Array):
|
|
|
1416
1697
|
|
|
1417
1698
|
def _init_from_ptr(self, ptr, dtype, shape, strides, capacity, device, owner, pinned):
|
|
1418
1699
|
if dtype == Any:
|
|
1419
|
-
raise RuntimeError(
|
|
1700
|
+
raise RuntimeError("A concrete data type is required to create the array")
|
|
1420
1701
|
|
|
1421
1702
|
device = warp.get_device(device)
|
|
1422
1703
|
|
|
@@ -1450,7 +1731,7 @@ class array(Array):
|
|
|
1450
1731
|
|
|
1451
1732
|
def _init_new(self, dtype, shape, strides, device, pinned):
|
|
1452
1733
|
if dtype == Any:
|
|
1453
|
-
raise RuntimeError(
|
|
1734
|
+
raise RuntimeError("A concrete data type is required to create the array")
|
|
1454
1735
|
|
|
1455
1736
|
device = warp.get_device(device)
|
|
1456
1737
|
|
|
@@ -1753,7 +2034,7 @@ class array(Array):
|
|
|
1753
2034
|
return self._requires_grad
|
|
1754
2035
|
|
|
1755
2036
|
@requires_grad.setter
|
|
1756
|
-
def requires_grad(self, value: bool):
|
|
2037
|
+
def requires_grad(self, value: builtins.bool):
|
|
1757
2038
|
if value and self._grad is None:
|
|
1758
2039
|
self._alloc_grad()
|
|
1759
2040
|
elif not value:
|
|
@@ -1778,12 +2059,11 @@ class array(Array):
|
|
|
1778
2059
|
# member attributes available during code-gen (e.g.: d = array.shape[0])
|
|
1779
2060
|
# Note: we use a shared dict for all array instances
|
|
1780
2061
|
if array._vars is None:
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
array._vars = {"shape": Var("shape", shape_t)}
|
|
2062
|
+
array._vars = {"shape": warp.codegen.Var("shape", shape_t)}
|
|
1784
2063
|
return array._vars
|
|
1785
2064
|
|
|
1786
2065
|
def zero_(self):
|
|
2066
|
+
"""Zeroes-out the array entries."""
|
|
1787
2067
|
if self.is_contiguous:
|
|
1788
2068
|
# simple memset is usually faster than generic fill
|
|
1789
2069
|
self.device.memset(self.ptr, 0, self.size * type_size_in_bytes(self.dtype))
|
|
@@ -1791,6 +2071,32 @@ class array(Array):
|
|
|
1791
2071
|
self.fill_(0)
|
|
1792
2072
|
|
|
1793
2073
|
def fill_(self, value):
|
|
2074
|
+
"""Set all array entries to `value`
|
|
2075
|
+
|
|
2076
|
+
args:
|
|
2077
|
+
value: The value to set every array entry to. Must be convertible to the array's ``dtype``.
|
|
2078
|
+
|
|
2079
|
+
Raises:
|
|
2080
|
+
ValueError: If `value` cannot be converted to the array's ``dtype``.
|
|
2081
|
+
|
|
2082
|
+
Examples:
|
|
2083
|
+
``fill_()`` can take lists or other sequences when filling arrays of vectors or matrices.
|
|
2084
|
+
|
|
2085
|
+
>>> arr = wp.zeros(2, dtype=wp.mat22)
|
|
2086
|
+
>>> arr.numpy()
|
|
2087
|
+
array([[[0., 0.],
|
|
2088
|
+
[0., 0.]],
|
|
2089
|
+
<BLANKLINE>
|
|
2090
|
+
[[0., 0.],
|
|
2091
|
+
[0., 0.]]], dtype=float32)
|
|
2092
|
+
>>> arr.fill_([[1, 2], [3, 4]])
|
|
2093
|
+
>>> arr.numpy()
|
|
2094
|
+
array([[[1., 2.],
|
|
2095
|
+
[3., 4.]],
|
|
2096
|
+
<BLANKLINE>
|
|
2097
|
+
[[1., 2.],
|
|
2098
|
+
[3., 4.]]], dtype=float32)
|
|
2099
|
+
"""
|
|
1794
2100
|
if self.size == 0:
|
|
1795
2101
|
return
|
|
1796
2102
|
|
|
@@ -1837,19 +2143,22 @@ class array(Array):
|
|
|
1837
2143
|
else:
|
|
1838
2144
|
warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
|
|
1839
2145
|
|
|
1840
|
-
# equivalent to wrapping src data in an array and copying to self
|
|
1841
2146
|
def assign(self, src):
|
|
2147
|
+
"""Wraps ``src`` in an :class:`warp.array` if it is not already one and copies the contents to ``self``."""
|
|
1842
2148
|
if is_array(src):
|
|
1843
2149
|
warp.copy(self, src)
|
|
1844
2150
|
else:
|
|
1845
2151
|
warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
|
|
1846
2152
|
|
|
1847
|
-
# convert array to ndarray (alias memory through array interface)
|
|
1848
2153
|
def numpy(self):
|
|
2154
|
+
"""Converts the array to a :class:`numpy.ndarray` (aliasing memory through the array interface protocol)
|
|
2155
|
+
If the array is on the GPU, a synchronous device-to-host copy (on the CUDA default stream) will be
|
|
2156
|
+
automatically performed to ensure that any outstanding work is completed.
|
|
2157
|
+
"""
|
|
1849
2158
|
if self.ptr:
|
|
1850
2159
|
# use the CUDA default stream for synchronous behaviour with other streams
|
|
1851
2160
|
with warp.ScopedStream(self.device.null_stream):
|
|
1852
|
-
a = self.to("cpu")
|
|
2161
|
+
a = self.to("cpu", requires_grad=False)
|
|
1853
2162
|
# convert through __array_interface__
|
|
1854
2163
|
# Note: this handles arrays of structs using `descr`, so the result will be a structured NumPy array
|
|
1855
2164
|
return np.array(a, copy=False)
|
|
@@ -1866,12 +2175,16 @@ class array(Array):
|
|
|
1866
2175
|
npshape = self.shape
|
|
1867
2176
|
return np.empty(npshape, dtype=npdtype)
|
|
1868
2177
|
|
|
1869
|
-
# return a ctypes cast of the array address
|
|
1870
|
-
# note #1: only CPU arrays support this method
|
|
1871
|
-
# note #2: the array must be contiguous
|
|
1872
|
-
# note #3: accesses to this object are *not* bounds checked
|
|
1873
|
-
# note #4: for float16 types, a pointer to the internal uint16 representation is returned
|
|
1874
2178
|
def cptr(self):
|
|
2179
|
+
"""Return a ctypes cast of the array address.
|
|
2180
|
+
|
|
2181
|
+
Notes:
|
|
2182
|
+
|
|
2183
|
+
#. Only CPU arrays support this method.
|
|
2184
|
+
#. The array must be contiguous.
|
|
2185
|
+
#. Accesses to this object are **not** bounds checked.
|
|
2186
|
+
#. For ``float16`` types, a pointer to the internal ``uint16`` representation is returned.
|
|
2187
|
+
"""
|
|
1875
2188
|
if not self.ptr:
|
|
1876
2189
|
return None
|
|
1877
2190
|
|
|
@@ -1890,8 +2203,8 @@ class array(Array):
|
|
|
1890
2203
|
|
|
1891
2204
|
return p
|
|
1892
2205
|
|
|
1893
|
-
# returns a flattened list of items in the array as a Python list
|
|
1894
2206
|
def list(self):
|
|
2207
|
+
"""Returns a flattened list of items in the array as a Python list."""
|
|
1895
2208
|
a = self.numpy()
|
|
1896
2209
|
|
|
1897
2210
|
if isinstance(self.dtype, warp.codegen.Struct):
|
|
@@ -1910,15 +2223,16 @@ class array(Array):
|
|
|
1910
2223
|
# scalar
|
|
1911
2224
|
return list(a.flatten())
|
|
1912
2225
|
|
|
1913
|
-
|
|
1914
|
-
|
|
2226
|
+
def to(self, device, requires_grad=None):
|
|
2227
|
+
"""Returns a Warp array with this array's data moved to the specified device, no-op if already on device."""
|
|
1915
2228
|
device = warp.get_device(device)
|
|
1916
2229
|
if self.device == device:
|
|
1917
2230
|
return self
|
|
1918
2231
|
else:
|
|
1919
|
-
return warp.clone(self, device=device)
|
|
2232
|
+
return warp.clone(self, device=device, requires_grad=requires_grad)
|
|
1920
2233
|
|
|
1921
2234
|
def flatten(self):
|
|
2235
|
+
"""Returns a zero-copy view of the array collapsed to 1-D. Only supported for contiguous arrays."""
|
|
1922
2236
|
if self.ndim == 1:
|
|
1923
2237
|
return self
|
|
1924
2238
|
|
|
@@ -1941,6 +2255,11 @@ class array(Array):
|
|
|
1941
2255
|
return a
|
|
1942
2256
|
|
|
1943
2257
|
def reshape(self, shape):
|
|
2258
|
+
"""Returns a reshaped array. Only supported for contiguous arrays.
|
|
2259
|
+
|
|
2260
|
+
Args:
|
|
2261
|
+
shape : An int or tuple of ints specifying the shape of the returned array.
|
|
2262
|
+
"""
|
|
1944
2263
|
if not self.is_contiguous:
|
|
1945
2264
|
raise RuntimeError("Reshaping non-contiguous arrays is unsupported.")
|
|
1946
2265
|
|
|
@@ -1998,6 +2317,9 @@ class array(Array):
|
|
|
1998
2317
|
return a
|
|
1999
2318
|
|
|
2000
2319
|
def view(self, dtype):
|
|
2320
|
+
"""Returns a zero-copy view of this array's memory with a different data type.
|
|
2321
|
+
``dtype`` must have the same byte size of the array's native ``dtype``.
|
|
2322
|
+
"""
|
|
2001
2323
|
if type_size_in_bytes(dtype) != type_size_in_bytes(self.dtype):
|
|
2002
2324
|
raise RuntimeError("Cannot cast dtypes of unequal byte size")
|
|
2003
2325
|
|
|
@@ -2018,6 +2340,7 @@ class array(Array):
|
|
|
2018
2340
|
return a
|
|
2019
2341
|
|
|
2020
2342
|
def contiguous(self):
|
|
2343
|
+
"""Returns a contiguous array with this array's data. No-op if array is already contiguous."""
|
|
2021
2344
|
if self.is_contiguous:
|
|
2022
2345
|
return self
|
|
2023
2346
|
|
|
@@ -2025,8 +2348,14 @@ class array(Array):
|
|
|
2025
2348
|
warp.copy(a, self)
|
|
2026
2349
|
return a
|
|
2027
2350
|
|
|
2028
|
-
# note: transpose operation will return an array with a non-contiguous access pattern
|
|
2029
2351
|
def transpose(self, axes=None):
|
|
2352
|
+
"""Returns an zero-copy view of the array with axes transposed.
|
|
2353
|
+
|
|
2354
|
+
Note: The transpose operation will return an array with a non-contiguous access pattern.
|
|
2355
|
+
|
|
2356
|
+
Args:
|
|
2357
|
+
axes (optional): Specifies the how the axes are permuted. If not specified, the axes order will be reversed.
|
|
2358
|
+
"""
|
|
2030
2359
|
# noop if 1d array
|
|
2031
2360
|
if self.ndim == 1:
|
|
2032
2361
|
return self
|
|
@@ -2059,6 +2388,8 @@ class array(Array):
|
|
|
2059
2388
|
grad=None if self.grad is None else self.grad.transpose(axes=axes),
|
|
2060
2389
|
)
|
|
2061
2390
|
|
|
2391
|
+
a.is_transposed = not self.is_transposed
|
|
2392
|
+
|
|
2062
2393
|
a._ref = self
|
|
2063
2394
|
return a
|
|
2064
2395
|
|
|
@@ -2093,7 +2424,7 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
|
|
|
2093
2424
|
dtype=dtype,
|
|
2094
2425
|
length=length,
|
|
2095
2426
|
capacity=length * type_size_in_bytes(dtype),
|
|
2096
|
-
ptr=ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
|
|
2427
|
+
ptr=0 if ptr == 0 else ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
|
|
2097
2428
|
shape=shape,
|
|
2098
2429
|
device=device,
|
|
2099
2430
|
owner=False,
|
|
@@ -2101,12 +2432,113 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
|
|
|
2101
2432
|
)
|
|
2102
2433
|
|
|
2103
2434
|
|
|
2104
|
-
class
|
|
2435
|
+
# A base class for non-contiguous arrays, providing the implementation of common methods like
|
|
2436
|
+
# contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
|
|
2437
|
+
class noncontiguous_array_base(Generic[T]):
|
|
2438
|
+
def __init__(self, array_type_id):
|
|
2439
|
+
self.type_id = array_type_id
|
|
2440
|
+
self.is_contiguous = False
|
|
2441
|
+
|
|
2442
|
+
# return a contiguous copy
|
|
2443
|
+
def contiguous(self):
|
|
2444
|
+
a = warp.empty_like(self)
|
|
2445
|
+
warp.copy(a, self)
|
|
2446
|
+
return a
|
|
2447
|
+
|
|
2448
|
+
# copy data from one device to another, nop if already on device
|
|
2449
|
+
def to(self, device):
|
|
2450
|
+
device = warp.get_device(device)
|
|
2451
|
+
if self.device == device:
|
|
2452
|
+
return self
|
|
2453
|
+
else:
|
|
2454
|
+
return warp.clone(self, device=device)
|
|
2455
|
+
|
|
2456
|
+
# return a contiguous numpy copy
|
|
2457
|
+
def numpy(self):
|
|
2458
|
+
# use the CUDA default stream for synchronous behaviour with other streams
|
|
2459
|
+
with warp.ScopedStream(self.device.null_stream):
|
|
2460
|
+
return self.contiguous().numpy()
|
|
2461
|
+
|
|
2462
|
+
# returns a flattened list of items in the array as a Python list
|
|
2463
|
+
def list(self):
|
|
2464
|
+
# use the CUDA default stream for synchronous behaviour with other streams
|
|
2465
|
+
with warp.ScopedStream(self.device.null_stream):
|
|
2466
|
+
return self.contiguous().list()
|
|
2467
|
+
|
|
2468
|
+
# equivalent to wrapping src data in an array and copying to self
|
|
2469
|
+
def assign(self, src):
|
|
2470
|
+
if is_array(src):
|
|
2471
|
+
warp.copy(self, src)
|
|
2472
|
+
else:
|
|
2473
|
+
warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
|
|
2474
|
+
|
|
2475
|
+
def zero_(self):
|
|
2476
|
+
self.fill_(0)
|
|
2477
|
+
|
|
2478
|
+
def fill_(self, value):
|
|
2479
|
+
if self.size == 0:
|
|
2480
|
+
return
|
|
2481
|
+
|
|
2482
|
+
# try to convert the given value to the array dtype
|
|
2483
|
+
try:
|
|
2484
|
+
if isinstance(self.dtype, warp.codegen.Struct):
|
|
2485
|
+
if isinstance(value, self.dtype.cls):
|
|
2486
|
+
cvalue = value.__ctype__()
|
|
2487
|
+
elif value == 0:
|
|
2488
|
+
# allow zero-initializing structs using default constructor
|
|
2489
|
+
cvalue = self.dtype().__ctype__()
|
|
2490
|
+
else:
|
|
2491
|
+
raise ValueError(
|
|
2492
|
+
f"Invalid initializer value for struct {self.dtype.cls.__name__}, expected struct instance or 0"
|
|
2493
|
+
)
|
|
2494
|
+
elif issubclass(self.dtype, ctypes.Array):
|
|
2495
|
+
# vector/matrix
|
|
2496
|
+
cvalue = self.dtype(value)
|
|
2497
|
+
else:
|
|
2498
|
+
# scalar
|
|
2499
|
+
if type(value) in warp.types.scalar_types:
|
|
2500
|
+
value = value.value
|
|
2501
|
+
if self.dtype == float16:
|
|
2502
|
+
cvalue = self.dtype._type_(float_to_half_bits(value))
|
|
2503
|
+
else:
|
|
2504
|
+
cvalue = self.dtype._type_(value)
|
|
2505
|
+
except Exception as e:
|
|
2506
|
+
raise ValueError(f"Failed to convert the value to the array data type: {e}")
|
|
2507
|
+
|
|
2508
|
+
cvalue_ptr = ctypes.pointer(cvalue)
|
|
2509
|
+
cvalue_size = ctypes.sizeof(cvalue)
|
|
2510
|
+
|
|
2511
|
+
ctype = self.__ctype__()
|
|
2512
|
+
ctype_ptr = ctypes.pointer(ctype)
|
|
2513
|
+
|
|
2514
|
+
if self.device.is_cuda:
|
|
2515
|
+
warp.context.runtime.core.array_fill_device(
|
|
2516
|
+
self.device.context, ctype_ptr, self.type_id, cvalue_ptr, cvalue_size
|
|
2517
|
+
)
|
|
2518
|
+
else:
|
|
2519
|
+
warp.context.runtime.core.array_fill_host(ctype_ptr, self.type_id, cvalue_ptr, cvalue_size)
|
|
2520
|
+
|
|
2521
|
+
|
|
2522
|
+
# helper to check index array properties
|
|
2523
|
+
def check_index_array(indices, expected_device):
|
|
2524
|
+
if not isinstance(indices, array):
|
|
2525
|
+
raise ValueError(f"Indices must be a Warp array, got {type(indices)}")
|
|
2526
|
+
if indices.ndim != 1:
|
|
2527
|
+
raise ValueError(f"Index array must be one-dimensional, got {indices.ndim}")
|
|
2528
|
+
if indices.dtype != int32:
|
|
2529
|
+
raise ValueError(f"Index array must use int32, got dtype {indices.dtype}")
|
|
2530
|
+
if indices.device != expected_device:
|
|
2531
|
+
raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
|
|
2532
|
+
|
|
2533
|
+
|
|
2534
|
+
class indexedarray(noncontiguous_array_base[T]):
|
|
2105
2535
|
# member attributes available during code-gen (e.g.: d = arr.shape[0])
|
|
2106
2536
|
# (initialized when needed)
|
|
2107
2537
|
_vars = None
|
|
2108
2538
|
|
|
2109
2539
|
def __init__(self, data: array = None, indices: Union[array, List[array]] = None, dtype=None, ndim=None):
|
|
2540
|
+
super().__init__(ARRAY_TYPE_INDEXED)
|
|
2541
|
+
|
|
2110
2542
|
# canonicalize types
|
|
2111
2543
|
if dtype is not None:
|
|
2112
2544
|
if dtype == int:
|
|
@@ -2136,17 +2568,6 @@ class indexedarray(Generic[T]):
|
|
|
2136
2568
|
shape = list(data.shape)
|
|
2137
2569
|
|
|
2138
2570
|
if indices is not None:
|
|
2139
|
-
# helper to check index array properties
|
|
2140
|
-
def check_index_array(inds, data):
|
|
2141
|
-
if inds.ndim != 1:
|
|
2142
|
-
raise ValueError(f"Index array must be one-dimensional, got {inds.ndim}")
|
|
2143
|
-
if inds.dtype != int32:
|
|
2144
|
-
raise ValueError(f"Index array must use int32, got dtype {inds.dtype}")
|
|
2145
|
-
if inds.device != data.device:
|
|
2146
|
-
raise ValueError(
|
|
2147
|
-
f"Index array device ({inds.device} does not match data array device ({data.device}))"
|
|
2148
|
-
)
|
|
2149
|
-
|
|
2150
2571
|
if isinstance(indices, (list, tuple)):
|
|
2151
2572
|
if len(indices) > self.ndim:
|
|
2152
2573
|
raise ValueError(
|
|
@@ -2154,16 +2575,14 @@ class indexedarray(Generic[T]):
|
|
|
2154
2575
|
)
|
|
2155
2576
|
|
|
2156
2577
|
for i in range(len(indices)):
|
|
2157
|
-
if
|
|
2158
|
-
check_index_array(indices[i], data)
|
|
2578
|
+
if indices[i] is not None:
|
|
2579
|
+
check_index_array(indices[i], data.device)
|
|
2159
2580
|
self.indices[i] = indices[i]
|
|
2160
2581
|
shape[i] = len(indices[i])
|
|
2161
|
-
elif indices[i] is not None:
|
|
2162
|
-
raise TypeError(f"Invalid index array type: {type(indices[i])}")
|
|
2163
2582
|
|
|
2164
2583
|
elif isinstance(indices, array):
|
|
2165
2584
|
# only a single index array was provided
|
|
2166
|
-
check_index_array(indices, data)
|
|
2585
|
+
check_index_array(indices, data.device)
|
|
2167
2586
|
self.indices[0] = indices
|
|
2168
2587
|
shape[0] = len(indices)
|
|
2169
2588
|
|
|
@@ -2185,8 +2604,6 @@ class indexedarray(Generic[T]):
|
|
|
2185
2604
|
for d in self.shape:
|
|
2186
2605
|
self.size *= d
|
|
2187
2606
|
|
|
2188
|
-
self.is_contiguous = False
|
|
2189
|
-
|
|
2190
2607
|
def __len__(self):
|
|
2191
2608
|
return self.shape[0]
|
|
2192
2609
|
|
|
@@ -2206,89 +2623,9 @@ class indexedarray(Generic[T]):
|
|
|
2206
2623
|
# member attributes available during code-gen (e.g.: d = arr.shape[0])
|
|
2207
2624
|
# Note: we use a shared dict for all indexedarray instances
|
|
2208
2625
|
if indexedarray._vars is None:
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
indexedarray._vars = {"shape": Var("shape", shape_t)}
|
|
2626
|
+
indexedarray._vars = {"shape": warp.codegen.Var("shape", shape_t)}
|
|
2212
2627
|
return indexedarray._vars
|
|
2213
2628
|
|
|
2214
|
-
def contiguous(self):
|
|
2215
|
-
a = warp.empty_like(self)
|
|
2216
|
-
warp.copy(a, self)
|
|
2217
|
-
return a
|
|
2218
|
-
|
|
2219
|
-
# convert data from one device to another, nop if already on device
|
|
2220
|
-
def to(self, device):
|
|
2221
|
-
device = warp.get_device(device)
|
|
2222
|
-
if self.device == device:
|
|
2223
|
-
return self
|
|
2224
|
-
else:
|
|
2225
|
-
return warp.clone(self, device=device)
|
|
2226
|
-
|
|
2227
|
-
# return a contiguous numpy copy
|
|
2228
|
-
def numpy(self):
|
|
2229
|
-
# use the CUDA default stream for synchronous behaviour with other streams
|
|
2230
|
-
with warp.ScopedStream(self.device.null_stream):
|
|
2231
|
-
return self.contiguous().numpy()
|
|
2232
|
-
|
|
2233
|
-
# returns a flattened list of items in the array as a Python list
|
|
2234
|
-
def list(self):
|
|
2235
|
-
# use the CUDA default stream for synchronous behaviour with other streams
|
|
2236
|
-
with warp.ScopedStream(self.device.null_stream):
|
|
2237
|
-
return self.contiguous().list()
|
|
2238
|
-
|
|
2239
|
-
def zero_(self):
|
|
2240
|
-
self.fill_(0)
|
|
2241
|
-
|
|
2242
|
-
def fill_(self, value):
|
|
2243
|
-
if self.size == 0:
|
|
2244
|
-
return
|
|
2245
|
-
|
|
2246
|
-
# try to convert the given value to the array dtype
|
|
2247
|
-
try:
|
|
2248
|
-
if isinstance(self.dtype, warp.codegen.Struct):
|
|
2249
|
-
if isinstance(value, self.dtype.cls):
|
|
2250
|
-
cvalue = value.__ctype__()
|
|
2251
|
-
elif value == 0:
|
|
2252
|
-
# allow zero-initializing structs using default constructor
|
|
2253
|
-
cvalue = self.dtype().__ctype__()
|
|
2254
|
-
else:
|
|
2255
|
-
raise ValueError(
|
|
2256
|
-
f"Invalid initializer value for struct {self.dtype.cls.__name__}, expected struct instance or 0"
|
|
2257
|
-
)
|
|
2258
|
-
elif issubclass(self.dtype, ctypes.Array):
|
|
2259
|
-
# vector/matrix
|
|
2260
|
-
cvalue = self.dtype(value)
|
|
2261
|
-
else:
|
|
2262
|
-
# scalar
|
|
2263
|
-
if type(value) in warp.types.scalar_types:
|
|
2264
|
-
value = value.value
|
|
2265
|
-
if self.dtype == float16:
|
|
2266
|
-
cvalue = self.dtype._type_(float_to_half_bits(value))
|
|
2267
|
-
else:
|
|
2268
|
-
cvalue = self.dtype._type_(value)
|
|
2269
|
-
except Exception as e:
|
|
2270
|
-
raise ValueError(f"Failed to convert the value to the array data type: {e}")
|
|
2271
|
-
|
|
2272
|
-
cvalue_ptr = ctypes.pointer(cvalue)
|
|
2273
|
-
cvalue_size = ctypes.sizeof(cvalue)
|
|
2274
|
-
|
|
2275
|
-
ctype = self.__ctype__()
|
|
2276
|
-
ctype_ptr = ctypes.pointer(ctype)
|
|
2277
|
-
|
|
2278
|
-
if self.device.is_cuda:
|
|
2279
|
-
warp.context.runtime.core.array_fill_device(
|
|
2280
|
-
self.device.context, ctype_ptr, ARRAY_TYPE_INDEXED, cvalue_ptr, cvalue_size
|
|
2281
|
-
)
|
|
2282
|
-
else:
|
|
2283
|
-
warp.context.runtime.core.array_fill_host(ctype_ptr, ARRAY_TYPE_INDEXED, cvalue_ptr, cvalue_size)
|
|
2284
|
-
|
|
2285
|
-
# equivalent to wrapping src data in an array and copying to self
|
|
2286
|
-
def assign(self, src):
|
|
2287
|
-
if is_array(src):
|
|
2288
|
-
warp.copy(self, src)
|
|
2289
|
-
else:
|
|
2290
|
-
warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
|
|
2291
|
-
|
|
2292
2629
|
|
|
2293
2630
|
# aliases for indexedarrays with small dimensions
|
|
2294
2631
|
def indexedarray1d(*args, **kwargs):
|
|
@@ -2314,16 +2651,22 @@ def indexedarray4d(*args, **kwargs):
|
|
|
2314
2651
|
return indexedarray(*args, **kwargs)
|
|
2315
2652
|
|
|
2316
2653
|
|
|
2317
|
-
|
|
2654
|
+
from warp.fabric import fabricarray, indexedfabricarray # noqa: E402
|
|
2655
|
+
|
|
2656
|
+
array_types = (array, indexedarray, fabricarray, indexedfabricarray)
|
|
2318
2657
|
|
|
2319
2658
|
|
|
2320
2659
|
def array_type_id(a):
|
|
2321
|
-
if isinstance(a,
|
|
2322
|
-
return
|
|
2323
|
-
elif isinstance(a,
|
|
2324
|
-
return
|
|
2660
|
+
if isinstance(a, array):
|
|
2661
|
+
return ARRAY_TYPE_REGULAR
|
|
2662
|
+
elif isinstance(a, indexedarray):
|
|
2663
|
+
return ARRAY_TYPE_INDEXED
|
|
2664
|
+
elif isinstance(a, fabricarray):
|
|
2665
|
+
return ARRAY_TYPE_FABRIC
|
|
2666
|
+
elif isinstance(a, indexedfabricarray):
|
|
2667
|
+
return ARRAY_TYPE_FABRIC_INDEXED
|
|
2325
2668
|
else:
|
|
2326
|
-
raise ValueError(
|
|
2669
|
+
raise ValueError("Invalid array type")
|
|
2327
2670
|
|
|
2328
2671
|
|
|
2329
2672
|
class Bvh:
|
|
@@ -2381,11 +2724,11 @@ class Bvh:
|
|
|
2381
2724
|
with self.device.context_guard:
|
|
2382
2725
|
runtime.core.bvh_destroy_device(self.id)
|
|
2383
2726
|
|
|
2384
|
-
except:
|
|
2727
|
+
except Exception:
|
|
2385
2728
|
pass
|
|
2386
2729
|
|
|
2387
2730
|
def refit(self):
|
|
2388
|
-
"""Refit the
|
|
2731
|
+
"""Refit the BVH. This should be called after users modify the `lowers` and `uppers` arrays."""
|
|
2389
2732
|
|
|
2390
2733
|
from warp.context import runtime
|
|
2391
2734
|
|
|
@@ -2471,7 +2814,7 @@ class Mesh:
|
|
|
2471
2814
|
# use CUDA context guard to avoid side effects during garbage collection
|
|
2472
2815
|
with self.device.context_guard:
|
|
2473
2816
|
runtime.core.mesh_destroy_device(self.id)
|
|
2474
|
-
except:
|
|
2817
|
+
except Exception:
|
|
2475
2818
|
pass
|
|
2476
2819
|
|
|
2477
2820
|
def refit(self):
|
|
@@ -2487,16 +2830,14 @@ class Mesh:
|
|
|
2487
2830
|
|
|
2488
2831
|
|
|
2489
2832
|
class Volume:
|
|
2833
|
+
#: Enum value to specify nearest-neighbor interpolation during sampling
|
|
2490
2834
|
CLOSEST = constant(0)
|
|
2835
|
+
#: Enum value to specify trilinear interpolation during sampling
|
|
2491
2836
|
LINEAR = constant(1)
|
|
2492
2837
|
|
|
2493
2838
|
def __init__(self, data: array):
|
|
2494
2839
|
"""Class representing a sparse grid.
|
|
2495
2840
|
|
|
2496
|
-
Attributes:
|
|
2497
|
-
CLOSEST (int): Enum value to specify nearest-neighbor interpolation during sampling
|
|
2498
|
-
LINEAR (int): Enum value to specify trilinear interpolation during sampling
|
|
2499
|
-
|
|
2500
2841
|
Args:
|
|
2501
2842
|
data (:class:`warp.array`): Array of bytes representing the volume in NanoVDB format
|
|
2502
2843
|
"""
|
|
@@ -2538,10 +2879,11 @@ class Volume:
|
|
|
2538
2879
|
with self.device.context_guard:
|
|
2539
2880
|
runtime.core.volume_destroy_device(self.id)
|
|
2540
2881
|
|
|
2541
|
-
except:
|
|
2882
|
+
except Exception:
|
|
2542
2883
|
pass
|
|
2543
2884
|
|
|
2544
|
-
def array(self):
|
|
2885
|
+
def array(self) -> array:
|
|
2886
|
+
"""Returns the raw memory buffer of the Volume as an array"""
|
|
2545
2887
|
buf = ctypes.c_void_p(0)
|
|
2546
2888
|
size = ctypes.c_uint64(0)
|
|
2547
2889
|
if self.device.is_cpu:
|
|
@@ -2550,7 +2892,7 @@ class Volume:
|
|
|
2550
2892
|
self.context.core.volume_get_buffer_info_device(self.id, ctypes.byref(buf), ctypes.byref(size))
|
|
2551
2893
|
return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
|
|
2552
2894
|
|
|
2553
|
-
def get_tiles(self):
|
|
2895
|
+
def get_tiles(self) -> array:
|
|
2554
2896
|
if self.id == 0:
|
|
2555
2897
|
raise RuntimeError("Invalid Volume")
|
|
2556
2898
|
|
|
@@ -2563,7 +2905,7 @@ class Volume:
|
|
|
2563
2905
|
num_tiles = size.value // (3 * 4)
|
|
2564
2906
|
return array(ptr=buf.value, dtype=int32, shape=(num_tiles, 3), device=self.device, owner=True)
|
|
2565
2907
|
|
|
2566
|
-
def get_voxel_size(self):
|
|
2908
|
+
def get_voxel_size(self) -> Tuple[float, float, float]:
|
|
2567
2909
|
if self.id == 0:
|
|
2568
2910
|
raise RuntimeError("Invalid Volume")
|
|
2569
2911
|
|
|
@@ -2572,7 +2914,13 @@ class Volume:
|
|
|
2572
2914
|
return (dx.value, dy.value, dz.value)
|
|
2573
2915
|
|
|
2574
2916
|
@classmethod
|
|
2575
|
-
def load_from_nvdb(cls, file_or_buffer, device=None):
|
|
2917
|
+
def load_from_nvdb(cls, file_or_buffer, device=None) -> Volume:
|
|
2918
|
+
"""Creates a Volume object from a NanoVDB file or in-memory buffer.
|
|
2919
|
+
|
|
2920
|
+
Returns:
|
|
2921
|
+
|
|
2922
|
+
A ``warp.Volume`` object.
|
|
2923
|
+
"""
|
|
2576
2924
|
try:
|
|
2577
2925
|
data = file_or_buffer.read()
|
|
2578
2926
|
except AttributeError:
|
|
@@ -2601,6 +2949,90 @@ class Volume:
|
|
|
2601
2949
|
data_array = array(np.frombuffer(grid_data, dtype=np.byte), device=device)
|
|
2602
2950
|
return cls(data_array)
|
|
2603
2951
|
|
|
2952
|
+
@classmethod
|
|
2953
|
+
def load_from_numpy(
|
|
2954
|
+
cls, ndarray: np.array, min_world=(0.0, 0.0, 0.0), voxel_size=1.0, bg_value=0.0, device=None
|
|
2955
|
+
) -> Volume:
|
|
2956
|
+
"""Creates a Volume object from a dense 3D NumPy array.
|
|
2957
|
+
|
|
2958
|
+
This function is only supported for CUDA devices.
|
|
2959
|
+
|
|
2960
|
+
Args:
|
|
2961
|
+
min_world: The 3D coordinate of the lower corner of the volume.
|
|
2962
|
+
voxel_size: The size of each voxel in spatial coordinates.
|
|
2963
|
+
bg_value: Background value
|
|
2964
|
+
device: The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2965
|
+
|
|
2966
|
+
Returns:
|
|
2967
|
+
|
|
2968
|
+
A ``warp.Volume`` object.
|
|
2969
|
+
"""
|
|
2970
|
+
|
|
2971
|
+
import math
|
|
2972
|
+
|
|
2973
|
+
target_shape = (
|
|
2974
|
+
math.ceil(ndarray.shape[0] / 8) * 8,
|
|
2975
|
+
math.ceil(ndarray.shape[1] / 8) * 8,
|
|
2976
|
+
math.ceil(ndarray.shape[2] / 8) * 8,
|
|
2977
|
+
)
|
|
2978
|
+
if hasattr(bg_value, "__len__"):
|
|
2979
|
+
# vec3, assuming the numpy array is 4D
|
|
2980
|
+
padded_array = np.array((target_shape[0], target_shape[1], target_shape[2], 3), dtype=np.single)
|
|
2981
|
+
padded_array[:, :, :, :] = np.array(bg_value)
|
|
2982
|
+
padded_array[0 : ndarray.shape[0], 0 : ndarray.shape[1], 0 : ndarray.shape[2], :] = ndarray
|
|
2983
|
+
else:
|
|
2984
|
+
padded_amount = (
|
|
2985
|
+
math.ceil(ndarray.shape[0] / 8) * 8 - ndarray.shape[0],
|
|
2986
|
+
math.ceil(ndarray.shape[1] / 8) * 8 - ndarray.shape[1],
|
|
2987
|
+
math.ceil(ndarray.shape[2] / 8) * 8 - ndarray.shape[2],
|
|
2988
|
+
)
|
|
2989
|
+
padded_array = np.pad(
|
|
2990
|
+
ndarray,
|
|
2991
|
+
((0, padded_amount[0]), (0, padded_amount[1]), (0, padded_amount[2])),
|
|
2992
|
+
mode="constant",
|
|
2993
|
+
constant_values=bg_value,
|
|
2994
|
+
)
|
|
2995
|
+
|
|
2996
|
+
shape = padded_array.shape
|
|
2997
|
+
volume = warp.Volume.allocate(
|
|
2998
|
+
min_world,
|
|
2999
|
+
[
|
|
3000
|
+
min_world[0] + (shape[0] - 1) * voxel_size,
|
|
3001
|
+
min_world[1] + (shape[1] - 1) * voxel_size,
|
|
3002
|
+
min_world[2] + (shape[2] - 1) * voxel_size,
|
|
3003
|
+
],
|
|
3004
|
+
voxel_size,
|
|
3005
|
+
bg_value=bg_value,
|
|
3006
|
+
points_in_world_space=True,
|
|
3007
|
+
translation=min_world,
|
|
3008
|
+
device=device,
|
|
3009
|
+
)
|
|
3010
|
+
|
|
3011
|
+
# Populate volume
|
|
3012
|
+
if hasattr(bg_value, "__len__"):
|
|
3013
|
+
warp.launch(
|
|
3014
|
+
warp.utils.copy_dense_volume_to_nano_vdb_v,
|
|
3015
|
+
dim=(shape[0], shape[1], shape[2]),
|
|
3016
|
+
inputs=[volume.id, warp.array(padded_array, dtype=warp.vec3, device=device)],
|
|
3017
|
+
device=device,
|
|
3018
|
+
)
|
|
3019
|
+
elif isinstance(bg_value, int):
|
|
3020
|
+
warp.launch(
|
|
3021
|
+
warp.utils.copy_dense_volume_to_nano_vdb_i,
|
|
3022
|
+
dim=shape,
|
|
3023
|
+
inputs=[volume.id, warp.array(padded_array, dtype=warp.int32, device=device)],
|
|
3024
|
+
device=device,
|
|
3025
|
+
)
|
|
3026
|
+
else:
|
|
3027
|
+
warp.launch(
|
|
3028
|
+
warp.utils.copy_dense_volume_to_nano_vdb_f,
|
|
3029
|
+
dim=shape,
|
|
3030
|
+
inputs=[volume.id, warp.array(padded_array, dtype=warp.float32, device=device)],
|
|
3031
|
+
device=device,
|
|
3032
|
+
)
|
|
3033
|
+
|
|
3034
|
+
return volume
|
|
3035
|
+
|
|
2604
3036
|
@classmethod
|
|
2605
3037
|
def allocate(
|
|
2606
3038
|
cls,
|
|
@@ -2611,9 +3043,11 @@ class Volume:
|
|
|
2611
3043
|
translation=(0.0, 0.0, 0.0),
|
|
2612
3044
|
points_in_world_space=False,
|
|
2613
3045
|
device=None,
|
|
2614
|
-
):
|
|
3046
|
+
) -> Volume:
|
|
2615
3047
|
"""Allocate a new Volume based on the bounding box defined by min and max.
|
|
2616
3048
|
|
|
3049
|
+
This function is only supported for CUDA devices.
|
|
3050
|
+
|
|
2617
3051
|
Allocate a volume that is large enough to contain voxels [min[0], min[1], min[2]] - [max[0], max[1], max[2]], inclusive.
|
|
2618
3052
|
If points_in_world_space is true, then min and max are first converted to index space with the given voxel size and
|
|
2619
3053
|
translation, and the volume is allocated with those.
|
|
@@ -2622,12 +3056,12 @@ class Volume:
|
|
|
2622
3056
|
the resulting tiles will be available in the new volume.
|
|
2623
3057
|
|
|
2624
3058
|
Args:
|
|
2625
|
-
min (array-like): Lower 3D
|
|
2626
|
-
max (array-like): Upper 3D
|
|
2627
|
-
voxel_size (float): Voxel size of the new volume
|
|
3059
|
+
min (array-like): Lower 3D coordinates of the bounding box in index space or world space, inclusive.
|
|
3060
|
+
max (array-like): Upper 3D coordinates of the bounding box in index space or world space, inclusive.
|
|
3061
|
+
voxel_size (float): Voxel size of the new volume.
|
|
2628
3062
|
bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
|
|
2629
|
-
translation (array-like): translation between the index and world spaces
|
|
2630
|
-
device (Devicelike):
|
|
3063
|
+
translation (array-like): translation between the index and world spaces.
|
|
3064
|
+
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2631
3065
|
|
|
2632
3066
|
"""
|
|
2633
3067
|
if points_in_world_space:
|
|
@@ -2652,9 +3086,11 @@ class Volume:
|
|
|
2652
3086
|
@classmethod
|
|
2653
3087
|
def allocate_by_tiles(
|
|
2654
3088
|
cls, tile_points: array, voxel_size: float, bg_value=0.0, translation=(0.0, 0.0, 0.0), device=None
|
|
2655
|
-
):
|
|
3089
|
+
) -> Volume:
|
|
2656
3090
|
"""Allocate a new Volume with active tiles for each point tile_points.
|
|
2657
3091
|
|
|
3092
|
+
This function is only supported for CUDA devices.
|
|
3093
|
+
|
|
2658
3094
|
The smallest unit of allocation is a dense tile of 8x8x8 voxels.
|
|
2659
3095
|
This is the primary method for allocating sparse volumes. It uses an array of points indicating the tiles that must be allocated.
|
|
2660
3096
|
|
|
@@ -2664,13 +3100,13 @@ class Volume:
|
|
|
2664
3100
|
|
|
2665
3101
|
Args:
|
|
2666
3102
|
tile_points (:class:`warp.array`): Array of positions that define the tiles to be allocated.
|
|
2667
|
-
The array can be a
|
|
3103
|
+
The array can be a 2D, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
|
|
2668
3104
|
or can be a 1D array of :class:`warp.vec3` values, indicating world space positions.
|
|
2669
3105
|
Repeated points per tile are allowed and will be efficiently deduplicated.
|
|
2670
|
-
voxel_size (float): Voxel size of the new volume
|
|
3106
|
+
voxel_size (float): Voxel size of the new volume.
|
|
2671
3107
|
bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
|
|
2672
|
-
translation (array-like):
|
|
2673
|
-
device (Devicelike):
|
|
3108
|
+
translation (array-like): Translation between the index and world spaces.
|
|
3109
|
+
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2674
3110
|
|
|
2675
3111
|
"""
|
|
2676
3112
|
from warp.context import runtime
|
|
@@ -2707,7 +3143,7 @@ class Volume:
|
|
|
2707
3143
|
translation[2],
|
|
2708
3144
|
in_world_space,
|
|
2709
3145
|
)
|
|
2710
|
-
elif
|
|
3146
|
+
elif isinstance(bg_value, int):
|
|
2711
3147
|
volume.id = volume.context.core.volume_i_from_tiles_device(
|
|
2712
3148
|
volume.device.context,
|
|
2713
3149
|
ctypes.c_void_p(tile_points.ptr),
|
|
@@ -2738,6 +3174,67 @@ class Volume:
|
|
|
2738
3174
|
return volume
|
|
2739
3175
|
|
|
2740
3176
|
|
|
3177
|
+
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
3178
|
+
# NOTE: its layout must match the corresponding struct defined in C.
|
|
3179
|
+
# NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
|
|
3180
|
+
class mesh_query_point_t:
|
|
3181
|
+
"""Output for the mesh query point functions.
|
|
3182
|
+
|
|
3183
|
+
Attributes:
|
|
3184
|
+
result (bool): Whether a point is found within the given constraints.
|
|
3185
|
+
sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.
|
|
3186
|
+
Note that mesh must be watertight for this to be robust
|
|
3187
|
+
face (int32): Index of the closest face.
|
|
3188
|
+
u (float32): Barycentric u coordinate of the closest point.
|
|
3189
|
+
v (float32): Barycentric v coordinate of the closest point.
|
|
3190
|
+
|
|
3191
|
+
See Also:
|
|
3192
|
+
:func:`mesh_query_point`, :func:`mesh_query_point_no_sign`,
|
|
3193
|
+
:func:`mesh_query_furthest_point_no_sign`,
|
|
3194
|
+
:func:`mesh_query_point_sign_normal`,
|
|
3195
|
+
and :func:`mesh_query_point_sign_winding_number`.
|
|
3196
|
+
"""
|
|
3197
|
+
from warp.codegen import Var
|
|
3198
|
+
|
|
3199
|
+
vars = {
|
|
3200
|
+
"result": Var("result", bool),
|
|
3201
|
+
"sign": Var("sign", float32),
|
|
3202
|
+
"face": Var("face", int32),
|
|
3203
|
+
"u": Var("u", float32),
|
|
3204
|
+
"v": Var("v", float32),
|
|
3205
|
+
}
|
|
3206
|
+
|
|
3207
|
+
|
|
3208
|
+
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
3209
|
+
# NOTE: its layout must match the corresponding struct defined in C.
|
|
3210
|
+
class mesh_query_ray_t:
|
|
3211
|
+
"""Output for the mesh query ray functions.
|
|
3212
|
+
|
|
3213
|
+
Attributes:
|
|
3214
|
+
result (bool): Whether a hit is found within the given constraints.
|
|
3215
|
+
sign (float32): A value > 0 if the ray hit in front of the face, returns < 0 otherwise.
|
|
3216
|
+
face (int32): Index of the closest face.
|
|
3217
|
+
t (float32): Distance of the closest hit along the ray.
|
|
3218
|
+
u (float32): Barycentric u coordinate of the closest hit.
|
|
3219
|
+
v (float32): Barycentric v coordinate of the closest hit.
|
|
3220
|
+
normal (vec3f): Face normal.
|
|
3221
|
+
|
|
3222
|
+
See Also:
|
|
3223
|
+
:func:`mesh_query_ray`.
|
|
3224
|
+
"""
|
|
3225
|
+
from warp.codegen import Var
|
|
3226
|
+
|
|
3227
|
+
vars = {
|
|
3228
|
+
"result": Var("result", bool),
|
|
3229
|
+
"sign": Var("sign", float32),
|
|
3230
|
+
"face": Var("face", int32),
|
|
3231
|
+
"t": Var("t", float32),
|
|
3232
|
+
"u": Var("u", float32),
|
|
3233
|
+
"v": Var("v", float32),
|
|
3234
|
+
"normal": Var("normal", vec3),
|
|
3235
|
+
}
|
|
3236
|
+
|
|
3237
|
+
|
|
2741
3238
|
def matmul(
|
|
2742
3239
|
a: array2d,
|
|
2743
3240
|
b: array2d,
|
|
@@ -2745,7 +3242,7 @@ def matmul(
|
|
|
2745
3242
|
d: array2d,
|
|
2746
3243
|
alpha: float = 1.0,
|
|
2747
3244
|
beta: float = 0.0,
|
|
2748
|
-
allow_tf32x3_arith: bool = False,
|
|
3245
|
+
allow_tf32x3_arith: builtins.bool = False,
|
|
2749
3246
|
device=None,
|
|
2750
3247
|
):
|
|
2751
3248
|
"""Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
@@ -2774,6 +3271,11 @@ def matmul(
|
|
|
2774
3271
|
"wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
2775
3272
|
)
|
|
2776
3273
|
|
|
3274
|
+
if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
|
|
3275
|
+
raise RuntimeError(
|
|
3276
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
3277
|
+
)
|
|
3278
|
+
|
|
2777
3279
|
m = a.shape[0]
|
|
2778
3280
|
n = b.shape[1]
|
|
2779
3281
|
k = a.shape[1]
|
|
@@ -2808,13 +3310,13 @@ def matmul(
|
|
|
2808
3310
|
ctypes.c_void_p(d.ptr),
|
|
2809
3311
|
alpha,
|
|
2810
3312
|
beta,
|
|
2811
|
-
|
|
2812
|
-
|
|
3313
|
+
not a.is_transposed,
|
|
3314
|
+
not b.is_transposed,
|
|
2813
3315
|
allow_tf32x3_arith,
|
|
2814
3316
|
1,
|
|
2815
3317
|
)
|
|
2816
3318
|
if not ret:
|
|
2817
|
-
raise RuntimeError("
|
|
3319
|
+
raise RuntimeError("matmul failed.")
|
|
2818
3320
|
|
|
2819
3321
|
|
|
2820
3322
|
def adj_matmul(
|
|
@@ -2827,7 +3329,7 @@ def adj_matmul(
|
|
|
2827
3329
|
adj_d: array2d,
|
|
2828
3330
|
alpha: float = 1.0,
|
|
2829
3331
|
beta: float = 0.0,
|
|
2830
|
-
allow_tf32x3_arith: bool = False,
|
|
3332
|
+
allow_tf32x3_arith: builtins.bool = False,
|
|
2831
3333
|
device=None,
|
|
2832
3334
|
):
|
|
2833
3335
|
"""Computes the adjoint of a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
@@ -2878,6 +3380,19 @@ def adj_matmul(
|
|
|
2878
3380
|
"wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
2879
3381
|
)
|
|
2880
3382
|
|
|
3383
|
+
if (
|
|
3384
|
+
(not a.is_contiguous and not a.is_transposed)
|
|
3385
|
+
or (not b.is_contiguous and not b.is_transposed)
|
|
3386
|
+
or (not c.is_contiguous)
|
|
3387
|
+
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
3388
|
+
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
3389
|
+
or (not adj_c.is_contiguous)
|
|
3390
|
+
or (not adj_d.is_contiguous)
|
|
3391
|
+
):
|
|
3392
|
+
raise RuntimeError(
|
|
3393
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
3394
|
+
)
|
|
3395
|
+
|
|
2881
3396
|
m = a.shape[0]
|
|
2882
3397
|
n = b.shape[1]
|
|
2883
3398
|
k = a.shape[1]
|
|
@@ -2898,75 +3413,105 @@ def adj_matmul(
|
|
|
2898
3413
|
|
|
2899
3414
|
# cpu fallback if no cuda devices found
|
|
2900
3415
|
if device == "cpu":
|
|
2901
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()))
|
|
2902
|
-
adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()))
|
|
2903
|
-
adj_c.assign(beta * adj_d.numpy())
|
|
3416
|
+
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
|
|
3417
|
+
adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
|
|
3418
|
+
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
2904
3419
|
return
|
|
2905
3420
|
|
|
2906
3421
|
cc = device.arch
|
|
2907
3422
|
|
|
2908
3423
|
# adj_a
|
|
2909
|
-
|
|
2910
|
-
|
|
2911
|
-
|
|
2912
|
-
|
|
2913
|
-
|
|
2914
|
-
|
|
2915
|
-
|
|
2916
|
-
|
|
2917
|
-
|
|
2918
|
-
|
|
2919
|
-
|
|
2920
|
-
|
|
2921
|
-
|
|
2922
|
-
|
|
2923
|
-
|
|
2924
|
-
|
|
2925
|
-
|
|
2926
|
-
|
|
2927
|
-
|
|
3424
|
+
if not a.is_transposed:
|
|
3425
|
+
ret = runtime.core.cutlass_gemm(
|
|
3426
|
+
cc,
|
|
3427
|
+
m,
|
|
3428
|
+
k,
|
|
3429
|
+
n,
|
|
3430
|
+
type_typestr(a.dtype).encode(),
|
|
3431
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3432
|
+
ctypes.c_void_p(b.ptr),
|
|
3433
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3434
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3435
|
+
alpha,
|
|
3436
|
+
1.0,
|
|
3437
|
+
True,
|
|
3438
|
+
b.is_transposed,
|
|
3439
|
+
allow_tf32x3_arith,
|
|
3440
|
+
1,
|
|
3441
|
+
)
|
|
3442
|
+
if not ret:
|
|
3443
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3444
|
+
else:
|
|
3445
|
+
ret = runtime.core.cutlass_gemm(
|
|
3446
|
+
cc,
|
|
3447
|
+
k,
|
|
3448
|
+
m,
|
|
3449
|
+
n,
|
|
3450
|
+
type_typestr(a.dtype).encode(),
|
|
3451
|
+
ctypes.c_void_p(b.ptr),
|
|
3452
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3453
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3454
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3455
|
+
alpha,
|
|
3456
|
+
1.0,
|
|
3457
|
+
not b.is_transposed,
|
|
3458
|
+
False,
|
|
3459
|
+
allow_tf32x3_arith,
|
|
3460
|
+
1,
|
|
3461
|
+
)
|
|
3462
|
+
if not ret:
|
|
3463
|
+
raise RuntimeError("adj_matmul failed.")
|
|
2928
3464
|
|
|
2929
3465
|
# adj_b
|
|
2930
|
-
|
|
2931
|
-
|
|
2932
|
-
|
|
2933
|
-
|
|
2934
|
-
|
|
2935
|
-
|
|
2936
|
-
|
|
2937
|
-
|
|
2938
|
-
|
|
2939
|
-
|
|
2940
|
-
|
|
2941
|
-
|
|
2942
|
-
|
|
2943
|
-
|
|
2944
|
-
|
|
2945
|
-
|
|
2946
|
-
|
|
2947
|
-
|
|
2948
|
-
|
|
3466
|
+
if not b.is_transposed:
|
|
3467
|
+
ret = runtime.core.cutlass_gemm(
|
|
3468
|
+
cc,
|
|
3469
|
+
k,
|
|
3470
|
+
n,
|
|
3471
|
+
m,
|
|
3472
|
+
type_typestr(a.dtype).encode(),
|
|
3473
|
+
ctypes.c_void_p(a.ptr),
|
|
3474
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3475
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3476
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3477
|
+
alpha,
|
|
3478
|
+
1.0,
|
|
3479
|
+
a.is_transposed,
|
|
3480
|
+
True,
|
|
3481
|
+
allow_tf32x3_arith,
|
|
3482
|
+
1,
|
|
3483
|
+
)
|
|
3484
|
+
if not ret:
|
|
3485
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3486
|
+
else:
|
|
3487
|
+
ret = runtime.core.cutlass_gemm(
|
|
3488
|
+
cc,
|
|
3489
|
+
n,
|
|
3490
|
+
k,
|
|
3491
|
+
m,
|
|
3492
|
+
type_typestr(a.dtype).encode(),
|
|
3493
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3494
|
+
ctypes.c_void_p(a.ptr),
|
|
3495
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3496
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3497
|
+
alpha,
|
|
3498
|
+
1.0,
|
|
3499
|
+
False,
|
|
3500
|
+
not a.is_transposed,
|
|
3501
|
+
allow_tf32x3_arith,
|
|
3502
|
+
1,
|
|
3503
|
+
)
|
|
3504
|
+
if not ret:
|
|
3505
|
+
raise RuntimeError("adj_matmul failed.")
|
|
2949
3506
|
|
|
2950
3507
|
# adj_c
|
|
2951
|
-
|
|
2952
|
-
|
|
2953
|
-
|
|
2954
|
-
|
|
2955
|
-
|
|
2956
|
-
|
|
2957
|
-
ctypes.c_void_p(a.ptr),
|
|
2958
|
-
ctypes.c_void_p(b.ptr),
|
|
2959
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
2960
|
-
ctypes.c_void_p(adj_c.ptr),
|
|
2961
|
-
0.0,
|
|
2962
|
-
beta,
|
|
2963
|
-
True,
|
|
2964
|
-
True,
|
|
2965
|
-
allow_tf32x3_arith,
|
|
2966
|
-
1,
|
|
3508
|
+
warp.launch(
|
|
3509
|
+
kernel=warp.utils.add_kernel_2d,
|
|
3510
|
+
dim=adj_c.shape,
|
|
3511
|
+
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
3512
|
+
device=device,
|
|
3513
|
+
record_tape=False
|
|
2967
3514
|
)
|
|
2968
|
-
if not ret:
|
|
2969
|
-
raise RuntimeError("adj_matmul failed.")
|
|
2970
3515
|
|
|
2971
3516
|
|
|
2972
3517
|
def batched_matmul(
|
|
@@ -2976,7 +3521,7 @@ def batched_matmul(
|
|
|
2976
3521
|
d: array3d,
|
|
2977
3522
|
alpha: float = 1.0,
|
|
2978
3523
|
beta: float = 0.0,
|
|
2979
|
-
allow_tf32x3_arith: bool = False,
|
|
3524
|
+
allow_tf32x3_arith: builtins.bool = False,
|
|
2980
3525
|
device=None,
|
|
2981
3526
|
):
|
|
2982
3527
|
"""Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
@@ -3005,6 +3550,11 @@ def batched_matmul(
|
|
|
3005
3550
|
"wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
3006
3551
|
)
|
|
3007
3552
|
|
|
3553
|
+
if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
|
|
3554
|
+
raise RuntimeError(
|
|
3555
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
3556
|
+
)
|
|
3557
|
+
|
|
3008
3558
|
m = a.shape[1]
|
|
3009
3559
|
n = b.shape[2]
|
|
3010
3560
|
k = a.shape[2]
|
|
@@ -3016,7 +3566,7 @@ def batched_matmul(
|
|
|
3016
3566
|
|
|
3017
3567
|
if runtime.tape:
|
|
3018
3568
|
runtime.tape.record_func(
|
|
3019
|
-
backward=lambda:
|
|
3569
|
+
backward=lambda: adj_batched_matmul(
|
|
3020
3570
|
a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
|
|
3021
3571
|
),
|
|
3022
3572
|
arrays=[a, b, c, d],
|
|
@@ -3027,26 +3577,55 @@ def batched_matmul(
|
|
|
3027
3577
|
d.assign(alpha * np.matmul(a.numpy(), b.numpy()) + beta * c.numpy())
|
|
3028
3578
|
return
|
|
3029
3579
|
|
|
3580
|
+
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
3581
|
+
max_batch_count = 65535
|
|
3582
|
+
iters = int(batch_count / max_batch_count)
|
|
3583
|
+
remainder = batch_count % max_batch_count
|
|
3584
|
+
|
|
3030
3585
|
cc = device.arch
|
|
3586
|
+
for i in range(iters):
|
|
3587
|
+
idx_start = i * max_batch_count
|
|
3588
|
+
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
3589
|
+
ret = runtime.core.cutlass_gemm(
|
|
3590
|
+
cc,
|
|
3591
|
+
m,
|
|
3592
|
+
n,
|
|
3593
|
+
k,
|
|
3594
|
+
type_typestr(a.dtype).encode(),
|
|
3595
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3596
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3597
|
+
ctypes.c_void_p(c[idx_start:idx_end,:,:].ptr),
|
|
3598
|
+
ctypes.c_void_p(d[idx_start:idx_end,:,:].ptr),
|
|
3599
|
+
alpha,
|
|
3600
|
+
beta,
|
|
3601
|
+
not a.is_transposed,
|
|
3602
|
+
not b.is_transposed,
|
|
3603
|
+
allow_tf32x3_arith,
|
|
3604
|
+
max_batch_count,
|
|
3605
|
+
)
|
|
3606
|
+
if not ret:
|
|
3607
|
+
raise RuntimeError("Batched matmul failed.")
|
|
3608
|
+
|
|
3609
|
+
idx_start = iters * max_batch_count
|
|
3031
3610
|
ret = runtime.core.cutlass_gemm(
|
|
3032
3611
|
cc,
|
|
3033
3612
|
m,
|
|
3034
3613
|
n,
|
|
3035
3614
|
k,
|
|
3036
3615
|
type_typestr(a.dtype).encode(),
|
|
3037
|
-
ctypes.c_void_p(a.ptr),
|
|
3038
|
-
ctypes.c_void_p(b.ptr),
|
|
3039
|
-
ctypes.c_void_p(c.ptr),
|
|
3040
|
-
ctypes.c_void_p(d.ptr),
|
|
3616
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3617
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3618
|
+
ctypes.c_void_p(c[idx_start:,:,:].ptr),
|
|
3619
|
+
ctypes.c_void_p(d[idx_start:,:,:].ptr),
|
|
3041
3620
|
alpha,
|
|
3042
3621
|
beta,
|
|
3043
|
-
|
|
3044
|
-
|
|
3622
|
+
not a.is_transposed,
|
|
3623
|
+
not b.is_transposed,
|
|
3045
3624
|
allow_tf32x3_arith,
|
|
3046
|
-
|
|
3625
|
+
remainder,
|
|
3047
3626
|
)
|
|
3048
3627
|
if not ret:
|
|
3049
|
-
raise RuntimeError("Batched matmul failed.")
|
|
3628
|
+
raise RuntimeError("Batched matmul failed.")
|
|
3050
3629
|
|
|
3051
3630
|
|
|
3052
3631
|
def adj_batched_matmul(
|
|
@@ -3059,7 +3638,7 @@ def adj_batched_matmul(
|
|
|
3059
3638
|
adj_d: array3d,
|
|
3060
3639
|
alpha: float = 1.0,
|
|
3061
3640
|
beta: float = 0.0,
|
|
3062
|
-
allow_tf32x3_arith: bool = False,
|
|
3641
|
+
allow_tf32x3_arith: builtins.bool = False,
|
|
3063
3642
|
device=None,
|
|
3064
3643
|
):
|
|
3065
3644
|
"""Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
@@ -3126,78 +3705,215 @@ def adj_batched_matmul(
|
|
|
3126
3705
|
)
|
|
3127
3706
|
)
|
|
3128
3707
|
|
|
3708
|
+
if (
|
|
3709
|
+
(not a.is_contiguous and not a.is_transposed)
|
|
3710
|
+
or (not b.is_contiguous and not b.is_transposed)
|
|
3711
|
+
or (not c.is_contiguous)
|
|
3712
|
+
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
3713
|
+
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
3714
|
+
or (not adj_c.is_contiguous)
|
|
3715
|
+
or (not adj_d.is_contiguous)
|
|
3716
|
+
):
|
|
3717
|
+
raise RuntimeError(
|
|
3718
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
3719
|
+
)
|
|
3720
|
+
|
|
3129
3721
|
# cpu fallback if no cuda devices found
|
|
3130
3722
|
if device == "cpu":
|
|
3131
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))))
|
|
3132
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()))
|
|
3133
|
-
adj_c.assign(beta * adj_d.numpy())
|
|
3723
|
+
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
|
|
3724
|
+
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
|
|
3725
|
+
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
3134
3726
|
return
|
|
3135
3727
|
|
|
3728
|
+
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
3729
|
+
max_batch_count = 65535
|
|
3730
|
+
iters = int(batch_count / max_batch_count)
|
|
3731
|
+
remainder = batch_count % max_batch_count
|
|
3732
|
+
|
|
3136
3733
|
cc = device.arch
|
|
3137
3734
|
|
|
3735
|
+
for i in range(iters):
|
|
3736
|
+
idx_start = i * max_batch_count
|
|
3737
|
+
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
3738
|
+
|
|
3739
|
+
# adj_a
|
|
3740
|
+
if not a.is_transposed:
|
|
3741
|
+
ret = runtime.core.cutlass_gemm(
|
|
3742
|
+
cc,
|
|
3743
|
+
m,
|
|
3744
|
+
k,
|
|
3745
|
+
n,
|
|
3746
|
+
type_typestr(a.dtype).encode(),
|
|
3747
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3748
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3749
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3750
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3751
|
+
alpha,
|
|
3752
|
+
1.0,
|
|
3753
|
+
True,
|
|
3754
|
+
b.is_transposed,
|
|
3755
|
+
allow_tf32x3_arith,
|
|
3756
|
+
max_batch_count,
|
|
3757
|
+
)
|
|
3758
|
+
if not ret:
|
|
3759
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3760
|
+
else:
|
|
3761
|
+
ret = runtime.core.cutlass_gemm(
|
|
3762
|
+
cc,
|
|
3763
|
+
k,
|
|
3764
|
+
m,
|
|
3765
|
+
n,
|
|
3766
|
+
type_typestr(a.dtype).encode(),
|
|
3767
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3768
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3769
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3770
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3771
|
+
alpha,
|
|
3772
|
+
1.0,
|
|
3773
|
+
not b.is_transposed,
|
|
3774
|
+
False,
|
|
3775
|
+
allow_tf32x3_arith,
|
|
3776
|
+
max_batch_count,
|
|
3777
|
+
)
|
|
3778
|
+
if not ret:
|
|
3779
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3780
|
+
|
|
3781
|
+
# adj_b
|
|
3782
|
+
if not b.is_transposed:
|
|
3783
|
+
ret = runtime.core.cutlass_gemm(
|
|
3784
|
+
cc,
|
|
3785
|
+
k,
|
|
3786
|
+
n,
|
|
3787
|
+
m,
|
|
3788
|
+
type_typestr(a.dtype).encode(),
|
|
3789
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3790
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3791
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3792
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3793
|
+
alpha,
|
|
3794
|
+
1.0,
|
|
3795
|
+
a.is_transposed,
|
|
3796
|
+
True,
|
|
3797
|
+
allow_tf32x3_arith,
|
|
3798
|
+
max_batch_count,
|
|
3799
|
+
)
|
|
3800
|
+
if not ret:
|
|
3801
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3802
|
+
else:
|
|
3803
|
+
ret = runtime.core.cutlass_gemm(
|
|
3804
|
+
cc,
|
|
3805
|
+
n,
|
|
3806
|
+
k,
|
|
3807
|
+
m,
|
|
3808
|
+
type_typestr(a.dtype).encode(),
|
|
3809
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3810
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3811
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3812
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3813
|
+
alpha,
|
|
3814
|
+
1.0,
|
|
3815
|
+
False,
|
|
3816
|
+
not a.is_transposed,
|
|
3817
|
+
allow_tf32x3_arith,
|
|
3818
|
+
max_batch_count,
|
|
3819
|
+
)
|
|
3820
|
+
if not ret:
|
|
3821
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3822
|
+
|
|
3823
|
+
idx_start = iters * max_batch_count
|
|
3824
|
+
|
|
3138
3825
|
# adj_a
|
|
3139
|
-
|
|
3140
|
-
|
|
3141
|
-
|
|
3142
|
-
|
|
3143
|
-
|
|
3144
|
-
|
|
3145
|
-
|
|
3146
|
-
|
|
3147
|
-
|
|
3148
|
-
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
|
|
3152
|
-
|
|
3153
|
-
|
|
3154
|
-
|
|
3155
|
-
|
|
3156
|
-
|
|
3157
|
-
|
|
3826
|
+
if not a.is_transposed:
|
|
3827
|
+
ret = runtime.core.cutlass_gemm(
|
|
3828
|
+
cc,
|
|
3829
|
+
m,
|
|
3830
|
+
k,
|
|
3831
|
+
n,
|
|
3832
|
+
type_typestr(a.dtype).encode(),
|
|
3833
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3834
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3835
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3836
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3837
|
+
alpha,
|
|
3838
|
+
1.0,
|
|
3839
|
+
True,
|
|
3840
|
+
b.is_transposed,
|
|
3841
|
+
allow_tf32x3_arith,
|
|
3842
|
+
remainder,
|
|
3843
|
+
)
|
|
3844
|
+
if not ret:
|
|
3845
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3846
|
+
else:
|
|
3847
|
+
ret = runtime.core.cutlass_gemm(
|
|
3848
|
+
cc,
|
|
3849
|
+
k,
|
|
3850
|
+
m,
|
|
3851
|
+
n,
|
|
3852
|
+
type_typestr(a.dtype).encode(),
|
|
3853
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3854
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3855
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3856
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3857
|
+
alpha,
|
|
3858
|
+
1.0,
|
|
3859
|
+
not b.is_transposed,
|
|
3860
|
+
False,
|
|
3861
|
+
allow_tf32x3_arith,
|
|
3862
|
+
remainder,
|
|
3863
|
+
)
|
|
3864
|
+
if not ret:
|
|
3865
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3158
3866
|
|
|
3159
3867
|
# adj_b
|
|
3160
|
-
|
|
3161
|
-
|
|
3162
|
-
|
|
3163
|
-
|
|
3164
|
-
|
|
3165
|
-
|
|
3166
|
-
|
|
3167
|
-
|
|
3168
|
-
|
|
3169
|
-
|
|
3170
|
-
|
|
3171
|
-
|
|
3172
|
-
|
|
3173
|
-
|
|
3174
|
-
|
|
3175
|
-
|
|
3176
|
-
|
|
3177
|
-
|
|
3178
|
-
|
|
3868
|
+
if not b.is_transposed:
|
|
3869
|
+
ret = runtime.core.cutlass_gemm(
|
|
3870
|
+
cc,
|
|
3871
|
+
k,
|
|
3872
|
+
n,
|
|
3873
|
+
m,
|
|
3874
|
+
type_typestr(a.dtype).encode(),
|
|
3875
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3876
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3877
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3878
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3879
|
+
alpha,
|
|
3880
|
+
1.0,
|
|
3881
|
+
a.is_transposed,
|
|
3882
|
+
True,
|
|
3883
|
+
allow_tf32x3_arith,
|
|
3884
|
+
remainder,
|
|
3885
|
+
)
|
|
3886
|
+
if not ret:
|
|
3887
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3888
|
+
else:
|
|
3889
|
+
ret = runtime.core.cutlass_gemm(
|
|
3890
|
+
cc,
|
|
3891
|
+
n,
|
|
3892
|
+
k,
|
|
3893
|
+
m,
|
|
3894
|
+
type_typestr(a.dtype).encode(),
|
|
3895
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3896
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3897
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3898
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3899
|
+
alpha,
|
|
3900
|
+
1.0,
|
|
3901
|
+
False,
|
|
3902
|
+
not a.is_transposed,
|
|
3903
|
+
allow_tf32x3_arith,
|
|
3904
|
+
remainder,
|
|
3905
|
+
)
|
|
3906
|
+
if not ret:
|
|
3907
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3179
3908
|
|
|
3180
3909
|
# adj_c
|
|
3181
|
-
|
|
3182
|
-
|
|
3183
|
-
|
|
3184
|
-
|
|
3185
|
-
|
|
3186
|
-
|
|
3187
|
-
ctypes.c_void_p(a.ptr),
|
|
3188
|
-
ctypes.c_void_p(b.ptr),
|
|
3189
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
3190
|
-
ctypes.c_void_p(adj_c.ptr),
|
|
3191
|
-
0.0,
|
|
3192
|
-
beta,
|
|
3193
|
-
True,
|
|
3194
|
-
True,
|
|
3195
|
-
allow_tf32x3_arith,
|
|
3196
|
-
batch_count,
|
|
3910
|
+
warp.launch(
|
|
3911
|
+
kernel=warp.utils.add_kernel_3d,
|
|
3912
|
+
dim=adj_c.shape,
|
|
3913
|
+
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
3914
|
+
device=device,
|
|
3915
|
+
record_tape=False
|
|
3197
3916
|
)
|
|
3198
|
-
if not ret:
|
|
3199
|
-
raise RuntimeError("adj_matmul failed.")
|
|
3200
|
-
|
|
3201
3917
|
|
|
3202
3918
|
class HashGrid:
|
|
3203
3919
|
def __init__(self, dim_x, dim_y, dim_z, device=None):
|
|
@@ -3266,7 +3982,7 @@ class HashGrid:
|
|
|
3266
3982
|
with self.device.context_guard:
|
|
3267
3983
|
runtime.core.hash_grid_destroy_device(self.id)
|
|
3268
3984
|
|
|
3269
|
-
except:
|
|
3985
|
+
except Exception:
|
|
3270
3986
|
pass
|
|
3271
3987
|
|
|
3272
3988
|
|
|
@@ -3340,7 +4056,7 @@ class MarchingCubes:
|
|
|
3340
4056
|
|
|
3341
4057
|
if error:
|
|
3342
4058
|
raise RuntimeError(
|
|
3343
|
-
"
|
|
4059
|
+
"Buffers may not be large enough, marching cubes required at least {num_verts} vertices, and {num_tris} triangles."
|
|
3344
4060
|
)
|
|
3345
4061
|
|
|
3346
4062
|
# resize the geometry arrays
|
|
@@ -3396,7 +4112,7 @@ def type_matches_template(arg_type, template_type):
|
|
|
3396
4112
|
return True
|
|
3397
4113
|
elif is_array(template_type):
|
|
3398
4114
|
# ensure the argument type is a non-generic array with matching dtype and dimensionality
|
|
3399
|
-
if type(arg_type)
|
|
4115
|
+
if type(arg_type) is not type(template_type):
|
|
3400
4116
|
return False
|
|
3401
4117
|
if not type_matches_template(arg_type.dtype, template_type.dtype):
|
|
3402
4118
|
return False
|
|
@@ -3429,7 +4145,7 @@ def infer_argument_types(args, template_types, arg_names=None):
|
|
|
3429
4145
|
"""Resolve argument types with the given list of template types."""
|
|
3430
4146
|
|
|
3431
4147
|
if len(args) != len(template_types):
|
|
3432
|
-
raise RuntimeError(
|
|
4148
|
+
raise RuntimeError("Number of arguments must match number of template types.")
|
|
3433
4149
|
|
|
3434
4150
|
arg_types = []
|
|
3435
4151
|
|
|
@@ -3452,7 +4168,7 @@ def infer_argument_types(args, template_types, arg_names=None):
|
|
|
3452
4168
|
arg_types.append(arg._cls)
|
|
3453
4169
|
# elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
|
|
3454
4170
|
# arg_types.append(arg_type)
|
|
3455
|
-
# elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.bvh_query_t]:
|
|
4171
|
+
# elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
|
|
3456
4172
|
# arg_types.append(arg_type)
|
|
3457
4173
|
elif arg is None:
|
|
3458
4174
|
# allow passing None for arrays
|
|
@@ -3471,6 +4187,7 @@ def infer_argument_types(args, template_types, arg_names=None):
|
|
|
3471
4187
|
simple_type_codes = {
|
|
3472
4188
|
int: "i4",
|
|
3473
4189
|
float: "f4",
|
|
4190
|
+
builtins.bool: "b",
|
|
3474
4191
|
bool: "b",
|
|
3475
4192
|
str: "str", # accepted by print()
|
|
3476
4193
|
int8: "i1",
|
|
@@ -3489,6 +4206,8 @@ simple_type_codes = {
|
|
|
3489
4206
|
launch_bounds_t: "lb",
|
|
3490
4207
|
hash_grid_query_t: "hgq",
|
|
3491
4208
|
mesh_query_aabb_t: "mqa",
|
|
4209
|
+
mesh_query_point_t: "mqp",
|
|
4210
|
+
mesh_query_ray_t: "mqr",
|
|
3492
4211
|
bvh_query_t: "bvhq",
|
|
3493
4212
|
}
|
|
3494
4213
|
|
|
@@ -3505,14 +4224,14 @@ def get_type_code(arg_type):
|
|
|
3505
4224
|
# check for "special" vector/matrix subtypes
|
|
3506
4225
|
if hasattr(arg_type, "_wp_generic_type_str_"):
|
|
3507
4226
|
type_str = arg_type._wp_generic_type_str_
|
|
3508
|
-
if type_str == "
|
|
4227
|
+
if type_str == "quat_t":
|
|
3509
4228
|
return f"q{dtype_code}"
|
|
3510
4229
|
elif type_str == "transform_t":
|
|
3511
4230
|
return f"t{dtype_code}"
|
|
3512
|
-
elif type_str == "spatial_vector_t":
|
|
3513
|
-
|
|
3514
|
-
elif type_str == "spatial_matrix_t":
|
|
3515
|
-
|
|
4231
|
+
# elif type_str == "spatial_vector_t":
|
|
4232
|
+
# return f"sv{dtype_code}"
|
|
4233
|
+
# elif type_str == "spatial_matrix_t":
|
|
4234
|
+
# return f"sm{dtype_code}"
|
|
3516
4235
|
# generic vector/matrix
|
|
3517
4236
|
ndim = len(arg_type._shape_)
|
|
3518
4237
|
if ndim == 1:
|
|
@@ -3535,6 +4254,10 @@ def get_type_code(arg_type):
|
|
|
3535
4254
|
return f"a{arg_type.ndim}{get_type_code(arg_type.dtype)}"
|
|
3536
4255
|
elif isinstance(arg_type, indexedarray):
|
|
3537
4256
|
return f"ia{arg_type.ndim}{get_type_code(arg_type.dtype)}"
|
|
4257
|
+
elif isinstance(arg_type, fabricarray):
|
|
4258
|
+
return f"fa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
|
|
4259
|
+
elif isinstance(arg_type, indexedfabricarray):
|
|
4260
|
+
return f"ifa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
|
|
3538
4261
|
elif isinstance(arg_type, warp.codegen.Struct):
|
|
3539
4262
|
return warp.codegen.make_full_qualified_name(arg_type.cls)
|
|
3540
4263
|
elif arg_type == Scalar:
|