warp-lang 1.4.2__py3-none-manylinux2014_x86_64.whl → 1.5.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/tests/test_func.py CHANGED
@@ -162,7 +162,7 @@ def user_func_with_defaults(a: int = 123, b: int = 234) -> int:
162
162
 
163
163
 
164
164
  @wp.kernel
165
- def test_user_func_with_defaults():
165
+ def user_func_with_defaults_kernel():
166
166
  a = user_func_with_defaults()
167
167
  wp.expect_eq(a, 357)
168
168
 
@@ -179,6 +179,25 @@ def test_user_func_with_defaults():
179
179
  wp.expect_eq(e, 234)
180
180
 
181
181
 
182
+ def test_user_func_with_defaults(test, device):
183
+ wp.launch(user_func_with_defaults_kernel, dim=1, device=device)
184
+
185
+ a = user_func_with_defaults()
186
+ assert a == 357
187
+
188
+ b = user_func_with_defaults(111)
189
+ assert b == 345
190
+
191
+ c = user_func_with_defaults(111, 222)
192
+ assert c == 333
193
+
194
+ d = user_func_with_defaults(a=111)
195
+ assert d == 345
196
+
197
+ e = user_func_with_defaults(b=111)
198
+ assert e == 234
199
+
200
+
182
201
  @wp.func
183
202
  def user_func_return_multiple_values(a: int, b: float) -> Tuple[int, float]:
184
203
  return a + a, b * b
@@ -222,6 +241,16 @@ def test_user_func_overload_resolution(test, device):
222
241
  assert a1.numpy()[0] == 12
223
242
 
224
243
 
244
+ @wp.func
245
+ def user_func_return_none() -> None:
246
+ pass
247
+
248
+
249
+ @wp.kernel
250
+ def test_return_annotation_none() -> None:
251
+ user_func_return_none()
252
+
253
+
225
254
  devices = get_test_devices()
226
255
 
227
256
 
@@ -396,9 +425,7 @@ add_function_test(TestFunc, func=test_func_closure_capture, name="test_func_clos
396
425
  add_function_test(TestFunc, func=test_multi_valued_func, name="test_multi_valued_func", devices=devices)
397
426
  add_kernel_test(TestFunc, kernel=test_func_defaults, name="test_func_defaults", dim=1, devices=devices)
398
427
  add_kernel_test(TestFunc, kernel=test_builtin_shadowing, name="test_builtin_shadowing", dim=1, devices=devices)
399
- add_kernel_test(
400
- TestFunc, kernel=test_user_func_with_defaults, name="test_user_func_with_defaults", dim=1, devices=devices
401
- )
428
+ add_function_test(TestFunc, func=test_user_func_with_defaults, name="test_user_func_with_defaults", devices=devices)
402
429
  add_kernel_test(
403
430
  TestFunc,
404
431
  kernel=test_user_func_return_multiple_values,
@@ -409,6 +436,9 @@ add_kernel_test(
409
436
  add_function_test(
410
437
  TestFunc, func=test_user_func_overload_resolution, name="test_user_func_overload_resolution", devices=devices
411
438
  )
439
+ add_kernel_test(
440
+ TestFunc, kernel=test_return_annotation_none, name="test_return_annotation_none", dim=1, devices=devices
441
+ )
412
442
 
413
443
 
414
444
  if __name__ == "__main__":
@@ -522,6 +522,57 @@ def test_type_attribute_error(test, device):
522
522
  )
523
523
 
524
524
 
525
+ @wp.func
526
+ def vec_int_annotation_func(v: wp.vec(3, wp.Int)) -> wp.Int:
527
+ return v[0] + v[1] + v[2]
528
+
529
+
530
+ @wp.func
531
+ def vec_float_annotation_func(v: wp.vec(3, wp.Float)) -> wp.Float:
532
+ return v[0] + v[1] + v[2]
533
+
534
+
535
+ @wp.func
536
+ def vec_scalar_annotation_func(v: wp.vec(3, wp.Scalar)) -> wp.Scalar:
537
+ return v[0] + v[1] + v[2]
538
+
539
+
540
+ @wp.func
541
+ def mat_int_annotation_func(m: wp.mat((2, 2), wp.Int)) -> wp.Int:
542
+ return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1]
543
+
544
+
545
+ @wp.func
546
+ def mat_float_annotation_func(m: wp.mat((2, 2), wp.Float)) -> wp.Float:
547
+ return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1]
548
+
549
+
550
+ @wp.func
551
+ def mat_scalar_annotation_func(m: wp.mat((2, 2), wp.Scalar)) -> wp.Scalar:
552
+ return m[0, 0] + m[0, 1] + m[1, 0] + m[1, 1]
553
+
554
+
555
+ mat22s = wp.mat((2, 2), wp.int16)
556
+ mat22d = wp.mat((2, 2), wp.float64)
557
+
558
+
559
+ @wp.kernel
560
+ def test_annotations_kernel():
561
+ vi16 = wp.vec3s(wp.int16(1), wp.int16(2), wp.int16(3))
562
+ vf64 = wp.vec3d(wp.float64(1), wp.float64(2), wp.float64(3))
563
+ wp.expect_eq(vec_int_annotation_func(vi16), wp.int16(6))
564
+ wp.expect_eq(vec_float_annotation_func(vf64), wp.float64(6))
565
+ wp.expect_eq(vec_scalar_annotation_func(vi16), wp.int16(6))
566
+ wp.expect_eq(vec_scalar_annotation_func(vf64), wp.float64(6))
567
+
568
+ mi16 = mat22s(wp.int16(1), wp.int16(2), wp.int16(3), wp.int16(4))
569
+ mf64 = mat22d(wp.float64(1), wp.float64(2), wp.float64(3), wp.float64(4))
570
+ wp.expect_eq(mat_int_annotation_func(mi16), wp.int16(10))
571
+ wp.expect_eq(mat_float_annotation_func(mf64), wp.float64(10))
572
+ wp.expect_eq(mat_scalar_annotation_func(mi16), wp.int16(10))
573
+ wp.expect_eq(mat_scalar_annotation_func(mf64), wp.float64(10))
574
+
575
+
525
576
  class TestGenerics(unittest.TestCase):
