warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,265 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+ from typing import Any, Tuple
18
+
19
+ import warp as wp
20
+ from warp.tests.unittest_utils import *
21
+
22
+
23
+ @wp.struct
24
+ class BasicsStruct:
25
+ origin: wp.vec3
26
+ scale: float
27
+
28
+
29
+ @wp.kernel
30
+ def test_basics():
31
+ tid = wp.tid()
32
+ s = BasicsStruct(wp.vec3(1.1, 2.2, 3.3), 1.23)
33
+
34
+ t = (1, 2.2, wp.vec3(1.1, 2.2, 3.3), wp.mat22(1.1, 2.2, 3.3, 4.4), s, tid)
35
+ wp.expect_eq(len(t), 6)
36
+ wp.expect_eq(wp.static(len(t)), 6)
37
+ wp.expect_eq(t[0], 1)
38
+ wp.expect_eq(t[1], 2.2)
39
+ wp.expect_eq(t[2], wp.vec3(1.1, 2.2, 3.3))
40
+ wp.expect_eq(t[3], wp.mat22(1.1, 2.2, 3.3, 4.4))
41
+ wp.expect_eq(t[4].origin, wp.vec3(1.1, 2.2, 3.3))
42
+ wp.expect_eq(t[4].scale, 1.23)
43
+ wp.expect_eq(t[5], wp.tid())
44
+
45
+ t0, t1, t2, t3, t4, t5 = t
46
+ wp.expect_eq(t0, 1)
47
+ wp.expect_eq(t1, 2.2)
48
+ wp.expect_eq(t2, wp.vec3(1.1, 2.2, 3.3))
49
+ wp.expect_eq(t3, wp.mat22(1.1, 2.2, 3.3, 4.4))
50
+ wp.expect_eq(t4.origin, wp.vec3(1.1, 2.2, 3.3))
51
+ wp.expect_eq(t4.scale, 1.23)
52
+ wp.expect_eq(t5, wp.tid())
53
+
54
+
55
+ @wp.kernel
56
+ def test_builtin_with_multiple_return():
57
+ expected_axis = wp.vec3(0.26726124, 0.53452247, 0.80178368)
58
+ expected_angle = 1.50408018
59
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
60
+
61
+ t = wp.quat_to_axis_angle(q)
62
+ wp.expect_eq(len(t), 2)
63
+ wp.expect_eq(wp.static(len(t)), 2)
64
+
65
+ axis_1 = t[0]
66
+ angle_1 = t[1]
67
+ wp.expect_near(axis_1[0], expected_axis[0])
68
+ wp.expect_near(axis_1[1], expected_axis[1])
69
+ wp.expect_near(axis_1[2], expected_axis[2])
70
+ wp.expect_near(angle_1, expected_angle)
71
+
72
+ axis_2, angle_2 = t
73
+ wp.expect_near(axis_2[0], expected_axis[0])
74
+ wp.expect_near(axis_2[1], expected_axis[1])
75
+ wp.expect_near(axis_2[2], expected_axis[2])
76
+ wp.expect_near(angle_2, expected_angle)
77
+
78
+ axis_3, angle_3 = wp.quat_to_axis_angle(q)
79
+ wp.expect_near(axis_3[0], expected_axis[0])
80
+ wp.expect_near(axis_3[1], expected_axis[1])
81
+ wp.expect_near(axis_3[2], expected_axis[2])
82
+ wp.expect_near(angle_3, expected_angle)
83
+
84
+
85
+ @wp.func
86
+ def user_func_with_multiple_return(x: int, y: float) -> Tuple[int, float]:
87
+ return (x * 123, y * 1.23)
88
+
89
+
90
+ @wp.kernel
91
+ def test_user_func_with_multiple_return():
92
+ t = user_func_with_multiple_return(4, wp.pow(2.0, 3.0))
93
+ wp.expect_eq(len(t), 2)
94
+ wp.expect_eq(wp.static(len(t)), 2)
95
+
96
+ x_1 = t[0]
97
+ y_1 = t[1]
98
+ wp.expect_eq(x_1, 492)
99
+ wp.expect_near(y_1, 9.84)
100
+
101
+ x_2, y_2 = t
102
+ wp.expect_eq(x_2, 492)
103
+ wp.expect_near(y_2, 9.84)
104
+
105
+ x_3, y_3 = user_func_with_multiple_return(4, wp.pow(2.0, 3.0))
106
+ wp.expect_eq(x_3, 492)
107
+ wp.expect_near(y_3, 9.84)
108
+
109
+
110
+ @wp.func
111
+ def user_func_with_tuple_arg(values: Tuple[wp.vec3, float]) -> float:
112
+ wp.expect_eq(len(values), 2)
113
+ wp.expect_eq(wp.static(len(values)), 2)
114
+ return wp.length(values[0]) * values[1]
115
+
116
+
117
+ @wp.kernel
118
+ def test_user_func_with_tuple_arg():
119
+ t = (wp.vec3(1.0, 2.0, 3.0), wp.pow(2.0, 4.0))
120
+ wp.expect_eq(len(t), 2)
121
+ wp.expect_eq(wp.static(len(t)), 2)
122
+
123
+ x_1 = user_func_with_tuple_arg(t)
124
+ wp.expect_near(x_1, 59.86652)
125
+
126
+ x_2 = user_func_with_tuple_arg((t[0], t[1]))
127
+ wp.expect_near(x_2, 59.86652)
128
+
129
+ x_3 = user_func_with_tuple_arg((wp.vec3(1.0, 2.0, 3.0), wp.pow(2.0, 4.0)))
130
+ wp.expect_near(x_3, 59.86652)
131
+
132
+
133
+ @wp.func
134
+ def loop_user_func(values: Tuple[int, int, int]):
135
+ out = wp.int32(0)
136
+ for i in range(wp.static(len(values))):
137
+ out += values[i]
138
+
139
+ for i in range(len(values)):
140
+ out += values[i] * 2
141
+
142
+ return out
143
+
144
+
145
+ @wp.kernel
146
+ def test_loop():
147
+ t = (1, 2, 3)
148
+ res = loop_user_func(t)
149
+ wp.expect_eq(res, 18)
150
+
151
+
152
+ @wp.func
153
+ def loop_variadic_any_user_func(values: Any):
154
+ out = wp.int32(0)
155
+ for i in range(wp.static(len(values))):
156
+ out += values[i]
157
+
158
+ for i in range(len(values)):
159
+ out += values[i] * 2
160
+
161
+ return out
162
+
163
+
164
+ @wp.kernel
165
+ def test_loop_variadic_any():
166
+ t1 = (1,)
167
+ res = loop_variadic_any_user_func(t1)
168
+ wp.expect_eq(res, 3)
169
+
170
+ t2 = (2, 3)
171
+ res = loop_variadic_any_user_func(t2)
172
+ wp.expect_eq(res, 15)
173
+
174
+ t3 = (3, 4, 5)
175
+ res = loop_variadic_any_user_func(t3)
176
+ wp.expect_eq(res, 36)
177
+
178
+ t4 = (4, 5, 6, 7)
179
+ res = loop_variadic_any_user_func(t4)
180
+ wp.expect_eq(res, 66)
181
+
182
+
183
+ @wp.func
184
+ def loop_variadic_ellipsis_user_func(values: Tuple[int, ...]):
185
+ out = wp.int32(0)
186
+ for i in range(wp.static(len(values))):
187
+ out += values[i]
188
+
189
+ return out
190
+
191
+
192
+ @wp.kernel
193
+ def test_loop_variadic_ellipsis():
194
+ t1 = (1,)
195
+ res = loop_variadic_ellipsis_user_func(t1)
196
+ wp.expect_eq(res, 1)
197
+
198
+ t2 = (2, 3)
199
+ res = loop_variadic_ellipsis_user_func(t2)
200
+ wp.expect_eq(res, 5)
201
+
202
+ t3 = (3, 4, 5)
203
+ res = loop_variadic_ellipsis_user_func(t3)
204
+ wp.expect_eq(res, 12)
205
+
206
+ t4 = (4, 5, 6, 7)
207
+ res = loop_variadic_ellipsis_user_func(t4)
208
+ wp.expect_eq(res, 22)
209
+
210
+
211
+ devices = get_test_devices()
212
+
213
+
214
+ class TestTuple(unittest.TestCase):
215
+ pass
216
+
217
+
218
+ add_kernel_test(TestTuple, name="test_basics", kernel=test_basics, dim=3, devices=devices)
219
+ add_kernel_test(
220
+ TestTuple,
221
+ name="test_builtin_with_multiple_return",
222
+ kernel=test_builtin_with_multiple_return,
223
+ dim=1,
224
+ devices=devices,
225
+ )
226
+ add_kernel_test(
227
+ TestTuple,
228
+ name="test_user_func_with_multiple_return",
229
+ kernel=test_user_func_with_multiple_return,
230
+ dim=1,
231
+ devices=devices,
232
+ )
233
+ add_kernel_test(
234
+ TestTuple,
235
+ name="test_user_func_with_tuple_arg",
236
+ kernel=test_user_func_with_tuple_arg,
237
+ dim=1,
238
+ devices=devices,
239
+ )
240
+ add_kernel_test(
241
+ TestTuple,
242
+ name="test_loop",
243
+ kernel=test_loop,
244
+ dim=1,
245
+ devices=devices,
246
+ )
247
+ add_kernel_test(
248
+ TestTuple,
249
+ name="test_loop_variadic_any",
250
+ kernel=test_loop_variadic_any,
251
+ dim=1,
252
+ devices=devices,
253
+ )
254
+ add_kernel_test(
255
+ TestTuple,
256
+ name="test_loop_variadic_ellipsis",
257
+ kernel=test_loop_variadic_ellipsis,
258
+ dim=1,
259
+ devices=devices,
260
+ )
261
+
262
+
263
+ if __name__ == "__main__":
264
+ wp.clear_kernel_cache()
265
+ unittest.main(verbosity=2)
warp/tests/test_types.py CHANGED
@@ -286,7 +286,7 @@ class TestTypes(unittest.TestCase):
286
286
  v2[:1] = (v2,)
