warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,299 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import unittest
16
+
17
+ import numpy as np
18
+
19
+ import warp as wp
20
+ from warp.tests.unittest_utils import *
21
+
22
+
23
+ def create_spinlock_test(dtype):
24
+ @wp.func
25
+ def spinlock_acquire(lock: wp.array(dtype=dtype)):
26
+ # Try to acquire the lock by setting it to 1 if it's 0
27
+ while wp.atomic_cas(lock, 0, dtype(0), dtype(1)) == 1:
28
+ pass
29
+
30
+ @wp.func
31
+ def spinlock_release(lock: wp.array(dtype=dtype)):
32
+ # Release the lock by setting it back to 0
33
+ wp.atomic_exch(lock, 0, dtype(0))
34
+
35
+ @wp.func
36
+ def volatile_read(ptr: wp.array(dtype=dtype), index: int):
37
+ value = wp.atomic_exch(ptr, index, dtype(0))
38
+ wp.atomic_exch(ptr, index, value)
39
+ return value
40
+
41
+ @wp.kernel
42
+ def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype)):
43
+ # Try to acquire the lock
44
+ spinlock_acquire(lock)
45
+
46
+ # Critical section - increment counter
47
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
48
+
49
+ # Work around since warp arrays cannot be marked as volatile
50
+ value = volatile_read(counter, 0)
51
+ counter[0] = value + dtype(1)
52
+
53
+ # Release the lock
54
+ spinlock_release(lock)
55
+
56
+ return test_spinlock_counter
57
+
58
+
59
+ def test_atomic_cas(test, device, warp_type, numpy_type):
60
+ n = 100
61
+ counter = wp.array([0], dtype=warp_type, device=device)
62
+ lock = wp.array([0], dtype=warp_type, device=device)
63
+
64
+ test_spinlock_counter = create_spinlock_test(warp_type)
65
+ wp.launch(test_spinlock_counter, dim=n, inputs=[counter, lock], device=device)
66
+
67
+ # Verify counter reached n
68
+ counter_np = counter.numpy()
69
+ expected = np.array([n], dtype=numpy_type)
70
+
71
+ if not np.array_equal(counter_np, expected):
72
+ print(f"Counter mismatch: expected {expected}, got {counter_np}")
73
+
74
+ assert_np_equal(counter_np, expected)
75
+
76
+
77
+ def create_spinlock_test_2d(dtype):
78
+ @wp.func
79
+ def spinlock_acquire(lock: wp.array(dtype=dtype, ndim=2)):
80
+ # Try to acquire the lock by setting it to 1 if it's 0
81
+ while wp.atomic_cas(lock, 0, 0, dtype(0), dtype(1)) == 1:
82
+ pass
83
+
84
+ @wp.func
85
+ def spinlock_release(lock: wp.array(dtype=dtype, ndim=2)):
86
+ # Release the lock by setting it back to 0
87
+ wp.atomic_exch(lock, 0, 0, dtype(0))
88
+
89
+ @wp.func
90
+ def volatile_read(ptr: wp.array(dtype=dtype), index: int):
91
+ value = wp.atomic_exch(ptr, index, dtype(0))
92
+ wp.atomic_exch(ptr, index, value)
93
+ return value
94
+
95
+ @wp.kernel
96
+ def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype, ndim=2)):
97
+ # Try to acquire the lock
98
+ spinlock_acquire(lock)
99
+
100
+ # Critical section - increment counter
101
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
102
+
103
+ # Work around since warp arrays cannot be marked as volatile
104
+ value = volatile_read(counter, 0)
105
+ counter[0] = value + dtype(1)
106
+
107
+ # Release the lock
108
+ spinlock_release(lock)
109
+
110
+ return test_spinlock_counter
111
+
112
+
113
+ def test_atomic_cas_2d(test, device, warp_type, numpy_type):
114
+ n = 100
115
+ counter = wp.array([0], dtype=warp_type, device=device)
116
+ lock = wp.zeros(shape=(1, 1), dtype=warp_type, device=device)
117
+
118
+ test_spinlock_counter = create_spinlock_test_2d(warp_type)
119
+ wp.launch(test_spinlock_counter, dim=n, inputs=[counter, lock], device=device)
120
+
121
+ # Verify counter reached n
122
+ counter_np = counter.numpy()
123
+ expected = np.array([n], dtype=numpy_type)
124
+
125
+ if not np.array_equal(counter_np, expected):
126
+ print(f"Counter mismatch: expected {expected}, got {counter_np}")
127
+
128
+ assert_np_equal(counter_np, expected)
129
+
130
+
131
+ def create_spinlock_test_3d(dtype):
132
+ @wp.func
133
+ def spinlock_acquire(lock: wp.array(dtype=dtype, ndim=3)):
134
+ # Try to acquire the lock by setting it to 1 if it's 0
135
+ while wp.atomic_cas(lock, 0, 0, 0, dtype(0), dtype(1)) == 1:
136
+ pass
137
+
138
+ @wp.func
139
+ def spinlock_release(lock: wp.array(dtype=dtype, ndim=3)):
140
+ # Release the lock by setting it back to 0
141
+ wp.atomic_exch(lock, 0, 0, 0, dtype(0))
142
+
143
+ @wp.func
144
+ def volatile_read(ptr: wp.array(dtype=dtype), index: int):
145
+ value = wp.atomic_exch(ptr, index, dtype(0))
146
+ wp.atomic_exch(ptr, index, value)
147
+ return value
148
+
149
+ @wp.kernel
150
+ def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype, ndim=3)):
151
+ # Try to acquire the lock
152
+ spinlock_acquire(lock)
153
+
154
+ # Critical section - increment counter
155
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
156
+
157
+ # Work around since warp arrays cannot be marked as volatile
158
+ value = volatile_read(counter, 0)
159
+ counter[0] = value + dtype(1)
160
+
161
+ # Release the lock
162
+ spinlock_release(lock)
163
+
164
+ return test_spinlock_counter
165
+
166
+
167
+ def test_atomic_cas_3d(test, device, warp_type, numpy_type):
168
+ n = 100
169
+ counter = wp.array([0], dtype=warp_type, device=device)
170
+ lock = wp.zeros(shape=(1, 1, 1), dtype=warp_type, device=device)
171
+
172
+ test_spinlock_counter = create_spinlock_test_3d(warp_type)
173
+ wp.launch(test_spinlock_counter, dim=n, inputs=[counter, lock], device=device)
174
+
175
+ # Verify counter reached n
176
+ counter_np = counter.numpy()
177
+ expected = np.array([n], dtype=numpy_type)
178
+
179
+ if not np.array_equal(counter_np, expected):
180
+ print(f"Counter mismatch: expected {expected}, got {counter_np}")
181
+
182
+ assert_np_equal(counter_np, expected)
183
+
184
+
185
+ def create_spinlock_test_4d(dtype):
186
+ @wp.func
187
+ def spinlock_acquire(lock: wp.array(dtype=dtype, ndim=4)):
188
+ # Try to acquire the lock by setting it to 1 if it's 0
189
+ while wp.atomic_cas(lock, 0, 0, 0, 0, dtype(0), dtype(1)) == 1:
190
+ pass
191
+
192
+ @wp.func
193
+ def spinlock_release(lock: wp.array(dtype=dtype, ndim=4)):
194
+ # Release the lock by setting it back to 0
195
+ wp.atomic_exch(lock, 0, 0, 0, 0, dtype(0))
196
+
197
+ @wp.func
198
+ def volatile_read(ptr: wp.array(dtype=dtype), index: int):
199
+ value = wp.atomic_exch(ptr, index, dtype(0))
200
+ wp.atomic_exch(ptr, index, value)
201
+ return value
202
+
203
+ @wp.kernel
204
+ def test_spinlock_counter(counter: wp.array(dtype=dtype), lock: wp.array(dtype=dtype, ndim=4)):
205
+ # Try to acquire the lock
206
+ spinlock_acquire(lock)
207
+
208
+ # Critical section - increment counter
209
+ # counter[0] = counter[0] + 1 # This gives wrong results - counter should be marked as volatile
210
+
211
+ # Work around since warp arrays cannot be marked as volatile
212
+ value = volatile_read(counter, 0)
213
+ counter[0] = value + dtype(1)
214
+
215
+ # Release the lock
216
+ spinlock_release(lock)
217
+
218
+ return test_spinlock_counter
219
+
220
+
221
+ def test_atomic_cas_4d(test, device, warp_type, numpy_type):
222
+ n = 100
223
+ counter = wp.array([0], dtype=warp_type, device=device)
224
+ lock = wp.zeros(shape=(1, 1, 1, 1), dtype=warp_type, device=device)
225
+
226
+ test_spinlock_counter = create_spinlock_test_4d(warp_type)
227
+ wp.launch(test_spinlock_counter, dim=n, inputs=[counter, lock], device=device)
228
+
229
+ # Verify counter reached n
230
+ counter_np = counter.numpy()
231
+ expected = np.array([n], dtype=numpy_type)
232
+
233
+ if not np.array_equal(counter_np, expected):
234
+ print(f"Counter mismatch: expected {expected}, got {counter_np}")
235
+
236
+ assert_np_equal(counter_np, expected)
237
+
238
+
239
+ devices = get_test_devices()
240
+
241
+
242
+ class TestAtomicCAS(unittest.TestCase):
243
+ pass
244
+
245
+
246
+ # Test all supported types
247
+ test_types = [
248
+ (wp.int32, np.int32),
249
+ (wp.uint32, np.uint32),
250
+ (wp.int64, np.int64),
251
+ (wp.uint64, np.uint64),
252
+ (wp.float32, np.float32),
253
+ (wp.float64, np.float64),
254
+ ]
255
+
256
+ for warp_type, numpy_type in test_types:
257
+ type_name = warp_type.__name__
258
+ add_function_test(
259
+ TestAtomicCAS,
260
+ f"test_cas_{type_name}",
261
+ test_atomic_cas,
262
+ devices=devices,
263
+ warp_type=warp_type,
264
+ numpy_type=numpy_type,
265
+ )
266
+
267
+ # Add 2D test for each type
268
+ add_function_test(
269
+ TestAtomicCAS,
270
+ f"test_cas_2d_{type_name}",
271
+ test_atomic_cas_2d,
272
+ devices=devices,
273
+ warp_type=warp_type,
274
+ numpy_type=numpy_type,
275
+ )
276
+
277
+ # Add 3D test for each type
278
+ add_function_test(
279
+ TestAtomicCAS,
280
+ f"test_cas_3d_{type_name}",
281
+ test_atomic_cas_3d,
282
+ devices=devices,
283
+ warp_type=warp_type,
284
+ numpy_type=numpy_type,
285
+ )
286
+
287
+ # Add 4D test for each type
288
+ add_function_test(
289
+ TestAtomicCAS,
290
+ f"test_cas_4d_{type_name}",
291
+ test_atomic_cas_4d,
292
+ devices=devices,
293
+ warp_type=warp_type,
294
+ numpy_type=numpy_type,
295
+ )
296
+
297
+ if __name__ == "__main__":
298
+ wp.clear_kernel_cache()
299
+ unittest.main(verbosity=2)
@@ -435,14 +435,8 @@ def test_error_collection_construct(test, device):
435
435
  x = [1.0, 2.0, 3.0]
