warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,602 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import ctypes
|
|
17
|
+
import enum
|
|
18
|
+
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
import warp as wp
|
|
23
|
+
|
|
24
|
+
#######################################################################
|
|
25
|
+
# ctypes structures and enums for XLA's FFI API:
|
|
26
|
+
# https://github.com/openxla/xla/blob/a1a5e62fbffa3a3b6c409d72607456cf5b353a22/xla/ffi/api/c_api.h
|
|
27
|
+
#######################################################################
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# typedef enum {
|
|
31
|
+
# XLA_FFI_Extension_Metadata = 1,
|
|
32
|
+
# } XLA_FFI_Extension_Type;
|
|
33
|
+
class XLA_FFI_Extension_Type(enum.IntEnum):
|
|
34
|
+
Metadata = 1
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# typedef struct XLA_FFI_Extension_Base {
|
|
38
|
+
# size_t struct_size;
|
|
39
|
+
# XLA_FFI_Extension_Type type;
|
|
40
|
+
# struct XLA_FFI_Extension_Base* next;
|
|
41
|
+
# } XLA_FFI_Extension_Base;
|
|
42
|
+
class XLA_FFI_Extension_Base(ctypes.Structure):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
XLA_FFI_Extension_Base._fields_ = [
|
|
47
|
+
("struct_size", ctypes.c_size_t),
|
|
48
|
+
("type", ctypes.c_int), # XLA_FFI_Extension_Type
|
|
49
|
+
("next", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# typedef enum {
|
|
54
|
+
# XLA_FFI_ExecutionStage_INSTANTIATE = 0,
|
|
55
|
+
# XLA_FFI_ExecutionStage_PREPARE = 1,
|
|
56
|
+
# XLA_FFI_ExecutionStage_INITIALIZE = 2,
|
|
57
|
+
# XLA_FFI_ExecutionStage_EXECUTE = 3,
|
|
58
|
+
# } XLA_FFI_ExecutionStage;
|
|
59
|
+
class XLA_FFI_ExecutionStage(enum.IntEnum):
|
|
60
|
+
INSTANTIATE = 0
|
|
61
|
+
PREPARE = 1
|
|
62
|
+
INITIALIZE = 2
|
|
63
|
+
EXECUTE = 3
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# typedef enum {
|
|
67
|
+
# XLA_FFI_DataType_INVALID = 0,
|
|
68
|
+
# XLA_FFI_DataType_PRED = 1,
|
|
69
|
+
# XLA_FFI_DataType_S8 = 2,
|
|
70
|
+
# XLA_FFI_DataType_S16 = 3,
|
|
71
|
+
# XLA_FFI_DataType_S32 = 4,
|
|
72
|
+
# XLA_FFI_DataType_S64 = 5,
|
|
73
|
+
# XLA_FFI_DataType_U8 = 6,
|
|
74
|
+
# XLA_FFI_DataType_U16 = 7,
|
|
75
|
+
# XLA_FFI_DataType_U32 = 8,
|
|
76
|
+
# XLA_FFI_DataType_U64 = 9,
|
|
77
|
+
# XLA_FFI_DataType_F16 = 10,
|
|
78
|
+
# XLA_FFI_DataType_F32 = 11,
|
|
79
|
+
# XLA_FFI_DataType_F64 = 12,
|
|
80
|
+
# XLA_FFI_DataType_BF16 = 16,
|
|
81
|
+
# XLA_FFI_DataType_C64 = 15,
|
|
82
|
+
# XLA_FFI_DataType_C128 = 18,
|
|
83
|
+
# XLA_FFI_DataType_TOKEN = 17,
|
|
84
|
+
# XLA_FFI_DataType_F8E5M2 = 19,
|
|
85
|
+
# XLA_FFI_DataType_F8E3M4 = 29,
|
|
86
|
+
# XLA_FFI_DataType_F8E4M3 = 28,
|
|
87
|
+
# XLA_FFI_DataType_F8E4M3FN = 20,
|
|
88
|
+
# XLA_FFI_DataType_F8E4M3B11FNUZ = 23,
|
|
89
|
+
# XLA_FFI_DataType_F8E5M2FNUZ = 24,
|
|
90
|
+
# XLA_FFI_DataType_F8E4M3FNUZ = 25,
|
|
91
|
+
# XLA_FFI_DataType_F4E2M1FN = 32,
|
|
92
|
+
# XLA_FFI_DataType_F8E8M0FNU = 33,
|
|
93
|
+
# } XLA_FFI_DataType;
|
|
94
|
+
class XLA_FFI_DataType(enum.IntEnum):
|
|
95
|
+
INVALID = 0
|
|
96
|
+
PRED = 1
|
|
97
|
+
S8 = 2
|
|
98
|
+
S16 = 3
|
|
99
|
+
S32 = 4
|
|
100
|
+
S64 = 5
|
|
101
|
+
U8 = 6
|
|
102
|
+
U16 = 7
|
|
103
|
+
U32 = 8
|
|
104
|
+
U64 = 9
|
|
105
|
+
F16 = 10
|
|
106
|
+
F32 = 11
|
|
107
|
+
F64 = 12
|
|
108
|
+
BF16 = 16
|
|
109
|
+
C64 = 15
|
|
110
|
+
C128 = 18
|
|
111
|
+
TOKEN = 17
|
|
112
|
+
F8E5M2 = 19
|
|
113
|
+
F8E3M4 = 29
|
|
114
|
+
F8E4M3 = 28
|
|
115
|
+
F8E4M3FN = 20
|
|
116
|
+
F8E4M3B11FNUZ = 23
|
|
117
|
+
F8E5M2FNUZ = 24
|
|
118
|
+
F8E4M3FNUZ = 25
|
|
119
|
+
F4E2M1FN = 32
|
|
120
|
+
F8E8M0FNU = 33
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
# struct XLA_FFI_Buffer {
|
|
124
|
+
# size_t struct_size;
|
|
125
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
126
|
+
#
|
|
127
|
+
# XLA_FFI_DataType dtype;
|
|
128
|
+
# void* data;
|
|
129
|
+
# int64_t rank;
|
|
130
|
+
# int64_t* dims; // length == rank
|
|
131
|
+
# };
|
|
132
|
+
class XLA_FFI_Buffer(ctypes.Structure):
|
|
133
|
+
_fields_ = [
|
|
134
|
+
("struct_size", ctypes.c_size_t),
|
|
135
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
136
|
+
("dtype", ctypes.c_int), # XLA_FFI_DataType
|
|
137
|
+
("data", ctypes.c_void_p),
|
|
138
|
+
("rank", ctypes.c_int64),
|
|
139
|
+
("dims", ctypes.POINTER(ctypes.c_int64)),
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# typedef enum {
|
|
144
|
+
# XLA_FFI_ArgType_BUFFER = 1,
|
|
145
|
+
# } XLA_FFI_ArgType;
|
|
146
|
+
class XLA_FFI_ArgType(enum.IntEnum):
|
|
147
|
+
BUFFER = 1
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# typedef enum {
|
|
151
|
+
# XLA_FFI_RetType_BUFFER = 1,
|
|
152
|
+
# } XLA_FFI_RetType;
|
|
153
|
+
class XLA_FFI_RetType(enum.IntEnum):
|
|
154
|
+
BUFFER = 1
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# struct XLA_FFI_Args {
|
|
158
|
+
# size_t struct_size;
|
|
159
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
160
|
+
# int64_t size;
|
|
161
|
+
# XLA_FFI_ArgType* types; // length == size
|
|
162
|
+
# void** args; // length == size
|
|
163
|
+
# };
|
|
164
|
+
class XLA_FFI_Args(ctypes.Structure):
|
|
165
|
+
_fields_ = [
|
|
166
|
+
("struct_size", ctypes.c_size_t),
|
|
167
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
168
|
+
("size", ctypes.c_int64),
|
|
169
|
+
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_ArgType*
|
|
170
|
+
("args", ctypes.POINTER(ctypes.c_void_p)),
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
# struct XLA_FFI_Rets {
|
|
175
|
+
# size_t struct_size;
|
|
176
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
177
|
+
# int64_t size;
|
|
178
|
+
# XLA_FFI_RetType* types; // length == size
|
|
179
|
+
# void** rets; // length == size
|
|
180
|
+
# };
|
|
181
|
+
class XLA_FFI_Rets(ctypes.Structure):
|
|
182
|
+
_fields_ = [
|
|
183
|
+
("struct_size", ctypes.c_size_t),
|
|
184
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
185
|
+
("size", ctypes.c_int64),
|
|
186
|
+
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_RetType*
|
|
187
|
+
("rets", ctypes.POINTER(ctypes.c_void_p)),
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# typedef struct XLA_FFI_ByteSpan {
|
|
192
|
+
# const char* ptr;
|
|
193
|
+
# size_t len;
|
|
194
|
+
# } XLA_FFI_ByteSpan;
|
|
195
|
+
class XLA_FFI_ByteSpan(ctypes.Structure):
|
|
196
|
+
_fields_ = [("ptr", ctypes.POINTER(ctypes.c_char)), ("len", ctypes.c_size_t)]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# typedef struct XLA_FFI_Scalar {
|
|
200
|
+
# XLA_FFI_DataType dtype;
|
|
201
|
+
# void* value;
|
|
202
|
+
# } XLA_FFI_Scalar;
|
|
203
|
+
class XLA_FFI_Scalar(ctypes.Structure):
|
|
204
|
+
_fields_ = [("dtype", ctypes.c_int), ("value", ctypes.c_void_p)]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# typedef struct XLA_FFI_Array {
|
|
208
|
+
# XLA_FFI_DataType dtype;
|
|
209
|
+
# size_t size;
|
|
210
|
+
# void* data;
|
|
211
|
+
# } XLA_FFI_Array;
|
|
212
|
+
class XLA_FFI_Array(ctypes.Structure):
|
|
213
|
+
_fields_ = [("dtype", ctypes.c_int), ("size", ctypes.c_size_t), ("data", ctypes.c_void_p)]
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# typedef enum {
|
|
217
|
+
# XLA_FFI_AttrType_ARRAY = 1,
|
|
218
|
+
# XLA_FFI_AttrType_DICTIONARY = 2,
|
|
219
|
+
# XLA_FFI_AttrType_SCALAR = 3,
|
|
220
|
+
# XLA_FFI_AttrType_STRING = 4,
|
|
221
|
+
# } XLA_FFI_AttrType;
|
|
222
|
+
class XLA_FFI_AttrType(enum.IntEnum):
|
|
223
|
+
ARRAY = 1
|
|
224
|
+
DICTIONARY = 2
|
|
225
|
+
SCALAR = 3
|
|
226
|
+
STRING = 4
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# struct XLA_FFI_Attrs {
|
|
230
|
+
# size_t struct_size;
|
|
231
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
232
|
+
# int64_t size;
|
|
233
|
+
# XLA_FFI_AttrType* types; // length == size
|
|
234
|
+
# XLA_FFI_ByteSpan** names; // length == size
|
|
235
|
+
# void** attrs; // length == size
|
|
236
|
+
# };
|
|
237
|
+
class XLA_FFI_Attrs(ctypes.Structure):
|
|
238
|
+
_fields_ = [
|
|
239
|
+
("struct_size", ctypes.c_size_t),
|
|
240
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
241
|
+
("size", ctypes.c_int64),
|
|
242
|
+
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_AttrType*
|
|
243
|
+
("names", ctypes.POINTER(ctypes.POINTER(XLA_FFI_ByteSpan))),
|
|
244
|
+
("attrs", ctypes.POINTER(ctypes.c_void_p)),
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# struct XLA_FFI_Api_Version {
|
|
249
|
+
# size_t struct_size;
|
|
250
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
251
|
+
# int major_version; // out
|
|
252
|
+
# int minor_version; // out
|
|
253
|
+
# };
|
|
254
|
+
class XLA_FFI_Api_Version(ctypes.Structure):
|
|
255
|
+
_fields_ = [
|
|
256
|
+
("struct_size", ctypes.c_size_t),
|
|
257
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
258
|
+
("major_version", ctypes.c_int),
|
|
259
|
+
("minor_version", ctypes.c_int),
|
|
260
|
+
]
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
# enum XLA_FFI_Handler_TraitsBits {
|
|
264
|
+
# // Calls to FFI handler are safe to trace into the command buffer. It means
|
|
265
|
+
# // that calls to FFI handler always launch exactly the same device operations
|
|
266
|
+
# // (can depend on attribute values) that can be captured and then replayed.
|
|
267
|
+
# XLA_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE = 1u << 0,
|
|
268
|
+
# };
|
|
269
|
+
class XLA_FFI_Handler_TraitsBits(enum.IntEnum):
|
|
270
|
+
COMMAND_BUFFER_COMPATIBLE = 1 << 0
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
# struct XLA_FFI_Metadata {
|
|
274
|
+
# size_t struct_size;
|
|
275
|
+
# XLA_FFI_Api_Version api_version;
|
|
276
|
+
# XLA_FFI_Handler_Traits traits;
|
|
277
|
+
# };
|
|
278
|
+
class XLA_FFI_Metadata(ctypes.Structure):
|
|
279
|
+
_fields_ = [
|
|
280
|
+
("struct_size", ctypes.c_size_t),
|
|
281
|
+
("api_version", XLA_FFI_Api_Version), # XLA_FFI_Extension_Type
|
|
282
|
+
("traits", ctypes.c_uint32), # XLA_FFI_Handler_Traits
|
|
283
|
+
]
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# struct XLA_FFI_Metadata_Extension {
|
|
287
|
+
# XLA_FFI_Extension_Base extension_base;
|
|
288
|
+
# XLA_FFI_Metadata* metadata;
|
|
289
|
+
# };
|
|
290
|
+
class XLA_FFI_Metadata_Extension(ctypes.Structure):
|
|
291
|
+
_fields_ = [("extension_base", XLA_FFI_Extension_Base), ("metadata", ctypes.POINTER(XLA_FFI_Metadata))]
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
# typedef enum {
|
|
295
|
+
# XLA_FFI_Error_Code_OK = 0,
|
|
296
|
+
# XLA_FFI_Error_Code_CANCELLED = 1,
|
|
297
|
+
# XLA_FFI_Error_Code_UNKNOWN = 2,
|
|
298
|
+
# XLA_FFI_Error_Code_INVALID_ARGUMENT = 3,
|
|
299
|
+
# XLA_FFI_Error_Code_DEADLINE_EXCEEDED = 4,
|
|
300
|
+
# XLA_FFI_Error_Code_NOT_FOUND = 5,
|
|
301
|
+
# XLA_FFI_Error_Code_ALREADY_EXISTS = 6,
|
|
302
|
+
# XLA_FFI_Error_Code_PERMISSION_DENIED = 7,
|
|
303
|
+
# XLA_FFI_Error_Code_RESOURCE_EXHAUSTED = 8,
|
|
304
|
+
# XLA_FFI_Error_Code_FAILED_PRECONDITION = 9,
|
|
305
|
+
# XLA_FFI_Error_Code_ABORTED = 10,
|
|
306
|
+
# XLA_FFI_Error_Code_OUT_OF_RANGE = 11,
|
|
307
|
+
# XLA_FFI_Error_Code_UNIMPLEMENTED = 12,
|
|
308
|
+
# XLA_FFI_Error_Code_INTERNAL = 13,
|
|
309
|
+
# XLA_FFI_Error_Code_UNAVAILABLE = 14,
|
|
310
|
+
# XLA_FFI_Error_Code_DATA_LOSS = 15,
|
|
311
|
+
# XLA_FFI_Error_Code_UNAUTHENTICATED = 16
|
|
312
|
+
# } XLA_FFI_Error_Code;
|
|
313
|
+
class XLA_FFI_Error_Code(enum.IntEnum):
|
|
314
|
+
OK = 0
|
|
315
|
+
CANCELLED = 1
|
|
316
|
+
UNKNOWN = 2
|
|
317
|
+
INVALID_ARGUMENT = 3
|
|
318
|
+
DEADLINE_EXCEEDED = 4
|
|
319
|
+
NOT_FOUND = 5
|
|
320
|
+
ALREADY_EXISTS = 6
|
|
321
|
+
PERMISSION_DENIED = 7
|
|
322
|
+
RESOURCE_EXHAUSTED = 8
|
|
323
|
+
FAILED_PRECONDITION = 9
|
|
324
|
+
ABORTED = 10
|
|
325
|
+
OUT_OF_RANGE = 11
|
|
326
|
+
UNIMPLEMENTED = 12
|
|
327
|
+
INTERNAL = 13
|
|
328
|
+
UNAVAILABLE = 14
|
|
329
|
+
DATA_LOSS = 15
|
|
330
|
+
UNAUTHENTICATED = 16
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# struct XLA_FFI_Error_Create_Args {
|
|
334
|
+
# size_t struct_size;
|
|
335
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
336
|
+
# const char* message;
|
|
337
|
+
# XLA_FFI_Error_Code errc;
|
|
338
|
+
# };
|
|
339
|
+
class XLA_FFI_Error_Create_Args(ctypes.Structure):
|
|
340
|
+
_fields_ = [
|
|
341
|
+
("struct_size", ctypes.c_size_t),
|
|
342
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
343
|
+
("message", ctypes.c_char_p),
|
|
344
|
+
("errc", ctypes.c_int),
|
|
345
|
+
] # XLA_FFI_Error_Code
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Error_Create_Args))
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
# struct XLA_FFI_Stream_Get_Args {
|
|
352
|
+
# size_t struct_size;
|
|
353
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
354
|
+
# XLA_FFI_ExecutionContext* ctx;
|
|
355
|
+
# void* stream; // out
|
|
356
|
+
# };
|
|
357
|
+
class XLA_FFI_Stream_Get_Args(ctypes.Structure):
|
|
358
|
+
_fields_ = [
|
|
359
|
+
("struct_size", ctypes.c_size_t),
|
|
360
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
361
|
+
("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
|
|
362
|
+
("stream", ctypes.c_void_p),
|
|
363
|
+
] # // out
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
# struct XLA_FFI_Api {
|
|
370
|
+
# size_t struct_size;
|
|
371
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
372
|
+
#
|
|
373
|
+
# XLA_FFI_Api_Version api_version;
|
|
374
|
+
# XLA_FFI_InternalApi* internal_api;
|
|
375
|
+
#
|
|
376
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Create);
|
|
377
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_GetMessage);
|
|
378
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Error_Destroy);
|
|
379
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Handler_Register);
|
|
380
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Stream_Get);
|
|
381
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_TypeId_Register);
|
|
382
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ExecutionContext_Get);
|
|
383
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Set);
|
|
384
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_State_Get);
|
|
385
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Allocate);
|
|
386
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_DeviceMemory_Free);
|
|
387
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_Schedule);
|
|
388
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_ThreadPool_NumThreads);
|
|
389
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_Create);
|
|
390
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetAvailable);
|
|
391
|
+
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
|
|
392
|
+
# };
|
|
393
|
+
class XLA_FFI_Api(ctypes.Structure):
|
|
394
|
+
_fields_ = [
|
|
395
|
+
("struct_size", ctypes.c_size_t),
|
|
396
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
397
|
+
("api_version", XLA_FFI_Api_Version),
|
|
398
|
+
("internal_api", ctypes.c_void_p), # XLA_FFI_InternalApi*
|
|
399
|
+
("XLA_FFI_Error_Create", XLA_FFI_Error_Create), # XLA_FFI_Error_Create
|
|
400
|
+
("XLA_FFI_Error_GetMessage", ctypes.c_void_p), # XLA_FFI_Error_GetMessage
|
|
401
|
+
("XLA_FFI_Error_Destroy", ctypes.c_void_p), # XLA_FFI_Error_Destroy
|
|
402
|
+
("XLA_FFI_Handler_Register", ctypes.c_void_p), # XLA_FFI_Handler_Register
|
|
403
|
+
("XLA_FFI_Stream_Get", XLA_FFI_Stream_Get), # XLA_FFI_Stream_Get
|
|
404
|
+
("XLA_FFI_TypeId_Register", ctypes.c_void_p), # XLA_FFI_TypeId_Register
|
|
405
|
+
("XLA_FFI_ExecutionContext_Get", ctypes.c_void_p), # XLA_FFI_ExecutionContext_Get
|
|
406
|
+
("XLA_FFI_State_Set", ctypes.c_void_p), # XLA_FFI_State_Set
|
|
407
|
+
("XLA_FFI_State_Get", ctypes.c_void_p), # XLA_FFI_State_Get
|
|
408
|
+
("XLA_FFI_DeviceMemory_Allocate", ctypes.c_void_p), # XLA_FFI_DeviceMemory_Allocate
|
|
409
|
+
("XLA_FFI_DeviceMemory_Free", ctypes.c_void_p), # XLA_FFI_DeviceMemory_Free
|
|
410
|
+
("XLA_FFI_ThreadPool_Schedule", ctypes.c_void_p), # XLA_FFI_ThreadPool_Schedule
|
|
411
|
+
("XLA_FFI_ThreadPool_NumThreads", ctypes.c_void_p), # XLA_FFI_ThreadPool_NumThreads
|
|
412
|
+
("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
|
|
413
|
+
("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
|
|
414
|
+
("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
|
|
415
|
+
]
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# struct XLA_FFI_CallFrame {
|
|
419
|
+
# size_t struct_size;
|
|
420
|
+
# XLA_FFI_Extension_Base* extension_start;
|
|
421
|
+
# const XLA_FFI_Api* api;
|
|
422
|
+
# XLA_FFI_ExecutionContext* ctx;
|
|
423
|
+
# XLA_FFI_ExecutionStage stage;
|
|
424
|
+
# XLA_FFI_Args args;
|
|
425
|
+
# XLA_FFI_Rets rets;
|
|
426
|
+
# XLA_FFI_Attrs attrs;
|
|
427
|
+
#
|
|
428
|
+
# // XLA FFI handler implementation can use `future` to signal a result of
|
|
429
|
+
# // asynchronous computation to the XLA runtime. XLA runtime will keep all
|
|
430
|
+
# // arguments, results and attributes alive until `future` is completed.
|
|
431
|
+
# XLA_FFI_Future* future; // out
|
|
432
|
+
# };
|
|
433
|
+
class XLA_FFI_CallFrame(ctypes.Structure):
|
|
434
|
+
_fields_ = [
|
|
435
|
+
("struct_size", ctypes.c_size_t),
|
|
436
|
+
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
437
|
+
("api", ctypes.POINTER(XLA_FFI_Api)),
|
|
438
|
+
("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
|
|
439
|
+
("stage", ctypes.c_int), # XLA_FFI_ExecutionStage
|
|
440
|
+
("args", XLA_FFI_Args),
|
|
441
|
+
("rets", XLA_FFI_Rets),
|
|
442
|
+
("attrs", XLA_FFI_Attrs),
|
|
443
|
+
("future", ctypes.c_void_p), # XLA_FFI_Future* // out
|
|
444
|
+
]
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
_xla_data_type_to_constructor = {
|
|
448
|
+
# XLA_FFI_DataType.INVALID
|
|
449
|
+
XLA_FFI_DataType.PRED: jnp.bool,
|
|
450
|
+
XLA_FFI_DataType.S8: jnp.int8,
|
|
451
|
+
XLA_FFI_DataType.S16: jnp.int16,
|
|
452
|
+
XLA_FFI_DataType.S32: jnp.int32,
|
|
453
|
+
XLA_FFI_DataType.S64: jnp.int64,
|
|
454
|
+
XLA_FFI_DataType.U8: jnp.uint8,
|
|
455
|
+
XLA_FFI_DataType.U16: jnp.uint16,
|
|
456
|
+
XLA_FFI_DataType.U32: jnp.uint32,
|
|
457
|
+
XLA_FFI_DataType.U64: jnp.uint64,
|
|
458
|
+
XLA_FFI_DataType.F16: jnp.float16,
|
|
459
|
+
XLA_FFI_DataType.F32: jnp.float32,
|
|
460
|
+
XLA_FFI_DataType.F64: jnp.float64,
|
|
461
|
+
XLA_FFI_DataType.BF16: jnp.bfloat16,
|
|
462
|
+
XLA_FFI_DataType.C64: jnp.complex64,
|
|
463
|
+
XLA_FFI_DataType.C128: jnp.complex128,
|
|
464
|
+
# XLA_FFI_DataType.TOKEN
|
|
465
|
+
XLA_FFI_DataType.F8E5M2: jnp.float8_e5m2,
|
|
466
|
+
XLA_FFI_DataType.F8E3M4: jnp.float8_e3m4,
|
|
467
|
+
XLA_FFI_DataType.F8E4M3: jnp.float8_e4m3,
|
|
468
|
+
XLA_FFI_DataType.F8E4M3FN: jnp.float8_e4m3fn,
|
|
469
|
+
XLA_FFI_DataType.F8E4M3B11FNUZ: jnp.float8_e4m3b11fnuz,
|
|
470
|
+
XLA_FFI_DataType.F8E5M2FNUZ: jnp.float8_e5m2fnuz,
|
|
471
|
+
XLA_FFI_DataType.F8E4M3FNUZ: jnp.float8_e4m3fnuz,
|
|
472
|
+
# XLA_FFI_DataType.F4E2M1FN: jnp.float4_e2m1fn.dtype,
|
|
473
|
+
# XLA_FFI_DataType.F8E8M0FNU: jnp.float8_e8m0fnu.dtype,
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
########################################################################
|
|
478
|
+
# Helpers for translating between ctypes and python types
|
|
479
|
+
#######################################################################
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def decode_bytespan(span: XLA_FFI_ByteSpan):
|
|
483
|
+
len = span.len
|
|
484
|
+
chars = ctypes.cast(span.ptr, ctypes.POINTER(ctypes.c_char * len))
|
|
485
|
+
return chars.contents.value.decode("utf-8")
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def decode_scalar(scalar: XLA_FFI_Scalar):
|
|
489
|
+
# TODO validate if dtype supported
|
|
490
|
+
dtype = jnp.dtype(_xla_data_type_to_constructor[scalar.dtype])
|
|
491
|
+
bytes = ctypes.string_at(scalar.value, dtype.itemsize)
|
|
492
|
+
return np.frombuffer(bytes, dtype=dtype).reshape(())
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def decode_array(array: XLA_FFI_Array):
|
|
496
|
+
# TODO validate if dtype supported
|
|
497
|
+
dtype = jnp.dtype(_xla_data_type_to_constructor[array.dtype])
|
|
498
|
+
bytes = ctypes.string_at(array.data, dtype.itemsize * array.size)
|
|
499
|
+
return np.frombuffer(bytes, dtype=dtype)
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def decode_attrs(attrs: XLA_FFI_Attrs):
|
|
503
|
+
result = {}
|
|
504
|
+
for i in range(attrs.size):
|
|
505
|
+
attr_name = decode_bytespan(attrs.names[i].contents)
|
|
506
|
+
attr_type = attrs.types[i]
|
|
507
|
+
if attr_type == XLA_FFI_AttrType.STRING:
|
|
508
|
+
bytespan = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_ByteSpan))
|
|
509
|
+
attr_value = decode_bytespan(bytespan.contents)
|
|
510
|
+
elif attr_type == XLA_FFI_AttrType.SCALAR:
|
|
511
|
+
attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Scalar))
|
|
512
|
+
attr_value = decode_scalar(attr_value.contents)
|
|
513
|
+
elif attr_type == XLA_FFI_AttrType.ARRAY:
|
|
514
|
+
attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Array))
|
|
515
|
+
attr_value = decode_array(attr_value.contents)
|
|
516
|
+
elif attr_type == XLA_FFI_AttrType.DICTIONARY:
|
|
517
|
+
attr_value = ctypes.cast(attrs.attrs[i], ctypes.POINTER(XLA_FFI_Attrs))
|
|
518
|
+
attr_value = decode_attrs(attr_value.contents)
|
|
519
|
+
else:
|
|
520
|
+
raise Exception("Unexpected attr type")
|
|
521
|
+
result[attr_name] = attr_value
|
|
522
|
+
return result
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
# error-string to XLA_FFI_Error
|
|
526
|
+
def create_ffi_error(api, errc, message):
|
|
527
|
+
create_args = XLA_FFI_Error_Create_Args(
|
|
528
|
+
ctypes.sizeof(XLA_FFI_Error_Create_Args),
|
|
529
|
+
ctypes.POINTER(XLA_FFI_Extension_Base)(),
|
|
530
|
+
ctypes.c_char_p(message.encode("utf-8")),
|
|
531
|
+
errc,
|
|
532
|
+
)
|
|
533
|
+
return api.contents.XLA_FFI_Error_Create(create_args)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def create_invalid_argument_ffi_error(api, message):
|
|
537
|
+
return create_ffi_error(api, XLA_FFI_Error_Code.INVALID_ARGUMENT, message)
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
# Extract CUDA stream from XLA_FFI_CallFrame.
|
|
541
|
+
def get_stream_from_callframe(call_frame):
|
|
542
|
+
api = call_frame.api
|
|
543
|
+
get_stream_args = XLA_FFI_Stream_Get_Args(
|
|
544
|
+
ctypes.sizeof(XLA_FFI_Stream_Get_Args), ctypes.POINTER(XLA_FFI_Extension_Base)(), call_frame.ctx, None
|
|
545
|
+
)
|
|
546
|
+
api.contents.XLA_FFI_Stream_Get(get_stream_args)
|
|
547
|
+
# TODO check result
|
|
548
|
+
return get_stream_args.stream
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
_dtype_from_ffi = {
|
|
552
|
+
XLA_FFI_DataType.S8: wp.int8,
|
|
553
|
+
XLA_FFI_DataType.S16: wp.int16,
|
|
554
|
+
XLA_FFI_DataType.S32: wp.int32,
|
|
555
|
+
XLA_FFI_DataType.S64: wp.int64,
|
|
556
|
+
XLA_FFI_DataType.U8: wp.uint8,
|
|
557
|
+
XLA_FFI_DataType.U16: wp.uint16,
|
|
558
|
+
XLA_FFI_DataType.U32: wp.uint32,
|
|
559
|
+
XLA_FFI_DataType.U64: wp.uint64,
|
|
560
|
+
XLA_FFI_DataType.F16: wp.float16,
|
|
561
|
+
XLA_FFI_DataType.F32: wp.float32,
|
|
562
|
+
XLA_FFI_DataType.F64: wp.float64,
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def dtype_from_ffi(ffi_dtype):
|
|
567
|
+
return _dtype_from_ffi.get(ffi_dtype)
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def jax_dtype_from_ffi(ffi_dtype):
|
|
571
|
+
return _xla_data_type_to_constructor.get(ffi_dtype)
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
# Execution context (stream, stage)
|
|
575
|
+
class ExecutionContext:
|
|
576
|
+
stage: XLA_FFI_ExecutionStage
|
|
577
|
+
stream: int
|
|
578
|
+
|
|
579
|
+
def __init__(self, callframe: XLA_FFI_CallFrame):
|
|
580
|
+
self.stage = XLA_FFI_ExecutionStage(callframe.stage)
|
|
581
|
+
self.stream = get_stream_from_callframe(callframe)
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
class FfiBuffer:
|
|
585
|
+
dtype: str
|
|
586
|
+
data: int
|
|
587
|
+
shape: tuple[int]
|
|
588
|
+
|
|
589
|
+
def __init__(self, xla_buffer):
|
|
590
|
+
# TODO check if valid
|
|
591
|
+
self.dtype = jnp.dtype(_xla_data_type_to_constructor[xla_buffer.dtype])
|
|
592
|
+
self.shape = tuple(xla_buffer.dims[i] for i in range(xla_buffer.rank))
|
|
593
|
+
self.data = xla_buffer.data
|
|
594
|
+
|
|
595
|
+
@property
|
|
596
|
+
def __cuda_array_interface__(self):
|
|
597
|
+
return {
|
|
598
|
+
"shape": self.shape,
|
|
599
|
+
"typestr": self.dtype.char,
|
|
600
|
+
"data": (self.data, False),
|
|
601
|
+
"version": 2,
|
|
602
|
+
}
|
warp/math.py
CHANGED
|
@@ -27,6 +27,8 @@ __all__ = [
|
|
|
27
27
|
"norm_huber",
|
|
28
28
|
"norm_pseudo_huber",
|
|
29
29
|
"smooth_normalize",
|
|
30
|
+
"transform_from_matrix",
|
|
31
|
+
"transform_to_matrix",
|
|
30
32
|
]
|
|
31
33
|
|
|
32
34
|
|
|
@@ -131,6 +133,85 @@ def smooth_normalize(v: Any, delta: float = 1.0):
|
|
|
131
133
|
return v / norm_pseudo_huber(v, delta)
|
|
132
134
|
|
|
133
135
|
|
|
136
|
+
def create_transform_from_matrix_func(dtype):
|
|
137
|
+
mat44 = wp.types.matrix((4, 4), dtype)
|
|
138
|
+
vec3 = wp.types.vector(3, dtype)
|
|
139
|
+
transform = wp.types.transformation(dtype)
|
|
140
|
+
|
|
141
|
+
def transform_from_matrix(mat: mat44) -> transform:
|
|
142
|
+
"""
|
|
143
|
+
Construct a transformation from a 4x4 matrix.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
mat (Matrix[4, 4, Float]): Matrix to convert.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Transformation[Float]: The transformation.
|
|
150
|
+
"""
|
|
151
|
+
p = vec3(mat[0][3], mat[1][3], mat[2][3])
|
|
152
|
+
q = wp.quat_from_matrix(mat)
|
|
153
|
+
return transform(p, q)
|
|
154
|
+
|
|
155
|
+
return transform_from_matrix
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
transform_from_matrix = wp.func(
|
|
159
|
+
create_transform_from_matrix_func(wp.float32),
|
|
160
|
+
name="transform_from_matrix",
|
|
161
|
+
)
|
|
162
|
+
wp.func(
|
|
163
|
+
create_transform_from_matrix_func(wp.float16),
|
|
164
|
+
name="transform_from_matrix",
|
|
165
|
+
)
|
|
166
|
+
wp.func(
|
|
167
|
+
create_transform_from_matrix_func(wp.float64),
|
|
168
|
+
name="transform_from_matrix",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def create_transform_to_matrix_func(dtype):
|
|
173
|
+
mat44 = wp.types.matrix((4, 4), dtype)
|
|
174
|
+
transform = wp.types.transformation(dtype)
|
|
175
|
+
|
|
176
|
+
def transform_to_matrix(xform: transform) -> mat44:
|
|
177
|
+
"""
|
|
178
|
+
Convert a transformation to a 4x4 matrix.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
xform (Transformation[Float]): Transformation to convert.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Matrix[4, 4, Float]: The matrix.
|
|
185
|
+
"""
|
|
186
|
+
p = wp.transform_get_translation(xform)
|
|
187
|
+
q = wp.transform_get_rotation(xform)
|
|
188
|
+
rot = wp.quat_to_matrix(q)
|
|
189
|
+
# fmt: off
|
|
190
|
+
return mat44(
|
|
191
|
+
rot[0][0], rot[0][1], rot[0][2], p[0],
|
|
192
|
+
rot[1][0], rot[1][1], rot[1][2], p[1],
|
|
193
|
+
rot[2][0], rot[2][1], rot[2][2], p[2],
|
|
194
|
+
dtype(0.0), dtype(0.0), dtype(0.0), dtype(1.0),
|
|
195
|
+
)
|
|
196
|
+
# fmt: on
|
|
197
|
+
|
|
198
|
+
return transform_to_matrix
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
transform_to_matrix = wp.func(
|
|
202
|
+
create_transform_to_matrix_func(wp.float32),
|
|
203
|
+
name="transform_to_matrix",
|
|
204
|
+
)
|
|
205
|
+
wp.func(
|
|
206
|
+
create_transform_to_matrix_func(wp.float16),
|
|
207
|
+
name="transform_to_matrix",
|
|
208
|
+
)
|
|
209
|
+
wp.func(
|
|
210
|
+
create_transform_to_matrix_func(wp.float64),
|
|
211
|
+
name="transform_to_matrix",
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
134
215
|
# register API functions so they appear in the documentation
|
|
135
216
|
|
|
136
217
|
wp.context.register_api_function(
|
|
@@ -153,3 +234,11 @@ wp.context.register_api_function(
|
|
|
153
234
|
smooth_normalize,
|
|
154
235
|
group="Vector Math",
|
|
155
236
|
)
|
|
237
|
+
wp.context.register_api_function(
|
|
238
|
+
transform_from_matrix,
|
|
239
|
+
group="Transformations",
|
|
240
|
+
)
|
|
241
|
+
wp.context.register_api_function(
|
|
242
|
+
transform_to_matrix,
|
|
243
|
+
group="Transformations",
|
|
244
|
+
)
|