287
287
 
288
288
  def test_matrix(self):
289
- for dtype in tuple(wp.types.float_types) + (float,):
289
+ for dtype in (*wp.types.float_types, float):
290
290
 
291
291
  def make_scalar(x, dtype=dtype):
292
292
  # Cast to the correct integer type to simulate wrapping.
@@ -554,7 +554,7 @@ for dtype in wp.types.int_types:
554
554
  for dtype in wp.types.float_types:
555
555
  add_function_test(TestTypes, f"test_floats_{dtype.__name__}", test_floats, devices=devices, dtype=dtype)
556
556
 
557
- for dtype in tuple(wp.types.scalar_types) + (int, float):
557
+ for dtype in (*wp.types.scalar_types, int, float):
558
558
  add_function_test(TestTypes, f"test_vector_{dtype.__name__}", test_vector, devices=devices, dtype=dtype)
559
559
 
560
560
  if __name__ == "__main__":
warp/tests/test_utils.py CHANGED
@@ -19,6 +19,7 @@ import io
19
19
  import unittest
20
20
 
21
21
  from warp.tests.unittest_utils import *
22
+ from warp.types import type_scalar_type
22
23
 
23
24
 
24
25
  def test_array_scan(test, device):
@@ -61,7 +62,7 @@ def test_array_scan_error_sizes_mismatch(test, device):
61
62
  result = wp.zeros(234, dtype=int, device=device)