436
436
 
437
437
  def kernel_2_fn():
438
- x = (1.0, 2.0, 3.0)
439
-
440
- def kernel_3_fn():
441
438
  x = {"a": 1.0, "b": 2.0, "c": 3.0}
442
439
 
443
- def kernel_4_fn():
444
- wp.length((1.0, 2.0, 3.0))
445
-
446
440
  kernel = wp.Kernel(func=kernel_1_fn)
447
441
  with test.assertRaisesRegex(
448
442
  RuntimeError,
@@ -451,22 +445,9 @@ def test_error_collection_construct(test, device):
451
445
  wp.launch(kernel, dim=1, device=device)
452
446
 
453
447
  kernel = wp.Kernel(func=kernel_2_fn)
454
- with test.assertRaisesRegex(
455
- RuntimeError,
456
- r"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3\(\)` for small collections instead.",
457
- ):
458
- wp.launch(kernel, dim=1, device=device)
459
-
460
- kernel = wp.Kernel(func=kernel_3_fn)
461
448
  with test.assertRaisesRegex(RuntimeError, r"Construct `ast.Dict` not supported in kernels."):
462
449
  wp.launch(kernel, dim=1, device=device)
463
450
 
464
- kernel = wp.Kernel(func=kernel_4_fn)
465
- with test.assertRaisesRegex(
466
- RuntimeError, r"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3\(\)` instead."
467
- ):
468
- wp.launch(kernel, dim=1, device=device)
469
-
470
451
 
471
452
  def test_error_unmatched_arguments(test, device):
472
453
  def kernel_1_fn():
@@ -693,6 +674,138 @@ def test_codegen_return_in_kernel(test, device):
693
674
  test.assertEqual(result.numpy()[0], grid_size - 256)
694
675
 
695
676
 
677
+ @wp.kernel
678
+ def conditional_ifexp(x: float, result: wp.array(dtype=wp.int32)):
679
+ wp.atomic_add(result, 0, 1) if x > 0.0 else wp.atomic_add(result, 1, 1)
680
+
681
+
682
+ def test_ifexp_only_executes_one_branch(test, device):
683
+ result = wp.zeros(2, dtype=wp.int32, device=device)
684
+
685
+ wp.launch(conditional_ifexp, dim=1, inputs=[1.0, result], device=device)
686
+
687
+ values = result.numpy()
688
+ # Only first branch is taken
689
+ test.assertEqual(values[0], 1)
690
+ test.assertEqual(values[1], 0)
691
+
692
+
693
+ @wp.kernel
694
+ def test_multiple_return_values_quat_to_axis_angle_kernel(
695
+ q: wp.quath,
696
+ expected_axis: wp.vec3h,
697
+ expected_angle: wp.float16,
698
+ ):
699
+ axis, angle = wp.quat_to_axis_angle(q)
700
+
701
+ wp.expect_near(axis[0], expected_axis[0], tolerance=wp.float16(1e-3))
702
+ wp.expect_near(axis[1], expected_axis[1], tolerance=wp.float16(1e-3))
703
+ wp.expect_near(axis[2], expected_axis[2], tolerance=wp.float16(1e-3))
704
+
705
+ wp.expect_near(angle, expected_angle, tolerance=wp.float16(1e-3))
706
+
707
+
708
+ @wp.kernel
709
+ def test_multiple_return_values_svd3_kernel(
710
+ A: wp.mat33f,
711
+ expected_U: wp.mat33f,
712
+ expected_sigma: wp.vec3f,
713
+ expected_V: wp.mat33f,
714
+ ):
715
+ U, sigma, V = wp.svd3(A)
716
+
717
+ wp.expect_near(U[0][0], expected_U[0][0], tolerance=1e-5)
718
+ wp.expect_near(U[0][1], expected_U[0][1], tolerance=1e-5)
719
+ wp.expect_near(U[0][2], expected_U[0][2], tolerance=1e-5)
720
+ wp.expect_near(U[1][0], expected_U[1][0], tolerance=1e-5)
721
+ wp.expect_near(U[1][1], expected_U[1][1], tolerance=1e-5)
722
+ wp.expect_near(U[1][2], expected_U[1][2], tolerance=1e-5)
723
+ wp.expect_near(U[2][0], expected_U[2][0], tolerance=1e-5)
724
+ wp.expect_near(U[2][1], expected_U[2][1], tolerance=1e-5)
725
+ wp.expect_near(U[2][2], expected_U[2][2], tolerance=1e-5)
726
+
727
+ wp.expect_near(sigma[0], expected_sigma[0], tolerance=1e-5)
728
+ wp.expect_near(sigma[1], expected_sigma[1], tolerance=1e-5)
729
+ wp.expect_near(sigma[2], expected_sigma[2], tolerance=1e-5)
730
+
731
+ wp.expect_near(V[0][0], expected_V[0][0], tolerance=1e-5)
732
+ wp.expect_near(V[0][1], expected_V[0][1], tolerance=1e-5)
733
+ wp.expect_near(V[0][2], expected_V[0][2], tolerance=1e-5)
734
+ wp.expect_near(V[1][0], expected_V[1][0], tolerance=1e-5)
735
+ wp.expect_near(V[1][1], expected_V[1][1], tolerance=1e-5)
736
+ wp.expect_near(V[1][2], expected_V[1][2], tolerance=1e-5)
737
+ wp.expect_near(V[2][0], expected_V[2][0], tolerance=1e-5)
738
+ wp.expect_near(V[2][1], expected_V[2][1], tolerance=1e-5)
739
+ wp.expect_near(V[2][2], expected_V[2][2], tolerance=1e-5)
740
+
741
+
742
+ def test_multiple_return_values(test, device):
743
+ q = wp.quath(1.0, 2.0, 3.0, 4.0)
744
+ expected_axis = wp.vec3h(0.26726124, 0.53452247, 0.80178368)
745
+ expected_angle = 1.50408018
746
+
747
+ axis, angle = wp.quat_to_axis_angle(q)
748
+
749
+ test.assertAlmostEqual(axis[0], expected_axis[0], places=3)
750
+ test.assertAlmostEqual(axis[1], expected_axis[1], places=3)
751
+ test.assertAlmostEqual(axis[2], expected_axis[2], places=3)
752
+
753
+ test.assertAlmostEqual(angle, expected_angle, places=3)
754
+
755
+ wp.launch(
756
+ test_multiple_return_values_quat_to_axis_angle_kernel,
757
+ dim=1,
758
+ inputs=(q, expected_axis, expected_angle),
759
+ )
760
+
761
+ # fmt: off
762
+ A = wp.mat33(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)
763
+ expected_U = wp.mat33(
764
+ 0.21483721, 0.88723058, -0.40824816,
765
+ 0.52058744, 0.24964368, 0.81649637,
766
+ 0.82633746, -0.38794267, -0.40824834,
767
+ )
768
+ expected_sigma = wp.vec3(16.84809875, 1.06836915, 0.00000019)
769
+ expected_V = wp.mat33(
770
+ 0.47967088, -0.77669406, 0.40824246,
771
+ 0.57236743, -0.07568054, -0.81649727,
772
+ 0.66506463, 0.62531471, 0.40825251,
773
+ )
774
+ # fmt: on
775
+
776
+ U, sigma, V = wp.svd3(A)
777
+
778
+ test.assertAlmostEqual(U[0][0], expected_U[0][0], places=5)
779
+ test.assertAlmostEqual(U[0][1], expected_U[0][1], places=5)
780
+ test.assertAlmostEqual(U[0][2], expected_U[0][2], places=5)
781
+ test.assertAlmostEqual(U[1][0], expected_U[1][0], places=5)
782
+ test.assertAlmostEqual(U[1][1], expected_U[1][1], places=5)
783
+ test.assertAlmostEqual(U[1][2], expected_U[1][2], places=5)
784
+ test.assertAlmostEqual(U[2][0], expected_U[2][0], places=5)
785
+ test.assertAlmostEqual(U[2][1], expected_U[2][1], places=5)
786
+ test.assertAlmostEqual(U[2][2], expected_U[2][2], places=5)
787
+
788
+ test.assertAlmostEqual(sigma[0], expected_sigma[0], places=5)
789
+ test.assertAlmostEqual(sigma[1], expected_sigma[1], places=5)
790
+ test.assertAlmostEqual(sigma[2], expected_sigma[2], places=5)
791
+
792
+ test.assertAlmostEqual(V[0][0], expected_V[0][0], places=5)
793
+ test.assertAlmostEqual(V[0][1], expected_V[0][1], places=5)
794
+ test.assertAlmostEqual(V[0][2], expected_V[0][2], places=5)
795
+ test.assertAlmostEqual(V[1][0], expected_V[1][0], places=5)
796
+ test.assertAlmostEqual(V[1][1], expected_V[1][1], places=5)
797
+ test.assertAlmostEqual(V[1][2], expected_V[1][2], places=5)
798
+ test.assertAlmostEqual(V[2][0], expected_V[2][0], places=5)
799
+ test.assertAlmostEqual(V[2][1], expected_V[2][1], places=5)
800
+ test.assertAlmostEqual(V[2][2], expected_V[2][2], places=5)
801
+
802
+ wp.launch(
803
+ test_multiple_return_values_svd3_kernel,
804
+ dim=1,
805
+ inputs=(A, expected_U, expected_sigma, expected_V),
806
+ )
807
+
808
+
696
809
  class TestCodeGen(unittest.TestCase):
697
810
  pass
698
811
 
@@ -825,6 +938,16 @@ add_kernel_test(TestCodeGen, name="test_call_syntax", kernel=test_call_syntax, d
825
938
  add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices)
826
939
  add_kernel_test(TestCodeGen, name="test_while_condition_eval", kernel=test_while_condition_eval, dim=1, devices=devices)
827
940
  add_function_test(TestCodeGen, "test_codegen_return_in_kernel", test_codegen_return_in_kernel, devices=devices)
941
+ add_function_test(
942
+ TestCodeGen, "test_ifexp_only_executes_one_branch", test_ifexp_only_executes_one_branch, devices=devices
943
+ )
944
+ add_function_test(
945
+ TestCodeGen,
946
+ func=test_multiple_return_values,
947
+ name="test_multiple_return_values",
948
+ devices=devices,
949
+ )
950
+
828
951
 
829
952
  if __name__ == "__main__":
830
953
  wp.clear_kernel_cache()
@@ -58,6 +58,48 @@ def test_conditional_if_else_nested():
58
58
  wp.expect_eq(e, -2.0)
59
59
 
60
60
 
61
+ @wp.kernel
62
+ def test_conditional_ifexp():
63
+ a = 0.5
64
+ b = 2.0
65
+
66
+ c = 1.0 if a > b else -1.0
67
+
68
+ wp.expect_eq(c, -1.0)
69
+
70
+
71
+ @wp.kernel
72
+ def test_conditional_ifexp_nested():
73
+ a = 1.0
74
+ b = 2.0
75
+
76
+ c = 3.0 if a > b else 6.0
77
+ d = 4.0 if a > b else 7.0
78
+ e = 1.0 if (a > b and c > d) else (-1.0 if a > b else (2.0 if c > d else -2.0))
79
+
80
+ wp.expect_eq(e, -2.0)
81
+
82
+
83
+ @wp.kernel
84
+ def test_conditional_ifexp_constant():
85
+ a = 1.0 if False else -1.0
86
+ b = 2.0 if 123 else -2.0
87
+
88
+ wp.expect_eq(a, -1.0)
89
+ wp.expect_eq(b, 2.0)
90
+
91
+
92
+ @wp.kernel
93
+ def test_conditional_ifexp_constant_nested():
94
+ a = 1.0 if False else (2.0 if True else 3.0)
95
+ b = 4.0 if 0 else (5.0 if 0 else (6.0 if False else 7.0))
96
+ c = 8.0 if False else (9.0 if False else (10.0 if 321 else 11.0))
97
+
98
+ wp.expect_eq(a, 2.0)
99
+ wp.expect_eq(b, 7.0)
100
+ wp.expect_eq(c, 10.0)
101
+
102
+
61
103
  @wp.kernel
62
104
  def test_boolean_and():
63
105
  a = 1.0
@@ -90,7 +132,7 @@ def test_boolean_compound():
90
132
 
91
133
  d = 1.0
92
134
 
93
- if a > 0.0 and b > 0.0 or c > a:
135
+ if (a > 0.0 and b > 0.0) or c > a:
94
136
  d = -1.0
95
137
 
96
138
  wp.expect_eq(d, -1.0)
@@ -231,6 +273,10 @@ class TestConditional(unittest.TestCase):
231
273
 
232
274
  add_kernel_test(TestConditional, kernel=test_conditional_if_else, dim=1, devices=devices)
233
275
  add_kernel_test(TestConditional, kernel=test_conditional_if_else_nested, dim=1, devices=devices)
276
+ add_kernel_test(TestConditional, kernel=test_conditional_ifexp, dim=1, devices=devices)
277
+ add_kernel_test(TestConditional, kernel=test_conditional_ifexp_nested, dim=1, devices=devices)
278
+ add_kernel_test(TestConditional, kernel=test_conditional_ifexp_constant, dim=1, devices=devices)
279
+ add_kernel_test(TestConditional, kernel=test_conditional_ifexp_constant_nested, dim=1, devices=devices)
234
280
  add_kernel_test(TestConditional, kernel=test_boolean_and, dim=1, devices=devices)
235
281
  add_kernel_test(TestConditional, kernel=test_boolean_or, dim=1, devices=devices)
236
282
  add_kernel_test(TestConditional, kernel=test_boolean_compound, dim=1, devices=devices)
warp/tests/test_ctypes.py CHANGED
@@ -541,25 +541,6 @@ def test_scalar_array_types(test, device, load, store):
541
541
  )
