warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (134) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +47 -67
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +312 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1249 -784
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/fabric.py +1 -1
  18. warp/fem/cache.py +27 -19
  19. warp/fem/domain.py +2 -2
  20. warp/fem/field/nodal_field.py +2 -2
  21. warp/fem/field/virtual.py +264 -166
  22. warp/fem/geometry/geometry.py +5 -5
  23. warp/fem/integrate.py +129 -51
  24. warp/fem/space/restriction.py +4 -0
  25. warp/fem/space/shape/tet_shape_function.py +3 -10
  26. warp/jax_experimental/custom_call.py +1 -1
  27. warp/jax_experimental/ffi.py +2 -1
  28. warp/marching_cubes.py +708 -0
  29. warp/native/array.h +99 -4
  30. warp/native/builtin.h +82 -5
  31. warp/native/bvh.cpp +64 -28
  32. warp/native/bvh.cu +58 -58
  33. warp/native/bvh.h +2 -2
  34. warp/native/clang/clang.cpp +7 -7
  35. warp/native/coloring.cpp +8 -2
  36. warp/native/crt.cpp +2 -2
  37. warp/native/crt.h +3 -5
  38. warp/native/cuda_util.cpp +41 -10
  39. warp/native/cuda_util.h +10 -4
  40. warp/native/exports.h +1842 -1908
  41. warp/native/fabric.h +2 -1
  42. warp/native/hashgrid.cpp +37 -37
  43. warp/native/hashgrid.cu +2 -2
  44. warp/native/initializer_array.h +1 -1
  45. warp/native/intersect.h +2 -2
  46. warp/native/mat.h +1910 -116
  47. warp/native/mathdx.cpp +43 -43
  48. warp/native/mesh.cpp +24 -24
  49. warp/native/mesh.cu +26 -26
  50. warp/native/mesh.h +4 -2
  51. warp/native/nanovdb/GridHandle.h +179 -12
  52. warp/native/nanovdb/HostBuffer.h +8 -7
  53. warp/native/nanovdb/NanoVDB.h +517 -895
  54. warp/native/nanovdb/NodeManager.h +323 -0
  55. warp/native/nanovdb/PNanoVDB.h +2 -2
  56. warp/native/quat.h +331 -14
  57. warp/native/range.h +7 -1
  58. warp/native/reduce.cpp +10 -10
  59. warp/native/reduce.cu +13 -14
  60. warp/native/runlength_encode.cpp +2 -2
  61. warp/native/runlength_encode.cu +5 -5
  62. warp/native/scan.cpp +3 -3
  63. warp/native/scan.cu +4 -4
  64. warp/native/sort.cpp +10 -10
  65. warp/native/sort.cu +22 -22
  66. warp/native/sparse.cpp +8 -8
  67. warp/native/sparse.cu +13 -13
  68. warp/native/spatial.h +366 -17
  69. warp/native/temp_buffer.h +2 -2
  70. warp/native/tile.h +283 -69
  71. warp/native/vec.h +381 -14
  72. warp/native/volume.cpp +54 -54
  73. warp/native/volume.cu +1 -1
  74. warp/native/volume.h +2 -1
  75. warp/native/volume_builder.cu +30 -37
  76. warp/native/warp.cpp +150 -149
  77. warp/native/warp.cu +323 -192
  78. warp/native/warp.h +227 -226
  79. warp/optim/linear.py +736 -271
  80. warp/render/imgui_manager.py +289 -0
  81. warp/render/render_opengl.py +85 -6
  82. warp/sim/graph_coloring.py +2 -2
  83. warp/sparse.py +558 -175
  84. warp/tests/aux_test_module_aot.py +7 -0
  85. warp/tests/cuda/test_async.py +3 -3
  86. warp/tests/cuda/test_conditional_captures.py +101 -0
  87. warp/tests/geometry/test_marching_cubes.py +233 -12
  88. warp/tests/sim/test_coloring.py +6 -6
  89. warp/tests/test_array.py +56 -5
  90. warp/tests/test_codegen.py +3 -2
  91. warp/tests/test_context.py +8 -15
  92. warp/tests/test_enum.py +136 -0
  93. warp/tests/test_examples.py +2 -2
  94. warp/tests/test_fem.py +45 -2
  95. warp/tests/test_fixedarray.py +229 -0
  96. warp/tests/test_func.py +18 -15
  97. warp/tests/test_future_annotations.py +7 -5
  98. warp/tests/test_linear_solvers.py +30 -0
  99. warp/tests/test_map.py +1 -1
  100. warp/tests/test_mat.py +1518 -378
  101. warp/tests/test_mat_assign_copy.py +178 -0
  102. warp/tests/test_mat_constructors.py +574 -0
  103. warp/tests/test_module_aot.py +287 -0
  104. warp/tests/test_print.py +69 -0
  105. warp/tests/test_quat.py +140 -34
  106. warp/tests/test_quat_assign_copy.py +145 -0
  107. warp/tests/test_reload.py +2 -1
  108. warp/tests/test_sparse.py +71 -0
  109. warp/tests/test_spatial.py +140 -34
  110. warp/tests/test_spatial_assign_copy.py +160 -0
  111. warp/tests/test_struct.py +43 -3
  112. warp/tests/test_types.py +0 -20
  113. warp/tests/test_vec.py +179 -34
  114. warp/tests/test_vec_assign_copy.py +143 -0
  115. warp/tests/tile/test_tile.py +184 -18
  116. warp/tests/tile/test_tile_cholesky.py +605 -0
  117. warp/tests/tile/test_tile_load.py +169 -0
  118. warp/tests/tile/test_tile_mathdx.py +2 -558
  119. warp/tests/tile/test_tile_matmul.py +1 -1
  120. warp/tests/tile/test_tile_mlp.py +1 -1
  121. warp/tests/tile/test_tile_shared_memory.py +5 -5
  122. warp/tests/unittest_suites.py +6 -0
  123. warp/tests/walkthrough_debug.py +1 -1
  124. warp/thirdparty/unittest_parallel.py +108 -9
  125. warp/types.py +554 -264
  126. warp/utils.py +68 -86
  127. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  128. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
  129. warp/native/marching.cpp +0 -19
  130. warp/native/marching.cu +0 -514
  131. warp/native/marching.h +0 -19
  132. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  133. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,136 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import enum
