warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,574 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import unittest
17
+
18
+ import numpy as np
19
+
20
+ import warp as wp
21
+ from warp.tests.unittest_utils import *
22
+
23
+ np_float_types = [np.float16, np.float32, np.float64]
24
+
25
+ kernel_cache = {}
26
+
27
+
28
+ def getkernel(func, suffix=""):
29
+ key = func.__name__ + "_" + suffix
30
+ if key not in kernel_cache:
31
+ kernel_cache[key] = wp.Kernel(func=func, key=key)
32
+ return kernel_cache[key]
33
+
34
+
35
+ def get_select_kernel(dtype):
36
+ def output_select_kernel_fn(input: wp.array(dtype=dtype), index: int, out: wp.array(dtype=dtype)):
37
+ out[0] = input[index]
38
+
39
+ return getkernel(output_select_kernel_fn, suffix=dtype.__name__)
40
+
41
+
42
+ def test_anon_constructor_error_shape_arg_missing(test, device):
43
+ @wp.kernel
44
+ def kernel():
45
+ wp.matrix(1.0, 2.0, 3.0)
46
+
47
+ with test.assertRaisesRegex(
48
+ RuntimeError,
49
+ r"the `shape` argument must be specified when initializing a matrix by value$",
50
+ ):
51
+ wp.launch(kernel, dim=1, inputs=[], device=device)
52
+
53
+
54
+ def test_anon_constructor_error_shape_mismatch(test, device):
55
+ @wp.kernel
56
+ def kernel():
57
+ wp.matrix(wp.matrix(shape=(1, 2), dtype=float), shape=(3, 4), dtype=float)
58
+
59
+ with test.assertRaisesRegex(
60
+ RuntimeError,
61
+ r"incompatible matrix of shape \(3, 4\) given when copy constructing a matrix of shape \(1, 2\)$",
62
+ ):
63
+ wp.launch(kernel, dim=1, inputs=[], device=device)
64
+
65
+
66
+ def test_anon_constructor_error_type_mismatch(test, device):
67
+ @wp.kernel
68
+ def kernel():
69
+ wp.matrix(1.0, shape=(3, 2), dtype=wp.float16)
70
+
71
+ with test.assertRaisesRegex(
72
+ RuntimeError,
73
+ r"the value used to fill this matrix is expected to be of the type `float16`$",
74
+ ):
75
+ wp.launch(kernel, dim=1, inputs=[], device=device)
76
+
77
+
78
+ def test_anon_constructor_error_invalid_arg_count(test, device):
79
+ @wp.kernel
80
+ def kernel():
81
+ wp.matrix(1.0, 2.0, 3.0, shape=(2, 2), dtype=float)
82
+
83
+ with test.assertRaisesRegex(
84
+ RuntimeError,
85
+ r"incompatible number of values given \(3\) when constructing a matrix of shape \(2, 2\)$",
86
+ ):
87
+ wp.launch(kernel, dim=1, inputs=[], device=device)
88
+
89
+
90
+ def test_tpl_constructor_error_incompatible_sizes(test, device):
91
+ @wp.kernel
92
+ def kernel():
93
+ wp.mat33(wp.mat22(1.0, 2.0, 3.0, 4.0))
94
+
95
+ with test.assertRaisesRegex(
96
+ RuntimeError,
97
+ r"incompatible matrix of shape \(3, 3\) given when copy constructing a matrix of shape \(2, 2\)$",
98
+ ):
99
+ wp.launch(kernel, dim=1, inputs=[], device=device)
100
+
101
+
102
+ def test_tpl_constructor_error_invalid_arg_count(test, device):
103
+ @wp.kernel
104
+ def kernel():
105
+ wp.mat22(1.0, 2.0, 3.0)
106
+
107
+ with test.assertRaisesRegex(
108
+ RuntimeError,
109
+ r"incompatible number of values given \(3\) when constructing a matrix of shape \(2, 2\)$",
110
+ ):
111
+ wp.launch(kernel, dim=1, inputs=[], device=device)
112
+
113
+
114
+ def test_matrix_from_vecs_runtime(test, device):
115
+ m1 = wp.matrix_from_cols(
116
+ wp.vec3(1.0, 2.0, 3.0),
117
+ wp.vec3(4.0, 5.0, 6.0),
118
+ wp.vec3(7.0, 8.0, 9.0),
119
+ )
120
+ assert m1[0, 0] == 1.0
121
+ assert m1[0, 1] == 4.0
122
+ assert m1[0, 2] == 7.0
123
+ assert m1[1, 0] == 2.0
124
+ assert m1[1, 1] == 5.0
125
+ assert m1[1, 2] == 8.0
126
+ assert m1[2, 0] == 3.0
127
+ assert m1[2, 1] == 6.0
128
+ assert m1[2, 2] == 9.0
129
+
130
+ assert m1.get_row(0) == wp.vec3(1.0, 4.0, 7.0)
131
+ assert m1.get_row(1) == wp.vec3(2.0, 5.0, 8.0)
132
+ assert m1.get_row(2) == wp.vec3(3.0, 6.0, 9.0)
133
+ assert m1.get_col(0) == wp.vec3(1.0, 2.0, 3.0)
134
+ assert m1.get_col(1) == wp.vec3(4.0, 5.0, 6.0)
135
+ assert m1.get_col(2) == wp.vec3(7.0, 8.0, 9.0)
136
+
137
+ m1.set_row(0, wp.vec3(8.0, 9.0, 10.0))
138
+ m1.set_row(1, wp.vec3(11.0, 12.0, 13.0))
139
+ m1.set_row(2, wp.vec3(14.0, 15.0, 16.0))
140
+
141
+ assert m1 == wp.matrix_from_rows(
142
+ wp.vec3(8.0, 9.0, 10.0),
143
+ wp.vec3(11.0, 12.0, 13.0),
144
+ wp.vec3(14.0, 15.0, 16.0),
145
+ )
146
+
147
+ m1.set_col(0, wp.vec3(8.0, 9.0, 10.0))
148
+ m1.set_col(1, wp.vec3(11.0, 12.0, 13.0))
149
+ m1.set_col(2, wp.vec3(14.0, 15.0, 16.0))
150
+
151
+ assert m1 == wp.matrix_from_cols(
152
+ wp.vec3(8.0, 9.0, 10.0),
153
+ wp.vec3(11.0, 12.0, 13.0),
154
+ wp.vec3(14.0, 15.0, 16.0),
155
+ )
156
+
157
+ m2 = wp.matrix_from_rows(
158
+ wp.vec3(1.0, 2.0, 3.0),
159
+ wp.vec3(4.0, 5.0, 6.0),
160
+ wp.vec3(7.0, 8.0, 9.0),
161
+ )
162
+ assert m2[0, 0] == 1.0
163
+ assert m2[0, 1] == 2.0
164
+ assert m2[0, 2] == 3.0
165
+ assert m2[1, 0] == 4.0
166
+ assert m2[1, 1] == 5.0
167
+ assert m2[1, 2] == 6.0
168
+ assert m2[2, 0] == 7.0
169
+ assert m2[2, 1] == 8.0
170
+ assert m2[2, 2] == 9.0
171
+
172
+ assert m2.get_row(0) == wp.vec3(1.0, 2.0, 3.0)
173
+ assert m2.get_row(1) == wp.vec3(4.0, 5.0, 6.0)
174
+ assert m2.get_row(2) == wp.vec3(7.0, 8.0, 9.0)
175
+ assert m2.get_col(0) == wp.vec3(1.0, 4.0, 7.0)
176
+ assert m2.get_col(1) == wp.vec3(2.0, 5.0, 8.0)
177
+ assert m2.get_col(2) == wp.vec3(3.0, 6.0, 9.0)
178
+
179
+ m2.set_row(0, wp.vec3(8.0, 9.0, 10.0))
180
+ m2.set_row(1, wp.vec3(11.0, 12.0, 13.0))
181
+ m2.set_row(2, wp.vec3(14.0, 15.0, 16.0))
182
+
183
+ assert m2 == wp.matrix_from_rows(
184
+ wp.vec3(8.0, 9.0, 10.0),
185
+ wp.vec3(11.0, 12.0, 13.0),
186
+ wp.vec3(14.0, 15.0, 16.0),
187
+ )
188
+
189
+ m2.set_col(0, wp.vec3(8.0, 9.0, 10.0))
190
+ m2.set_col(1, wp.vec3(11.0, 12.0, 13.0))
191
+ m2.set_col(2, wp.vec3(14.0, 15.0, 16.0))
192
+
193
+ assert m2 == wp.matrix_from_cols(
194
+ wp.vec3(8.0, 9.0, 10.0),
195
+ wp.vec3(11.0, 12.0, 13.0),
196
+ wp.vec3(14.0, 15.0, 16.0),
197
+ )
198
+
199
+ m3 = wp.matrix_from_cols(
200
+ wp.vec3(1.0, 2.0, 3.0),
201
+ wp.vec3(4.0, 5.0, 6.0),
202
+ )
203
+ assert m3[0, 0] == 1.0
204
+ assert m3[0, 1] == 4.0
205
+ assert m3[1, 0] == 2.0
206
+ assert m3[1, 1] == 5.0
207
+ assert m3[2, 0] == 3.0
208
+ assert m3[2, 1] == 6.0
209
+
210
+ assert m3.get_row(0) == wp.vec2(1.0, 4.0)
211
+ assert m3.get_row(1) == wp.vec2(2.0, 5.0)
212
+ assert m3.get_row(2) == wp.vec2(3.0, 6.0)
213
+ assert m3.get_col(0) == wp.vec3(1.0, 2.0, 3.0)
214
+ assert m3.get_col(1) == wp.vec3(4.0, 5.0, 6.0)
215
+
216
+ m3.set_row(0, wp.vec2(7.0, 8.0))
217
+ m3.set_row(1, wp.vec2(9.0, 10.0))
218
+ m3.set_row(2, wp.vec2(11.0, 12.0))
219
+
220
+ assert m3 == wp.matrix_from_rows(
221
+ wp.vec2(7.0, 8.0),
222
+ wp.vec2(9.0, 10.0),
223
+ wp.vec2(11.0, 12.0),
224
+ )
225
+
226
+ m3.set_col(0, wp.vec3(7.0, 8.0, 9.0))
227
+ m3.set_col(1, wp.vec3(10.0, 11.0, 12.0))
228
+
229
+ assert m3 == wp.matrix_from_cols(
230
+ wp.vec3(7.0, 8.0, 9.0),
231
+ wp.vec3(10.0, 11.0, 12.0),
232
+ )
233
+
234
+ m4 = wp.matrix_from_rows(
235
+ wp.vec3(1.0, 2.0, 3.0),
236
+ wp.vec3(4.0, 5.0, 6.0),
237
+ )
238
+ assert m4[0, 0] == 1.0
239
+ assert m4[0, 1] == 2.0
240
+ assert m4[0, 2] == 3.0
241
+ assert m4[1, 0] == 4.0
242
+ assert m4[1, 1] == 5.0
243
+ assert m4[1, 2] == 6.0
244
+
245
+ assert m4.get_row(0) == wp.vec3(1.0, 2.0, 3.0)
246
+ assert m4.get_row(1) == wp.vec3(4.0, 5.0, 6.0)
247
+ assert m4.get_col(0) == wp.vec2(1.0, 4.0)
248
+ assert m4.get_col(1) == wp.vec2(2.0, 5.0)
249
+ assert m4.get_col(2) == wp.vec2(3.0, 6.0)
250
+
251
+ m4.set_row(0, wp.vec3(7.0, 8.0, 9.0))
252
+ m4.set_row(1, wp.vec3(10.0, 11.0, 12.0))
253
+
254
+ assert m4 == wp.matrix_from_rows(
255
+ wp.vec3(7.0, 8.0, 9.0),
256
+ wp.vec3(10.0, 11.0, 12.0),
257
+ )
258
+
259
+ m4.set_col(0, wp.vec2(7.0, 8.0))
260
+ m4.set_col(1, wp.vec2(9.0, 10.0))
261
+ m4.set_col(2, wp.vec2(11.0, 12.0))
262
+
263
+ assert m4 == wp.matrix_from_cols(
264
+ wp.vec2(7.0, 8.0),
265
+ wp.vec2(9.0, 10.0),
266
+ wp.vec2(11.0, 12.0),
267
+ )
268
+
269
+ m4.set_row(0, 13.0)
270
+
271
+ assert m4 == wp.matrix_from_rows(
272
+ wp.vec3(13.0, 13.0, 13.0),
273
+ wp.vec3(8.0, 10.0, 12.0),
274
+ )
275
+
276
+ m4.set_col(2, 14.0)
277
+
278
+ assert m4 == wp.matrix_from_rows(
279
+ wp.vec3(13.0, 13.0, 14.0),
280
+ wp.vec3(8.0, 10.0, 14.0),
281
+ )
282
+
283
+
284
+ # Test matrix constructors using explicit type (float16)
285
+ # note that these tests are specifically not using generics / closure
286
+ # args to create kernels dynamically (like the rest of this file)
287
+ # as those use different code paths to resolve arg types which
288
+ # has lead to regressions.
289
+ @wp.kernel
290
+ def test_constructors_explicit_precision():
291
+ # construction for custom matrix types
292
+ eye = wp.identity(dtype=wp.float16, n=2)
293
+ zeros = wp.matrix(shape=(2, 2), dtype=wp.float16)
294
+ custom = wp.matrix(wp.float16(0.0), wp.float16(1.0), wp.float16(2.0), wp.float16(3.0), shape=(2, 2))
295
+
296
+ for i in range(2):
297
+ for j in range(2):
298
+ if i == j:
299
+ wp.expect_eq(eye[i, j], wp.float16(1.0))
300
+ else:
301
+ wp.expect_eq(eye[i, j], wp.float16(0.0))
302
+
303
+ wp.expect_eq(zeros[i, j], wp.float16(0.0))
304
+ wp.expect_eq(custom[i, j], wp.float16(i) * wp.float16(2.0) + wp.float16(j))
305
+
306
+
307
+ # Same as above but with a default (float/int) type
308
+ # which tests some different code paths that
309
+ # need to ensure types are correctly canonicalized
310
+ # during codegen
311
+ @wp.kernel
312
+ def test_constructors_default_precision():
313
+ # construction for default (float) matrix types
314
+ eye = wp.identity(dtype=float, n=2)
315
+ zeros = wp.matrix(shape=(2, 2), dtype=float)
316
+ custom = wp.matrix(0.0, 1.0, 2.0, 3.0, shape=(2, 2))
317
+
318
+ for i in range(2):
319
+ for j in range(2):
320
+ if i == j:
321
+ wp.expect_eq(eye[i, j], 1.0)
322
+ else:
323
+ wp.expect_eq(eye[i, j], 0.0)
324
+
325
+ wp.expect_eq(zeros[i, j], 0.0)
326
+ wp.expect_eq(custom[i, j], float(i) * 2.0 + float(j))
327
+
328
+
329
+ # NOTE: Compile tile is highly sensitive to shape so we use small values now
330
+ CONSTANT_SHAPE_ROWS = wp.constant(2)
331
+ CONSTANT_SHAPE_COLS = wp.constant(2)
332
+
333
+
334
+ # tests that we can use global constants in shape keyword argument
335
+ # for matrix constructor
336
+ @wp.kernel
337
+ def test_constructors_constant_shape():
338
+ m = wp.matrix(shape=(CONSTANT_SHAPE_ROWS, CONSTANT_SHAPE_COLS), dtype=float)
339
+
340
+ for i in range(CONSTANT_SHAPE_ROWS):
341
+ for j in range(CONSTANT_SHAPE_COLS):
342
+ m[i, j] = float(i * j)
343
+
344
+
345
+ @wp.kernel
346
+ def test_matrix_from_vecs():
347
+ m1 = wp.matrix_from_cols(
348
+ wp.vec3(1.0, 2.0, 3.0),
349
+ wp.vec3(4.0, 5.0, 6.0),
350
+ wp.vec3(7.0, 8.0, 9.0),
351
+ )
352
+ wp.expect_eq(m1[0, 0], 1.0)
353
+ wp.expect_eq(m1[0, 1], 4.0)
354
+ wp.expect_eq(m1[0, 2], 7.0)
355
+ wp.expect_eq(m1[1, 0], 2.0)
356
+ wp.expect_eq(m1[1, 1], 5.0)
357
+ wp.expect_eq(m1[1, 2], 8.0)
358
+ wp.expect_eq(m1[2, 0], 3.0)
359
+ wp.expect_eq(m1[2, 1], 6.0)
360
+ wp.expect_eq(m1[2, 2], 9.0)
361
+
362
+ m2 = wp.matrix_from_rows(
363
+ wp.vec3(1.0, 2.0, 3.0),
364
+ wp.vec3(4.0, 5.0, 6.0),
365
+ wp.vec3(7.0, 8.0, 9.0),
366
+ )
367
+ wp.expect_eq(m2[0, 0], 1.0)
368
+ wp.expect_eq(m2[0, 1], 2.0)
369
+ wp.expect_eq(m2[0, 2], 3.0)
370
+ wp.expect_eq(m2[1, 0], 4.0)
371
+ wp.expect_eq(m2[1, 1], 5.0)
372
+ wp.expect_eq(m2[1, 2], 6.0)
373
+ wp.expect_eq(m2[2, 0], 7.0)
374
+ wp.expect_eq(m2[2, 1], 8.0)
375
+ wp.expect_eq(m2[2, 2], 9.0)
376
+
377
+ m3 = wp.matrix_from_cols(
378
+ wp.vec3(1.0, 2.0, 3.0),
379
+ wp.vec3(4.0, 5.0, 6.0),
380
+ )
381
+ wp.expect_eq(m3[0, 0], 1.0)
382
+ wp.expect_eq(m3[0, 1], 4.0)
383
+ wp.expect_eq(m3[1, 0], 2.0)
384
+ wp.expect_eq(m3[1, 1], 5.0)
385
+ wp.expect_eq(m3[2, 0], 3.0)
386
+ wp.expect_eq(m3[2, 1], 6.0)
387
+
388
+ m4 = wp.matrix_from_rows(
389
+ wp.vec3(1.0, 2.0, 3.0),
390
+ wp.vec3(4.0, 5.0, 6.0),
391
+ )
392
+ wp.expect_eq(m4[0, 0], 1.0)
393
+ wp.expect_eq(m4[0, 1], 2.0)
394
+ wp.expect_eq(m4[0, 2], 3.0)
395
+ wp.expect_eq(m4[1, 0], 4.0)
396
+ wp.expect_eq(m4[1, 1], 5.0)
397
+ wp.expect_eq(m4[1, 2], 6.0)
398
+
399
+
400
+ mat32d = wp.mat(shape=(3, 2), dtype=wp.float64)
401
+
402
+
403
+ @wp.kernel
404
+ def test_matrix_constructor_value_func():
405
+ a = wp.mat22()
406
+ b = wp.matrix(a, shape=(2, 2))
407
+ c = mat32d()
408
+ d = mat32d(c, shape=(3, 2))
409
+ e = mat32d(wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0), wp.float64(1.0), wp.float64(2.0))
410
+ f = wp.matrix(1.0, 2.0, 3.0, 4.0, shape=(2, 2), dtype=float)
411
+
412
+
413
+ def test_quat_constructor(test, device, dtype, register_kernels=False):
414
+ rng = np.random.default_rng(123)
415
+
416
+ tol = {
417
+ np.float16: 1.0e-3,
418
+ np.float32: 1.0e-6,
419
+ np.float64: 1.0e-8,
420
+ }.get(dtype, 0)
421
+
422
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
423
+ mat44 = wp.types.matrix(shape=(4, 4), dtype=wptype)
424
+ vec4 = wp.types.vector(length=4, dtype=wptype)
425
+ vec3 = wp.types.vector(length=3, dtype=wptype)
426
+ quat = wp.types.quaternion(dtype=wptype)
427
+
428
+ output_select_kernel = get_select_kernel(wptype)
429
+
430
+ def check_mat_quat_constructor(
431
+ p: wp.array(dtype=vec3),
432
+ r: wp.array(dtype=quat),
433
+ s: wp.array(dtype=vec3),
434
+ outcomponents: wp.array(dtype=wptype),
435
+ outcomponents_alt: wp.array(dtype=wptype),
436
+ ):
437
+ m = wp.transform_compose(p[0], r[0], s[0])
438
+
439
+ R = wp.transpose(wp.quat_to_matrix(r[0]))
440
+ c0 = s[0][0] * R[0]
441
+ c1 = s[0][1] * R[1]
442
+ c2 = s[0][2] * R[2]
443
+ m_alt = wp.matrix_from_cols(
444
+ vec4(c0[0], c0[1], c0[2], wptype(0.0)),
445
+ vec4(c1[0], c1[1], c1[2], wptype(0.0)),
446
+ vec4(c2[0], c2[1], c2[2], wptype(0.0)),
447
+ vec4(p[0][0], p[0][1], p[0][2], wptype(1.0)),
448
+ )
449
+
450
+ idx = 0
451
+ for i in range(4):
452
+ for j in range(4):
453
+ outcomponents[idx] = m[i, j]
454
+ outcomponents_alt[idx] = m_alt[i, j]
455
+ idx = idx + 1
456
+
457
+ kernel = getkernel(check_mat_quat_constructor, suffix=dtype.__name__)
458
+
459
+ if register_kernels:
460
+ return
461
+
462
+ # translation:
463
+ p = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
464
+
465
+ # generate a normalized quaternion for the rotation:
466
+ r = rng.standard_normal(size=(1, 4))
467
+ r /= np.linalg.norm(r)
468
+ r = wp.array(r.astype(dtype), dtype=quat, requires_grad=True, device=device)
469
+
470
+ # scale:
471
+ s = wp.array(rng.standard_normal(size=(1, 3)).astype(dtype), dtype=vec3, requires_grad=True, device=device)
472
+
473
+ # just going to generate the matrix using the constructor, then
474
+ # more manually, and make sure the values/gradients are the same:
475
+ outcomponents = wp.zeros(4 * 4, dtype=wptype, requires_grad=True, device=device)
476
+ outcomponents_alt = wp.zeros(4 * 4, dtype=wptype, requires_grad=True, device=device)
477
+ wp.launch(kernel, dim=1, inputs=[p, r, s], outputs=[outcomponents, outcomponents_alt], device=device)
478
+ assert_np_equal(outcomponents.numpy(), outcomponents_alt.numpy(), tol=1.0e-6)
479
+
480
+ idx = 0
481
+ out = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
482
+ out_alt = wp.zeros(1, dtype=wptype, requires_grad=True, device=device)
483
+ for _i in range(4):
484
+ for _j in range(4):
485
+ tape = wp.Tape()
486
+ with tape:
487
+ wp.launch(kernel, dim=1, inputs=[p, r, s], outputs=[outcomponents, outcomponents_alt], device=device)
488
+ wp.launch(output_select_kernel, dim=1, inputs=[outcomponents, idx], outputs=[out], device=device)
489
+ wp.launch(
490
+ output_select_kernel, dim=1, inputs=[outcomponents_alt, idx], outputs=[out_alt], device=device
491
+ )
492
+
493
+ tape.backward(loss=out)
494
+ p_grad = 1.0 * tape.gradients[p].numpy()[0]
495
+ r_grad = 1.0 * tape.gradients[r].numpy()[0]
496
+ s_grad = 1.0 * tape.gradients[s].numpy()[0]
497
+ tape.zero()
498
+
499
+ tape.backward(loss=out_alt)
500
+ p_grad_alt = 1.0 * tape.gradients[p].numpy()[0]
501
+ r_grad_alt = 1.0 * tape.gradients[r].numpy()[0]
502
+ s_grad_alt = 1.0 * tape.gradients[s].numpy()[0]
503
+ tape.zero()
504
+
505
+ assert_np_equal(p_grad, p_grad_alt, tol=tol)
506
+ assert_np_equal(r_grad, r_grad_alt, tol=tol)
507
+ assert_np_equal(s_grad, s_grad_alt, tol=tol)
508
+
509
+ idx = idx + 1
510
+
511
+
512
+ devices = get_test_devices()
513
+
514
+
515
+ class TestMatConstructors(unittest.TestCase):
516
+ pass
517
+
518
+
519
+ add_function_test(
520
+ TestMatConstructors,
521
+ "test_anon_constructor_error_shape_arg_missing",
522
+ test_anon_constructor_error_shape_arg_missing,
523
+ devices=devices,
524
+ )
525
+ add_function_test(
526
+ TestMatConstructors,
527
+ "test_anon_constructor_error_shape_mismatch",
528
+ test_anon_constructor_error_shape_mismatch,
529
+ devices=devices,
530
+ )
531
+ add_function_test(
532
+ TestMatConstructors,
533
+ "test_anon_constructor_error_type_mismatch",
534
+ test_anon_constructor_error_type_mismatch,
535
+ devices=devices,
536
+ )
537
+ add_function_test(
538
+ TestMatConstructors,
539
+ "test_anon_constructor_error_invalid_arg_count",
540
+ test_anon_constructor_error_invalid_arg_count,
541
+ devices=devices,
542
+ )
543
+ add_function_test(
544
+ TestMatConstructors,
545
+ "test_tpl_constructor_error_incompatible_sizes",
546
+ test_tpl_constructor_error_incompatible_sizes,
547
+ devices=devices,
548
+ )
549
+ add_function_test(
550
+ TestMatConstructors,
551
+ "test_tpl_constructor_error_invalid_arg_count",
552
+ test_tpl_constructor_error_invalid_arg_count,
553
+ devices=devices,
554
+ )
555
+ add_function_test(TestMatConstructors, "test_matrix_from_vecs_runtime", test_matrix_from_vecs_runtime, devices=devices)
556
+
557
+ add_kernel_test(TestMatConstructors, test_constructors_explicit_precision, dim=1, devices=devices)
558
+ add_kernel_test(TestMatConstructors, test_constructors_default_precision, dim=1, devices=devices)
559
+ add_kernel_test(TestMatConstructors, test_constructors_constant_shape, dim=1, devices=devices)
560
+ add_kernel_test(TestMatConstructors, test_matrix_from_vecs, dim=1, devices=devices)
561
+ add_kernel_test(TestMatConstructors, test_matrix_constructor_value_func, dim=1, devices=devices)
562
+
563
+ for dtype in np_float_types:
564
+ add_function_test_register_kernel(
565
+ TestMatConstructors,
566
+ f"test_quat_constructor_{dtype.__name__}",
567
+ test_quat_constructor,
568
+ devices=devices,
569
+ dtype=dtype,
570
+ )
571
+
572
+ if __name__ == "__main__":
573
+ wp.clear_kernel_cache()
574
+ unittest.main(verbosity=2, failfast=True)