542
542
 
543
543
 
544
- @wp.kernel
545
- def test_transform_matrix():
546
- r = wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), 0.5)
547
- t = wp.vec3(0.25, 0.5, -0.75)
548
- s = wp.vec3(2.0, 0.5, 0.75)
549
-
550
- m = wp.mat44(t, r, s)
551
-
552
- p = wp.vec3(1.0, 2.0, 3.0)
553
-
554
- r_0 = wp.quat_rotate(r, wp.cw_mul(s, p)) + t
555
- r_1 = wp.transform_point(m, p)
556
-
557
- r_2 = wp.transform_vector(m, p)
558
-
559
- wp.expect_near(r_0, r_1, 1.0e-4)
560
- wp.expect_near(r_2, r_0 - t, 1.0e-4)
561
-
562
-
563
544
  devices = get_test_devices()
564
545
 
565
546
 
@@ -628,7 +609,6 @@ add_function_test(TestCTypes, "test_vec2_transform", test_vec2_transform, device
628
609
  add_function_test(TestCTypes, "test_vec3_arg", test_vec3_arg, devices=devices, n=8)
629
610
  add_function_test(TestCTypes, "test_vec3_transform", test_vec3_transform, devices=devices, n=8)
630
611
  add_function_test(TestCTypes, "test_transform_multiply", test_transform_multiply, devices=devices, n=8)
631
- add_kernel_test(TestCTypes, name="test_transform_matrix", kernel=test_transform_matrix, dim=1, devices=devices)
632
612
  add_function_test(TestCTypes, "test_scalar_array", test_scalar_array, devices=devices)
633
613
  add_function_test(TestCTypes, "test_vector_array", test_vector_array, devices=devices)
634
614
 
@@ -70,6 +70,13 @@ def test_devices_can_access_self(test, device):
70
70
  test.assertNotEqual(device, device_str)
71
71
 
72
72
 
73
+ def test_devices_sm_count(test, device):
74
+ if device.is_cuda:
75
+ test.assertTrue(device.sm_count > 0)
76
+ else:
77
+ test.assertEqual(device.sm_count, 0)
78
+
79
+
73
80
  devices = get_test_devices()
74
81
 
75
82
 
@@ -90,6 +97,7 @@ add_function_test(
90
97
  )
91
98
  add_function_test(TestDevices, "test_devices_verify_cuda_device", test_devices_verify_cuda_device, devices=devices)
92
99
  add_function_test(TestDevices, "test_devices_can_access_self", test_devices_can_access_self, devices=devices)
100
+ add_function_test(TestDevices, "test_devices_sm_count", test_devices_sm_count, devices=devices)
93
101
 
94
102
 
95
103
  if __name__ == "__main__":
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  import math
17
19
  import unittest
18
20
  from typing import Any
@@ -70,7 +72,7 @@ def _warp_type_to_fabric(dtype, is_array=False):
70
72
 
71
73
 
72
74
  # returns a fabric array interface constructed from a regular array
73
- def _create_fabric_array_interface(data: wp.array, attrib: str, bucket_sizes: list = None, copy=False):
75
+ def _create_fabric_array_interface(data: wp.array, attrib: str, bucket_sizes: list[int] | None = None, copy=False):
74
76
  assert isinstance(data, wp.array)
75
77
  assert data.ndim == 1
76
78
 
@@ -138,7 +140,7 @@ def _create_fabric_array_interface(data: wp.array, attrib: str, bucket_sizes: li
138
140
 
139
141
 
140
142
  # returns a fabric array array interface constructed from a list of regular arrays
141
- def _create_fabric_array_array_interface(data: list, attrib: str, bucket_sizes: list = None):
143
+ def _create_fabric_array_array_interface(data: list, attrib: str, bucket_sizes: list[int] | None = None):
142
144
  # data should be a list of arrays
143
145
  assert isinstance(data, list)
144
146