526
577
  pass
527
578
 
@@ -590,6 +641,7 @@ add_kernel_test(
590
641
  )
591
642
  add_function_test(TestGenerics, "test_type_operator_misspell", test_type_operator_misspell, devices=devices)
592
643
  add_function_test(TestGenerics, "test_type_attribute_error", test_type_attribute_error, devices=devices)
644
+ add_kernel_test(TestGenerics, name="test_annotations_kernel", kernel=test_annotations_kernel, dim=1, devices=devices)
593
645
 
594
646
  if __name__ == "__main__":
595
647
  wp.clear_kernel_cache()
@@ -0,0 +1,68 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ import unittest
9
+
10
+ import warp as wp
11
+ from warp.tests.unittest_utils import *
12
+
13
+
14
+ @wp.kernel
15
+ def reversed_kernel(
16
+ start: wp.int32,
17
+ end: wp.int32,
18
+ step: wp.int32,
19
+ out_count: wp.array(dtype=wp.int32),
20
+ out_values: wp.array(dtype=wp.int32),
21
+ ):
22
+ count = wp.int32(0)
23
+ for i in reversed(range(start, end, step)):
24
+ out_values[count] = i
25
+ count += 1
26
+
27
+ out_count[0] = count
28
+
29
+
30
+ def test_reversed(test, device):
31
+ count = wp.empty(1, dtype=wp.int32)
32
+ values = wp.empty(32, dtype=wp.int32)
33
+
34
+ start, end, step = (-2, 8, 3)
35
+ wp.launch(
36
+ reversed_kernel,
37
+ dim=1,
38
+ inputs=(start, end, step),
39
+ outputs=(count, values),
40
+ )
41
+ expected = tuple(reversed(range(start, end, step)))
42
+ assert count.numpy()[0] == len(expected)
43
+ assert_np_equal(values.numpy()[: len(expected)], expected)
44
+
45
+ start, end, step = (9, -3, -2)
46
+ wp.launch(
47
+ reversed_kernel,
48
+ dim=1,
49
+ inputs=(start, end, step),
50
+ outputs=(count, values),
51
+ )
52
+ expected = tuple(reversed(range(start, end, step)))
53
+ assert count.numpy()[0] == len(expected)
54
+ assert_np_equal(values.numpy()[: len(expected)], expected)
55
+
56
+
57
+ devices = get_test_devices()
58
+
59
+
60
+ class TestIter(unittest.TestCase):
61
+ pass
62
+
63
+
64
+ add_function_test(TestIter, "test_reversed", test_reversed, devices=devices)
65
+
66
+ if __name__ == "__main__":
67
+ wp.clear_kernel_cache()
68
+ unittest.main(verbosity=2)
warp/tests/test_lerp.py CHANGED
@@ -31,66 +31,14 @@ class TestData:
31
31
 
