warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,287 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import importlib.util
18
+ import os
19
+ import shutil
20
+ import unittest
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+
25
+ import warp as wp
26
+ import warp.tests.aux_test_module_aot
27
+ from warp.tests.unittest_utils import *
28
+
29
+ ADD_KERNEL_START = """import warp as wp
30
+
31
+
32
+ @wp.kernel
33
+ def add_kernel(a: wp.array(dtype=wp.int32), b: wp.array(dtype=wp.int32), res: wp.array(dtype=wp.int32)):
34
+ pass
35
+ """
36
+
37
+ ADD_KERNEL_FINAL = """import warp as wp
38
+
39
+
40
+ @wp.kernel
41
+ def add_kernel(a: wp.array(dtype=wp.int32), b: wp.array(dtype=wp.int32), res: wp.array(dtype=wp.int32)):
42
+ i = wp.tid()
43
+ res[i] = a[i] + b[i]
44
+ """
45
+
46
+
47
+ def reload_module(module):
48
+ # Clearing the .pyc file associated with a module is a necessary workaround
49
+ # for `importlib.reload` to work as expected when run from within Kit.
50
+ cache_file = importlib.util.cache_from_source(module.__file__)
51
+ if os.path.exists(cache_file):
52
+ os.remove(cache_file)
53
+ importlib.reload(module)
54
+
55
+
56
+ TEST_CACHE_DIR = Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "test_module_aot_cache")))
57
+
58
+
59
+ def test_disable_hashing(test, device):
60
+ """Test that module hashing can be disabled.
61
+
62
+ A module is run, modified, and run again. The second run should not trigger
63
+ a recompilation since the hash will not be used to detect changes.
64
+ """
65
+
66
+ try:
67
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
68
+ TEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
69
+ wp.set_module_options(
70
+ {"block_dim": 1 if device.is_cpu else 256},
71
+ warp.tests.aux_test_module_aot,
72
+ )
73
+
74
+ a = wp.ones(10, dtype=wp.int32, device=device)
75
+ b = wp.ones(10, dtype=wp.int32, device=device)
76
+ res = wp.zeros((10,), dtype=wp.int32, device=device)
77
+
78
+ # Write out the module and import it
79
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
80
+ f.writelines(ADD_KERNEL_START)
81
+ reload_module(warp.tests.aux_test_module_aot)
82
+
83
+ # First launch, cold compile, expect res to be unchanged since kernel is empty
84
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=True)
85
+ wp.load_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=True)
86
+
87
+ wp.launch(
88
+ warp.tests.aux_test_module_aot.add_kernel,
89
+ dim=a.shape,
90
+ inputs=[a, b],
91
+ outputs=[res],
92
+ device=device,
93
+ )
94
+
95
+ assert_np_equal(res.numpy(), np.zeros((10,), dtype=np.int32))
96
+
97
+ res.zero_()
98
+
99
+ # Write out the modified module and import it
100
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
101
+ f.writelines(ADD_KERNEL_FINAL)
102
+ reload_module(warp.tests.aux_test_module_aot)
103
+
104
+ # This time, the hash checks will be skipped so the previously compiled module will be loaded
105
+ wp.load_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=True)
106
+
107
+ # Kernel is executed with the ADD_KERNEL_START code, not the ADD_KERNEL_FINAL code
108
+ wp.launch(
109
+ warp.tests.aux_test_module_aot.add_kernel,
110
+ dim=a.shape,
111
+ inputs=[a, b],
112
+ outputs=[res],
113
+ device=device,
114
+ )
115
+
116
+ assert_np_equal(res.numpy(), np.zeros((10,), dtype=np.int32))
117
+ finally:
118
+ # Clear the cache directory
119
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
120
+ # Revert the module default options and auxiliary file to the original states
121
+ wp.set_module_options({"cuda_output": None, "strip_hash": False}, warp.tests.aux_test_module_aot)
122
+
123
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
124
+ f.writelines(ADD_KERNEL_FINAL)
125
+
126
+
127
+ def test_enable_hashing(test, device):
128
+ """Ensure that the logic of test_disable_hashing is sound.
129
+
130
+ This test sets "strip_hash" to False, so normal module hashing rules
131
+ should be in effect.
132
+ """
133
+
134
+ try:
135
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
136
+ TEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
137
+ wp.set_module_options(
138
+ {"block_dim": 1 if device.is_cpu else 256},
139
+ warp.tests.aux_test_module_aot,
140
+ )
141
+
142
+ a = wp.ones(10, dtype=wp.int32, device=device)
143
+ b = wp.ones(10, dtype=wp.int32, device=device)
144
+ res = wp.zeros((10,), dtype=wp.int32, device=device)
145
+
146
+ # Write out the module and import it
147
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
148
+ f.writelines(ADD_KERNEL_START)
149
+ reload_module(warp.tests.aux_test_module_aot)
150
+
151
+ # First launch, cold compile, expect no-op result
152
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=False)
153
+ wp.load_aot_module(warp.tests.aux_test_module_aot, device, module_dir=TEST_CACHE_DIR, strip_hash=False)
154
+ wp.launch(
155
+ warp.tests.aux_test_module_aot.add_kernel,
156
+ dim=a.shape,
157
+ inputs=[a, b],
158
+ outputs=[res],
159
+ device=device,
160
+ )
161
+
162
+ assert_np_equal(res.numpy(), np.zeros((10,), dtype=np.int32))
163
+
164
+ # Write out the modified module (results in a different hash) and import it
165
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
166
+ f.writelines(ADD_KERNEL_FINAL)
167
+ reload_module(warp.tests.aux_test_module_aot)
168
+
169
+ # Trying to load the module should fail since a compiled module with the expected hash does not exist
170
+ with test.assertRaises(FileNotFoundError):
171
+ wp.load_aot_module("warp.tests.aux_test_module_aot", device, module_dir=TEST_CACHE_DIR, strip_hash=False)
172
+ finally:
173
+ # Clear the cache directory
174
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
175
+ # Revert the module default options and auxiliary file to the original states
176
+ wp.set_module_options({"cuda_output": None, "strip_hash": False}, warp.tests.aux_test_module_aot)
177
+
178
+ with open(os.path.abspath(os.path.join(os.path.dirname(__file__), "aux_test_module_aot.py")), "w") as f:
179
+ f.writelines(ADD_KERNEL_FINAL)
180
+
181
+
182
+ def test_module_load_resolution(test, device):
183
+ """Test various ways to resolving a module when loading and compiling."""
184
+
185
+ wp.set_module_options(
186
+ {"block_dim": 1 if device.is_cpu else 256},
187
+ warp.tests.aux_test_module_aot,
188
+ )
189
+
190
+ a = wp.ones(10, dtype=wp.int32, device=device)
191
+ b = wp.ones(10, dtype=wp.int32, device=device)
192
+ res = wp.zeros((10,), dtype=wp.int32, device=device)
193
+
194
+ reload_module(warp.tests.aux_test_module_aot)
195
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, device)
196
+ wp.load_aot_module(warp.tests.aux_test_module_aot, device)
197
+
198
+ wp.launch(
199
+ warp.tests.aux_test_module_aot.add_kernel,
200
+ dim=a.shape,
201
+ inputs=[a, b],
202
+ outputs=[res],
203
+ device=device,
204
+ )
205
+ assert_np_equal(res.numpy(), np.full((10,), 2, dtype=np.int32))
206
+
207
+ reload_module(warp.tests.aux_test_module_aot)
208
+ res.zero_()
209
+ wp.compile_aot_module("warp.tests.aux_test_module_aot", device)
210
+ wp.load_aot_module("warp.tests.aux_test_module_aot", device)
211
+
212
+ wp.launch(
213
+ warp.tests.aux_test_module_aot.add_kernel,
214
+ dim=a.shape,
215
+ inputs=[a, b],
216
+ outputs=[res],
217
+ device=device,
218
+ )
219
+ assert_np_equal(res.numpy(), np.full((10,), 2, dtype=np.int32))
220
+
221
+
222
+ class TestModuleAOT(unittest.TestCase):
223
+ def test_module_compile_specified_arch_ptx(self):
224
+ """Test that a module can be compiled for a specific architecture or architectures (PTX)."""
225
+
226
+ if wp.get_cuda_device_count() == 0:
227
+ self.skipTest("No CUDA devices found")
228
+
229
+ if len(wp.context.runtime.nvrtc_supported_archs) < 2:
230
+ self.skipTest("NVRTC must support at least two architectures to run this test")
231
+
232
+ try:
233
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
234
+ TEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
235
+
236
+ archs = list(wp.context.runtime.nvrtc_supported_archs)[:2]
237
+
238
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, arch=archs, module_dir=TEST_CACHE_DIR, use_ptx=True)
239
+
240
+ # Make sure the expected files exist
241
+ module_identifier = wp.get_module("warp.tests.aux_test_module_aot").get_module_identifier()
242
+ for arch in archs:
243
+ expected_filename = f"{module_identifier}.sm{arch}.ptx"
244
+ expected_path = TEST_CACHE_DIR / expected_filename
245
+ self.assertTrue(expected_path.exists(), f"Expected compiled PTX file not found: {expected_path}")
246
+
247
+ finally:
248
+ # Clear the cache directory
249
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
250
+
251
+ def test_module_compile_specified_arch_cubin(self):
252
+ """Test that a module can be compiled for a specific architecture or architectures (CUBIN)."""
253
+
254
+ if wp.get_cuda_device_count() == 0:
255
+ self.skipTest("No CUDA devices found")
256
+
257
+ if len(wp.context.runtime.nvrtc_supported_archs) < 2:
258
+ self.skipTest("NVRTC must support at least two architectures to run this test")
259
+
260
+ try:
261
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
262
+ TEST_CACHE_DIR.mkdir(parents=True, exist_ok=True)
263
+
264
+ archs = list(wp.context.runtime.nvrtc_supported_archs)[:2]
265
+
266
+ wp.compile_aot_module(warp.tests.aux_test_module_aot, arch=archs, module_dir=TEST_CACHE_DIR, use_ptx=False)
267
+
268
+ # Make sure the expected files exist
269
+ module_identifier = wp.get_module("warp.tests.aux_test_module_aot").get_module_identifier()
270
+ for arch in archs:
271
+ expected_filename = f"{module_identifier}.sm{arch}.cubin"
272
+ expected_path = TEST_CACHE_DIR / expected_filename
273
+ self.assertTrue(expected_path.exists(), f"Expected compiled CUBIN file not found: {expected_path}")
274
+
275
+ finally:
276
+ # Clear the cache directory
277
+ shutil.rmtree(TEST_CACHE_DIR, ignore_errors=True)
278
+
279
+
280
+ devices = get_test_devices()
281
+ add_function_test(TestModuleAOT, "test_disable_hashing", test_disable_hashing, devices=devices)
282
+ add_function_test(TestModuleAOT, "test_enable_hashing", test_enable_hashing, devices=devices)
283
+ add_function_test(TestModuleAOT, "test_module_load_resolution", test_module_load_resolution, devices=devices)
284
+
285
+ if __name__ == "__main__":
286
+ wp.clear_kernel_cache()
287
+ unittest.main(verbosity=2)
warp/tests/test_print.py CHANGED
@@ -43,6 +43,52 @@ def test_print_kernel():
43
43
  # fmt: on
