warp-lang 1.8.1__py3-none-macosx_10_13_universal2.whl → 1.9.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,229 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import warp as wp
19
+ from warp.tests.unittest_utils import *
20
+
21
+ devices = get_test_devices()
22
+
23
+
24
+ @wp.kernel
25
+ def test_zeros():
26
+ arr = wp.zeros(shape=(2, 3), dtype=int)
27
+ for i in range(arr.shape[0]):
28
+ for j in range(arr.shape[1]):
29
+ arr[i][j] = i * arr.shape[1] + j
30
+
31
+ wp.expect_eq(arr[0][0], 0)
32
+ wp.expect_eq(arr[0][1], 1)
33
+ wp.expect_eq(arr[0][2], 2)
34
+ wp.expect_eq(arr[1][0], 3)
35
+ wp.expect_eq(arr[1][1], 4)
36
+ wp.expect_eq(arr[1][2], 5)
37
+
38
+
39
+ @wp.func
40
+ def test_func_arg_func(arr: wp.array(ndim=2, dtype=int)):
41
+ for i in range(arr.shape[0]):
42
+ for j in range(arr.shape[1]):
43
+ arr[i][j] = i * arr.shape[1] + j
44
+
45
+
46
+ @wp.kernel
47
+ def test_func_arg():
48
+ arr = wp.zeros(shape=(2, 3), dtype=int)
49
+ test_func_arg_func(arr)
50
+
51
+ wp.expect_eq(arr[0][0], 0)
52
+ wp.expect_eq(arr[0][1], 1)
53
+ wp.expect_eq(arr[0][2], 2)
54
+ wp.expect_eq(arr[1][0], 3)
55
+ wp.expect_eq(arr[1][1], 4)
56
+ wp.expect_eq(arr[1][2], 5)
57
+
58
+
59
+ @wp.func
60
+ def test_func_return_func():
61
+ arr = wp.zeros(shape=(2, 3), dtype=int)
62
+ for i in range(arr.shape[0]):
63
+ for j in range(arr.shape[1]):
64
+ arr[i][j] = i * arr.shape[1] + j
65
+
66
+ return arr
67
+
68
+
69
+ @wp.kernel
70
+ def test_func_return():
71
+ arr = test_func_return_func()
72
+
73
+ wp.expect_eq(arr[0][0], 0)
74
+ wp.expect_eq(arr[0][1], 1)
75
+ wp.expect_eq(arr[0][2], 2)
76
+ wp.expect_eq(arr[1][0], 3)
77
+ wp.expect_eq(arr[1][1], 4)
78
+ wp.expect_eq(arr[1][2], 5)
79
+
80
+
81
+ @wp.func
82
+ def test_func_return_annotation_func() -> wp.fixedarray(shape=(2, 3), dtype=int):
83
+ arr = wp.zeros(shape=(2, 3), dtype=int)
84
+ for i in range(arr.shape[0]):
85
+ for j in range(arr.shape[1]):
86
+ arr[i][j] = i * arr.shape[1] + j
87
+
88
+ return arr
89
+
90
+
91
+ @wp.kernel
92
+ def test_func_return_annotation():
93
+ arr = test_func_return_annotation_func()
94
+
95
+ wp.expect_eq(arr[0][0], 0)
96
+ wp.expect_eq(arr[0][1], 1)
97
+ wp.expect_eq(arr[0][2], 2)
98
+ wp.expect_eq(arr[1][0], 3)
99
+ wp.expect_eq(arr[1][1], 4)
100
+ wp.expect_eq(arr[1][2], 5)
101
+
102
+
103
+ def test_error_invalid_func_return_annotation(test, device):
104
+ @wp.func
105
+ def func() -> wp.array(ndim=2, dtype=int):
106
+ arr = wp.zeros(shape=(2, 3), dtype=int)
107
+ for i in range(arr.shape[0]):
108
+ for j in range(arr.shape[1]):
109
+ arr[i][j] = i * arr.shape[1] + j
110
+
111
+ return arr
112
+
113
+ @wp.kernel
114
+ def kernel():
115
+ arr = func()
116
+
117
+ with test.assertRaisesRegex(
118
+ wp.codegen.WarpCodegenError,
119
+ r"The function `func` returns a fixed-size array whereas it has its return type annotated as `Array\[int32\]`.$",
120
+ ):
121
+ wp.launch(kernel, 1, device=device)
122
+
123
+
124
+ def test_error_runtime_shape(test, device):
125
+ @wp.kernel
126
+ def kernel():
127
+ tid = wp.tid()
128
+ wp.zeros(shape=(tid,), dtype=int)
129
+
130
+ with test.assertRaisesRegex(
131
+ RuntimeError,
132
+ r"the `shape` argument must be specified as a constant when zero-initializing an array$",
133
+ ):
134
+ wp.launch(kernel, 1, device=device)
135
+
136
+
137
+ @wp.kernel
138
+ def test_capture_if_kernel():
139
+ arr = wp.zeros(shape=(2, 3), dtype=int)
140
+ for i in range(arr.shape[0]):
141
+ for j in range(arr.shape[1]):
142
+ arr[i][j] = i * arr.shape[1] + j
143
+
144
+ wp.expect_eq(arr[0][0], 0)
145
+ wp.expect_eq(arr[0][1], 1)
146
+ wp.expect_eq(arr[0][2], 2)
147
+ wp.expect_eq(arr[1][0], 3)
148
+ wp.expect_eq(arr[1][1], 4)
149
+ wp.expect_eq(arr[1][2], 5)
150
+
151
+
152
+ def test_capture_if(test, device):
153
+ if (
154
+ not wp.get_device(device).is_cuda
155
+ or wp.context.runtime.toolkit_version < (12, 4)
156
+ or wp.context.runtime.driver_version < (12, 4)
157
+ ):
158
+ return
159
+
160
+ def foo():
161
+ wp.launch(test_capture_if_kernel, dim=512, block_dim=128, device=device)
162
+
163
+ cond = wp.ones(1, dtype=wp.int32, device=device)
164
+ with wp.ScopedCapture(device=device) as capture:
165
+ wp.capture_if(condition=cond, on_true=foo)
166
+
167
+ wp.capture_launch(capture.graph)
168
+
169
+
170
+ @wp.struct
171
+ class test_func_struct_MyStruct:
172
+ offset: int
173
+ dist: float
174
+
175
+
176
+ @wp.func
177
+ def test_func_struct_func():
178
+ arr = wp.zeros(shape=(2, 3), dtype=test_func_struct_MyStruct)
179
+ count = float(arr.shape[0] * arr.shape[1] - 1)
180
+ for i in range(arr.shape[0]):
181
+ for j in range(arr.shape[1]):
182
+ arr[i][j].offset = i * arr.shape[1] + j
183
+ arr[i][j].dist = float(arr[i][j].offset) / count
184
+
185
+ return arr
186
+
187
+
188
+ @wp.kernel
189
+ def test_func_struct():
190
+ arr = test_func_struct_func()
191
+
192
+ wp.expect_eq(arr[0][0].offset, 0)
193
+ wp.expect_near(arr[0][0].dist, 0.0)
194
+ wp.expect_eq(arr[0][1].offset, 1)
195
+ wp.expect_near(arr[0][1].dist, 0.2)
196
+ wp.expect_eq(arr[0][2].offset, 2)
197
+ wp.expect_near(arr[0][2].dist, 0.4)
198
+ wp.expect_eq(arr[1][0].offset, 3)
199
+ wp.expect_near(arr[1][0].dist, 0.6)
200
+ wp.expect_eq(arr[1][1].offset, 4)
201
+ wp.expect_near(arr[1][1].dist, 0.8)
202
+ wp.expect_eq(arr[1][2].offset, 5)
203
+ wp.expect_near(arr[1][2].dist, 1.0)
204
+
205
+
206
+ class TestFixedArray(unittest.TestCase):
207
+ pass
208
+
209
+
210
+ add_kernel_test(TestFixedArray, kernel=test_zeros, name="test_zeros", dim=1, devices=devices)
211
+ add_kernel_test(TestFixedArray, kernel=test_func_arg, name="test_func_arg", dim=1, devices=devices)
212
+ add_kernel_test(TestFixedArray, kernel=test_func_return, name="test_func_return", dim=1, devices=devices)
213
+ add_kernel_test(
214
+ TestFixedArray, kernel=test_func_return_annotation, name="test_func_return_annotation", dim=1, devices=devices
215
+ )
216
+ add_function_test(
217
+ TestFixedArray,
218
+ "test_error_invalid_func_return_annotation",
219
+ test_error_invalid_func_return_annotation,
220
+ devices=devices,
221
+ )
222
+ add_function_test(TestFixedArray, "test_error_runtime_shape", test_error_runtime_shape, devices=devices)
223
+ add_function_test(TestFixedArray, "test_capture_if", test_capture_if, devices=devices)
224
+ add_kernel_test(TestFixedArray, kernel=test_func_struct, name="test_func_struct", dim=1, devices=devices)
225
+
226
+
227
+ if __name__ == "__main__":
228
+ wp.clear_kernel_cache()
229
+ unittest.main(verbosity=2)
warp/tests/test_func.py CHANGED
@@ -269,15 +269,14 @@ def normalize_vector(vec_a: wp.vec3):
269
269
  return wp.normalize(vec_a)
