warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +7 -6
- warp/build_dll.py +70 -79
- warp/builtins.py +10 -6
- warp/codegen.py +51 -19
- warp/config.py +7 -8
- warp/constants.py +3 -0
- warp/context.py +948 -245
- warp/dlpack.py +198 -113
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -0
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usda +42 -0
- warp/examples/assets/nv_ant.xml +92 -0
- warp/examples/assets/nv_humanoid.xml +183 -0
- warp/examples/assets/quadruped.urdf +268 -0
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usda +56 -0
- warp/examples/assets/torus.usda +105 -0
- warp/examples/benchmarks/benchmark_api.py +383 -0
- warp/examples/benchmarks/benchmark_cloth.py +279 -0
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
- warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
- warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
- warp/examples/benchmarks/benchmark_launches.py +295 -0
- warp/examples/core/example_dem.py +221 -0
- warp/examples/core/example_fluid.py +267 -0
- warp/examples/core/example_graph_capture.py +129 -0
- warp/examples/core/example_marching_cubes.py +177 -0
- warp/examples/core/example_mesh.py +154 -0
- warp/examples/core/example_mesh_intersect.py +193 -0
- warp/examples/core/example_nvdb.py +169 -0
- warp/examples/core/example_raycast.py +89 -0
- warp/examples/core/example_raymarch.py +178 -0
- warp/examples/core/example_render_opengl.py +141 -0
- warp/examples/core/example_sph.py +389 -0
- warp/examples/core/example_torch.py +181 -0
- warp/examples/core/example_wave.py +249 -0
- warp/examples/fem/bsr_utils.py +380 -0
- warp/examples/fem/example_apic_fluid.py +391 -0
- warp/examples/fem/example_convection_diffusion.py +168 -0
- warp/examples/fem/example_convection_diffusion_dg.py +209 -0
- warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
- warp/examples/fem/example_deformed_geometry.py +159 -0
- warp/examples/fem/example_diffusion.py +173 -0
- warp/examples/fem/example_diffusion_3d.py +152 -0
- warp/examples/fem/example_diffusion_mgpu.py +214 -0
- warp/examples/fem/example_mixed_elasticity.py +222 -0
- warp/examples/fem/example_navier_stokes.py +243 -0
- warp/examples/fem/example_stokes.py +192 -0
- warp/examples/fem/example_stokes_transfer.py +249 -0
- warp/examples/fem/mesh_utils.py +109 -0
- warp/examples/fem/plot_utils.py +287 -0
- warp/examples/optim/example_bounce.py +248 -0
- warp/examples/optim/example_cloth_throw.py +210 -0
- warp/examples/optim/example_diffray.py +535 -0
- warp/examples/optim/example_drone.py +850 -0
- warp/examples/optim/example_inverse_kinematics.py +169 -0
- warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
- warp/examples/optim/example_spring_cage.py +234 -0
- warp/examples/optim/example_trajectory.py +201 -0
- warp/examples/sim/example_cartpole.py +128 -0
- warp/examples/sim/example_cloth.py +184 -0
- warp/examples/sim/example_granular.py +113 -0
- warp/examples/sim/example_granular_collision_sdf.py +185 -0
- warp/examples/sim/example_jacobian_ik.py +213 -0
- warp/examples/sim/example_particle_chain.py +106 -0
- warp/examples/sim/example_quadruped.py +179 -0
- warp/examples/sim/example_rigid_chain.py +191 -0
- warp/examples/sim/example_rigid_contact.py +176 -0
- warp/examples/sim/example_rigid_force.py +126 -0
- warp/examples/sim/example_rigid_gyroscopic.py +97 -0
- warp/examples/sim/example_rigid_soft_contact.py +124 -0
- warp/examples/sim/example_soft_body.py +178 -0
- warp/fabric.py +29 -20
- warp/fem/cache.py +0 -1
- warp/fem/dirichlet.py +0 -2
- warp/fem/integrate.py +0 -1
- warp/jax.py +45 -0
- warp/jax_experimental.py +339 -0
- warp/native/builtin.h +12 -0
- warp/native/bvh.cu +18 -18
- warp/native/clang/clang.cpp +8 -3
- warp/native/cuda_util.cpp +94 -5
- warp/native/cuda_util.h +35 -6
- warp/native/cutlass_gemm.cpp +1 -1
- warp/native/cutlass_gemm.cu +4 -1
- warp/native/error.cpp +66 -0
- warp/native/error.h +27 -0
- warp/native/mesh.cu +2 -2
- warp/native/reduce.cu +4 -4
- warp/native/runlength_encode.cu +2 -2
- warp/native/scan.cu +2 -2
- warp/native/sparse.cu +0 -1
- warp/native/temp_buffer.h +2 -2
- warp/native/warp.cpp +95 -60
- warp/native/warp.cu +1053 -218
- warp/native/warp.h +49 -32
- warp/optim/linear.py +33 -16
- warp/render/render_opengl.py +202 -101
- warp/render/render_usd.py +82 -40
- warp/sim/__init__.py +13 -4
- warp/sim/articulation.py +4 -5
- warp/sim/collide.py +320 -175
- warp/sim/import_mjcf.py +25 -30
- warp/sim/import_urdf.py +94 -63
- warp/sim/import_usd.py +51 -36
- warp/sim/inertia.py +3 -2
- warp/sim/integrator.py +233 -0
- warp/sim/integrator_euler.py +447 -469
- warp/sim/integrator_featherstone.py +1991 -0
- warp/sim/integrator_xpbd.py +1420 -640
- warp/sim/model.py +765 -487
- warp/sim/particles.py +2 -1
- warp/sim/render.py +35 -13
- warp/sim/utils.py +222 -11
- warp/stubs.py +8 -0
- warp/tape.py +16 -1
- warp/tests/aux_test_grad_customs.py +23 -0
- warp/tests/test_array.py +190 -1
- warp/tests/test_async.py +656 -0
- warp/tests/test_bool.py +50 -0
- warp/tests/test_dlpack.py +164 -11
- warp/tests/test_examples.py +166 -74
- warp/tests/test_fem.py +8 -1
- warp/tests/test_generics.py +15 -5
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +172 -12
- warp/tests/test_jax.py +254 -0
- warp/tests/test_large.py +29 -6
- warp/tests/test_launch.py +25 -0
- warp/tests/test_linear_solvers.py +20 -3
- warp/tests/test_matmul.py +61 -16
- warp/tests/test_matmul_lite.py +13 -13
- warp/tests/test_mempool.py +186 -0
- warp/tests/test_multigpu.py +3 -0
- warp/tests/test_options.py +16 -2
- warp/tests/test_peer.py +137 -0
- warp/tests/test_print.py +3 -1
- warp/tests/test_quat.py +23 -0
- warp/tests/test_sim_kinematics.py +97 -0
- warp/tests/test_snippet.py +126 -3
- warp/tests/test_streams.py +108 -79
- warp/tests/test_torch.py +16 -8
- warp/tests/test_utils.py +32 -27
- warp/tests/test_verify_fp.py +65 -0
- warp/tests/test_volume.py +1 -1
- warp/tests/unittest_serial.py +2 -0
- warp/tests/unittest_suites.py +12 -0
- warp/tests/unittest_utils.py +14 -7
- warp/thirdparty/unittest_parallel.py +15 -3
- warp/torch.py +10 -8
- warp/types.py +363 -246
- warp/utils.py +143 -19
- warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
- warp_lang-1.0.0.dist-info/METADATA +394 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
- warp/sim/optimizer.py +0 -138
- warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
- warp_lang-0.11.0.dist-info/METADATA +0 -238
- /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
- {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import unittest
|
|
9
|
+
|
|
10
|
+
import warp as wp
|
|
11
|
+
from warp.tests.unittest_utils import *
|
|
12
|
+
|
|
13
|
+
wp.init()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def setUpModule():
|
|
17
|
+
wp.config.verify_fp = True # Enable checking floating-point values to be finite
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def tearDownModule():
|
|
21
|
+
wp.config.verify_fp = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@wp.struct
|
|
25
|
+
class TestStruct:
|
|
26
|
+
field: wp.float32
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@wp.kernel
|
|
30
|
+
def finite_kernel(foos: wp.array(dtype=TestStruct)):
|
|
31
|
+
i = wp.tid()
|
|
32
|
+
foos[i].field += wp.float32(1.0)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_finite(test, device):
|
|
36
|
+
foos = wp.zeros((10,), dtype=TestStruct, device=device)
|
|
37
|
+
|
|
38
|
+
wp.launch(
|
|
39
|
+
kernel=finite_kernel,
|
|
40
|
+
dim=(10,),
|
|
41
|
+
inputs=[foos],
|
|
42
|
+
device=device,
|
|
43
|
+
)
|
|
44
|
+
wp.synchronize()
|
|
45
|
+
|
|
46
|
+
expected = TestStruct()
|
|
47
|
+
expected.field = 1.0
|
|
48
|
+
for f in foos.list():
|
|
49
|
+
if f.field != expected.field:
|
|
50
|
+
raise AssertionError(f"Unexpected result, got: {f} expected: {expected}")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
devices = get_test_devices()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TestVerifyFP(unittest.TestCase):
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
add_function_test(TestVerifyFP, "test_finite", test_finite, devices=devices)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
if __name__ == "__main__":
|
|
64
|
+
wp.build.clear_kernel_cache()
|
|
65
|
+
unittest.main(verbosity=2)
|
warp/tests/test_volume.py
CHANGED
|
@@ -627,7 +627,7 @@ def test_volume_from_numpy(test, device):
|
|
|
627
627
|
|
|
628
628
|
sphere_vdb_array = sphere_vdb.array()
|
|
629
629
|
test.assertEqual(sphere_vdb_array.dtype, wp.uint8)
|
|
630
|
-
test.
|
|
630
|
+
test.assertIsNone(sphere_vdb_array.deleter)
|
|
631
631
|
|
|
632
632
|
|
|
633
633
|
class TestVolume(unittest.TestCase):
|
warp/tests/unittest_serial.py
CHANGED
warp/tests/unittest_suites.py
CHANGED
|
@@ -91,6 +91,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
91
91
|
from warp.tests.test_arithmetic import TestArithmetic
|
|
92
92
|
from warp.tests.test_array import TestArray
|
|
93
93
|
from warp.tests.test_array_reduce import TestArrayReduce
|
|
94
|
+
from warp.tests.test_async import TestAsync
|
|
94
95
|
from warp.tests.test_atomic import TestAtomic
|
|
95
96
|
from warp.tests.test_bool import TestBool
|
|
96
97
|
from warp.tests.test_builtins_resolution import TestBuiltinsResolution
|
|
@@ -117,6 +118,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
117
118
|
from warp.tests.test_import import TestImport
|
|
118
119
|
from warp.tests.test_indexedarray import TestIndexedArray
|
|
119
120
|
from warp.tests.test_intersect import TestIntersect
|
|
121
|
+
from warp.tests.test_jax import TestJax
|
|
120
122
|
from warp.tests.test_large import TestLarge
|
|
121
123
|
from warp.tests.test_launch import TestLaunch
|
|
122
124
|
from warp.tests.test_lerp import TestLerp
|
|
@@ -129,6 +131,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
129
131
|
from warp.tests.test_math import TestMath
|
|
130
132
|
from warp.tests.test_matmul import TestMatmul
|
|
131
133
|
from warp.tests.test_matmul_lite import TestMatmulLite
|
|
134
|
+
from warp.tests.test_mempool import TestMempool
|
|
132
135
|
from warp.tests.test_mesh import TestMesh
|
|
133
136
|
from warp.tests.test_mesh_query_aabb import TestMeshQueryAABBMethods
|
|
134
137
|
from warp.tests.test_mesh_query_point import TestMeshQueryPoint
|
|
@@ -140,6 +143,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
140
143
|
from warp.tests.test_noise import TestNoise
|
|
141
144
|
from warp.tests.test_operators import TestOperators
|
|
142
145
|
from warp.tests.test_options import TestOptions
|
|
146
|
+
from warp.tests.test_peer import TestPeer
|
|
143
147
|
from warp.tests.test_pinned import TestPinned
|
|
144
148
|
from warp.tests.test_print import TestPrint
|
|
145
149
|
from warp.tests.test_quat import TestQuat
|
|
@@ -147,6 +151,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
147
151
|
from warp.tests.test_reload import TestReload
|
|
148
152
|
from warp.tests.test_rounding import TestRounding
|
|
149
153
|
from warp.tests.test_runlength_encode import TestRunlengthEncode
|
|
154
|
+
from warp.tests.test_sim_kinematics import TestSimKinematics
|
|
150
155
|
from warp.tests.test_smoothstep import TestSmoothstep
|
|
151
156
|
from warp.tests.test_snippet import TestSnippets
|
|
152
157
|
from warp.tests.test_sparse import TestSparse
|
|
@@ -161,6 +166,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
161
166
|
from warp.tests.test_vec import TestVec
|
|
162
167
|
from warp.tests.test_vec_lite import TestVecLite
|
|
163
168
|
from warp.tests.test_vec_scalar_ops import TestVecScalarOps
|
|
169
|
+
from warp.tests.test_verify_fp import TestVerifyFP
|
|
164
170
|
from warp.tests.test_volume import TestVolume
|
|
165
171
|
from warp.tests.test_volume_write import TestVolumeWrite
|
|
166
172
|
|
|
@@ -169,6 +175,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
169
175
|
TestArithmetic,
|
|
170
176
|
TestArray,
|
|
171
177
|
TestArrayReduce,
|
|
178
|
+
TestAsync,
|
|
172
179
|
TestAtomic,
|
|
173
180
|
TestBool,
|
|
174
181
|
TestBuiltinsResolution,
|
|
@@ -198,6 +205,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
198
205
|
TestImport,
|
|
199
206
|
TestIndexedArray,
|
|
200
207
|
TestIntersect,
|
|
208
|
+
TestJax,
|
|
201
209
|
TestLarge,
|
|
202
210
|
TestLaunch,
|
|
203
211
|
TestLerp,
|
|
@@ -210,6 +218,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
210
218
|
TestMath,
|
|
211
219
|
TestMatmul,
|
|
212
220
|
TestMatmulLite,
|
|
221
|
+
TestMempool,
|
|
213
222
|
TestMesh,
|
|
214
223
|
TestMeshQueryAABBMethods,
|
|
215
224
|
TestMeshQueryPoint,
|
|
@@ -221,6 +230,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
221
230
|
TestNoise,
|
|
222
231
|
TestOperators,
|
|
223
232
|
TestOptions,
|
|
233
|
+
TestPeer,
|
|
224
234
|
TestPinned,
|
|
225
235
|
TestPrint,
|
|
226
236
|
TestQuat,
|
|
@@ -228,6 +238,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
228
238
|
TestReload,
|
|
229
239
|
TestRounding,
|
|
230
240
|
TestRunlengthEncode,
|
|
241
|
+
TestSimKinematics,
|
|
231
242
|
TestSmoothstep,
|
|
232
243
|
TestSparse,
|
|
233
244
|
TestSnippets,
|
|
@@ -242,6 +253,7 @@ def default_suite(test_loader: unittest.TestLoader):
|
|
|
242
253
|
TestVec,
|
|
243
254
|
TestVecLite,
|
|
244
255
|
TestVecScalarOps,
|
|
256
|
+
TestVerifyFP,
|
|
245
257
|
TestVolume,
|
|
246
258
|
TestVolumeWrite,
|
|
247
259
|
]
|
warp/tests/unittest_utils.py
CHANGED
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
import ctypes
|
|
9
9
|
import ctypes.util
|
|
10
|
+
import math
|
|
10
11
|
import os
|
|
11
12
|
import sys
|
|
12
13
|
import time
|
|
@@ -101,6 +102,11 @@ def get_test_devices(mode=None):
|
|
|
101
102
|
return devices
|
|
102
103
|
|
|
103
104
|
|
|
105
|
+
def get_cuda_test_devices(mode=None):
|
|
106
|
+
devices = get_test_devices(mode=mode)
|
|
107
|
+
return [d for d in devices if d.is_cuda]
|
|
108
|
+
|
|
109
|
+
|
|
104
110
|
# redirects and captures all stdout output (including from C-libs)
|
|
105
111
|
class StdOutCapture:
|
|
106
112
|
def begin(self):
|
|
@@ -127,11 +133,12 @@ class StdOutCapture:
|
|
|
127
133
|
sys.stdout = self.tempfile
|
|
128
134
|
|
|
129
135
|
def end(self):
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
136
|
+
# The following sleep doesn't seem to fix the test_print failure on Windows
|
|
137
|
+
# if sys.platform == "win32":
|
|
138
|
+
# # Workaround for what seems to be a Windows-specific bug where
|
|
139
|
+
# # the output of CUDA's `printf` is not being immediately flushed
|
|
140
|
+
# # despite the context synchronisation.
|
|
141
|
+
# time.sleep(0.01)
|
|
135
142
|
|
|
136
143
|
os.dup2(self.target, self.saved.fileno())
|
|
137
144
|
os.close(self.target)
|
|
@@ -185,9 +192,9 @@ def assert_np_equal(result, expect, tol=0.0):
|
|
|
185
192
|
else:
|
|
186
193
|
delta = a - b
|
|
187
194
|
err = np.max(np.abs(delta))
|
|
188
|
-
if err > tol:
|
|
195
|
+
if err > tol or math.isnan(err):
|
|
189
196
|
raise AssertionError(
|
|
190
|
-
f"Maximum expected error exceeds tolerance got: {a}, expected: {b}, with err: {err} > {tol}"
|
|
197
|
+
f"Maximum expected error exceeds absolute tolerance got: {a}, expected: {b}, with err: {err} > {tol}"
|
|
191
198
|
)
|
|
192
199
|
|
|
193
200
|
|
|
@@ -53,7 +53,16 @@ def main(argv=None):
|
|
|
53
53
|
"""
|
|
54
54
|
|
|
55
55
|
# Command line arguments
|
|
56
|
-
parser = argparse.ArgumentParser(
|
|
56
|
+
parser = argparse.ArgumentParser(
|
|
57
|
+
prog="unittest-parallel",
|
|
58
|
+
# NVIDIA Modifications follow:
|
|
59
|
+
formatter_class=argparse.RawTextHelpFormatter,
|
|
60
|
+
epilog="""Example usage:
|
|
61
|
+
python -m warp.tests -s autodetect -p 'test_a*.py'
|
|
62
|
+
python -m warp.tests -s kit
|
|
63
|
+
python -m warp.tests -k 'mgpu' -k 'cuda'
|
|
64
|
+
""",
|
|
65
|
+
)
|
|
57
66
|
# parser.add_argument("-v", "--verbose", action="store_const", const=2, default=1, help="Verbose output")
|
|
58
67
|
parser.add_argument("-q", "--quiet", dest="verbose", action="store_const", const=0, default=2, help="Quiet output")
|
|
59
68
|
parser.add_argument("-f", "--failfast", action="store_true", default=False, help="Stop on first fail or error")
|
|
@@ -238,6 +247,8 @@ def main(argv=None):
|
|
|
238
247
|
|
|
239
248
|
import warp as wp
|
|
240
249
|
|
|
250
|
+
wp.init()
|
|
251
|
+
|
|
241
252
|
# force rebuild of all kernels
|
|
242
253
|
wp.build.clear_kernel_cache()
|
|
243
254
|
print("Cleared Warp kernel cache")
|
|
@@ -449,6 +460,7 @@ class ParallelTestManager:
|
|
|
449
460
|
# Clean up kernel cache (NVIDIA modification)
|
|
450
461
|
import warp as wp
|
|
451
462
|
|
|
463
|
+
wp.init()
|
|
452
464
|
wp.build.clear_kernel_cache()
|
|
453
465
|
|
|
454
466
|
# Return (test_count, errors, failures, skipped_count, expected_failure_count, unexpected_success_count)
|
|
@@ -537,11 +549,11 @@ def set_worker_cache(args, temp_dir):
|
|
|
537
549
|
cache_root_dir = os.path.join(os.getenv("WARP_CACHE_ROOT"), f"{wp.config.version}-{pid}")
|
|
538
550
|
else:
|
|
539
551
|
cache_root_dir = appdirs.user_cache_dir(
|
|
540
|
-
appname="warp", appauthor="NVIDIA
|
|
552
|
+
appname="warp", appauthor="NVIDIA", version=f"{wp.config.version}-{pid}"
|
|
541
553
|
)
|
|
542
554
|
|
|
543
555
|
wp.config.kernel_cache_dir = cache_root_dir
|
|
544
|
-
|
|
556
|
+
wp.init()
|
|
545
557
|
wp.build.clear_kernel_cache()
|
|
546
558
|
|
|
547
559
|
|
warp/torch.py
CHANGED
|
@@ -121,6 +121,7 @@ def from_torch(t, dtype=None, requires_grad=None, grad=None):
|
|
|
121
121
|
|
|
122
122
|
shape = tuple(t.shape)
|
|
123
123
|
strides = tuple(s * ctype_size for s in t.stride())
|
|
124
|
+
device = device_from_torch(t.device)
|
|
124
125
|
|
|
125
126
|
# if target is a vector or matrix type
|
|
126
127
|
# then check if trailing dimensions match
|
|
@@ -157,20 +158,21 @@ def from_torch(t, dtype=None, requires_grad=None, grad=None):
|
|
|
157
158
|
elif requires_grad:
|
|
158
159
|
# wrap the tensor gradient, allocate if necessary
|
|
159
160
|
if t.grad is None:
|
|
160
|
-
# allocate a zero-filled gradient
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
t.grad =
|
|
164
|
-
|
|
161
|
+
# allocate a zero-filled gradient if it doesn't exist
|
|
162
|
+
# Note: we use Warp to allocate the shared gradient with compatible strides
|
|
163
|
+
grad = warp.zeros(dtype=dtype, shape=shape, strides=strides, device=device)
|
|
164
|
+
t.grad = to_torch(grad, requires_grad=False)
|
|
165
|
+
else:
|
|
166
|
+
# TODO: this will fail if the strides are incompatible
|
|
167
|
+
grad = from_torch(t.grad, dtype=dtype)
|
|
165
168
|
|
|
166
|
-
a = warp.
|
|
169
|
+
a = warp.array(
|
|
167
170
|
ptr=t.data_ptr(),
|
|
168
171
|
dtype=dtype,
|
|
169
172
|
shape=shape,
|
|
170
173
|
strides=strides,
|
|
171
|
-
device=
|
|
174
|
+
device=device,
|
|
172
175
|
copy=False,
|
|
173
|
-
owner=False,
|
|
174
176
|
grad=grad,
|
|
175
177
|
requires_grad=requires_grad,
|
|
176
178
|
)
|