warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/autograd.py +12 -2
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +463 -372
- warp/codegen.py +196 -124
- warp/config.py +42 -6
- warp/context.py +496 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_cloth.py +1 -1
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/distributed/example_jacobi_mpi.py +507 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/field.py +11 -1
- warp/fem/field/nodal_field.py +56 -88
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +16 -13
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +7 -20
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
- warp/jax_experimental/ffi.py +702 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +312 -116
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +100 -11
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/render/render_opengl.py +19 -17
- warp/render/render_usd.py +93 -3
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +32 -19
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/inertia.py +189 -156
- warp/sim/integrator_euler.py +8 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +8 -5
- warp/sim/model.py +71 -25
- warp/sim/render.py +4 -0
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +217 -20
- warp/tests/__main__.py +0 -15
- warp/tests/assets/torus.usda +1 -1
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
- warp/tests/sim/test_inertia.py +161 -0
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/sim/test_xpbd.py +399 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_codegen.py +24 -3
- warp/tests/test_examples.py +40 -38
- warp/tests/test_fem.py +98 -14
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +577 -156
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +356 -151
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +336 -178
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -62
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +175 -666
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example jax_callable()
|
|
18
|
+
#
|
|
19
|
+
# Examples of calling annotated Python functions from JAX.
|
|
20
|
+
###########################################################################
|
|
21
|
+
|
|
22
|
+
from functools import partial
|
|
23
|
+
|
|
24
|
+
import jax
|
|
25
|
+
import jax.numpy as jnp
|
|
26
|
+
|
|
27
|
+
import warp as wp
|
|
28
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@wp.kernel
|
|
32
|
+
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
|
|
33
|
+
tid = wp.tid()
|
|
34
|
+
output[tid] = a[tid] * s
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@wp.kernel
|
|
38
|
+
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
|
|
39
|
+
tid = wp.tid()
|
|
40
|
+
output[tid] = a[tid] * s
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# The Python function to call.
|
|
44
|
+
# Note the argument annotations, just like Warp kernels.
|
|
45
|
+
def example_func(
|
|
46
|
+
# inputs
|
|
47
|
+
a: wp.array(dtype=float),
|
|
48
|
+
b: wp.array(dtype=wp.vec2),
|
|
49
|
+
s: float,
|
|
50
|
+
# outputs
|
|
51
|
+
c: wp.array(dtype=float),
|
|
52
|
+
d: wp.array(dtype=wp.vec2),
|
|
53
|
+
):
|
|
54
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
55
|
+
wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def example1():
|
|
59
|
+
jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
|
|
60
|
+
|
|
61
|
+
@jax.jit
|
|
62
|
+
def f():
|
|
63
|
+
# inputs
|
|
64
|
+
a = jnp.arange(10, dtype=jnp.float32)
|
|
65
|
+
b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # wp.vec2
|
|
66
|
+
s = 2.0
|
|
67
|
+
|
|
68
|
+
# output shapes
|
|
69
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
70
|
+
|
|
71
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
72
|
+
|
|
73
|
+
return c, d
|
|
74
|
+
|
|
75
|
+
r1, r2 = f()
|
|
76
|
+
print(r1)
|
|
77
|
+
print(r2)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def example2():
|
|
81
|
+
jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
|
|
82
|
+
|
|
83
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
84
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
85
|
+
def f(a, b, s):
|
|
86
|
+
# output shapes
|
|
87
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
88
|
+
|
|
89
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
90
|
+
|
|
91
|
+
return c, d
|
|
92
|
+
|
|
93
|
+
# inputs
|
|
94
|
+
a = jnp.arange(10, dtype=jnp.float32)
|
|
95
|
+
b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # wp.vec2
|
|
96
|
+
s = 3.0
|
|
97
|
+
|
|
98
|
+
r1, r2 = f(a, b, s)
|
|
99
|
+
print(r1)
|
|
100
|
+
print(r2)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def main():
|
|
104
|
+
wp.init()
|
|
105
|
+
wp.load_module(device=wp.get_device())
|
|
106
|
+
|
|
107
|
+
examples = [example1, example2]
|
|
108
|
+
|
|
109
|
+
for example in examples:
|
|
110
|
+
print("\n===========================================================================")
|
|
111
|
+
print(f"{example.__name__}:")
|
|
112
|
+
example()
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
if __name__ == "__main__":
|
|
116
|
+
main()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example register_ffi_callback()
|
|
18
|
+
#
|
|
19
|
+
# Examples of calling Python functions from JAX.
|
|
20
|
+
# Target functions must have the form func(inputs, outputs, attrs, ctx).
|
|
21
|
+
###########################################################################
|
|
22
|
+
|
|
23
|
+
import jax
|
|
24
|
+
import jax.numpy as jnp
|
|
25
|
+
import numpy as np
|
|
26
|
+
|
|
27
|
+
import warp as wp
|
|
28
|
+
from warp.jax import get_jax_device
|
|
29
|
+
from warp.jax_experimental.ffi import register_ffi_callback
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@wp.kernel
|
|
33
|
+
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
|
|
34
|
+
tid = wp.tid()
|
|
35
|
+
output[tid] = a[tid] * s
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@wp.kernel
|
|
39
|
+
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
|
|
40
|
+
tid = wp.tid()
|
|
41
|
+
output[tid] = a[tid] * s
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def example1():
|
|
45
|
+
# the Python function to call
|
|
46
|
+
def print_args(inputs, outputs, attrs, ctx):
|
|
47
|
+
def buffer_to_string(b):
|
|
48
|
+
return str(b.dtype) + str(list(b.shape)) + " @%x" % b.data
|
|
49
|
+
|
|
50
|
+
print("Inputs: ", ", ".join([buffer_to_string(b) for b in inputs]))
|
|
51
|
+
print("Outputs: ", ", ".join([buffer_to_string(b) for b in outputs]))
|
|
52
|
+
print("Attributes: ", "".join(["\n %s: %s" % (k, str(v)) for k, v in attrs.items()]))
|
|
53
|
+
|
|
54
|
+
# register callback
|
|
55
|
+
register_ffi_callback("print_args", print_args)
|
|
56
|
+
|
|
57
|
+
# set up call
|
|
58
|
+
call = jax.ffi.ffi_call("print_args", jax.ShapeDtypeStruct((1, 2, 3), jnp.int8))
|
|
59
|
+
|
|
60
|
+
# call it
|
|
61
|
+
call(
|
|
62
|
+
jnp.arange(16),
|
|
63
|
+
jnp.arange(32.0).reshape((4, 8)),
|
|
64
|
+
str_attr="hi",
|
|
65
|
+
f32_attr=np.float32(4.2),
|
|
66
|
+
dict_attr={"a": 1, "b": 6.4},
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def example2():
|
|
71
|
+
# the Python function to call
|
|
72
|
+
def warp_func(inputs, outputs, attrs, ctx):
|
|
73
|
+
# input arrays
|
|
74
|
+
a = inputs[0]
|
|
75
|
+
b = inputs[1]
|
|
76
|
+
|
|
77
|
+
# scalar attributes
|
|
78
|
+
s = attrs["scale"]
|
|
79
|
+
|
|
80
|
+
# output arrays
|
|
81
|
+
c = outputs[0]
|
|
82
|
+
d = outputs[1]
|
|
83
|
+
|
|
84
|
+
device = wp.device_from_jax(get_jax_device())
|
|
85
|
+
stream = wp.Stream(device, cuda_stream=ctx.stream)
|
|
86
|
+
|
|
87
|
+
with wp.ScopedStream(stream):
|
|
88
|
+
# launch with arrays of scalars
|
|
89
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
90
|
+
|
|
91
|
+
# launch with arrays of vec2
|
|
92
|
+
# NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
|
|
93
|
+
wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
|
|
94
|
+
|
|
95
|
+
# register callback
|
|
96
|
+
register_ffi_callback("warp_func", warp_func)
|
|
97
|
+
|
|
98
|
+
n = 10
|
|
99
|
+
|
|
100
|
+
# inputs
|
|
101
|
+
a = jnp.arange(n, dtype=jnp.float32)
|
|
102
|
+
b = jnp.arange(n, dtype=jnp.float32).reshape((n // 2, 2)) # array of wp.vec2
|
|
103
|
+
s = 2.0
|
|
104
|
+
|
|
105
|
+
# set up call
|
|
106
|
+
out_types = [
|
|
107
|
+
jax.ShapeDtypeStruct(a.shape, jnp.float32),
|
|
108
|
+
jax.ShapeDtypeStruct(b.shape, jnp.float32), # array of wp.vec2
|
|
109
|
+
]
|
|
110
|
+
call = jax.ffi.ffi_call("warp_func", out_types)
|
|
111
|
+
|
|
112
|
+
# call it
|
|
113
|
+
c, d = call(a, b, scale=s)
|
|
114
|
+
|
|
115
|
+
print(c)
|
|
116
|
+
print(d)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def main():
|
|
120
|
+
wp.init()
|
|
121
|
+
wp.load_module(device=wp.get_device())
|
|
122
|
+
|
|
123
|
+
examples = [example1, example2]
|
|
124
|
+
|
|
125
|
+
for example in examples:
|
|
126
|
+
print("\n===========================================================================")
|
|
127
|
+
print(f"{example.__name__}:")
|
|
128
|
+
example()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
if __name__ == "__main__":
|
|
132
|
+
main()
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
###########################################################################
|
|
17
|
+
# Example jax_kernel()
|
|
18
|
+
#
|
|
19
|
+
# Examples of calling a Warp kernel from JAX.
|
|
20
|
+
###########################################################################
|
|
21
|
+
|
|
22
|
+
import math
|
|
23
|
+
from functools import partial
|
|
24
|
+
|
|
25
|
+
import jax
|
|
26
|
+
import jax.numpy as jnp
|
|
27
|
+
|
|
28
|
+
import warp as wp
|
|
29
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@wp.kernel
|
|
33
|
+
def add_kernel(a: wp.array(dtype=int), b: wp.array(dtype=int), output: wp.array(dtype=int)):
|
|
34
|
+
tid = wp.tid()
|
|
35
|
+
output[tid] = a[tid] + b[tid]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@wp.kernel
|
|
39
|
+
def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
|
|
40
|
+
tid = wp.tid()
|
|
41
|
+
sin_out[tid] = wp.sin(angle[tid])
|
|
42
|
+
cos_out[tid] = wp.cos(angle[tid])
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@wp.kernel
|
|
46
|
+
def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
|
|
47
|
+
tid = wp.tid()
|
|
48
|
+
output[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@wp.kernel
|
|
52
|
+
def matmul_kernel(
|
|
53
|
+
a: wp.array2d(dtype=float), # NxK
|
|
54
|
+
b: wp.array2d(dtype=float), # KxM
|
|
55
|
+
c: wp.array2d(dtype=float), # NxM
|
|
56
|
+
):
|
|
57
|
+
# launch dims should be (N, M)
|
|
58
|
+
i, j = wp.tid()
|
|
59
|
+
N = a.shape[0]
|
|
60
|
+
K = a.shape[1]
|
|
61
|
+
M = b.shape[1]
|
|
62
|
+
if i < N and j < M:
|
|
63
|
+
s = wp.float32(0)
|
|
64
|
+
for k in range(K):
|
|
65
|
+
s += a[i, k] * b[k, j]
|
|
66
|
+
c[i, j] = s
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@wp.kernel
|
|
70
|
+
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
|
|
71
|
+
tid = wp.tid()
|
|
72
|
+
output[tid] = a[tid] * s
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def example1():
|
|
76
|
+
# two inputs and one output
|
|
77
|
+
jax_add = jax_kernel(add_kernel)
|
|
78
|
+
|
|
79
|
+
@jax.jit
|
|
80
|
+
def f():
|
|
81
|
+
n = 10
|
|
82
|
+
a = jnp.arange(n, dtype=jnp.int32)
|
|
83
|
+
b = jnp.ones(n, dtype=jnp.int32)
|
|
84
|
+
return jax_add(a, b)
|
|
85
|
+
|
|
86
|
+
print(f())
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def example2():
|
|
90
|
+
# one input and two outputs
|
|
91
|
+
jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
|
|
92
|
+
|
|
93
|
+
@jax.jit
|
|
94
|
+
def f():
|
|
95
|
+
n = 32
|
|
96
|
+
a = jnp.linspace(0, 2 * math.pi, n)
|
|
97
|
+
return jax_sincos(a)
|
|
98
|
+
|
|
99
|
+
s, c = f()
|
|
100
|
+
print(s)
|
|
101
|
+
print()
|
|
102
|
+
print(c)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def example3():
|
|
106
|
+
# multiply vectors by scalar
|
|
107
|
+
jax_scale_vec = jax_kernel(scale_vec_kernel)
|
|
108
|
+
|
|
109
|
+
@jax.jit
|
|
110
|
+
def f():
|
|
111
|
+
a = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # array of vec2
|
|
112
|
+
s = 2.0
|
|
113
|
+
return jax_scale_vec(a, s)
|
|
114
|
+
|
|
115
|
+
b = f()
|
|
116
|
+
print(b)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def example4():
|
|
120
|
+
# multiply vectors by scalar (static arg)
|
|
121
|
+
jax_scale_vec = jax_kernel(scale_vec_kernel)
|
|
122
|
+
|
|
123
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
124
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
125
|
+
def f(a, s):
|
|
126
|
+
return jax_scale_vec(a, s)
|
|
127
|
+
|
|
128
|
+
a = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # array of vec2
|
|
129
|
+
s = 3.0
|
|
130
|
+
|
|
131
|
+
b = f(a, s)
|
|
132
|
+
print(b)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def example5():
|
|
136
|
+
N, M, K = 3, 4, 2
|
|
137
|
+
|
|
138
|
+
# specify default launch dims
|
|
139
|
+
jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
|
|
140
|
+
|
|
141
|
+
@jax.jit
|
|
142
|
+
def f():
|
|
143
|
+
a = jnp.full((N, K), 2, dtype=jnp.float32)
|
|
144
|
+
b = jnp.full((K, M), 3, dtype=jnp.float32)
|
|
145
|
+
|
|
146
|
+
# use default launch dims
|
|
147
|
+
return jax_matmul(a, b)
|
|
148
|
+
|
|
149
|
+
print(f())
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def example6():
|
|
153
|
+
# don't specify default launch dims
|
|
154
|
+
jax_matmul = jax_kernel(matmul_kernel)
|
|
155
|
+
|
|
156
|
+
@jax.jit
|
|
157
|
+
def f():
|
|
158
|
+
N1, M1, K1 = 3, 4, 2
|
|
159
|
+
a1 = jnp.full((N1, K1), 2, dtype=jnp.float32)
|
|
160
|
+
b1 = jnp.full((K1, M1), 3, dtype=jnp.float32)
|
|
161
|
+
|
|
162
|
+
# use custom launch dims
|
|
163
|
+
result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
|
|
164
|
+
|
|
165
|
+
N2, M2, K2 = 4, 3, 2
|
|
166
|
+
a2 = jnp.full((N2, K2), 2, dtype=jnp.float32)
|
|
167
|
+
b2 = jnp.full((K2, M2), 3, dtype=jnp.float32)
|
|
168
|
+
|
|
169
|
+
# use custom launch dims
|
|
170
|
+
result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
|
|
171
|
+
|
|
172
|
+
return result1, result2
|
|
173
|
+
|
|
174
|
+
r1, r2 = f()
|
|
175
|
+
print(r1)
|
|
176
|
+
print()
|
|
177
|
+
print(r2)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def example7():
|
|
181
|
+
# no inputs and one output
|
|
182
|
+
jax_diagonal = jax_kernel(diagonal_kernel)
|
|
183
|
+
|
|
184
|
+
@jax.jit
|
|
185
|
+
def f():
|
|
186
|
+
# launch dimensions determine output size
|
|
187
|
+
return jax_diagonal(launch_dims=4)
|
|
188
|
+
|
|
189
|
+
print(f())
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def main():
|
|
193
|
+
wp.init()
|
|
194
|
+
wp.load_module(device=wp.get_device())
|
|
195
|
+
|
|
196
|
+
examples = [example1, example2, example3, example4, example5, example6, example7]
|
|
197
|
+
|
|
198
|
+
for example in examples:
|
|
199
|
+
print("\n===========================================================================")
|
|
200
|
+
print(f"{example.__name__}:")
|
|
201
|
+
example()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
if __name__ == "__main__":
|
|
205
|
+
main()
|