17
+ import unittest
18
+
19
+ import warp as wp
20
+ from warp.tests.unittest_utils import *
21
+
22
+
23
+ class MyIntEnum(enum.IntEnum):
24
+ A = 1
25
+ B = 2
26
+ C = 3
27
+
28
+
29
+ class MyIntFlag(enum.IntFlag):
30
+ A = 1
31
+ B = enum.auto()
32
+ C = enum.auto()
33
+
34
+
35
+ def test_intenum_ints(test, device):
36
+ @wp.kernel
37
+ def expect_intenum_ints():
38
+ wp.expect_eq(MyIntEnum.A, 1)
39
+ wp.expect_eq(MyIntEnum.B, 2)
40
+ wp.expect_eq(MyIntEnum.C, 3)
41
+ wp.expect_eq(MyIntEnum.A + MyIntEnum.B, MyIntEnum.C)
42
+
43
+ wp.launch(expect_intenum_ints, dim=1, device=device)
44
+
45
+
46
+ def test_intflag_ints(test, device):
47
+ @wp.kernel
48
+ def expect_intflag_ints():
49
+ wp.expect_eq(MyIntFlag.A, 1)
50
+ wp.expect_eq(MyIntFlag.B, 2)
51
+ wp.expect_eq(MyIntFlag.C, 4)
52
+ wp.expect_eq(MyIntFlag.A | MyIntFlag.B, 3)
53
+ wp.expect_eq(MyIntFlag.A | MyIntFlag.B | MyIntFlag.C, 7)
54
+
55
+ wp.launch(expect_intflag_ints, dim=1, device=device)
56
+
57
+
58
+ def test_alternative_accessors(test, device):
59
+ @wp.kernel
60
+ def expect_alternative_accessors():
61
+ wp.expect_eq(int(MyIntEnum.A), 1)
62
+ wp.expect_eq(int(MyIntEnum.B.value), 2)
63
+ wp.expect_eq(MyIntEnum.C.value, 3)
64
+ wp.expect_eq(MyIntEnum.A + int(MyIntEnum.B) + 0, MyIntEnum.C)
65
+ wp.expect_eq(int(MyIntFlag.A), 1)
66
+ wp.expect_eq(int(MyIntFlag.B.value), 2)
67
+ wp.expect_eq(MyIntFlag.C.value, 4)
68
+ wp.expect_eq(MyIntFlag.A | int(MyIntFlag.B), 3)
69
+ wp.expect_eq(MyIntFlag.A | MyIntFlag.B.value | MyIntFlag.C, 7)
70
+
71
+ wp.launch(expect_alternative_accessors, dim=1, device=device)
72
+
73
+
74
+ def test_static_accessors(test, device):
75
+ @wp.kernel
76
+ def expect_static_accessors():
77
+ wp.expect_eq(wp.static(MyIntEnum.A), 1)
78
+ wp.expect_eq(wp.static(int(MyIntEnum.A)), 1)
79
+ wp.expect_eq(wp.static(MyIntEnum.A.value), 1)
80
+ wp.expect_eq(wp.static(MyIntFlag.A), 1)
81
+ wp.expect_eq(wp.static(int(MyIntFlag.A)), 1)
82
+ wp.expect_eq(wp.static(MyIntFlag.A.value), 1)
83
+
84
+ wp.launch(expect_static_accessors, dim=1, device=device)
85
+
86
+
87
+ def test_intflag_compare(test, device):
88
+ @wp.kernel
89
+ def compute_intflag_compare(ins: wp.array(dtype=wp.int32), outs: wp.array(dtype=wp.int32)):
90
+ tid = wp.tid()
91
+ if ins[tid] & MyIntFlag.A:
92
+ outs[tid] += MyIntFlag.A
93
+ if ins[tid] & MyIntFlag.B:
94
+ outs[tid] += MyIntFlag.B
95
+ if ins[tid] & MyIntFlag.C:
96
+ outs[tid] += MyIntFlag.C
97
+
98
+ with wp.ScopedDevice(device):
99
+ ins = wp.array(
100
+ [
101
+ 0,
102
+ MyIntFlag.A,
103
+ MyIntFlag.B,
104
+ MyIntFlag.C,
105
+ MyIntFlag.A | MyIntFlag.B,
106
+ MyIntFlag.A | MyIntFlag.B | MyIntFlag.C,
107
+ ],
108
+ dtype=wp.int32,
109
+ )
110
+ outs = wp.zeros(len(ins), dtype=wp.int32)
111
+ wp.launch(compute_intflag_compare, dim=len(ins), inputs=[ins], outputs=[outs])
112
+ outs = outs.numpy()
113
+ test.assertEqual(outs[0], 0)
114
+ test.assertEqual(outs[1], 1)
115
+ test.assertEqual(outs[2], 2)
116
+ test.assertEqual(outs[3], 4)
117
+ test.assertEqual(outs[4], 3)
118
+ test.assertEqual(outs[5], 7)
119
+
120
+
121
+ class TestEnum(unittest.TestCase):
122
+ pass
123
+
124
+
125
+ devices = get_test_devices()
126
+
127
+ add_function_test(TestEnum, "test_intenum_ints", test_intenum_ints, devices=devices)
128
+ add_function_test(TestEnum, "test_intflag_ints", test_intflag_ints, devices=devices)
129
+ add_function_test(TestEnum, "test_intflag_compare", test_intflag_compare, devices=devices)
130
+ add_function_test(TestEnum, "test_alternative_accessors", test_alternative_accessors, devices=devices)
131
+ add_function_test(TestEnum, "test_static_accessors", test_static_accessors, devices=devices)
132
+
133
+
134
+ if __name__ == "__main__":
135
+ wp.clear_kernel_cache()
136
+ unittest.main(verbosity=2)
@@ -52,7 +52,7 @@ from warp.tests.unittest_utils import (
52
52
  )