44
44
 
45
45
 
46
+ @wp.func
47
+ def test_print_numeric_func(value: int):
48
+ b = wp.bool(value)
49
+ print(b)
50
+ assert repr(b) == "bool(True)"
51
+
52
+ # signed ints
53
+ i8 = wp.int8(value)
54
+ print(i8)
55
+ assert repr(i8) == "int8(-1)"
56
+ i16 = wp.int16(value)
57
+ print(i16)
58
+ assert repr(i16) == "int16(-1)"
59
+ i32 = wp.int32(value)
60
+ print(i32)
61
+ assert repr(i32) == "int32(-1)"
62
+ i64 = wp.int64(value)
63
+ print(i64)
64
+ assert repr(i64) == "int64(-1)"
65
+
66
+ # unsigned ints
67
+ ui8 = wp.uint8(value)
68
+ print(ui8)
69
+ assert repr(ui8) == "uint8(255)"
70
+ ui16 = wp.uint16(value)
71
+ print(ui16)
72
+ assert repr(ui16) == "uint16(65535)"
73
+ ui32 = wp.uint32(value)
74
+ print(ui32)
75
+ assert repr(ui32) == "uint32(4294967295)"
76
+ ui64 = wp.uint64(value)
77
+ print(ui64)
78
+ assert repr(ui64) == "uint64(18446744073709551615)"
79
+
80
+ # floats
81
+ f16 = wp.float16(value)
82
+ print(f16)
83
+ assert repr(f16) == "float16(-1)"
84
+ f32 = wp.float32(value)
85
+ print(f32)
86
+ assert repr(f32) == "float32(-1)"
87
+ f64 = wp.float64(value)
88
+ print(f64)
89
+ assert repr(f64) == "float64(-1)"
90
+
91
+
46
92
  @wp.kernel
