warp-lang 1.0.1__py3-none-manylinux2014_aarch64.whl → 1.1.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -279
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -28
- warp/examples/core/example_dem.py +234 -221
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -129
- warp/examples/core/example_marching_cubes.py +188 -176
- warp/examples/core/example_mesh.py +174 -154
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -169
- warp/examples/core/example_raycast.py +105 -89
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -389
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -249
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -391
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -248
- warp/examples/optim/example_cloth_throw.py +222 -210
- warp/examples/optim/example_diffray.py +566 -535
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -169
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -170
- warp/examples/optim/example_spring_cage.py +239 -234
- warp/examples/optim/example_trajectory.py +223 -201
- warp/examples/optim/example_walker.py +306 -292
- warp/examples/sim/example_cartpole.py +139 -128
- warp/examples/sim/example_cloth.py +196 -184
- warp/examples/sim/example_granular.py +124 -113
- warp/examples/sim/example_granular_collision_sdf.py +197 -185
- warp/examples/sim/example_jacobian_ik.py +236 -213
- warp/examples/sim/example_particle_chain.py +118 -106
- warp/examples/sim/example_quadruped.py +193 -179
- warp/examples/sim/example_rigid_chain.py +197 -189
- warp/examples/sim/example_rigid_contact.py +189 -176
- warp/examples/sim/example_rigid_force.py +127 -126
- warp/examples/sim/example_rigid_gyroscopic.py +109 -97
- warp/examples/sim/example_rigid_soft_contact.py +134 -124
- warp/examples/sim/example_soft_body.py +190 -178
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.1.dist-info/RECORD +0 -352
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.1.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/thirdparty/dlpack.py
CHANGED
|
@@ -1,143 +1,143 @@
|
|
|
1
|
-
import ctypes
|
|
2
|
-
|
|
3
|
-
_c_str_dltensor = b"dltensor"
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class DLDeviceType(ctypes.c_int):
|
|
7
|
-
"""The enum that encodes the type of the device where
|
|
8
|
-
DLTensor memory is allocated.
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
kDLCPU = 1
|
|
12
|
-
kDLCUDA = 2
|
|
13
|
-
kDLCUDAHost = 3
|
|
14
|
-
kDLOpenCL = 4
|
|
15
|
-
kDLVulkan = 7
|
|
16
|
-
kDLMetal = 8
|
|
17
|
-
kDLVPI = 9
|
|
18
|
-
kDLROCM = 10
|
|
19
|
-
kDLROCMHost = 11
|
|
20
|
-
kDLCUDAManaged = 13
|
|
21
|
-
kDLOneAPI = 14
|
|
22
|
-
|
|
23
|
-
def __str__(self):
|
|
24
|
-
return {
|
|
25
|
-
self.kDLCPU: "CPU",
|
|
26
|
-
self.kDLCUDA: "CUDA",
|
|
27
|
-
self.kDLCUDAHost: "CUDAHost",
|
|
28
|
-
self.kDLOpenCL: "OpenCL",
|
|
29
|
-
self.kDLVulkan: "Vulkan",
|
|
30
|
-
self.kDLMetal: "Metal",
|
|
31
|
-
self.kDLVPI: "VPI",
|
|
32
|
-
self.kDLROCM: "ROCM",
|
|
33
|
-
self.kDLROCMHost: "ROMCHost",
|
|
34
|
-
self.kDLCUDAManaged: "CUDAManaged",
|
|
35
|
-
self.kDLOneAPI: "oneAPI",
|
|
36
|
-
}[self.value]
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class DLDevice(ctypes.Structure):
|
|
40
|
-
"""Represents the device where DLTensor memory is allocated.
|
|
41
|
-
The device is represented by the pair of fields:
|
|
42
|
-
device_type: DLDeviceType
|
|
43
|
-
device_id: c_int
|
|
44
|
-
"""
|
|
45
|
-
|
|
46
|
-
_fields_ = [
|
|
47
|
-
("device_type", DLDeviceType),
|
|
48
|
-
("device_id", ctypes.c_int),
|
|
49
|
-
]
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class DLDataTypeCode(ctypes.c_uint8):
|
|
53
|
-
"""An integer that encodes the category of DLTensor elements' data type."""
|
|
54
|
-
|
|
55
|
-
kDLInt = 0
|
|
56
|
-
kDLUInt = 1
|
|
57
|
-
kDLFloat = 2
|
|
58
|
-
kDLOpaquePointer = 3
|
|
59
|
-
kDLBfloat = 4
|
|
60
|
-
kDLComplex = 5
|
|
61
|
-
|
|
62
|
-
def __str__(self):
|
|
63
|
-
return {
|
|
64
|
-
self.kDLInt: "int",
|
|
65
|
-
self.kDLUInt: "uint",
|
|
66
|
-
self.kDLFloat: "float",
|
|
67
|
-
self.kDLBfloat: "bfloat",
|
|
68
|
-
self.kDLComplex: "complex",
|
|
69
|
-
self.kDLOpaquePointer: "void_p",
|
|
70
|
-
}[self.value]
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class DLDataType(ctypes.Structure):
|
|
74
|
-
"""Descriptor of data type for elements of DLTensor.
|
|
75
|
-
The data type is described by a triple, `DLDataType.type_code`,
|
|
76
|
-
`DLDataType.bits`, and `DLDataType.lanes`.
|
|
77
|
-
|
|
78
|
-
The element is understood as packed `lanes` repetitions of
|
|
79
|
-
elements from `type_code` data-category of width `bits`.
|
|
80
|
-
"""
|
|
81
|
-
|
|
82
|
-
_fields_ = [
|
|
83
|
-
("type_code", DLDataTypeCode),
|
|
84
|
-
("bits", ctypes.c_uint8),
|
|
85
|
-
("lanes", ctypes.c_uint16),
|
|
86
|
-
]
|
|
87
|
-
TYPE_MAP = {
|
|
88
|
-
"bool": (DLDataTypeCode.kDLUInt, 1, 1),
|
|
89
|
-
"int8": (DLDataTypeCode.kDLInt, 8, 1),
|
|
90
|
-
"int16": (DLDataTypeCode.kDLInt, 16, 1),
|
|
91
|
-
"int32": (DLDataTypeCode.kDLInt, 32, 1),
|
|
92
|
-
"int64": (DLDataTypeCode.kDLInt, 64, 1),
|
|
93
|
-
"uint8": (DLDataTypeCode.kDLUInt, 8, 1),
|
|
94
|
-
"uint16": (DLDataTypeCode.kDLUInt, 16, 1),
|
|
95
|
-
"uint32": (DLDataTypeCode.kDLUInt, 32, 1),
|
|
96
|
-
"uint64": (DLDataTypeCode.kDLUInt, 64, 1),
|
|
97
|
-
"float16": (DLDataTypeCode.kDLFloat, 16, 1),
|
|
98
|
-
"float32": (DLDataTypeCode.kDLFloat, 32, 1),
|
|
99
|
-
"float64": (DLDataTypeCode.kDLFloat, 64, 1),
|
|
100
|
-
"complex64": (DLDataTypeCode.kDLComplex, 64, 1),
|
|
101
|
-
"complex128": (DLDataTypeCode.kDLComplex, 128, 1),
|
|
102
|
-
}
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
class DLTensor(ctypes.Structure):
|
|
106
|
-
"""Structure describing strided layout of DLTensor.
|
|
107
|
-
Fields are:
|
|
108
|
-
data: void pointer
|
|
109
|
-
device: DLDevice
|
|
110
|
-
ndim: number of indices needed to reference an
|
|
111
|
-
element of the tensor
|
|
112
|
-
dtype: data type descriptor
|
|
113
|
-
shape: tuple with lengths of the corresponding
|
|
114
|
-
tensor dimensions
|
|
115
|
-
strides: tuple of numbers of array elements to
|
|
116
|
-
step in each dimension when traversing
|
|
117
|
-
the tensor
|
|
118
|
-
byte_offset: data + byte_offset gives the address of
|
|
119
|
-
tensor element with index (0,) * ndim
|
|
120
|
-
"""
|
|
121
|
-
|
|
122
|
-
_fields_ = [
|
|
123
|
-
("data", ctypes.c_void_p),
|
|
124
|
-
("device", DLDevice),
|
|
125
|
-
("ndim", ctypes.c_int),
|
|
126
|
-
("dtype", DLDataType),
|
|
127
|
-
("shape", ctypes.POINTER(ctypes.c_int64)),
|
|
128
|
-
("strides", ctypes.POINTER(ctypes.c_int64)),
|
|
129
|
-
("byte_offset", ctypes.c_uint64),
|
|
130
|
-
]
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
class DLManagedTensor(ctypes.Structure):
|
|
134
|
-
"""Structure storing the pointer to the tensor descriptor,
|
|
135
|
-
deleter callable for the tensor descriptor, and pointer to
|
|
136
|
-
some additional data. These are stored in fields `dl_tensor`,
|
|
137
|
-
`deleter`, and `manager_ctx`."""
|
|
138
|
-
|
|
139
|
-
_fields_ = [
|
|
140
|
-
("dl_tensor", DLTensor),
|
|
141
|
-
("manager_ctx", ctypes.c_void_p),
|
|
142
|
-
("deleter", ctypes.CFUNCTYPE(None, ctypes.c_void_p)),
|
|
143
|
-
]
|
|
1
|
+
import ctypes
|
|
2
|
+
|
|
3
|
+
_c_str_dltensor = b"dltensor"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DLDeviceType(ctypes.c_int):
|
|
7
|
+
"""The enum that encodes the type of the device where
|
|
8
|
+
DLTensor memory is allocated.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
kDLCPU = 1
|
|
12
|
+
kDLCUDA = 2
|
|
13
|
+
kDLCUDAHost = 3
|
|
14
|
+
kDLOpenCL = 4
|
|
15
|
+
kDLVulkan = 7
|
|
16
|
+
kDLMetal = 8
|
|
17
|
+
kDLVPI = 9
|
|
18
|
+
kDLROCM = 10
|
|
19
|
+
kDLROCMHost = 11
|
|
20
|
+
kDLCUDAManaged = 13
|
|
21
|
+
kDLOneAPI = 14
|
|
22
|
+
|
|
23
|
+
def __str__(self):
|
|
24
|
+
return {
|
|
25
|
+
self.kDLCPU: "CPU",
|
|
26
|
+
self.kDLCUDA: "CUDA",
|
|
27
|
+
self.kDLCUDAHost: "CUDAHost",
|
|
28
|
+
self.kDLOpenCL: "OpenCL",
|
|
29
|
+
self.kDLVulkan: "Vulkan",
|
|
30
|
+
self.kDLMetal: "Metal",
|
|
31
|
+
self.kDLVPI: "VPI",
|
|
32
|
+
self.kDLROCM: "ROCM",
|
|
33
|
+
self.kDLROCMHost: "ROMCHost",
|
|
34
|
+
self.kDLCUDAManaged: "CUDAManaged",
|
|
35
|
+
self.kDLOneAPI: "oneAPI",
|
|
36
|
+
}[self.value]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DLDevice(ctypes.Structure):
|
|
40
|
+
"""Represents the device where DLTensor memory is allocated.
|
|
41
|
+
The device is represented by the pair of fields:
|
|
42
|
+
device_type: DLDeviceType
|
|
43
|
+
device_id: c_int
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
_fields_ = [
|
|
47
|
+
("device_type", DLDeviceType),
|
|
48
|
+
("device_id", ctypes.c_int),
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DLDataTypeCode(ctypes.c_uint8):
|
|
53
|
+
"""An integer that encodes the category of DLTensor elements' data type."""
|
|
54
|
+
|
|
55
|
+
kDLInt = 0
|
|
56
|
+
kDLUInt = 1
|
|
57
|
+
kDLFloat = 2
|
|
58
|
+
kDLOpaquePointer = 3
|
|
59
|
+
kDLBfloat = 4
|
|
60
|
+
kDLComplex = 5
|
|
61
|
+
|
|
62
|
+
def __str__(self):
|
|
63
|
+
return {
|
|
64
|
+
self.kDLInt: "int",
|
|
65
|
+
self.kDLUInt: "uint",
|
|
66
|
+
self.kDLFloat: "float",
|
|
67
|
+
self.kDLBfloat: "bfloat",
|
|
68
|
+
self.kDLComplex: "complex",
|
|
69
|
+
self.kDLOpaquePointer: "void_p",
|
|
70
|
+
}[self.value]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class DLDataType(ctypes.Structure):
|
|
74
|
+
"""Descriptor of data type for elements of DLTensor.
|
|
75
|
+
The data type is described by a triple, `DLDataType.type_code`,
|
|
76
|
+
`DLDataType.bits`, and `DLDataType.lanes`.
|
|
77
|
+
|
|
78
|
+
The element is understood as packed `lanes` repetitions of
|
|
79
|
+
elements from `type_code` data-category of width `bits`.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
_fields_ = [
|
|
83
|
+
("type_code", DLDataTypeCode),
|
|
84
|
+
("bits", ctypes.c_uint8),
|
|
85
|
+
("lanes", ctypes.c_uint16),
|
|
86
|
+
]
|
|
87
|
+
TYPE_MAP = {
|
|
88
|
+
"bool": (DLDataTypeCode.kDLUInt, 1, 1),
|
|
89
|
+
"int8": (DLDataTypeCode.kDLInt, 8, 1),
|
|
90
|
+
"int16": (DLDataTypeCode.kDLInt, 16, 1),
|
|
91
|
+
"int32": (DLDataTypeCode.kDLInt, 32, 1),
|
|
92
|
+
"int64": (DLDataTypeCode.kDLInt, 64, 1),
|
|
93
|
+
"uint8": (DLDataTypeCode.kDLUInt, 8, 1),
|
|
94
|
+
"uint16": (DLDataTypeCode.kDLUInt, 16, 1),
|
|
95
|
+
"uint32": (DLDataTypeCode.kDLUInt, 32, 1),
|
|
96
|
+
"uint64": (DLDataTypeCode.kDLUInt, 64, 1),
|
|
97
|
+
"float16": (DLDataTypeCode.kDLFloat, 16, 1),
|
|
98
|
+
"float32": (DLDataTypeCode.kDLFloat, 32, 1),
|
|
99
|
+
"float64": (DLDataTypeCode.kDLFloat, 64, 1),
|
|
100
|
+
"complex64": (DLDataTypeCode.kDLComplex, 64, 1),
|
|
101
|
+
"complex128": (DLDataTypeCode.kDLComplex, 128, 1),
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class DLTensor(ctypes.Structure):
|
|
106
|
+
"""Structure describing strided layout of DLTensor.
|
|
107
|
+
Fields are:
|
|
108
|
+
data: void pointer
|
|
109
|
+
device: DLDevice
|
|
110
|
+
ndim: number of indices needed to reference an
|
|
111
|
+
element of the tensor
|
|
112
|
+
dtype: data type descriptor
|
|
113
|
+
shape: tuple with lengths of the corresponding
|
|
114
|
+
tensor dimensions
|
|
115
|
+
strides: tuple of numbers of array elements to
|
|
116
|
+
step in each dimension when traversing
|
|
117
|
+
the tensor
|
|
118
|
+
byte_offset: data + byte_offset gives the address of
|
|
119
|
+
tensor element with index (0,) * ndim
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
_fields_ = [
|
|
123
|
+
("data", ctypes.c_void_p),
|
|
124
|
+
("device", DLDevice),
|
|
125
|
+
("ndim", ctypes.c_int),
|
|
126
|
+
("dtype", DLDataType),
|
|
127
|
+
("shape", ctypes.POINTER(ctypes.c_int64)),
|
|
128
|
+
("strides", ctypes.POINTER(ctypes.c_int64)),
|
|
129
|
+
("byte_offset", ctypes.c_uint64),
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class DLManagedTensor(ctypes.Structure):
|
|
134
|
+
"""Structure storing the pointer to the tensor descriptor,
|
|
135
|
+
deleter callable for the tensor descriptor, and pointer to
|
|
136
|
+
some additional data. These are stored in fields `dl_tensor`,
|
|
137
|
+
`deleter`, and `manager_ctx`."""
|
|
138
|
+
|
|
139
|
+
_fields_ = [
|
|
140
|
+
("dl_tensor", DLTensor),
|
|
141
|
+
("manager_ctx", ctypes.c_void_p),
|
|
142
|
+
("deleter", ctypes.CFUNCTYPE(None, ctypes.c_void_p)),
|
|
143
|
+
]
|