53
53
  from warp.utils import check_p2p
54
54
 
55
- wp.init() # For wp.context.runtime.core.is_debug_enabled()
55
+ wp.init() # For wp.context.runtime.core.wp_is_debug_enabled()
56
56
 
57
57
 
58
58
  def _build_command_line_options(test_options: Dict[str, Any]) -> list:
@@ -331,7 +331,7 @@ add_example_test(
331
331
  name="optim.example_softbody_properties",
332
332
  devices=test_devices,
333
333
  test_options_cuda={
334
- "train_iters": 1 if warp.context.runtime.core.is_debug_enabled() else 3,
334
+ "train_iters": 1 if warp.context.runtime.core.wp_is_debug_enabled() else 3,
335
335
  },
336
336
  test_options_cpu={"train_iters": 1},
337
337
  )
warp/tests/test_fem.py CHANGED
@@ -46,6 +46,11 @@ def linear_form(s: Sample, u: Field):
46
46
  return u(s)
47
47
 
48
48
 
49
+ @integrand
50
+ def bilinear_form(s: Sample, u: Field, v: Field):
51
+ return u(s) * v(s)
52
+
53
+
49
54
  @integrand
50
55
  def scaled_linear_form(s: Sample, u: Field, scale: wp.array(dtype=float)):