270
270
 
271
271
 
272
- # This pair is to test the situation where one overload throws an error, but a second one works.
273
272
  @wp.func
274
- def divide_by_zero_overload(x: wp.float32):
275
- return x / 0
273
+ def divide_float64(x: wp.float64):
274
+ return x / wp.float64(1.23)
276
275
 
277
276
 
278
277
  @wp.func
279
- def divide_by_zero_overload(x: wp.float64):
280
- return wp.div(x, 0.0)
278
+ def get_array_len(arr: wp.array(dtype=wp.float32)):
279
+ return len(arr)
281
280
 
282
281
 
283
282
  class TestFunc(unittest.TestCase):
@@ -444,26 +443,30 @@ class TestFunc(unittest.TestCase):
444
443
  a * b
445
444
 
446
445
  def test_cpython_call_user_function_with_error(self):
447
- # Actually the following also includes a ZeroDivisionError in the message due to exception chaining,
448
- # but I don't know how to test for that.
449
446
  with self.assertRaisesRegex(
450
- RuntimeError,
451
- "Error calling function 'divide_by_zero'. No version succeeded. "
452
- "See above for the error from the last version that was tried.",
447
+ ZeroDivisionError,
448
+ "float division by zero",
453
449
  ):
454
450
  divide_by_zero(1.0)
