warp-lang 1.4.2__py3-none-macosx_10_13_universal2.whl → 1.5.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1819 -7
- warp/codegen.py +197 -61
- warp/config.py +2 -2
- warp/context.py +379 -107
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +4 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -7
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +604 -0
- warp/native/cuda_util.cpp +68 -51
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1854 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +140 -67
- warp/sim/graph_coloring.py +292 -0
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +109 -32
- warp/sparse.py +1 -1
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +251 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +21 -5
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +34 -4
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_matmul.py +6 -9
- warp/tests/test_matmul_lite.py +6 -11
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_overwrite.py +45 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_static.py +3 -3
- warp/tests/test_tile.py +744 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -2
- warp/types.py +340 -74
- warp/utils.py +23 -3
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +160 -133
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/tests/unittest_suites.py
CHANGED
|
@@ -99,8 +99,11 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
99
99
|
from warp.tests.test_closest_point_edge_edge import TestClosestPointEdgeEdgeMethods
|
|
100
100
|
from warp.tests.test_codegen import TestCodeGen
|
|
101
101
|
from warp.tests.test_codegen_instancing import TestCodeGenInstancing
|
|
102
|
+
from warp.tests.test_collision import TestCollision
|
|
103
|
+
from warp.tests.test_coloring import TestColoring
|
|
102
104
|
from warp.tests.test_compile_consts import TestConstants
|
|
103
105
|
from warp.tests.test_conditional import TestConditional
|
|
106
|
+
from warp.tests.test_context import TestContext
|
|
104
107
|
from warp.tests.test_copy import TestCopy
|
|
105
108
|
from warp.tests.test_ctypes import TestCTypes
|
|
106
109
|
from warp.tests.test_dense import TestDense
|
|
@@ -115,7 +118,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
115
118
|
)
|
|
116
119
|
from warp.tests.test_fabricarray import TestFabricArray
|
|
117
120
|
from warp.tests.test_fast_math import TestFastMath
|
|
118
|
-
from warp.tests.test_fem import TestFem, TestFemShapeFunctions
|
|
121
|
+
from warp.tests.test_fem import TestFem, TestFemShapeFunctions, TestFemUtilities
|
|
119
122
|
from warp.tests.test_fp16 import TestFp16
|
|
120
123
|
from warp.tests.test_func import TestFunc
|
|
121
124
|
from warp.tests.test_future_annotations import TestFutureAnnotations
|
|
@@ -127,6 +130,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
127
130
|
from warp.tests.test_import import TestImport
|
|
128
131
|
from warp.tests.test_indexedarray import TestIndexedArray
|
|
129
132
|
from warp.tests.test_intersect import TestIntersect
|
|
133
|
+
from warp.tests.test_iter import TestIter
|
|
130
134
|
from warp.tests.test_jax import TestJax
|
|
131
135
|
from warp.tests.test_large import TestLarge
|
|
132
136
|
from warp.tests.test_launch import TestLaunch
|
|
@@ -174,6 +178,10 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
174
178
|
from warp.tests.test_streams import TestStreams
|
|
175
179
|
from warp.tests.test_struct import TestStruct
|
|
176
180
|
from warp.tests.test_tape import TestTape
|
|
181
|
+
from warp.tests.test_tile import TestTile
|
|
182
|
+
from warp.tests.test_tile_mathdx import TestTileMathDx
|
|
183
|
+
from warp.tests.test_tile_reduce import TestTileReduce
|
|
184
|
+
from warp.tests.test_tile_shared_memory import TestTileSharedMemory
|
|
177
185
|
from warp.tests.test_torch import TestTorch
|
|
178
186
|
from warp.tests.test_transient_module import TestTransientModule
|
|
179
187
|
from warp.tests.test_triangle_closest_point import TestTriangleClosestPoint
|
|
@@ -200,8 +208,11 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
200
208
|
TestClosestPointEdgeEdgeMethods,
|
|
201
209
|
TestCodeGen,
|
|
202
210
|
TestCodeGenInstancing,
|
|
203
|
-
|
|
211
|
+
TestCollision,
|
|
212
|
+
TestColoring,
|
|
204
213
|
TestConditional,
|
|
214
|
+
TestConstants,
|
|
215
|
+
TestContext,
|
|
205
216
|
TestCopy,
|
|
206
217
|
TestCTypes,
|
|
207
218
|
TestDense,
|
|
@@ -216,6 +227,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
216
227
|
TestFastMath,
|
|
217
228
|
TestFem,
|
|
218
229
|
TestFemShapeFunctions,
|
|
230
|
+
TestFemUtilities,
|
|
219
231
|
TestFp16,
|
|
220
232
|
TestFunc,
|
|
221
233
|
TestFutureAnnotations,
|
|
@@ -227,6 +239,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
227
239
|
TestImport,
|
|
228
240
|
TestIndexedArray,
|
|
229
241
|
TestIntersect,
|
|
242
|
+
TestIter,
|
|
230
243
|
TestJax,
|
|
231
244
|
TestLarge,
|
|
232
245
|
TestLaunch,
|
|
@@ -274,6 +287,10 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
274
287
|
TestStreams,
|
|
275
288
|
TestStruct,
|
|
276
289
|
TestTape,
|
|
290
|
+
TestTile,
|
|
291
|
+
TestTileMathDx,
|
|
292
|
+
TestTileReduce,
|
|
293
|
+
TestTileSharedMemory,
|
|
277
294
|
TestTorch,
|
|
278
295
|
TestTransientModule,
|
|
279
296
|
TestTriangleClosestPoint,
|
warp/tests/unittest_utils.py
CHANGED
|
@@ -58,7 +58,6 @@ def get_selected_cuda_test_devices(mode: Optional[str] = None):
|
|
|
58
58
|
"""
|
|
59
59
|
|
|
60
60
|
if mode is None:
|
|
61
|
-
global test_mode
|
|
62
61
|
mode = test_mode
|
|
63
62
|
|
|
64
63
|
if mode == "basic":
|
|
@@ -98,7 +97,6 @@ def get_test_devices(mode: Optional[str] = None):
|
|
|
98
97
|
"all": Returns all available devices.
|
|
99
98
|
"""
|
|
100
99
|
if mode is None:
|
|
101
|
-
global test_mode
|
|
102
100
|
mode = test_mode
|
|
103
101
|
|
|
104
102
|
devices = []
|
|
@@ -232,6 +230,10 @@ def create_test_func(func, device, check_output, **kwargs):
|
|
|
232
230
|
else:
|
|
233
231
|
func(self, device, **kwargs)
|
|
234
232
|
|
|
233
|
+
# Copy the __unittest_expecting_failure__ attribute from func to test_func
|
|
234
|
+
if hasattr(func, "__unittest_expecting_failure__"):
|
|
235
|
+
test_func.__unittest_expecting_failure__ = func.__unittest_expecting_failure__
|
|
236
|
+
|
|
235
237
|
return test_func
|
|
236
238
|
|
|
237
239
|
|
warp/types.py
CHANGED
|
@@ -12,9 +12,10 @@ import ctypes
|
|
|
12
12
|
import inspect
|
|
13
13
|
import struct
|
|
14
14
|
import zlib
|
|
15
|
-
from typing import Any, Callable, Generic, List, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
|
|
15
|
+
from typing import Any, Callable, Generic, List, Literal, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
|
+
import numpy.typing as npt
|
|
18
19
|
|
|
19
20
|
import warp
|
|
20
21
|
|
|
@@ -100,8 +101,10 @@ def vector(length, dtype):
|
|
|
100
101
|
|
|
101
102
|
if dtype is bool:
|
|
102
103
|
_type_ = ctypes.c_bool
|
|
103
|
-
elif dtype in
|
|
104
|
+
elif dtype in (Scalar, Float):
|
|
104
105
|
_type_ = ctypes.c_float
|
|
106
|
+
elif dtype is Int:
|
|
107
|
+
_type_ = ctypes.c_int
|
|
105
108
|
else:
|
|
106
109
|
_type_ = dtype._type_
|
|
107
110
|
|
|
@@ -289,8 +292,10 @@ def matrix(shape, dtype):
|
|
|
289
292
|
|
|
290
293
|
if dtype is bool:
|
|
291
294
|
_type_ = ctypes.c_bool
|
|
292
|
-
elif dtype in
|
|
295
|
+
elif dtype in (Scalar, Float):
|
|
293
296
|
_type_ = ctypes.c_float
|
|
297
|
+
elif dtype is Int:
|
|
298
|
+
_type_ = ctypes.c_int
|
|
294
299
|
else:
|
|
295
300
|
_type_ = dtype._type_
|
|
296
301
|
|
|
@@ -338,8 +343,8 @@ def matrix(shape, dtype):
|
|
|
338
343
|
f"Invalid argument in matrix constructor, expected row of length {self._shape_[1]}, got {row}"
|
|
339
344
|
)
|
|
340
345
|
offset = i * self._shape_[1]
|
|
341
|
-
for
|
|
342
|
-
super().__setitem__(offset +
|
|
346
|
+
for j in range(self._shape_[1]):
|
|
347
|
+
super().__setitem__(offset + j, mat_t.scalar_import(row[j]))
|
|
343
348
|
else:
|
|
344
349
|
raise ValueError(
|
|
345
350
|
f"Invalid number of arguments in matrix constructor, expected {self._length_} elements, got {num_args}"
|
|
@@ -991,43 +996,6 @@ vector_types = (
|
|
|
991
996
|
spatial_matrixd,
|
|
992
997
|
)
|
|
993
998
|
|
|
994
|
-
atomic_vector_types = (
|
|
995
|
-
vec2i,
|
|
996
|
-
vec2ui,
|
|
997
|
-
vec2l,
|
|
998
|
-
vec2ul,
|
|
999
|
-
vec2h,
|
|
1000
|
-
vec2f,
|
|
1001
|
-
vec2d,
|
|
1002
|
-
vec3i,
|
|
1003
|
-
vec3ui,
|
|
1004
|
-
vec3l,
|
|
1005
|
-
vec3ul,
|
|
1006
|
-
vec3h,
|
|
1007
|
-
vec3f,
|
|
1008
|
-
vec3d,
|
|
1009
|
-
vec4i,
|
|
1010
|
-
vec4ui,
|
|
1011
|
-
vec4l,
|
|
1012
|
-
vec4ul,
|
|
1013
|
-
vec4h,
|
|
1014
|
-
vec4f,
|
|
1015
|
-
vec4d,
|
|
1016
|
-
mat22h,
|
|
1017
|
-
mat22f,
|
|
1018
|
-
mat22d,
|
|
1019
|
-
mat33h,
|
|
1020
|
-
mat33f,
|
|
1021
|
-
mat33d,
|
|
1022
|
-
mat44h,
|
|
1023
|
-
mat44f,
|
|
1024
|
-
mat44d,
|
|
1025
|
-
quath,
|
|
1026
|
-
quatf,
|
|
1027
|
-
quatd,
|
|
1028
|
-
)
|
|
1029
|
-
atomic_types = float_types + (int32, uint32, int64, uint64) + atomic_vector_types
|
|
1030
|
-
|
|
1031
999
|
np_dtype_to_warp_type = {
|
|
1032
1000
|
# Numpy scalar types
|
|
1033
1001
|
np.bool_: bool,
|
|
@@ -1076,6 +1044,14 @@ warp_type_to_np_dtype = {
|
|
|
1076
1044
|
float64: np.float64,
|
|
1077
1045
|
}
|
|
1078
1046
|
|
|
1047
|
+
non_atomic_types = (
|
|
1048
|
+
int8,
|
|
1049
|
+
uint8,
|
|
1050
|
+
int16,
|
|
1051
|
+
uint16,
|
|
1052
|
+
int64,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1079
1055
|
|
|
1080
1056
|
def dtype_from_numpy(numpy_dtype):
|
|
1081
1057
|
"""Return the Warp dtype corresponding to a NumPy dtype."""
|
|
@@ -1337,6 +1313,8 @@ def type_typestr(dtype):
|
|
|
1337
1313
|
def type_repr(t):
|
|
1338
1314
|
if is_array(t):
|
|
1339
1315
|
return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
|
|
1316
|
+
if is_tile(t):
|
|
1317
|
+
return str(f"tile(dtype={t.dtype}, m={t.M}, n={t.N})")
|
|
1340
1318
|
if type_is_vector(t):
|
|
1341
1319
|
return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
|
|
1342
1320
|
if type_is_matrix(t):
|
|
@@ -1448,6 +1426,8 @@ def scalars_equal(a, b, match_generic):
|
|
|
1448
1426
|
|
|
1449
1427
|
def types_equal(a, b, match_generic=False):
|
|
1450
1428
|
if match_generic:
|
|
1429
|
+
# Special cases to interpret the types listed in `int_tuple_type_hints`
|
|
1430
|
+
# as generic hints that accept any integer types.
|
|
1451
1431
|
if a in int_tuple_type_hints and isinstance(b, Sequence):
|
|
1452
1432
|
a_length = int_tuple_type_hints[a]
|
|
1453
1433
|
if (a_length == -1 or a_length == len(b)) and all(
|
|
@@ -1466,6 +1446,24 @@ def types_equal(a, b, match_generic=False):
|
|
|
1466
1446
|
if a_length is None or b_length is None or a_length == b_length:
|
|
1467
1447
|
return True
|
|
1468
1448
|
|
|
1449
|
+
a_origin = warp.codegen.get_type_origin(a)
|
|
1450
|
+
b_origin = warp.codegen.get_type_origin(b)
|
|
1451
|
+
if a_origin is tuple and b_origin is tuple:
|
|
1452
|
+
a_args = warp.codegen.get_type_args(a)
|
|
1453
|
+
b_args = warp.codegen.get_type_args(b)
|
|
1454
|
+
if len(a_args) == len(b_args) and all(
|
|
1455
|
+
scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
|
|
1456
|
+
):
|
|
1457
|
+
return True
|
|
1458
|
+
elif a_origin is tuple and isinstance(b, Sequence):
|
|
1459
|
+
a_args = warp.codegen.get_type_args(a)
|
|
1460
|
+
if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
|
|
1461
|
+
return True
|
|
1462
|
+
elif b_origin is tuple and isinstance(a, Sequence):
|
|
1463
|
+
b_args = warp.codegen.get_type_args(b)
|
|
1464
|
+
if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
|
|
1465
|
+
return True
|
|
1466
|
+
|
|
1469
1467
|
# convert to canonical types
|
|
1470
1468
|
if a == float:
|
|
1471
1469
|
a = float32
|
|
@@ -1495,6 +1493,9 @@ def types_equal(a, b, match_generic=False):
|
|
|
1495
1493
|
if getattr(a, "cls", "a") is getattr(b, "cls", "b"):
|
|
1496
1494
|
return True
|
|
1497
1495
|
|
|
1496
|
+
if is_tile(a) and is_tile(b):
|
|
1497
|
+
return True
|
|
1498
|
+
|
|
1498
1499
|
return scalars_equal(a, b, match_generic)
|
|
1499
1500
|
|
|
1500
1501
|
|
|
@@ -1581,6 +1582,23 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
|
|
|
1581
1582
|
|
|
1582
1583
|
|
|
1583
1584
|
class array(Array):
|
|
1585
|
+
"""A fixed-size multi-dimensional array containing values of the same type.
|
|
1586
|
+
|
|
1587
|
+
Attributes:
|
|
1588
|
+
dtype (DType): The data type of the array.
|
|
1589
|
+
ndim (int): The number of array dimensions.
|
|
1590
|
+
size (int): The number of items in the array.
|
|
1591
|
+
capacity (int): The amount of memory in bytes allocated for this array.
|
|
1592
|
+
shape (Tuple[int]): Dimensions of the array.
|
|
1593
|
+
strides (Tuple[int]): Number of bytes in each dimension between successive elements of the array.
|
|
1594
|
+
ptr (int): Pointer to underlying memory allocation backing the array.
|
|
1595
|
+
device (Device): The device where the array's memory allocation resides.
|
|
1596
|
+
pinned (bool): Indicates whether the array was allocated in pinned host memory.
|
|
1597
|
+
is_contiguous (bool): Indicates whether this array has a contiguous memory layout.
|
|
1598
|
+
deleter (Callable[[int, int], None]): A function to be called when the array is deleted,
|
|
1599
|
+
taking two arguments: pointer and size. If ``None``, then no function is called.
|
|
1600
|
+
"""
|
|
1601
|
+
|
|
1584
1602
|
# member attributes available during code-gen (e.g.: d = array.shape[0])
|
|
1585
1603
|
# (initialized when needed)
|
|
1586
1604
|
_vars = None
|
|
@@ -1592,21 +1610,21 @@ class array(Array):
|
|
|
1592
1610
|
|
|
1593
1611
|
def __init__(
|
|
1594
1612
|
self,
|
|
1595
|
-
data=None,
|
|
1596
|
-
dtype: DType = Any,
|
|
1597
|
-
shape=None,
|
|
1598
|
-
strides=None,
|
|
1599
|
-
length=None,
|
|
1600
|
-
ptr=None,
|
|
1601
|
-
capacity=None,
|
|
1613
|
+
data: Optional[Union[List, Tuple, npt.NDArray]] = None,
|
|
1614
|
+
dtype: Union[DType, Any] = Any,
|
|
1615
|
+
shape: Optional[Tuple[int, ...]] = None,
|
|
1616
|
+
strides: Optional[Tuple[int, ...]] = None,
|
|
1617
|
+
length: Optional[int] = None,
|
|
1618
|
+
ptr: Optional[int] = None,
|
|
1619
|
+
capacity: Optional[int] = None,
|
|
1602
1620
|
device=None,
|
|
1603
|
-
pinned=False,
|
|
1604
|
-
copy=True,
|
|
1605
|
-
owner=False, # deprecated - pass deleter instead
|
|
1606
|
-
deleter=None,
|
|
1607
|
-
ndim=None,
|
|
1608
|
-
grad=None,
|
|
1609
|
-
requires_grad=False,
|
|
1621
|
+
pinned: bool = False,
|
|
1622
|
+
copy: bool = True,
|
|
1623
|
+
owner: bool = False, # deprecated - pass deleter instead
|
|
1624
|
+
deleter: Optional[Callable[[int, int], None]] = None,
|
|
1625
|
+
ndim: Optional[int] = None,
|
|
1626
|
+
grad: Optional[array] = None,
|
|
1627
|
+
requires_grad: bool = False,
|
|
1610
1628
|
):
|
|
1611
1629
|
"""Constructs a new Warp array object
|
|
1612
1630
|
|
|
@@ -1628,20 +1646,24 @@ class array(Array):
|
|
|
1628
1646
|
are taken into account and no memory is allocated for the array.
|
|
1629
1647
|
|
|
1630
1648
|
Args:
|
|
1631
|
-
data
|
|
1632
|
-
dtype
|
|
1633
|
-
shape
|
|
1634
|
-
strides
|
|
1635
|
-
length
|
|
1636
|
-
ptr
|
|
1637
|
-
capacity
|
|
1649
|
+
data: An object to construct the array from, can be a Tuple, List, or generally any type convertible to an np.array
|
|
1650
|
+
dtype: One of the available `data types <#data-types>`_, such as :class:`warp.float32`, :class:`warp.mat33`, or a custom `struct <#structs>`_. If dtype is ``Any`` and data is an ndarray, then it will be inferred from the array data type
|
|
1651
|
+
shape: Dimensions of the array
|
|
1652
|
+
strides: Number of bytes in each dimension between successive elements of the array
|
|
1653
|
+
length: Number of elements of the data type (deprecated, users should use ``shape`` argument)
|
|
1654
|
+
ptr: Address of an external memory address to alias (``data`` should be ``None``)
|
|
1655
|
+
capacity: Maximum size in bytes of the ``ptr`` allocation (``data`` should be ``None``)
|
|
1638
1656
|
device (Devicelike): Device the array lives on
|
|
1639
|
-
copy
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1657
|
+
copy: Whether the incoming ``data`` will be copied or aliased. Aliasing requires that
|
|
1658
|
+
the incoming ``data`` already lives on the ``device`` specified and the data types match.
|
|
1659
|
+
owner: Whether the array will try to deallocate the underlying memory when it is deleted
|
|
1660
|
+
(deprecated, pass ``deleter`` if you wish to transfer ownership to Warp)
|
|
1661
|
+
deleter: Function to be called when the array is deleted, taking two arguments: pointer and size
|
|
1662
|
+
requires_grad: Whether or not gradients will be tracked for this array, see :class:`warp.Tape` for details
|
|
1663
|
+
grad: The array in which to accumulate gradients in the backward pass. If ``None`` and ``requires_grad`` is ``True``,
|
|
1664
|
+
then a gradient array will be allocated automatically.
|
|
1665
|
+
pinned: Whether to allocate pinned host memory, which allows asynchronous host–device transfers
|
|
1666
|
+
(only applicable with ``device="cpu"``)
|
|
1645
1667
|
|
|
1646
1668
|
"""
|
|
1647
1669
|
|
|
@@ -2963,6 +2985,116 @@ def array_type_id(a):
|
|
|
2963
2985
|
raise ValueError("Invalid array type")
|
|
2964
2986
|
|
|
2965
2987
|
|
|
2988
|
+
# tile expression objects
|
|
2989
|
+
class Tile:
|
|
2990
|
+
alignment = 16
|
|
2991
|
+
|
|
2992
|
+
def __init__(self, dtype, M, N, op=None, storage="register", layout="rowmajor", strides=None, owner=True):
|
|
2993
|
+
self.dtype = type_to_warp(dtype)
|
|
2994
|
+
self.M = M
|
|
2995
|
+
self.N = N
|
|
2996
|
+
self.op = op
|
|
2997
|
+
self.storage = storage
|
|
2998
|
+
self.layout = layout
|
|
2999
|
+
|
|
3000
|
+
if strides is None:
|
|
3001
|
+
if layout == "rowmajor":
|
|
3002
|
+
self.strides = (N, 1)
|
|
3003
|
+
elif layout == "colmajor":
|
|
3004
|
+
self.strides = (1, M)
|
|
3005
|
+
else:
|
|
3006
|
+
self.strides = strides
|
|
3007
|
+
|
|
3008
|
+
self.owner = owner
|
|
3009
|
+
|
|
3010
|
+
# generates C-type string
|
|
3011
|
+
def ctype(self):
|
|
3012
|
+
from warp.codegen import Var
|
|
3013
|
+
|
|
3014
|
+
if self.storage == "register":
|
|
3015
|
+
return f"wp::tile_register_t<{Var.type_to_ctype(self.dtype)},{self.M},{self.N}>"
|
|
3016
|
+
elif self.storage == "shared":
|
|
3017
|
+
return f"wp::tile_shared_t<{Var.type_to_ctype(self.dtype)},{self.M},{self.N},{self.strides[0]}, {self.strides[1]}, {'true' if self.owner else 'false'}>"
|
|
3018
|
+
else:
|
|
3019
|
+
raise RuntimeError(f"Unrecognized tile storage type {self.storage}")
|
|
3020
|
+
|
|
3021
|
+
# generates C-initializer string
|
|
3022
|
+
def cinit(self, requires_grad=False):
|
|
3023
|
+
from warp.codegen import Var
|
|
3024
|
+
|
|
3025
|
+
if self.storage == "register":
|
|
3026
|
+
return self.ctype() + "(0.0)"
|
|
3027
|
+
elif self.storage == "shared":
|
|
3028
|
+
if self.owner:
|
|
3029
|
+
# allocate new shared memory tile
|
|
3030
|
+
return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},{self.M},{self.N},{'true' if requires_grad else 'false'}>()"
|
|
3031
|
+
else:
|
|
3032
|
+
# tile will be initialized by another call, e.g.: tile_transpose()
|
|
3033
|
+
return "NULL"
|
|
3034
|
+
|
|
3035
|
+
# return total tile size in bytes
|
|
3036
|
+
def size_in_bytes(self):
|
|
3037
|
+
num_bytes = self.align(type_size_in_bytes(self.dtype) * self.M * self.N)
|
|
3038
|
+
return num_bytes
|
|
3039
|
+
|
|
3040
|
+
# align tile size to natural boundary, default 16-bytes
|
|
3041
|
+
def align(self, bytes):
|
|
3042
|
+
return ((bytes + self.alignment - 1) // self.alignment) * self.alignment
|
|
3043
|
+
|
|
3044
|
+
|
|
3045
|
+
class TileZeros(Tile):
|
|
3046
|
+
def __init__(self, dtype, M, N, storage="register"):
|
|
3047
|
+
Tile.__init__(self, dtype, M, N, op="zeros", storage=storage)
|
|
3048
|
+
|
|
3049
|
+
|
|
3050
|
+
class TileRange(Tile):
|
|
3051
|
+
def __init__(self, dtype, start, stop, step, storage="register"):
|
|
3052
|
+
self.start = start
|
|
3053
|
+
self.stop = stop
|
|
3054
|
+
self.step = step
|
|
3055
|
+
|
|
3056
|
+
M = 1
|
|
3057
|
+
N = int((stop - start) / step)
|
|
3058
|
+
|
|
3059
|
+
Tile.__init__(self, dtype, M, N, op="arange", storage=storage)
|
|
3060
|
+
|
|
3061
|
+
|
|
3062
|
+
class TileConstant(Tile):
|
|
3063
|
+
def __init__(self, dtype, M, N):
|
|
3064
|
+
Tile.__init__(self, dtype, M, N, op="constant", storage="register")
|
|
3065
|
+
|
|
3066
|
+
|
|
3067
|
+
class TileLoad(Tile):
|
|
3068
|
+
def __init__(self, array, M, N, storage="register"):
|
|
3069
|
+
Tile.__init__(self, array.dtype, M, N, op="load", storage=storage)
|
|
3070
|
+
|
|
3071
|
+
|
|
3072
|
+
class TileUnaryMap(Tile):
|
|
3073
|
+
def __init__(self, t, storage="register"):
|
|
3074
|
+
Tile.__init__(self, t.dtype, t.M, t.N, op="unary_map", storage=storage)
|
|
3075
|
+
|
|
3076
|
+
self.t = t
|
|
3077
|
+
|
|
3078
|
+
|
|
3079
|
+
class TileBinaryMap(Tile):
|
|
3080
|
+
def __init__(self, a, b, storage="register"):
|
|
3081
|
+
Tile.__init__(self, a.dtype, a.M, a.N, op="binary_map", storage=storage)
|
|
3082
|
+
|
|
3083
|
+
self.a = a
|
|
3084
|
+
self.b = b
|
|
3085
|
+
|
|
3086
|
+
|
|
3087
|
+
class TileShared(Tile):
|
|
3088
|
+
def __init__(self, t):
|
|
3089
|
+
Tile.__init__(self, t.dtype, t.M, t.N, "shared", storage="shared")
|
|
3090
|
+
|
|
3091
|
+
self.t = t
|
|
3092
|
+
|
|
3093
|
+
|
|
3094
|
+
def is_tile(t):
|
|
3095
|
+
return isinstance(t, Tile)
|
|
3096
|
+
|
|
3097
|
+
|
|
2966
3098
|
class Bvh:
|
|
2967
3099
|
def __new__(cls, *args, **kwargs):
|
|
2968
3100
|
instance = super(Bvh, cls).__new__(cls)
|
|
@@ -3544,9 +3676,9 @@ class Volume:
|
|
|
3544
3676
|
grid_data = bytearray()
|
|
3545
3677
|
while grid_data_offset < file_end:
|
|
3546
3678
|
chunk_size = struct.unpack("<Q", data[grid_data_offset : grid_data_offset + 8])[0]
|
|
3547
|
-
|
|
3548
|
-
|
|
3549
|
-
|
|
3679
|
+
grid_data_offset += 8
|
|
3680
|
+
grid_data += zlib.decompress(data[grid_data_offset : grid_data_offset + chunk_size])
|
|
3681
|
+
grid_data_offset += chunk_size
|
|
3550
3682
|
elif codec == 2: # blosc compression
|
|
3551
3683
|
try:
|
|
3552
3684
|
import blosc
|
|
@@ -3558,8 +3690,9 @@ class Volume:
|
|
|
3558
3690
|
grid_data = bytearray()
|
|
3559
3691
|
while grid_data_offset < file_end:
|
|
3560
3692
|
chunk_size = struct.unpack("<Q", data[grid_data_offset : grid_data_offset + 8])[0]
|
|
3561
|
-
|
|
3562
|
-
|
|
3693
|
+
grid_data_offset += 8
|
|
3694
|
+
grid_data += blosc.decompress(data[grid_data_offset : grid_data_offset + chunk_size])
|
|
3695
|
+
grid_data_offset += chunk_size
|
|
3563
3696
|
else:
|
|
3564
3697
|
raise RuntimeError(f"Unsupported codec code: {codec}")
|
|
3565
3698
|
|
|
@@ -3570,6 +3703,139 @@ class Volume:
|
|
|
3570
3703
|
data_array = array(np.frombuffer(grid_data, dtype=np.byte), device=device)
|
|
3571
3704
|
return cls(data_array)
|
|
3572
3705
|
|
|
3706
|
+
def save_to_nvdb(self, path, codec: Literal["none", "zip", "blosc"] = "none"):
|
|
3707
|
+
"""Serialize the Volume into a NanoVDB (.nvdb) file.
|
|
3708
|
+
|
|
3709
|
+
Args:
|
|
3710
|
+
path: File path to save.
|
|
3711
|
+
codec: Compression codec used
|
|
3712
|
+
"none" - no compression
|
|
3713
|
+
"zip" - ZIP compression
|
|
3714
|
+
"blosc" - BLOSC compression, requires the blosc module to be installed
|
|
3715
|
+
"""
|
|
3716
|
+
|
|
3717
|
+
codec_dict = {"none": 0, "zip": 1, "blosc": 2}
|
|
3718
|
+
|
|
3719
|
+
class FileHeader(ctypes.Structure):
|
|
3720
|
+
_fields_ = [
|
|
3721
|
+
("magic", ctypes.c_uint64),
|
|
3722
|
+
("version", ctypes.c_uint32),
|
|
3723
|
+
("gridCount", ctypes.c_uint16),
|
|
3724
|
+
("codec", ctypes.c_uint16),
|
|
3725
|
+
]
|
|
3726
|
+
|
|
3727
|
+
class FileMetaData(ctypes.Structure):
|
|
3728
|
+
_fields_ = [
|
|
3729
|
+
("gridSize", ctypes.c_uint64),
|
|
3730
|
+
("fileSize", ctypes.c_uint64),
|
|
3731
|
+
("nameKey", ctypes.c_uint64),
|
|
3732
|
+
("voxelCount", ctypes.c_uint64),
|
|
3733
|
+
("gridType", ctypes.c_uint32),
|
|
3734
|
+
("gridClass", ctypes.c_uint32),
|
|
3735
|
+
("worldBBox", ctypes.c_double * 6),
|
|
3736
|
+
("indexBBox", ctypes.c_uint32 * 6),
|
|
3737
|
+
("voxelSize", ctypes.c_double * 3),
|
|
3738
|
+
("nameSize", ctypes.c_uint32),
|
|
3739
|
+
("nodeCount", ctypes.c_uint32 * 4),
|
|
3740
|
+
("tileCount", ctypes.c_uint32 * 3),
|
|
3741
|
+
("codec", ctypes.c_uint16),
|
|
3742
|
+
("padding", ctypes.c_uint16),
|
|
3743
|
+
("version", ctypes.c_uint32),
|
|
3744
|
+
]
|
|
3745
|
+
|
|
3746
|
+
class GridData(ctypes.Structure):
|
|
3747
|
+
_fields_ = [
|
|
3748
|
+
("magic", ctypes.c_uint64),
|
|
3749
|
+
("checksum", ctypes.c_uint64),
|
|
3750
|
+
("version", ctypes.c_uint32),
|
|
3751
|
+
("flags", ctypes.c_uint32),
|
|
3752
|
+
("gridIndex", ctypes.c_uint32),
|
|
3753
|
+
("gridCount", ctypes.c_uint32),
|
|
3754
|
+
("gridSize", ctypes.c_uint64),
|
|
3755
|
+
("gridName", ctypes.c_char * 256),
|
|
3756
|
+
("map", ctypes.c_byte * 264),
|
|
3757
|
+
("worldBBox", ctypes.c_double * 6),
|
|
3758
|
+
("voxelSize", ctypes.c_double * 3),
|
|
3759
|
+
("gridClass", ctypes.c_uint32),
|
|
3760
|
+
("gridType", ctypes.c_uint32),
|
|
3761
|
+
("blindMetadataOffset", ctypes.c_int64),
|
|
3762
|
+
("blindMetadataCount", ctypes.c_uint32),
|
|
3763
|
+
("data0", ctypes.c_uint32),
|
|
3764
|
+
("data1", ctypes.c_uint64),
|
|
3765
|
+
("data2", ctypes.c_uint64),
|
|
3766
|
+
]
|
|
3767
|
+
|
|
3768
|
+
NVDB_MAGIC = 0x304244566F6E614E
|
|
3769
|
+
NVDB_VERSION = 32 << 21 | 3 << 10 | 3
|
|
3770
|
+
|
|
3771
|
+
try:
|
|
3772
|
+
codec_int = codec_dict[codec]
|
|
3773
|
+
except KeyError as err:
|
|
3774
|
+
raise RuntimeError(f"Unsupported codec requested: {codec}") from err
|
|
3775
|
+
|
|
3776
|
+
if codec_int == 2:
|
|
3777
|
+
try:
|
|
3778
|
+
import blosc
|
|
3779
|
+
except ImportError as err:
|
|
3780
|
+
raise RuntimeError(
|
|
3781
|
+
f"blosc compression was requested, but Python module could not be imported: {err}"
|
|
3782
|
+
) from err
|
|
3783
|
+
|
|
3784
|
+
data = self.array().numpy()
|
|
3785
|
+
grid_data = GridData.from_buffer(data)
|
|
3786
|
+
|
|
3787
|
+
if grid_data.gridIndex > 0:
|
|
3788
|
+
raise RuntimeError(
|
|
3789
|
+
"Saving of aliased Volumes is not supported. Use `save_to_nvdb` on the original volume, before any `load_next_grid` calls."
|
|
3790
|
+
)
|
|
3791
|
+
|
|
3792
|
+
file_header = FileHeader(NVDB_MAGIC, NVDB_VERSION, grid_data.gridCount, codec_int)
|
|
3793
|
+
|
|
3794
|
+
grid_data_offset = 0
|
|
3795
|
+
all_file_meta_data = []
|
|
3796
|
+
for i in range(file_header.gridCount):
|
|
3797
|
+
if i > 0:
|
|
3798
|
+
grid_data = GridData.from_buffer(data[grid_data_offset : grid_data_offset + 672])
|
|
3799
|
+
current_grid_data = data[grid_data_offset : grid_data_offset + grid_data.gridSize]
|
|
3800
|
+
if codec_int == 1: # zip compression
|
|
3801
|
+
compressed_data = zlib.compress(current_grid_data)
|
|
3802
|
+
compressed_size = len(compressed_data)
|
|
3803
|
+
elif codec_int == 2: # blosc compression
|
|
3804
|
+
compressed_data = blosc.compress(current_grid_data)
|
|
3805
|
+
compressed_size = len(compressed_data)
|
|
3806
|
+
else: # no compression
|
|
3807
|
+
compressed_data = current_grid_data
|
|
3808
|
+
compressed_size = grid_data.gridSize
|
|
3809
|
+
|
|
3810
|
+
file_meta_data = FileMetaData()
|
|
3811
|
+
file_meta_data.gridSize = grid_data.gridSize
|
|
3812
|
+
file_meta_data.fileSize = compressed_size
|
|
3813
|
+
file_meta_data.gridType = grid_data.gridType
|
|
3814
|
+
file_meta_data.gridClass = grid_data.gridClass
|
|
3815
|
+
file_meta_data.worldBBox = grid_data.worldBBox
|
|
3816
|
+
file_meta_data.voxelSize = grid_data.voxelSize
|
|
3817
|
+
file_meta_data.nameSize = len(grid_data.gridName) + 1 # including the closing 0x0
|
|
3818
|
+
file_meta_data.codec = codec_int
|
|
3819
|
+
file_meta_data.version = NVDB_VERSION
|
|
3820
|
+
|
|
3821
|
+
grid_data_offset += file_meta_data.gridSize
|
|
3822
|
+
|
|
3823
|
+
all_file_meta_data.append((file_meta_data, grid_data.gridName, compressed_data))
|
|
3824
|
+
|
|
3825
|
+
with open(path, "wb") as nvdb:
|
|
3826
|
+
nvdb.write(file_header)
|
|
3827
|
+
for file_meta_data, grid_name, _ in all_file_meta_data:
|
|
3828
|
+
nvdb.write(file_meta_data)
|
|
3829
|
+
nvdb.write(grid_name + b"\x00")
|
|
3830
|
+
|
|
3831
|
+
for file_meta_data, _, compressed_data in all_file_meta_data:
|
|
3832
|
+
if codec_int > 0:
|
|
3833
|
+
chunk_size = struct.pack("<Q", file_meta_data.fileSize)
|
|
3834
|
+
nvdb.write(chunk_size)
|
|
3835
|
+
nvdb.write(compressed_data)
|
|
3836
|
+
|
|
3837
|
+
return path
|
|
3838
|
+
|
|
3573
3839
|
@classmethod
|
|
3574
3840
|
def load_from_address(cls, grid_ptr: int, buffer_size: int = 0, device=None) -> Volume:
|
|
3575
3841
|
"""
|
warp/utils.py
CHANGED
|
@@ -18,6 +18,7 @@ import numpy as np
|
|
|
18
18
|
import warp as wp
|
|
19
19
|
import warp.context
|
|
20
20
|
import warp.types
|
|
21
|
+
from warp.context import Devicelike
|
|
21
22
|
|
|
22
23
|
warnings_seen = set()
|
|
23
24
|
|
|
@@ -38,8 +39,7 @@ def warp_showwarning(message, category, filename, lineno, file=None, line=None):
|
|
|
38
39
|
# and the import machinery don't work anymore
|
|
39
40
|
line = None
|
|
40
41
|
linecache = None
|
|
41
|
-
|
|
42
|
-
line = line
|
|
42
|
+
|
|
43
43
|
if line:
|
|
44
44
|
line = line.strip()
|
|
45
45
|
s += " %s\n" % line
|
|
@@ -554,7 +554,27 @@ def mem_report(): # pragma: no cover
|
|
|
554
554
|
|
|
555
555
|
|
|
556
556
|
class ScopedDevice:
|
|
557
|
-
|
|
557
|
+
"""A context manager to temporarily change the current default device.
|
|
558
|
+
|
|
559
|
+
For CUDA devices, this context manager makes the device's CUDA context
|
|
560
|
+
current and restores the previous CUDA context on exit. This is handy when
|
|
561
|
+
running Warp scripts as part of a bigger pipeline because it avoids any side
|
|
562
|
+
effects of changing the CUDA context in the enclosed code.
|
|
563
|
+
|
|
564
|
+
Attributes:
|
|
565
|
+
device (Device): The device that will temporarily become the default
|
|
566
|
+
device within the context.
|
|
567
|
+
saved_device (Device): The previous default device. This is restored as
|
|
568
|
+
the default device on exiting the context.
|
|
569
|
+
"""
|
|
570
|
+
|
|
571
|
+
def __init__(self, device: Devicelike):
|
|
572
|
+
"""Initializes the context manager with a device.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
device: The device that will temporarily become the default device
|
|
576
|
+
within the context.
|
|
577
|
+
"""
|
|
558
578
|
self.device = wp.get_device(device)
|
|
559
579
|
|
|
560
580
|
def __enter__(self):
|