47
93
  def test_print_numeric_kernel(value: int):
48
94
  # signed ints
@@ -140,6 +186,29 @@ def test_print_numeric(test, device):
140
186
  rf"-1{os.linesep}",
141
187
  )
142
188
 
189
+ capture = StdOutCapture()
190
+ capture.begin()
191
+ test_print_numeric_func(-1)
192
+ s = capture.end()
193
+
194
+ # We skip the win32 comparison for now since the capture sometimes is an empty string
195
+ if sys.platform != "win32":
196
+ test.assertRegex(
197
+ s,
198
+ rf"True{os.linesep}"
199
+ rf"-1{os.linesep}"
200
+ rf"-1{os.linesep}"
201
+ rf"-1{os.linesep}"
202
+ rf"-1{os.linesep}"
203
+ rf"255{os.linesep}"
204
+ rf"65535{os.linesep}"
205
+ rf"4294967295{os.linesep}"
206
+ rf"18446744073709551615{os.linesep}"
207
+ rf"-1{os.linesep}"
208
+ rf"-1{os.linesep}"
209
+ rf"-1{os.linesep}",
210
+ )
211
+
143
212
 
144
213
  def test_print_boolean(test, device):
145
214
  wp.load_module(device=device)
warp/tests/test_quat.py CHANGED
@@ -2014,6 +2014,22 @@ def test_py_arithmetic_ops(test, device, dtype):
2014
2014
  test.assertSequenceEqual(wptype(24) / v, make_quat(12, 6, 4, 3))