455
451
 
456
- def test_cpython_call_user_function_with_overloads(self):
457
- self.assertEqual(divide_by_zero_overload(1.0), math.inf)
458
-
459
452
  def test_cpython_call_user_function_with_wrong_argument_types(self):
460
453
  with self.assertRaisesRegex(
461
454
  RuntimeError,
462
- "Error calling function 'normalize_vector'. No version succeeded. "
463
- "See above for the error from the last version that was tried.",
455
+ r"^Error calling function 'divide_float64', no overload found for arguments \(1.0,\)$",
456
+ ):
457
+ divide_float64(1.0)
458
+
459
+ with self.assertRaisesRegex(
460
+ RuntimeError,
461
+ r"^Error calling function 'normalize_vector', no overload found for arguments \(1.0,\)$",
464
462
  ):
465
463
  normalize_vector(1.0)
466
464
 
465
+ def test_cpython_call_user_function_with_array_type(self):
466
+ arr = wp.array((1, 2, 3, 4, 5, 6, 7, 8), dtype=wp.float32)
467
+ length = get_array_len(arr)
468
+ assert length == 8
469
+
467
470
 
468
471
  devices = get_test_devices()
469
472
 
@@ -77,20 +77,22 @@ def test_future_annotations(test, device):
77
77
  foo_data.x = 1.23
78
78
  foo_data.y = 2.34
79
79
 
80
- out = wp.empty(1, dtype=float)
80
+ out = wp.empty(1, dtype=float, device=device)
81
81
 
82
82
  kernel_3 = create_kernel_3(foo)
83
83
 
84
- wp.launch(kernel_1, dim=out.shape, outputs=(out,))
85
- wp.launch(kernel_2, dim=out.shape, outputs=(out,))
86
- wp.launch(kernel_3, dim=out.shape, inputs=(foo_data,), outputs=(out,))
84
+ wp.launch(kernel_1, dim=out.shape, outputs=(out,), device=device)
85
+ wp.launch(kernel_2, dim=out.shape, outputs=(out,), device=device)
86
+ wp.launch(kernel_3, dim=out.shape, inputs=(foo_data,), outputs=(out,), device=device)
87
87
 
88
88
 
89
89
  class TestFutureAnnotations(unittest.TestCase):
90
90
  pass
91
91
 