51
56
  return u(s) * scale[0]
@@ -1868,8 +1873,6 @@ def test_vector_spaces(test, device):
1868
1873
  fields={"field": div_field.trace()},
1869
1874
  )
1870
1875
 
1871
- return
1872
-
1873
1876
  with wp.ScopedDevice(device):
1874
1877
  positions, tri_vidx = _gen_trimesh(3, 5)
1875
1878
 
@@ -2055,6 +2058,45 @@ def test_array_axpy(test, device):
2055
2058
  assert_np_equal(y.grad.numpy(), beta * np.ones(N))
2056
2059
 
2057
2060
 
2061
+ def test_integrate_high_order(test_field, device):
2062
+ with wp.ScopedDevice(device):
2063
+ geo = fem.Grid3D(res=(1, 1, 1))
2064
+ space = fem.make_polynomial_space(geo, degree=4)
2065
+ test_field = fem.make_test(space)
2066
+ trial_field = fem.make_trial(space)
2067
+
2068
+ # compare consistency of tile-based "dispatch" assembly and generic
2069
+ v0 = fem.integrate(
2070
+ linear_form, fields={"u": test_field}, assembly="dispatch", kernel_options={"enable_backward": False}
2071
+ )
2072
+ v1 = fem.integrate(
2073
+ linear_form, fields={"u": test_field}, assembly="generic", kernel_options={"enable_backward": False}
2074
+ )
2075
+
2076
+ assert_np_equal(v0.numpy(), v1.numpy(), tol=1.0e-6)
2077
+
2078
+ h0 = fem.integrate(
2079
+ bilinear_form,
2080
+ fields={"v": test_field, "u": trial_field},
2081
+ assembly="dispatch",
2082
+ kernel_options={"enable_backward": False},
2083
+ )
2084
+ h1 = fem.integrate(
2085
+ bilinear_form,
2086
+ fields={"v": test_field, "u": trial_field},
2087
+ assembly="generic",
2088
+ kernel_options={"enable_backward": False},
2089
+ )
2090
+
2091
+ h0_nnz = h0.nnz_sync()
2092
+ h1_nnz = h1.nnz_sync()
2093
+ assert h0.shape == h1.shape
2094
+ assert h0_nnz == h1_nnz
2095
+ assert_array_equal(h0.offsets[: h0.nrow + 1], h1.offsets[: h1.nrow + 1])
2096
+ assert_array_equal(h0.columns[:h0_nnz], h1.columns[:h1_nnz])
2097
+ assert_np_equal(h0.values[:h0_nnz].numpy(), h1.values[:h1_nnz].numpy(), tol=1.0e-6)
2098
+
2099
+
2058
2100
  devices = get_test_devices()
2059
2101
  cuda_devices = get_selected_cuda_test_devices()
2060
2102
 
@@ -2088,6 +2130,7 @@ add_function_test(TestFem, "test_point_basis", test_point_basis)
2088
2130
  add_function_test(TestFem, "test_particle_quadratures", test_particle_quadratures)
2089
2131
  add_function_test(TestFem, "test_nodal_quadrature", test_nodal_quadrature)
