warp-lang 1.0.2__py3-none-manylinux2014_x86_64.whl → 1.2.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +88 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3693 -3354
- warp/codegen.py +2925 -2792
- warp/config.py +40 -36
- warp/constants.py +49 -45
- warp/context.py +5409 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +381 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
- warp/examples/benchmarks/benchmark_launches.py +293 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +232 -219
- warp/examples/core/example_fluid.py +291 -267
- warp/examples/core/example_graph_capture.py +142 -126
- warp/examples/core/example_marching_cubes.py +186 -174
- warp/examples/core/example_mesh.py +172 -155
- warp/examples/core/example_mesh_intersect.py +203 -193
- warp/examples/core/example_nvdb.py +174 -170
- warp/examples/core/example_raycast.py +103 -90
- warp/examples/core/example_raymarch.py +197 -178
- warp/examples/core/example_render_opengl.py +183 -141
- warp/examples/core/example_sph.py +403 -387
- warp/examples/core/example_torch.py +219 -181
- warp/examples/core/example_wave.py +261 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +432 -389
- warp/examples/fem/example_burgers.py +262 -0
- warp/examples/fem/example_convection_diffusion.py +180 -168
- warp/examples/fem/example_convection_diffusion_dg.py +217 -209
- warp/examples/fem/example_deformed_geometry.py +175 -159
- warp/examples/fem/example_diffusion.py +199 -173
- warp/examples/fem/example_diffusion_3d.py +178 -152
- warp/examples/fem/example_diffusion_mgpu.py +219 -214
- warp/examples/fem/example_mixed_elasticity.py +242 -222
- warp/examples/fem/example_navier_stokes.py +257 -243
- warp/examples/fem/example_stokes.py +218 -192
- warp/examples/fem/example_stokes_transfer.py +263 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +258 -246
- warp/examples/optim/example_cloth_throw.py +220 -209
- warp/examples/optim/example_diffray.py +564 -536
- warp/examples/optim/example_drone.py +862 -835
- warp/examples/optim/example_inverse_kinematics.py +174 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
- warp/examples/optim/example_spring_cage.py +237 -231
- warp/examples/optim/example_trajectory.py +221 -199
- warp/examples/optim/example_walker.py +304 -293
- warp/examples/sim/example_cartpole.py +137 -129
- warp/examples/sim/example_cloth.py +194 -186
- warp/examples/sim/example_granular.py +122 -111
- warp/examples/sim/example_granular_collision_sdf.py +195 -186
- warp/examples/sim/example_jacobian_ik.py +234 -214
- warp/examples/sim/example_particle_chain.py +116 -105
- warp/examples/sim/example_quadruped.py +191 -180
- warp/examples/sim/example_rigid_chain.py +195 -187
- warp/examples/sim/example_rigid_contact.py +187 -177
- warp/examples/sim/example_rigid_force.py +125 -125
- warp/examples/sim/example_rigid_gyroscopic.py +107 -95
- warp/examples/sim/example_rigid_soft_contact.py +132 -122
- warp/examples/sim/example_soft_body.py +188 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +61 -27
- warp/fem/cache.py +403 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +16 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +748 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +437 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/nanogrid.py +455 -0
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1684 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +179 -292
- warp/fem/space/basis_space.py +522 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +148 -267
- warp/fem/space/grid_3d_function_space.py +167 -306
- warp/fem/space/hexmesh_function_space.py +253 -352
- warp/fem/space/nanogrid_function_space.py +202 -0
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +261 -369
- warp/fem/space/restriction.py +161 -160
- warp/fem/space/shape/__init__.py +90 -15
- warp/fem/space/shape/cube_shape_function.py +728 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +224 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +153 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1081 -1025
- warp/native/builtin.h +1603 -1560
- warp/native/bvh.cpp +402 -398
- warp/native/bvh.cu +533 -525
- warp/native/bvh.h +430 -429
- warp/native/clang/clang.cpp +496 -464
- warp/native/crt.cpp +42 -32
- warp/native/crt.h +352 -335
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/exports.h +187 -0
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1545 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +292 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -4782
- warp/native/nanovdb/PNanoVDB.h +3390 -2553
- warp/native/noise.h +850 -850
- warp/native/quat.h +1112 -1085
- warp/native/rand.h +303 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1177 -1133
- warp/native/volume.cpp +529 -297
- warp/native/volume.cu +58 -32
- warp/native/volume.h +960 -538
- warp/native/volume_builder.cu +446 -425
- warp/native/volume_builder.h +34 -19
- warp/native/volume_impl.h +61 -0
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2949 -2828
- warp/native/warp.h +321 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3356 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1917 -1991
- warp/sim/integrator_xpbd.py +3288 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1289 -1227
- warp/stubs.py +2192 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +20 -22
- warp/tests/aux_test_grad_customs.py +21 -23
- warp/tests/aux_test_reference.py +9 -11
- warp/tests/aux_test_reference_reference.py +8 -10
- warp/tests/aux_test_square.py +15 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +237 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +155 -157
- warp/tests/test_arithmetic.py +1088 -1124
- warp/tests/test_array.py +2415 -2326
- warp/tests/test_array_reduce.py +148 -150
- warp/tests/test_async.py +666 -656
- warp/tests/test_atomic.py +139 -141
- warp/tests/test_bool.py +212 -149
- warp/tests/test_builtins_resolution.py +1290 -1292
- warp/tests/test_bvh.py +162 -171
- warp/tests/test_closest_point_edge_edge.py +227 -228
- warp/tests/test_codegen.py +562 -553
- warp/tests/test_compile_consts.py +217 -101
- warp/tests/test_conditional.py +244 -246
- warp/tests/test_copy.py +230 -215
- warp/tests/test_ctypes.py +630 -632
- warp/tests/test_dense.py +65 -67
- warp/tests/test_devices.py +89 -98
- warp/tests/test_dlpack.py +528 -529
- warp/tests/test_examples.py +403 -378
- warp/tests/test_fabricarray.py +952 -955
- warp/tests/test_fast_math.py +60 -54
- warp/tests/test_fem.py +1298 -1278
- warp/tests/test_fp16.py +128 -130
- warp/tests/test_func.py +336 -337
- warp/tests/test_generics.py +596 -571
- warp/tests/test_grad.py +885 -640
- warp/tests/test_grad_customs.py +331 -336
- warp/tests/test_hash_grid.py +208 -164
- warp/tests/test_import.py +37 -39
- warp/tests/test_indexedarray.py +1132 -1134
- warp/tests/test_intersect.py +65 -67
- warp/tests/test_jax.py +305 -307
- warp/tests/test_large.py +169 -164
- warp/tests/test_launch.py +352 -354
- warp/tests/test_lerp.py +217 -261
- warp/tests/test_linear_solvers.py +189 -171
- warp/tests/test_lvalue.py +419 -493
- warp/tests/test_marching_cubes.py +63 -65
- warp/tests/test_mat.py +1799 -1827
- warp/tests/test_mat_lite.py +113 -115
- warp/tests/test_mat_scalar_ops.py +2905 -2889
- warp/tests/test_math.py +124 -193
- warp/tests/test_matmul.py +498 -499
- warp/tests/test_matmul_lite.py +408 -410
- warp/tests/test_mempool.py +186 -190
- warp/tests/test_mesh.py +281 -324
- warp/tests/test_mesh_query_aabb.py +226 -241
- warp/tests/test_mesh_query_point.py +690 -702
- warp/tests/test_mesh_query_ray.py +290 -303
- warp/tests/test_mlp.py +274 -276
- warp/tests/test_model.py +108 -110
- warp/tests/test_module_hashing.py +111 -0
- warp/tests/test_modules_lite.py +36 -39
- warp/tests/test_multigpu.py +161 -163
- warp/tests/test_noise.py +244 -248
- warp/tests/test_operators.py +248 -250
- warp/tests/test_options.py +121 -125
- warp/tests/test_peer.py +131 -137
- warp/tests/test_pinned.py +76 -78
- warp/tests/test_print.py +52 -54
- warp/tests/test_quat.py +2084 -2086
- warp/tests/test_rand.py +324 -288
- warp/tests/test_reload.py +207 -217
- warp/tests/test_rounding.py +177 -179
- warp/tests/test_runlength_encode.py +188 -190
- warp/tests/test_sim_grad.py +241 -0
- warp/tests/test_sim_kinematics.py +89 -97
- warp/tests/test_smoothstep.py +166 -168
- warp/tests/test_snippet.py +303 -266
- warp/tests/test_sparse.py +466 -460
- warp/tests/test_spatial.py +2146 -2148
- warp/tests/test_special_values.py +362 -0
- warp/tests/test_streams.py +484 -473
- warp/tests/test_struct.py +708 -675
- warp/tests/test_tape.py +171 -148
- warp/tests/test_torch.py +741 -743
- warp/tests/test_transient_module.py +85 -87
- warp/tests/test_types.py +554 -659
- warp/tests/test_utils.py +488 -499
- warp/tests/test_vec.py +1262 -1268
- warp/tests/test_vec_lite.py +71 -73
- warp/tests/test_vec_scalar_ops.py +2097 -2099
- warp/tests/test_verify_fp.py +92 -94
- warp/tests/test_volume.py +961 -736
- warp/tests/test_volume_write.py +338 -265
- warp/tests/unittest_serial.py +38 -37
- warp/tests/unittest_suites.py +367 -359
- warp/tests/unittest_utils.py +434 -578
- warp/tests/unused_test_misc.py +69 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +563 -561
- warp/torch.py +321 -295
- warp/types.py +4941 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
- warp_lang-1.2.0.dist-info/RECORD +359 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
- warp/native/nanovdb/PNanoVDBWrite.h +0 -295
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
warp/torch.py
CHANGED
|
@@ -1,295 +1,321 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
# and proprietary rights in and to this software, related documentation
|
|
4
|
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
# distribution of this software and related documentation without an express
|
|
6
|
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
|
|
8
|
-
import ctypes
|
|
9
|
-
|
|
10
|
-
import
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
warp.
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
torch.
|
|
97
|
-
torch.
|
|
98
|
-
torch.
|
|
99
|
-
|
|
100
|
-
torch.
|
|
101
|
-
|
|
102
|
-
torch.
|
|
103
|
-
torch.
|
|
104
|
-
torch.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
shape=shape,
|
|
201
|
-
strides=strides,
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
1
|
+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import ctypes
|
|
9
|
+
|
|
10
|
+
import numpy
|
|
11
|
+
|
|
12
|
+
import warp
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# return the warp device corresponding to a torch device
|
|
16
|
+
def device_from_torch(torch_device) -> warp.context.Device:
|
|
17
|
+
"""Return the Warp device corresponding to a Torch device."""
|
|
18
|
+
return warp.get_device(str(torch_device))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def device_to_torch(warp_device: warp.context.Devicelike) -> str:
|
|
22
|
+
"""Return the Torch device string corresponding to a Warp device.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
warp_device: An identifier that can be resolved to a :class:`warp.context.Device`.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
RuntimeError: The Warp device is not compatible with PyTorch.
|
|
29
|
+
"""
|
|
30
|
+
device = warp.get_device(warp_device)
|
|
31
|
+
if device.is_cpu or device.is_primary:
|
|
32
|
+
return str(device)
|
|
33
|
+
elif device.is_cuda and device.is_uva:
|
|
34
|
+
# it's not a primary context, but torch can access the data ptr directly thanks to UVA
|
|
35
|
+
return f"cuda:{device.ordinal}"
|
|
36
|
+
raise RuntimeError(f"Warp device {device} is not compatible with torch")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def dtype_to_torch(warp_dtype):
|
|
40
|
+
"""Return the Torch dtype corresponding to a Warp dtype.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
warp_dtype: A Warp data type that has a corresponding ``torch.dtype``.
|
|
44
|
+
``warp.uint16``, ``warp.uint32``, and ``warp.uint64`` are mapped
|
|
45
|
+
to the signed integer ``torch.dtype`` of the same width.
|
|
46
|
+
Raises:
|
|
47
|
+
TypeError: Unable to find a corresponding PyTorch data type.
|
|
48
|
+
"""
|
|
49
|
+
# initialize lookup table on first call to defer torch import
|
|
50
|
+
if dtype_to_torch.type_map is None:
|
|
51
|
+
import torch
|
|
52
|
+
|
|
53
|
+
dtype_to_torch.type_map = {
|
|
54
|
+
warp.float16: torch.float16,
|
|
55
|
+
warp.float32: torch.float32,
|
|
56
|
+
warp.float64: torch.float64,
|
|
57
|
+
warp.int8: torch.int8,
|
|
58
|
+
warp.int16: torch.int16,
|
|
59
|
+
warp.int32: torch.int32,
|
|
60
|
+
warp.int64: torch.int64,
|
|
61
|
+
warp.uint8: torch.uint8,
|
|
62
|
+
# torch doesn't support unsigned ints bigger than 8 bits
|
|
63
|
+
warp.uint16: torch.int16,
|
|
64
|
+
warp.uint32: torch.int32,
|
|
65
|
+
warp.uint64: torch.int64,
|
|
66
|
+
warp.bool: torch.bool,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
torch_dtype = dtype_to_torch.type_map.get(warp_dtype)
|
|
70
|
+
if torch_dtype is not None:
|
|
71
|
+
return torch_dtype
|
|
72
|
+
else:
|
|
73
|
+
raise TypeError(f"Cannot convert {warp_dtype} to a Torch type")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def dtype_from_torch(torch_dtype):
|
|
77
|
+
"""Return the Warp dtype corresponding to a Torch dtype.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
torch_dtype: A ``torch.dtype`` that has a corresponding Warp data type.
|
|
81
|
+
Currently ``torch.bfloat16``, ``torch.complex64``, and
|
|
82
|
+
``torch.complex128`` are not supported.
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
TypeError: Unable to find a corresponding Warp data type.
|
|
86
|
+
"""
|
|
87
|
+
# initialize lookup table on first call to defer torch import
|
|
88
|
+
if dtype_from_torch.type_map is None:
|
|
89
|
+
import torch
|
|
90
|
+
|
|
91
|
+
dtype_from_torch.type_map = {
|
|
92
|
+
torch.float16: warp.float16,
|
|
93
|
+
torch.float32: warp.float32,
|
|
94
|
+
torch.float64: warp.float64,
|
|
95
|
+
torch.int8: warp.int8,
|
|
96
|
+
torch.int16: warp.int16,
|
|
97
|
+
torch.int32: warp.int32,
|
|
98
|
+
torch.int64: warp.int64,
|
|
99
|
+
torch.uint8: warp.uint8,
|
|
100
|
+
torch.bool: warp.bool,
|
|
101
|
+
# currently unsupported by Warp
|
|
102
|
+
# torch.bfloat16:
|
|
103
|
+
# torch.complex64:
|
|
104
|
+
# torch.complex128:
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
warp_dtype = dtype_from_torch.type_map.get(torch_dtype)
|
|
108
|
+
|
|
109
|
+
if warp_dtype is not None:
|
|
110
|
+
return warp_dtype
|
|
111
|
+
else:
|
|
112
|
+
raise TypeError(f"Cannot convert {torch_dtype} to a Warp type")
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def dtype_is_compatible(torch_dtype, warp_dtype) -> bool:
|
|
116
|
+
"""Evaluates whether the given torch dtype is compatible with the given Warp dtype."""
|
|
117
|
+
# initialize lookup table on first call to defer torch import
|
|
118
|
+
if dtype_is_compatible.compatible_sets is None:
|
|
119
|
+
import torch
|
|
120
|
+
|
|
121
|
+
dtype_is_compatible.compatible_sets = {
|
|
122
|
+
torch.float64: {warp.float64},
|
|
123
|
+
torch.float32: {warp.float32},
|
|
124
|
+
torch.float16: {warp.float16},
|
|
125
|
+
# allow aliasing integer tensors as signed or unsigned integer arrays
|
|
126
|
+
torch.int64: {warp.int64, warp.uint64},
|
|
127
|
+
torch.int32: {warp.int32, warp.uint32},
|
|
128
|
+
torch.int16: {warp.int16, warp.uint16},
|
|
129
|
+
torch.int8: {warp.int8, warp.uint8},
|
|
130
|
+
torch.uint8: {warp.uint8, warp.int8},
|
|
131
|
+
torch.bool: {warp.bool, warp.uint8, warp.int8},
|
|
132
|
+
# currently unsupported by Warp
|
|
133
|
+
# torch.bfloat16:
|
|
134
|
+
# torch.complex64:
|
|
135
|
+
# torch.complex128:
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
compatible_set = dtype_is_compatible.compatible_sets.get(torch_dtype)
|
|
139
|
+
|
|
140
|
+
if compatible_set is not None:
|
|
141
|
+
if warp_dtype in compatible_set:
|
|
142
|
+
return True
|
|
143
|
+
# check if it's a vector or matrix type
|
|
144
|
+
if hasattr(warp_dtype, "_wp_scalar_type_"):
|
|
145
|
+
return warp_dtype._wp_scalar_type_ in compatible_set
|
|
146
|
+
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# lookup tables initialized when needed
|
|
151
|
+
dtype_from_torch.type_map = None
|
|
152
|
+
dtype_to_torch.type_map = None
|
|
153
|
+
dtype_is_compatible.compatible_sets = None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# wrap a torch tensor to a wp array, data is not copied
|
|
157
|
+
def from_torch(t, dtype=None, requires_grad=None, grad=None):
|
|
158
|
+
"""Convert a Torch tensor to a Warp array without copying the data.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
t (torch.Tensor): The torch tensor to wrap.
|
|
162
|
+
dtype (warp.dtype, optional): The target data type of the resulting Warp array. Defaults to the tensor value type mapped to a Warp array value type.
|
|
163
|
+
requires_grad (bool, optional): Whether the resulting array should wrap the tensor's gradient, if it exists (the grad tensor will be allocated otherwise). Defaults to the tensor's `requires_grad` value.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
warp.array: The wrapped array.
|
|
167
|
+
"""
|
|
168
|
+
if dtype is None:
|
|
169
|
+
dtype = dtype_from_torch(t.dtype)
|
|
170
|
+
elif not dtype_is_compatible(t.dtype, dtype):
|
|
171
|
+
raise RuntimeError(f"Cannot convert Torch type {t.dtype} to Warp type {dtype}")
|
|
172
|
+
|
|
173
|
+
# get size of underlying data type to compute strides
|
|
174
|
+
ctype_size = ctypes.sizeof(dtype._type_)
|
|
175
|
+
|
|
176
|
+
shape = tuple(t.shape)
|
|
177
|
+
strides = tuple(s * ctype_size for s in t.stride())
|
|
178
|
+
device = device_from_torch(t.device)
|
|
179
|
+
|
|
180
|
+
# if target is a vector or matrix type
|
|
181
|
+
# then check if trailing dimensions match
|
|
182
|
+
# the target type and update the shape
|
|
183
|
+
if hasattr(dtype, "_shape_"):
|
|
184
|
+
dtype_shape = dtype._shape_
|
|
185
|
+
dtype_dims = len(dtype._shape_)
|
|
186
|
+
if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
|
|
187
|
+
raise RuntimeError(
|
|
188
|
+
f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# ensure the inner strides are contiguous
|
|
192
|
+
stride = ctype_size
|
|
193
|
+
for i in range(dtype_dims):
|
|
194
|
+
if strides[-i - 1] != stride:
|
|
195
|
+
raise RuntimeError(
|
|
196
|
+
f"Could not convert Torch tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
|
|
197
|
+
)
|
|
198
|
+
stride *= dtype_shape[-i - 1]
|
|
199
|
+
|
|
200
|
+
shape = tuple(shape[:-dtype_dims]) or (1,)
|
|
201
|
+
strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
|
|
202
|
+
|
|
203
|
+
requires_grad = t.requires_grad if requires_grad is None else requires_grad
|
|
204
|
+
if grad is not None:
|
|
205
|
+
if not isinstance(grad, warp.array):
|
|
206
|
+
import torch
|
|
207
|
+
|
|
208
|
+
if isinstance(grad, torch.Tensor):
|
|
209
|
+
grad = from_torch(grad, dtype=dtype)
|
|
210
|
+
else:
|
|
211
|
+
raise ValueError(f"Invalid gradient type: {type(grad)}")
|
|
212
|
+
elif requires_grad:
|
|
213
|
+
# wrap the tensor gradient, allocate if necessary
|
|
214
|
+
if t.grad is None:
|
|
215
|
+
# allocate a zero-filled gradient if it doesn't exist
|
|
216
|
+
# Note: we use Warp to allocate the shared gradient with compatible strides
|
|
217
|
+
grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device)
|
|
218
|
+
t.grad = to_torch(grad, requires_grad=False)
|
|
219
|
+
else:
|
|
220
|
+
# TODO: this will fail if the strides are incompatible
|
|
221
|
+
grad = from_torch(t.grad, dtype=dtype)
|
|
222
|
+
|
|
223
|
+
a = warp.array(
|
|
224
|
+
ptr=t.data_ptr(),
|
|
225
|
+
dtype=dtype,
|
|
226
|
+
shape=shape,
|
|
227
|
+
strides=strides,
|
|
228
|
+
device=device,
|
|
229
|
+
copy=False,
|
|
230
|
+
grad=grad,
|
|
231
|
+
requires_grad=requires_grad,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# save a reference to the source tensor, otherwise it will be deallocated
|
|
235
|
+
a._tensor = t
|
|
236
|
+
return a
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def to_torch(a, requires_grad=None):
|
|
240
|
+
"""
|
|
241
|
+
Convert a Warp array to a Torch tensor without copying the data.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
a (warp.array): The Warp array to convert.
|
|
245
|
+
requires_grad (bool, optional): Whether the resulting tensor should convert the array's gradient, if it exists, to a grad tensor. Defaults to the array's `requires_grad` value.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
torch.Tensor: The converted tensor.
|
|
249
|
+
"""
|
|
250
|
+
import torch
|
|
251
|
+
|
|
252
|
+
if requires_grad is None:
|
|
253
|
+
requires_grad = a.requires_grad
|
|
254
|
+
|
|
255
|
+
# Torch does not support structured arrays
|
|
256
|
+
if isinstance(a.dtype, warp.codegen.Struct):
|
|
257
|
+
raise RuntimeError("Cannot convert structured Warp arrays to Torch.")
|
|
258
|
+
|
|
259
|
+
if a.device.is_cpu:
|
|
260
|
+
# Torch has an issue wrapping CPU objects
|
|
261
|
+
# that support the __array_interface__ protocol
|
|
262
|
+
# in this case we need to workaround by going
|
|
263
|
+
# to an ndarray first, see https://pearu.github.io/array_interface_pytorch.html
|
|
264
|
+
t = torch.as_tensor(numpy.asarray(a))
|
|
265
|
+
t.requires_grad = requires_grad
|
|
266
|
+
if requires_grad and a.requires_grad:
|
|
267
|
+
t.grad = torch.as_tensor(numpy.asarray(a.grad))
|
|
268
|
+
return t
|
|
269
|
+
|
|
270
|
+
elif a.device.is_cuda:
|
|
271
|
+
# Torch does support the __cuda_array_interface__
|
|
272
|
+
# correctly, but we must be sure to maintain a reference
|
|
273
|
+
# to the owning object to prevent memory allocs going out of scope
|
|
274
|
+
t = torch.as_tensor(a, device=device_to_torch(a.device))
|
|
275
|
+
t.requires_grad = requires_grad
|
|
276
|
+
if requires_grad and a.requires_grad:
|
|
277
|
+
t.grad = torch.as_tensor(a.grad, device=device_to_torch(a.device))
|
|
278
|
+
return t
|
|
279
|
+
|
|
280
|
+
else:
|
|
281
|
+
raise RuntimeError("Unsupported device")
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def stream_from_torch(stream_or_device=None):
|
|
285
|
+
"""Convert from a Torch CUDA stream to a Warp CUDA stream."""
|
|
286
|
+
import torch
|
|
287
|
+
|
|
288
|
+
if isinstance(stream_or_device, torch.cuda.Stream):
|
|
289
|
+
stream = stream_or_device
|
|
290
|
+
else:
|
|
291
|
+
# assume arg is a torch device
|
|
292
|
+
stream = torch.cuda.current_stream(stream_or_device)
|
|
293
|
+
|
|
294
|
+
device = device_from_torch(stream.device)
|
|
295
|
+
|
|
296
|
+
warp_stream = warp.Stream(device, cuda_stream=stream.cuda_stream)
|
|
297
|
+
|
|
298
|
+
# save a reference to the source stream, otherwise it may be destroyed
|
|
299
|
+
warp_stream._torch_stream = stream
|
|
300
|
+
|
|
301
|
+
return warp_stream
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def stream_to_torch(stream_or_device=None):
|
|
305
|
+
"""Convert from a Warp CUDA stream to a Torch CUDA stream."""
|
|
306
|
+
import torch
|
|
307
|
+
|
|
308
|
+
if isinstance(stream_or_device, warp.Stream):
|
|
309
|
+
stream = stream_or_device
|
|
310
|
+
else:
|
|
311
|
+
# assume arg is a warp device
|
|
312
|
+
stream = warp.get_device(stream_or_device).stream
|
|
313
|
+
|
|
314
|
+
device = device_to_torch(stream.device)
|
|
315
|
+
|
|
316
|
+
torch_stream = torch.cuda.ExternalStream(stream.cuda_stream, device=device)
|
|
317
|
+
|
|
318
|
+
# save a reference to the source stream, otherwise it may be destroyed
|
|
319
|
+
torch_stream._warp_stream = stream
|
|
320
|
+
|
|
321
|
+
return torch_stream
|