92
92
 
93
- add_function_test(TestFutureAnnotations, "test_future_annotations", test_future_annotations)
93
+ devices = get_test_devices()
94
+
95
+ add_function_test(TestFutureAnnotations, "test_future_annotations", test_future_annotations, devices=devices)
94
96
 
95
97
 
96
98
  if __name__ == "__main__":
@@ -18,10 +18,19 @@ import unittest
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
+ from warp.context import assert_conditional_graph_support
21
22
  from warp.optim.linear import bicgstab, cg, cr, gmres, preconditioner
22
23
  from warp.tests.unittest_utils import *
23
24
 
24
25
 
26
+ def check_conditional_graph_support():
27
+ try:
28
+ assert_conditional_graph_support()
29
+ except RuntimeError:
30
+ return False
31
+ return True
32
+
33
+
25
34
  def _check_linear_solve(test, A, b, func, *args, **kwargs):
26
35
  # test from zero
27
36
  x = wp.zeros_like(b)
@@ -30,10 +39,30 @@ def _check_linear_solve(test, A, b, func, *args, **kwargs):
30
39
 
31
40
  test.assertLessEqual(err, atol)
32
41
 
42
+ # Test with capturable graph
43
+ if A.device.is_cuda and check_conditional_graph_support():
44
+ x.zero_()
45
+ with wp.ScopedDevice(A.device):
46
+ with wp.ScopedCapture() as capture:
47
+ niter, err, atol = func(A, b, x, *args, use_cuda_graph=True, check_every=0, **kwargs)
48
+
49
+ wp.capture_launch(capture.graph)
50
+
51
+ niter = niter.numpy()[0]
52
+ err = np.sqrt(err.numpy()[0])
53
+ atol = np.sqrt(atol.numpy()[0])
54
+
55
+ test.assertLessEqual(err, atol)
56
+
33
57
  # test with warm start
34
58
  with wp.ScopedDevice(A.device):
35
59
  niter_warm, err, atol = func(A, b, x, *args, use_cuda_graph=False, **kwargs)
36
60
 
61
+ if isinstance(niter_warm, wp.array):
62
+ niter_warm = niter_warm.numpy()[0]
63
+ err = np.sqrt(err.numpy()[0])
64
+ atol = np.sqrt(atol.numpy()[0])
65
+
37
66
  test.assertLessEqual(err, atol)
38
67
 
39
68
  if func in [cr, gmres]:
@@ -45,6 +74,7 @@ def _check_linear_solve(test, A, b, func, *args, **kwargs):
45
74
  # This can lead to accumulated inaccuracies over iterations, esp in float32
46
75
  residual = A.numpy() @ x.numpy() - b.numpy()
47
76
  err_np = np.linalg.norm(residual)
77
+
48
78
  if A.dtype == wp.float64:
49
79
  test.assertLessEqual(err_np, 2.0 * atol)
50
80
  else:
warp/tests/test_map.py CHANGED
@@ -318,7 +318,7 @@ def test_input_validity(test, device):
318
318
 
319
319
  with test.assertRaisesRegex(
320
320
  TypeError,
321
- 'Incorrect input provided for argument "i": received array of dtype float32, expected int$',
321
+ "Function test_input_validity__locals__int_function does not support the provided argument types float32",
322
322
  ):
323
323
  wp.map(int_function, a1)
324
324
 
@@ -476,6 +476,20 @@ add_function_test(TestMap, "test_graph_capture", test_graph_capture, devices=cud
476
476
  add_function_test(TestMap, "test_renamed_warp_module", test_renamed_warp_module, devices=devices)
477
477
 
478
478
 
479
+ class TestMapDebug(unittest.TestCase):
480
+ @classmethod
481
+ def setUpClass(cls):
482
+ cls._saved_mode = wp.config.mode
483
+ wp.config.mode = "debug"
484
+
485
+ @classmethod
486
+ def tearDownClass(cls):
487
+ wp.config.mode = cls._saved_mode
488
+
489
+
490
+ add_function_test(TestMapDebug, "test_mixed_inputs", test_mixed_inputs, devices=devices)
491
+ add_function_test(TestMapDebug, "test_kernel_creation", test_kernel_creation, devices=devices)
492
+
479
493
  if __name__ == "__main__":
480
494
  wp.clear_kernel_cache()
481
495
  unittest.main(verbosity=2)