2090
2132
  add_function_test(TestFem, "test_implicit_fields", test_implicit_fields)
2133
+ add_function_test(TestFem, "test_integrate_high_order", test_integrate_high_order, devices=cuda_devices)
2091
2134
 
2092
2135
 
2093
2136
  class TestFemUtilities(unittest.TestCase):
@@ -0,0 +1,229 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import warp as wp
19
+ from warp.tests.unittest_utils import *
20
+
21
+ devices = get_test_devices()
22
+
23
+
24
+ @wp.kernel
25
+ def test_zeros():
26
+ arr = wp.zeros(shape=(2, 3), dtype=int)
27
+ for i in range(arr.shape[0]):
28
+ for j in range(arr.shape[1]):
29
+ arr[i][j] = i * arr.shape[1] + j
30
+
31
+ wp.expect_eq(arr[0][0], 0)
32
+ wp.expect_eq(arr[0][1], 1)
33
+ wp.expect_eq(arr[0][2], 2)
34
+ wp.expect_eq(arr[1][0], 3)
35
+ wp.expect_eq(arr[1][1], 4)
36
+ wp.expect_eq(arr[1][2], 5)
37
+
38
+
39
+ @wp.func
40
+ def test_func_arg_func(arr: wp.array(ndim=2, dtype=int)):
41
+ for i in range(arr.shape[0]):
42
+ for j in range(arr.shape[1]):
43
+ arr[i][j] = i * arr.shape[1] + j
44
+
45
+
46
+ @wp.kernel
47
+ def test_func_arg():
48
+ arr = wp.zeros(shape=(2, 3), dtype=int)
49
+ test_func_arg_func(arr)
50
+
51
+ wp.expect_eq(arr[0][0], 0)
52
+ wp.expect_eq(arr[0][1], 1)
53
+ wp.expect_eq(arr[0][2], 2)
54
+ wp.expect_eq(arr[1][0], 3)
55
+ wp.expect_eq(arr[1][1], 4)
56
+ wp.expect_eq(arr[1][2], 5)
57
+
58
+
59
+ @wp.func
60
+ def test_func_return_func():
61
+ arr = wp.zeros(shape=(2, 3), dtype=int)
62
+ for i in range(arr.shape[0]):
63
+ for j in range(arr.shape[1]):
64
+ arr[i][j] = i * arr.shape[1] + j
65
+
66
+ return arr
67
+
68
+
69
+ @wp.kernel
70
+ def test_func_return():
71
+ arr = test_func_return_func()
72
+
73
+ wp.expect_eq(arr[0][0], 0)
74
+ wp.expect_eq(arr[0][1], 1)
75
+ wp.expect_eq(arr[0][2], 2)
76
+ wp.expect_eq(arr[1][0], 3)
77
+ wp.expect_eq(arr[1][1], 4)
78
+ wp.expect_eq(arr[1][2], 5)
79
+
80
+
81
+ @wp.func
82
+ def test_func_return_annotation_func() -> wp.fixedarray(shape=(2, 3), dtype=int):
83
+ arr = wp.zeros(shape=(2, 3), dtype=int)
84
+ for i in range(arr.shape[0]):
85
+ for j in range(arr.shape[1]):
86
+ arr[i][j] = i * arr.shape[1] + j
87
+
88
+ return arr
89
+
90
+
91
+ @wp.kernel
92
+ def test_func_return_annotation():
93
+ arr = test_func_return_annotation_func()
94
+
95
+ wp.expect_eq(arr[0][0], 0)
96
+ wp.expect_eq(arr[0][1], 1)
97
+ wp.expect_eq(arr[0][2], 2)
98
+ wp.expect_eq(arr[1][0], 3)
99
+ wp.expect_eq(arr[1][1], 4)
100
+ wp.expect_eq(arr[1][2], 5)
101
+
102
+
103
+ def test_error_invalid_func_return_annotation(test, device):
104
+ @wp.func
105
+ def func() -> wp.array(ndim=2, dtype=int):
106
+ arr = wp.zeros(shape=(2, 3), dtype=int)
107
+ for i in range(arr.shape[0]):
108
+ for j in range(arr.shape[1]):
109
+ arr[i][j] = i * arr.shape[1] + j
110
+
111
+ return arr
112
+
113
+ @wp.kernel
114
+ def kernel():
115
+ arr = func()
116
+
117
+ with test.assertRaisesRegex(
118
+ wp.codegen.WarpCodegenError,
119
+ r"The function `func` returns a fixed-size array whereas it has its return type annotated as `Array\[int32\]`.$",
120
+ ):
121
+ wp.launch(kernel, 1, device=device)
122
+
123
+
124
+ def test_error_runtime_shape(test, device):
125
+ @wp.kernel
126
+ def kernel():
127
+ tid = wp.tid()
128
+ wp.zeros(shape=(tid,), dtype=int)
129
+
130
+ with test.assertRaisesRegex(
131
+ RuntimeError,
132
+ r"the `shape` argument must be specified as a constant when zero-initializing an array$",
133
+ ):
134
+ wp.launch(kernel, 1, device=device)
135
+
136
+
137
+ @wp.kernel
138
+ def test_capture_if_kernel():
139
+ arr = wp.zeros(shape=(2, 3), dtype=int)
140
+ for i in range(arr.shape[0]):
141
+ for j in range(arr.shape[1]):
142
+ arr[i][j] = i * arr.shape[1] + j
143
+
144
+ wp.expect_eq(arr[0][0], 0)
145
+ wp.expect_eq(arr[0][1], 1)
146
+ wp.expect_eq(arr[0][2], 2)
147
+ wp.expect_eq(arr[1][0], 3)
148
+ wp.expect_eq(arr[1][1], 4)
149
+ wp.expect_eq(arr[1][2], 5)
150
+
151
+
152
+ def test_capture_if(test, device):
153
+ if (
154
+ not wp.get_device(device).is_cuda
155
+ or wp.context.runtime.toolkit_version < (12, 4)
156
+ or wp.context.runtime.driver_version < (12, 4)
157
+ ):
158
+ return
159
+
160
+ def foo():
161
+ wp.launch(test_capture_if_kernel, dim=512, block_dim=128, device=device)
162
+
163
+ cond = wp.ones(1, dtype=wp.int32, device=device)
164
+ with wp.ScopedCapture(device=device) as capture:
165
+ wp.capture_if(condition=cond, on_true=foo)
166
+
167
+ wp.capture_launch(capture.graph)
168
+
169
+
170
+ @wp.struct
171
+ class test_func_struct_MyStruct:
172
+ offset: int
173
+ dist: float
174
+
175
+
176
+ @wp.func
177
+ def test_func_struct_func():
178
+ arr = wp.zeros(shape=(2, 3), dtype=test_func_struct_MyStruct)
179
+ count = float(arr.shape[0] * arr.shape[1] - 1)
180
+ for i in range(arr.shape[0]):
181
+ for j in range(arr.shape[1]):
182
+ arr[i][j].offset = i * arr.shape[1] + j
183
+ arr[i][j].dist = float(arr[i][j].offset) / count
184
+
185
+ return arr
186
+
187
+
188
+ @wp.kernel
189
+ def test_func_struct():
190
+ arr = test_func_struct_func()
191
+
192
+ wp.expect_eq(arr[0][0].offset, 0)
193
+ wp.expect_near(arr[0][0].dist, 0.0)
194
+ wp.expect_eq(arr[0][1].offset, 1)
195
+ wp.expect_near(arr[0][1].dist, 0.2)
196
+ wp.expect_eq(arr[0][2].offset, 2)
197
+ wp.expect_near(arr[0][2].dist, 0.4)
198
+ wp.expect_eq(arr[1][0].offset, 3)
199
+ wp.expect_near(arr[1][0].dist, 0.6)
200
+ wp.expect_eq(arr[1][1].offset, 4)
201
+ wp.expect_near(arr[1][1].dist, 0.8)
202
+ wp.expect_eq(arr[1][2].offset, 5)
203
+ wp.expect_near(arr[1][2].dist, 1.0)
204
+
205
+
206
+ class TestFixedArray(unittest.TestCase):
207
+ pass
208
+
209
+
210
+ add_kernel_test(TestFixedArray, kernel=test_zeros, name="test_zeros", dim=1, devices=devices)
211
+ add_kernel_test(TestFixedArray, kernel=test_func_arg, name="test_func_arg", dim=1, devices=devices)
212
+ add_kernel_test(TestFixedArray, kernel=test_func_return, name="test_func_return", dim=1, devices=devices)
213
+ add_kernel_test(
214
+ TestFixedArray, kernel=test_func_return_annotation, name="test_func_return_annotation", dim=1, devices=devices
215
+ )
216
+ add_function_test(
217
+ TestFixedArray,
218
+ "test_error_invalid_func_return_annotation",
219
+ test_error_invalid_func_return_annotation,
220
+ devices=devices,
221
+ )
222
+ add_function_test(TestFixedArray, "test_error_runtime_shape", test_error_runtime_shape, devices=devices)
223
+ add_function_test(TestFixedArray, "test_capture_if", test_capture_if, devices=devices)
224
+ add_kernel_test(TestFixedArray, kernel=test_func_struct, name="test_func_struct", dim=1, devices=devices)
225
+
226
+
227
+ if __name__ == "__main__":
228
+ wp.clear_kernel_cache()
229
+ unittest.main(verbosity=2)
warp/tests/test_func.py CHANGED
@@ -269,15 +269,14 @@ def normalize_vector(vec_a: wp.vec3):
269
269
  return wp.normalize(vec_a)
