warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/warp-clang.dll +0 -0
  3. warp/bin/warp.dll +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example register_ffi_callback()
18
+ #
19
+ # Examples of calling Python functions from JAX.
20
+ # Target functions must have the form func(inputs, outputs, attrs, ctx).
21
+ ###########################################################################
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+ import warp as wp
28
+ from warp.jax import get_jax_device
29
+ from warp.jax_experimental.ffi import register_ffi_callback
30
+
31
+
32
+ @wp.kernel
33
+ def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
34
+ tid = wp.tid()
35
+ output[tid] = a[tid] * s
36
+
37
+
38
+ @wp.kernel
39
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
40
+ tid = wp.tid()
41
+ output[tid] = a[tid] * s
42
+
43
+
44
+ def example1():
45
+ # the Python function to call
46
+ def print_args(inputs, outputs, attrs, ctx):
47
+ def buffer_to_string(b):
48
+ return str(b.dtype) + str(list(b.shape)) + " @%x" % b.data
49
+
50
+ print("Inputs: ", ", ".join([buffer_to_string(b) for b in inputs]))
51
+ print("Outputs: ", ", ".join([buffer_to_string(b) for b in outputs]))
52
+ print("Attributes: ", "".join(["\n %s: %s" % (k, str(v)) for k, v in attrs.items()]))
53
+
54
+ # register callback
55
+ register_ffi_callback("print_args", print_args)
56
+
57
+ # set up call
58
+ call = jax.ffi.ffi_call("print_args", jax.ShapeDtypeStruct((1, 2, 3), jnp.int8))
59
+
60
+ # call it
61
+ call(
62
+ jnp.arange(16),
63
+ jnp.arange(32.0).reshape((4, 8)),
64
+ str_attr="hi",
65
+ f32_attr=np.float32(4.2),
66
+ dict_attr={"a": 1, "b": 6.4},
67
+ )
68
+
69
+
70
+ def example2():
71
+ # the Python function to call
72
+ def warp_func(inputs, outputs, attrs, ctx):
73
+ # input arrays
74
+ a = inputs[0]
75
+ b = inputs[1]
76
+
77
+ # scalar attributes
78
+ s = attrs["scale"]
79
+
80
+ # output arrays
81
+ c = outputs[0]
82
+ d = outputs[1]
83
+
84
+ device = wp.device_from_jax(get_jax_device())
85
+ stream = wp.Stream(device, cuda_stream=ctx.stream)
86
+
87
+ with wp.ScopedStream(stream):
88
+ # launch with arrays of scalars
89
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
90
+
91
+ # launch with arrays of vec2
92
+ # NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
93
+ wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
94
+
95
+ # register callback
96
+ register_ffi_callback("warp_func", warp_func)
97
+
98
+ n = 10
99
+
100
+ # inputs
101
+ a = jnp.arange(n, dtype=jnp.float32)
102
+ b = jnp.arange(n, dtype=jnp.float32).reshape((n // 2, 2)) # array of wp.vec2
103
+ s = 2.0
104
+
105
+ # set up call
106
+ out_types = [
107
+ jax.ShapeDtypeStruct(a.shape, jnp.float32),
108
+ jax.ShapeDtypeStruct(b.shape, jnp.float32), # array of wp.vec2
109
+ ]
110
+ call = jax.ffi.ffi_call("warp_func", out_types)
111
+
112
+ # call it
113
+ c, d = call(a, b, scale=s)
114
+
115
+ print(c)
116
+ print(d)
117
+
118
+
119
+ def main():
120
+ wp.init()
121
+ wp.load_module(device=wp.get_device())
122
+
123
+ examples = [example1, example2]
124
+
125
+ for example in examples:
126
+ print("\n===========================================================================")
127
+ print(f"{example.__name__}:")
128
+ example()
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
@@ -0,0 +1,205 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example jax_kernel()
18
+ #
19
+ # Examples of calling a Warp kernel from JAX.
20
+ ###########################################################################
21
+
22
+ import math
23
+ from functools import partial
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+
28
+ import warp as wp
29
+ from warp.jax_experimental.ffi import jax_kernel
30
+
31
+
32
+ @wp.kernel
33
+ def add_kernel(a: wp.array(dtype=int), b: wp.array(dtype=int), output: wp.array(dtype=int)):
34
+ tid = wp.tid()
35
+ output[tid] = a[tid] + b[tid]
36
+
37
+
38
+ @wp.kernel
39
+ def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
40
+ tid = wp.tid()
41
+ sin_out[tid] = wp.sin(angle[tid])
42
+ cos_out[tid] = wp.cos(angle[tid])
43
+
44
+
45
+ @wp.kernel
46
+ def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
47
+ tid = wp.tid()
48
+ output[tid] = wp.mat33(1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0)
49
+
50
+
51
+ @wp.kernel
52
+ def matmul_kernel(
53
+ a: wp.array2d(dtype=float), # NxK
54
+ b: wp.array2d(dtype=float), # KxM
55
+ c: wp.array2d(dtype=float), # NxM
56
+ ):
57
+ # launch dims should be (N, M)
58
+ i, j = wp.tid()
59
+ N = a.shape[0]
60
+ K = a.shape[1]
61
+ M = b.shape[1]
62
+ if i < N and j < M:
63
+ s = wp.float32(0)
64
+ for k in range(K):
65
+ s += a[i, k] * b[k, j]
66
+ c[i, j] = s
67
+
68
+
69
+ @wp.kernel
70
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
71
+ tid = wp.tid()
72
+ output[tid] = a[tid] * s
73
+
74
+
75
+ def example1():
76
+ # two inputs and one output
77
+ jax_add = jax_kernel(add_kernel)
78
+
79
+ @jax.jit
80
+ def f():
81
+ n = 10
82
+ a = jnp.arange(n, dtype=jnp.int32)
83
+ b = jnp.ones(n, dtype=jnp.int32)
84
+ return jax_add(a, b)
85
+
86
+ print(f())
87
+
88
+
89
+ def example2():
90
+ # one input and two outputs
91
+ jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
92
+
93
+ @jax.jit
94
+ def f():
95
+ n = 32
96
+ a = jnp.linspace(0, 2 * math.pi, n)
97
+ return jax_sincos(a)
98
+
99
+ s, c = f()
100
+ print(s)
101
+ print()
102
+ print(c)
103
+
104
+
105
+ def example3():
106
+ # multiply vectors by scalar
107
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
108
+
109
+ @jax.jit
110
+ def f():
111
+ a = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # array of vec2
112
+ s = 2.0
113
+ return jax_scale_vec(a, s)
114
+
115
+ b = f()
116
+ print(b)
117
+
118
+
119
+ def example4():
120
+ # multiply vectors by scalar (static arg)
121
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
122
+
123
+ # NOTE: scalar arguments must be static compile-time constants
124
+ @partial(jax.jit, static_argnames=["s"])
125
+ def f(a, s):
126
+ return jax_scale_vec(a, s)
127
+
128
+ a = jnp.arange(10, dtype=jnp.float32).reshape((5, 2)) # array of vec2
129
+ s = 3.0
130
+
131
+ b = f(a, s)
132
+ print(b)
133
+
134
+
135
+ def example5():
136
+ N, M, K = 3, 4, 2
137
+
138
+ # specify default launch dims
139
+ jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
140
+
141
+ @jax.jit
142
+ def f():
143
+ a = jnp.full((N, K), 2, dtype=jnp.float32)
144
+ b = jnp.full((K, M), 3, dtype=jnp.float32)
145
+
146
+ # use default launch dims
147
+ return jax_matmul(a, b)
148
+
149
+ print(f())
150
+
151
+
152
+ def example6():
153
+ # don't specify default launch dims
154
+ jax_matmul = jax_kernel(matmul_kernel)
155
+
156
+ @jax.jit
157
+ def f():
158
+ N1, M1, K1 = 3, 4, 2
159
+ a1 = jnp.full((N1, K1), 2, dtype=jnp.float32)
160
+ b1 = jnp.full((K1, M1), 3, dtype=jnp.float32)
161
+
162
+ # use custom launch dims
163
+ result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
164
+
165
+ N2, M2, K2 = 4, 3, 2
166
+ a2 = jnp.full((N2, K2), 2, dtype=jnp.float32)
167
+ b2 = jnp.full((K2, M2), 3, dtype=jnp.float32)
168
+
169
+ # use custom launch dims
170
+ result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
171
+
172
+ return result1, result2
173
+
174
+ r1, r2 = f()
175
+ print(r1)
176
+ print()
177
+ print(r2)
178
+
179
+
180
+ def example7():
181
+ # no inputs and one output
182
+ jax_diagonal = jax_kernel(diagonal_kernel)
183
+
184
+ @jax.jit
185
+ def f():
186
+ # launch dimensions determine output size
187
+ return jax_diagonal(launch_dims=4)
188
+
189
+ print(f())
190
+
191
+
192
+ def main():
193
+ wp.init()
194
+ wp.load_module(device=wp.get_device())
195
+
196
+ examples = [example1, example2, example3, example4, example5, example6, example7]
197
+
198
+ for example in examples:
199
+ print("\n===========================================================================")
200
+ print(f"{example.__name__}:")
201
+ example()
202
+
203
+
204
+ if __name__ == "__main__":
205
+ main()