warp-lang 1.3.2__py3-none-win_amd64.whl → 1.4.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (108) hide show
  1. warp/__init__.py +6 -0
  2. warp/autograd.py +59 -6
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build_dll.py +8 -10
  6. warp/builtins.py +126 -4
  7. warp/codegen.py +435 -53
  8. warp/config.py +1 -1
  9. warp/context.py +678 -403
  10. warp/dlpack.py +2 -0
  11. warp/examples/benchmarks/benchmark_cloth.py +10 -0
  12. warp/examples/core/example_render_opengl.py +12 -10
  13. warp/examples/fem/example_adaptive_grid.py +251 -0
  14. warp/examples/fem/example_apic_fluid.py +1 -1
  15. warp/examples/fem/example_diffusion_3d.py +2 -2
  16. warp/examples/fem/example_magnetostatics.py +1 -1
  17. warp/examples/fem/example_streamlines.py +1 -0
  18. warp/examples/fem/utils.py +23 -4
  19. warp/examples/sim/example_cloth.py +50 -6
  20. warp/fem/__init__.py +2 -0
  21. warp/fem/adaptivity.py +493 -0
  22. warp/fem/field/field.py +2 -1
  23. warp/fem/field/nodal_field.py +18 -26
  24. warp/fem/field/test.py +4 -4
  25. warp/fem/field/trial.py +4 -4
  26. warp/fem/geometry/__init__.py +1 -0
  27. warp/fem/geometry/adaptive_nanogrid.py +843 -0
  28. warp/fem/geometry/nanogrid.py +55 -28
  29. warp/fem/space/__init__.py +1 -1
  30. warp/fem/space/nanogrid_function_space.py +69 -35
  31. warp/fem/utils.py +113 -107
  32. warp/jax_experimental.py +28 -15
  33. warp/native/array.h +0 -1
  34. warp/native/builtin.h +103 -6
  35. warp/native/bvh.cu +2 -0
  36. warp/native/cuda_util.cpp +14 -0
  37. warp/native/cuda_util.h +2 -0
  38. warp/native/error.cpp +4 -2
  39. warp/native/exports.h +99 -17
  40. warp/native/mat.h +97 -0
  41. warp/native/mesh.cpp +36 -0
  42. warp/native/mesh.cu +51 -0
  43. warp/native/mesh.h +1 -0
  44. warp/native/quat.h +43 -0
  45. warp/native/spatial.h +6 -0
  46. warp/native/vec.h +74 -0
  47. warp/native/warp.cpp +2 -1
  48. warp/native/warp.cu +10 -3
  49. warp/native/warp.h +8 -1
  50. warp/paddle.py +382 -0
  51. warp/sim/__init__.py +1 -0
  52. warp/sim/collide.py +519 -0
  53. warp/sim/integrator_euler.py +18 -5
  54. warp/sim/integrator_featherstone.py +5 -5
  55. warp/sim/integrator_vbd.py +1026 -0
  56. warp/sim/model.py +49 -23
  57. warp/stubs.py +459 -0
  58. warp/tape.py +2 -0
  59. warp/tests/aux_test_dependent.py +1 -0
  60. warp/tests/aux_test_name_clash1.py +32 -0
  61. warp/tests/aux_test_name_clash2.py +32 -0
  62. warp/tests/aux_test_square.py +1 -0
  63. warp/tests/test_array.py +222 -0
  64. warp/tests/test_async.py +3 -3
  65. warp/tests/test_atomic.py +6 -0
  66. warp/tests/test_closest_point_edge_edge.py +93 -1
  67. warp/tests/test_codegen.py +62 -15
  68. warp/tests/test_codegen_instancing.py +1457 -0
  69. warp/tests/test_collision.py +486 -0
  70. warp/tests/test_compile_consts.py +3 -28
  71. warp/tests/test_dlpack.py +170 -0
  72. warp/tests/test_examples.py +22 -8
  73. warp/tests/test_fast_math.py +10 -4
  74. warp/tests/test_fem.py +64 -0
  75. warp/tests/test_func.py +46 -0
  76. warp/tests/test_implicit_init.py +49 -0
  77. warp/tests/test_jax.py +58 -0
  78. warp/tests/test_mat.py +84 -0
  79. warp/tests/test_mesh_query_point.py +188 -0
  80. warp/tests/test_module_hashing.py +40 -0
  81. warp/tests/test_multigpu.py +3 -3
  82. warp/tests/test_overwrite.py +8 -0
  83. warp/tests/test_paddle.py +852 -0
  84. warp/tests/test_print.py +89 -0
  85. warp/tests/test_quat.py +111 -0
  86. warp/tests/test_reload.py +31 -1
  87. warp/tests/test_scalar_ops.py +2 -0
  88. warp/tests/test_static.py +412 -0
  89. warp/tests/test_streams.py +64 -3
  90. warp/tests/test_struct.py +4 -4
  91. warp/tests/test_torch.py +24 -0
  92. warp/tests/test_triangle_closest_point.py +137 -0
  93. warp/tests/test_types.py +1 -1
  94. warp/tests/test_vbd.py +386 -0
  95. warp/tests/test_vec.py +143 -0
  96. warp/tests/test_vec_scalar_ops.py +139 -0
  97. warp/tests/test_volume.py +30 -0
  98. warp/tests/unittest_suites.py +12 -0
  99. warp/tests/unittest_utils.py +9 -5
  100. warp/thirdparty/dlpack.py +3 -1
  101. warp/types.py +157 -34
  102. warp/utils.py +37 -14
  103. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/METADATA +10 -8
  104. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/RECORD +107 -95
  105. warp/tests/test_point_triangle_closest_point.py +0 -143
  106. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/LICENSE.md +0 -0
  107. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/WHEEL +0 -0
  108. {warp_lang-1.3.2.dist-info → warp_lang-1.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1457 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+ from typing import Any
10
+
11
+ import warp as wp
12
+ import warp.tests.aux_test_name_clash1 as name_clash_module_1
13
+ import warp.tests.aux_test_name_clash2 as name_clash_module_2
14
+ from warp.tests.unittest_utils import *
15
+
16
+ # =======================================================================
17
+
18
+
19
+ @wp.kernel
20
+ def global_kernel(a: wp.array(dtype=int)):
21
+ a[0] = 17
22
+
23
+
24
+ global_kernel_1 = global_kernel
25
+
26
+
27
+ @wp.kernel
28
+ def global_kernel(a: wp.array(dtype=int)):
29
+ a[0] = 42
30
+
31
+
32
+ global_kernel_2 = global_kernel
33
+
34
+
35
+ def test_global_kernel_redefine(test, device):
36
+ """Ensure that referenced kernels remain valid and unique, even when redefined."""
37
+
38
+ with wp.ScopedDevice(device):
39
+ a = wp.zeros(1, dtype=int)
40
+
41
+ wp.launch(global_kernel, dim=1, inputs=[a])
42
+ test.assertEqual(a.numpy()[0], 42)
43
+
44
+ wp.launch(global_kernel_1, dim=1, inputs=[a])
45
+ test.assertEqual(a.numpy()[0], 17)
46
+
47
+ wp.launch(global_kernel_2, dim=1, inputs=[a])
48
+ test.assertEqual(a.numpy()[0], 42)
49
+
50
+
51
+ # =======================================================================
52
+
53
+
54
+ @wp.func
55
+ def global_func():
56
+ return 17
57
+
58
+
59
+ global_func_1 = global_func
60
+
61
+
62
+ @wp.func
63
+ def global_func():
64
+ return 42
65
+
66
+
67
+ global_func_2 = global_func
68
+
69
+
70
+ @wp.kernel
71
+ def global_func_kernel(a: wp.array(dtype=int)):
72
+ a[0] = global_func()
73
+ a[1] = global_func_1()
74
+ a[2] = global_func_2()
75
+
76
+
77
+ def test_global_func_redefine(test, device):
78
+ """Ensure that referenced functions remain valid and unique, even when redefined."""
79
+
80
+ with wp.ScopedDevice(device):
81
+ a = wp.zeros(3, dtype=int)
82
+ wp.launch(global_func_kernel, dim=1, inputs=[a])
83
+ assert_np_equal(a.numpy(), np.array([42, 17, 42]))
84
+
85
+
86
+ # =======================================================================
87
+
88
+
89
+ @wp.struct
90
+ class GlobalStruct:
91
+ v: float
92
+
93
+
94
+ GlobalStruct1 = GlobalStruct
95
+
96
+
97
+ @wp.struct
98
+ class GlobalStruct:
99
+ v: wp.vec2
100
+
101
+
102
+ GlobalStruct2 = GlobalStruct
103
+
104
+
105
+ @wp.kernel
106
+ def global_struct_args_kernel(s0: GlobalStruct, s1: GlobalStruct1, s2: GlobalStruct2, a: wp.array(dtype=float)):
107
+ a[0] = s0.v[0]
108
+ a[1] = s0.v[1]
109
+ a[2] = s1.v
110
+ a[3] = s2.v[0]
111
+ a[4] = s2.v[1]
112
+
113
+
114
+ def test_global_struct_args_redefine(test, device):
115
+ """Ensure that referenced structs remain valid and unique, even when redefined."""
116
+ with wp.ScopedDevice(device):
117
+ s0 = GlobalStruct()
118
+ s1 = GlobalStruct1()
119
+ s2 = GlobalStruct2()
120
+ s0.v = wp.vec2(1.0, 2.0)
121
+ s1.v = 3.0
122
+ s2.v = wp.vec2(4.0, 5.0)
123
+
124
+ a = wp.zeros(5, dtype=float)
125
+
126
+ wp.launch(global_struct_args_kernel, dim=1, inputs=[s0, s1, s2, a])
127
+
128
+ assert_np_equal(a.numpy(), np.array([1, 2, 3, 4, 5], dtype=np.float32))
129
+
130
+
131
+ @wp.kernel
132
+ def global_struct_ctor_kernel(a: wp.array(dtype=float)):
133
+ s0 = GlobalStruct()
134
+ s1 = GlobalStruct1()
135
+ s2 = GlobalStruct2()
136
+ s0.v = wp.vec2(1.0, 2.0)
137
+ s1.v = 3.0
138
+ s2.v = wp.vec2(4.0, 5.0)
139
+ a[0] = s0.v[0]
140
+ a[1] = s0.v[1]
141
+ a[2] = s1.v
142
+ a[3] = s2.v[0]
143
+ a[4] = s2.v[1]
144
+
145
+
146
+ def test_global_struct_ctor_redefine(test, device):
147
+ """Ensure that referenced structs remain valid and unique, even when redefined."""
148
+ with wp.ScopedDevice(device):
149
+ a = wp.zeros(5, dtype=float)
150
+ wp.launch(global_struct_ctor_kernel, dim=1, inputs=[a])
151
+ assert_np_equal(a.numpy(), np.array([1, 2, 3, 4, 5], dtype=np.float32))
152
+
153
+
154
+ # =======================================================================
155
+
156
+
157
+ # "primary" (first) overload
158
+ @wp.func
159
+ def global_func_po(x: int):
160
+ return x * x
161
+
162
+
163
+ # "secondary" overload
164
+ @wp.func
165
+ def global_func_po(x: float):
166
+ return x * x
167
+
168
+
169
+ # redefine primary overload
170
+ @wp.func
171
+ def global_func_po(x: int):
172
+ return x * x * x
173
+
174
+
175
+ @wp.kernel
176
+ def global_overload_primary_kernel(a: wp.array(dtype=float)):
177
+ # use primary (int) overload
178
+ a[0] = float(global_func_po(2))
179
+ # use secondary (float) overload
180
+ a[1] = global_func_po(2.0)
181
+
182
+
183
+ def test_global_overload_primary_redefine(test, device):
184
+ """Ensure that redefining a primary overload works and doesn't affect secondary overloads."""
185
+ with wp.ScopedDevice(device):
186
+ a = wp.zeros(2, dtype=float)
187
+ wp.launch(global_overload_primary_kernel, dim=1, inputs=[a])
188
+ assert_np_equal(a.numpy(), np.array([8, 4], dtype=np.float32))
189
+
190
+
191
+ # =======================================================================
192
+
193
+
194
+ # "primary" (first) overload
195
+ @wp.func
196
+ def global_func_so(x: int):
197
+ return x * x
198
+
199
+
200
+ # "secondary" overload
201
+ @wp.func
202
+ def global_func_so(x: float):
203
+ return x * x
204
+
205
+
206
+ # redefine secondary overload
207
+ @wp.func
208
+ def global_func_so(x: float):
209
+ return x * x * x
210
+
211
+
212
+ @wp.kernel
213
+ def global_overload_secondary_kernel(a: wp.array(dtype=float)):
214
+ # use primary (int) overload
215
+ a[0] = float(global_func_so(2))
216
+ # use secondary (float) overload
217
+ a[1] = global_func_so(2.0)
218
+
219
+
220
+ def test_global_overload_secondary_redefine(test, device):
221
+ """Ensure that redefining a secondary overload works."""
222
+ with wp.ScopedDevice(device):
223
+ a = wp.zeros(2, dtype=float)
224
+ wp.launch(global_overload_secondary_kernel, dim=1, inputs=[a])
225
+ assert_np_equal(a.numpy(), np.array([4, 8], dtype=np.float32))
226
+
227
+
228
+ # =======================================================================
229
+
230
+
231
+ @wp.kernel
232
+ def global_generic_kernel(x: Any, a: wp.array(dtype=Any)):
233
+ a[0] = x * x
234
+
235
+
236
+ global_generic_kernel_1 = global_generic_kernel
237
+
238
+
239
+ @wp.kernel
240
+ def global_generic_kernel(x: Any, a: wp.array(dtype=Any)):
241
+ a[0] = x * x * x
242
+
243
+
244
+ global_generic_kernel_2 = global_generic_kernel
245
+
246
+
247
+ def test_global_generic_kernel_redefine(test, device):
248
+ """Ensure that referenced generic kernels remain valid and unique, even when redefined."""
249
+
250
+ with wp.ScopedDevice(device):
251
+ ai = wp.zeros(1, dtype=int)
252
+ af = wp.zeros(1, dtype=float)
253
+
254
+ wp.launch(global_generic_kernel, dim=1, inputs=[2, ai])
255
+ wp.launch(global_generic_kernel, dim=1, inputs=[2.0, af])
256
+ test.assertEqual(ai.numpy()[0], 8)
257
+ test.assertEqual(af.numpy()[0], 8.0)
258
+
259
+ wp.launch(global_generic_kernel_1, dim=1, inputs=[2, ai])
260
+ wp.launch(global_generic_kernel_1, dim=1, inputs=[2.0, af])
261
+ test.assertEqual(ai.numpy()[0], 4)
262
+ test.assertEqual(af.numpy()[0], 4.0)
263
+
264
+ wp.launch(global_generic_kernel_2, dim=1, inputs=[2, ai])
265
+ wp.launch(global_generic_kernel_2, dim=1, inputs=[2.0, af])
266
+ test.assertEqual(ai.numpy()[0], 8)
267
+ test.assertEqual(af.numpy()[0], 8.0)
268
+
269
+
270
+ # =======================================================================
271
+
272
+
273
+ @wp.func
274
+ def global_generic_func(x: Any):
275
+ return x * x
276
+
277
+
278
+ global_generic_func_1 = global_generic_func
279
+
280
+
281
+ @wp.func
282
+ def global_generic_func(x: Any):
283
+ return x * x * x
284
+
285
+
286
+ global_generic_func_2 = global_generic_func
287
+
288
+
289
+ @wp.kernel
290
+ def global_generic_func_kernel(ai: wp.array(dtype=int), af: wp.array(dtype=float)):
291
+ ai[0] = global_generic_func(2)
292
+ af[0] = global_generic_func(2.0)
293
+
294
+ ai[1] = global_generic_func_1(2)
295
+ af[1] = global_generic_func_1(2.0)
296
+
297
+ ai[2] = global_generic_func_2(2)
298
+ af[2] = global_generic_func_2(2.0)
299
+
300
+
301
+ def test_global_generic_func_redefine(test, device):
302
+ """Ensure that referenced generic functions remain valid and unique, even when redefined."""
303
+
304
+ with wp.ScopedDevice(device):
305
+ ai = wp.zeros(3, dtype=int)
306
+ af = wp.zeros(3, dtype=float)
307
+ wp.launch(global_generic_func_kernel, dim=1, inputs=[ai, af])
308
+ assert_np_equal(ai.numpy(), np.array([8, 4, 8], dtype=np.int32))
309
+ assert_np_equal(af.numpy(), np.array([8, 4, 8], dtype=np.float32))
310
+
311
+
312
+ # =======================================================================
313
+
314
+
315
+ def create_kernel_simple():
316
+ # not a closure
317
+ @wp.kernel
318
+ def k(a: wp.array(dtype=int)):
319
+ a[0] = 17
320
+
321
+ return k
322
+
323
+
324
+ simple_kernel_1 = create_kernel_simple()
325
+ simple_kernel_2 = create_kernel_simple()
326
+
327
+
328
+ def test_create_kernel_simple(test, device):
329
+ """Test creating multiple identical simple (non-closure) kernels."""
330
+ with wp.ScopedDevice(device):
331
+ a = wp.zeros(1, dtype=int)
332
+
333
+ wp.launch(simple_kernel_1, dim=1, inputs=[a])
334
+ test.assertEqual(a.numpy()[0], 17)
335
+
336
+ wp.launch(simple_kernel_2, dim=1, inputs=[a])
337
+ test.assertEqual(a.numpy()[0], 17)
338
+
339
+
340
+ # =======================================================================
341
+
342
+
343
+ def create_func_simple():
344
+ # not a closure
345
+ @wp.func
346
+ def f():
347
+ return 17
348
+
349
+ return f
350
+
351
+
352
+ simple_func_1 = create_func_simple()
353
+ simple_func_2 = create_func_simple()
354
+
355
+
356
+ @wp.kernel
357
+ def simple_func_kernel(a: wp.array(dtype=int)):
358
+ a[0] = simple_func_1()
359
+ a[1] = simple_func_2()
360
+
361
+
362
+ def test_create_func_simple(test, device):
363
+ """Test creating multiple identical simple (non-closure) functions."""
364
+ with wp.ScopedDevice(device):
365
+ a = wp.zeros(2, dtype=int)
366
+ wp.launch(simple_func_kernel, dim=1, inputs=[a])
367
+ assert_np_equal(a.numpy(), np.array([17, 17]))
368
+
369
+
370
+ # =======================================================================
371
+
372
+
373
+ def create_struct_simple():
374
+ @wp.struct
375
+ class S:
376
+ x: int
377
+
378
+ return S
379
+
380
+
381
+ SimpleStruct1 = create_struct_simple()
382
+ SimpleStruct2 = create_struct_simple()
383
+
384
+
385
+ @wp.kernel
386
+ def simple_struct_args_kernel(s1: SimpleStruct1, s2: SimpleStruct2, a: wp.array(dtype=int)):
387
+ a[0] = s1.x
388
+ a[1] = s2.x
389
+
390
+
391
+ def test_create_struct_simple_args(test, device):
392
+ """Test creating multiple identical structs and passing them as arguments."""
393
+ with wp.ScopedDevice(device):
394
+ s1 = SimpleStruct1()
395
+ s2 = SimpleStruct2()
396
+ s1.x = 17
397
+ s2.x = 42
398
+ a = wp.zeros(2, dtype=int)
399
+ wp.launch(simple_struct_args_kernel, dim=1, inputs=[s1, s2, a])
400
+ assert_np_equal(a.numpy(), np.array([17, 42]))
401
+
402
+
403
+ @wp.kernel
404
+ def simple_struct_ctor_kernel(a: wp.array(dtype=int)):
405
+ s1 = SimpleStruct1()
406
+ s2 = SimpleStruct2()
407
+ s1.x = 17
408
+ s2.x = 42
409
+ a[0] = s1.x
410
+ a[1] = s2.x
411
+
412
+
413
+ def test_create_struct_simple_ctor(test, device):
414
+ """Test creating multiple identical structs and constructing them in kernels."""
415
+ with wp.ScopedDevice(device):
416
+ a = wp.zeros(2, dtype=int)
417
+ wp.launch(simple_struct_ctor_kernel, dim=1, inputs=[a])
418
+ assert_np_equal(a.numpy(), np.array([17, 42]))
419
+
420
+
421
+ # =======================================================================
422
+
423
+
424
+ def create_generic_kernel_simple():
425
+ # not a closure
426
+ @wp.kernel
427
+ def k(x: Any, a: wp.array(dtype=Any)):
428
+ a[0] = x * x
429
+
430
+ return k
431
+
432
+
433
+ simple_generic_kernel_1 = create_generic_kernel_simple()
434
+ simple_generic_kernel_2 = create_generic_kernel_simple()
435
+
436
+
437
+ def test_create_generic_kernel_simple(test, device):
438
+ """Test creating multiple identical simple (non-closure) generic kernels."""
439
+ with wp.ScopedDevice(device):
440
+ ai = wp.zeros(1, dtype=int)
441
+ af = wp.zeros(1, dtype=float)
442
+
443
+ wp.launch(simple_generic_kernel_1, dim=1, inputs=[2, ai])
444
+ wp.launch(simple_generic_kernel_1, dim=1, inputs=[2.0, af])
445
+ test.assertEqual(ai.numpy()[0], 4)
446
+ test.assertEqual(af.numpy()[0], 4.0)
447
+
448
+ wp.launch(simple_generic_kernel_2, dim=1, inputs=[2, ai])
449
+ wp.launch(simple_generic_kernel_2, dim=1, inputs=[2.0, af])
450
+ test.assertEqual(ai.numpy()[0], 4)
451
+ test.assertEqual(af.numpy()[0], 4.0)
452
+
453
+
454
+ # =======================================================================
455
+
456
+
457
+ def create_generic_func_simple():
458
+ # not a closure
459
+ @wp.func
460
+ def f(x: Any):
461
+ return x * x
462
+
463
+ return f
464
+
465
+
466
+ simple_generic_func_1 = create_generic_func_simple()
467
+ simple_generic_func_2 = create_generic_func_simple()
468
+
469
+
470
+ @wp.kernel
471
+ def simple_generic_func_kernel(
472
+ ai: wp.array(dtype=int),
473
+ af: wp.array(dtype=float),
474
+ ):
475
+ ai[0] = simple_generic_func_1(2)
476
+ af[0] = simple_generic_func_1(2.0)
477
+
478
+ ai[1] = simple_generic_func_2(2)
479
+ af[1] = simple_generic_func_2(2.0)
480
+
481
+
482
+ def test_create_generic_func_simple(test, device):
483
+ """Test creating multiple identical simple (non-closure) generic functions."""
484
+ with wp.ScopedDevice(device):
485
+ ai = wp.zeros(2, dtype=int)
486
+ af = wp.zeros(2, dtype=float)
487
+ wp.launch(simple_generic_func_kernel, dim=1, inputs=[ai, af])
488
+ assert_np_equal(ai.numpy(), np.array([4, 4], dtype=np.int32))
489
+ assert_np_equal(af.numpy(), np.array([4, 4], dtype=np.float32))
490
+
491
+
492
+ # =======================================================================
493
+
494
+
495
+ def create_kernel_cond(cond):
496
+ if cond:
497
+
498
+ @wp.kernel
499
+ def k(a: wp.array(dtype=int)):
500
+ a[0] = 17
501
+ else:
502
+
503
+ @wp.kernel
504
+ def k(a: wp.array(dtype=int)):
505
+ a[0] = 42
506
+
507
+ return k
508
+
509
+
510
+ cond_kernel_1 = create_kernel_cond(True)
511
+ cond_kernel_2 = create_kernel_cond(False)
512
+
513
+
514
+ def test_create_kernel_cond(test, device):
515
+ """Test conditionally creating different simple (non-closure) kernels."""
516
+ with wp.ScopedDevice(device):
517
+ a = wp.zeros(1, dtype=int)
518
+
519
+ wp.launch(cond_kernel_1, dim=1, inputs=[a])
520
+ test.assertEqual(a.numpy()[0], 17)
521
+
522
+ wp.launch(cond_kernel_2, dim=1, inputs=[a])
523
+ test.assertEqual(a.numpy()[0], 42)
524
+
525
+
526
+ # =======================================================================
527
+
528
+
529
+ def create_func_cond(cond):
530
+ if cond:
531
+
532
+ @wp.func
533
+ def f():
534
+ return 17
535
+ else:
536
+
537
+ @wp.func
538
+ def f():
539
+ return 42
540
+
541
+ return f
542
+
543
+
544
+ cond_func_1 = create_func_cond(True)
545
+ cond_func_2 = create_func_cond(False)
546
+
547
+
548
+ @wp.kernel
549
+ def cond_func_kernel(a: wp.array(dtype=int)):
550
+ a[0] = cond_func_1()
551
+ a[1] = cond_func_2()
552
+
553
+
554
+ def test_create_func_cond(test, device):
555
+ """Test conditionally creating different simple (non-closure) functions."""
556
+ with wp.ScopedDevice(device):
557
+ a = wp.zeros(2, dtype=int)
558
+ wp.launch(cond_func_kernel, dim=1, inputs=[a])
559
+ assert_np_equal(a.numpy(), np.array([17, 42]))
560
+
561
+
562
+ # =======================================================================
563
+
564
+
565
+ def create_struct_cond(cond):
566
+ if cond:
567
+
568
+ @wp.struct
569
+ class S:
570
+ v: float
571
+ else:
572
+
573
+ @wp.struct
574
+ class S:
575
+ v: wp.vec2
576
+
577
+ return S
578
+
579
+
580
+ CondStruct1 = create_struct_cond(True)
581
+ CondStruct2 = create_struct_cond(False)
582
+
583
+
584
+ @wp.kernel
585
+ def cond_struct_args_kernel(s1: CondStruct1, s2: CondStruct2, a: wp.array(dtype=float)):
586
+ a[0] = s1.v
587
+ a[1] = s2.v[0]
588
+ a[2] = s2.v[1]
589
+
590
+
591
+ def test_create_struct_cond_args(test, device):
592
+ """Test conditionally creating different structs and passing them as arguments."""
593
+ with wp.ScopedDevice(device):
594
+ s1 = CondStruct1()
595
+ s2 = CondStruct2()
596
+ s1.v = 1.0
597
+ s2.v = wp.vec2(2.0, 3.0)
598
+ a = wp.zeros(3, dtype=float)
599
+ wp.launch(cond_struct_args_kernel, dim=1, inputs=[s1, s2, a])
600
+ assert_np_equal(a.numpy(), np.array([1, 2, 3], dtype=np.float32))
601
+
602
+
603
+ @wp.kernel
604
+ def cond_struct_ctor_kernel(a: wp.array(dtype=float)):
605
+ s1 = CondStruct1()
606
+ s2 = CondStruct2()
607
+ s1.v = 1.0
608
+ s2.v = wp.vec2(2.0, 3.0)
609
+ a[0] = s1.v
610
+ a[1] = s2.v[0]
611
+ a[2] = s2.v[1]
612
+
613
+
614
+ def test_create_struct_cond_ctor(test, device):
615
+ """Test conditionally creating different structs and passing them as arguments."""
616
+ with wp.ScopedDevice(device):
617
+ a = wp.zeros(3, dtype=float)
618
+ wp.launch(cond_struct_ctor_kernel, dim=1, inputs=[a])
619
+ assert_np_equal(a.numpy(), np.array([1, 2, 3], dtype=np.float32))
620
+
621
+
622
+ # =======================================================================
623
+
624
+
625
+ def create_generic_kernel_cond(cond):
626
+ if cond:
627
+
628
+ @wp.kernel
629
+ def k(x: Any, a: wp.array(dtype=Any)):
630
+ a[0] = x * x
631
+ else:
632
+
633
+ @wp.kernel
634
+ def k(x: Any, a: wp.array(dtype=Any)):
635
+ a[0] = x * x * x
636
+
637
+ return k
638
+
639
+
640
+ cond_generic_kernel_1 = create_generic_kernel_cond(True)
641
+ cond_generic_kernel_2 = create_generic_kernel_cond(False)
642
+
643
+
644
+ def test_create_generic_kernel_cond(test, device):
645
+ """Test creating different simple (non-closure) generic kernels."""
646
+ with wp.ScopedDevice(device):
647
+ ai = wp.zeros(1, dtype=int)
648
+ af = wp.zeros(1, dtype=float)
649
+
650
+ wp.launch(cond_generic_kernel_1, dim=1, inputs=[2, ai])
651
+ wp.launch(cond_generic_kernel_1, dim=1, inputs=[2.0, af])
652
+ test.assertEqual(ai.numpy()[0], 4)
653
+ test.assertEqual(af.numpy()[0], 4.0)
654
+
655
+ wp.launch(cond_generic_kernel_2, dim=1, inputs=[2, ai])
656
+ wp.launch(cond_generic_kernel_2, dim=1, inputs=[2.0, af])
657
+ test.assertEqual(ai.numpy()[0], 8)
658
+ test.assertEqual(af.numpy()[0], 8.0)
659
+
660
+
661
+ # =======================================================================
662
+
663
+
664
+ def create_generic_func_cond(cond):
665
+ if cond:
666
+
667
+ @wp.func
668
+ def f(x: Any):
669
+ return x * x
670
+ else:
671
+
672
+ @wp.func
673
+ def f(x: Any):
674
+ return x * x * x
675
+
676
+ return f
677
+
678
+
679
+ cond_generic_func_1 = create_generic_func_cond(True)
680
+ cond_generic_func_2 = create_generic_func_cond(False)
681
+
682
+
683
+ @wp.kernel
684
+ def cond_generic_func_kernel(
685
+ ai: wp.array(dtype=int),
686
+ af: wp.array(dtype=float),
687
+ ):
688
+ ai[0] = cond_generic_func_1(2)
689
+ af[0] = cond_generic_func_1(2.0)
690
+
691
+ ai[1] = cond_generic_func_2(2)
692
+ af[1] = cond_generic_func_2(2.0)
693
+
694
+
695
+ def test_create_generic_func_cond(test, device):
696
+ """Test creating different simple (non-closure) generic functions."""
697
+ with wp.ScopedDevice(device):
698
+ ai = wp.zeros(2, dtype=int)
699
+ af = wp.zeros(2, dtype=float)
700
+ wp.launch(cond_generic_func_kernel, dim=1, inputs=[ai, af])
701
+ assert_np_equal(ai.numpy(), np.array([4, 8], dtype=np.int32))
702
+ assert_np_equal(af.numpy(), np.array([4, 8], dtype=np.float32))
703
+
704
+
705
+ # =======================================================================
706
+
707
+
708
+ def create_kernel_closure(value: int):
709
+ # closure
710
+ @wp.kernel
711
+ def k(a: wp.array(dtype=int)):
712
+ a[0] = value
713
+
714
+ return k
715
+
716
+
717
+ closure_kernel_1 = create_kernel_closure(17)
718
+ closure_kernel_2 = create_kernel_closure(42)
719
+
720
+
721
+ def test_create_kernel_closure(test, device):
722
+ """Test creating kernel closures."""
723
+ with wp.ScopedDevice(device):
724
+ a = wp.zeros(1, dtype=int)
725
+
726
+ wp.launch(closure_kernel_1, dim=1, inputs=[a])
727
+ test.assertEqual(a.numpy()[0], 17)
728
+
729
+ wp.launch(closure_kernel_2, dim=1, inputs=[a])
730
+ test.assertEqual(a.numpy()[0], 42)
731
+
732
+
733
+ # =======================================================================
734
+
735
+
736
+ def create_func_closure(value: int):
737
+ # closure
738
+ @wp.func
739
+ def f():
740
+ return value
741
+
742
+ return f
743
+
744
+
745
+ closure_func_1 = create_func_closure(17)
746
+ closure_func_2 = create_func_closure(42)
747
+
748
+
749
+ @wp.kernel
750
+ def closure_func_kernel(a: wp.array(dtype=int)):
751
+ a[0] = closure_func_1()
752
+ a[1] = closure_func_2()
753
+
754
+
755
+ def test_create_func_closure(test, device):
756
+ """Test creating function closures."""
757
+ with wp.ScopedDevice(device):
758
+ a = wp.zeros(2, dtype=int)
759
+ wp.launch(closure_func_kernel, dim=1, inputs=[a])
760
+ assert_np_equal(a.numpy(), np.array([17, 42]))
761
+
762
+
763
+ # =======================================================================
764
+
765
+
766
+ def create_func_closure_overload(value: int):
767
+ @wp.func
768
+ def f():
769
+ return value
770
+
771
+ @wp.func
772
+ def f(x: int):
773
+ return value * x
774
+
775
+ # return overloaded closure function
776
+ return f
777
+
778
+
779
+ closure_func_overload_1 = create_func_closure_overload(2)
780
+ closure_func_overload_2 = create_func_closure_overload(3)
781
+
782
+
783
+ @wp.kernel
784
+ def closure_func_overload_kernel(a: wp.array(dtype=int)):
785
+ a[0] = closure_func_overload_1()
786
+ a[1] = closure_func_overload_1(2)
787
+ a[2] = closure_func_overload_2()
788
+ a[3] = closure_func_overload_2(2)
789
+
790
+
791
+ def test_create_func_closure_overload(test, device):
792
+ """Test creating overloaded function closures."""
793
+ with wp.ScopedDevice(device):
794
+ a = wp.zeros(4, dtype=int)
795
+ wp.launch(closure_func_overload_kernel, dim=1, inputs=[a])
796
+ assert_np_equal(a.numpy(), np.array([2, 4, 3, 6]))
797
+
798
+
799
+ # =======================================================================
800
+
801
+
802
+ def create_func_closure_overload_selfref(value: int):
803
+ @wp.func
804
+ def f():
805
+ return value
806
+
807
+ @wp.func
808
+ def f(x: int):
809
+ # reference another overload
810
+ return f() * x
811
+
812
+ # return overloaded closure function
813
+ return f
814
+
815
+
816
+ closure_func_overload_selfref_1 = create_func_closure_overload_selfref(2)
817
+ closure_func_overload_selfref_2 = create_func_closure_overload_selfref(3)
818
+
819
+
820
+ @wp.kernel
821
+ def closure_func_overload_selfref_kernel(a: wp.array(dtype=int)):
822
+ a[0] = closure_func_overload_selfref_1()
823
+ a[1] = closure_func_overload_selfref_1(2)
824
+ a[2] = closure_func_overload_selfref_2()
825
+ a[3] = closure_func_overload_selfref_2(2)
826
+
827
+
828
+ def test_create_func_closure_overload_selfref(test, device):
829
+ """Test creating overloaded function closures with self-referencing overloads."""
830
+ with wp.ScopedDevice(device):
831
+ a = wp.zeros(4, dtype=int)
832
+ wp.launch(closure_func_overload_selfref_kernel, dim=1, inputs=[a])
833
+ assert_np_equal(a.numpy(), np.array([2, 4, 3, 6]))
834
+
835
+
836
+ # =======================================================================
837
+
838
+
839
+ def create_func_closure_nonoverload(dtype, value):
840
+ @wp.func
841
+ def f(x: dtype):
842
+ return x * value
843
+
844
+ return f
845
+
846
+
847
+ # functions created in different scopes should NOT be overloads of each other
848
+ # (i.e., creating new functions with the same signature should not replace previous ones)
849
+ closure_func_nonoverload_1 = create_func_closure_nonoverload(int, 2)
850
+ closure_func_nonoverload_2 = create_func_closure_nonoverload(float, 2.0)
851
+ closure_func_nonoverload_3 = create_func_closure_nonoverload(int, 3)
852
+ closure_func_nonoverload_4 = create_func_closure_nonoverload(float, 3.0)
853
+
854
+
855
+ @wp.kernel
856
+ def closure_func_nonoverload_kernel(
857
+ ai: wp.array(dtype=int),
858
+ af: wp.array(dtype=float),
859
+ ):
860
+ ai[0] = closure_func_nonoverload_1(2)
861
+ af[0] = closure_func_nonoverload_2(2.0)
862
+ ai[1] = closure_func_nonoverload_3(2)
863
+ af[1] = closure_func_nonoverload_4(2.0)
864
+
865
+
866
+ def test_create_func_closure_nonoverload(test, device):
867
+ """Test creating function closures that are not overloads of each other (overloads are grouped by scope, not globally)."""
868
+ with wp.ScopedDevice(device):
869
+ ai = wp.zeros(2, dtype=int)
870
+ af = wp.zeros(2, dtype=float)
871
+ wp.launch(closure_func_nonoverload_kernel, dim=1, inputs=[ai, af])
872
+ assert_np_equal(ai.numpy(), np.array([4, 6], dtype=np.int32))
873
+ assert_np_equal(af.numpy(), np.array([4, 6], dtype=np.float32))
874
+
875
+
876
+ # =======================================================================
877
+
878
+
879
+ def create_fk_closure(a, b):
880
+ # closure
881
+ @wp.func
882
+ def f():
883
+ return a
884
+
885
+ # closure
886
+ @wp.kernel
887
+ def k(a: wp.array(dtype=int)):
888
+ a[0] = f() + b
889
+
890
+ return f, k
891
+
892
+
893
+ fk_closure_func_1, fk_closure_kernel_1 = create_fk_closure(10, 7)
894
+ fk_closure_func_2, fk_closure_kernel_2 = create_fk_closure(40, 2)
895
+
896
+
897
+ # use generated functions in a new kernel
898
+ @wp.kernel
899
+ def fk_closure_combine_kernel(a: wp.array(dtype=int)):
900
+ a[0] = fk_closure_func_1() + fk_closure_func_2()
901
+
902
+
903
+ def test_create_fk_closure(test, device):
904
+ """Test creating function and kernel closures together, then reusing the functions in another kernel."""
905
+ with wp.ScopedDevice(device):
906
+ a = wp.zeros(1, dtype=int)
907
+
908
+ wp.launch(fk_closure_kernel_1, dim=1, inputs=[a])
909
+ test.assertEqual(a.numpy()[0], 17)
910
+
911
+ wp.launch(fk_closure_kernel_2, dim=1, inputs=[a])
912
+ test.assertEqual(a.numpy()[0], 42)
913
+
914
+ wp.launch(fk_closure_combine_kernel, dim=1, inputs=[a])
915
+ test.assertEqual(a.numpy()[0], 50)
916
+
917
+
918
+ # =======================================================================
919
+
920
+
921
+ def create_generic_kernel_closure(value):
922
+ @wp.kernel
923
+ def k(x: Any, a: wp.array(dtype=Any)):
924
+ a[0] = x * type(x)(value)
925
+
926
+ return k
927
+
928
+
929
+ generic_closure_kernel_1 = create_generic_kernel_closure(2)
930
+ generic_closure_kernel_2 = create_generic_kernel_closure(3)
931
+
932
+
933
+ def test_create_generic_kernel_closure(test, device):
934
+ """Test creating generic closure kernels."""
935
+ with wp.ScopedDevice(device):
936
+ ai = wp.zeros(1, dtype=int)
937
+ af = wp.zeros(1, dtype=float)
938
+
939
+ wp.launch(generic_closure_kernel_1, dim=1, inputs=[2, ai])
940
+ wp.launch(generic_closure_kernel_1, dim=1, inputs=[2.0, af])
941
+ test.assertEqual(ai.numpy()[0], 4)
942
+ test.assertEqual(af.numpy()[0], 4.0)
943
+
944
+ wp.launch(generic_closure_kernel_2, dim=1, inputs=[2, ai])
945
+ wp.launch(generic_closure_kernel_2, dim=1, inputs=[2.0, af])
946
+ test.assertEqual(ai.numpy()[0], 6)
947
+ test.assertEqual(af.numpy()[0], 6.0)
948
+
949
+
950
+ # =======================================================================
951
+
952
+
953
+ def create_generic_kernel_overload_closure(value, dtype):
954
+ @wp.kernel
955
+ def k(x: Any, a: wp.array(dtype=Any)):
956
+ a[0] = x * type(x)(value)
957
+
958
+ # return only the overload, not the generic kernel
959
+ return wp.overload(k, [dtype, wp.array(dtype=dtype)])
960
+
961
+
962
+ generic_closure_kernel_overload_i1 = create_generic_kernel_overload_closure(2, int)
963
+ generic_closure_kernel_overload_i2 = create_generic_kernel_overload_closure(3, int)
964
+ generic_closure_kernel_overload_f1 = create_generic_kernel_overload_closure(2, float)
965
+ generic_closure_kernel_overload_f2 = create_generic_kernel_overload_closure(3, float)
966
+
967
+
968
+ def test_create_generic_kernel_overload_closure(test, device):
969
+ """Test creating generic closure kernels, but return only overloads, not the generic kernels themselves."""
970
+ with wp.ScopedDevice(device):
971
+ ai = wp.zeros(1, dtype=int)
972
+ af = wp.zeros(1, dtype=float)
973
+
974
+ wp.launch(generic_closure_kernel_overload_i1, dim=1, inputs=[2, ai])
975
+ wp.launch(generic_closure_kernel_overload_f1, dim=1, inputs=[2.0, af])
976
+ test.assertEqual(ai.numpy()[0], 4)
977
+ test.assertEqual(af.numpy()[0], 4.0)
978
+
979
+ wp.launch(generic_closure_kernel_overload_i2, dim=1, inputs=[2, ai])
980
+ wp.launch(generic_closure_kernel_overload_f2, dim=1, inputs=[2.0, af])
981
+ test.assertEqual(ai.numpy()[0], 6)
982
+ test.assertEqual(af.numpy()[0], 6.0)
983
+
984
+
985
+ # =======================================================================
986
+
987
+
988
+ def create_generic_func_closure(value):
989
+ @wp.func
990
+ def f(x: Any):
991
+ return x * type(x)(value)
992
+
993
+ return f
994
+
995
+
996
+ generic_closure_func_1 = create_generic_func_closure(2)
997
+ generic_closure_func_2 = create_generic_func_closure(3)
998
+
999
+
1000
+ @wp.kernel
1001
+ def closure_generic_func_kernel(
1002
+ ai: wp.array(dtype=int),
1003
+ af: wp.array(dtype=float),
1004
+ ):
1005
+ ai[0] = generic_closure_func_1(2)
1006
+ af[0] = generic_closure_func_1(2.0)
1007
+
1008
+ ai[1] = generic_closure_func_2(2)
1009
+ af[1] = generic_closure_func_2(2.0)
1010
+
1011
+
1012
+ def test_create_generic_func_closure(test, device):
1013
+ """Test creating generic closure functions."""
1014
+ with wp.ScopedDevice(device):
1015
+ ai = wp.zeros(2, dtype=int)
1016
+ af = wp.zeros(2, dtype=float)
1017
+ wp.launch(closure_generic_func_kernel, dim=1, inputs=[ai, af])
1018
+ assert_np_equal(ai.numpy(), np.array([4, 6], dtype=np.int32))
1019
+ assert_np_equal(af.numpy(), np.array([4, 6], dtype=np.float32))
1020
+
1021
+
1022
+ # =======================================================================
1023
+
1024
+
1025
+ def create_generic_func_closure_overload(value):
1026
+ @wp.func
1027
+ def f(x: Any):
1028
+ return x * type(x)(value)
1029
+
1030
+ @wp.func
1031
+ def f(x: Any, y: Any):
1032
+ return f(x + y)
1033
+
1034
+ # return overloaded generic closure function
1035
+ return f
1036
+
1037
+
1038
+ generic_closure_func_overload_1 = create_generic_func_closure_overload(2)
1039
+ generic_closure_func_overload_2 = create_generic_func_closure_overload(3)
1040
+
1041
+
1042
+ @wp.kernel
1043
+ def generic_closure_func_overload_kernel(
1044
+ ai: wp.array(dtype=int),
1045
+ af: wp.array(dtype=float),
1046
+ ):
1047
+ ai[0] = generic_closure_func_overload_1(1) # 1 * 2 = 2
1048
+ ai[1] = generic_closure_func_overload_2(1) # 1 * 3 = 3
1049
+ ai[2] = generic_closure_func_overload_1(1, 2) # (1 + 2) * 2 = 6
1050
+ ai[3] = generic_closure_func_overload_2(1, 2) # (1 + 2) * 3 = 9
1051
+
1052
+ af[0] = generic_closure_func_overload_1(1.0) # 1 * 2 = 2
1053
+ af[1] = generic_closure_func_overload_2(1.0) # 1 * 3 = 3
1054
+ af[2] = generic_closure_func_overload_1(1.0, 2.0) # (1 + 2) * 2 = 6
1055
+ af[3] = generic_closure_func_overload_2(1.0, 2.0) # (1 + 2) * 3 = 9
1056
+
1057
+
1058
+ def test_create_generic_func_closure_overload(test, device):
1059
+ """Test creating overloaded generic function closures."""
1060
+ with wp.ScopedDevice(device):
1061
+ ai = wp.zeros(4, dtype=int)
1062
+ af = wp.zeros(4, dtype=float)
1063
+ wp.launch(generic_closure_func_overload_kernel, dim=1, inputs=[ai, af])
1064
+ assert_np_equal(ai.numpy(), np.array([2, 3, 6, 9], dtype=np.int32))
1065
+ assert_np_equal(af.numpy(), np.array([2, 3, 6, 9], dtype=np.float32))
1066
+
1067
+
1068
+ # =======================================================================
1069
+
1070
+
1071
+ def create_type_closure_scalar(scalar_type):
1072
+ @wp.kernel
1073
+ def k(input: float, expected: float):
1074
+ x = scalar_type(input)
1075
+ wp.expect_eq(float(x), expected)
1076
+
1077
+ return k
1078
+
1079
+
1080
+ type_closure_kernel_int = create_type_closure_scalar(int)
1081
+ type_closure_kernel_float = create_type_closure_scalar(float)
1082
+ type_closure_kernel_uint8 = create_type_closure_scalar(wp.uint8)
1083
+
1084
+
1085
+ def test_type_closure_scalar(test, device):
1086
+ with wp.ScopedDevice(device):
1087
+ wp.launch(type_closure_kernel_int, dim=1, inputs=[-1.5, -1.0])
1088
+ wp.launch(type_closure_kernel_float, dim=1, inputs=[-1.5, -1.5])
1089
+
1090
+ # FIXME: a problem with type conversions breaks this case
1091
+ # wp.launch(type_closure_kernel_uint8, dim=1, inputs=[-1.5, 255.0])
1092
+
1093
+
1094
+ # =======================================================================
1095
+
1096
+
1097
+ def create_type_closure_vector(vec_type):
1098
+ @wp.kernel
1099
+ def k(expected: float):
1100
+ v = vec_type(1.0)
1101
+ wp.expect_eq(wp.length_sq(v), expected)
1102
+
1103
+ return k
1104
+
1105
+
1106
+ type_closure_kernel_vec2 = create_type_closure_vector(wp.vec2)
1107
+ type_closure_kernel_vec3 = create_type_closure_vector(wp.vec3)
1108
+
1109
+
1110
+ def test_type_closure_vector(test, device):
1111
+ with wp.ScopedDevice(device):
1112
+ wp.launch(type_closure_kernel_vec2, dim=1, inputs=[2.0])
1113
+ wp.launch(type_closure_kernel_vec3, dim=1, inputs=[3.0])
1114
+
1115
+
1116
+ # =======================================================================
1117
+
1118
+
1119
+ @wp.struct
1120
+ class ClosureStruct1:
1121
+ v: float
1122
+
1123
+
1124
+ @wp.struct
1125
+ class ClosureStruct2:
1126
+ v: wp.vec2
1127
+
1128
+
1129
+ @wp.func
1130
+ def closure_struct_func(s: ClosureStruct1):
1131
+ return 17.0
1132
+
1133
+
1134
+ @wp.func
1135
+ def closure_struct_func(s: ClosureStruct2):
1136
+ return 42.0
1137
+
1138
+
1139
+ def create_type_closure_struct(struct_type):
1140
+ @wp.kernel
1141
+ def k(expected: float):
1142
+ s = struct_type()
1143
+ result = closure_struct_func(s)
1144
+ wp.expect_eq(result, expected)
1145
+
1146
+ return k
1147
+
1148
+
1149
+ type_closure_kernel_struct1 = create_type_closure_struct(ClosureStruct1)
1150
+ type_closure_kernel_struct2 = create_type_closure_struct(ClosureStruct2)
1151
+
1152
+
1153
+ def test_type_closure_struct(test, device):
1154
+ with wp.ScopedDevice(device):
1155
+ wp.launch(type_closure_kernel_struct1, dim=1, inputs=[17.0])
1156
+ wp.launch(type_closure_kernel_struct2, dim=1, inputs=[42.0])
1157
+
1158
+
1159
+ # =======================================================================
1160
+
1161
+
1162
+ @wp.kernel
1163
+ def name_clash_func_kernel(a: wp.array(dtype=int)):
1164
+ a[0] = name_clash_module_1.same_func()
1165
+ a[1] = name_clash_module_2.same_func()
1166
+ a[2] = name_clash_module_1.different_func()
1167
+ a[3] = name_clash_module_2.different_func()
1168
+
1169
+
1170
+ def test_name_clash_func(test, device):
1171
+ """Test using identically named functions from different modules"""
1172
+ with wp.ScopedDevice(device):
1173
+ a = wp.zeros(4, dtype=int)
1174
+ wp.launch(name_clash_func_kernel, dim=1, inputs=[a])
1175
+ assert_np_equal(a.numpy(), np.array([99, 99, 17, 42]))
1176
+
1177
+
1178
+ # =======================================================================
1179
+
1180
+
1181
+ @wp.kernel
1182
+ def name_clash_structs_args_kernel(
1183
+ s1: name_clash_module_1.SameStruct,
1184
+ s2: name_clash_module_2.SameStruct,
1185
+ d1: name_clash_module_1.DifferentStruct,
1186
+ d2: name_clash_module_2.DifferentStruct,
1187
+ a: wp.array(dtype=float),
1188
+ ):
1189
+ a[0] = s1.x
1190
+ a[1] = s2.x
1191
+ a[2] = d1.v
1192
+ a[3] = d2.v[0]
1193
+ a[4] = d2.v[1]
1194
+
1195
+
1196
+ def test_name_clash_struct_args(test, device):
1197
+ with wp.ScopedDevice(device):
1198
+ s1 = name_clash_module_1.SameStruct()
1199
+ s2 = name_clash_module_2.SameStruct()
1200
+ d1 = name_clash_module_1.DifferentStruct()
1201
+ d2 = name_clash_module_2.DifferentStruct()
1202
+ s1.x = 1.0
1203
+ s2.x = 2.0
1204
+ d1.v = 3.0
1205
+ d2.v = wp.vec2(4.0, 5.0)
1206
+ a = wp.zeros(5, dtype=float)
1207
+ wp.launch(name_clash_structs_args_kernel, dim=1, inputs=[s1, s2, d1, d2, a])
1208
+ assert_np_equal(a.numpy(), np.array([1, 2, 3, 4, 5], dtype=np.float32))
1209
+
1210
+
1211
+ # =======================================================================
1212
+
1213
+
1214
+ @wp.kernel
1215
+ def name_clash_structs_ctor_kernel(
1216
+ a: wp.array(dtype=float),
1217
+ ):
1218
+ s1 = name_clash_module_1.SameStruct()
1219
+ s2 = name_clash_module_2.SameStruct()
1220
+ d1 = name_clash_module_1.DifferentStruct()
1221
+ d2 = name_clash_module_2.DifferentStruct()
1222
+
1223
+ s1.x = 1.0
1224
+ s2.x = 2.0
1225
+ d1.v = 3.0
1226
+ d2.v = wp.vec2(4.0, 5.0)
1227
+
1228
+ a[0] = s1.x
1229
+ a[1] = s2.x
1230
+ a[2] = d1.v
1231
+ a[3] = d2.v[0]
1232
+ a[4] = d2.v[1]
1233
+
1234
+
1235
+ def test_name_clash_struct_ctor(test, device):
1236
+ with wp.ScopedDevice(device):
1237
+ a = wp.zeros(5, dtype=float)
1238
+ wp.launch(name_clash_structs_ctor_kernel, dim=1, inputs=[a])
1239
+ assert_np_equal(a.numpy(), np.array([1, 2, 3, 4, 5], dtype=np.float32))
1240
+
1241
+
1242
+ # =======================================================================
1243
+
1244
+
1245
+ def test_create_kernel_loop(test, device):
1246
+ """
1247
+ Test creating a kernel in a loop. The kernel is always the same,
1248
+ so the module hash doesn't change and the module shouldn't be reloaded.
1249
+ This test ensures that the kernel hooks are found for new duplicate kernels.
1250
+ """
1251
+
1252
+ with wp.ScopedDevice(device):
1253
+ for _ in range(5):
1254
+
1255
+ @wp.kernel
1256
+ def k():
1257
+ pass
1258
+
1259
+ wp.launch(k, dim=1)
1260
+ wp.synchronize_device()
1261
+
1262
+
1263
+ # =======================================================================
1264
+
1265
+
1266
+ def test_module_mark_modified(test, device):
1267
+ """Test that Module.mark_modified() forces module rehashing and reloading."""
1268
+
1269
+ with wp.ScopedDevice(device):
1270
+
1271
+ @wp.kernel
1272
+ def k(expected: int):
1273
+ wp.expect_eq(C, expected)
1274
+
1275
+ C = 17
1276
+ wp.launch(k, dim=1, inputs=[17])
1277
+ wp.synchronize_device()
1278
+
1279
+ # redefine constant and force rehashing on next launch
1280
+ C = 42
1281
+ k.module.mark_modified()
1282
+
1283
+ wp.launch(k, dim=1, inputs=[42])
1284
+ wp.synchronize_device()
1285
+
1286
+
1287
+ # =======================================================================
1288
+
1289
+
1290
+ class TestCodeGenInstancing(unittest.TestCase):
1291
+ pass
1292
+
1293
+
1294
+ devices = get_test_devices()
1295
+
1296
+ # global redefinitions with retained references
1297
+ add_function_test(
1298
+ TestCodeGenInstancing, func=test_global_kernel_redefine, name="test_global_kernel_redefine", devices=devices
1299
+ )
1300
+ add_function_test(
1301
+ TestCodeGenInstancing, func=test_global_func_redefine, name="test_global_func_redefine", devices=devices
1302
+ )
1303
+ add_function_test(
1304
+ TestCodeGenInstancing,
1305
+ func=test_global_struct_args_redefine,
1306
+ name="test_global_struct_args_redefine",
1307
+ devices=devices,
1308
+ )
1309
+ add_function_test(
1310
+ TestCodeGenInstancing,
1311
+ func=test_global_struct_ctor_redefine,
1312
+ name="test_global_struct_ctor_redefine",
1313
+ devices=devices,
1314
+ )
1315
+ add_function_test(
1316
+ TestCodeGenInstancing,
1317
+ func=test_global_overload_primary_redefine,
1318
+ name="test_global_overload_primary_redefine",
1319
+ devices=devices,
1320
+ )
1321
+ add_function_test(
1322
+ TestCodeGenInstancing,
1323
+ func=test_global_overload_secondary_redefine,
1324
+ name="test_global_overload_secondary_redefine",
1325
+ devices=devices,
1326
+ )
1327
+ add_function_test(
1328
+ TestCodeGenInstancing,
1329
+ func=test_global_generic_kernel_redefine,
1330
+ name="test_global_generic_kernel_redefine",
1331
+ devices=devices,
1332
+ )
1333
+ add_function_test(
1334
+ TestCodeGenInstancing,
1335
+ func=test_global_generic_func_redefine,
1336
+ name="test_global_generic_func_redefine",
1337
+ devices=devices,
1338
+ )
1339
+
1340
+ # create identical simple kernels, functions, and structs
1341
+ add_function_test(
1342
+ TestCodeGenInstancing, func=test_create_kernel_simple, name="test_create_kernel_simple", devices=devices
1343
+ )
1344
+ add_function_test(TestCodeGenInstancing, func=test_create_func_simple, name="test_create_func_simple", devices=devices)
1345
+ add_function_test(
1346
+ TestCodeGenInstancing, func=test_create_struct_simple_args, name="test_create_struct_simple_args", devices=devices
1347
+ )
1348
+ add_function_test(
1349
+ TestCodeGenInstancing, func=test_create_struct_simple_ctor, name="test_create_struct_simple_ctor", devices=devices
1350
+ )
1351
+ add_function_test(
1352
+ TestCodeGenInstancing,
1353
+ func=test_create_generic_kernel_simple,
1354
+ name="test_create_generic_kernel_simple",
1355
+ devices=devices,
1356
+ )
1357
+ add_function_test(
1358
+ TestCodeGenInstancing, func=test_create_generic_func_simple, name="test_create_generic_func_simple", devices=devices
1359
+ )
1360
+
1361
+ # create different simple kernels, functions, and structs
1362
+ add_function_test(TestCodeGenInstancing, func=test_create_kernel_cond, name="test_create_kernel_cond", devices=devices)
1363
+ add_function_test(TestCodeGenInstancing, func=test_create_func_cond, name="test_create_func_cond", devices=devices)
1364
+ add_function_test(
1365
+ TestCodeGenInstancing, func=test_create_struct_cond_args, name="test_create_struct_cond_args", devices=devices
1366
+ )
1367
+ add_function_test(
1368
+ TestCodeGenInstancing, func=test_create_struct_cond_ctor, name="test_create_struct_cond_ctor", devices=devices
1369
+ )
1370
+ add_function_test(
1371
+ TestCodeGenInstancing, func=test_create_generic_kernel_cond, name="test_create_generic_kernel_cond", devices=devices
1372
+ )
1373
+ add_function_test(
1374
+ TestCodeGenInstancing, func=test_create_generic_func_cond, name="test_create_generic_func_cond", devices=devices
1375
+ )
1376
+
1377
+ # closure kernels and functions
1378
+ add_function_test(
1379
+ TestCodeGenInstancing, func=test_create_kernel_closure, name="test_create_kernel_closure", devices=devices
1380
+ )
1381
+ add_function_test(
1382
+ TestCodeGenInstancing, func=test_create_func_closure, name="test_create_func_closure", devices=devices
1383
+ )
1384
+ add_function_test(
1385
+ TestCodeGenInstancing,
1386
+ func=test_create_func_closure_overload,
1387
+ name="test_create_func_closure_overload",
1388
+ devices=devices,
1389
+ )
1390
+ add_function_test(
1391
+ TestCodeGenInstancing,
1392
+ func=test_create_func_closure_overload_selfref,
1393
+ name="test_create_func_closure_overload_selfref",
1394
+ devices=devices,
1395
+ )
1396
+ add_function_test(
1397
+ TestCodeGenInstancing,
1398
+ func=test_create_func_closure_nonoverload,
1399
+ name="test_create_func_closure_nonoverload",
1400
+ devices=devices,
1401
+ )
1402
+ add_function_test(TestCodeGenInstancing, func=test_create_fk_closure, name="test_create_fk_closure", devices=devices)
1403
+ add_function_test(
1404
+ TestCodeGenInstancing,
1405
+ func=test_create_generic_kernel_closure,
1406
+ name="test_create_generic_kernel_closure",
1407
+ devices=devices,
1408
+ )
1409
+ add_function_test(
1410
+ TestCodeGenInstancing,
1411
+ func=test_create_generic_kernel_overload_closure,
1412
+ name="test_create_generic_kernel_overload_closure",
1413
+ devices=devices,
1414
+ )
1415
+ add_function_test(
1416
+ TestCodeGenInstancing,
1417
+ func=test_create_generic_func_closure,
1418
+ name="test_create_generic_func_closure",
1419
+ devices=devices,
1420
+ )
1421
+ add_function_test(
1422
+ TestCodeGenInstancing,
1423
+ func=test_create_generic_func_closure_overload,
1424
+ name="test_create_generic_func_closure_overload",
1425
+ devices=devices,
1426
+ )
1427
+
1428
+ # type closures
1429
+ add_function_test(
1430
+ TestCodeGenInstancing, func=test_type_closure_scalar, name="test_type_closure_scalar", devices=devices
1431
+ )
1432
+ add_function_test(
1433
+ TestCodeGenInstancing, func=test_type_closure_vector, name="test_type_closure_vector", devices=devices
1434
+ )
1435
+ add_function_test(
1436
+ TestCodeGenInstancing, func=test_type_closure_struct, name="test_type_closure_struct", devices=devices
1437
+ )
1438
+
1439
+ # test name clashes between modules
1440
+ add_function_test(TestCodeGenInstancing, func=test_name_clash_func, name="test_name_clash_func", devices=devices)
1441
+ add_function_test(
1442
+ TestCodeGenInstancing, func=test_name_clash_struct_args, name="test_name_clash_struct_args", devices=devices
1443
+ )
1444
+ add_function_test(
1445
+ TestCodeGenInstancing, func=test_name_clash_struct_ctor, name="test_name_clash_struct_ctor", devices=devices
1446
+ )
1447
+
1448
+ # miscellaneous tests
1449
+ add_function_test(TestCodeGenInstancing, func=test_create_kernel_loop, name="test_create_kernel_loop", devices=devices)
1450
+ add_function_test(
1451
+ TestCodeGenInstancing, func=test_module_mark_modified, name="test_module_mark_modified", devices=devices
1452
+ )
1453
+
1454
+
1455
+ if __name__ == "__main__":
1456
+ wp.clear_kernel_cache()
1457
+ unittest.main(verbosity=2)