270
270
 
271
271
 
272
- # This pair is to test the situation where one overload throws an error, but a second one works.
273
272
  @wp.func
274
- def divide_by_zero_overload(x: wp.float32):
275
- return x / 0
273
+ def divide_float64(x: wp.float64):
274
+ return x / wp.float64(1.23)
276
275
 
277
276
 
278
277
  @wp.func
279
- def divide_by_zero_overload(x: wp.float64):
280
- return wp.div(x, 0.0)
278
+ def get_array_len(arr: wp.array(dtype=wp.float32)):
279
+ return len(arr)
281
280
 
282
281
 
283
282
  class TestFunc(unittest.TestCase):
@@ -444,26 +443,30 @@ class TestFunc(unittest.TestCase):
444
443
  a * b
445
444
 
446
445
  def test_cpython_call_user_function_with_error(self):
447
- # Actually the following also includes a ZeroDivisionError in the message due to exception chaining,
448
- # but I don't know how to test for that.
449
446
  with self.assertRaisesRegex(
450
- RuntimeError,
451
- "Error calling function 'divide_by_zero'. No version succeeded. "
452
- "See above for the error from the last version that was tried.",
447
+ ZeroDivisionError,
448
+ "float division by zero",
453
449
  ):
454
450
  divide_by_zero(1.0)