32
32
  TEST_DATA = {
33
33
  wp.float32: (
34
- TestData(
35
- a=1.0,
36
- b=5.0,
37
- t=0.75,
38
- expected=4.0,
39
- expected_adj_a=0.25,
40
- expected_adj_b=0.75,
41
- expected_adj_t=4.0,
42
- ),
43
- TestData(
44
- a=-2.0,
45
- b=5.0,
46
- t=0.25,
47
- expected=-0.25,
48
- expected_adj_a=0.75,
49
- expected_adj_b=0.25,
50
- expected_adj_t=7.0,
51
- ),
52
- TestData(
53
- a=1.23,
54
- b=2.34,
55
- t=0.5,
56
- expected=1.785,
57
- expected_adj_a=0.5,
58
- expected_adj_b=0.5,
59
- expected_adj_t=1.11,
60
- ),
61
- ),
62
- wp.vec2: (
63
- TestData(
64
- a=[1, 2],
65
- b=[3, 4],
66
- t=0.5,
67
- expected=[2, 3],
68
- ),
69
- ),
70
- wp.vec3: (
71
- TestData(
72
- a=[1, 2, 3],
73
- b=[3, 4, 5],
74
- t=0.5,
75
- expected=[2, 3, 4],
76
- ),
77
- ),
78
- wp.vec4: (
79
- TestData(
80
- a=[1, 2, 3, 4],
81
- b=[3, 4, 5, 6],
82
- t=0.5,
83
- expected=[2, 3, 4, 5],
84
- ),
85
- ),
86
- wp.mat22: (
87
- TestData(
88
- a=[[1, 2], [2, 1]],
89
- b=[[3, 4], [4, 3]],
90
- t=0.5,
91
- expected=[[2, 3], [3, 2]],
92
- ),
34
+ TestData(a=1.0, b=5.0, t=0.75, expected=4.0, expected_adj_a=0.25, expected_adj_b=0.75, expected_adj_t=4.0),
35
+ TestData(a=-2.0, b=5.0, t=0.25, expected=-0.25, expected_adj_a=0.75, expected_adj_b=0.25, expected_adj_t=7.0),
36
+ TestData(a=1.23, b=2.34, t=0.5, expected=1.785, expected_adj_a=0.5, expected_adj_b=0.5, expected_adj_t=1.11),
93
37
  ),
38
+ wp.vec2: (TestData(a=[1, 2], b=[3, 4], t=0.5, expected=[2, 3]),),
39
+ wp.vec3: (TestData(a=[1, 2, 3], b=[3, 4, 5], t=0.5, expected=[2, 3, 4]),),
40
+ wp.vec4: (TestData(a=[1, 2, 3, 4], b=[3, 4, 5, 6], t=0.5, expected=[2, 3, 4, 5]),),
41
+ wp.mat22: (TestData(a=[[1, 2], [2, 1]], b=[[3, 4], [4, 3]], t=0.5, expected=[[2, 3], [3, 2]]),),
94
42
  wp.mat33: (
95
43
  TestData(
96
44
  a=[[1, 2, 3], [3, 1, 2], [2, 3, 1]],
@@ -107,30 +55,9 @@ TEST_DATA = {
107
55
  expected=[[2, 3, 4, 5], [5, 2, 3, 4], [4, 5, 2, 3], [3, 4, 5, 2]],
108
56
  ),
109
57
  ),
110
- wp.quat: (
111
- TestData(
112
- a=[1, 2, 3, 4],
113
- b=[3, 4, 5, 6],
114
- t=0.5,
115
- expected=[2, 3, 4, 5],
116
- ),
117
- ),
118
- wp.transform: (
119
- TestData(
120
- a=[1, 2, 3, 4, 5, 6, 7],
121
- b=[3, 4, 5, 6, 7, 8, 9],
122
- t=0.5,
123
- expected=[2, 3, 4, 5, 6, 7, 8],
124
- ),
125
- ),
126
- wp.spatial_vector: (
127
- TestData(
128
- a=[1, 2, 3, 4, 5, 6],
129
- b=[3, 4, 5, 6, 7, 8],
130
- t=0.5,
131
- expected=[2, 3, 4, 5, 6, 7],
132
- ),
133
- ),
58
+ wp.quat: (TestData(a=[1, 2, 3, 4], b=[3, 4, 5, 6], t=0.5, expected=[2, 3, 4, 5]),),
59
+ wp.transform: (TestData(a=[1, 2, 3, 4, 5, 6, 7], b=[3, 4, 5, 6, 7, 8, 9], t=0.5, expected=[2, 3, 4, 5, 6, 7, 8]),),
60
+ wp.spatial_vector: (TestData(a=[1, 2, 3, 4, 5, 6], b=[3, 4, 5, 6, 7, 8], t=0.5, expected=[2, 3, 4, 5, 6, 7]),),
134
61
  wp.spatial_matrix: (
135
62
  TestData(
136
63
  a=[
@@ -175,12 +102,12 @@ def test_lerp(test, device):
175
102
 
176
103
  return fn
177
104
 
178
- for data_type in TEST_DATA:
105
+ for data_type, test_data_set in TEST_DATA.items():
179
106
  kernel_fn = make_kernel_fn(data_type)
180
107
  kernel = wp.Kernel(func=kernel_fn, key=f"test_lerp_{data_type.__name__}_kernel")
181
108
 
182
109
  with test.subTest(data_type=data_type):
183
- for test_data in TEST_DATA[data_type]:
110
+ for test_data in test_data_set:
184
111
  a = wp.array([test_data.a], dtype=data_type, device=device, requires_grad=True)
185
112
  b = wp.array([test_data.b], dtype=data_type, device=device, requires_grad=True)
186
113
  t = wp.array([test_data.t], dtype=float, device=device, requires_grad=True)
@@ -188,8 +115,7 @@ def test_lerp(test, device):
188
115
  [0] * wp.types.type_length(data_type), dtype=data_type, device=device, requires_grad=True
189
116
  )
190
117
 
191
- tape = wp.Tape()
192
- with tape:
118
+ with wp.Tape() as tape:
193
119
  wp.launch(kernel, dim=1, inputs=[a, b, t, out], device=device)
194
120
 
195
121
  assert_np_equal(out.numpy(), np.array([test_data.expected]), tol=1e-6)
@@ -1501,7 +1501,7 @@ def test_matmat_multiplication(test, device, dtype, register_kernels=False):
1501
1501
  tol = {
1502
1502
  np.float16: 2.0e-2,
1503
1503
  np.float32: 5.0e-6,
1504
- np.float64: 1.0e-8,
1504
+ np.float64: 5.0e-7,
1505
1505
  }.get(dtype, 0)
1506
1506
 
1507
1507
  wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
warp/tests/test_matmul.py CHANGED
@@ -5,6 +5,7 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ import itertools
8
9
  import unittest
9
10
  from typing import Any
10
11
 
@@ -105,19 +106,15 @@ class gemm_test_bed_runner:
105
106
  assert_np_equal(C.grad.numpy(), adj_C_np)
106
107
 
107
108
  def run(self):
108
- Ms = [64, 128, 256]
109
- Ns = [64, 128, 256]
110
- Ks = [64, 128, 256]
109
+ Ms = [16, 32, 64]
110
+ Ns = [16, 32, 64]
111
+ Ks = [16, 32, 64]
111
112
  batch_counts = [1, 4]
112
113
  betas = [0.0, 1.0]
113
114
  alpha = 1.0
114
115
 
115
- for batch_count in batch_counts:
116
- for m in Ms:
117
- for n in Ns:
118
- for k in Ks:
119
- for beta in betas:
120
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
116
+ for batch_count, m, n, k, beta in itertools.product(batch_counts, Ms, Ns, Ks, betas):
117
+ self.run_and_verify(m, n, k, batch_count, alpha, beta)
121
118
 
122
119
 
123
120
  class gemm_test_bed_runner_transpose:
@@ -102,19 +102,14 @@ class gemm_test_bed_runner:
102
102
  assert_np_equal(C.grad.numpy(), adj_C_np)
103
103
 
104
104
  def run(self):
105
- Ms = [8]
106
- Ns = [16]
107
- Ks = [32]
108
- batch_counts = [1]
109
- betas = [1.0]
105
+ m = 8
106
+ n = 16
107
+ k = 32
108
+ batch_count = 1
109
+ beta = 1.0
110
110
  alpha = 1.0
111
111
 
112
- for batch_count in batch_counts:
113
- for m in Ms:
114
- for n in Ns:
115
- for k in Ks:
116
- for beta in betas:
117
- self.run_and_verify(m, n, k, batch_count, alpha, beta)
112
+ self.run_and_verify(m, n, k, batch_count, alpha, beta)
118
113
 
119
114
 
120
115
  class gemm_test_bed_runner_transpose:
@@ -805,7 +805,7 @@ def test_set_mesh_points(test, device):
805
805
  device=device,
806
806
  )
807
807
 
808
- shift = np.random.randn(3)
808
+ shift = rng.standard_normal(size=3)
809
809
 
810
810
  vs_higher = vs + shift
811
811
  vertices2 = wp.array(vs_higher, dtype=wp.vec3, device=device)
@@ -214,12 +214,35 @@ def test_function_generic_overload_hashing(test, device):
214
214
  test.assertNotEqual(hash4, hash1)
215
215
 
216
216
 
217
+ SIMPLE_MODULE = """# -*- coding: utf-8 -*-
218
+ import warp as wp
219
+
220
+ @wp.kernel
221
+ def k():
222
+ pass
223
+ """
224
+
225
+
226
+ def test_module_load(test, device):
227
+ """Ensure that loading a module does not change its hash"""
228
+ m = load_code_as_module(SIMPLE_MODULE, "simple_module")
229
+
230
+ hash1 = m.hash_module()
231
+ m.load(device)
232
+ hash2 = m.hash_module()
233
+
234
+ test.assertEqual(hash1, hash2)
235
+
236
+
217
237
  class TestModuleHashing(unittest.TestCase):
218
238
  pass
219
239
 
220
240
 
241
+ devices = get_test_devices()
242
+
221
243
  add_function_test(TestModuleHashing, "test_function_overload_hashing", test_function_overload_hashing)
222
244
  add_function_test(TestModuleHashing, "test_function_generic_overload_hashing", test_function_generic_overload_hashing)
245
+ add_function_test(TestModuleHashing, "test_module_load", test_module_load, devices=devices)
223
246
 
224
247
 
225
248
  if __name__ == "__main__":
@@ -507,6 +507,50 @@ def test_kernel_read_func_write(test, device):
507
507
  wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
508
508
 
509
509
 
510
+ @wp.func
511
+ def atomic_func(
512
+ a: wp.array(dtype=wp.int32),
513
+ b: wp.array(dtype=wp.int32),
514
+ c: wp.array(dtype=wp.int32),
515
+ d: wp.array(dtype=wp.int32),
516
+ i: int,
517
+ ):
518
+ wp.atomic_add(a, i, 1)
519
+ wp.atomic_sub(b, i, 1)
520
+ wp.atomic_min(c, i, 1)
521
+ wp.atomic_max(d, i, 3)
522
+
523
+
524
+ @wp.kernel(enable_backward=False)
525
+ def atomic_kernel(
526
+ a: wp.array(dtype=wp.int32), b: wp.array(dtype=wp.int32), c: wp.array(dtype=wp.int32), d: wp.array(dtype=wp.int32)
527
+ ):
528
+ i = wp.tid()
529
+ atomic_func(a, b, c, d, i)
530
+
531
+
532
+ # atomic operations should mark arrays as WRITE
533
+ def test_atomic_operations(test, device):
534
+ saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access
535
+ try:
536
+ wp.config.verify_autograd_array_access = True
537
+
538
+ a = wp.array((1, 2, 3), dtype=wp.int32, device=device)
539
+ b = wp.array((1, 2, 3), dtype=wp.int32, device=device)
540
+ c = wp.array((1, 2, 3), dtype=wp.int32, device=device)
541
+ d = wp.array((1, 2, 3), dtype=wp.int32, device=device)
542
+
543
+ wp.launch(atomic_kernel, dim=a.shape, inputs=(a, b, c, d), device=device)
544
+
545
+ test.assertEqual(atomic_kernel.adj.args[0].is_write, True)
546
+ test.assertEqual(atomic_kernel.adj.args[1].is_write, True)
547
+ test.assertEqual(atomic_kernel.adj.args[2].is_write, True)
548
+ test.assertEqual(atomic_kernel.adj.args[3].is_write, True)
549
+
550
+ finally:
551
+ wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting
552
+
553
+
510
554
  class TestOverwrite(unittest.TestCase):
511
555
  pass
512
556
 
@@ -535,6 +579,7 @@ add_function_test(TestOverwrite, "test_reset", test_reset, devices=devices)
535
579
  add_function_test(TestOverwrite, "test_copy", test_copy, devices=devices)
536
580
  add_function_test(TestOverwrite, "test_matmul", test_matmul, devices=devices)
537
581
  add_function_test(TestOverwrite, "test_batched_matmul", test_batched_matmul, devices=devices)
582
+ add_function_test(TestOverwrite, "test_atomic_operations", test_atomic_operations, devices=devices)
538
583
 
539
584
  # Some warning are only issued during codegen, and codegen only runs on cuda_0 in the MGPU case.
540
585
  cuda_device = get_cuda_test_devices(mode="basic")