warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/torch.py
CHANGED
|
@@ -5,20 +5,21 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
-
import warp
|
|
9
|
-
import numpy
|
|
10
8
|
import ctypes
|
|
11
|
-
from typing import Union
|
|
12
9
|
|
|
10
|
+
import numpy
|
|
11
|
+
|
|
12
|
+
import warp
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
# return the warp device corresponding to a torch device
|
|
16
16
|
def device_from_torch(torch_device):
|
|
17
|
+
"""Return the warp device corresponding to a torch device."""
|
|
17
18
|
return warp.get_device(str(torch_device))
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
# return the torch device corresponding to a warp device
|
|
21
21
|
def device_to_torch(wp_device):
|
|
22
|
+
"""Return the torch device corresponding to a warp device."""
|
|
22
23
|
device = warp.get_device(wp_device)
|
|
23
24
|
if device.is_cpu or device.is_primary:
|
|
24
25
|
return str(device)
|
|
@@ -29,6 +30,7 @@ def device_to_torch(wp_device):
|
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
def dtype_from_torch(torch_dtype):
|
|
33
|
+
"""Return the Warp dtype corresponding to a torch dtype."""
|
|
32
34
|
# initialize lookup table on first call to defer torch import
|
|
33
35
|
if dtype_from_torch.type_map is None:
|
|
34
36
|
import torch
|
|
@@ -42,7 +44,7 @@ def dtype_from_torch(torch_dtype):
|
|
|
42
44
|
torch.int16: warp.int16,
|
|
43
45
|
torch.int8: warp.int8,
|
|
44
46
|
torch.uint8: warp.uint8,
|
|
45
|
-
torch.bool: warp.
|
|
47
|
+
torch.bool: warp.bool,
|
|
46
48
|
# currently unsupported by Warp
|
|
47
49
|
# torch.bfloat16:
|
|
48
50
|
# torch.complex64:
|
|
@@ -61,14 +63,14 @@ dtype_from_torch.type_map = None
|
|
|
61
63
|
|
|
62
64
|
|
|
63
65
|
def dtype_is_compatible(torch_dtype, warp_dtype):
|
|
66
|
+
"""Evaluates whether the given torch dtype is compatible with the given warp dtype."""
|
|
64
67
|
# initialize lookup table on first call to defer torch import
|
|
65
68
|
if dtype_is_compatible.compatible_sets is None:
|
|
66
69
|
import torch
|
|
67
70
|
|
|
68
71
|
dtype_is_compatible.compatible_sets = {
|
|
69
72
|
torch.float64: {warp.float64},
|
|
70
|
-
|
|
71
|
-
torch.float32: {warp.float32, *warp.types.warp.types.vector_types},
|
|
73
|
+
torch.float32: {warp.float32},
|
|
72
74
|
torch.float16: {warp.float16},
|
|
73
75
|
# allow aliasing integer tensors as signed or unsigned integer arrays
|
|
74
76
|
torch.int64: {warp.int64, warp.uint64},
|
|
@@ -76,7 +78,7 @@ def dtype_is_compatible(torch_dtype, warp_dtype):
|
|
|
76
78
|
torch.int16: {warp.int16, warp.uint16},
|
|
77
79
|
torch.int8: {warp.int8, warp.uint8},
|
|
78
80
|
torch.uint8: {warp.uint8, warp.int8},
|
|
79
|
-
torch.bool: {warp.uint8, warp.int8},
|
|
81
|
+
torch.bool: {warp.bool, warp.uint8, warp.int8},
|
|
80
82
|
# currently unsupported by Warp
|
|
81
83
|
# torch.bfloat16:
|
|
82
84
|
# torch.complex64:
|
|
@@ -86,7 +88,10 @@ def dtype_is_compatible(torch_dtype, warp_dtype):
|
|
|
86
88
|
compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype)
|
|
87
89
|
|
|
88
90
|
if compatible_set is not None:
|
|
89
|
-
|
|
91
|
+
if hasattr(warp_dtype, "_wp_scalar_type_"):
|
|
92
|
+
return warp_dtype._wp_scalar_type_ in compatible_set
|
|
93
|
+
else:
|
|
94
|
+
return warp_dtype in compatible_set
|
|
90
95
|
else:
|
|
91
96
|
raise TypeError(f"Invalid or unsupported data type: {torch_dtype}")
|
|
92
97
|
|
|
@@ -96,14 +101,21 @@ dtype_is_compatible.compatible_sets = None
|
|
|
96
101
|
|
|
97
102
|
# wrap a torch tensor to a wp array, data is not copied
|
|
98
103
|
def from_torch(t, dtype=None, requires_grad=None, grad=None):
|
|
104
|
+
"""Wrap a PyTorch tensor to a Warp array without copying the data.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
t (torch.Tensor): The torch tensor to wrap.
|
|
108
|
+
dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
|
|
109
|
+
requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
warp.array: The wrapped array.
|
|
113
|
+
"""
|
|
99
114
|
if dtype is None:
|
|
100
115
|
dtype = dtype_from_torch(t.dtype)
|
|
101
116
|
elif not dtype_is_compatible(t.dtype, dtype):
|
|
102
117
|
raise RuntimeError(f"Incompatible data types: {t.dtype} and {dtype}")
|
|
103
118
|
|
|
104
|
-
if requires_grad is None:
|
|
105
|
-
requires_grad = t.requires_grad
|
|
106
|
-
|
|
107
119
|
# get size of underlying data type to compute strides
|
|
108
120
|
ctype_size = ctypes.sizeof(dtype._type_)
|
|
109
121
|
|
|
@@ -122,7 +134,7 @@ def from_torch(t, dtype=None, requires_grad=None, grad=None):
|
|
|
122
134
|
)
|
|
123
135
|
|
|
124
136
|
# ensure the inner strides are contiguous
|
|
125
|
-
stride =
|
|
137
|
+
stride = ctype_size
|
|
126
138
|
for i in range(dtype_dims):
|
|
127
139
|
if strides[-i - 1] != stride:
|
|
128
140
|
raise RuntimeError(
|
|
@@ -130,40 +142,60 @@ def from_torch(t, dtype=None, requires_grad=None, grad=None):
|
|
|
130
142
|
)
|
|
131
143
|
stride *= dtype_shape[-i - 1]
|
|
132
144
|
|
|
133
|
-
shape = tuple(shape[:-dtype_dims])
|
|
134
|
-
strides = tuple(strides[:-dtype_dims])
|
|
145
|
+
shape = tuple(shape[:-dtype_dims]) or (1,)
|
|
146
|
+
strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
|
|
135
147
|
|
|
136
|
-
|
|
148
|
+
requires_grad = t.requires_grad if requires_grad is None else requires_grad
|
|
137
149
|
if grad is not None:
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
150
|
+
if not isinstance(grad, warp.array):
|
|
151
|
+
import torch
|
|
152
|
+
|
|
153
|
+
if isinstance(grad, torch.Tensor):
|
|
154
|
+
grad = from_torch(grad, dtype=dtype)
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(f"Invalid gradient type: {type(grad)}")
|
|
157
|
+
elif requires_grad:
|
|
158
|
+
# wrap the tensor gradient, allocate if necessary
|
|
159
|
+
if t.grad is None:
|
|
160
|
+
# allocate a zero-filled gradient tensor if it doesn't exist
|
|
161
|
+
import torch
|
|
162
|
+
|
|
163
|
+
t.grad = torch.zeros_like(t, requires_grad=False)
|
|
164
|
+
grad = from_torch(t.grad, dtype=dtype)
|
|
146
165
|
|
|
147
166
|
a = warp.types.array(
|
|
148
167
|
ptr=t.data_ptr(),
|
|
149
|
-
grad_ptr=grad_ptr,
|
|
150
168
|
dtype=dtype,
|
|
151
169
|
shape=shape,
|
|
152
170
|
strides=strides,
|
|
171
|
+
device=device_from_torch(t.device),
|
|
153
172
|
copy=False,
|
|
154
173
|
owner=False,
|
|
174
|
+
grad=grad,
|
|
155
175
|
requires_grad=requires_grad,
|
|
156
|
-
device=device_from_torch(t.device),
|
|
157
176
|
)
|
|
158
177
|
|
|
159
178
|
# save a reference to the source tensor, otherwise it will be deallocated
|
|
160
|
-
a.
|
|
179
|
+
a._tensor = t
|
|
161
180
|
return a
|
|
162
181
|
|
|
163
182
|
|
|
164
|
-
def to_torch(a):
|
|
183
|
+
def to_torch(a, requires_grad=None):
|
|
184
|
+
"""
|
|
185
|
+
Convert a Warp array to a PyTorch tensor without copying the data.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
a (warp.array): The Warp array to convert.
|
|
189
|
+
requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
torch.Tensor: The converted tensor.
|
|
193
|
+
"""
|
|
165
194
|
import torch
|
|
166
195
|
|
|
196
|
+
if requires_grad is None:
|
|
197
|
+
requires_grad = a.requires_grad
|
|
198
|
+
|
|
167
199
|
# Torch does not support structured arrays
|
|
168
200
|
if isinstance(a.dtype, warp.codegen.Struct):
|
|
169
201
|
raise RuntimeError("Cannot convert structured Warp arrays to Torch.")
|
|
@@ -173,19 +205,28 @@ def to_torch(a):
|
|
|
173
205
|
# that support the __array_interface__ protocol
|
|
174
206
|
# in this case we need to workaround by going
|
|
175
207
|
# to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html
|
|
176
|
-
|
|
208
|
+
t = torch.as_tensor(numpy.asarray(a))
|
|
209
|
+
t.requires_grad = requires_grad
|
|
210
|
+
if requires_grad and a.requires_grad:
|
|
211
|
+
t.grad = torch.as_tensor(numpy.asarray(a.grad))
|
|
212
|
+
return t
|
|
177
213
|
|
|
178
214
|
elif a.device.is_cuda:
|
|
179
215
|
# Torch does support the __cuda_array_interface__
|
|
180
216
|
# correctly, but we must be sure to maintain a reference
|
|
181
217
|
# to the owning object to prevent memory allocs going out of scope
|
|
182
|
-
|
|
218
|
+
t = torch.as_tensor(a, device=device_to_torch(a.device))
|
|
219
|
+
t.requires_grad = requires_grad
|
|
220
|
+
if requires_grad and a.requires_grad:
|
|
221
|
+
t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device))
|
|
222
|
+
return t
|
|
183
223
|
|
|
184
224
|
else:
|
|
185
225
|
raise RuntimeError("Unsupported device")
|
|
186
226
|
|
|
187
227
|
|
|
188
228
|
def stream_from_torch(stream_or_device=None):
|
|
229
|
+
"""Convert from a PyTorch CUDA stream to a Warp.Stream."""
|
|
189
230
|
import torch
|
|
190
231
|
|
|
191
232
|
if isinstance(stream_or_device, torch.cuda.Stream):
|
|
@@ -205,6 +246,7 @@ def stream_from_torch(stream_or_device=None):
|
|
|
205
246
|
|
|
206
247
|
|
|
207
248
|
def stream_to_torch(stream_or_device=None):
|
|
249
|
+
"""Convert from a Warp.Stream to a PyTorch CUDA stream."""
|
|
208
250
|
import torch
|
|
209
251
|
|
|
210
252
|
if isinstance(stream_or_device, warp.Stream):
|