455
451
 
456
- def test_cpython_call_user_function_with_overloads(self):
457
- self.assertEqual(divide_by_zero_overload(1.0), math.inf)
458
-
459
452
  def test_cpython_call_user_function_with_wrong_argument_types(self):
460
453
  with self.assertRaisesRegex(
461
454
  RuntimeError,
462
- "Error calling function 'normalize_vector'. No version succeeded. "
463
- "See above for the error from the last version that was tried.",
455
+ r"^Error calling function 'divide_float64', no overload found for arguments \(1.0,\)$",
456
+ ):
457
+ divide_float64(1.0)
458
+
459
+ with self.assertRaisesRegex(
460
+ RuntimeError,
461
+ r"^Error calling function 'normalize_vector', no overload found for arguments \(1.0,\)$",
464
462
  ):
465
463
  normalize_vector(1.0)
466
464
 
465
+ def test_cpython_call_user_function_with_array_type(self):
466
+ arr = wp.array((1, 2, 3, 4, 5, 6, 7, 8), dtype=wp.float32)
467
+ length = get_array_len(arr)
468
+ assert length == 8
469
+
467
470
 
468
471
  devices = get_test_devices()
469
472
 
@@ -77,20 +77,22 @@ def test_future_annotations(test, device):
77
77
  foo_data.x = 1.23
78
78
  foo_data.y = 2.34
79
79
 
80
- out = wp.empty(1, dtype=float)
80
+ out = wp.empty(1, dtype=float, device=device)
81
81
 