2015
2015
 
2016
2016
 
2017
+ @wp.kernel
2018
+ def quat_grad(q: wp.quat):
2019
+ wp.expect_eq(q.w, 1.0)
2020
+
2021
+
2022
+ # Test passing of a quaternion in the backward pass
2023
+ def test_quat_backward(test, device):
2024
+ q = wp.quat_identity()
2025
+
2026
+ tape = wp.Tape()
2027
+ with tape:
2028
+ wp.launch(quat_grad, dim=1, inputs=[q], device=device)
2029
+
2030
+ tape.backward()
2031
+
2032
+
2017
2033
  @wp.kernel
2018
2034
  def quat_len_kernel(
2019
2035
  q: wp.quat,
@@ -2118,39 +2134,6 @@ def test_quat_assign(test, device):
2118
2134
  run(quat_assign_attribute)
2119
2135
 
2120
2136
 
2121
- def test_quat_assign_copy(test, device):
2122
- saved_enable_vector_component_overwrites_setting = wp.config.enable_vector_component_overwrites
2123
- try:
2124
- wp.config.enable_vector_component_overwrites = True
2125
-
2126
- @wp.kernel
2127
- def quat_assign_overwrite(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2128
- tid = wp.tid()
2129
-
2130
- a = wp.quat()
2131
- b = x[tid]
2132
- a = b
2133
- a[1] = 3.0
2134
-
2135
- y[tid] = a
2136
-
2137
- x = wp.ones(1, dtype=wp.quat, device=device, requires_grad=True)
2138
- y = wp.zeros(1, dtype=wp.quat, device=device, requires_grad=True)
2139
-
2140
- tape = wp.Tape()
2141
- with tape:
2142
- wp.launch(quat_assign_overwrite, dim=1, inputs=[x, y], device=device)
2143
-
2144
- y.grad = wp.ones_like(y, requires_grad=False)
2145
- tape.backward()
2146
-
2147
- assert_np_equal(y.numpy(), np.array([[1.0, 3.0, 1.0, 1.0]], dtype=float))
2148
- assert_np_equal(x.grad.numpy(), np.array([[1.0, 0.0, 1.0, 1.0]], dtype=float))
2149
-
2150
- finally:
2151
- wp.config.enable_vector_component_overwrites = saved_enable_vector_component_overwrites_setting
2152
-
2153
-
2154
2137
  @wp.kernel
2155
2138
  def quat_array_extract_subscript(x: wp.array2d(dtype=wp.quat), y: wp.array2d(dtype=float)):
2156
2139
  i, j = wp.tid()
@@ -2372,6 +2355,147 @@ def test_quat_array_sub_inplace(test, device):
2372
2355
  assert_np_equal(x.grad.numpy(), np.array([[-1.0, -1.0, -1.0, -1.0]], dtype=float))
2373
2356
 
2374
2357
 
2358
+ @wp.kernel
2359
+ def scalar_quat_div(x: wp.array(dtype=wp.quat), y: wp.array(dtype=wp.quat)):
2360
+ i = wp.tid()
2361
+ y[i] = 1.0 / x[i]
2362
+
2363
+
2364
+ def test_scalar_quat_div(test, device):
2365
+ x = wp.array((wp.quat(1.0, 2.0, 4.0, 8.0),), dtype=wp.quat, requires_grad=True, device=device)
2366
+ y = wp.ones(1, dtype=wp.quat, requires_grad=True, device=device)
2367
+
2368
+ tape = wp.Tape()
2369
+ with tape:
2370
+ wp.launch(scalar_quat_div, 1, inputs=(x,), outputs=(y,), device=device)
2371
+
2372
+ y.grad = wp.ones_like(y)
2373
+ tape.backward()
2374
+
2375
+ assert_np_equal(y.numpy(), np.array(((1.0, 0.5, 0.25, 0.125),), dtype=float))
2376
+ assert_np_equal(x.grad.numpy(), np.array(((-1.0, -0.25, -0.0625, -0.015625),), dtype=float))
2377
+
2378
+
2379
+ def test_quat_indexing_assign(test, device):
2380
+ @wp.func
2381
+ def fn():
2382
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
2383
+
2384
+ q[0] = 123.0
2385
+ q[1] *= 2.0
2386
+
2387
+ wp.expect_eq(q[0], 123.0)
2388
+ wp.expect_eq(q[1], 4.0)
2389
+ wp.expect_eq(q[2], 3.0)
2390
+ wp.expect_eq(q[3], 4.0)
2391
+
2392
+ q[-1] = 123.0
2393
+ q[-2] *= 2.0
2394
+
2395
+ wp.expect_eq(q[-1], 123.0)
2396
+ wp.expect_eq(q[-2], 6.0)
2397
+ wp.expect_eq(q[-3], 4.0)
2398
+ wp.expect_eq(q[-4], 123.0)
2399
+
2400
+ @wp.kernel(module="unique")
2401
+ def kernel():
2402
+ fn()
2403
+
2404
+ wp.launch(kernel, 1, device=device)
2405
+ wp.synchronize()
2406
+ fn()
2407
+
2408
+
2409
+ def test_quat_slicing_assign(test, device):
2410
+ vec0 = wp.vec(0, float)
2411
+ vec1 = wp.vec(1, float)
2412
+ vec2 = wp.vec(2, float)
2413
+ vec3 = wp.vec(3, float)
2414
+ vec4 = wp.vec(4, float)
2415
+
2416
+ @wp.func
2417
+ def fn():
2418
+ q = wp.quat(1.0, 2.0, 3.0, 4.0)
2419
+
2420
+ wp.expect_eq(q[:] == vec4(1.0, 2.0, 3.0, 4.0), True)
2421
+ wp.expect_eq(q[-123:123] == vec4(1.0, 2.0, 3.0, 4.0), True)
2422
+ wp.expect_eq(q[123:] == vec0(), True)
2423
+ wp.expect_eq(q[:-123] == vec0(), True)
2424
+ wp.expect_eq(q[::123] == vec1(1.0), True)
2425
+
2426
+ wp.expect_eq(q[1:] == vec3(2.0, 3.0, 4.0), True)
2427
+ wp.expect_eq(q[-2:] == vec2(3.0, 4.0), True)
2428
+ wp.expect_eq(q[:2] == vec2(1.0, 2.0), True)
2429
+ wp.expect_eq(q[:-1] == vec3(1.0, 2.0, 3.0), True)
2430
+ wp.expect_eq(q[::2] == vec2(1.0, 3.0), True)
2431
+ wp.expect_eq(q[1::2] == vec2(2.0, 4.0), True)
2432
+ wp.expect_eq(q[::-1] == vec4(4.0, 3.0, 2.0, 1.0), True)
2433
+ wp.expect_eq(q[::-2] == vec2(4.0, 2.0), True)
2434
+ wp.expect_eq(q[1::-2] == vec1(2.0), True)
2435
+
2436
+ q[1:] = vec3(5.0, 6.0, 7.0)
2437
+ wp.expect_eq(q == wp.quat(1.0, 5.0, 6.0, 7.0), True)
2438
+
2439
+ q[-2:] = vec2(8.0, 9.0)
2440
+ wp.expect_eq(q == wp.quat(1.0, 5.0, 8.0, 9.0), True)
2441
+
2442
+ q[:2] = vec2(10.0, 11.0)
2443
+ wp.expect_eq(q == wp.quat(10.0, 11.0, 8.0, 9.0), True)
2444
+
2445
+ q[:-1] = vec3(12.0, 13.0, 14.0)
2446
+ wp.expect_eq(q == wp.quat(12.0, 13.0, 14.0, 9.0), True)
2447
+
2448
+ q[::2] = vec2(15.0, 16.0)
2449
+ wp.expect_eq(q == wp.quat(15.0, 13.0, 16.0, 9.0), True)
2450
+
2451
+ q[1::2] = vec2(17.0, 18.0)
2452
+ wp.expect_eq(q == wp.quat(15.0, 17.0, 16.0, 18.0), True)
2453
+
2454
+ q[1::-2] = vec1(19.0)
2455
+ wp.expect_eq(q == wp.quat(15.0, 19.0, 16.0, 18.0), True)
2456
+
2457
+ q[1:] += vec3(20.0, 21.0, 22.0)
2458
+ wp.expect_eq(q == wp.quat(15.0, 39.0, 37.0, 40.0), True)
2459
+
2460
+ q[:-1] -= vec3(23.0, 24.0, 25.0)
2461
+ wp.expect_eq(q == wp.quat(-8.0, 15.0, 12.0, 40.0), True)
2462
+
2463
+ @wp.kernel(module="unique")
2464
+ def kernel():
2465
+ fn()
2466
+
2467
+ wp.launch(kernel, 1, device=device)
2468
+ wp.synchronize()
2469
+ fn()
2470
+
2471
+
2472
+ def test_quat_slicing_assign_backward(test, device):
2473
+ @wp.kernel(module="unique")
2474
+ def kernel(arr_x: wp.array(dtype=wp.vec2), arr_y: wp.array(dtype=wp.quat)):
2475
+ i = wp.tid()
2476
+
2477
+ y = arr_y[i]
2478
+
2479
+ y[:2] = arr_x[i]
2480
+ y[1:-1] += arr_x[i][:2]
2481
+ y[3:1:-1] -= arr_x[i][0:]
2482
+
2483
+ arr_y[i] = y
2484
+
2485
+ x = wp.ones(1, dtype=wp.vec2, requires_grad=True, device=device)
2486
+ y = wp.zeros(1, dtype=wp.quat, requires_grad=True, device=device)
2487
+
2488
+ tape = wp.Tape()
2489
+ with tape:
2490
+ wp.launch(kernel, 1, inputs=(x,), outputs=(y,), device=device)
2491
+
2492
+ y.grad = wp.ones_like(y)
2493
+ tape.backward()
2494
+
2495
+ assert_np_equal(y.numpy(), np.array(((1.0, 2.0, 0.0, -1.0),), dtype=float))
2496
+ assert_np_equal(x.grad.numpy(), np.array(((1.0, 1.0),), dtype=float))
2497
+
2498
+
2375
2499
  devices = get_test_devices()
2376
2500
 
2377
2501
 
@@ -2473,16 +2597,20 @@ for dtype in np_float_types:
2473
2597
  TestQuat, f"test_py_arithmetic_ops_{dtype.__name__}", test_py_arithmetic_ops, devices=None, dtype=dtype
2474
2598
  )
2475
2599
 
2600
+ add_function_test(TestQuat, "test_quat_backward", test_quat_backward, devices=devices)
2476
2601
  add_function_test(TestQuat, "test_quat_len", test_quat_len, devices=devices)
2477
2602
  add_function_test(TestQuat, "test_quat_extract", test_quat_extract, devices=devices)
2478
2603
  add_function_test(TestQuat, "test_quat_assign", test_quat_assign, devices=devices)
2479
- add_function_test(TestQuat, "test_quat_assign_copy", test_quat_assign_copy, devices=devices)
2480
2604
  add_function_test(TestQuat, "test_quat_array_extract", test_quat_array_extract, devices=devices)
2481
2605
  add_function_test(TestQuat, "test_quat_array_assign", test_quat_array_assign, devices=devices)
2482
2606
  add_function_test(TestQuat, "test_quat_add_inplace", test_quat_add_inplace, devices=devices)
2483
2607
  add_function_test(TestQuat, "test_quat_sub_inplace", test_quat_sub_inplace, devices=devices)
2484
2608
  add_function_test(TestQuat, "test_quat_array_add_inplace", test_quat_array_add_inplace, devices=devices)
2485
2609
  add_function_test(TestQuat, "test_quat_array_sub_inplace", test_quat_array_sub_inplace, devices=devices)
2610
+ add_function_test(TestQuat, "test_scalar_quat_div", test_scalar_quat_div, devices=devices)
2611
+ add_function_test(TestQuat, "test_quat_indexing_assign", test_quat_indexing_assign, devices=devices)
2612
+ add_function_test(TestQuat, "test_quat_slicing_assign", test_quat_slicing_assign, devices=devices)
2613
+ add_function_test(TestQuat, "test_quat_slicing_assign_backward", test_quat_slicing_assign_backward, devices=devices)
2486
2614
 
2487
2615
 
2488
2616
  if __name__ == "__main__":