warp-lang 1.4.1__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1920 -111
- warp/codegen.py +186 -62
- warp/config.py +2 -2
- warp/context.py +322 -73
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/core/example_dem.py +2 -1
- warp/examples/core/example_mesh_intersect.py +3 -3
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/optim/example_walker.py +2 -2
- warp/examples/sim/example_cloth.py +2 -25
- warp/examples/sim/example_jacobian_ik.py +6 -2
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -5
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +55 -40
- warp/native/builtin.h +124 -43
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +600 -0
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1857 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +137 -65
- warp/sim/graph_coloring.py +292 -0
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +90 -17
- warp/stubs.py +651 -85
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +207 -48
- warp/tests/test_closest_point_edge_edge.py +8 -8
- warp/tests/test_codegen.py +120 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +241 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +18 -4
- warp/tests/test_fabricarray.py +33 -0
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +48 -1
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_mesh_query_point.py +5 -4
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +191 -1
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_tile.py +700 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +23 -2
- warp/tests/unittest_utils.py +4 -0
- warp/types.py +339 -73
- warp/utils.py +22 -1
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/tests/unittest_suites.py
CHANGED
|
@@ -99,8 +99,11 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
99
99
|
from warp.tests.test_closest_point_edge_edge import TestClosestPointEdgeEdgeMethods
|
|
100
100
|
from warp.tests.test_codegen import TestCodeGen
|
|
101
101
|
from warp.tests.test_codegen_instancing import TestCodeGenInstancing
|
|
102
|
+
from warp.tests.test_collision import TestCollision
|
|
103
|
+
from warp.tests.test_coloring import TestColoring
|
|
102
104
|
from warp.tests.test_compile_consts import TestConstants
|
|
103
105
|
from warp.tests.test_conditional import TestConditional
|
|
106
|
+
from warp.tests.test_context import TestContext
|
|
104
107
|
from warp.tests.test_copy import TestCopy
|
|
105
108
|
from warp.tests.test_ctypes import TestCTypes
|
|
106
109
|
from warp.tests.test_dense import TestDense
|
|
@@ -115,7 +118,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
115
118
|
)
|
|
116
119
|
from warp.tests.test_fabricarray import TestFabricArray
|
|
117
120
|
from warp.tests.test_fast_math import TestFastMath
|
|
118
|
-
from warp.tests.test_fem import TestFem, TestFemShapeFunctions
|
|
121
|
+
from warp.tests.test_fem import TestFem, TestFemShapeFunctions, TestFemUtilities
|
|
119
122
|
from warp.tests.test_fp16 import TestFp16
|
|
120
123
|
from warp.tests.test_func import TestFunc
|
|
121
124
|
from warp.tests.test_future_annotations import TestFutureAnnotations
|
|
@@ -127,6 +130,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
127
130
|
from warp.tests.test_import import TestImport
|
|
128
131
|
from warp.tests.test_indexedarray import TestIndexedArray
|
|
129
132
|
from warp.tests.test_intersect import TestIntersect
|
|
133
|
+
from warp.tests.test_iter import TestIter
|
|
130
134
|
from warp.tests.test_jax import TestJax
|
|
131
135
|
from warp.tests.test_large import TestLarge
|
|
132
136
|
from warp.tests.test_launch import TestLaunch
|
|
@@ -170,9 +174,14 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
170
174
|
from warp.tests.test_sparse import TestSparse
|
|
171
175
|
from warp.tests.test_spatial import TestSpatial
|
|
172
176
|
from warp.tests.test_special_values import TestSpecialValues
|
|
177
|
+
from warp.tests.test_static import TestStatic
|
|
173
178
|
from warp.tests.test_streams import TestStreams
|
|
174
179
|
from warp.tests.test_struct import TestStruct
|
|
175
180
|
from warp.tests.test_tape import TestTape
|
|
181
|
+
from warp.tests.test_tile import TestTile
|
|
182
|
+
from warp.tests.test_tile_mathdx import TestTileMathDx
|
|
183
|
+
from warp.tests.test_tile_reduce import TestTileReduce
|
|
184
|
+
from warp.tests.test_tile_shared_memory import TestTileSharedMemory
|
|
176
185
|
from warp.tests.test_torch import TestTorch
|
|
177
186
|
from warp.tests.test_transient_module import TestTransientModule
|
|
178
187
|
from warp.tests.test_triangle_closest_point import TestTriangleClosestPoint
|
|
@@ -199,8 +208,11 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
199
208
|
TestClosestPointEdgeEdgeMethods,
|
|
200
209
|
TestCodeGen,
|
|
201
210
|
TestCodeGenInstancing,
|
|
202
|
-
|
|
211
|
+
TestCollision,
|
|
212
|
+
TestColoring,
|
|
203
213
|
TestConditional,
|
|
214
|
+
TestConstants,
|
|
215
|
+
TestContext,
|
|
204
216
|
TestCopy,
|
|
205
217
|
TestCTypes,
|
|
206
218
|
TestDense,
|
|
@@ -215,6 +227,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
215
227
|
TestFastMath,
|
|
216
228
|
TestFem,
|
|
217
229
|
TestFemShapeFunctions,
|
|
230
|
+
TestFemUtilities,
|
|
218
231
|
TestFp16,
|
|
219
232
|
TestFunc,
|
|
220
233
|
TestFutureAnnotations,
|
|
@@ -226,6 +239,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
226
239
|
TestImport,
|
|
227
240
|
TestIndexedArray,
|
|
228
241
|
TestIntersect,
|
|
242
|
+
TestIter,
|
|
229
243
|
TestJax,
|
|
230
244
|
TestLarge,
|
|
231
245
|
TestLaunch,
|
|
@@ -269,9 +283,14 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader)
|
|
|
269
283
|
TestSparse,
|
|
270
284
|
TestSpatial,
|
|
271
285
|
TestSpecialValues,
|
|
286
|
+
TestStatic,
|
|
272
287
|
TestStreams,
|
|
273
288
|
TestStruct,
|
|
274
289
|
TestTape,
|
|
290
|
+
TestTile,
|
|
291
|
+
TestTileMathDx,
|
|
292
|
+
TestTileReduce,
|
|
293
|
+
TestTileSharedMemory,
|
|
275
294
|
TestTorch,
|
|
276
295
|
TestTransientModule,
|
|
277
296
|
TestTriangleClosestPoint,
|
|
@@ -329,6 +348,7 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
|
|
|
329
348
|
from warp.tests.test_rounding import TestRounding
|
|
330
349
|
from warp.tests.test_runlength_encode import TestRunlengthEncode
|
|
331
350
|
from warp.tests.test_sparse import TestSparse
|
|
351
|
+
from warp.tests.test_static import TestStatic
|
|
332
352
|
from warp.tests.test_streams import TestStreams
|
|
333
353
|
from warp.tests.test_tape import TestTape
|
|
334
354
|
from warp.tests.test_transient_module import TestTransientModule
|
|
@@ -374,6 +394,7 @@ def kit_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader):
|
|
|
374
394
|
TestRounding,
|
|
375
395
|
TestRunlengthEncode,
|
|
376
396
|
TestSparse,
|
|
397
|
+
TestStatic,
|
|
377
398
|
TestStreams,
|
|
378
399
|
TestTape,
|
|
379
400
|
TestTransientModule,
|
warp/tests/unittest_utils.py
CHANGED
|
@@ -232,6 +232,10 @@ def create_test_func(func, device, check_output, **kwargs):
|
|
|
232
232
|
else:
|
|
233
233
|
func(self, device, **kwargs)
|
|
234
234
|
|
|
235
|
+
# Copy the __unittest_expecting_failure__ attribute from func to test_func
|
|
236
|
+
if hasattr(func, "__unittest_expecting_failure__"):
|
|
237
|
+
test_func.__unittest_expecting_failure__ = func.__unittest_expecting_failure__
|
|
238
|
+
|
|
235
239
|
return test_func
|
|
236
240
|
|
|
237
241
|
|
warp/types.py
CHANGED
|
@@ -12,9 +12,10 @@ import ctypes
|
|
|
12
12
|
import inspect
|
|
13
13
|
import struct
|
|
14
14
|
import zlib
|
|
15
|
-
from typing import Any, Callable, Generic, List, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
|
|
15
|
+
from typing import Any, Callable, Generic, List, Literal, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
|
+
import numpy.typing as npt
|
|
18
19
|
|
|
19
20
|
import warp
|
|
20
21
|
|
|
@@ -100,8 +101,10 @@ def vector(length, dtype):
|
|
|
100
101
|
|
|
101
102
|
if dtype is bool:
|
|
102
103
|
_type_ = ctypes.c_bool
|
|
103
|
-
elif dtype in
|
|
104
|
+
elif dtype in (Scalar, Float):
|
|
104
105
|
_type_ = ctypes.c_float
|
|
106
|
+
elif dtype is Int:
|
|
107
|
+
_type_ = ctypes.c_int
|
|
105
108
|
else:
|
|
106
109
|
_type_ = dtype._type_
|
|
107
110
|
|
|
@@ -289,8 +292,10 @@ def matrix(shape, dtype):
|
|
|
289
292
|
|
|
290
293
|
if dtype is bool:
|
|
291
294
|
_type_ = ctypes.c_bool
|
|
292
|
-
elif dtype in
|
|
295
|
+
elif dtype in (Scalar, Float):
|
|
293
296
|
_type_ = ctypes.c_float
|
|
297
|
+
elif dtype is Int:
|
|
298
|
+
_type_ = ctypes.c_int
|
|
294
299
|
else:
|
|
295
300
|
_type_ = dtype._type_
|
|
296
301
|
|
|
@@ -991,43 +996,6 @@ vector_types = (
|
|
|
991
996
|
spatial_matrixd,
|
|
992
997
|
)
|
|
993
998
|
|
|
994
|
-
atomic_vector_types = (
|
|
995
|
-
vec2i,
|
|
996
|
-
vec2ui,
|
|
997
|
-
vec2l,
|
|
998
|
-
vec2ul,
|
|
999
|
-
vec2h,
|
|
1000
|
-
vec2f,
|
|
1001
|
-
vec2d,
|
|
1002
|
-
vec3i,
|
|
1003
|
-
vec3ui,
|
|
1004
|
-
vec3l,
|
|
1005
|
-
vec3ul,
|
|
1006
|
-
vec3h,
|
|
1007
|
-
vec3f,
|
|
1008
|
-
vec3d,
|
|
1009
|
-
vec4i,
|
|
1010
|
-
vec4ui,
|
|
1011
|
-
vec4l,
|
|
1012
|
-
vec4ul,
|
|
1013
|
-
vec4h,
|
|
1014
|
-
vec4f,
|
|
1015
|
-
vec4d,
|
|
1016
|
-
mat22h,
|
|
1017
|
-
mat22f,
|
|
1018
|
-
mat22d,
|
|
1019
|
-
mat33h,
|
|
1020
|
-
mat33f,
|
|
1021
|
-
mat33d,
|
|
1022
|
-
mat44h,
|
|
1023
|
-
mat44f,
|
|
1024
|
-
mat44d,
|
|
1025
|
-
quath,
|
|
1026
|
-
quatf,
|
|
1027
|
-
quatd,
|
|
1028
|
-
)
|
|
1029
|
-
atomic_types = float_types + (int32, uint32, int64, uint64) + atomic_vector_types
|
|
1030
|
-
|
|
1031
999
|
np_dtype_to_warp_type = {
|
|
1032
1000
|
# Numpy scalar types
|
|
1033
1001
|
np.bool_: bool,
|
|
@@ -1076,6 +1044,14 @@ warp_type_to_np_dtype = {
|
|
|
1076
1044
|
float64: np.float64,
|
|
1077
1045
|
}
|
|
1078
1046
|
|
|
1047
|
+
non_atomic_types = (
|
|
1048
|
+
int8,
|
|
1049
|
+
uint8,
|
|
1050
|
+
int16,
|
|
1051
|
+
uint16,
|
|
1052
|
+
int64,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1079
1055
|
|
|
1080
1056
|
def dtype_from_numpy(numpy_dtype):
|
|
1081
1057
|
"""Return the Warp dtype corresponding to a NumPy dtype."""
|
|
@@ -1337,6 +1313,8 @@ def type_typestr(dtype):
|
|
|
1337
1313
|
def type_repr(t):
|
|
1338
1314
|
if is_array(t):
|
|
1339
1315
|
return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
|
|
1316
|
+
if is_tile(t):
|
|
1317
|
+
return str(f"tile(dtype={t.dtype}, m={t.M}, n={t.N})")
|
|
1340
1318
|
if type_is_vector(t):
|
|
1341
1319
|
return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
|
|
1342
1320
|
if type_is_matrix(t):
|
|
@@ -1448,6 +1426,8 @@ def scalars_equal(a, b, match_generic):
|
|
|
1448
1426
|
|
|
1449
1427
|
def types_equal(a, b, match_generic=False):
|
|
1450
1428
|
if match_generic:
|
|
1429
|
+
# Special cases to interpret the types listed in `int_tuple_type_hints`
|
|
1430
|
+
# as generic hints that accept any integer types.
|
|
1451
1431
|
if a in int_tuple_type_hints and isinstance(b, Sequence):
|
|
1452
1432
|
a_length = int_tuple_type_hints[a]
|
|
1453
1433
|
if (a_length == -1 or a_length == len(b)) and all(
|
|
@@ -1466,6 +1446,24 @@ def types_equal(a, b, match_generic=False):
|
|
|
1466
1446
|
if a_length is None or b_length is None or a_length == b_length:
|
|
1467
1447
|
return True
|
|
1468
1448
|
|
|
1449
|
+
a_origin = warp.codegen.get_type_origin(a)
|
|
1450
|
+
b_origin = warp.codegen.get_type_origin(b)
|
|
1451
|
+
if a_origin is tuple and b_origin is tuple:
|
|
1452
|
+
a_args = warp.codegen.get_type_args(a)
|
|
1453
|
+
b_args = warp.codegen.get_type_args(b)
|
|
1454
|
+
if len(a_args) == len(b_args) and all(
|
|
1455
|
+
scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b_args)
|
|
1456
|
+
):
|
|
1457
|
+
return True
|
|
1458
|
+
elif a_origin is tuple and isinstance(b, Sequence):
|
|
1459
|
+
a_args = warp.codegen.get_type_args(a)
|
|
1460
|
+
if len(a_args) == len(b) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(a_args, b)):
|
|
1461
|
+
return True
|
|
1462
|
+
elif b_origin is tuple and isinstance(a, Sequence):
|
|
1463
|
+
b_args = warp.codegen.get_type_args(b)
|
|
1464
|
+
if len(b_args) == len(a) and all(scalars_equal(x, y, match_generic=match_generic) for x, y in zip(b_args, a)):
|
|
1465
|
+
return True
|
|
1466
|
+
|
|
1469
1467
|
# convert to canonical types
|
|
1470
1468
|
if a == float:
|
|
1471
1469
|
a = float32
|
|
@@ -1488,13 +1486,16 @@ def types_equal(a, b, match_generic=False):
|
|
|
1488
1486
|
|
|
1489
1487
|
return True
|
|
1490
1488
|
|
|
1491
|
-
if is_array(a) and type(a) is type(b):
|
|
1489
|
+
if is_array(a) and type(a) is type(b) and types_equal(a.dtype, b.dtype, match_generic=match_generic):
|
|
1492
1490
|
return True
|
|
1493
1491
|
|
|
1494
1492
|
# match NewStructInstance and Struct dtype
|
|
1495
1493
|
if getattr(a, "cls", "a") is getattr(b, "cls", "b"):
|
|
1496
1494
|
return True
|
|
1497
1495
|
|
|
1496
|
+
if is_tile(a) and is_tile(b):
|
|
1497
|
+
return True
|
|
1498
|
+
|
|
1498
1499
|
return scalars_equal(a, b, match_generic)
|
|
1499
1500
|
|
|
1500
1501
|
|
|
@@ -1581,6 +1582,23 @@ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
|
|
|
1581
1582
|
|
|
1582
1583
|
|
|
1583
1584
|
class array(Array):
|
|
1585
|
+
"""A fixed-size multi-dimensional array containing values of the same type.
|
|
1586
|
+
|
|
1587
|
+
Attributes:
|
|
1588
|
+
dtype (DType): The data type of the array.
|
|
1589
|
+
ndim (int): The number of array dimensions.
|
|
1590
|
+
size (int): The number of items in the array.
|
|
1591
|
+
capacity (int): The amount of memory in bytes allocated for this array.
|
|
1592
|
+
shape (Tuple[int]): Dimensions of the array.
|
|
1593
|
+
strides (Tuple[int]): Number of bytes in each dimension between successive elements of the array.
|
|
1594
|
+
ptr (int): Pointer to underlying memory allocation backing the array.
|
|
1595
|
+
device (Device): The device where the array's memory allocation resides.
|
|
1596
|
+
pinned (bool): Indicates whether the array was allocated in pinned host memory.
|
|
1597
|
+
is_contiguous (bool): Indicates whether this array has a contiguous memory layout.
|
|
1598
|
+
deleter (Callable[[int, int], None]): A function to be called when the array is deleted,
|
|
1599
|
+
taking two arguments: pointer and size. If ``None``, then no function is called.
|
|
1600
|
+
"""
|
|
1601
|
+
|
|
1584
1602
|
# member attributes available during code-gen (e.g.: d = array.shape[0])
|
|
1585
1603
|
# (initialized when needed)
|
|
1586
1604
|
_vars = None
|
|
@@ -1592,21 +1610,21 @@ class array(Array):
|
|
|
1592
1610
|
|
|
1593
1611
|
def __init__(
|
|
1594
1612
|
self,
|
|
1595
|
-
data=None,
|
|
1596
|
-
dtype: DType = Any,
|
|
1597
|
-
shape=None,
|
|
1598
|
-
strides=None,
|
|
1599
|
-
length=None,
|
|
1600
|
-
ptr=None,
|
|
1601
|
-
capacity=None,
|
|
1613
|
+
data: Optional[Union[List, Tuple, npt.NDArray]] = None,
|
|
1614
|
+
dtype: Union[DType, Any] = Any,
|
|
1615
|
+
shape: Optional[Tuple[int, ...]] = None,
|
|
1616
|
+
strides: Optional[Tuple[int, ...]] = None,
|
|
1617
|
+
length: Optional[int] = None,
|
|
1618
|
+
ptr: Optional[int] = None,
|
|
1619
|
+
capacity: Optional[int] = None,
|
|
1602
1620
|
device=None,
|
|
1603
|
-
pinned=False,
|
|
1604
|
-
copy=True,
|
|
1605
|
-
owner=False, # deprecated - pass deleter instead
|
|
1606
|
-
deleter=None,
|
|
1607
|
-
ndim=None,
|
|
1608
|
-
grad=None,
|
|
1609
|
-
requires_grad=False,
|
|
1621
|
+
pinned: bool = False,
|
|
1622
|
+
copy: bool = True,
|
|
1623
|
+
owner: bool = False, # deprecated - pass deleter instead
|
|
1624
|
+
deleter: Optional[Callable[[int, int], None]] = None,
|
|
1625
|
+
ndim: Optional[int] = None,
|
|
1626
|
+
grad: Optional[array] = None,
|
|
1627
|
+
requires_grad: bool = False,
|
|
1610
1628
|
):
|
|
1611
1629
|
"""Constructs a new Warp array object
|
|
1612
1630
|
|
|
@@ -1628,20 +1646,24 @@ class array(Array):
|
|
|
1628
1646
|
are taken into account and no memory is allocated for the array.
|
|
1629
1647
|
|
|
1630
1648
|
Args:
|
|
1631
|
-
data
|
|
1632
|
-
dtype
|
|
1633
|
-
shape
|
|
1634
|
-
strides
|
|
1635
|
-
length
|
|
1636
|
-
ptr
|
|
1637
|
-
capacity
|
|
1649
|
+
data: An object to construct the array from, can be a Tuple, List, or generally any type convertible to an np.array
|
|
1650
|
+
dtype: One of the available `data types <#data-types>`_, such as :class:`warp.float32`, :class:`warp.mat33`, or a custom `struct <#structs>`_. If dtype is ``Any`` and data is an ndarray, then it will be inferred from the array data type
|
|
1651
|
+
shape: Dimensions of the array
|
|
1652
|
+
strides: Number of bytes in each dimension between successive elements of the array
|
|
1653
|
+
length: Number of elements of the data type (deprecated, users should use ``shape`` argument)
|
|
1654
|
+
ptr: Address of an external memory address to alias (``data`` should be ``None``)
|
|
1655
|
+
capacity: Maximum size in bytes of the ``ptr`` allocation (``data`` should be ``None``)
|
|
1638
1656
|
device (Devicelike): Device the array lives on
|
|
1639
|
-
copy
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
|
|
1644
|
-
|
|
1657
|
+
copy: Whether the incoming ``data`` will be copied or aliased. Aliasing requires that
|
|
1658
|
+
the incoming ``data`` already lives on the ``device`` specified and the data types match.
|
|
1659
|
+
owner: Whether the array will try to deallocate the underlying memory when it is deleted
|
|
1660
|
+
(deprecated, pass ``deleter`` if you wish to transfer ownership to Warp)
|
|
1661
|
+
deleter: Function to be called when the array is deleted, taking two arguments: pointer and size
|
|
1662
|
+
requires_grad: Whether or not gradients will be tracked for this array, see :class:`warp.Tape` for details
|
|
1663
|
+
grad: The array in which to accumulate gradients in the backward pass. If ``None`` and ``requires_grad`` is ``True``,
|
|
1664
|
+
then a gradient array will be allocated automatically.
|
|
1665
|
+
pinned: Whether to allocate pinned host memory, which allows asynchronous host–device transfers
|
|
1666
|
+
(only applicable with ``device="cpu"``)
|
|
1645
1667
|
|
|
1646
1668
|
"""
|
|
1647
1669
|
|
|
@@ -2963,6 +2985,116 @@ def array_type_id(a):
|
|
|
2963
2985
|
raise ValueError("Invalid array type")
|
|
2964
2986
|
|
|
2965
2987
|
|
|
2988
|
+
# tile expression objects
|
|
2989
|
+
class Tile:
|
|
2990
|
+
alignment = 16
|
|
2991
|
+
|
|
2992
|
+
def __init__(self, dtype, M, N, op=None, storage="register", layout="rowmajor", strides=None, owner=True):
|
|
2993
|
+
self.dtype = type_to_warp(dtype)
|
|
2994
|
+
self.M = M
|
|
2995
|
+
self.N = N
|
|
2996
|
+
self.op = op
|
|
2997
|
+
self.storage = storage
|
|
2998
|
+
self.layout = layout
|
|
2999
|
+
|
|
3000
|
+
if strides is None:
|
|
3001
|
+
if layout == "rowmajor":
|
|
3002
|
+
self.strides = (N, 1)
|
|
3003
|
+
elif layout == "colmajor":
|
|
3004
|
+
self.strides = (1, M)
|
|
3005
|
+
else:
|
|
3006
|
+
self.strides = strides
|
|
3007
|
+
|
|
3008
|
+
self.owner = owner
|
|
3009
|
+
|
|
3010
|
+
# generates C-type string
|
|
3011
|
+
def ctype(self):
|
|
3012
|
+
from warp.codegen import Var
|
|
3013
|
+
|
|
3014
|
+
if self.storage == "register":
|
|
3015
|
+
return f"wp::tile_register_t<{Var.type_to_ctype(self.dtype)},{self.M},{self.N}>"
|
|
3016
|
+
elif self.storage == "shared":
|
|
3017
|
+
return f"wp::tile_shared_t<{Var.type_to_ctype(self.dtype)},{self.M},{self.N},{self.strides[0]}, {self.strides[1]}, {'true' if self.owner else 'false'}>"
|
|
3018
|
+
else:
|
|
3019
|
+
raise RuntimeError(f"Unrecognized tile storage type {self.storage}")
|
|
3020
|
+
|
|
3021
|
+
# generates C-initializer string
|
|
3022
|
+
def cinit(self, requires_grad=False):
|
|
3023
|
+
from warp.codegen import Var
|
|
3024
|
+
|
|
3025
|
+
if self.storage == "register":
|
|
3026
|
+
return self.ctype() + "(0.0)"
|
|
3027
|
+
elif self.storage == "shared":
|
|
3028
|
+
if self.owner:
|
|
3029
|
+
# allocate new shared memory tile
|
|
3030
|
+
return f"wp::tile_alloc_empty<{Var.type_to_ctype(self.dtype)},{self.M},{self.N},{'true' if requires_grad else 'false'}>()"
|
|
3031
|
+
else:
|
|
3032
|
+
# tile will be initialized by another call, e.g.: tile_transpose()
|
|
3033
|
+
return "NULL"
|
|
3034
|
+
|
|
3035
|
+
# return total tile size in bytes
|
|
3036
|
+
def size_in_bytes(self):
|
|
3037
|
+
num_bytes = self.align(type_size_in_bytes(self.dtype) * self.M * self.N)
|
|
3038
|
+
return num_bytes
|
|
3039
|
+
|
|
3040
|
+
# align tile size to natural boundary, default 16-bytes
|
|
3041
|
+
def align(self, bytes):
|
|
3042
|
+
return ((bytes + self.alignment - 1) // self.alignment) * self.alignment
|
|
3043
|
+
|
|
3044
|
+
|
|
3045
|
+
class TileZeros(Tile):
|
|
3046
|
+
def __init__(self, dtype, M, N, storage="register"):
|
|
3047
|
+
Tile.__init__(self, dtype, M, N, op="zeros", storage=storage)
|
|
3048
|
+
|
|
3049
|
+
|
|
3050
|
+
class TileRange(Tile):
|
|
3051
|
+
def __init__(self, dtype, start, stop, step, storage="register"):
|
|
3052
|
+
self.start = start
|
|
3053
|
+
self.stop = stop
|
|
3054
|
+
self.step = step
|
|
3055
|
+
|
|
3056
|
+
M = 1
|
|
3057
|
+
N = int((stop - start) / step)
|
|
3058
|
+
|
|
3059
|
+
Tile.__init__(self, dtype, M, N, op="arange", storage=storage)
|
|
3060
|
+
|
|
3061
|
+
|
|
3062
|
+
class TileConstant(Tile):
|
|
3063
|
+
def __init__(self, dtype, M, N):
|
|
3064
|
+
Tile.__init__(self, dtype, M, N, op="constant", storage="register")
|
|
3065
|
+
|
|
3066
|
+
|
|
3067
|
+
class TileLoad(Tile):
|
|
3068
|
+
def __init__(self, array, M, N, storage="register"):
|
|
3069
|
+
Tile.__init__(self, array.dtype, M, N, op="load", storage=storage)
|
|
3070
|
+
|
|
3071
|
+
|
|
3072
|
+
class TileUnaryMap(Tile):
|
|
3073
|
+
def __init__(self, t, storage="register"):
|
|
3074
|
+
Tile.__init__(self, t.dtype, t.M, t.N, op="unary_map", storage=storage)
|
|
3075
|
+
|
|
3076
|
+
self.t = t
|
|
3077
|
+
|
|
3078
|
+
|
|
3079
|
+
class TileBinaryMap(Tile):
|
|
3080
|
+
def __init__(self, a, b, storage="register"):
|
|
3081
|
+
Tile.__init__(self, a.dtype, a.M, a.N, op="binary_map", storage=storage)
|
|
3082
|
+
|
|
3083
|
+
self.a = a
|
|
3084
|
+
self.b = b
|
|
3085
|
+
|
|
3086
|
+
|
|
3087
|
+
class TileShared(Tile):
|
|
3088
|
+
def __init__(self, t):
|
|
3089
|
+
Tile.__init__(self, t.dtype, t.M, t.N, "shared", storage="shared")
|
|
3090
|
+
|
|
3091
|
+
self.t = t
|
|
3092
|
+
|
|
3093
|
+
|
|
3094
|
+
def is_tile(t):
|
|
3095
|
+
return isinstance(t, Tile)
|
|
3096
|
+
|
|
3097
|
+
|
|
2966
3098
|
class Bvh:
|
|
2967
3099
|
def __new__(cls, *args, **kwargs):
|
|
2968
3100
|
instance = super(Bvh, cls).__new__(cls)
|
|
@@ -3544,9 +3676,9 @@ class Volume:
|
|
|
3544
3676
|
grid_data = bytearray()
|
|
3545
3677
|
while grid_data_offset < file_end:
|
|
3546
3678
|
chunk_size = struct.unpack("<Q", data[grid_data_offset : grid_data_offset + 8])[0]
|
|
3547
|
-
|
|
3548
|
-
|
|
3549
|
-
|
|
3679
|
+
grid_data_offset += 8
|
|
3680
|
+
grid_data += zlib.decompress(data[grid_data_offset : grid_data_offset + chunk_size])
|
|
3681
|
+
grid_data_offset += chunk_size
|
|
3550
3682
|
elif codec == 2: # blosc compression
|
|
3551
3683
|
try:
|
|
3552
3684
|
import blosc
|
|
@@ -3558,8 +3690,9 @@ class Volume:
|
|
|
3558
3690
|
grid_data = bytearray()
|
|
3559
3691
|
while grid_data_offset < file_end:
|
|
3560
3692
|
chunk_size = struct.unpack("<Q", data[grid_data_offset : grid_data_offset + 8])[0]
|
|
3561
|
-
|
|
3562
|
-
|
|
3693
|
+
grid_data_offset += 8
|
|
3694
|
+
grid_data += blosc.decompress(data[grid_data_offset : grid_data_offset + chunk_size])
|
|
3695
|
+
grid_data_offset += chunk_size
|
|
3563
3696
|
else:
|
|
3564
3697
|
raise RuntimeError(f"Unsupported codec code: {codec}")
|
|
3565
3698
|
|
|
@@ -3570,6 +3703,139 @@ class Volume:
|
|
|
3570
3703
|
data_array = array(np.frombuffer(grid_data, dtype=np.byte), device=device)
|
|
3571
3704
|
return cls(data_array)
|
|
3572
3705
|
|
|
3706
|
+
def save_to_nvdb(self, path, codec: Literal["none", "zip", "blosc"] = "none"):
|
|
3707
|
+
"""Serialize the Volume into a NanoVDB (.nvdb) file.
|
|
3708
|
+
|
|
3709
|
+
Args:
|
|
3710
|
+
path: File path to save.
|
|
3711
|
+
codec: Compression codec used
|
|
3712
|
+
"none" - no compression
|
|
3713
|
+
"zip" - ZIP compression
|
|
3714
|
+
"blosc" - BLOSC compression, requires the blosc module to be installed
|
|
3715
|
+
"""
|
|
3716
|
+
|
|
3717
|
+
codec_dict = {"none": 0, "zip": 1, "blosc": 2}
|
|
3718
|
+
|
|
3719
|
+
class FileHeader(ctypes.Structure):
|
|
3720
|
+
_fields_ = [
|
|
3721
|
+
("magic", ctypes.c_uint64),
|
|
3722
|
+
("version", ctypes.c_uint32),
|
|
3723
|
+
("gridCount", ctypes.c_uint16),
|
|
3724
|
+
("codec", ctypes.c_uint16),
|
|
3725
|
+
]
|
|
3726
|
+
|
|
3727
|
+
class FileMetaData(ctypes.Structure):
|
|
3728
|
+
_fields_ = [
|
|
3729
|
+
("gridSize", ctypes.c_uint64),
|
|
3730
|
+
("fileSize", ctypes.c_uint64),
|
|
3731
|
+
("nameKey", ctypes.c_uint64),
|
|
3732
|
+
("voxelCount", ctypes.c_uint64),
|
|
3733
|
+
("gridType", ctypes.c_uint32),
|
|
3734
|
+
("gridClass", ctypes.c_uint32),
|
|
3735
|
+
("worldBBox", ctypes.c_double * 6),
|
|
3736
|
+
("indexBBox", ctypes.c_uint32 * 6),
|
|
3737
|
+
("voxelSize", ctypes.c_double * 3),
|
|
3738
|
+
("nameSize", ctypes.c_uint32),
|
|
3739
|
+
("nodeCount", ctypes.c_uint32 * 4),
|
|
3740
|
+
("tileCount", ctypes.c_uint32 * 3),
|
|
3741
|
+
("codec", ctypes.c_uint16),
|
|
3742
|
+
("padding", ctypes.c_uint16),
|
|
3743
|
+
("version", ctypes.c_uint32),
|
|
3744
|
+
]
|
|
3745
|
+
|
|
3746
|
+
class GridData(ctypes.Structure):
|
|
3747
|
+
_fields_ = [
|
|
3748
|
+
("magic", ctypes.c_uint64),
|
|
3749
|
+
("checksum", ctypes.c_uint64),
|
|
3750
|
+
("version", ctypes.c_uint32),
|
|
3751
|
+
("flags", ctypes.c_uint32),
|
|
3752
|
+
("gridIndex", ctypes.c_uint32),
|
|
3753
|
+
("gridCount", ctypes.c_uint32),
|
|
3754
|
+
("gridSize", ctypes.c_uint64),
|
|
3755
|
+
("gridName", ctypes.c_char * 256),
|
|
3756
|
+
("map", ctypes.c_byte * 264),
|
|
3757
|
+
("worldBBox", ctypes.c_double * 6),
|
|
3758
|
+
("voxelSize", ctypes.c_double * 3),
|
|
3759
|
+
("gridClass", ctypes.c_uint32),
|
|
3760
|
+
("gridType", ctypes.c_uint32),
|
|
3761
|
+
("blindMetadataOffset", ctypes.c_int64),
|
|
3762
|
+
("blindMetadataCount", ctypes.c_uint32),
|
|
3763
|
+
("data0", ctypes.c_uint32),
|
|
3764
|
+
("data1", ctypes.c_uint64),
|
|
3765
|
+
("data2", ctypes.c_uint64),
|
|
3766
|
+
]
|
|
3767
|
+
|
|
3768
|
+
NVDB_MAGIC = 0x304244566F6E614E
|
|
3769
|
+
NVDB_VERSION = 32 << 21 | 3 << 10 | 3
|
|
3770
|
+
|
|
3771
|
+
try:
|
|
3772
|
+
codec_int = codec_dict[codec]
|
|
3773
|
+
except KeyError as err:
|
|
3774
|
+
raise RuntimeError(f"Unsupported codec requested: {codec}") from err
|
|
3775
|
+
|
|
3776
|
+
if codec_int == 2:
|
|
3777
|
+
try:
|
|
3778
|
+
import blosc
|
|
3779
|
+
except ImportError as err:
|
|
3780
|
+
raise RuntimeError(
|
|
3781
|
+
f"blosc compression was requested, but Python module could not be imported: {err}"
|
|
3782
|
+
) from err
|
|
3783
|
+
|
|
3784
|
+
data = self.array().numpy()
|
|
3785
|
+
grid_data = GridData.from_buffer(data)
|
|
3786
|
+
|
|
3787
|
+
if grid_data.gridIndex > 0:
|
|
3788
|
+
raise RuntimeError(
|
|
3789
|
+
"Saving of aliased Volumes is not supported. Use `save_to_nvdb` on the original volume, before any `load_next_grid` calls."
|
|
3790
|
+
)
|
|
3791
|
+
|
|
3792
|
+
file_header = FileHeader(NVDB_MAGIC, NVDB_VERSION, grid_data.gridCount, codec_int)
|
|
3793
|
+
|
|
3794
|
+
grid_data_offset = 0
|
|
3795
|
+
all_file_meta_data = []
|
|
3796
|
+
for i in range(file_header.gridCount):
|
|
3797
|
+
if i > 0:
|
|
3798
|
+
grid_data = GridData.from_buffer(data[grid_data_offset : grid_data_offset + 672])
|
|
3799
|
+
current_grid_data = data[grid_data_offset : grid_data_offset + grid_data.gridSize]
|
|
3800
|
+
if codec_int == 1: # zip compression
|
|
3801
|
+
compressed_data = zlib.compress(current_grid_data)
|
|
3802
|
+
compressed_size = len(compressed_data)
|
|
3803
|
+
elif codec_int == 2: # blosc compression
|
|
3804
|
+
compressed_data = blosc.compress(current_grid_data)
|
|
3805
|
+
compressed_size = len(compressed_data)
|
|
3806
|
+
else: # no compression
|
|
3807
|
+
compressed_data = current_grid_data
|
|
3808
|
+
compressed_size = grid_data.gridSize
|
|
3809
|
+
|
|
3810
|
+
file_meta_data = FileMetaData()
|
|
3811
|
+
file_meta_data.gridSize = grid_data.gridSize
|
|
3812
|
+
file_meta_data.fileSize = compressed_size
|
|
3813
|
+
file_meta_data.gridType = grid_data.gridType
|
|
3814
|
+
file_meta_data.gridClass = grid_data.gridClass
|
|
3815
|
+
file_meta_data.worldBBox = grid_data.worldBBox
|
|
3816
|
+
file_meta_data.voxelSize = grid_data.voxelSize
|
|
3817
|
+
file_meta_data.nameSize = len(grid_data.gridName) + 1 # including the closing 0x0
|
|
3818
|
+
file_meta_data.codec = codec_int
|
|
3819
|
+
file_meta_data.version = NVDB_VERSION
|
|
3820
|
+
|
|
3821
|
+
grid_data_offset += file_meta_data.gridSize
|
|
3822
|
+
|
|
3823
|
+
all_file_meta_data.append((file_meta_data, grid_data.gridName, compressed_data))
|
|
3824
|
+
|
|
3825
|
+
with open(path, "wb") as nvdb:
|
|
3826
|
+
nvdb.write(file_header)
|
|
3827
|
+
for file_meta_data, grid_name, _ in all_file_meta_data:
|
|
3828
|
+
nvdb.write(file_meta_data)
|
|
3829
|
+
nvdb.write(grid_name + b"\x00")
|
|
3830
|
+
|
|
3831
|
+
for file_meta_data, _, compressed_data in all_file_meta_data:
|
|
3832
|
+
if codec_int > 0:
|
|
3833
|
+
chunk_size = struct.pack("<Q", file_meta_data.fileSize)
|
|
3834
|
+
nvdb.write(chunk_size)
|
|
3835
|
+
nvdb.write(compressed_data)
|
|
3836
|
+
|
|
3837
|
+
return path
|
|
3838
|
+
|
|
3573
3839
|
@classmethod
|
|
3574
3840
|
def load_from_address(cls, grid_ptr: int, buffer_size: int = 0, device=None) -> Volume:
|
|
3575
3841
|
"""
|
warp/utils.py
CHANGED
|
@@ -18,6 +18,7 @@ import numpy as np
|
|
|
18
18
|
import warp as wp
|
|
19
19
|
import warp.context
|
|
20
20
|
import warp.types
|
|
21
|
+
from warp.context import Devicelike
|
|
21
22
|
|
|
22
23
|
warnings_seen = set()
|
|
23
24
|
|
|
@@ -554,7 +555,27 @@ def mem_report(): # pragma: no cover
|
|
|
554
555
|
|
|
555
556
|
|
|
556
557
|
class ScopedDevice:
|
|
557
|
-
|
|
558
|
+
"""A context manager to temporarily change the current default device.
|
|
559
|
+
|
|
560
|
+
For CUDA devices, this context manager makes the device's CUDA context
|
|
561
|
+
current and restores the previous CUDA context on exit. This is handy when
|
|
562
|
+
running Warp scripts as part of a bigger pipeline because it avoids any side
|
|
563
|
+
effects of changing the CUDA context in the enclosed code.
|
|
564
|
+
|
|
565
|
+
Attributes:
|
|
566
|
+
device (Device): The device that will temporarily become the default
|
|
567
|
+
device within the context.
|
|
568
|
+
saved_device (Device): The previous default device. This is restored as
|
|
569
|
+
the default device on exiting the context.
|
|
570
|
+
"""
|
|
571
|
+
|
|
572
|
+
def __init__(self, device: Devicelike):
|
|
573
|
+
"""Initializes the context manager with a device.
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
device: The device that will temporarily become the default device
|
|
577
|
+
within the context.
|
|
578
|
+
"""
|
|
558
579
|
self.device = wp.get_device(device)
|
|
559
580
|
|
|
560
581
|
def __enter__(self):
|