62
63
  with test.assertRaisesRegex(
63
64
  RuntimeError,
64
- r"Array storage sizes do not match$",
65
+ r"In and out array storage sizes do not match \(123 vs 234\)$",
65
66
  ):
66
67
  wp.utils.array_scan(values, result, True)
67
68
 
@@ -71,7 +72,7 @@ def test_array_scan_error_dtypes_mismatch(test, device):
71
72
  result = wp.zeros(123, dtype=float, device=device)
72
73
  with test.assertRaisesRegex(
73
74
  RuntimeError,
74
- r"Array data types do not match$",
75
+ r"In and out array data types do not match \(int32 vs float32\)$",
75
76
  ):
76
77
  wp.utils.array_scan(values, result, True)
77
78
 
@@ -81,7 +82,7 @@ def test_array_scan_error_unsupported_dtype(test, device):
81
82
  result = wp.zeros(123, dtype=wp.vec3, device=device)
82
83
  with test.assertRaisesRegex(
83
84
  RuntimeError,
84
- r"Unsupported data type$",
85
+ r"Unsupported data type: vec3f$",
85
86
  ):
86
87
  wp.utils.array_scan(values, result, True)
87
88
 
@@ -142,7 +143,7 @@ def test_radix_sort_pairs_error_insufficient_storage(test, device):
142
143
  values = wp.array((1, 2, 3), dtype=int, device=device)