82
82
  kernel_3 = create_kernel_3(foo)
83
83
 
84
- wp.launch(kernel_1, dim=out.shape, outputs=(out,))
85
- wp.launch(kernel_2, dim=out.shape, outputs=(out,))
86
- wp.launch(kernel_3, dim=out.shape, inputs=(foo_data,), outputs=(out,))
84
+ wp.launch(kernel_1, dim=out.shape, outputs=(out,), device=device)
85
+ wp.launch(kernel_2, dim=out.shape, outputs=(out,), device=device)
86
+ wp.launch(kernel_3, dim=out.shape, inputs=(foo_data,), outputs=(out,), device=device)
87
87
 
88
88
 
89
89
  class TestFutureAnnotations(unittest.TestCase):
90
90
  pass
91
91
 
92
92
 
93
- add_function_test(TestFutureAnnotations, "test_future_annotations", test_future_annotations)
93
+ devices = get_test_devices()
94
+
95
+ add_function_test(TestFutureAnnotations, "test_future_annotations", test_future_annotations, devices=devices)
94
96
 
95
97
 
96
98
  if __name__ == "__main__":
@@ -18,10 +18,19 @@ import unittest
18
18
  import numpy as np
19
19
 
20
20
  import warp as wp
21
+ from warp.context import assert_conditional_graph_support
21
22
  from warp.optim.linear import bicgstab, cg, cr, gmres, preconditioner
22
23
  from warp.tests.unittest_utils import *
23
24
 
24
25
 
26
+ def check_conditional_graph_support():
27
+ try:
28
+ assert_conditional_graph_support()
29
+ except RuntimeError:
30
+ return False
31
+ return True
32
+
33
+
25
34
  def _check_linear_solve(test, A, b, func, *args, **kwargs):
26
35
  # test from zero
27
36
  x = wp.zeros_like(b)
@@ -30,10 +39,30 @@ def _check_linear_solve(test, A, b, func, *args, **kwargs):
30
39
 
31
40
  test.assertLessEqual(err, atol)
32
41
 
42
+ # Test with capturable graph
43
+ if A.device.is_cuda and check_conditional_graph_support():
44
+ x.zero_()
45
+ with wp.ScopedDevice(A.device):
46
+ with wp.ScopedCapture() as capture:
47
+ niter, err, atol = func(A, b, x, *args, use_cuda_graph=True, check_every=0, **kwargs)
48
+
49
+ wp.capture_launch(capture.graph)
50
+
51
+ niter = niter.numpy()[0]
52
+ err = np.sqrt(err.numpy()[0])
53
+ atol = np.sqrt(atol.numpy()[0])
54
+
55
+ test.assertLessEqual(err, atol)
56
+
33
57
  # test with warm start
34
58
  with wp.ScopedDevice(A.device):
35
59
  niter_warm, err, atol = func(A, b, x, *args, use_cuda_graph=False, **kwargs)
36
60
 
61
+ if isinstance(niter_warm, wp.array):
62
+ niter_warm = niter_warm.numpy()[0]
63
+ err = np.sqrt(err.numpy()[0])
64
+ atol = np.sqrt(atol.numpy()[0])
65
+
37
66
  test.assertLessEqual(err, atol)
38
67
 
39
68
  if func in [cr, gmres]:
@@ -45,6 +74,7 @@ def _check_linear_solve(test, A, b, func, *args, **kwargs):
45
74
  # This can lead to accumulated inaccuracies over iterations, esp in float32
46
75
  residual = A.numpy() @ x.numpy() - b.numpy()
47
76
  err_np = np.linalg.norm(residual)
77
+
48
78
  if A.dtype == wp.float64:
49
79
  test.assertLessEqual(err_np, 2.0 * atol)
50
80
  else:
warp/tests/test_map.py CHANGED
@@ -318,7 +318,7 @@ def test_input_validity(test, device):
318
318
 
319
319
  with test.assertRaisesRegex(
320
320
  TypeError,
321
- 'Incorrect input provided for argument "i": received array of dtype float32, expected int$',
321
+ "Function test_input_validity__locals__int_function does not support the provided argument types float32",
322
322
  ):
323
323
  wp.map(int_function, a1)
324
324