warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/types.py
CHANGED
|
@@ -5,19 +5,17 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import builtins
|
|
8
11
|
import ctypes
|
|
9
12
|
import hashlib
|
|
13
|
+
import inspect
|
|
10
14
|
import struct
|
|
11
15
|
import zlib
|
|
12
|
-
import
|
|
16
|
+
from typing import Any, Callable, Generic, List, Tuple, TypeVar, Union
|
|
13
17
|
|
|
14
|
-
|
|
15
|
-
from typing import Tuple
|
|
16
|
-
from typing import TypeVar
|
|
17
|
-
from typing import Generic
|
|
18
|
-
from typing import List
|
|
19
|
-
from typing import Callable
|
|
20
|
-
from typing import Union
|
|
18
|
+
import numpy as np
|
|
21
19
|
|
|
22
20
|
import warp
|
|
23
21
|
|
|
@@ -54,12 +52,14 @@ def constant(x):
|
|
|
54
52
|
global _constant_hash
|
|
55
53
|
|
|
56
54
|
# hash the constant value
|
|
57
|
-
if isinstance(x,
|
|
55
|
+
if isinstance(x, builtins.bool):
|
|
56
|
+
# This needs to come before the check for `int` since all boolean
|
|
57
|
+
# values are also instances of `int`.
|
|
58
|
+
_constant_hash.update(struct.pack("?", x))
|
|
59
|
+
elif isinstance(x, int):
|
|
58
60
|
_constant_hash.update(struct.pack("<q", x))
|
|
59
61
|
elif isinstance(x, float):
|
|
60
62
|
_constant_hash.update(struct.pack("<d", x))
|
|
61
|
-
elif isinstance(x, bool):
|
|
62
|
-
_constant_hash.update(struct.pack("?", x))
|
|
63
63
|
elif isinstance(x, float16):
|
|
64
64
|
# float16 is a special case
|
|
65
65
|
p = ctypes.pointer(ctypes.c_float(x.value))
|
|
@@ -75,6 +75,14 @@ def constant(x):
|
|
|
75
75
|
return x
|
|
76
76
|
|
|
77
77
|
|
|
78
|
+
def float_to_half_bits(value):
|
|
79
|
+
return warp.context.runtime.core.float_to_half_bits(value)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def half_bits_to_float(value):
|
|
83
|
+
return warp.context.runtime.core.half_bits_to_float(value)
|
|
84
|
+
|
|
85
|
+
|
|
78
86
|
# ----------------------
|
|
79
87
|
# built-in types
|
|
80
88
|
|
|
@@ -98,19 +106,15 @@ def vector(length, dtype):
|
|
|
98
106
|
_wp_generic_type_str_ = "vec_t"
|
|
99
107
|
_wp_constructor_ = "vector"
|
|
100
108
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
from warp.context import runtime
|
|
109
|
-
|
|
110
|
-
scalar_value = runtime.core.float_to_half_bits
|
|
111
|
-
else:
|
|
112
|
-
scalar_value = lambda x: x
|
|
109
|
+
# special handling for float16 type: in this case, data is stored
|
|
110
|
+
# as uint16 but it's actually half precision floating point
|
|
111
|
+
# data. This means we need to convert each of the arguments
|
|
112
|
+
# to uint16s containing half float bits before storing them in
|
|
113
|
+
# the array:
|
|
114
|
+
scalar_import = float_to_half_bits if _wp_scalar_type_ == float16 else lambda x: x
|
|
115
|
+
scalar_export = half_bits_to_float if _wp_scalar_type_ == float16 else lambda x: x
|
|
113
116
|
|
|
117
|
+
def __init__(self, *args):
|
|
114
118
|
num_args = len(args)
|
|
115
119
|
if num_args == 0:
|
|
116
120
|
super().__init__()
|
|
@@ -120,29 +124,99 @@ def vector(length, dtype):
|
|
|
120
124
|
self.__init__(*args[0])
|
|
121
125
|
else:
|
|
122
126
|
# set all elements to the same value
|
|
123
|
-
value =
|
|
127
|
+
value = vec_t.scalar_import(args[0])
|
|
124
128
|
for i in range(self._length_):
|
|
125
129
|
super().__setitem__(i, value)
|
|
126
130
|
elif num_args == self._length_:
|
|
127
131
|
# set all scalar elements
|
|
128
132
|
for i in range(self._length_):
|
|
129
|
-
super().__setitem__(i,
|
|
133
|
+
super().__setitem__(i, vec_t.scalar_import(args[i]))
|
|
130
134
|
else:
|
|
131
135
|
raise ValueError(
|
|
132
136
|
f"Invalid number of arguments in vector constructor, expected {self._length_} elements, got {num_args}"
|
|
133
137
|
)
|
|
134
138
|
|
|
139
|
+
def __getitem__(self, key):
|
|
140
|
+
if isinstance(key, int):
|
|
141
|
+
return vec_t.scalar_export(super().__getitem__(key))
|
|
142
|
+
elif isinstance(key, slice):
|
|
143
|
+
if self._wp_scalar_type_ == float16:
|
|
144
|
+
return [vec_t.scalar_export(x) for x in super().__getitem__(key)]
|
|
145
|
+
else:
|
|
146
|
+
return super().__getitem__(key)
|
|
147
|
+
else:
|
|
148
|
+
raise KeyError(f"Invalid key {key}, expected int or slice")
|
|
149
|
+
|
|
150
|
+
def __setitem__(self, key, value):
|
|
151
|
+
if isinstance(key, int):
|
|
152
|
+
try:
|
|
153
|
+
return super().__setitem__(key, vec_t.scalar_import(value))
|
|
154
|
+
except (TypeError, ctypes.ArgumentError):
|
|
155
|
+
raise TypeError(
|
|
156
|
+
f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
|
|
157
|
+
f"but got `{type(value).__name__}` instead"
|
|
158
|
+
) from None
|
|
159
|
+
elif isinstance(key, slice):
|
|
160
|
+
try:
|
|
161
|
+
iter(value)
|
|
162
|
+
except TypeError:
|
|
163
|
+
raise TypeError(
|
|
164
|
+
f"Expected to assign a slice from a sequence of values "
|
|
165
|
+
f"but got `{type(value).__name__}` instead"
|
|
166
|
+
) from None
|
|
167
|
+
|
|
168
|
+
if self._wp_scalar_type_ == float16:
|
|
169
|
+
converted = []
|
|
170
|
+
try:
|
|
171
|
+
for x in value:
|
|
172
|
+
converted.append(vec_t.scalar_import(x))
|
|
173
|
+
except ctypes.ArgumentError:
|
|
174
|
+
raise TypeError(
|
|
175
|
+
f"Expected to assign a slice from a sequence of `float16` values "
|
|
176
|
+
f"but got `{type(x).__name__}` instead"
|
|
177
|
+
) from None
|
|
178
|
+
|
|
179
|
+
value = converted
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
return super().__setitem__(key, value)
|
|
183
|
+
except TypeError:
|
|
184
|
+
for x in value:
|
|
185
|
+
try:
|
|
186
|
+
self._type_(x)
|
|
187
|
+
except TypeError:
|
|
188
|
+
raise TypeError(
|
|
189
|
+
f"Expected to assign a slice from a sequence of `{self._wp_scalar_type_.__name__}` values "
|
|
190
|
+
f"but got `{type(x).__name__}` instead"
|
|
191
|
+
) from None
|
|
192
|
+
else:
|
|
193
|
+
raise KeyError(f"Invalid key {key}, expected int or slice")
|
|
194
|
+
|
|
195
|
+
def __getattr__(self, name):
|
|
196
|
+
idx = "xyzw".find(name)
|
|
197
|
+
if idx != -1:
|
|
198
|
+
return self.__getitem__(idx)
|
|
199
|
+
|
|
200
|
+
return self.__getattribute__(name)
|
|
201
|
+
|
|
202
|
+
def __setattr__(self, name, value):
|
|
203
|
+
idx = "xyzw".find(name)
|
|
204
|
+
if idx != -1:
|
|
205
|
+
return self.__setitem__(idx, value)
|
|
206
|
+
|
|
207
|
+
return super().__setattr__(name, value)
|
|
208
|
+
|
|
135
209
|
def __add__(self, y):
|
|
136
210
|
return warp.add(self, y)
|
|
137
211
|
|
|
138
212
|
def __radd__(self, y):
|
|
139
|
-
return warp.add(
|
|
213
|
+
return warp.add(y, self)
|
|
140
214
|
|
|
141
215
|
def __sub__(self, y):
|
|
142
216
|
return warp.sub(self, y)
|
|
143
217
|
|
|
144
|
-
def __rsub__(self,
|
|
145
|
-
return warp.sub(
|
|
218
|
+
def __rsub__(self, y):
|
|
219
|
+
return warp.sub(y, self)
|
|
146
220
|
|
|
147
221
|
def __mul__(self, y):
|
|
148
222
|
return warp.mul(self, y)
|
|
@@ -150,17 +224,17 @@ def vector(length, dtype):
|
|
|
150
224
|
def __rmul__(self, x):
|
|
151
225
|
return warp.mul(x, self)
|
|
152
226
|
|
|
153
|
-
def
|
|
227
|
+
def __truediv__(self, y):
|
|
154
228
|
return warp.div(self, y)
|
|
155
229
|
|
|
156
|
-
def
|
|
230
|
+
def __rtruediv__(self, x):
|
|
157
231
|
return warp.div(x, self)
|
|
158
232
|
|
|
159
|
-
def __pos__(self
|
|
160
|
-
return warp.pos(self
|
|
233
|
+
def __pos__(self):
|
|
234
|
+
return warp.pos(self)
|
|
161
235
|
|
|
162
|
-
def __neg__(self
|
|
163
|
-
return warp.neg(self
|
|
236
|
+
def __neg__(self):
|
|
237
|
+
return warp.neg(self)
|
|
164
238
|
|
|
165
239
|
def __str__(self):
|
|
166
240
|
return f"[{', '.join(map(str, self))}]"
|
|
@@ -171,6 +245,17 @@ def vector(length, dtype):
|
|
|
171
245
|
return False
|
|
172
246
|
return True
|
|
173
247
|
|
|
248
|
+
@classmethod
|
|
249
|
+
def from_ptr(cls, ptr):
|
|
250
|
+
if ptr:
|
|
251
|
+
# create a new vector instance and initialize the contents from the binary data
|
|
252
|
+
# this skips float16 conversions, assuming that float16 data is already encoded as uint16
|
|
253
|
+
value = cls()
|
|
254
|
+
ctypes.memmove(ctypes.byref(value), ptr, ctypes.sizeof(cls._type_) * cls._length_)
|
|
255
|
+
return value
|
|
256
|
+
else:
|
|
257
|
+
raise RuntimeError("NULL pointer exception")
|
|
258
|
+
|
|
174
259
|
return vec_t
|
|
175
260
|
|
|
176
261
|
|
|
@@ -197,19 +282,15 @@ def matrix(shape, dtype):
|
|
|
197
282
|
|
|
198
283
|
_wp_row_type_ = vector(0 if shape[1] == Any else shape[1], dtype)
|
|
199
284
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
from warp.context import runtime
|
|
208
|
-
|
|
209
|
-
scalar_value = runtime.core.float_to_half_bits
|
|
210
|
-
else:
|
|
211
|
-
scalar_value = lambda x: x
|
|
285
|
+
# special handling for float16 type: in this case, data is stored
|
|
286
|
+
# as uint16 but it's actually half precision floating point
|
|
287
|
+
# data. This means we need to convert each of the arguments
|
|
288
|
+
# to uint16s containing half float bits before storing them in
|
|
289
|
+
# the array:
|
|
290
|
+
scalar_import = float_to_half_bits if _wp_scalar_type_ == float16 else lambda x: x
|
|
291
|
+
scalar_export = half_bits_to_float if _wp_scalar_type_ == float16 else lambda x: x
|
|
212
292
|
|
|
293
|
+
def __init__(self, *args):
|
|
213
294
|
num_args = len(args)
|
|
214
295
|
if num_args == 0:
|
|
215
296
|
super().__init__()
|
|
@@ -219,13 +300,13 @@ def matrix(shape, dtype):
|
|
|
219
300
|
self.__init__(*args[0])
|
|
220
301
|
else:
|
|
221
302
|
# set all elements to the same value
|
|
222
|
-
value =
|
|
303
|
+
value = mat_t.scalar_import(args[0])
|
|
223
304
|
for i in range(self._length_):
|
|
224
305
|
super().__setitem__(i, value)
|
|
225
306
|
elif num_args == self._length_:
|
|
226
307
|
# set all scalar elements
|
|
227
308
|
for i in range(self._length_):
|
|
228
|
-
super().__setitem__(i,
|
|
309
|
+
super().__setitem__(i, mat_t.scalar_import(args[i]))
|
|
229
310
|
elif num_args == self._shape_[0]:
|
|
230
311
|
# row vectors
|
|
231
312
|
for i, row in enumerate(args):
|
|
@@ -235,7 +316,7 @@ def matrix(shape, dtype):
|
|
|
235
316
|
)
|
|
236
317
|
offset = i * self._shape_[1]
|
|
237
318
|
for i in range(self._shape_[1]):
|
|
238
|
-
super().__setitem__(offset + i,
|
|
319
|
+
super().__setitem__(offset + i, mat_t.scalar_import(row[i]))
|
|
239
320
|
else:
|
|
240
321
|
raise ValueError(
|
|
241
322
|
f"Invalid number of arguments in matrix constructor, expected {self._length_} elements, got {num_args}"
|
|
@@ -245,13 +326,13 @@ def matrix(shape, dtype):
|
|
|
245
326
|
return warp.add(self, y)
|
|
246
327
|
|
|
247
328
|
def __radd__(self, y):
|
|
248
|
-
return warp.add(
|
|
329
|
+
return warp.add(y, self)
|
|
249
330
|
|
|
250
331
|
def __sub__(self, y):
|
|
251
332
|
return warp.sub(self, y)
|
|
252
333
|
|
|
253
|
-
def __rsub__(self,
|
|
254
|
-
return warp.sub(
|
|
334
|
+
def __rsub__(self, y):
|
|
335
|
+
return warp.sub(y, self)
|
|
255
336
|
|
|
256
337
|
def __mul__(self, y):
|
|
257
338
|
return warp.mul(self, y)
|
|
@@ -265,17 +346,17 @@ def matrix(shape, dtype):
|
|
|
265
346
|
def __rmatmul__(self, x):
|
|
266
347
|
return warp.mul(x, self)
|
|
267
348
|
|
|
268
|
-
def
|
|
349
|
+
def __truediv__(self, y):
|
|
269
350
|
return warp.div(self, y)
|
|
270
351
|
|
|
271
|
-
def
|
|
352
|
+
def __rtruediv__(self, x):
|
|
272
353
|
return warp.div(x, self)
|
|
273
354
|
|
|
274
|
-
def __pos__(self
|
|
275
|
-
return warp.pos(self
|
|
355
|
+
def __pos__(self):
|
|
356
|
+
return warp.pos(self)
|
|
276
357
|
|
|
277
|
-
def __neg__(self
|
|
278
|
-
return warp.neg(self
|
|
358
|
+
def __neg__(self):
|
|
359
|
+
return warp.neg(self)
|
|
279
360
|
|
|
280
361
|
def __str__(self):
|
|
281
362
|
row_str = []
|
|
@@ -286,48 +367,96 @@ def matrix(shape, dtype):
|
|
|
286
367
|
return "[" + ",\n ".join(row_str) + "]"
|
|
287
368
|
|
|
288
369
|
def __eq__(self, other):
|
|
289
|
-
for i in range(self.
|
|
290
|
-
|
|
291
|
-
|
|
370
|
+
for i in range(self._shape_[0]):
|
|
371
|
+
for j in range(self._shape_[1]):
|
|
372
|
+
if self[i][j] != other[i][j]:
|
|
373
|
+
return False
|
|
292
374
|
return True
|
|
293
375
|
|
|
294
|
-
|
|
295
376
|
def get_row(self, r):
|
|
296
377
|
if r < 0 or r >= self._shape_[0]:
|
|
297
378
|
raise IndexError("Invalid row index")
|
|
298
379
|
row_start = r * self._shape_[1]
|
|
299
380
|
row_end = row_start + self._shape_[1]
|
|
300
|
-
|
|
381
|
+
row_data = super().__getitem__(slice(row_start, row_end))
|
|
382
|
+
if self._wp_scalar_type_ == float16:
|
|
383
|
+
return self._wp_row_type_(*[mat_t.scalar_export(x) for x in row_data])
|
|
384
|
+
else:
|
|
385
|
+
return self._wp_row_type_(row_data)
|
|
301
386
|
|
|
302
387
|
def set_row(self, r, v):
|
|
303
388
|
if r < 0 or r >= self._shape_[0]:
|
|
304
389
|
raise IndexError("Invalid row index")
|
|
390
|
+
try:
|
|
391
|
+
iter(v)
|
|
392
|
+
except TypeError:
|
|
393
|
+
raise TypeError(
|
|
394
|
+
f"Expected to assign a slice from a sequence of values "
|
|
395
|
+
f"but got `{type(v).__name__}` instead"
|
|
396
|
+
) from None
|
|
397
|
+
|
|
305
398
|
row_start = r * self._shape_[1]
|
|
306
399
|
row_end = row_start + self._shape_[1]
|
|
400
|
+
if self._wp_scalar_type_ == float16:
|
|
401
|
+
converted = []
|
|
402
|
+
try:
|
|
403
|
+
for x in v:
|
|
404
|
+
converted.append(mat_t.scalar_import(x))
|
|
405
|
+
except ctypes.ArgumentError:
|
|
406
|
+
raise TypeError(
|
|
407
|
+
f"Expected to assign a slice from a sequence of `float16` values "
|
|
408
|
+
f"but got `{type(x).__name__}` instead"
|
|
409
|
+
) from None
|
|
410
|
+
|
|
411
|
+
v = converted
|
|
307
412
|
super().__setitem__(slice(row_start, row_end), v)
|
|
308
413
|
|
|
309
414
|
def __getitem__(self, key):
|
|
310
415
|
if isinstance(key, Tuple):
|
|
311
416
|
# element indexing m[i,j]
|
|
312
|
-
|
|
417
|
+
if len(key) != 2:
|
|
418
|
+
raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
|
|
419
|
+
if any(isinstance(x, slice) for x in key):
|
|
420
|
+
raise KeyError(f"Slices are not supported when indexing matrices using the `m[i, j]` notation")
|
|
421
|
+
return mat_t.scalar_export(super().__getitem__(key[0] * self._shape_[1] + key[1]))
|
|
313
422
|
elif isinstance(key, int):
|
|
314
423
|
# row vector indexing m[r]
|
|
315
424
|
return self.get_row(key)
|
|
316
425
|
else:
|
|
317
|
-
|
|
318
|
-
return super().__getitem__(key)
|
|
426
|
+
raise KeyError(f"Invalid key {key}, expected int or pair of ints")
|
|
319
427
|
|
|
320
428
|
def __setitem__(self, key, value):
|
|
321
429
|
if isinstance(key, Tuple):
|
|
322
430
|
# element indexing m[i,j] = x
|
|
323
|
-
|
|
431
|
+
if len(key) != 2:
|
|
432
|
+
raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
|
|
433
|
+
if any(isinstance(x, slice) for x in key):
|
|
434
|
+
raise KeyError(f"Slices are not supported when indexing matrices using the `m[i, j]` notation")
|
|
435
|
+
try:
|
|
436
|
+
return super().__setitem__(key[0] * self._shape_[1] + key[1], mat_t.scalar_import(value))
|
|
437
|
+
except (TypeError, ctypes.ArgumentError):
|
|
438
|
+
raise TypeError(
|
|
439
|
+
f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
|
|
440
|
+
f"but got `{type(value).__name__}` instead"
|
|
441
|
+
) from None
|
|
324
442
|
elif isinstance(key, int):
|
|
325
443
|
# row vector indexing m[r] = v
|
|
326
|
-
self.set_row(key, value)
|
|
444
|
+
return self.set_row(key, value)
|
|
445
|
+
elif isinstance(key, slice):
|
|
446
|
+
raise KeyError(f"Slices are not supported when indexing matrices using the `m[start:end]` notation")
|
|
447
|
+
else:
|
|
448
|
+
raise KeyError(f"Invalid key {key}, expected int or pair of ints")
|
|
449
|
+
|
|
450
|
+
@classmethod
|
|
451
|
+
def from_ptr(cls, ptr):
|
|
452
|
+
if ptr:
|
|
453
|
+
# create a new matrix instance and initialize the contents from the binary data
|
|
454
|
+
# this skips float16 conversions, assuming that float16 data is already encoded as uint16
|
|
455
|
+
value = cls()
|
|
456
|
+
ctypes.memmove(ctypes.byref(value), ptr, ctypes.sizeof(cls._type_) * cls._length_)
|
|
327
457
|
return value
|
|
328
458
|
else:
|
|
329
|
-
|
|
330
|
-
return super().__setitem__(key, value)
|
|
459
|
+
raise RuntimeError("NULL pointer exception")
|
|
331
460
|
|
|
332
461
|
return mat_t
|
|
333
462
|
|
|
@@ -337,6 +466,23 @@ class void:
|
|
|
337
466
|
pass
|
|
338
467
|
|
|
339
468
|
|
|
469
|
+
class bool:
|
|
470
|
+
_length_ = 1
|
|
471
|
+
_type_ = ctypes.c_bool
|
|
472
|
+
|
|
473
|
+
def __init__(self, x=False):
|
|
474
|
+
self.value = x
|
|
475
|
+
|
|
476
|
+
def __bool__(self) -> bool:
|
|
477
|
+
return self.value != 0
|
|
478
|
+
|
|
479
|
+
def __float__(self) -> float:
|
|
480
|
+
return float(self.value != 0)
|
|
481
|
+
|
|
482
|
+
def __int__(self) -> int:
|
|
483
|
+
return int(self.value != 0)
|
|
484
|
+
|
|
485
|
+
|
|
340
486
|
class float16:
|
|
341
487
|
_length_ = 1
|
|
342
488
|
_type_ = ctypes.c_uint16
|
|
@@ -344,6 +490,15 @@ class float16:
|
|
|
344
490
|
def __init__(self, x=0.0):
|
|
345
491
|
self.value = x
|
|
346
492
|
|
|
493
|
+
def __bool__(self) -> bool:
|
|
494
|
+
return self.value != 0.0
|
|
495
|
+
|
|
496
|
+
def __float__(self) -> float:
|
|
497
|
+
return float(self.value)
|
|
498
|
+
|
|
499
|
+
def __int__(self) -> int:
|
|
500
|
+
return int(self.value)
|
|
501
|
+
|
|
347
502
|
|
|
348
503
|
class float32:
|
|
349
504
|
_length_ = 1
|
|
@@ -352,6 +507,15 @@ class float32:
|
|
|
352
507
|
def __init__(self, x=0.0):
|
|
353
508
|
self.value = x
|
|
354
509
|
|
|
510
|
+
def __bool__(self) -> bool:
|
|
511
|
+
return self.value != 0.0
|
|
512
|
+
|
|
513
|
+
def __float__(self) -> float:
|
|
514
|
+
return float(self.value)
|
|
515
|
+
|
|
516
|
+
def __int__(self) -> int:
|
|
517
|
+
return int(self.value)
|
|
518
|
+
|
|
355
519
|
|
|
356
520
|
class float64:
|
|
357
521
|
_length_ = 1
|
|
@@ -360,6 +524,15 @@ class float64:
|
|
|
360
524
|
def __init__(self, x=0.0):
|
|
361
525
|
self.value = x
|
|
362
526
|
|
|
527
|
+
def __bool__(self) -> bool:
|
|
528
|
+
return self.value != 0.0
|
|
529
|
+
|
|
530
|
+
def __float__(self) -> float:
|
|
531
|
+
return float(self.value)
|
|
532
|
+
|
|
533
|
+
def __int__(self) -> int:
|
|
534
|
+
return int(self.value)
|
|
535
|
+
|
|
363
536
|
|
|
364
537
|
class int8:
|
|
365
538
|
_length_ = 1
|
|
@@ -368,6 +541,18 @@ class int8:
|
|
|
368
541
|
def __init__(self, x=0):
|
|
369
542
|
self.value = x
|
|
370
543
|
|
|
544
|
+
def __bool__(self) -> bool:
|
|
545
|
+
return self.value != 0
|
|
546
|
+
|
|
547
|
+
def __float__(self) -> float:
|
|
548
|
+
return float(self.value)
|
|
549
|
+
|
|
550
|
+
def __int__(self) -> int:
|
|
551
|
+
return int(self.value)
|
|
552
|
+
|
|
553
|
+
def __index__(self) -> int:
|
|
554
|
+
return int(self.value)
|
|
555
|
+
|
|
371
556
|
|
|
372
557
|
class uint8:
|
|
373
558
|
_length_ = 1
|
|
@@ -376,6 +561,18 @@ class uint8:
|
|
|
376
561
|
def __init__(self, x=0):
|
|
377
562
|
self.value = x
|
|
378
563
|
|
|
564
|
+
def __bool__(self) -> bool:
|
|
565
|
+
return self.value != 0
|
|
566
|
+
|
|
567
|
+
def __float__(self) -> float:
|
|
568
|
+
return float(self.value)
|
|
569
|
+
|
|
570
|
+
def __int__(self) -> int:
|
|
571
|
+
return int(self.value)
|
|
572
|
+
|
|
573
|
+
def __index__(self) -> int:
|
|
574
|
+
return int(self.value)
|
|
575
|
+
|
|
379
576
|
|
|
380
577
|
class int16:
|
|
381
578
|
_length_ = 1
|
|
@@ -384,6 +581,18 @@ class int16:
|
|
|
384
581
|
def __init__(self, x=0):
|
|
385
582
|
self.value = x
|
|
386
583
|
|
|
584
|
+
def __bool__(self) -> bool:
|
|
585
|
+
return self.value != 0
|
|
586
|
+
|
|
587
|
+
def __float__(self) -> float:
|
|
588
|
+
return float(self.value)
|
|
589
|
+
|
|
590
|
+
def __int__(self) -> int:
|
|
591
|
+
return int(self.value)
|
|
592
|
+
|
|
593
|
+
def __index__(self) -> int:
|
|
594
|
+
return int(self.value)
|
|
595
|
+
|
|
387
596
|
|
|
388
597
|
class uint16:
|
|
389
598
|
_length_ = 1
|
|
@@ -392,6 +601,18 @@ class uint16:
|
|
|
392
601
|
def __init__(self, x=0):
|
|
393
602
|
self.value = x
|
|
394
603
|
|
|
604
|
+
def __bool__(self) -> bool:
|
|
605
|
+
return self.value != 0
|
|
606
|
+
|
|
607
|
+
def __float__(self) -> float:
|
|
608
|
+
return float(self.value)
|
|
609
|
+
|
|
610
|
+
def __int__(self) -> int:
|
|
611
|
+
return int(self.value)
|
|
612
|
+
|
|
613
|
+
def __index__(self) -> int:
|
|
614
|
+
return int(self.value)
|
|
615
|
+
|
|
395
616
|
|
|
396
617
|
class int32:
|
|
397
618
|
_length_ = 1
|
|
@@ -400,6 +621,18 @@ class int32:
|
|
|
400
621
|
def __init__(self, x=0):
|
|
401
622
|
self.value = x
|
|
402
623
|
|
|
624
|
+
def __bool__(self) -> bool:
|
|
625
|
+
return self.value != 0
|
|
626
|
+
|
|
627
|
+
def __float__(self) -> float:
|
|
628
|
+
return float(self.value)
|
|
629
|
+
|
|
630
|
+
def __int__(self) -> int:
|
|
631
|
+
return int(self.value)
|
|
632
|
+
|
|
633
|
+
def __index__(self) -> int:
|
|
634
|
+
return int(self.value)
|
|
635
|
+
|
|
403
636
|
|
|
404
637
|
class uint32:
|
|
405
638
|
_length_ = 1
|
|
@@ -408,6 +641,18 @@ class uint32:
|
|
|
408
641
|
def __init__(self, x=0):
|
|
409
642
|
self.value = x
|
|
410
643
|
|
|
644
|
+
def __bool__(self) -> bool:
|
|
645
|
+
return self.value != 0
|
|
646
|
+
|
|
647
|
+
def __float__(self) -> float:
|
|
648
|
+
return float(self.value)
|
|
649
|
+
|
|
650
|
+
def __int__(self) -> int:
|
|
651
|
+
return int(self.value)
|
|
652
|
+
|
|
653
|
+
def __index__(self) -> int:
|
|
654
|
+
return int(self.value)
|
|
655
|
+
|
|
411
656
|
|
|
412
657
|
class int64:
|
|
413
658
|
_length_ = 1
|
|
@@ -416,6 +661,18 @@ class int64:
|
|
|
416
661
|
def __init__(self, x=0):
|
|
417
662
|
self.value = x
|
|
418
663
|
|
|
664
|
+
def __bool__(self) -> bool:
|
|
665
|
+
return self.value != 0
|
|
666
|
+
|
|
667
|
+
def __float__(self) -> float:
|
|
668
|
+
return float(self.value)
|
|
669
|
+
|
|
670
|
+
def __int__(self) -> int:
|
|
671
|
+
return int(self.value)
|
|
672
|
+
|
|
673
|
+
def __index__(self) -> int:
|
|
674
|
+
return int(self.value)
|
|
675
|
+
|
|
419
676
|
|
|
420
677
|
class uint64:
|
|
421
678
|
_length_ = 1
|
|
@@ -424,6 +681,18 @@ class uint64:
|
|
|
424
681
|
def __init__(self, x=0):
|
|
425
682
|
self.value = x
|
|
426
683
|
|
|
684
|
+
def __bool__(self) -> bool:
|
|
685
|
+
return self.value != 0
|
|
686
|
+
|
|
687
|
+
def __float__(self) -> float:
|
|
688
|
+
return float(self.value)
|
|
689
|
+
|
|
690
|
+
def __int__(self) -> int:
|
|
691
|
+
return int(self.value)
|
|
692
|
+
|
|
693
|
+
def __index__(self) -> int:
|
|
694
|
+
return int(self.value)
|
|
695
|
+
|
|
427
696
|
|
|
428
697
|
def quaternion(dtype=Any):
|
|
429
698
|
class quat_t(vector(length=4, dtype=dtype)):
|
|
@@ -453,23 +722,63 @@ class quatd(quaternion(dtype=float64)):
|
|
|
453
722
|
|
|
454
723
|
def transformation(dtype=Any):
|
|
455
724
|
class transform_t(vector(length=7, dtype=dtype)):
|
|
725
|
+
_wp_init_from_components_sig_ = inspect.Signature(
|
|
726
|
+
(
|
|
727
|
+
inspect.Parameter(
|
|
728
|
+
"p",
|
|
729
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
730
|
+
default=(0.0, 0.0, 0.0),
|
|
731
|
+
),
|
|
732
|
+
inspect.Parameter(
|
|
733
|
+
"q",
|
|
734
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
735
|
+
default=(0.0, 0.0, 0.0, 1.0),
|
|
736
|
+
),
|
|
737
|
+
),
|
|
738
|
+
)
|
|
456
739
|
_wp_type_params_ = [dtype]
|
|
457
740
|
_wp_generic_type_str_ = "transform_t"
|
|
458
741
|
_wp_constructor_ = "transformation"
|
|
459
742
|
|
|
460
|
-
def __init__(self,
|
|
461
|
-
|
|
743
|
+
def __init__(self, *args, **kwargs):
|
|
744
|
+
if len(args) == 1 and len(kwargs) == 0:
|
|
745
|
+
if getattr(args[0], "_wp_generic_type_str_") == self._wp_generic_type_str_:
|
|
746
|
+
# Copy constructor.
|
|
747
|
+
super().__init__(*args[0])
|
|
748
|
+
return
|
|
749
|
+
|
|
750
|
+
try:
|
|
751
|
+
# For backward compatibility, try to check if the arguments
|
|
752
|
+
# match the original signature that'd allow initializing
|
|
753
|
+
# the `p` and `q` components separately.
|
|
754
|
+
bound_args = self._wp_init_from_components_sig_.bind(*args, **kwargs)
|
|
755
|
+
bound_args.apply_defaults()
|
|
756
|
+
p, q = bound_args.args
|
|
757
|
+
except (TypeError, ValueError):
|
|
758
|
+
# Fallback to the vector's constructor.
|
|
759
|
+
super().__init__(*args)
|
|
760
|
+
return
|
|
761
|
+
|
|
762
|
+
# Even if the arguments match the original “from components”
|
|
763
|
+
# signature, we still need to make sure that they represent
|
|
764
|
+
# sequences that can be unpacked.
|
|
765
|
+
if hasattr(p, "__len__") and hasattr(q, "__len__"):
|
|
766
|
+
# Initialize from the `p` and `q` components.
|
|
767
|
+
super().__init__()
|
|
768
|
+
self[0:3] = vector(length=3, dtype=dtype)(*p)
|
|
769
|
+
self[3:7] = quaternion(dtype=dtype)(*q)
|
|
770
|
+
return
|
|
462
771
|
|
|
463
|
-
|
|
464
|
-
|
|
772
|
+
# Fallback to the vector's constructor.
|
|
773
|
+
super().__init__(*args)
|
|
465
774
|
|
|
466
775
|
@property
|
|
467
776
|
def p(self):
|
|
468
|
-
return self[0:3]
|
|
777
|
+
return vec3(self[0:3])
|
|
469
778
|
|
|
470
779
|
@property
|
|
471
780
|
def q(self):
|
|
472
|
-
return self[3:7]
|
|
781
|
+
return quat(self[3:7])
|
|
473
782
|
|
|
474
783
|
return transform_t
|
|
475
784
|
|
|
@@ -753,6 +1062,7 @@ vector_types = [
|
|
|
753
1062
|
]
|
|
754
1063
|
|
|
755
1064
|
np_dtype_to_warp_type = {
|
|
1065
|
+
np.dtype(np.bool_): bool,
|
|
756
1066
|
np.dtype(np.int8): int8,
|
|
757
1067
|
np.dtype(np.uint8): uint8,
|
|
758
1068
|
np.dtype(np.int16): int16,
|
|
@@ -768,6 +1078,21 @@ np_dtype_to_warp_type = {
|
|
|
768
1078
|
np.dtype(np.float64): float64,
|
|
769
1079
|
}
|
|
770
1080
|
|
|
1081
|
+
warp_type_to_np_dtype = {
|
|
1082
|
+
bool: np.bool_,
|
|
1083
|
+
int8: np.int8,
|
|
1084
|
+
int16: np.int16,
|
|
1085
|
+
int32: np.int32,
|
|
1086
|
+
int64: np.int64,
|
|
1087
|
+
uint8: np.uint8,
|
|
1088
|
+
uint16: np.uint16,
|
|
1089
|
+
uint32: np.uint32,
|
|
1090
|
+
uint64: np.uint64,
|
|
1091
|
+
float16: np.float16,
|
|
1092
|
+
float32: np.float32,
|
|
1093
|
+
float64: np.float64,
|
|
1094
|
+
}
|
|
1095
|
+
|
|
771
1096
|
|
|
772
1097
|
# represent a Python range iterator
|
|
773
1098
|
class range_t:
|
|
@@ -777,18 +1102,21 @@ class range_t:
|
|
|
777
1102
|
|
|
778
1103
|
# definition just for kernel type (cannot be a parameter), see bvh.h
|
|
779
1104
|
class bvh_query_t:
|
|
1105
|
+
"""Object used to track state during BVH traversal."""
|
|
780
1106
|
def __init__(self):
|
|
781
1107
|
pass
|
|
782
1108
|
|
|
783
1109
|
|
|
784
1110
|
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
785
1111
|
class mesh_query_aabb_t:
|
|
1112
|
+
"""Object used to track state during mesh traversal."""
|
|
786
1113
|
def __init__(self):
|
|
787
1114
|
pass
|
|
788
1115
|
|
|
789
1116
|
|
|
790
1117
|
# definition just for kernel type (cannot be a parameter), see hash_grid.h
|
|
791
1118
|
class hash_grid_query_t:
|
|
1119
|
+
"""Object used to track state during neighbor traversal."""
|
|
792
1120
|
def __init__(self):
|
|
793
1121
|
pass
|
|
794
1122
|
|
|
@@ -800,6 +1128,8 @@ LAUNCH_MAX_DIMS = 4
|
|
|
800
1128
|
# must match array.h
|
|
801
1129
|
ARRAY_TYPE_REGULAR = 0
|
|
802
1130
|
ARRAY_TYPE_INDEXED = 1
|
|
1131
|
+
ARRAY_TYPE_FABRIC = 2
|
|
1132
|
+
ARRAY_TYPE_FABRIC_INDEXED = 3
|
|
803
1133
|
|
|
804
1134
|
|
|
805
1135
|
# represents bounds for kernel launch (number of threads across multiple dimensions)
|
|
@@ -851,6 +1181,30 @@ class array_t(ctypes.Structure):
|
|
|
851
1181
|
self.shape[i] = shape[i]
|
|
852
1182
|
self.strides[i] = strides[i]
|
|
853
1183
|
|
|
1184
|
+
# structured type description used when array_t is packed in a struct and shared via numpy structured array.
|
|
1185
|
+
@classmethod
|
|
1186
|
+
def numpy_dtype(cls):
|
|
1187
|
+
return cls._numpy_dtype_
|
|
1188
|
+
|
|
1189
|
+
# structured value used when array_t is packed in a struct and shared via a numpy structured array
|
|
1190
|
+
def numpy_value(self):
|
|
1191
|
+
return (self.data, self.grad, list(self.shape), list(self.strides), self.ndim)
|
|
1192
|
+
|
|
1193
|
+
|
|
1194
|
+
# NOTE: must match array_t._fields_
|
|
1195
|
+
array_t._numpy_dtype_ = {
|
|
1196
|
+
"names": ["data", "grad", "shape", "strides", "ndim"],
|
|
1197
|
+
"formats": ["u8", "u8", f"{ARRAY_MAX_DIMS}i4", f"{ARRAY_MAX_DIMS}i4", "i4"],
|
|
1198
|
+
"offsets": [
|
|
1199
|
+
array_t.data.offset,
|
|
1200
|
+
array_t.grad.offset,
|
|
1201
|
+
array_t.shape.offset,
|
|
1202
|
+
array_t.strides.offset,
|
|
1203
|
+
array_t.ndim.offset,
|
|
1204
|
+
],
|
|
1205
|
+
"itemsize": ctypes.sizeof(array_t),
|
|
1206
|
+
}
|
|
1207
|
+
|
|
854
1208
|
|
|
855
1209
|
class indexedarray_t(ctypes.Structure):
|
|
856
1210
|
_fields_ = [
|
|
@@ -892,16 +1246,20 @@ def type_length(dtype):
|
|
|
892
1246
|
return dtype._length_
|
|
893
1247
|
|
|
894
1248
|
|
|
1249
|
+
def type_scalar_type(dtype):
|
|
1250
|
+
return getattr(dtype, "_wp_scalar_type_", dtype)
|
|
1251
|
+
|
|
1252
|
+
|
|
895
1253
|
def type_size_in_bytes(dtype):
|
|
896
1254
|
if dtype.__module__ == "ctypes":
|
|
897
1255
|
return ctypes.sizeof(dtype)
|
|
898
|
-
elif
|
|
1256
|
+
elif isinstance(dtype, warp.codegen.Struct):
|
|
899
1257
|
return ctypes.sizeof(dtype.ctype)
|
|
900
1258
|
elif dtype == float or dtype == int:
|
|
901
1259
|
return 4
|
|
902
1260
|
elif hasattr(dtype, "_type_"):
|
|
903
1261
|
return getattr(dtype, "_length_", 1) * ctypes.sizeof(dtype._type_)
|
|
904
|
-
|
|
1262
|
+
|
|
905
1263
|
else:
|
|
906
1264
|
return 0
|
|
907
1265
|
|
|
@@ -916,9 +1274,9 @@ def type_to_warp(dtype):
|
|
|
916
1274
|
|
|
917
1275
|
|
|
918
1276
|
def type_typestr(dtype):
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
1277
|
+
if dtype == bool:
|
|
1278
|
+
return "?"
|
|
1279
|
+
elif dtype == float16:
|
|
922
1280
|
return "<f2"
|
|
923
1281
|
elif dtype == float32:
|
|
924
1282
|
return "<f4"
|
|
@@ -940,8 +1298,8 @@ def type_typestr(dtype):
|
|
|
940
1298
|
return "<i8"
|
|
941
1299
|
elif dtype == uint64:
|
|
942
1300
|
return "<u8"
|
|
943
|
-
elif isinstance(dtype, Struct):
|
|
944
|
-
return f"|V{ctypes.sizeof(dtype.ctype)}"
|
|
1301
|
+
elif isinstance(dtype, warp.codegen.Struct):
|
|
1302
|
+
return f"|V{ctypes.sizeof(dtype.ctype)}"
|
|
945
1303
|
elif issubclass(dtype, ctypes.Array):
|
|
946
1304
|
return type_typestr(dtype._wp_scalar_type_)
|
|
947
1305
|
else:
|
|
@@ -954,9 +1312,16 @@ def type_repr(t):
|
|
|
954
1312
|
return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
|
|
955
1313
|
if type_is_vector(t):
|
|
956
1314
|
return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
|
|
957
|
-
|
|
1315
|
+
if type_is_matrix(t):
|
|
958
1316
|
return str(f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={t._wp_scalar_type_})")
|
|
959
|
-
|
|
1317
|
+
if isinstance(t, warp.codegen.Struct):
|
|
1318
|
+
return type_repr(t.cls)
|
|
1319
|
+
if t in scalar_types:
|
|
1320
|
+
return t.__name__
|
|
1321
|
+
|
|
1322
|
+
try:
|
|
1323
|
+
return t.__module__ + "." + t.__qualname__
|
|
1324
|
+
except AttributeError:
|
|
960
1325
|
return str(t)
|
|
961
1326
|
|
|
962
1327
|
|
|
@@ -974,14 +1339,6 @@ def type_is_float(t):
|
|
|
974
1339
|
return t in float_types
|
|
975
1340
|
|
|
976
1341
|
|
|
977
|
-
def type_is_struct(dtype):
|
|
978
|
-
from warp.codegen import Struct
|
|
979
|
-
|
|
980
|
-
if isinstance(dtype, Struct):
|
|
981
|
-
return True
|
|
982
|
-
else:
|
|
983
|
-
return False
|
|
984
|
-
|
|
985
1342
|
# returns True if the passed *type* is a vector
|
|
986
1343
|
def type_is_vector(t):
|
|
987
1344
|
if hasattr(t, "_wp_generic_type_str_") and t._wp_generic_type_str_ == "vec_t":
|
|
@@ -1000,7 +1357,7 @@ def type_is_matrix(t):
|
|
|
1000
1357
|
|
|
1001
1358
|
# returns true for all value types (int, float, bool, scalars, vectors, matrices)
|
|
1002
1359
|
def type_is_value(x):
|
|
1003
|
-
if (x == int) or (x == float) or (x == bool) or (x in scalar_types) or issubclass(x, ctypes.Array):
|
|
1360
|
+
if (x == int) or (x == float) or (x == builtins.bool) or (x in scalar_types) or issubclass(x, ctypes.Array):
|
|
1004
1361
|
return True
|
|
1005
1362
|
else:
|
|
1006
1363
|
return False
|
|
@@ -1028,14 +1385,16 @@ def types_equal(a, b, match_generic=False):
|
|
|
1028
1385
|
# convert to canonical types
|
|
1029
1386
|
if a == float:
|
|
1030
1387
|
a = float32
|
|
1031
|
-
|
|
1388
|
+
elif a == int:
|
|
1032
1389
|
a = int32
|
|
1033
1390
|
|
|
1034
1391
|
if b == float:
|
|
1035
1392
|
b = float32
|
|
1036
|
-
|
|
1393
|
+
elif b == int:
|
|
1037
1394
|
b = int32
|
|
1038
1395
|
|
|
1396
|
+
compatible_bool_types = [builtins.bool, bool]
|
|
1397
|
+
|
|
1039
1398
|
def are_equal(p1, p2):
|
|
1040
1399
|
if match_generic:
|
|
1041
1400
|
if p1 == Any or p2 == Any:
|
|
@@ -1052,7 +1411,22 @@ def types_equal(a, b, match_generic=False):
|
|
|
1052
1411
|
return True
|
|
1053
1412
|
if p1 == Float and p2 == Float:
|
|
1054
1413
|
return True
|
|
1055
|
-
|
|
1414
|
+
|
|
1415
|
+
# convert to canonical types
|
|
1416
|
+
if p1 == float:
|
|
1417
|
+
p1 = float32
|
|
1418
|
+
elif p1 == int:
|
|
1419
|
+
p1 = int32
|
|
1420
|
+
|
|
1421
|
+
if p2 == float:
|
|
1422
|
+
p2 = float32
|
|
1423
|
+
elif b == int:
|
|
1424
|
+
p2 = int32
|
|
1425
|
+
|
|
1426
|
+
if p1 in compatible_bool_types and p2 in compatible_bool_types:
|
|
1427
|
+
return True
|
|
1428
|
+
else:
|
|
1429
|
+
return p1 == p2
|
|
1056
1430
|
|
|
1057
1431
|
if (
|
|
1058
1432
|
hasattr(a, "_wp_generic_type_str_")
|
|
@@ -1060,9 +1434,7 @@ def types_equal(a, b, match_generic=False):
|
|
|
1060
1434
|
and a._wp_generic_type_str_ == b._wp_generic_type_str_
|
|
1061
1435
|
):
|
|
1062
1436
|
return all([are_equal(p1, p2) for p1, p2 in zip(a._wp_type_params_, b._wp_type_params_)])
|
|
1063
|
-
if
|
|
1064
|
-
return True
|
|
1065
|
-
if isinstance(a, indexedarray) and isinstance(b, indexedarray):
|
|
1437
|
+
if is_array(a) and type(a) is type(b):
|
|
1066
1438
|
return True
|
|
1067
1439
|
else:
|
|
1068
1440
|
return are_equal(a, b)
|
|
@@ -1093,18 +1465,18 @@ class array(Array):
|
|
|
1093
1465
|
dtype: DType = Any,
|
|
1094
1466
|
shape=None,
|
|
1095
1467
|
strides=None,
|
|
1096
|
-
length=
|
|
1468
|
+
length=None,
|
|
1097
1469
|
ptr=None,
|
|
1098
|
-
|
|
1099
|
-
capacity=0,
|
|
1470
|
+
capacity=None,
|
|
1100
1471
|
device=None,
|
|
1472
|
+
pinned=False,
|
|
1101
1473
|
copy=True,
|
|
1102
|
-
owner=True,
|
|
1474
|
+
owner=True, # TODO: replace with deleter=None
|
|
1103
1475
|
ndim=None,
|
|
1476
|
+
grad=None,
|
|
1104
1477
|
requires_grad=False,
|
|
1105
|
-
pinned=False,
|
|
1106
1478
|
):
|
|
1107
|
-
"""Constructs a new Warp array object
|
|
1479
|
+
"""Constructs a new Warp array object
|
|
1108
1480
|
|
|
1109
1481
|
When the ``data`` argument is a valid list, tuple, or ndarray the array will be constructed from this object's data.
|
|
1110
1482
|
For objects that are not stored sequentially in memory (e.g.: a list), then the data will first
|
|
@@ -1115,39 +1487,38 @@ class array(Array):
|
|
|
1115
1487
|
allocation should reside on the same device given by the device argument, and the user should set the length
|
|
1116
1488
|
and dtype parameter appropriately.
|
|
1117
1489
|
|
|
1490
|
+
If neither ``data`` nor ``ptr`` are specified, the ``shape`` or ``length`` arguments are checked next.
|
|
1491
|
+
This construction path can be used to create new uninitialized arrays, but users are encouraged to call
|
|
1492
|
+
``wp.empty()``, ``wp.zeros()``, or ``wp.full()`` instead to create new arrays.
|
|
1493
|
+
|
|
1494
|
+
If none of the above arguments are specified, a simple type annotation is constructed. This is used when annotating
|
|
1495
|
+
kernel arguments or struct members (e.g.,``arr: wp.array(dtype=float)``). In this case, only ``dtype`` and ``ndim``
|
|
1496
|
+
are taken into account and no memory is allocated for the array.
|
|
1497
|
+
|
|
1118
1498
|
Args:
|
|
1119
1499
|
data (Union[list, tuple, ndarray]) An object to construct the array from, can be a Tuple, List, or generally any type convertible to an np.array
|
|
1120
1500
|
dtype (Union): One of the built-in types, e.g.: :class:`warp.mat33`, if dtype is Any and data an ndarray then it will be inferred from the array data type
|
|
1121
1501
|
shape (tuple): Dimensions of the array
|
|
1122
1502
|
strides (tuple): Number of bytes in each dimension between successive elements of the array
|
|
1123
|
-
length (int): Number of elements
|
|
1503
|
+
length (int): Number of elements of the data type (deprecated, users should use `shape` argument)
|
|
1124
1504
|
ptr (uint64): Address of an external memory address to alias (data should be None)
|
|
1125
|
-
grad_ptr (uint64): Address of an external memory address to alias for the gradient array
|
|
1126
1505
|
capacity (int): Maximum size in bytes of the ptr allocation (data should be None)
|
|
1127
1506
|
device (Devicelike): Device the array lives on
|
|
1128
1507
|
copy (bool): Whether the incoming data will be copied or aliased, this is only possible when the incoming `data` already lives on the device specified and types match
|
|
1129
1508
|
owner (bool): Should the array object try to deallocate memory when it is deleted
|
|
1130
1509
|
requires_grad (bool): Whether or not gradients will be tracked for this array, see :class:`warp.Tape` for details
|
|
1510
|
+
grad (array): The gradient array to use
|
|
1131
1511
|
pinned (bool): Whether to allocate pinned host memory, which allows asynchronous host-device transfers (only applicable with device="cpu")
|
|
1132
1512
|
|
|
1133
1513
|
"""
|
|
1134
1514
|
|
|
1135
1515
|
self.owner = False
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
elif isinstance(shape, List):
|
|
1143
|
-
shape = tuple(shape)
|
|
1144
|
-
|
|
1145
|
-
self.shape = shape
|
|
1146
|
-
|
|
1147
|
-
if len(shape) > ARRAY_MAX_DIMS:
|
|
1148
|
-
raise RuntimeError(
|
|
1149
|
-
f"Arrays may only have {ARRAY_MAX_DIMS} dimensions maximum, trying to create array with {len(shape)} dims."
|
|
1150
|
-
)
|
|
1516
|
+
self.ctype = None
|
|
1517
|
+
self._requires_grad = False
|
|
1518
|
+
self._grad = None
|
|
1519
|
+
# __array_interface__ or __cuda_array_interface__, evaluated lazily and cached
|
|
1520
|
+
self._array_interface = None
|
|
1521
|
+
self.is_transposed = False
|
|
1151
1522
|
|
|
1152
1523
|
# canonicalize dtype
|
|
1153
1524
|
if dtype == int:
|
|
@@ -1155,20 +1526,78 @@ class array(Array):
|
|
|
1155
1526
|
elif dtype == float:
|
|
1156
1527
|
dtype = float32
|
|
1157
1528
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1529
|
+
# convert shape to tuple (or leave shape=None if neither shape nor length were specified)
|
|
1530
|
+
if shape is not None:
|
|
1531
|
+
if isinstance(shape, int):
|
|
1532
|
+
shape = (shape,)
|
|
1533
|
+
else:
|
|
1534
|
+
shape = tuple(shape)
|
|
1535
|
+
if len(shape) > ARRAY_MAX_DIMS:
|
|
1536
|
+
raise RuntimeError(
|
|
1537
|
+
f"Failed to create array with shape {shape}, the maximum number of dimensions is {ARRAY_MAX_DIMS}"
|
|
1538
|
+
)
|
|
1539
|
+
elif length is not None:
|
|
1540
|
+
# backward compatibility
|
|
1541
|
+
shape = (length,)
|
|
1162
1542
|
|
|
1543
|
+
# determine the construction path from the given arguments
|
|
1163
1544
|
if data is not None:
|
|
1164
|
-
|
|
1165
|
-
raise RuntimeError(f"Cannot allocate memory on device {device} while graph capture is active")
|
|
1166
|
-
|
|
1545
|
+
# data or ptr, not both
|
|
1167
1546
|
if ptr is not None:
|
|
1168
|
-
|
|
1169
|
-
|
|
1547
|
+
raise RuntimeError("Can only construct arrays with either `data` or `ptr` arguments, not both")
|
|
1548
|
+
self._init_from_data(data, dtype, shape, device, copy, pinned)
|
|
1549
|
+
elif ptr is not None:
|
|
1550
|
+
self._init_from_ptr(ptr, dtype, shape, strides, capacity, device, owner, pinned)
|
|
1551
|
+
elif shape is not None:
|
|
1552
|
+
self._init_new(dtype, shape, strides, device, pinned)
|
|
1553
|
+
else:
|
|
1554
|
+
self._init_annotation(dtype, ndim or 1)
|
|
1170
1555
|
|
|
1171
|
-
|
|
1556
|
+
# initialize gradient, if needed
|
|
1557
|
+
if self.device is not None:
|
|
1558
|
+
if grad is not None:
|
|
1559
|
+
# this will also check whether the gradient array is compatible
|
|
1560
|
+
self.grad = grad
|
|
1561
|
+
else:
|
|
1562
|
+
# allocate gradient if needed
|
|
1563
|
+
self._requires_grad = requires_grad
|
|
1564
|
+
if requires_grad:
|
|
1565
|
+
with warp.ScopedStream(self.device.null_stream):
|
|
1566
|
+
self._alloc_grad()
|
|
1567
|
+
|
|
1568
|
+
def _init_from_data(self, data, dtype, shape, device, copy, pinned):
|
|
1569
|
+
if not hasattr(data, "__len__"):
|
|
1570
|
+
raise RuntimeError(f"Data must be a sequence or array, got scalar {data}")
|
|
1571
|
+
|
|
1572
|
+
if hasattr(dtype, "_wp_scalar_type_"):
|
|
1573
|
+
dtype_shape = dtype._shape_
|
|
1574
|
+
dtype_ndim = len(dtype_shape)
|
|
1575
|
+
scalar_dtype = dtype._wp_scalar_type_
|
|
1576
|
+
else:
|
|
1577
|
+
dtype_shape = ()
|
|
1578
|
+
dtype_ndim = 0
|
|
1579
|
+
scalar_dtype = dtype
|
|
1580
|
+
|
|
1581
|
+
# convert input data to ndarray (handles lists, tuples, etc.) and determine dtype
|
|
1582
|
+
if dtype == Any:
|
|
1583
|
+
# infer dtype from data
|
|
1584
|
+
try:
|
|
1585
|
+
arr = np.array(data, copy=False, ndmin=1)
|
|
1586
|
+
except Exception as e:
|
|
1587
|
+
raise RuntimeError(f"Failed to convert input data to an array: {e}")
|
|
1588
|
+
dtype = np_dtype_to_warp_type.get(arr.dtype)
|
|
1589
|
+
if dtype is None:
|
|
1590
|
+
raise RuntimeError(f"Unsupported input data dtype: {arr.dtype}")
|
|
1591
|
+
elif isinstance(dtype, warp.codegen.Struct):
|
|
1592
|
+
if isinstance(data, np.ndarray):
|
|
1593
|
+
# construct from numpy structured array
|
|
1594
|
+
if data.dtype != dtype.numpy_dtype():
|
|
1595
|
+
raise RuntimeError(
|
|
1596
|
+
f"Invalid source data type for array of structs, expected {dtype.numpy_dtype()}, got {data.dtype}"
|
|
1597
|
+
)
|
|
1598
|
+
arr = data
|
|
1599
|
+
elif isinstance(data, (list, tuple)):
|
|
1600
|
+
# construct from a sequence of structs
|
|
1172
1601
|
try:
|
|
1173
1602
|
# convert each struct instance to its corresponding ctype
|
|
1174
1603
|
ctype_list = [v.__ctype__() for v in data]
|
|
@@ -1176,156 +1605,227 @@ class array(Array):
|
|
|
1176
1605
|
ctype_arr = (dtype.ctype * len(ctype_list))(*ctype_list)
|
|
1177
1606
|
# convert to numpy
|
|
1178
1607
|
arr = np.frombuffer(ctype_arr, dtype=dtype.ctype)
|
|
1179
|
-
#arr = np.array(ctype_arr, copy=False)
|
|
1180
|
-
|
|
1181
|
-
except Exception as e:
|
|
1182
|
-
raise RuntimeError(
|
|
1183
|
-
"Error while trying to construct Warp array from a Python list of Warp structs." + str(e))
|
|
1184
|
-
|
|
1185
|
-
else:
|
|
1186
|
-
try:
|
|
1187
|
-
# convert tuples and lists of numeric types to ndarray
|
|
1188
|
-
arr = np.array(data, copy=False)
|
|
1189
1608
|
except Exception as e:
|
|
1190
1609
|
raise RuntimeError(
|
|
1191
|
-
"
|
|
1192
|
-
+ str(e)
|
|
1193
|
-
)
|
|
1194
|
-
|
|
1195
|
-
if dtype == Any:
|
|
1196
|
-
# infer dtype from the source data array
|
|
1197
|
-
dtype = np_dtype_to_warp_type[arr.dtype]
|
|
1198
|
-
|
|
1199
|
-
# try to convert numeric src array to destination type
|
|
1200
|
-
if not isinstance(dtype, warp.codegen.Struct):
|
|
1201
|
-
try:
|
|
1202
|
-
arr = arr.astype(dtype=type_typestr(dtype), copy=False)
|
|
1203
|
-
except:
|
|
1204
|
-
raise RuntimeError(
|
|
1205
|
-
f"Could not convert input data with type {arr.dtype} to array with type {dtype._type_}"
|
|
1610
|
+
f"Error while trying to construct Warp array from a sequence of Warp structs: {e}"
|
|
1206
1611
|
)
|
|
1612
|
+
else:
|
|
1613
|
+
raise RuntimeError(
|
|
1614
|
+
"Invalid data argument for array of structs, expected a sequence of structs or a NumPy structured array"
|
|
1615
|
+
)
|
|
1616
|
+
else:
|
|
1617
|
+
# convert input data to the given dtype
|
|
1618
|
+
npdtype = warp_type_to_np_dtype.get(scalar_dtype)
|
|
1619
|
+
if npdtype is None:
|
|
1620
|
+
raise RuntimeError(
|
|
1621
|
+
f"Failed to convert input data to an array with Warp type {warp.context.type_str(dtype)}"
|
|
1622
|
+
)
|
|
1623
|
+
try:
|
|
1624
|
+
arr = np.array(data, dtype=npdtype, copy=False, ndmin=1)
|
|
1625
|
+
except Exception as e:
|
|
1626
|
+
raise RuntimeError(f"Failed to convert input data to an array with type {npdtype}: {e}")
|
|
1627
|
+
|
|
1628
|
+
# determine whether the input needs reshaping
|
|
1629
|
+
target_npshape = None
|
|
1630
|
+
if shape is not None:
|
|
1631
|
+
target_npshape = (*shape, *dtype_shape)
|
|
1632
|
+
elif dtype_ndim > 0:
|
|
1633
|
+
# prune inner dimensions of length 1
|
|
1634
|
+
while arr.ndim > 1 and arr.shape[-1] == 1:
|
|
1635
|
+
arr = np.squeeze(arr, axis=-1)
|
|
1636
|
+
# if the inner dims don't match exactly, check if the innermost dim is a multiple of type length
|
|
1637
|
+
if arr.ndim < dtype_ndim or arr.shape[-dtype_ndim:] != dtype_shape:
|
|
1638
|
+
if arr.shape[-1] == dtype._length_:
|
|
1639
|
+
target_npshape = (*arr.shape[:-1], *dtype_shape)
|
|
1640
|
+
elif arr.shape[-1] % dtype._length_ == 0:
|
|
1641
|
+
target_npshape = (*arr.shape[:-1], arr.shape[-1] // dtype._length_, *dtype_shape)
|
|
1642
|
+
else:
|
|
1643
|
+
if dtype_ndim == 1:
|
|
1644
|
+
raise RuntimeError(
|
|
1645
|
+
f"The inner dimensions of the input data are not compatible with the requested vector type {warp.context.type_str(dtype)}: expected an inner dimension that is a multiple of {dtype._length_}"
|
|
1646
|
+
)
|
|
1647
|
+
else:
|
|
1648
|
+
raise RuntimeError(
|
|
1649
|
+
f"The inner dimensions of the input data are not compatible with the requested matrix type {warp.context.type_str(dtype)}: expected inner dimensions {dtype._shape_} or a multiple of {dtype._length_}"
|
|
1650
|
+
)
|
|
1207
1651
|
|
|
1208
|
-
|
|
1209
|
-
|
|
1652
|
+
if target_npshape is not None:
|
|
1653
|
+
try:
|
|
1654
|
+
arr = arr.reshape(target_npshape)
|
|
1655
|
+
except Exception as e:
|
|
1656
|
+
raise RuntimeError(
|
|
1657
|
+
f"Failed to reshape the input data to the given shape {shape} and type {warp.context.type_str(dtype)}: {e}"
|
|
1658
|
+
)
|
|
1210
1659
|
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1660
|
+
# determine final shape and strides
|
|
1661
|
+
if dtype_ndim > 0:
|
|
1662
|
+
# make sure the inner dims are contiguous for vector/matrix types
|
|
1663
|
+
scalar_size = type_size_in_bytes(dtype._wp_scalar_type_)
|
|
1664
|
+
inner_contiguous = arr.strides[-1] == scalar_size
|
|
1665
|
+
if inner_contiguous and dtype_ndim > 1:
|
|
1666
|
+
inner_contiguous = arr.strides[-2] == scalar_size * dtype_shape[-1]
|
|
1214
1667
|
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
strides = arr.__array_interface__.get("strides", None)
|
|
1668
|
+
if not inner_contiguous:
|
|
1669
|
+
arr = np.ascontiguousarray(arr)
|
|
1218
1670
|
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
if arr.ndim == 1:
|
|
1225
|
-
arr = arr.reshape((-1, *dtype._shape_))
|
|
1671
|
+
shape = arr.shape[:-dtype_ndim] or (1,)
|
|
1672
|
+
strides = arr.strides[:-dtype_ndim] or (type_size_in_bytes(dtype),)
|
|
1673
|
+
else:
|
|
1674
|
+
shape = arr.shape or (1,)
|
|
1675
|
+
strides = arr.strides or (type_size_in_bytes(dtype),)
|
|
1226
1676
|
|
|
1227
|
-
|
|
1228
|
-
# e.g.: array of mat22 objects should have shape (n, 2, 2)
|
|
1229
|
-
dtype_ndim = len(dtype._shape_)
|
|
1677
|
+
device = warp.get_device(device)
|
|
1230
1678
|
|
|
1231
|
-
|
|
1232
|
-
|
|
1679
|
+
if device.is_cpu and not copy and not pinned:
|
|
1680
|
+
# reference numpy memory directly
|
|
1681
|
+
self._init_from_ptr(arr.ctypes.data, dtype, shape, strides, None, device, False, False)
|
|
1682
|
+
# keep a ref to the source array to keep allocation alive
|
|
1683
|
+
self._ref = arr
|
|
1684
|
+
else:
|
|
1685
|
+
# copy data into a new array
|
|
1686
|
+
self._init_new(dtype, shape, None, device, pinned)
|
|
1687
|
+
src = array(
|
|
1688
|
+
ptr=arr.ctypes.data,
|
|
1689
|
+
dtype=dtype,
|
|
1690
|
+
shape=shape,
|
|
1691
|
+
strides=strides,
|
|
1692
|
+
device="cpu",
|
|
1693
|
+
copy=False,
|
|
1694
|
+
owner=False,
|
|
1695
|
+
)
|
|
1696
|
+
warp.copy(self, src)
|
|
1233
1697
|
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
)
|
|
1698
|
+
def _init_from_ptr(self, ptr, dtype, shape, strides, capacity, device, owner, pinned):
|
|
1699
|
+
if dtype == Any:
|
|
1700
|
+
raise RuntimeError("A concrete data type is required to create the array")
|
|
1238
1701
|
|
|
1239
|
-
|
|
1702
|
+
device = warp.get_device(device)
|
|
1240
1703
|
|
|
1241
|
-
|
|
1242
|
-
|
|
1704
|
+
size = 1
|
|
1705
|
+
for d in shape:
|
|
1706
|
+
size *= d
|
|
1243
1707
|
|
|
1244
|
-
|
|
1245
|
-
# ref numpy memory directly
|
|
1246
|
-
self.shape = shape
|
|
1247
|
-
self.ptr = ptr
|
|
1248
|
-
self.grad_ptr = grad_ptr
|
|
1249
|
-
self.dtype = dtype
|
|
1250
|
-
self.strides = strides
|
|
1251
|
-
self.capacity = arr.size * type_size_in_bytes(dtype)
|
|
1252
|
-
self.device = device
|
|
1253
|
-
self.owner = False
|
|
1254
|
-
self.pinned = False
|
|
1708
|
+
contiguous_strides = strides_from_shape(shape, dtype)
|
|
1255
1709
|
|
|
1256
|
-
|
|
1257
|
-
|
|
1710
|
+
if strides is None:
|
|
1711
|
+
strides = contiguous_strides
|
|
1712
|
+
is_contiguous = True
|
|
1713
|
+
if capacity is None:
|
|
1714
|
+
capacity = size * type_size_in_bytes(dtype)
|
|
1715
|
+
else:
|
|
1716
|
+
is_contiguous = strides == contiguous_strides
|
|
1717
|
+
if capacity is None:
|
|
1718
|
+
capacity = shape[0] * strides[0]
|
|
1719
|
+
|
|
1720
|
+
self.dtype = dtype
|
|
1721
|
+
self.ndim = len(shape)
|
|
1722
|
+
self.size = size
|
|
1723
|
+
self.capacity = capacity
|
|
1724
|
+
self.shape = shape
|
|
1725
|
+
self.strides = strides
|
|
1726
|
+
self.ptr = ptr
|
|
1727
|
+
self.device = device
|
|
1728
|
+
self.owner = owner
|
|
1729
|
+
self.pinned = pinned if device.is_cpu else False
|
|
1730
|
+
self.is_contiguous = is_contiguous
|
|
1258
1731
|
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
# and a new destination array to copy it to
|
|
1263
|
-
src = array(
|
|
1264
|
-
dtype=dtype,
|
|
1265
|
-
shape=shape,
|
|
1266
|
-
strides=strides,
|
|
1267
|
-
capacity=arr.size * type_size_in_bytes(dtype),
|
|
1268
|
-
ptr=ptr,
|
|
1269
|
-
device="cpu",
|
|
1270
|
-
copy=False,
|
|
1271
|
-
owner=False,
|
|
1272
|
-
)
|
|
1273
|
-
dest = warp.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned)
|
|
1274
|
-
dest.owner = False
|
|
1732
|
+
def _init_new(self, dtype, shape, strides, device, pinned):
|
|
1733
|
+
if dtype == Any:
|
|
1734
|
+
raise RuntimeError("A concrete data type is required to create the array")
|
|
1275
1735
|
|
|
1276
|
-
|
|
1277
|
-
warp.copy(dest, src, stream=device.null_stream)
|
|
1736
|
+
device = warp.get_device(device)
|
|
1278
1737
|
|
|
1279
|
-
|
|
1280
|
-
|
|
1738
|
+
size = 1
|
|
1739
|
+
for d in shape:
|
|
1740
|
+
size *= d
|
|
1281
1741
|
|
|
1282
|
-
|
|
1283
|
-
self.owner = True
|
|
1742
|
+
contiguous_strides = strides_from_shape(shape, dtype)
|
|
1284
1743
|
|
|
1744
|
+
if strides is None:
|
|
1745
|
+
strides = contiguous_strides
|
|
1746
|
+
is_contiguous = True
|
|
1747
|
+
capacity = size * type_size_in_bytes(dtype)
|
|
1285
1748
|
else:
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
self.strides = strides
|
|
1289
|
-
self.capacity = capacity
|
|
1290
|
-
self.dtype = dtype
|
|
1291
|
-
self.ptr = ptr
|
|
1292
|
-
self.grad_ptr = grad_ptr
|
|
1293
|
-
self.device = device
|
|
1294
|
-
self.owner = owner
|
|
1295
|
-
if device is not None and device.is_cpu:
|
|
1296
|
-
self.pinned = pinned
|
|
1297
|
-
else:
|
|
1298
|
-
self.pinned = False
|
|
1749
|
+
is_contiguous = strides == contiguous_strides
|
|
1750
|
+
capacity = shape[0] * strides[0]
|
|
1299
1751
|
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
self.ndim = len(self.shape)
|
|
1752
|
+
if capacity > 0:
|
|
1753
|
+
ptr = device.allocator.alloc(capacity, pinned=pinned)
|
|
1754
|
+
if ptr is None:
|
|
1755
|
+
raise RuntimeError(f"Array allocation failed on device: {device} for {capacity} bytes")
|
|
1305
1756
|
else:
|
|
1306
|
-
|
|
1757
|
+
ptr = None
|
|
1307
1758
|
|
|
1308
|
-
|
|
1309
|
-
self.
|
|
1310
|
-
|
|
1311
|
-
|
|
1759
|
+
self.dtype = dtype
|
|
1760
|
+
self.ndim = len(shape)
|
|
1761
|
+
self.size = size
|
|
1762
|
+
self.capacity = capacity
|
|
1763
|
+
self.shape = shape
|
|
1764
|
+
self.strides = strides
|
|
1765
|
+
self.ptr = ptr
|
|
1766
|
+
self.device = device
|
|
1767
|
+
self.owner = True
|
|
1768
|
+
self.pinned = pinned if device.is_cpu else False
|
|
1769
|
+
self.is_contiguous = is_contiguous
|
|
1770
|
+
|
|
1771
|
+
def _init_annotation(self, dtype, ndim):
|
|
1772
|
+
self.dtype = dtype
|
|
1773
|
+
self.ndim = ndim
|
|
1774
|
+
self.size = 0
|
|
1775
|
+
self.capacity = 0
|
|
1776
|
+
self.shape = (0,) * ndim
|
|
1777
|
+
self.strides = (0,) * ndim
|
|
1778
|
+
self.ptr = None
|
|
1779
|
+
self.device = None
|
|
1780
|
+
self.owner = False
|
|
1781
|
+
self.pinned = False
|
|
1782
|
+
self.is_contiguous = False
|
|
1312
1783
|
|
|
1313
|
-
|
|
1784
|
+
@property
|
|
1785
|
+
def __array_interface__(self):
|
|
1786
|
+
# raising an AttributeError here makes hasattr() return False
|
|
1787
|
+
if self.device is None or not self.device.is_cpu:
|
|
1788
|
+
raise AttributeError(f"__array_interface__ not supported because device is {self.device}")
|
|
1314
1789
|
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
self.strides
|
|
1321
|
-
|
|
1790
|
+
if self._array_interface is None:
|
|
1791
|
+
# get flat shape (including type shape)
|
|
1792
|
+
if isinstance(self.dtype, warp.codegen.Struct):
|
|
1793
|
+
# struct
|
|
1794
|
+
arr_shape = self.shape
|
|
1795
|
+
arr_strides = self.strides
|
|
1796
|
+
descr = self.dtype.numpy_dtype()
|
|
1797
|
+
elif issubclass(self.dtype, ctypes.Array):
|
|
1798
|
+
# vector type, flatten the dimensions into one tuple
|
|
1799
|
+
arr_shape = (*self.shape, *self.dtype._shape_)
|
|
1800
|
+
dtype_strides = strides_from_shape(self.dtype._shape_, self.dtype._type_)
|
|
1801
|
+
arr_strides = (*self.strides, *dtype_strides)
|
|
1802
|
+
descr = None
|
|
1322
1803
|
else:
|
|
1323
|
-
|
|
1324
|
-
|
|
1804
|
+
# scalar type
|
|
1805
|
+
arr_shape = self.shape
|
|
1806
|
+
arr_strides = self.strides
|
|
1807
|
+
descr = None
|
|
1808
|
+
|
|
1809
|
+
self._array_interface = {
|
|
1810
|
+
"data": (self.ptr if self.ptr is not None else 0, False),
|
|
1811
|
+
"shape": tuple(arr_shape),
|
|
1812
|
+
"strides": tuple(arr_strides),
|
|
1813
|
+
"typestr": type_typestr(self.dtype),
|
|
1814
|
+
"descr": descr, # optional description of structured array layout
|
|
1815
|
+
"version": 3,
|
|
1816
|
+
}
|
|
1325
1817
|
|
|
1326
|
-
|
|
1818
|
+
return self._array_interface
|
|
1327
1819
|
|
|
1328
|
-
|
|
1820
|
+
@property
|
|
1821
|
+
def __cuda_array_interface__(self):
|
|
1822
|
+
# raising an AttributeError here makes hasattr() return False
|
|
1823
|
+
if self.device is None or not self.device.is_cuda:
|
|
1824
|
+
raise AttributeError(f"__cuda_array_interface__ is not supported because device is {self.device}")
|
|
1825
|
+
|
|
1826
|
+
if self._array_interface is None:
|
|
1827
|
+
# get flat shape (including type shape)
|
|
1828
|
+
if issubclass(self.dtype, ctypes.Array):
|
|
1329
1829
|
# vector type, flatten the dimensions into one tuple
|
|
1330
1830
|
arr_shape = (*self.shape, *self.dtype._shape_)
|
|
1331
1831
|
dtype_strides = strides_from_shape(self.dtype._shape_, self.dtype._type_)
|
|
@@ -1335,44 +1835,18 @@ class array(Array):
|
|
|
1335
1835
|
arr_shape = self.shape
|
|
1336
1836
|
arr_strides = self.strides
|
|
1337
1837
|
|
|
1338
|
-
|
|
1339
|
-
self.
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
}
|
|
1346
|
-
|
|
1347
|
-
# set up cuda array interface access so we can treat this object as a Torch tensor
|
|
1348
|
-
elif device.is_cuda:
|
|
1349
|
-
self.__cuda_array_interface__ = {
|
|
1350
|
-
"data": (self.ptr, False),
|
|
1351
|
-
"shape": tuple(arr_shape),
|
|
1352
|
-
"strides": tuple(arr_strides),
|
|
1353
|
-
"typestr": type_typestr(self.dtype),
|
|
1354
|
-
"version": 2,
|
|
1355
|
-
}
|
|
1356
|
-
|
|
1357
|
-
# controls if gradients will be computed by wp.Tape
|
|
1358
|
-
# this will trigger allocation of a gradient array if it doesn't exist already
|
|
1359
|
-
self.requires_grad = requires_grad
|
|
1360
|
-
|
|
1361
|
-
else:
|
|
1362
|
-
# array has no data
|
|
1363
|
-
self.strides = (0,) * self.ndim
|
|
1364
|
-
self.is_contiguous = False
|
|
1365
|
-
self.requires_grad = False
|
|
1838
|
+
self._array_interface = {
|
|
1839
|
+
"data": (self.ptr if self.ptr is not None else 0, False),
|
|
1840
|
+
"shape": tuple(arr_shape),
|
|
1841
|
+
"strides": tuple(arr_strides),
|
|
1842
|
+
"typestr": type_typestr(self.dtype),
|
|
1843
|
+
"version": 2,
|
|
1844
|
+
}
|
|
1366
1845
|
|
|
1367
|
-
self.
|
|
1846
|
+
return self._array_interface
|
|
1368
1847
|
|
|
1369
1848
|
def __del__(self):
|
|
1370
|
-
if self.owner
|
|
1371
|
-
# TODO: ill-timed gc could trigger superfluous context switches here
|
|
1372
|
-
# Delegate to a separate thread? (e.g., device_free_async)
|
|
1373
|
-
if self.device.is_capturing:
|
|
1374
|
-
raise RuntimeError(f"Cannot free memory on device {self.device} while graph capture is active")
|
|
1375
|
-
|
|
1849
|
+
if self.owner:
|
|
1376
1850
|
# use CUDA context guard to avoid side effects during garbage collection
|
|
1377
1851
|
with self.device.context_guard:
|
|
1378
1852
|
self.device.allocator.free(self.ptr, self.capacity, self.pinned)
|
|
@@ -1385,7 +1859,7 @@ class array(Array):
|
|
|
1385
1859
|
# for 'empty' arrays we just return the type information, these are used in kernel function signatures
|
|
1386
1860
|
return f"array{self.dtype}"
|
|
1387
1861
|
else:
|
|
1388
|
-
return str(self.
|
|
1862
|
+
return str(self.numpy())
|
|
1389
1863
|
|
|
1390
1864
|
def __getitem__(self, key):
|
|
1391
1865
|
if isinstance(key, int):
|
|
@@ -1436,7 +1910,7 @@ class array(Array):
|
|
|
1436
1910
|
if stop < 0:
|
|
1437
1911
|
stop = self.shape[idx] + stop
|
|
1438
1912
|
|
|
1439
|
-
if start < 0 or start
|
|
1913
|
+
if start < 0 or start >= self.shape[idx]:
|
|
1440
1914
|
raise RuntimeError(f"Invalid indexing in slice: {start}:{stop}:{step}")
|
|
1441
1915
|
if stop < 1 or stop > self.shape[idx]:
|
|
1442
1916
|
raise RuntimeError(f"Invalid indexing in slice: {start}:{stop}:{step}")
|
|
@@ -1460,23 +1934,37 @@ class array(Array):
|
|
|
1460
1934
|
start = k
|
|
1461
1935
|
if start < 0:
|
|
1462
1936
|
start = self.shape[idx] + start
|
|
1463
|
-
if start < 0 or start
|
|
1937
|
+
if start < 0 or start >= self.shape[idx]:
|
|
1464
1938
|
raise RuntimeError(f"Invalid indexing in slice: {k}")
|
|
1465
1939
|
new_dim -= 1
|
|
1466
1940
|
|
|
1467
1941
|
ptr_offset += self.strides[idx] * start
|
|
1468
1942
|
|
|
1943
|
+
# handle grad
|
|
1944
|
+
if self.grad is not None:
|
|
1945
|
+
new_grad = array(
|
|
1946
|
+
ptr=self.grad.ptr + ptr_offset if self.grad.ptr is not None else None,
|
|
1947
|
+
dtype=self.grad.dtype,
|
|
1948
|
+
shape=tuple(new_shape),
|
|
1949
|
+
strides=tuple(new_strides),
|
|
1950
|
+
device=self.grad.device,
|
|
1951
|
+
pinned=self.grad.pinned,
|
|
1952
|
+
owner=False,
|
|
1953
|
+
)
|
|
1954
|
+
# store back-ref to stop data being destroyed
|
|
1955
|
+
new_grad._ref = self.grad
|
|
1956
|
+
else:
|
|
1957
|
+
new_grad = None
|
|
1958
|
+
|
|
1469
1959
|
a = array(
|
|
1960
|
+
ptr=self.ptr + ptr_offset if self.ptr is not None else None,
|
|
1470
1961
|
dtype=self.dtype,
|
|
1471
1962
|
shape=tuple(new_shape),
|
|
1472
1963
|
strides=tuple(new_strides),
|
|
1473
|
-
ptr=self.ptr + ptr_offset,
|
|
1474
|
-
grad_ptr=(self.grad_ptr + ptr_offset if self.grad_ptr is not None else None),
|
|
1475
|
-
capacity=self.capacity,
|
|
1476
1964
|
device=self.device,
|
|
1965
|
+
pinned=self.pinned,
|
|
1477
1966
|
owner=False,
|
|
1478
|
-
|
|
1479
|
-
requires_grad=self.requires_grad,
|
|
1967
|
+
grad=new_grad,
|
|
1480
1968
|
)
|
|
1481
1969
|
|
|
1482
1970
|
# store back-ref to stop data being destroyed
|
|
@@ -1494,7 +1982,7 @@ class array(Array):
|
|
|
1494
1982
|
def __ctype__(self):
|
|
1495
1983
|
if self.ctype is None:
|
|
1496
1984
|
data = 0 if self.ptr is None else ctypes.c_uint64(self.ptr)
|
|
1497
|
-
grad = 0 if self.
|
|
1985
|
+
grad = 0 if self.grad is None or self.grad.ptr is None else ctypes.c_uint64(self.grad.ptr)
|
|
1498
1986
|
self.ctype = array_t(data=data, grad=grad, ndim=self.ndim, shape=self.shape, strides=self.strides)
|
|
1499
1987
|
|
|
1500
1988
|
return self.ctype
|
|
@@ -1522,25 +2010,31 @@ class array(Array):
|
|
|
1522
2010
|
return self._grad
|
|
1523
2011
|
|
|
1524
2012
|
@grad.setter
|
|
1525
|
-
def grad(self,
|
|
1526
|
-
|
|
1527
|
-
self.ctype = None
|
|
1528
|
-
if value is None:
|
|
1529
|
-
self.grad_ptr = None
|
|
2013
|
+
def grad(self, grad):
|
|
2014
|
+
if grad is None:
|
|
1530
2015
|
self._grad = None
|
|
1531
|
-
|
|
1532
|
-
if self._grad is None:
|
|
1533
|
-
self.grad_ptr = value.ptr
|
|
1534
|
-
self._grad = value
|
|
2016
|
+
self._requires_grad = False
|
|
1535
2017
|
else:
|
|
1536
|
-
|
|
2018
|
+
# make sure the given gradient array is compatible
|
|
2019
|
+
if (
|
|
2020
|
+
grad.dtype != self.dtype
|
|
2021
|
+
or grad.shape != self.shape
|
|
2022
|
+
or grad.strides != self.strides
|
|
2023
|
+
or grad.device != self.device
|
|
2024
|
+
):
|
|
2025
|
+
raise ValueError("The given gradient array is incompatible")
|
|
2026
|
+
self._grad = grad
|
|
2027
|
+
self._requires_grad = True
|
|
2028
|
+
|
|
2029
|
+
# trigger re-creation of C-representation
|
|
2030
|
+
self.ctype = None
|
|
1537
2031
|
|
|
1538
2032
|
@property
|
|
1539
2033
|
def requires_grad(self):
|
|
1540
2034
|
return self._requires_grad
|
|
1541
2035
|
|
|
1542
2036
|
@requires_grad.setter
|
|
1543
|
-
def requires_grad(self, value: bool):
|
|
2037
|
+
def requires_grad(self, value: builtins.bool):
|
|
1544
2038
|
if value and self._grad is None:
|
|
1545
2039
|
self._alloc_grad()
|
|
1546
2040
|
elif not value:
|
|
@@ -1548,18 +2042,15 @@ class array(Array):
|
|
|
1548
2042
|
|
|
1549
2043
|
self._requires_grad = value
|
|
1550
2044
|
|
|
1551
|
-
|
|
1552
|
-
|
|
1553
|
-
num_bytes = self.size * type_size_in_bytes(self.dtype)
|
|
1554
|
-
self.grad_ptr = self.device.allocator.alloc(num_bytes, pinned=self.pinned)
|
|
1555
|
-
if self.grad_ptr is None:
|
|
1556
|
-
raise RuntimeError("Memory allocation failed on device: {} for {} bytes".format(self.device, num_bytes))
|
|
1557
|
-
with warp.ScopedStream(self.device.null_stream):
|
|
1558
|
-
self.device.memset(self.grad_ptr, 0, num_bytes)
|
|
2045
|
+
# trigger re-creation of C-representation
|
|
2046
|
+
self.ctype = None
|
|
1559
2047
|
|
|
2048
|
+
def _alloc_grad(self):
|
|
1560
2049
|
self._grad = array(
|
|
1561
|
-
|
|
2050
|
+
dtype=self.dtype, shape=self.shape, strides=self.strides, device=self.device, pinned=self.pinned
|
|
1562
2051
|
)
|
|
2052
|
+
self._grad.zero_()
|
|
2053
|
+
|
|
1563
2054
|
# trigger re-creation of C-representation
|
|
1564
2055
|
self.ctype = None
|
|
1565
2056
|
|
|
@@ -1568,171 +2059,195 @@ class array(Array):
|
|
|
1568
2059
|
# member attributes available during code-gen (e.g.: d = array.shape[0])
|
|
1569
2060
|
# Note: we use a shared dict for all array instances
|
|
1570
2061
|
if array._vars is None:
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
array._vars = {"shape": Var("shape", shape_t)}
|
|
2062
|
+
array._vars = {"shape": warp.codegen.Var("shape", shape_t)}
|
|
1574
2063
|
return array._vars
|
|
1575
2064
|
|
|
1576
2065
|
def zero_(self):
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
)
|
|
2066
|
+
"""Zeroes-out the array entries."""
|
|
2067
|
+
if self.is_contiguous:
|
|
2068
|
+
# simple memset is usually faster than generic fill
|
|
2069
|
+
self.device.memset(self.ptr, 0, self.size * type_size_in_bytes(self.dtype))
|
|
2070
|
+
else:
|
|
2071
|
+
self.fill_(0)
|
|
1584
2072
|
|
|
1585
2073
|
def fill_(self, value):
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
2074
|
+
"""Set all array entries to `value`
|
|
2075
|
+
|
|
2076
|
+
args:
|
|
2077
|
+
value: The value to set every array entry to. Must be convertible to the array's ``dtype``.
|
|
2078
|
+
|
|
2079
|
+
Raises:
|
|
2080
|
+
ValueError: If `value` cannot be converted to the array's ``dtype``.
|
|
2081
|
+
|
|
2082
|
+
Examples:
|
|
2083
|
+
``fill_()`` can take lists or other sequences when filling arrays of vectors or matrices.
|
|
2084
|
+
|
|
2085
|
+
>>> arr = wp.zeros(2, dtype=wp.mat22)
|
|
2086
|
+
>>> arr.numpy()
|
|
2087
|
+
array([[[0., 0.],
|
|
2088
|
+
[0., 0.]],
|
|
2089
|
+
<BLANKLINE>
|
|
2090
|
+
[[0., 0.],
|
|
2091
|
+
[0., 0.]]], dtype=float32)
|
|
2092
|
+
>>> arr.fill_([[1, 2], [3, 4]])
|
|
2093
|
+
>>> arr.numpy()
|
|
2094
|
+
array([[[1., 2.],
|
|
2095
|
+
[3., 4.]],
|
|
2096
|
+
<BLANKLINE>
|
|
2097
|
+
[[1., 2.],
|
|
2098
|
+
[3., 4.]]], dtype=float32)
|
|
2099
|
+
"""
|
|
2100
|
+
if self.size == 0:
|
|
2101
|
+
return
|
|
1608
2102
|
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
elem_size = ctypes.sizeof(elem_type)
|
|
1618
|
-
|
|
1619
|
-
# convert value to array type
|
|
1620
|
-
# we need a special case for float16 because it's annoying...
|
|
1621
|
-
if types_equal(self.dtype, float16) or (
|
|
1622
|
-
hasattr(self.dtype, "_wp_scalar_type_") and types_equal(self.dtype._wp_scalar_type_, float16)
|
|
1623
|
-
):
|
|
1624
|
-
# special case for float16:
|
|
1625
|
-
# If you just do elem_type(value), it'll just convert "value"
|
|
1626
|
-
# to uint16 then interpret the bits as float16, which will
|
|
1627
|
-
# mess the data up. Instead, we use float_to_half_bits() to
|
|
1628
|
-
# convert "value" to a float16 and return its bits in a uint16:
|
|
1629
|
-
|
|
1630
|
-
from warp.context import runtime
|
|
1631
|
-
|
|
1632
|
-
src_value = elem_type(runtime.core.float_to_half_bits(ctypes.c_float(value)))
|
|
2103
|
+
# try to convert the given value to the array dtype
|
|
2104
|
+
try:
|
|
2105
|
+
if isinstance(self.dtype, warp.codegen.Struct):
|
|
2106
|
+
if isinstance(value, self.dtype.cls):
|
|
2107
|
+
cvalue = value.__ctype__()
|
|
2108
|
+
elif value == 0:
|
|
2109
|
+
# allow zero-initializing structs using default constructor
|
|
2110
|
+
cvalue = self.dtype().__ctype__()
|
|
1633
2111
|
else:
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
# use memset for these special cases because it's quicker (probably...):
|
|
1637
|
-
total_bytes = self.size * type_size_in_bytes(self.dtype)
|
|
1638
|
-
if elem_size in [1, 2, 4] and (total_bytes % 4 == 0):
|
|
1639
|
-
# interpret as a 4 byte integer:
|
|
1640
|
-
dest_value = ctypes.cast(ctypes.pointer(src_value), ctypes.POINTER(ctypes.c_int)).contents
|
|
1641
|
-
if elem_size == 1:
|
|
1642
|
-
# need to repeat the bits, otherwise we'll get an array interleaved with zeros:
|
|
1643
|
-
dest_value.value = dest_value.value & 0x000000FF
|
|
1644
|
-
dest_value.value = (
|
|
1645
|
-
dest_value.value
|
|
1646
|
-
+ (dest_value.value << 8)
|
|
1647
|
-
+ (dest_value.value << 16)
|
|
1648
|
-
+ (dest_value.value << 24)
|
|
1649
|
-
)
|
|
1650
|
-
elif elem_size == 2:
|
|
1651
|
-
# need to repeat the bits, otherwise we'll get an array interleaved with zeros:
|
|
1652
|
-
dest_value.value = dest_value.value & 0x0000FFFF
|
|
1653
|
-
dest_value.value = dest_value.value + (dest_value.value << 16)
|
|
1654
|
-
|
|
1655
|
-
self.device.memset(
|
|
1656
|
-
ctypes.cast(self.ptr, ctypes.POINTER(ctypes.c_int)), dest_value, ctypes.c_size_t(total_bytes)
|
|
2112
|
+
raise ValueError(
|
|
2113
|
+
f"Invalid initializer value for struct {self.dtype.cls.__name__}, expected struct instance or 0"
|
|
1657
2114
|
)
|
|
2115
|
+
elif issubclass(self.dtype, ctypes.Array):
|
|
2116
|
+
# vector/matrix
|
|
2117
|
+
cvalue = self.dtype(value)
|
|
2118
|
+
else:
|
|
2119
|
+
# scalar
|
|
2120
|
+
if type(value) in warp.types.scalar_types:
|
|
2121
|
+
value = value.value
|
|
2122
|
+
if self.dtype == float16:
|
|
2123
|
+
cvalue = self.dtype._type_(float_to_half_bits(value))
|
|
1658
2124
|
else:
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
2125
|
+
cvalue = self.dtype._type_(value)
|
|
2126
|
+
except Exception as e:
|
|
2127
|
+
raise ValueError(f"Failed to convert the value to the array data type: {e}")
|
|
2128
|
+
|
|
2129
|
+
cvalue_ptr = ctypes.pointer(cvalue)
|
|
2130
|
+
cvalue_size = ctypes.sizeof(cvalue)
|
|
2131
|
+
|
|
2132
|
+
# prefer using memtile for contiguous arrays, because it should be faster than generic fill
|
|
2133
|
+
if self.is_contiguous:
|
|
2134
|
+
self.device.memtile(self.ptr, cvalue_ptr, cvalue_size, self.size)
|
|
2135
|
+
else:
|
|
2136
|
+
carr = self.__ctype__()
|
|
2137
|
+
carr_ptr = ctypes.pointer(carr)
|
|
2138
|
+
|
|
2139
|
+
if self.device.is_cuda:
|
|
2140
|
+
warp.context.runtime.core.array_fill_device(
|
|
2141
|
+
self.device.context, carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size
|
|
2142
|
+
)
|
|
2143
|
+
else:
|
|
2144
|
+
warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
|
|
1662
2145
|
|
|
1663
|
-
# equivalent to wrapping src data in an array and copying to self
|
|
1664
2146
|
def assign(self, src):
|
|
1665
|
-
|
|
2147
|
+
"""Wraps ``src`` in an :class:`warp.array` if it is not already one and copies the contents to ``self``."""
|
|
2148
|
+
if is_array(src):
|
|
1666
2149
|
warp.copy(self, src)
|
|
1667
2150
|
else:
|
|
1668
|
-
warp.copy(self, array(src, dtype=self.dtype, copy=False, device="cpu"))
|
|
2151
|
+
warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
|
|
1669
2152
|
|
|
1670
|
-
# convert array to ndarray (alias memory through array interface)
|
|
1671
2153
|
def numpy(self):
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
2154
|
+
"""Converts the array to a :class:`numpy.ndarray` (aliasing memory through the array interface protocol)
|
|
2155
|
+
If the array is on the GPU, a synchronous device-to-host copy (on the CUDA default stream) will be
|
|
2156
|
+
automatically performed to ensure that any outstanding work is completed.
|
|
2157
|
+
"""
|
|
2158
|
+
if self.ptr:
|
|
2159
|
+
# use the CUDA default stream for synchronous behaviour with other streams
|
|
2160
|
+
with warp.ScopedStream(self.device.null_stream):
|
|
2161
|
+
a = self.to("cpu", requires_grad=False)
|
|
2162
|
+
# convert through __array_interface__
|
|
2163
|
+
# Note: this handles arrays of structs using `descr`, so the result will be a structured NumPy array
|
|
2164
|
+
return np.array(a, copy=False)
|
|
2165
|
+
else:
|
|
2166
|
+
# return an empty numpy array with the correct dtype and shape
|
|
2167
|
+
if isinstance(self.dtype, warp.codegen.Struct):
|
|
2168
|
+
npdtype = self.dtype.numpy_dtype()
|
|
2169
|
+
npshape = self.shape
|
|
2170
|
+
elif issubclass(self.dtype, ctypes.Array):
|
|
2171
|
+
npdtype = warp_type_to_np_dtype[self.dtype._wp_scalar_type_]
|
|
2172
|
+
npshape = (*self.shape, *self.dtype._shape_)
|
|
1676
2173
|
else:
|
|
1677
|
-
|
|
2174
|
+
npdtype = warp_type_to_np_dtype[self.dtype]
|
|
2175
|
+
npshape = self.shape
|
|
2176
|
+
return np.empty(npshape, dtype=npdtype)
|
|
1678
2177
|
|
|
1679
|
-
if isinstance(self.dtype, warp.codegen.Struct):
|
|
1680
|
-
# Note: cptr holds a backref to the source array to avoid it being deallocated
|
|
1681
|
-
p = a.cptr()
|
|
1682
|
-
return np.ctypeslib.as_array(p, self.shape)
|
|
1683
|
-
else:
|
|
1684
|
-
# convert through array interface
|
|
1685
|
-
return np.array(a, copy=False)
|
|
1686
|
-
|
|
1687
|
-
# return a ctypes cast of the array address
|
|
1688
|
-
# note that accesses to this object are *not* bounds checked
|
|
1689
2178
|
def cptr(self):
|
|
1690
|
-
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
2179
|
+
"""Return a ctypes cast of the array address.
|
|
2180
|
+
|
|
2181
|
+
Notes:
|
|
2182
|
+
|
|
2183
|
+
#. Only CPU arrays support this method.
|
|
2184
|
+
#. The array must be contiguous.
|
|
2185
|
+
#. Accesses to this object are **not** bounds checked.
|
|
2186
|
+
#. For ``float16`` types, a pointer to the internal ``uint16`` representation is returned.
|
|
2187
|
+
"""
|
|
2188
|
+
if not self.ptr:
|
|
2189
|
+
return None
|
|
2190
|
+
|
|
2191
|
+
if self.device != "cpu" or not self.is_contiguous:
|
|
2192
|
+
raise RuntimeError(
|
|
2193
|
+
"Accessing array memory through a ctypes ptr is only supported for contiguous CPU arrays."
|
|
2194
|
+
)
|
|
2195
|
+
|
|
2196
|
+
if isinstance(self.dtype, warp.codegen.Struct):
|
|
2197
|
+
p = ctypes.cast(self.ptr, ctypes.POINTER(self.dtype.ctype))
|
|
2198
|
+
else:
|
|
2199
|
+
p = ctypes.cast(self.ptr, ctypes.POINTER(self.dtype._type_))
|
|
1694
2200
|
|
|
1695
2201
|
# store backref to the underlying array to avoid it being deallocated
|
|
1696
2202
|
p._ref = self
|
|
1697
2203
|
|
|
1698
2204
|
return p
|
|
1699
2205
|
|
|
1700
|
-
# returns a flattened list of items in the array as a Python list
|
|
1701
2206
|
def list(self):
|
|
1702
|
-
a
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
2207
|
+
"""Returns a flattened list of items in the array as a Python list."""
|
|
2208
|
+
a = self.numpy()
|
|
2209
|
+
|
|
2210
|
+
if isinstance(self.dtype, warp.codegen.Struct):
|
|
2211
|
+
# struct
|
|
2212
|
+
a = a.flatten()
|
|
2213
|
+
data = a.ctypes.data
|
|
2214
|
+
stride = a.strides[0]
|
|
2215
|
+
return [self.dtype.from_ptr(data + i * stride) for i in range(self.size)]
|
|
2216
|
+
elif issubclass(self.dtype, ctypes.Array):
|
|
2217
|
+
# vector/matrix - flatten, but preserve inner vector/matrix dimensions
|
|
2218
|
+
a = a.reshape((self.size, *self.dtype._shape_))
|
|
2219
|
+
data = a.ctypes.data
|
|
2220
|
+
stride = a.strides[0]
|
|
2221
|
+
return [self.dtype.from_ptr(data + i * stride) for i in range(self.size)]
|
|
2222
|
+
else:
|
|
2223
|
+
# scalar
|
|
2224
|
+
return list(a.flatten())
|
|
1706
2225
|
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
# convert data from one device to another, nop if already on device
|
|
1710
|
-
def to(self, device):
|
|
2226
|
+
def to(self, device, requires_grad=None):
|
|
2227
|
+
"""Returns a Warp array with this array's data moved to the specified device, no-op if already on device."""
|
|
1711
2228
|
device = warp.get_device(device)
|
|
1712
2229
|
if self.device == device:
|
|
1713
2230
|
return self
|
|
1714
2231
|
else:
|
|
1715
|
-
|
|
1716
|
-
# to copy between devices, array must be contiguous
|
|
1717
|
-
warp.copy(dest, self.contiguous())
|
|
1718
|
-
return dest
|
|
2232
|
+
return warp.clone(self, device=device, requires_grad=requires_grad)
|
|
1719
2233
|
|
|
1720
2234
|
def flatten(self):
|
|
2235
|
+
"""Returns a zero-copy view of the array collapsed to 1-D. Only supported for contiguous arrays."""
|
|
2236
|
+
if self.ndim == 1:
|
|
2237
|
+
return self
|
|
2238
|
+
|
|
1721
2239
|
if not self.is_contiguous:
|
|
1722
2240
|
raise RuntimeError("Flattening non-contiguous arrays is unsupported.")
|
|
1723
2241
|
|
|
1724
2242
|
a = array(
|
|
2243
|
+
ptr=self.ptr,
|
|
1725
2244
|
dtype=self.dtype,
|
|
1726
2245
|
shape=(self.size,),
|
|
1727
|
-
strides=(type_size_in_bytes(self.dtype),),
|
|
1728
|
-
ptr=self.ptr,
|
|
1729
|
-
grad_ptr=self.grad_ptr,
|
|
1730
|
-
capacity=self.capacity,
|
|
1731
2246
|
device=self.device,
|
|
2247
|
+
pinned=self.pinned,
|
|
1732
2248
|
copy=False,
|
|
1733
2249
|
owner=False,
|
|
1734
|
-
|
|
1735
|
-
requires_grad=self.requires_grad,
|
|
2250
|
+
grad=None if self.grad is None else self.grad.flatten(),
|
|
1736
2251
|
)
|
|
1737
2252
|
|
|
1738
2253
|
# store back-ref to stop data being destroyed
|
|
@@ -1740,6 +2255,11 @@ class array(Array):
|
|
|
1740
2255
|
return a
|
|
1741
2256
|
|
|
1742
2257
|
def reshape(self, shape):
|
|
2258
|
+
"""Returns a reshaped array. Only supported for contiguous arrays.
|
|
2259
|
+
|
|
2260
|
+
Args:
|
|
2261
|
+
shape : An int or tuple of ints specifying the shape of the returned array.
|
|
2262
|
+
"""
|
|
1743
2263
|
if not self.is_contiguous:
|
|
1744
2264
|
raise RuntimeError("Reshaping non-contiguous arrays is unsupported.")
|
|
1745
2265
|
|
|
@@ -1748,7 +2268,7 @@ class array(Array):
|
|
|
1748
2268
|
raise RuntimeError("shape parameter is required.")
|
|
1749
2269
|
if isinstance(shape, int):
|
|
1750
2270
|
shape = (shape,)
|
|
1751
|
-
elif isinstance(shape,
|
|
2271
|
+
elif not isinstance(shape, tuple):
|
|
1752
2272
|
shape = tuple(shape)
|
|
1753
2273
|
|
|
1754
2274
|
if len(shape) > ARRAY_MAX_DIMS:
|
|
@@ -1756,6 +2276,23 @@ class array(Array):
|
|
|
1756
2276
|
f"Arrays may only have {ARRAY_MAX_DIMS} dimensions maximum, trying to create array with {len(shape)} dims."
|
|
1757
2277
|
)
|
|
1758
2278
|
|
|
2279
|
+
# check for -1 dimension and reformat
|
|
2280
|
+
if -1 in shape:
|
|
2281
|
+
idx = self.size
|
|
2282
|
+
denom = 1
|
|
2283
|
+
minus_one_count = 0
|
|
2284
|
+
for i, d in enumerate(shape):
|
|
2285
|
+
if d == -1:
|
|
2286
|
+
idx = i
|
|
2287
|
+
minus_one_count += 1
|
|
2288
|
+
else:
|
|
2289
|
+
denom *= d
|
|
2290
|
+
if minus_one_count > 1:
|
|
2291
|
+
raise RuntimeError("Cannot infer shape if more than one index is -1.")
|
|
2292
|
+
new_shape = list(shape)
|
|
2293
|
+
new_shape[idx] = int(self.size / denom)
|
|
2294
|
+
shape = tuple(new_shape)
|
|
2295
|
+
|
|
1759
2296
|
size = 1
|
|
1760
2297
|
for d in shape:
|
|
1761
2298
|
size *= d
|
|
@@ -1764,17 +2301,15 @@ class array(Array):
|
|
|
1764
2301
|
raise RuntimeError("Reshaped array must have the same total size as the original.")
|
|
1765
2302
|
|
|
1766
2303
|
a = array(
|
|
2304
|
+
ptr=self.ptr,
|
|
1767
2305
|
dtype=self.dtype,
|
|
1768
2306
|
shape=shape,
|
|
1769
2307
|
strides=None,
|
|
1770
|
-
ptr=self.ptr,
|
|
1771
|
-
grad_ptr=self.grad_ptr,
|
|
1772
|
-
capacity=self.capacity,
|
|
1773
2308
|
device=self.device,
|
|
2309
|
+
pinned=self.pinned,
|
|
1774
2310
|
copy=False,
|
|
1775
2311
|
owner=False,
|
|
1776
|
-
|
|
1777
|
-
requires_grad=self.requires_grad,
|
|
2312
|
+
grad=None if self.grad is None else self.grad.reshape(shape),
|
|
1778
2313
|
)
|
|
1779
2314
|
|
|
1780
2315
|
# store back-ref to stop data being destroyed
|
|
@@ -1782,49 +2317,55 @@ class array(Array):
|
|
|
1782
2317
|
return a
|
|
1783
2318
|
|
|
1784
2319
|
def view(self, dtype):
|
|
2320
|
+
"""Returns a zero-copy view of this array's memory with a different data type.
|
|
2321
|
+
``dtype`` must have the same byte size of the array's native ``dtype``.
|
|
2322
|
+
"""
|
|
1785
2323
|
if type_size_in_bytes(dtype) != type_size_in_bytes(self.dtype):
|
|
1786
|
-
raise RuntimeError("
|
|
1787
|
-
else:
|
|
1788
|
-
# return an alias of the array memory with different type information
|
|
1789
|
-
a = array(
|
|
1790
|
-
data=None,
|
|
1791
|
-
dtype=dtype,
|
|
1792
|
-
shape=self.shape,
|
|
1793
|
-
strides=self.strides,
|
|
1794
|
-
ptr=self.ptr,
|
|
1795
|
-
grad_ptr=self.grad_ptr,
|
|
1796
|
-
capacity=self.capacity,
|
|
1797
|
-
device=self.device,
|
|
1798
|
-
copy=False,
|
|
1799
|
-
owner=False,
|
|
1800
|
-
ndim=self.ndim,
|
|
1801
|
-
requires_grad=self.requires_grad,
|
|
1802
|
-
)
|
|
2324
|
+
raise RuntimeError("Cannot cast dtypes of unequal byte size")
|
|
1803
2325
|
|
|
1804
|
-
|
|
1805
|
-
|
|
2326
|
+
# return an alias of the array memory with different type information
|
|
2327
|
+
a = array(
|
|
2328
|
+
ptr=self.ptr,
|
|
2329
|
+
dtype=dtype,
|
|
2330
|
+
shape=self.shape,
|
|
2331
|
+
strides=self.strides,
|
|
2332
|
+
device=self.device,
|
|
2333
|
+
pinned=self.pinned,
|
|
2334
|
+
copy=False,
|
|
2335
|
+
owner=False,
|
|
2336
|
+
grad=None if self.grad is None else self.grad.view(dtype),
|
|
2337
|
+
)
|
|
2338
|
+
|
|
2339
|
+
a._ref = self
|
|
2340
|
+
return a
|
|
1806
2341
|
|
|
1807
2342
|
def contiguous(self):
|
|
2343
|
+
"""Returns a contiguous array with this array's data. No-op if array is already contiguous."""
|
|
1808
2344
|
if self.is_contiguous:
|
|
1809
2345
|
return self
|
|
1810
2346
|
|
|
1811
2347
|
a = warp.empty_like(self)
|
|
1812
2348
|
warp.copy(a, self)
|
|
1813
|
-
|
|
1814
2349
|
return a
|
|
1815
2350
|
|
|
1816
|
-
# note: transpose operation will return an array with a non-contiguous access pattern
|
|
1817
2351
|
def transpose(self, axes=None):
|
|
2352
|
+
"""Returns an zero-copy view of the array with axes transposed.
|
|
2353
|
+
|
|
2354
|
+
Note: The transpose operation will return an array with a non-contiguous access pattern.
|
|
2355
|
+
|
|
2356
|
+
Args:
|
|
2357
|
+
axes (optional): Specifies the how the axes are permuted. If not specified, the axes order will be reversed.
|
|
2358
|
+
"""
|
|
1818
2359
|
# noop if 1d array
|
|
1819
|
-
if
|
|
2360
|
+
if self.ndim == 1:
|
|
1820
2361
|
return self
|
|
1821
2362
|
|
|
1822
2363
|
if axes is None:
|
|
1823
2364
|
# reverse the order of the axes
|
|
1824
2365
|
axes = range(self.ndim)[::-1]
|
|
1825
|
-
|
|
1826
|
-
if len(axes) != len(self.shape):
|
|
2366
|
+
elif len(axes) != len(self.shape):
|
|
1827
2367
|
raise RuntimeError("Length of parameter axes must be equal in length to array shape")
|
|
2368
|
+
|
|
1828
2369
|
shape = []
|
|
1829
2370
|
strides = []
|
|
1830
2371
|
for a in axes:
|
|
@@ -1836,20 +2377,19 @@ class array(Array):
|
|
|
1836
2377
|
strides.append(self.strides[a])
|
|
1837
2378
|
|
|
1838
2379
|
a = array(
|
|
1839
|
-
|
|
2380
|
+
ptr=self.ptr,
|
|
1840
2381
|
dtype=self.dtype,
|
|
1841
2382
|
shape=tuple(shape),
|
|
1842
2383
|
strides=tuple(strides),
|
|
1843
|
-
ptr=self.ptr,
|
|
1844
|
-
grad_ptr=self.grad_ptr,
|
|
1845
|
-
capacity=self.capacity,
|
|
1846
2384
|
device=self.device,
|
|
2385
|
+
pinned=self.pinned,
|
|
1847
2386
|
copy=False,
|
|
1848
2387
|
owner=False,
|
|
1849
|
-
|
|
1850
|
-
requires_grad=self.requires_grad,
|
|
2388
|
+
grad=None if self.grad is None else self.grad.transpose(axes=axes),
|
|
1851
2389
|
)
|
|
1852
2390
|
|
|
2391
|
+
a.is_transposed = not self.is_transposed
|
|
2392
|
+
|
|
1853
2393
|
a._ref = self
|
|
1854
2394
|
return a
|
|
1855
2395
|
|
|
@@ -1878,12 +2418,13 @@ def array4d(*args, **kwargs):
|
|
|
1878
2418
|
return array(*args, **kwargs)
|
|
1879
2419
|
|
|
1880
2420
|
|
|
2421
|
+
# TODO: Rewrite so that we take only shape, not length and optional shape
|
|
1881
2422
|
def from_ptr(ptr, length, dtype=None, shape=None, device=None):
|
|
1882
2423
|
return array(
|
|
1883
2424
|
dtype=dtype,
|
|
1884
2425
|
length=length,
|
|
1885
2426
|
capacity=length * type_size_in_bytes(dtype),
|
|
1886
|
-
ptr=ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
|
|
2427
|
+
ptr=0 if ptr == 0 else ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
|
|
1887
2428
|
shape=shape,
|
|
1888
2429
|
device=device,
|
|
1889
2430
|
owner=False,
|
|
@@ -1891,12 +2432,113 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
|
|
|
1891
2432
|
)
|
|
1892
2433
|
|
|
1893
2434
|
|
|
1894
|
-
class
|
|
2435
|
+
# A base class for non-contiguous arrays, providing the implementation of common methods like
|
|
2436
|
+
# contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
|
|
2437
|
+
class noncontiguous_array_base(Generic[T]):
|
|
2438
|
+
def __init__(self, array_type_id):
|
|
2439
|
+
self.type_id = array_type_id
|
|
2440
|
+
self.is_contiguous = False
|
|
2441
|
+
|
|
2442
|
+
# return a contiguous copy
|
|
2443
|
+
def contiguous(self):
|
|
2444
|
+
a = warp.empty_like(self)
|
|
2445
|
+
warp.copy(a, self)
|
|
2446
|
+
return a
|
|
2447
|
+
|
|
2448
|
+
# copy data from one device to another, nop if already on device
|
|
2449
|
+
def to(self, device):
|
|
2450
|
+
device = warp.get_device(device)
|
|
2451
|
+
if self.device == device:
|
|
2452
|
+
return self
|
|
2453
|
+
else:
|
|
2454
|
+
return warp.clone(self, device=device)
|
|
2455
|
+
|
|
2456
|
+
# return a contiguous numpy copy
|
|
2457
|
+
def numpy(self):
|
|
2458
|
+
# use the CUDA default stream for synchronous behaviour with other streams
|
|
2459
|
+
with warp.ScopedStream(self.device.null_stream):
|
|
2460
|
+
return self.contiguous().numpy()
|
|
2461
|
+
|
|
2462
|
+
# returns a flattened list of items in the array as a Python list
|
|
2463
|
+
def list(self):
|
|
2464
|
+
# use the CUDA default stream for synchronous behaviour with other streams
|
|
2465
|
+
with warp.ScopedStream(self.device.null_stream):
|
|
2466
|
+
return self.contiguous().list()
|
|
2467
|
+
|
|
2468
|
+
# equivalent to wrapping src data in an array and copying to self
|
|
2469
|
+
def assign(self, src):
|
|
2470
|
+
if is_array(src):
|
|
2471
|
+
warp.copy(self, src)
|
|
2472
|
+
else:
|
|
2473
|
+
warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
|
|
2474
|
+
|
|
2475
|
+
def zero_(self):
|
|
2476
|
+
self.fill_(0)
|
|
2477
|
+
|
|
2478
|
+
def fill_(self, value):
|
|
2479
|
+
if self.size == 0:
|
|
2480
|
+
return
|
|
2481
|
+
|
|
2482
|
+
# try to convert the given value to the array dtype
|
|
2483
|
+
try:
|
|
2484
|
+
if isinstance(self.dtype, warp.codegen.Struct):
|
|
2485
|
+
if isinstance(value, self.dtype.cls):
|
|
2486
|
+
cvalue = value.__ctype__()
|
|
2487
|
+
elif value == 0:
|
|
2488
|
+
# allow zero-initializing structs using default constructor
|
|
2489
|
+
cvalue = self.dtype().__ctype__()
|
|
2490
|
+
else:
|
|
2491
|
+
raise ValueError(
|
|
2492
|
+
f"Invalid initializer value for struct {self.dtype.cls.__name__}, expected struct instance or 0"
|
|
2493
|
+
)
|
|
2494
|
+
elif issubclass(self.dtype, ctypes.Array):
|
|
2495
|
+
# vector/matrix
|
|
2496
|
+
cvalue = self.dtype(value)
|
|
2497
|
+
else:
|
|
2498
|
+
# scalar
|
|
2499
|
+
if type(value) in warp.types.scalar_types:
|
|
2500
|
+
value = value.value
|
|
2501
|
+
if self.dtype == float16:
|
|
2502
|
+
cvalue = self.dtype._type_(float_to_half_bits(value))
|
|
2503
|
+
else:
|
|
2504
|
+
cvalue = self.dtype._type_(value)
|
|
2505
|
+
except Exception as e:
|
|
2506
|
+
raise ValueError(f"Failed to convert the value to the array data type: {e}")
|
|
2507
|
+
|
|
2508
|
+
cvalue_ptr = ctypes.pointer(cvalue)
|
|
2509
|
+
cvalue_size = ctypes.sizeof(cvalue)
|
|
2510
|
+
|
|
2511
|
+
ctype = self.__ctype__()
|
|
2512
|
+
ctype_ptr = ctypes.pointer(ctype)
|
|
2513
|
+
|
|
2514
|
+
if self.device.is_cuda:
|
|
2515
|
+
warp.context.runtime.core.array_fill_device(
|
|
2516
|
+
self.device.context, ctype_ptr, self.type_id, cvalue_ptr, cvalue_size
|
|
2517
|
+
)
|
|
2518
|
+
else:
|
|
2519
|
+
warp.context.runtime.core.array_fill_host(ctype_ptr, self.type_id, cvalue_ptr, cvalue_size)
|
|
2520
|
+
|
|
2521
|
+
|
|
2522
|
+
# helper to check index array properties
|
|
2523
|
+
def check_index_array(indices, expected_device):
|
|
2524
|
+
if not isinstance(indices, array):
|
|
2525
|
+
raise ValueError(f"Indices must be a Warp array, got {type(indices)}")
|
|
2526
|
+
if indices.ndim != 1:
|
|
2527
|
+
raise ValueError(f"Index array must be one-dimensional, got {indices.ndim}")
|
|
2528
|
+
if indices.dtype != int32:
|
|
2529
|
+
raise ValueError(f"Index array must use int32, got dtype {indices.dtype}")
|
|
2530
|
+
if indices.device != expected_device:
|
|
2531
|
+
raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
|
|
2532
|
+
|
|
2533
|
+
|
|
2534
|
+
class indexedarray(noncontiguous_array_base[T]):
|
|
1895
2535
|
# member attributes available during code-gen (e.g.: d = arr.shape[0])
|
|
1896
2536
|
# (initialized when needed)
|
|
1897
2537
|
_vars = None
|
|
1898
2538
|
|
|
1899
2539
|
def __init__(self, data: array = None, indices: Union[array, List[array]] = None, dtype=None, ndim=None):
|
|
2540
|
+
super().__init__(ARRAY_TYPE_INDEXED)
|
|
2541
|
+
|
|
1900
2542
|
# canonicalize types
|
|
1901
2543
|
if dtype is not None:
|
|
1902
2544
|
if dtype == int:
|
|
@@ -1926,17 +2568,6 @@ class indexedarray(Generic[T]):
|
|
|
1926
2568
|
shape = list(data.shape)
|
|
1927
2569
|
|
|
1928
2570
|
if indices is not None:
|
|
1929
|
-
# helper to check index array properties
|
|
1930
|
-
def check_index_array(inds, data):
|
|
1931
|
-
if inds.ndim != 1:
|
|
1932
|
-
raise ValueError(f"Index array must be one-dimensional, got {inds.ndim}")
|
|
1933
|
-
if inds.dtype != int32:
|
|
1934
|
-
raise ValueError(f"Index array must use int32, got dtype {inds.dtype}")
|
|
1935
|
-
if inds.device != data.device:
|
|
1936
|
-
raise ValueError(
|
|
1937
|
-
f"Index array device ({inds.device} does not match data array device ({data.device}))"
|
|
1938
|
-
)
|
|
1939
|
-
|
|
1940
2571
|
if isinstance(indices, (list, tuple)):
|
|
1941
2572
|
if len(indices) > self.ndim:
|
|
1942
2573
|
raise ValueError(
|
|
@@ -1944,16 +2575,14 @@ class indexedarray(Generic[T]):
|
|
|
1944
2575
|
)
|
|
1945
2576
|
|
|
1946
2577
|
for i in range(len(indices)):
|
|
1947
|
-
if
|
|
1948
|
-
check_index_array(indices[i], data)
|
|
2578
|
+
if indices[i] is not None:
|
|
2579
|
+
check_index_array(indices[i], data.device)
|
|
1949
2580
|
self.indices[i] = indices[i]
|
|
1950
2581
|
shape[i] = len(indices[i])
|
|
1951
|
-
elif indices[i] is not None:
|
|
1952
|
-
raise TypeError(f"Invalid index array type: {type(indices[i])}")
|
|
1953
2582
|
|
|
1954
2583
|
elif isinstance(indices, array):
|
|
1955
2584
|
# only a single index array was provided
|
|
1956
|
-
check_index_array(indices, data)
|
|
2585
|
+
check_index_array(indices, data.device)
|
|
1957
2586
|
self.indices[0] = indices
|
|
1958
2587
|
shape[0] = len(indices)
|
|
1959
2588
|
|
|
@@ -1975,13 +2604,15 @@ class indexedarray(Generic[T]):
|
|
|
1975
2604
|
for d in self.shape:
|
|
1976
2605
|
self.size *= d
|
|
1977
2606
|
|
|
1978
|
-
self.is_contiguous = False
|
|
1979
|
-
|
|
1980
2607
|
def __len__(self):
|
|
1981
2608
|
return self.shape[0]
|
|
1982
2609
|
|
|
1983
2610
|
def __str__(self):
|
|
1984
|
-
|
|
2611
|
+
if self.device is None:
|
|
2612
|
+
# type annotation
|
|
2613
|
+
return f"indexedarray{self.dtype}"
|
|
2614
|
+
else:
|
|
2615
|
+
return str(self.numpy())
|
|
1985
2616
|
|
|
1986
2617
|
# construct a C-representation of the array for passing to kernels
|
|
1987
2618
|
def __ctype__(self):
|
|
@@ -1992,48 +2623,9 @@ class indexedarray(Generic[T]):
|
|
|
1992
2623
|
# member attributes available during code-gen (e.g.: d = arr.shape[0])
|
|
1993
2624
|
# Note: we use a shared dict for all indexedarray instances
|
|
1994
2625
|
if indexedarray._vars is None:
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
indexedarray._vars = {"shape": Var("shape", shape_t)}
|
|
2626
|
+
indexedarray._vars = {"shape": warp.codegen.Var("shape", shape_t)}
|
|
1998
2627
|
return indexedarray._vars
|
|
1999
2628
|
|
|
2000
|
-
def contiguous(self):
|
|
2001
|
-
a = warp.empty_like(self)
|
|
2002
|
-
warp.copy(a, self)
|
|
2003
|
-
|
|
2004
|
-
return a
|
|
2005
|
-
|
|
2006
|
-
# convert data from one device to another, nop if already on device
|
|
2007
|
-
def to(self, device):
|
|
2008
|
-
device = warp.get_device(device)
|
|
2009
|
-
if self.device == device:
|
|
2010
|
-
return self
|
|
2011
|
-
else:
|
|
2012
|
-
dest = warp.empty(shape=self.shape, dtype=self.dtype, device=device)
|
|
2013
|
-
# to copy between devices, array must be contiguous
|
|
2014
|
-
warp.copy(dest, self.contiguous())
|
|
2015
|
-
return dest
|
|
2016
|
-
|
|
2017
|
-
# convert array to ndarray (alias memory through array interface)
|
|
2018
|
-
def numpy(self):
|
|
2019
|
-
# use the CUDA default stream for synchronous behaviour with other streams
|
|
2020
|
-
with warp.ScopedStream(self.device.null_stream):
|
|
2021
|
-
|
|
2022
|
-
a = self.contiguous().to("cpu")
|
|
2023
|
-
|
|
2024
|
-
if isinstance(self.dtype, warp.codegen.Struct):
|
|
2025
|
-
p = ctypes.cast(a.ptr, ctypes.POINTER(a.dtype.ctype))
|
|
2026
|
-
np.ctypeslib.as_array(p, self.shape)
|
|
2027
|
-
else:
|
|
2028
|
-
# convert through array interface
|
|
2029
|
-
return np.array(a, copy=False)
|
|
2030
|
-
|
|
2031
|
-
# returns a flattened list of items in the array as a Python list
|
|
2032
|
-
def list(self):
|
|
2033
|
-
a = self.flatten()
|
|
2034
|
-
p = ctypes.cast(a.ptr, ctypes.POINTER(a.dtype.ctype))
|
|
2035
|
-
return p[:a.size]
|
|
2036
|
-
|
|
2037
2629
|
|
|
2038
2630
|
# aliases for indexedarrays with small dimensions
|
|
2039
2631
|
def indexedarray1d(*args, **kwargs):
|
|
@@ -2059,7 +2651,22 @@ def indexedarray4d(*args, **kwargs):
|
|
|
2059
2651
|
return indexedarray(*args, **kwargs)
|
|
2060
2652
|
|
|
2061
2653
|
|
|
2062
|
-
|
|
2654
|
+
from warp.fabric import fabricarray, indexedfabricarray # noqa: E402
|
|
2655
|
+
|
|
2656
|
+
array_types = (array, indexedarray, fabricarray, indexedfabricarray)
|
|
2657
|
+
|
|
2658
|
+
|
|
2659
|
+
def array_type_id(a):
|
|
2660
|
+
if isinstance(a, array):
|
|
2661
|
+
return ARRAY_TYPE_REGULAR
|
|
2662
|
+
elif isinstance(a, indexedarray):
|
|
2663
|
+
return ARRAY_TYPE_INDEXED
|
|
2664
|
+
elif isinstance(a, fabricarray):
|
|
2665
|
+
return ARRAY_TYPE_FABRIC
|
|
2666
|
+
elif isinstance(a, indexedfabricarray):
|
|
2667
|
+
return ARRAY_TYPE_FABRIC_INDEXED
|
|
2668
|
+
else:
|
|
2669
|
+
raise ValueError("Invalid array type")
|
|
2063
2670
|
|
|
2064
2671
|
|
|
2065
2672
|
class Bvh:
|
|
@@ -2117,11 +2724,11 @@ class Bvh:
|
|
|
2117
2724
|
with self.device.context_guard:
|
|
2118
2725
|
runtime.core.bvh_destroy_device(self.id)
|
|
2119
2726
|
|
|
2120
|
-
except:
|
|
2727
|
+
except Exception:
|
|
2121
2728
|
pass
|
|
2122
2729
|
|
|
2123
2730
|
def refit(self):
|
|
2124
|
-
"""Refit the
|
|
2731
|
+
"""Refit the BVH. This should be called after users modify the `lowers` and `uppers` arrays."""
|
|
2125
2732
|
|
|
2126
2733
|
from warp.context import runtime
|
|
2127
2734
|
|
|
@@ -2141,7 +2748,7 @@ class Mesh:
|
|
|
2141
2748
|
"indices": Var("indices", array(dtype=int32)),
|
|
2142
2749
|
}
|
|
2143
2750
|
|
|
2144
|
-
def __init__(self, points=None, indices=None, velocities=None):
|
|
2751
|
+
def __init__(self, points=None, indices=None, velocities=None, support_winding_number=False):
|
|
2145
2752
|
"""Class representing a triangle mesh.
|
|
2146
2753
|
|
|
2147
2754
|
Attributes:
|
|
@@ -2152,6 +2759,7 @@ class Mesh:
|
|
|
2152
2759
|
points (:class:`warp.array`): Array of vertex positions of type :class:`warp.vec3`
|
|
2153
2760
|
indices (:class:`warp.array`): Array of triangle indices of type :class:`warp.int32`, should be a 1d array with shape (num_tris, 3)
|
|
2154
2761
|
velocities (:class:`warp.array`): Array of vertex velocities of type :class:`warp.vec3` (optional)
|
|
2762
|
+
support_winding_number (bool): If true the mesh will build additional datastructures to support `wp.mesh_query_point_sign_winding_number()` queries
|
|
2155
2763
|
"""
|
|
2156
2764
|
|
|
2157
2765
|
if points.device != indices.device:
|
|
@@ -2183,6 +2791,7 @@ class Mesh:
|
|
|
2183
2791
|
indices.__ctype__(),
|
|
2184
2792
|
int(len(points)),
|
|
2185
2793
|
int(indices.size / 3),
|
|
2794
|
+
int(support_winding_number),
|
|
2186
2795
|
)
|
|
2187
2796
|
else:
|
|
2188
2797
|
self.id = runtime.core.mesh_create_device(
|
|
@@ -2192,6 +2801,7 @@ class Mesh:
|
|
|
2192
2801
|
indices.__ctype__(),
|
|
2193
2802
|
int(len(points)),
|
|
2194
2803
|
int(indices.size / 3),
|
|
2804
|
+
int(support_winding_number),
|
|
2195
2805
|
)
|
|
2196
2806
|
|
|
2197
2807
|
def __del__(self):
|
|
@@ -2204,7 +2814,7 @@ class Mesh:
|
|
|
2204
2814
|
# use CUDA context guard to avoid side effects during garbage collection
|
|
2205
2815
|
with self.device.context_guard:
|
|
2206
2816
|
runtime.core.mesh_destroy_device(self.id)
|
|
2207
|
-
except:
|
|
2817
|
+
except Exception:
|
|
2208
2818
|
pass
|
|
2209
2819
|
|
|
2210
2820
|
def refit(self):
|
|
@@ -2220,16 +2830,14 @@ class Mesh:
|
|
|
2220
2830
|
|
|
2221
2831
|
|
|
2222
2832
|
class Volume:
|
|
2833
|
+
#: Enum value to specify nearest-neighbor interpolation during sampling
|
|
2223
2834
|
CLOSEST = constant(0)
|
|
2835
|
+
#: Enum value to specify trilinear interpolation during sampling
|
|
2224
2836
|
LINEAR = constant(1)
|
|
2225
2837
|
|
|
2226
2838
|
def __init__(self, data: array):
|
|
2227
2839
|
"""Class representing a sparse grid.
|
|
2228
2840
|
|
|
2229
|
-
Attributes:
|
|
2230
|
-
CLOSEST (int): Enum value to specify nearest-neighbor interpolation during sampling
|
|
2231
|
-
LINEAR (int): Enum value to specify trilinear interpolation during sampling
|
|
2232
|
-
|
|
2233
2841
|
Args:
|
|
2234
2842
|
data (:class:`warp.array`): Array of bytes representing the volume in NanoVDB format
|
|
2235
2843
|
"""
|
|
@@ -2271,19 +2879,20 @@ class Volume:
|
|
|
2271
2879
|
with self.device.context_guard:
|
|
2272
2880
|
runtime.core.volume_destroy_device(self.id)
|
|
2273
2881
|
|
|
2274
|
-
except:
|
|
2882
|
+
except Exception:
|
|
2275
2883
|
pass
|
|
2276
2884
|
|
|
2277
|
-
def array(self):
|
|
2885
|
+
def array(self) -> array:
|
|
2886
|
+
"""Returns the raw memory buffer of the Volume as an array"""
|
|
2278
2887
|
buf = ctypes.c_void_p(0)
|
|
2279
2888
|
size = ctypes.c_uint64(0)
|
|
2280
2889
|
if self.device.is_cpu:
|
|
2281
2890
|
self.context.core.volume_get_buffer_info_host(self.id, ctypes.byref(buf), ctypes.byref(size))
|
|
2282
2891
|
else:
|
|
2283
2892
|
self.context.core.volume_get_buffer_info_device(self.id, ctypes.byref(buf), ctypes.byref(size))
|
|
2284
|
-
return array(ptr=buf.value, dtype=uint8,
|
|
2893
|
+
return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
|
|
2285
2894
|
|
|
2286
|
-
def get_tiles(self):
|
|
2895
|
+
def get_tiles(self) -> array:
|
|
2287
2896
|
if self.id == 0:
|
|
2288
2897
|
raise RuntimeError("Invalid Volume")
|
|
2289
2898
|
|
|
@@ -2294,11 +2903,9 @@ class Volume:
|
|
|
2294
2903
|
else:
|
|
2295
2904
|
self.context.core.volume_get_tiles_device(self.id, ctypes.byref(buf), ctypes.byref(size))
|
|
2296
2905
|
num_tiles = size.value // (3 * 4)
|
|
2297
|
-
return array(
|
|
2298
|
-
ptr=buf.value, dtype=int32, shape=(num_tiles, 3), length=size.value, device=self.device, owner=True
|
|
2299
|
-
)
|
|
2906
|
+
return array(ptr=buf.value, dtype=int32, shape=(num_tiles, 3), device=self.device, owner=True)
|
|
2300
2907
|
|
|
2301
|
-
def get_voxel_size(self):
|
|
2908
|
+
def get_voxel_size(self) -> Tuple[float, float, float]:
|
|
2302
2909
|
if self.id == 0:
|
|
2303
2910
|
raise RuntimeError("Invalid Volume")
|
|
2304
2911
|
|
|
@@ -2307,7 +2914,13 @@ class Volume:
|
|
|
2307
2914
|
return (dx.value, dy.value, dz.value)
|
|
2308
2915
|
|
|
2309
2916
|
@classmethod
|
|
2310
|
-
def load_from_nvdb(cls, file_or_buffer, device=None):
|
|
2917
|
+
def load_from_nvdb(cls, file_or_buffer, device=None) -> Volume:
|
|
2918
|
+
"""Creates a Volume object from a NanoVDB file or in-memory buffer.
|
|
2919
|
+
|
|
2920
|
+
Returns:
|
|
2921
|
+
|
|
2922
|
+
A ``warp.Volume`` object.
|
|
2923
|
+
"""
|
|
2311
2924
|
try:
|
|
2312
2925
|
data = file_or_buffer.read()
|
|
2313
2926
|
except AttributeError:
|
|
@@ -2336,6 +2949,90 @@ class Volume:
|
|
|
2336
2949
|
data_array = array(np.frombuffer(grid_data, dtype=np.byte), device=device)
|
|
2337
2950
|
return cls(data_array)
|
|
2338
2951
|
|
|
2952
|
+
@classmethod
|
|
2953
|
+
def load_from_numpy(
|
|
2954
|
+
cls, ndarray: np.array, min_world=(0.0, 0.0, 0.0), voxel_size=1.0, bg_value=0.0, device=None
|
|
2955
|
+
) -> Volume:
|
|
2956
|
+
"""Creates a Volume object from a dense 3D NumPy array.
|
|
2957
|
+
|
|
2958
|
+
This function is only supported for CUDA devices.
|
|
2959
|
+
|
|
2960
|
+
Args:
|
|
2961
|
+
min_world: The 3D coordinate of the lower corner of the volume.
|
|
2962
|
+
voxel_size: The size of each voxel in spatial coordinates.
|
|
2963
|
+
bg_value: Background value
|
|
2964
|
+
device: The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2965
|
+
|
|
2966
|
+
Returns:
|
|
2967
|
+
|
|
2968
|
+
A ``warp.Volume`` object.
|
|
2969
|
+
"""
|
|
2970
|
+
|
|
2971
|
+
import math
|
|
2972
|
+
|
|
2973
|
+
target_shape = (
|
|
2974
|
+
math.ceil(ndarray.shape[0] / 8) * 8,
|
|
2975
|
+
math.ceil(ndarray.shape[1] / 8) * 8,
|
|
2976
|
+
math.ceil(ndarray.shape[2] / 8) * 8,
|
|
2977
|
+
)
|
|
2978
|
+
if hasattr(bg_value, "__len__"):
|
|
2979
|
+
# vec3, assuming the numpy array is 4D
|
|
2980
|
+
padded_array = np.array((target_shape[0], target_shape[1], target_shape[2], 3), dtype=np.single)
|
|
2981
|
+
padded_array[:, :, :, :] = np.array(bg_value)
|
|
2982
|
+
padded_array[0 : ndarray.shape[0], 0 : ndarray.shape[1], 0 : ndarray.shape[2], :] = ndarray
|
|
2983
|
+
else:
|
|
2984
|
+
padded_amount = (
|
|
2985
|
+
math.ceil(ndarray.shape[0] / 8) * 8 - ndarray.shape[0],
|
|
2986
|
+
math.ceil(ndarray.shape[1] / 8) * 8 - ndarray.shape[1],
|
|
2987
|
+
math.ceil(ndarray.shape[2] / 8) * 8 - ndarray.shape[2],
|
|
2988
|
+
)
|
|
2989
|
+
padded_array = np.pad(
|
|
2990
|
+
ndarray,
|
|
2991
|
+
((0, padded_amount[0]), (0, padded_amount[1]), (0, padded_amount[2])),
|
|
2992
|
+
mode="constant",
|
|
2993
|
+
constant_values=bg_value,
|
|
2994
|
+
)
|
|
2995
|
+
|
|
2996
|
+
shape = padded_array.shape
|
|
2997
|
+
volume = warp.Volume.allocate(
|
|
2998
|
+
min_world,
|
|
2999
|
+
[
|
|
3000
|
+
min_world[0] + (shape[0] - 1) * voxel_size,
|
|
3001
|
+
min_world[1] + (shape[1] - 1) * voxel_size,
|
|
3002
|
+
min_world[2] + (shape[2] - 1) * voxel_size,
|
|
3003
|
+
],
|
|
3004
|
+
voxel_size,
|
|
3005
|
+
bg_value=bg_value,
|
|
3006
|
+
points_in_world_space=True,
|
|
3007
|
+
translation=min_world,
|
|
3008
|
+
device=device,
|
|
3009
|
+
)
|
|
3010
|
+
|
|
3011
|
+
# Populate volume
|
|
3012
|
+
if hasattr(bg_value, "__len__"):
|
|
3013
|
+
warp.launch(
|
|
3014
|
+
warp.utils.copy_dense_volume_to_nano_vdb_v,
|
|
3015
|
+
dim=(shape[0], shape[1], shape[2]),
|
|
3016
|
+
inputs=[volume.id, warp.array(padded_array, dtype=warp.vec3, device=device)],
|
|
3017
|
+
device=device,
|
|
3018
|
+
)
|
|
3019
|
+
elif isinstance(bg_value, int):
|
|
3020
|
+
warp.launch(
|
|
3021
|
+
warp.utils.copy_dense_volume_to_nano_vdb_i,
|
|
3022
|
+
dim=shape,
|
|
3023
|
+
inputs=[volume.id, warp.array(padded_array, dtype=warp.int32, device=device)],
|
|
3024
|
+
device=device,
|
|
3025
|
+
)
|
|
3026
|
+
else:
|
|
3027
|
+
warp.launch(
|
|
3028
|
+
warp.utils.copy_dense_volume_to_nano_vdb_f,
|
|
3029
|
+
dim=shape,
|
|
3030
|
+
inputs=[volume.id, warp.array(padded_array, dtype=warp.float32, device=device)],
|
|
3031
|
+
device=device,
|
|
3032
|
+
)
|
|
3033
|
+
|
|
3034
|
+
return volume
|
|
3035
|
+
|
|
2339
3036
|
@classmethod
|
|
2340
3037
|
def allocate(
|
|
2341
3038
|
cls,
|
|
@@ -2346,9 +3043,11 @@ class Volume:
|
|
|
2346
3043
|
translation=(0.0, 0.0, 0.0),
|
|
2347
3044
|
points_in_world_space=False,
|
|
2348
3045
|
device=None,
|
|
2349
|
-
):
|
|
3046
|
+
) -> Volume:
|
|
2350
3047
|
"""Allocate a new Volume based on the bounding box defined by min and max.
|
|
2351
3048
|
|
|
3049
|
+
This function is only supported for CUDA devices.
|
|
3050
|
+
|
|
2352
3051
|
Allocate a volume that is large enough to contain voxels [min[0], min[1], min[2]] - [max[0], max[1], max[2]], inclusive.
|
|
2353
3052
|
If points_in_world_space is true, then min and max are first converted to index space with the given voxel size and
|
|
2354
3053
|
translation, and the volume is allocated with those.
|
|
@@ -2357,12 +3056,12 @@ class Volume:
|
|
|
2357
3056
|
the resulting tiles will be available in the new volume.
|
|
2358
3057
|
|
|
2359
3058
|
Args:
|
|
2360
|
-
min (array-like): Lower 3D
|
|
2361
|
-
max (array-like): Upper 3D
|
|
2362
|
-
voxel_size (float): Voxel size of the new volume
|
|
3059
|
+
min (array-like): Lower 3D coordinates of the bounding box in index space or world space, inclusive.
|
|
3060
|
+
max (array-like): Upper 3D coordinates of the bounding box in index space or world space, inclusive.
|
|
3061
|
+
voxel_size (float): Voxel size of the new volume.
|
|
2363
3062
|
bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
|
|
2364
|
-
translation (array-like): translation between the index and world spaces
|
|
2365
|
-
device (Devicelike):
|
|
3063
|
+
translation (array-like): translation between the index and world spaces.
|
|
3064
|
+
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2366
3065
|
|
|
2367
3066
|
"""
|
|
2368
3067
|
if points_in_world_space:
|
|
@@ -2387,9 +3086,11 @@ class Volume:
|
|
|
2387
3086
|
@classmethod
|
|
2388
3087
|
def allocate_by_tiles(
|
|
2389
3088
|
cls, tile_points: array, voxel_size: float, bg_value=0.0, translation=(0.0, 0.0, 0.0), device=None
|
|
2390
|
-
):
|
|
3089
|
+
) -> Volume:
|
|
2391
3090
|
"""Allocate a new Volume with active tiles for each point tile_points.
|
|
2392
3091
|
|
|
3092
|
+
This function is only supported for CUDA devices.
|
|
3093
|
+
|
|
2393
3094
|
The smallest unit of allocation is a dense tile of 8x8x8 voxels.
|
|
2394
3095
|
This is the primary method for allocating sparse volumes. It uses an array of points indicating the tiles that must be allocated.
|
|
2395
3096
|
|
|
@@ -2399,13 +3100,13 @@ class Volume:
|
|
|
2399
3100
|
|
|
2400
3101
|
Args:
|
|
2401
3102
|
tile_points (:class:`warp.array`): Array of positions that define the tiles to be allocated.
|
|
2402
|
-
The array can be a
|
|
3103
|
+
The array can be a 2D, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
|
|
2403
3104
|
or can be a 1D array of :class:`warp.vec3` values, indicating world space positions.
|
|
2404
3105
|
Repeated points per tile are allowed and will be efficiently deduplicated.
|
|
2405
|
-
voxel_size (float): Voxel size of the new volume
|
|
3106
|
+
voxel_size (float): Voxel size of the new volume.
|
|
2406
3107
|
bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
|
|
2407
|
-
translation (array-like):
|
|
2408
|
-
device (Devicelike):
|
|
3108
|
+
translation (array-like): Translation between the index and world spaces.
|
|
3109
|
+
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2409
3110
|
|
|
2410
3111
|
"""
|
|
2411
3112
|
from warp.context import runtime
|
|
@@ -2442,7 +3143,7 @@ class Volume:
|
|
|
2442
3143
|
translation[2],
|
|
2443
3144
|
in_world_space,
|
|
2444
3145
|
)
|
|
2445
|
-
elif
|
|
3146
|
+
elif isinstance(bg_value, int):
|
|
2446
3147
|
volume.id = volume.context.core.volume_i_from_tiles_device(
|
|
2447
3148
|
volume.device.context,
|
|
2448
3149
|
ctypes.c_void_p(tile_points.ptr),
|
|
@@ -2473,6 +3174,67 @@ class Volume:
|
|
|
2473
3174
|
return volume
|
|
2474
3175
|
|
|
2475
3176
|
|
|
3177
|
+
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
3178
|
+
# NOTE: its layout must match the corresponding struct defined in C.
|
|
3179
|
+
# NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
|
|
3180
|
+
class mesh_query_point_t:
|
|
3181
|
+
"""Output for the mesh query point functions.
|
|
3182
|
+
|
|
3183
|
+
Attributes:
|
|
3184
|
+
result (bool): Whether a point is found within the given constraints.
|
|
3185
|
+
sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.
|
|
3186
|
+
Note that mesh must be watertight for this to be robust
|
|
3187
|
+
face (int32): Index of the closest face.
|
|
3188
|
+
u (float32): Barycentric u coordinate of the closest point.
|
|
3189
|
+
v (float32): Barycentric v coordinate of the closest point.
|
|
3190
|
+
|
|
3191
|
+
See Also:
|
|
3192
|
+
:func:`mesh_query_point`, :func:`mesh_query_point_no_sign`,
|
|
3193
|
+
:func:`mesh_query_furthest_point_no_sign`,
|
|
3194
|
+
:func:`mesh_query_point_sign_normal`,
|
|
3195
|
+
and :func:`mesh_query_point_sign_winding_number`.
|
|
3196
|
+
"""
|
|
3197
|
+
from warp.codegen import Var
|
|
3198
|
+
|
|
3199
|
+
vars = {
|
|
3200
|
+
"result": Var("result", bool),
|
|
3201
|
+
"sign": Var("sign", float32),
|
|
3202
|
+
"face": Var("face", int32),
|
|
3203
|
+
"u": Var("u", float32),
|
|
3204
|
+
"v": Var("v", float32),
|
|
3205
|
+
}
|
|
3206
|
+
|
|
3207
|
+
|
|
3208
|
+
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
3209
|
+
# NOTE: its layout must match the corresponding struct defined in C.
|
|
3210
|
+
class mesh_query_ray_t:
|
|
3211
|
+
"""Output for the mesh query ray functions.
|
|
3212
|
+
|
|
3213
|
+
Attributes:
|
|
3214
|
+
result (bool): Whether a hit is found within the given constraints.
|
|
3215
|
+
sign (float32): A value > 0 if the ray hit in front of the face, returns < 0 otherwise.
|
|
3216
|
+
face (int32): Index of the closest face.
|
|
3217
|
+
t (float32): Distance of the closest hit along the ray.
|
|
3218
|
+
u (float32): Barycentric u coordinate of the closest hit.
|
|
3219
|
+
v (float32): Barycentric v coordinate of the closest hit.
|
|
3220
|
+
normal (vec3f): Face normal.
|
|
3221
|
+
|
|
3222
|
+
See Also:
|
|
3223
|
+
:func:`mesh_query_ray`.
|
|
3224
|
+
"""
|
|
3225
|
+
from warp.codegen import Var
|
|
3226
|
+
|
|
3227
|
+
vars = {
|
|
3228
|
+
"result": Var("result", bool),
|
|
3229
|
+
"sign": Var("sign", float32),
|
|
3230
|
+
"face": Var("face", int32),
|
|
3231
|
+
"t": Var("t", float32),
|
|
3232
|
+
"u": Var("u", float32),
|
|
3233
|
+
"v": Var("v", float32),
|
|
3234
|
+
"normal": Var("normal", vec3),
|
|
3235
|
+
}
|
|
3236
|
+
|
|
3237
|
+
|
|
2476
3238
|
def matmul(
|
|
2477
3239
|
a: array2d,
|
|
2478
3240
|
b: array2d,
|
|
@@ -2480,7 +3242,7 @@ def matmul(
|
|
|
2480
3242
|
d: array2d,
|
|
2481
3243
|
alpha: float = 1.0,
|
|
2482
3244
|
beta: float = 0.0,
|
|
2483
|
-
allow_tf32x3_arith: bool = False,
|
|
3245
|
+
allow_tf32x3_arith: builtins.bool = False,
|
|
2484
3246
|
device=None,
|
|
2485
3247
|
):
|
|
2486
3248
|
"""Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
@@ -2509,6 +3271,11 @@ def matmul(
|
|
|
2509
3271
|
"wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
2510
3272
|
)
|
|
2511
3273
|
|
|
3274
|
+
if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
|
|
3275
|
+
raise RuntimeError(
|
|
3276
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
3277
|
+
)
|
|
3278
|
+
|
|
2512
3279
|
m = a.shape[0]
|
|
2513
3280
|
n = b.shape[1]
|
|
2514
3281
|
k = a.shape[1]
|
|
@@ -2543,13 +3310,13 @@ def matmul(
|
|
|
2543
3310
|
ctypes.c_void_p(d.ptr),
|
|
2544
3311
|
alpha,
|
|
2545
3312
|
beta,
|
|
2546
|
-
|
|
2547
|
-
|
|
3313
|
+
not a.is_transposed,
|
|
3314
|
+
not b.is_transposed,
|
|
2548
3315
|
allow_tf32x3_arith,
|
|
2549
3316
|
1,
|
|
2550
3317
|
)
|
|
2551
3318
|
if not ret:
|
|
2552
|
-
raise RuntimeError("
|
|
3319
|
+
raise RuntimeError("matmul failed.")
|
|
2553
3320
|
|
|
2554
3321
|
|
|
2555
3322
|
def adj_matmul(
|
|
@@ -2562,7 +3329,7 @@ def adj_matmul(
|
|
|
2562
3329
|
adj_d: array2d,
|
|
2563
3330
|
alpha: float = 1.0,
|
|
2564
3331
|
beta: float = 0.0,
|
|
2565
|
-
allow_tf32x3_arith: bool = False,
|
|
3332
|
+
allow_tf32x3_arith: builtins.bool = False,
|
|
2566
3333
|
device=None,
|
|
2567
3334
|
):
|
|
2568
3335
|
"""Computes the adjoint of a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
@@ -2613,6 +3380,19 @@ def adj_matmul(
|
|
|
2613
3380
|
"wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
2614
3381
|
)
|
|
2615
3382
|
|
|
3383
|
+
if (
|
|
3384
|
+
(not a.is_contiguous and not a.is_transposed)
|
|
3385
|
+
or (not b.is_contiguous and not b.is_transposed)
|
|
3386
|
+
or (not c.is_contiguous)
|
|
3387
|
+
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
3388
|
+
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
3389
|
+
or (not adj_c.is_contiguous)
|
|
3390
|
+
or (not adj_d.is_contiguous)
|
|
3391
|
+
):
|
|
3392
|
+
raise RuntimeError(
|
|
3393
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
3394
|
+
)
|
|
3395
|
+
|
|
2616
3396
|
m = a.shape[0]
|
|
2617
3397
|
n = b.shape[1]
|
|
2618
3398
|
k = a.shape[1]
|
|
@@ -2633,75 +3413,105 @@ def adj_matmul(
|
|
|
2633
3413
|
|
|
2634
3414
|
# cpu fallback if no cuda devices found
|
|
2635
3415
|
if device == "cpu":
|
|
2636
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()))
|
|
2637
|
-
adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()))
|
|
2638
|
-
adj_c.assign(beta * adj_d.numpy())
|
|
3416
|
+
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
|
|
3417
|
+
adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
|
|
3418
|
+
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
2639
3419
|
return
|
|
2640
3420
|
|
|
2641
3421
|
cc = device.arch
|
|
2642
3422
|
|
|
2643
3423
|
# adj_a
|
|
2644
|
-
|
|
2645
|
-
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
2651
|
-
|
|
2652
|
-
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
|
|
2656
|
-
|
|
2657
|
-
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
|
|
2661
|
-
|
|
2662
|
-
|
|
3424
|
+
if not a.is_transposed:
|
|
3425
|
+
ret = runtime.core.cutlass_gemm(
|
|
3426
|
+
cc,
|
|
3427
|
+
m,
|
|
3428
|
+
k,
|
|
3429
|
+
n,
|
|
3430
|
+
type_typestr(a.dtype).encode(),
|
|
3431
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3432
|
+
ctypes.c_void_p(b.ptr),
|
|
3433
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3434
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3435
|
+
alpha,
|
|
3436
|
+
1.0,
|
|
3437
|
+
True,
|
|
3438
|
+
b.is_transposed,
|
|
3439
|
+
allow_tf32x3_arith,
|
|
3440
|
+
1,
|
|
3441
|
+
)
|
|
3442
|
+
if not ret:
|
|
3443
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3444
|
+
else:
|
|
3445
|
+
ret = runtime.core.cutlass_gemm(
|
|
3446
|
+
cc,
|
|
3447
|
+
k,
|
|
3448
|
+
m,
|
|
3449
|
+
n,
|
|
3450
|
+
type_typestr(a.dtype).encode(),
|
|
3451
|
+
ctypes.c_void_p(b.ptr),
|
|
3452
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3453
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3454
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3455
|
+
alpha,
|
|
3456
|
+
1.0,
|
|
3457
|
+
not b.is_transposed,
|
|
3458
|
+
False,
|
|
3459
|
+
allow_tf32x3_arith,
|
|
3460
|
+
1,
|
|
3461
|
+
)
|
|
3462
|
+
if not ret:
|
|
3463
|
+
raise RuntimeError("adj_matmul failed.")
|
|
2663
3464
|
|
|
2664
3465
|
# adj_b
|
|
2665
|
-
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
|
|
2669
|
-
|
|
2670
|
-
|
|
2671
|
-
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
|
|
2675
|
-
|
|
2676
|
-
|
|
2677
|
-
|
|
2678
|
-
|
|
2679
|
-
|
|
2680
|
-
|
|
2681
|
-
|
|
2682
|
-
|
|
2683
|
-
|
|
3466
|
+
if not b.is_transposed:
|
|
3467
|
+
ret = runtime.core.cutlass_gemm(
|
|
3468
|
+
cc,
|
|
3469
|
+
k,
|
|
3470
|
+
n,
|
|
3471
|
+
m,
|
|
3472
|
+
type_typestr(a.dtype).encode(),
|
|
3473
|
+
ctypes.c_void_p(a.ptr),
|
|
3474
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3475
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3476
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3477
|
+
alpha,
|
|
3478
|
+
1.0,
|
|
3479
|
+
a.is_transposed,
|
|
3480
|
+
True,
|
|
3481
|
+
allow_tf32x3_arith,
|
|
3482
|
+
1,
|
|
3483
|
+
)
|
|
3484
|
+
if not ret:
|
|
3485
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3486
|
+
else:
|
|
3487
|
+
ret = runtime.core.cutlass_gemm(
|
|
3488
|
+
cc,
|
|
3489
|
+
n,
|
|
3490
|
+
k,
|
|
3491
|
+
m,
|
|
3492
|
+
type_typestr(a.dtype).encode(),
|
|
3493
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3494
|
+
ctypes.c_void_p(a.ptr),
|
|
3495
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3496
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3497
|
+
alpha,
|
|
3498
|
+
1.0,
|
|
3499
|
+
False,
|
|
3500
|
+
not a.is_transposed,
|
|
3501
|
+
allow_tf32x3_arith,
|
|
3502
|
+
1,
|
|
3503
|
+
)
|
|
3504
|
+
if not ret:
|
|
3505
|
+
raise RuntimeError("adj_matmul failed.")
|
|
2684
3506
|
|
|
2685
3507
|
# adj_c
|
|
2686
|
-
|
|
2687
|
-
|
|
2688
|
-
|
|
2689
|
-
|
|
2690
|
-
|
|
2691
|
-
|
|
2692
|
-
ctypes.c_void_p(a.ptr),
|
|
2693
|
-
ctypes.c_void_p(b.ptr),
|
|
2694
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
2695
|
-
ctypes.c_void_p(adj_c.ptr),
|
|
2696
|
-
0.0,
|
|
2697
|
-
beta,
|
|
2698
|
-
True,
|
|
2699
|
-
True,
|
|
2700
|
-
allow_tf32x3_arith,
|
|
2701
|
-
1,
|
|
3508
|
+
warp.launch(
|
|
3509
|
+
kernel=warp.utils.add_kernel_2d,
|
|
3510
|
+
dim=adj_c.shape,
|
|
3511
|
+
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
3512
|
+
device=device,
|
|
3513
|
+
record_tape=False
|
|
2702
3514
|
)
|
|
2703
|
-
if not ret:
|
|
2704
|
-
raise RuntimeError("adj_matmul failed.")
|
|
2705
3515
|
|
|
2706
3516
|
|
|
2707
3517
|
def batched_matmul(
|
|
@@ -2711,7 +3521,7 @@ def batched_matmul(
|
|
|
2711
3521
|
d: array3d,
|
|
2712
3522
|
alpha: float = 1.0,
|
|
2713
3523
|
beta: float = 0.0,
|
|
2714
|
-
allow_tf32x3_arith: bool = False,
|
|
3524
|
+
allow_tf32x3_arith: builtins.bool = False,
|
|
2715
3525
|
device=None,
|
|
2716
3526
|
):
|
|
2717
3527
|
"""Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
@@ -2740,6 +3550,11 @@ def batched_matmul(
|
|
|
2740
3550
|
"wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
2741
3551
|
)
|
|
2742
3552
|
|
|
3553
|
+
if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
|
|
3554
|
+
raise RuntimeError(
|
|
3555
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
3556
|
+
)
|
|
3557
|
+
|
|
2743
3558
|
m = a.shape[1]
|
|
2744
3559
|
n = b.shape[2]
|
|
2745
3560
|
k = a.shape[2]
|
|
@@ -2751,7 +3566,7 @@ def batched_matmul(
|
|
|
2751
3566
|
|
|
2752
3567
|
if runtime.tape:
|
|
2753
3568
|
runtime.tape.record_func(
|
|
2754
|
-
backward=lambda:
|
|
3569
|
+
backward=lambda: adj_batched_matmul(
|
|
2755
3570
|
a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
|
|
2756
3571
|
),
|
|
2757
3572
|
arrays=[a, b, c, d],
|
|
@@ -2762,26 +3577,55 @@ def batched_matmul(
|
|
|
2762
3577
|
d.assign(alpha * np.matmul(a.numpy(), b.numpy()) + beta * c.numpy())
|
|
2763
3578
|
return
|
|
2764
3579
|
|
|
3580
|
+
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
3581
|
+
max_batch_count = 65535
|
|
3582
|
+
iters = int(batch_count / max_batch_count)
|
|
3583
|
+
remainder = batch_count % max_batch_count
|
|
3584
|
+
|
|
2765
3585
|
cc = device.arch
|
|
3586
|
+
for i in range(iters):
|
|
3587
|
+
idx_start = i * max_batch_count
|
|
3588
|
+
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
3589
|
+
ret = runtime.core.cutlass_gemm(
|
|
3590
|
+
cc,
|
|
3591
|
+
m,
|
|
3592
|
+
n,
|
|
3593
|
+
k,
|
|
3594
|
+
type_typestr(a.dtype).encode(),
|
|
3595
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3596
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3597
|
+
ctypes.c_void_p(c[idx_start:idx_end,:,:].ptr),
|
|
3598
|
+
ctypes.c_void_p(d[idx_start:idx_end,:,:].ptr),
|
|
3599
|
+
alpha,
|
|
3600
|
+
beta,
|
|
3601
|
+
not a.is_transposed,
|
|
3602
|
+
not b.is_transposed,
|
|
3603
|
+
allow_tf32x3_arith,
|
|
3604
|
+
max_batch_count,
|
|
3605
|
+
)
|
|
3606
|
+
if not ret:
|
|
3607
|
+
raise RuntimeError("Batched matmul failed.")
|
|
3608
|
+
|
|
3609
|
+
idx_start = iters * max_batch_count
|
|
2766
3610
|
ret = runtime.core.cutlass_gemm(
|
|
2767
3611
|
cc,
|
|
2768
3612
|
m,
|
|
2769
3613
|
n,
|
|
2770
3614
|
k,
|
|
2771
3615
|
type_typestr(a.dtype).encode(),
|
|
2772
|
-
ctypes.c_void_p(a.ptr),
|
|
2773
|
-
ctypes.c_void_p(b.ptr),
|
|
2774
|
-
ctypes.c_void_p(c.ptr),
|
|
2775
|
-
ctypes.c_void_p(d.ptr),
|
|
3616
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3617
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3618
|
+
ctypes.c_void_p(c[idx_start:,:,:].ptr),
|
|
3619
|
+
ctypes.c_void_p(d[idx_start:,:,:].ptr),
|
|
2776
3620
|
alpha,
|
|
2777
3621
|
beta,
|
|
2778
|
-
|
|
2779
|
-
|
|
3622
|
+
not a.is_transposed,
|
|
3623
|
+
not b.is_transposed,
|
|
2780
3624
|
allow_tf32x3_arith,
|
|
2781
|
-
|
|
3625
|
+
remainder,
|
|
2782
3626
|
)
|
|
2783
3627
|
if not ret:
|
|
2784
|
-
raise RuntimeError("Batched matmul failed.")
|
|
3628
|
+
raise RuntimeError("Batched matmul failed.")
|
|
2785
3629
|
|
|
2786
3630
|
|
|
2787
3631
|
def adj_batched_matmul(
|
|
@@ -2794,7 +3638,7 @@ def adj_batched_matmul(
|
|
|
2794
3638
|
adj_d: array3d,
|
|
2795
3639
|
alpha: float = 1.0,
|
|
2796
3640
|
beta: float = 0.0,
|
|
2797
|
-
allow_tf32x3_arith: bool = False,
|
|
3641
|
+
allow_tf32x3_arith: builtins.bool = False,
|
|
2798
3642
|
device=None,
|
|
2799
3643
|
):
|
|
2800
3644
|
"""Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
|
|
@@ -2861,78 +3705,215 @@ def adj_batched_matmul(
|
|
|
2861
3705
|
)
|
|
2862
3706
|
)
|
|
2863
3707
|
|
|
3708
|
+
if (
|
|
3709
|
+
(not a.is_contiguous and not a.is_transposed)
|
|
3710
|
+
or (not b.is_contiguous and not b.is_transposed)
|
|
3711
|
+
or (not c.is_contiguous)
|
|
3712
|
+
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
3713
|
+
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
3714
|
+
or (not adj_c.is_contiguous)
|
|
3715
|
+
or (not adj_d.is_contiguous)
|
|
3716
|
+
):
|
|
3717
|
+
raise RuntimeError(
|
|
3718
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
3719
|
+
)
|
|
3720
|
+
|
|
2864
3721
|
# cpu fallback if no cuda devices found
|
|
2865
3722
|
if device == "cpu":
|
|
2866
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))))
|
|
2867
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()))
|
|
2868
|
-
adj_c.assign(beta * adj_d.numpy())
|
|
3723
|
+
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
|
|
3724
|
+
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
|
|
3725
|
+
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
2869
3726
|
return
|
|
2870
3727
|
|
|
3728
|
+
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
3729
|
+
max_batch_count = 65535
|
|
3730
|
+
iters = int(batch_count / max_batch_count)
|
|
3731
|
+
remainder = batch_count % max_batch_count
|
|
3732
|
+
|
|
2871
3733
|
cc = device.arch
|
|
2872
3734
|
|
|
3735
|
+
for i in range(iters):
|
|
3736
|
+
idx_start = i * max_batch_count
|
|
3737
|
+
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
3738
|
+
|
|
3739
|
+
# adj_a
|
|
3740
|
+
if not a.is_transposed:
|
|
3741
|
+
ret = runtime.core.cutlass_gemm(
|
|
3742
|
+
cc,
|
|
3743
|
+
m,
|
|
3744
|
+
k,
|
|
3745
|
+
n,
|
|
3746
|
+
type_typestr(a.dtype).encode(),
|
|
3747
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3748
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3749
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3750
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3751
|
+
alpha,
|
|
3752
|
+
1.0,
|
|
3753
|
+
True,
|
|
3754
|
+
b.is_transposed,
|
|
3755
|
+
allow_tf32x3_arith,
|
|
3756
|
+
max_batch_count,
|
|
3757
|
+
)
|
|
3758
|
+
if not ret:
|
|
3759
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3760
|
+
else:
|
|
3761
|
+
ret = runtime.core.cutlass_gemm(
|
|
3762
|
+
cc,
|
|
3763
|
+
k,
|
|
3764
|
+
m,
|
|
3765
|
+
n,
|
|
3766
|
+
type_typestr(a.dtype).encode(),
|
|
3767
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3768
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3769
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3770
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3771
|
+
alpha,
|
|
3772
|
+
1.0,
|
|
3773
|
+
not b.is_transposed,
|
|
3774
|
+
False,
|
|
3775
|
+
allow_tf32x3_arith,
|
|
3776
|
+
max_batch_count,
|
|
3777
|
+
)
|
|
3778
|
+
if not ret:
|
|
3779
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3780
|
+
|
|
3781
|
+
# adj_b
|
|
3782
|
+
if not b.is_transposed:
|
|
3783
|
+
ret = runtime.core.cutlass_gemm(
|
|
3784
|
+
cc,
|
|
3785
|
+
k,
|
|
3786
|
+
n,
|
|
3787
|
+
m,
|
|
3788
|
+
type_typestr(a.dtype).encode(),
|
|
3789
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3790
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3791
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3792
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3793
|
+
alpha,
|
|
3794
|
+
1.0,
|
|
3795
|
+
a.is_transposed,
|
|
3796
|
+
True,
|
|
3797
|
+
allow_tf32x3_arith,
|
|
3798
|
+
max_batch_count,
|
|
3799
|
+
)
|
|
3800
|
+
if not ret:
|
|
3801
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3802
|
+
else:
|
|
3803
|
+
ret = runtime.core.cutlass_gemm(
|
|
3804
|
+
cc,
|
|
3805
|
+
n,
|
|
3806
|
+
k,
|
|
3807
|
+
m,
|
|
3808
|
+
type_typestr(a.dtype).encode(),
|
|
3809
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3810
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3811
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3812
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3813
|
+
alpha,
|
|
3814
|
+
1.0,
|
|
3815
|
+
False,
|
|
3816
|
+
not a.is_transposed,
|
|
3817
|
+
allow_tf32x3_arith,
|
|
3818
|
+
max_batch_count,
|
|
3819
|
+
)
|
|
3820
|
+
if not ret:
|
|
3821
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3822
|
+
|
|
3823
|
+
idx_start = iters * max_batch_count
|
|
3824
|
+
|
|
2873
3825
|
# adj_a
|
|
2874
|
-
|
|
2875
|
-
|
|
2876
|
-
|
|
2877
|
-
|
|
2878
|
-
|
|
2879
|
-
|
|
2880
|
-
|
|
2881
|
-
|
|
2882
|
-
|
|
2883
|
-
|
|
2884
|
-
|
|
2885
|
-
|
|
2886
|
-
|
|
2887
|
-
|
|
2888
|
-
|
|
2889
|
-
|
|
2890
|
-
|
|
2891
|
-
|
|
2892
|
-
|
|
3826
|
+
if not a.is_transposed:
|
|
3827
|
+
ret = runtime.core.cutlass_gemm(
|
|
3828
|
+
cc,
|
|
3829
|
+
m,
|
|
3830
|
+
k,
|
|
3831
|
+
n,
|
|
3832
|
+
type_typestr(a.dtype).encode(),
|
|
3833
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3834
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3835
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3836
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3837
|
+
alpha,
|
|
3838
|
+
1.0,
|
|
3839
|
+
True,
|
|
3840
|
+
b.is_transposed,
|
|
3841
|
+
allow_tf32x3_arith,
|
|
3842
|
+
remainder,
|
|
3843
|
+
)
|
|
3844
|
+
if not ret:
|
|
3845
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3846
|
+
else:
|
|
3847
|
+
ret = runtime.core.cutlass_gemm(
|
|
3848
|
+
cc,
|
|
3849
|
+
k,
|
|
3850
|
+
m,
|
|
3851
|
+
n,
|
|
3852
|
+
type_typestr(a.dtype).encode(),
|
|
3853
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3854
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3855
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3856
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3857
|
+
alpha,
|
|
3858
|
+
1.0,
|
|
3859
|
+
not b.is_transposed,
|
|
3860
|
+
False,
|
|
3861
|
+
allow_tf32x3_arith,
|
|
3862
|
+
remainder,
|
|
3863
|
+
)
|
|
3864
|
+
if not ret:
|
|
3865
|
+
raise RuntimeError("adj_matmul failed.")
|
|
2893
3866
|
|
|
2894
3867
|
# adj_b
|
|
2895
|
-
|
|
2896
|
-
|
|
2897
|
-
|
|
2898
|
-
|
|
2899
|
-
|
|
2900
|
-
|
|
2901
|
-
|
|
2902
|
-
|
|
2903
|
-
|
|
2904
|
-
|
|
2905
|
-
|
|
2906
|
-
|
|
2907
|
-
|
|
2908
|
-
|
|
2909
|
-
|
|
2910
|
-
|
|
2911
|
-
|
|
2912
|
-
|
|
2913
|
-
|
|
3868
|
+
if not b.is_transposed:
|
|
3869
|
+
ret = runtime.core.cutlass_gemm(
|
|
3870
|
+
cc,
|
|
3871
|
+
k,
|
|
3872
|
+
n,
|
|
3873
|
+
m,
|
|
3874
|
+
type_typestr(a.dtype).encode(),
|
|
3875
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3876
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3877
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3878
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3879
|
+
alpha,
|
|
3880
|
+
1.0,
|
|
3881
|
+
a.is_transposed,
|
|
3882
|
+
True,
|
|
3883
|
+
allow_tf32x3_arith,
|
|
3884
|
+
remainder,
|
|
3885
|
+
)
|
|
3886
|
+
if not ret:
|
|
3887
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3888
|
+
else:
|
|
3889
|
+
ret = runtime.core.cutlass_gemm(
|
|
3890
|
+
cc,
|
|
3891
|
+
n,
|
|
3892
|
+
k,
|
|
3893
|
+
m,
|
|
3894
|
+
type_typestr(a.dtype).encode(),
|
|
3895
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3896
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3897
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3898
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3899
|
+
alpha,
|
|
3900
|
+
1.0,
|
|
3901
|
+
False,
|
|
3902
|
+
not a.is_transposed,
|
|
3903
|
+
allow_tf32x3_arith,
|
|
3904
|
+
remainder,
|
|
3905
|
+
)
|
|
3906
|
+
if not ret:
|
|
3907
|
+
raise RuntimeError("adj_matmul failed.")
|
|
2914
3908
|
|
|
2915
3909
|
# adj_c
|
|
2916
|
-
|
|
2917
|
-
|
|
2918
|
-
|
|
2919
|
-
|
|
2920
|
-
|
|
2921
|
-
|
|
2922
|
-
ctypes.c_void_p(a.ptr),
|
|
2923
|
-
ctypes.c_void_p(b.ptr),
|
|
2924
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
2925
|
-
ctypes.c_void_p(adj_c.ptr),
|
|
2926
|
-
0.0,
|
|
2927
|
-
beta,
|
|
2928
|
-
True,
|
|
2929
|
-
True,
|
|
2930
|
-
allow_tf32x3_arith,
|
|
2931
|
-
batch_count,
|
|
3910
|
+
warp.launch(
|
|
3911
|
+
kernel=warp.utils.add_kernel_3d,
|
|
3912
|
+
dim=adj_c.shape,
|
|
3913
|
+
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
3914
|
+
device=device,
|
|
3915
|
+
record_tape=False
|
|
2932
3916
|
)
|
|
2933
|
-
if not ret:
|
|
2934
|
-
raise RuntimeError("adj_matmul failed.")
|
|
2935
|
-
|
|
2936
3917
|
|
|
2937
3918
|
class HashGrid:
|
|
2938
3919
|
def __init__(self, dim_x, dim_y, dim_z, device=None):
|
|
@@ -3001,7 +3982,7 @@ class HashGrid:
|
|
|
3001
3982
|
with self.device.context_guard:
|
|
3002
3983
|
runtime.core.hash_grid_destroy_device(self.id)
|
|
3003
3984
|
|
|
3004
|
-
except:
|
|
3985
|
+
except Exception:
|
|
3005
3986
|
pass
|
|
3006
3987
|
|
|
3007
3988
|
|
|
@@ -3075,7 +4056,7 @@ class MarchingCubes:
|
|
|
3075
4056
|
|
|
3076
4057
|
if error:
|
|
3077
4058
|
raise RuntimeError(
|
|
3078
|
-
"
|
|
4059
|
+
"Buffers may not be large enough, marching cubes required at least {num_verts} vertices, and {num_tris} triangles."
|
|
3079
4060
|
)
|
|
3080
4061
|
|
|
3081
4062
|
# resize the geometry arrays
|
|
@@ -3131,7 +4112,7 @@ def type_matches_template(arg_type, template_type):
|
|
|
3131
4112
|
return True
|
|
3132
4113
|
elif is_array(template_type):
|
|
3133
4114
|
# ensure the argument type is a non-generic array with matching dtype and dimensionality
|
|
3134
|
-
if type(arg_type)
|
|
4115
|
+
if type(arg_type) is not type(template_type):
|
|
3135
4116
|
return False
|
|
3136
4117
|
if not type_matches_template(arg_type.dtype, template_type.dtype):
|
|
3137
4118
|
return False
|
|
@@ -3160,9 +4141,53 @@ def type_matches_template(arg_type, template_type):
|
|
|
3160
4141
|
return True
|
|
3161
4142
|
|
|
3162
4143
|
|
|
4144
|
+
def infer_argument_types(args, template_types, arg_names=None):
|
|
4145
|
+
"""Resolve argument types with the given list of template types."""
|
|
4146
|
+
|
|
4147
|
+
if len(args) != len(template_types):
|
|
4148
|
+
raise RuntimeError("Number of arguments must match number of template types.")
|
|
4149
|
+
|
|
4150
|
+
arg_types = []
|
|
4151
|
+
|
|
4152
|
+
for i in range(len(args)):
|
|
4153
|
+
arg = args[i]
|
|
4154
|
+
arg_type = type(arg)
|
|
4155
|
+
arg_name = arg_names[i] if arg_names else str(i)
|
|
4156
|
+
if arg_type in warp.types.array_types:
|
|
4157
|
+
arg_types.append(arg_type(dtype=arg.dtype, ndim=arg.ndim))
|
|
4158
|
+
elif arg_type in warp.types.scalar_types:
|
|
4159
|
+
arg_types.append(arg_type)
|
|
4160
|
+
elif arg_type in [int, float]:
|
|
4161
|
+
# canonicalize type
|
|
4162
|
+
arg_types.append(warp.types.type_to_warp(arg_type))
|
|
4163
|
+
elif hasattr(arg_type, "_wp_scalar_type_"):
|
|
4164
|
+
# vector/matrix type
|
|
4165
|
+
arg_types.append(arg_type)
|
|
4166
|
+
elif issubclass(arg_type, warp.codegen.StructInstance):
|
|
4167
|
+
# a struct
|
|
4168
|
+
arg_types.append(arg._cls)
|
|
4169
|
+
# elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
|
|
4170
|
+
# arg_types.append(arg_type)
|
|
4171
|
+
# elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
|
|
4172
|
+
# arg_types.append(arg_type)
|
|
4173
|
+
elif arg is None:
|
|
4174
|
+
# allow passing None for arrays
|
|
4175
|
+
t = template_types[i]
|
|
4176
|
+
if warp.types.is_array(t):
|
|
4177
|
+
arg_types.append(type(t)(dtype=t.dtype, ndim=t.ndim))
|
|
4178
|
+
else:
|
|
4179
|
+
raise TypeError(f"Unable to infer the type of argument '{arg_name}', got None")
|
|
4180
|
+
else:
|
|
4181
|
+
# TODO: attempt to figure out if it's a vector/matrix type given as a numpy array, list, etc.
|
|
4182
|
+
raise TypeError(f"Unable to infer the type of argument '{arg_name}', got {arg_type}")
|
|
4183
|
+
|
|
4184
|
+
return arg_types
|
|
4185
|
+
|
|
4186
|
+
|
|
3163
4187
|
simple_type_codes = {
|
|
3164
4188
|
int: "i4",
|
|
3165
4189
|
float: "f4",
|
|
4190
|
+
builtins.bool: "b",
|
|
3166
4191
|
bool: "b",
|
|
3167
4192
|
str: "str", # accepted by print()
|
|
3168
4193
|
int8: "i1",
|
|
@@ -3181,6 +4206,8 @@ simple_type_codes = {
|
|
|
3181
4206
|
launch_bounds_t: "lb",
|
|
3182
4207
|
hash_grid_query_t: "hgq",
|
|
3183
4208
|
mesh_query_aabb_t: "mqa",
|
|
4209
|
+
mesh_query_point_t: "mqp",
|
|
4210
|
+
mesh_query_ray_t: "mqr",
|
|
3184
4211
|
bvh_query_t: "bvhq",
|
|
3185
4212
|
}
|
|
3186
4213
|
|
|
@@ -3197,14 +4224,14 @@ def get_type_code(arg_type):
|
|
|
3197
4224
|
# check for "special" vector/matrix subtypes
|
|
3198
4225
|
if hasattr(arg_type, "_wp_generic_type_str_"):
|
|
3199
4226
|
type_str = arg_type._wp_generic_type_str_
|
|
3200
|
-
if type_str == "
|
|
4227
|
+
if type_str == "quat_t":
|
|
3201
4228
|
return f"q{dtype_code}"
|
|
3202
4229
|
elif type_str == "transform_t":
|
|
3203
4230
|
return f"t{dtype_code}"
|
|
3204
|
-
elif type_str == "spatial_vector_t":
|
|
3205
|
-
|
|
3206
|
-
elif type_str == "spatial_matrix_t":
|
|
3207
|
-
|
|
4231
|
+
# elif type_str == "spatial_vector_t":
|
|
4232
|
+
# return f"sv{dtype_code}"
|
|
4233
|
+
# elif type_str == "spatial_matrix_t":
|
|
4234
|
+
# return f"sm{dtype_code}"
|
|
3208
4235
|
# generic vector/matrix
|
|
3209
4236
|
ndim = len(arg_type._shape_)
|
|
3210
4237
|
if ndim == 1:
|
|
@@ -3227,6 +4254,10 @@ def get_type_code(arg_type):
|
|
|
3227
4254
|
return f"a{arg_type.ndim}{get_type_code(arg_type.dtype)}"
|
|
3228
4255
|
elif isinstance(arg_type, indexedarray):
|
|
3229
4256
|
return f"ia{arg_type.ndim}{get_type_code(arg_type.dtype)}"
|
|
4257
|
+
elif isinstance(arg_type, fabricarray):
|
|
4258
|
+
return f"fa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
|
|
4259
|
+
elif isinstance(arg_type, indexedfabricarray):
|
|
4260
|
+
return f"ifa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
|
|
3230
4261
|
elif isinstance(arg_type, warp.codegen.Struct):
|
|
3231
4262
|
return warp.codegen.make_full_qualified_name(arg_type.cls)
|
|
3232
4263
|
elif arg_type == Scalar:
|