143
144
  with test.assertRaisesRegex(
144
145
  RuntimeError,
145
- r"Array storage must be large enough to contain 2\*count elements$",
146
+ r"Keys and values array storage must be large enough to contain 2\*count elements$",
146
147
  ):
147
148
  wp.utils.radix_sort_pairs(keys, values, 3)
148
149
 
@@ -167,27 +168,27 @@ def test_segmented_sort_pairs_error_insufficient_storage(test, device):
167
168
 
168
169
 
169
170
  def test_radix_sort_pairs_error_unsupported_dtype(test, device):
170
- keyTypes = [int, wp.float32, wp.int64]
171
+ keyTypes = [wp.int32, wp.float32, wp.int64]
171
172
 
172
173
  for keyType in keyTypes:
173
174
  keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
174
175
  values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
175
176
  with test.assertRaisesRegex(
176
177
  RuntimeError,
177
- r"Unsupported data type$",
178
+ rf"Unsupported keys and values data types: {keyType.__name__}, float32$",
178
179
  ):
179
180
  wp.utils.radix_sort_pairs(keys, values, 1)
180
181
 
181
182
 
182
183
  def test_segmented_sort_pairs_error_unsupported_dtype(test, device):
183
- keyTypes = [int, wp.float32]
184
+ keyTypes = [wp.int32, wp.float32]
184
185
 
185
186
  for keyType in keyTypes:
186
187
  keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
187
188
  values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
188
189
  with test.assertRaisesRegex(
189
190
  RuntimeError,
190
- r"Unsupported data type$",
191
+ rf"Unsupported data type: {keyType.__name__}$",
191
192
  ):
192
193
  wp.utils.segmented_sort_pairs(
193
194
  keys,
@@ -234,30 +235,35 @@ def test_array_sum_error_unsupported_dtype(test, device):
234
235
  values = wp.array((1, 2, 3), dtype=int, device=device)
235
236
  with test.assertRaisesRegex(
236
237
  RuntimeError,
237
- r"Unsupported data type$",
238
+ r"Unsupported data type: int32$",
238
239
  ):
239
240
  wp.utils.array_sum(values)
240
241
 
241
242
 
242
243
  def test_array_inner(test, device):
243
- for dtype in (wp.float32, wp.float64):
244
+ for dtype in (wp.float32, wp.float64, wp.vec3):
244
245
  a = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
245
246
  b = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
246
247
  test.assertEqual(wp.utils.array_inner(a, b), 14.0)
247
248
 
248
249
  a = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
249
250
  b = wp.array((1.0, 2.0, 3.0), dtype=dtype, device=device)
250
- result = wp.empty(shape=(1,), dtype=dtype, device=device)
251
+ result = wp.empty(shape=(1,), dtype=type_scalar_type(dtype), device=device)
251
252
  wp.utils.array_inner(a, b, out=result)
252
253
  test.assertEqual(result.numpy()[0], 14.0)
253
254
 
255
+ # test with different instances of same type
256
+ a = wp.array((1.0, 2.0, 3.0), dtype=wp.vec3, device=device)
257
+ b = wp.array((1.0, 2.0, 3.0), dtype=wp.vec(3, float), device=device)
258
+ test.assertEqual(wp.utils.array_inner(a, b), 14.0)
259
+
254
260
 
255
261
  def test_array_inner_error_sizes_mismatch(test, device):
256
262
  a = wp.array((1.0, 2.0), dtype=wp.float32, device=device)
257
263
  b = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device=device)
258
264
  with test.assertRaisesRegex(
259
265
  RuntimeError,
260
- r"Array storage sizes do not match$",
266
+ r"A and b array storage sizes do not match \(2 vs 3\)$",
261
267
  ):
262
268
  wp.utils.array_inner(a, b)
263
269
 
@@ -267,7 +273,7 @@ def test_array_inner_error_dtypes_mismatch(test, device):
267
273
  b = wp.array((1.0, 2.0, 3.0), dtype=wp.float64, device=device)
268
274
  with test.assertRaisesRegex(
269
275
  RuntimeError,
270
- r"Array data types do not match$",
276
+ r"A and b array data types do not match \(float32 vs float64\)$",
271
277
  ):
272
278
  wp.utils.array_inner(a, b)
273
279
 
@@ -299,7 +305,7 @@ def test_array_inner_error_unsupported_dtype(test, device):
299
305
  b = wp.array((1, 2, 3), dtype=int, device=device)
300
306
  with test.assertRaisesRegex(
301
307
  RuntimeError,
302
- r"Unsupported data type$",
308
+ r"Unsupported data type: int32$",
303
309
  ):
304
310
  wp.utils.array_inner(a, b)
305
311
 
@@ -411,7 +417,7 @@ class TestUtils(unittest.TestCase):
411
417
  result = wp.zeros_like(values, device="cuda:0")
412
418
  with self.assertRaisesRegex(
413
419
  RuntimeError,
414
- r"Array storage devices do not match$",
420
+ r"In and out array storage devices do not match \(cpu vs cuda:0\)$",
415
421
  ):
416
422
  wp.utils.array_scan(values, result, True)
417
423
 
@@ -421,7 +427,7 @@ class TestUtils(unittest.TestCase):
421
427
  values = wp.array((1, 2, 3), dtype=int, device="cuda:0")
422
428
  with self.assertRaisesRegex(
423
429
  RuntimeError,
424
- r"Array storage devices do not match$",
430
+ r"Keys and values array storage devices do not match \(cpu vs cuda:0\)$",
425
431
  ):
426
432
  wp.utils.radix_sort_pairs(keys, values, 1)
427
433
 
@@ -452,7 +458,7 @@ class TestUtils(unittest.TestCase):
452
458
  b = wp.array((1.0, 2.0, 3.0), dtype=wp.float32, device="cuda:0")
453
459
  with self.assertRaisesRegex(
454
460
  RuntimeError,
455
- r"Array storage devices do not match$",
461
+ r"A and b array storage devices do not match \(cpu vs cuda:0\)$",
456
462
  ):
457
463
  wp.utils.array_inner(a, b)
458
464
 
@@ -462,7 +468,7 @@ class TestUtils(unittest.TestCase):
462
468
  result = wp.empty(3, dtype=float, device="cuda:0")
463
469
  with self.assertRaisesRegex(
464
470
  RuntimeError,
465
- r"Array storage devices do not match$",
471
+ r"Array storage devices do not match \(cpu vs cuda:0\)$",
466
472
  ):
467
473
  wp.utils.array_cast(values, result)
468
474