warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/tests/test_rand.py CHANGED
@@ -26,6 +26,8 @@ def test_kernel(
26
26
  kernel_seed: int,
27
27
  int_a: wp.array(dtype=int),
28
28
  int_ab: wp.array(dtype=int),
29
+ uint_a: wp.array(dtype=wp.uint32),
30
+ uint_ab: wp.array(dtype=wp.uint32),
29
31
  float_01: wp.array(dtype=float),
30
32
  float_ab: wp.array(dtype=float),
31
33
  ):
@@ -35,6 +37,8 @@ def test_kernel(
35
37
 
36
38
  int_a[tid] = wp.randi(state)
37
39
  int_ab[tid] = wp.randi(state, 0, 100)
40
+ uint_a[tid] = wp.randu(state)
41
+ uint_ab[tid] = wp.randu(state, wp.uint32(0), wp.uint32(100))
38
42
  float_01[tid] = wp.randf(state)
39
43
  float_ab[tid] = wp.randf(state, 0.0, 100.0)
40
44
 
@@ -42,37 +46,25 @@ def test_kernel(
42
46
  def test_rand(test, device):
43
47
  N = 10
44
48
 
45
- int_a_device = wp.zeros(N, dtype=int, device=device)
46
- int_a_host = wp.zeros(N, dtype=int, device="cpu")
47
- int_ab_device = wp.zeros(N, dtype=int, device=device)
48
- int_ab_host = wp.zeros(N, dtype=int, device="cpu")
49
+ int_a = wp.zeros(N, dtype=int, device=device)
50
+ int_ab = wp.zeros(N, dtype=int, device=device)
49
51
 
50
- float_01_device = wp.zeros(N, dtype=float, device=device)
51
- float_01_host = wp.zeros(N, dtype=float, device="cpu")
52
- float_ab_device = wp.zeros(N, dtype=float, device=device)
53
- float_ab_host = wp.zeros(N, dtype=float, device="cpu")
52
+ uint_a = wp.zeros(N, dtype=wp.uint32, device=device)
53
+ uint_ab = wp.zeros(N, dtype=wp.uint32, device=device)
54
+
55
+ float_01 = wp.zeros(N, dtype=float, device=device)
56
+ float_ab = wp.zeros(N, dtype=float, device=device)
54
57
 
55
58
  seed = 42
56
59
 
57
60
  wp.launch(
58
61
  kernel=test_kernel,
59
62
  dim=N,
60
- inputs=[seed, int_a_device, int_ab_device, float_01_device, float_ab_device],
63
+ inputs=[seed, int_a, int_ab, uint_a, uint_ab, float_01, float_ab],
61
64
  outputs=[],
62
65
  device=device,
63
66
  )
64
67
 
65
- wp.copy(int_a_host, int_a_device)
66
- wp.copy(int_ab_host, int_ab_device)
67
- wp.copy(float_01_host, float_01_device)
68
- wp.copy(float_ab_host, float_ab_device)
69
- wp.synchronize_device(device)
70
-
71
- int_a = int_a_host.numpy()
72
- int_ab = int_ab_host.numpy()
73
- float_01 = float_01_host.numpy()
74
- float_ab = float_ab_host.numpy()
75
-
76
68
  int_a_true = np.array(
77
69
  [
78
70
  -575632308,
@@ -88,32 +80,47 @@ def test_rand(test, device):
88
80
  ]
89
81
  )
90
82
  int_ab_true = np.array([46, 58, 46, 83, 85, 39, 72, 99, 18, 41])
83
+ uint_a_true = np.array(
84
+ [
85
+ 3133687854,
86
+ 3702303309,
87
+ 1235698096,
88
+ 3516599792,
89
+ 800302729,
90
+ 2620462179,
91
+ 2423739693,
92
+ 3024873594,
93
+ 2783682377,
94
+ 1188846332,
95
+ ]
96
+ )
97
+ uint_ab_true = np.array([6, 55, 2, 92, 55, 93, 65, 23, 48, 0])
91
98
  float_01_true = np.array(
92
99
  [
93
- 0.72961855,
94
- 0.86200964,
95
- 0.28770837,
96
- 0.8187722,
97
- 0.186335,
98
- 0.6101239,
99
- 0.56432086,
100
- 0.70428324,
101
- 0.64812654,
102
- 0.27679986,
100
+ 0.8265858,
101
+ 0.5874614,
102
+ 0.1508659,
103
+ 0.9498008,
104
+ 0.02531803,
105
+ 0.8520948,
106
+ 0.0001185536,
107
+ 0.4855958,
108
+ 0.06277305,
109
+ 0.2214079,
103
110
  ]
104
111
  )
105
112
  float_ab_true = np.array(
106
- [96.04259, 73.33809, 63.601555, 38.647305, 71.813896, 64.65809, 77.79791, 46.579605, 94.614456, 91.921814]
113
+ [79.84678, 76.362206, 32.135242, 99.70866, 70.45863, 20.6523, 45.164482, 55.583008, 76.60291, 35.36277]
107
114
  )
108
115
 
109
- test.assertTrue((int_a == int_a_true).all())
110
- test.assertTrue((int_ab == int_ab_true).all())
116
+ assert_np_equal(int_a.numpy(), int_a_true)
117
+ assert_np_equal(int_ab.numpy(), int_ab_true)
111
118
 
112
- err = np.max(np.abs(float_01 - float_01_true))
113
- test.assertTrue(err < 1e-04)
119
+ assert_np_equal(uint_a.numpy(), uint_a_true)
120
+ assert_np_equal(uint_ab.numpy(), uint_ab_true)
114
121
 
115
- err = np.max(np.abs(float_ab - float_ab_true))
116
- test.assertTrue(err < 1e-04)
122
+ assert_np_equal(float_01.numpy(), float_01_true, 1e-04)
123
+ assert_np_equal(float_ab.numpy(), float_ab_true, 1e-04)
117
124
 
118
125
 
119
126
  @wp.kernel
warp/tests/test_sparse.py CHANGED
@@ -19,10 +19,12 @@ import numpy as np
19
19
 
20
20
  import warp as wp
21
21
  from warp.sparse import (
22
+ bsr_assign,
22
23
  bsr_axpy,
23
24
  bsr_axpy_work_arrays,
24
25
  bsr_copy,
25
26
  bsr_diag,
27
+ bsr_from_triplets,
26
28
  bsr_get_diag,
27
29
  bsr_identity,
28
30
  bsr_mm,
@@ -232,18 +234,43 @@ def test_bsr_split_merge(test, device):
232
234
  with test.assertRaisesRegex(ValueError, "Incompatible dest and src block shapes"):
233
235
  bsr_copy(bsr, block_shape=(3, 3))
234
236
 
235
- with test.assertRaisesRegex(
236
- ValueError, r"Dest block shape \(5, 5\) is not an exact multiple of src block shape \(4, 2\)"
237
- ):
237
+ with test.assertRaisesRegex(ValueError, "Incompatible dest and src block shapes"):
238
238
  bsr_copy(bsr, block_shape=(5, 5))
239
239
 
240
240
  with test.assertRaisesRegex(
241
241
  ValueError,
242
- "The total rows and columns of the src matrix cannot be evenly divided using the requested block shape",
242
+ "The requested block shape does not evenly divide the source matrix",
243
243
  ):
244
244
  bsr_copy(bsr, block_shape=(32, 32))
245
245
 
246
246
 
247
+ def test_bsr_assign_masked(test, device):
248
+ rng = np.random.default_rng(123)
249
+
250
+ block_shape = (1, 2)
251
+ nrow = 16
252
+ ncol = 8
253
+ shape = (block_shape[0] * nrow, block_shape[1] * ncol)
254
+ n = 20
255
+
256
+ rows = wp.array(rng.integers(0, high=nrow, size=n, dtype=int), dtype=int, device=device)
257
+ cols = wp.array(rng.integers(0, high=ncol, size=n, dtype=int), dtype=int, device=device)
258
+ vals = wp.array(rng.random(size=(n, block_shape[0], block_shape[1])), dtype=float, device=device)
259
+
260
+ A = bsr_from_triplets(nrow, ncol, rows, cols, vals)
261
+
262
+ # Extract coarse diagonal with copy + diag funcs, for reference
263
+ A_coarse = bsr_copy(A, block_shape=(4, 4))
264
+ ref = _bsr_to_dense(bsr_diag(bsr_get_diag(A_coarse)))
265
+
266
+ # Extract coarse diagonal with masked assign (more memory efficient)
267
+ diag_masked = bsr_diag(rows_of_blocks=shape[0] // 4, block_type=A_coarse.dtype, device=device)
268
+ bsr_assign(src=A, dest=diag_masked, masked=True)
269
+ res = _bsr_to_dense(diag_masked)
270
+
271
+ assert_np_equal(res, ref, 0.0001)
272
+
273
+
247
274
  def make_test_bsr_transpose(block_shape, scalar_type):
248
275
  def test_bsr_transpose(test, device):
249
276
  rng = np.random.default_rng(123)
@@ -316,6 +343,12 @@ def make_test_bsr_axpy(block_shape, scalar_type):
316
343
  res = _bsr_to_dense(y)
317
344
  assert_np_equal(res, ref, 0.0001)
318
345
 
346
+ # test masked
347
+ y_mask = bsr_from_triplets(nrow, ncol, y.uncompress_rows()[:1], y.columns[:1], y.values[:1])
348
+ bsr_axpy(y, y_mask, masked=True)
349
+ assert y_mask.nnz_sync() == 1
350
+ assert_np_equal(y_mask.values.numpy(), 2.0 * y.values[:1].numpy(), 0.0001)
351
+
319
352
  # test incompatible shapes
320
353
  y.ncol = y.ncol + 1
321
354
  with test.assertRaisesRegex(ValueError, "Matrices must have the same number of rows and columns"):
@@ -383,6 +416,13 @@ def make_test_bsr_mm(block_shape, scalar_type):
383
416
  bsr_mm(x, y, z, alpha, beta, work_arrays=work_arrays, reuse_topology=True)
384
417
  assert_np_equal(res, ref, 0.0001)
385
418
 
419
+ # test masked mm
420
+ z = bsr_diag(rows_of_blocks=z.nrow, block_type=z.dtype, device=z.device)
421
+ bsr_mm(x, y, z, masked=True)
422
+ res = _bsr_to_dense(z)
423
+ ref = _bsr_to_dense(bsr_diag(bsr_get_diag(x @ y)))
424
+ assert_np_equal(res, ref, 0.0001)
425
+
386
426
  # using overloaded operators
387
427
  x = (alpha * x) @ y
388
428
  assert_np_equal(res, ref, 0.0001)
@@ -479,12 +519,12 @@ def make_test_bsr_mv(block_shape, scalar_type):
479
519
  assert_np_equal(res, ref, 0.0001)
480
520
 
481
521
  A.ncol = A.ncol + 1
482
- with test.assertRaisesRegex(ValueError, "Number of columns"):
522
+ with test.assertRaisesRegex(ValueError, "Incompatible 'x'"):
483
523
  bsr_mv(A, x, y)
484
524
 
485
525
  A.ncol = A.ncol - 1
486
526
  A.nrow = A.nrow - 1
487
- with test.assertRaisesRegex(ValueError, "Number of rows"):
527
+ with test.assertRaisesRegex(ValueError, "Incompatible 'y'"):
488
528
  bsr_mv(A, x, y)
489
529
 
490
530
  return test_bsr_mv
@@ -518,6 +558,7 @@ add_function_test(TestSparse, "test_csr_from_triplets", test_csr_from_triplets,
518
558
  add_function_test(TestSparse, "test_bsr_from_triplets", test_bsr_from_triplets, devices=devices)
519
559
  add_function_test(TestSparse, "test_bsr_get_diag", test_bsr_get_set_diag, devices=devices)
520
560
  add_function_test(TestSparse, "test_bsr_split_merge", test_bsr_split_merge, devices=devices)
561
+ add_function_test(TestSparse, "test_bsr_assign_masked", test_bsr_assign_masked, devices=devices)
521
562
 
522
563
  add_function_test(TestSparse, "test_csr_transpose", make_test_bsr_transpose((1, 1), wp.float32), devices=devices)
523
564
  add_function_test(TestSparse, "test_bsr_transpose_1_3", make_test_bsr_transpose((1, 3), wp.float32), devices=devices)
@@ -1969,6 +1969,67 @@ def test_transform_anon_type_instance(test, device, dtype, register_kernels=Fals
1969
1969
  tape.zero()
1970
1970
 
1971
1971
 
1972
+ def test_transform_from_matrix(test, device, dtype, register_kernels=False):
1973
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
1974
+ mat44 = wp.types.matrix((4, 4), wptype)
1975
+ vec3 = wp.types.vector(3, wptype)
1976
+ quat = wp.types.quaternion(wptype)
1977
+
1978
+ def transform_from_matrix_kernel():
1979
+ # fmt: off
1980
+ m = mat44(
1981
+ wptype(0.6), wptype(0.48), wptype(0.64), wptype(1.0),
1982
+ wptype(-0.8), wptype(0.36), wptype(0.48), wptype(2.0),
1983
+ wptype(0.0), wptype(-0.8), wptype(0.6), wptype(3.0),
1984
+ wptype(0.0), wptype(0.0), wptype(0.0), wptype(1.0),
1985
+ )
1986
+ # fmt: on
1987
+ t = wp.transform_from_matrix(m)
1988
+ p = wp.transform_get_translation(t)
1989
+ q = wp.transform_get_rotation(t)
1990
+ wp.expect_near(p, vec3(wptype(1.0), wptype(2.0), wptype(3.0)), tolerance=wptype(1e-3))
1991
+ wp.expect_near(q, quat(wptype(-0.4), wptype(0.2), wptype(-0.4), wptype(0.8)), tolerance=wptype(1e-3))
1992
+
1993
+ kernel = getkernel(transform_from_matrix_kernel, suffix=dtype.__name__)
1994
+
1995
+ if register_kernels:
1996
+ return
1997
+
1998
+ wp.launch(kernel, dim=1, device=device)
1999
+
2000
+
2001
+ def test_transform_to_matrix(test, device, dtype, register_kernels=False):
2002
+ wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]
2003
+ mat44 = wp.types.matrix((4, 4), wptype)
2004
+ vec3 = wp.types.vector(3, wptype)
2005
+ quat = wp.types.quaternion(wptype)
2006
+
2007
+ def transform_to_matrix_kernel():
2008
+ p = vec3(wptype(1.0), wptype(2.0), wptype(3.0))
2009
+ q = quat(wptype(-0.4), wptype(0.2), wptype(-0.4), wptype(0.8))
2010
+ t = wp.transformation(p, q)
2011
+ m = wp.transform_to_matrix(t)
2012
+ # fmt: off
2013
+ wp.expect_near(
2014
+ m,
2015
+ mat44(
2016
+ wptype(0.6), wptype(0.48), wptype(0.64), wptype(1.0),
2017
+ wptype(-0.8), wptype(0.36), wptype(0.48), wptype(2.0),
2018
+ wptype(0.0), wptype(-0.8), wptype(0.6), wptype(3.0),
2019
+ wptype(0.0), wptype(0.0), wptype(0.0), wptype(1.0),
2020
+ ),
2021
+ tolerance=wptype(1e-3),
2022
+ )
2023
+ # fmt: on
2024
+
2025
+ kernel = getkernel(transform_to_matrix_kernel, suffix=dtype.__name__)
2026
+
2027
+ if register_kernels:
2028
+ return
2029
+
2030
+ wp.launch(kernel, dim=1, device=device)
2031
+
2032
+
1972
2033
  devices = get_test_devices()
1973
2034
 
1974
2035
 
@@ -2145,6 +2206,20 @@ for dtype in np_float_types:
2145
2206
  add_function_test_register_kernel(
2146
2207
  TestSpatial, f"test_spatial_adjoint_{dtype.__name__}", test_spatial_adjoint, devices=devices, dtype=dtype
2147
2208
  )
2209
+ add_function_test_register_kernel(
2210
+ TestSpatial,
2211
+ f"test_transform_from_matrix_{dtype.__name__}",
2212
+ test_transform_from_matrix,
2213
+ devices=devices,
2214
+ dtype=dtype,
2215
+ )
2216
+ add_function_test_register_kernel(
2217
+ TestSpatial,
2218
+ f"test_transform_to_matrix_{dtype.__name__}",
2219
+ test_transform_to_matrix,
2220
+ devices=devices,
2221
+ dtype=dtype,
2222
+ )
2148
2223
 
2149
2224
  # \TODO: test spatial_mass and spatial_jacobian
2150
2225
 
warp/tests/test_static.py CHANGED
@@ -307,7 +307,7 @@ def test_function_lookup(test, device):
307
307
 
308
308
  def count_ssa_occurrences(kernel: wp.Kernel, ssas: List[str]) -> Dict[str, int]:
309
309
  # analyze the generated code
310
- counts = {ssa: 0 for ssa in ssas}
310
+ counts = dict.fromkeys(ssas, 0)
311
311
  for line in kernel.adj.blocks[0].body_forward:
312
312
  for ssa in ssas:
313
313
  if ssa in line:
warp/tests/test_utils.py CHANGED
@@ -87,7 +87,7 @@ def test_array_scan_error_unsupported_dtype(test, device):
87
87
 
88
88
 
89
89
  def test_radix_sort_pairs(test, device):
90
- keyTypes = [int, wp.float32]
90
+ keyTypes = [int, wp.float32, wp.int64]
91
91
 
92
92
  for keyType in keyTypes:
93
93
  keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
@@ -97,18 +97,46 @@ def test_radix_sort_pairs(test, device):
97
97
  assert_np_equal(values.numpy()[:8], np.array((5, 2, 8, 4, 7, 6, 1, 3)))
98
98
 
99
99
 
100
- def test_radix_sort_pairs_empty(test, device):
100
+ def test_segmented_sort_pairs(test, device):
101
101
  keyTypes = [int, wp.float32]
102
102
 
103
+ for keyType in keyTypes:
104
+ keys = wp.array((7, 2, 8, 4, 1, 6, 5, 3, 0, 0, 0, 0, 0, 0, 0, 0), dtype=keyType, device=device)
105
+ values = wp.array((1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0), dtype=int, device=device)
106
+ wp.utils.segmented_sort_pairs(
107
+ keys,
108
+ values,
109
+ 8,
110
+ wp.array((0, 4), dtype=int, device=device),
111
+ wp.array((4, 8), dtype=int, device=device),
112
+ )
113
+ assert_np_equal(keys.numpy()[:8], np.array((2, 4, 7, 8, 1, 3, 5, 6)))
114
+ assert_np_equal(values.numpy()[:8], np.array((2, 4, 1, 3, 5, 8, 7, 6)))
115
+
116
+
117
+ def test_radix_sort_pairs_empty(test, device):
118
+ keyTypes = [int, wp.float32, wp.int64]
119
+
103
120
  for keyType in keyTypes:
104
121
  keys = wp.array((), dtype=keyType, device=device)
105
122
  values = wp.array((), dtype=int, device=device)
106
123
  wp.utils.radix_sort_pairs(keys, values, 0)
107
124
 
108
125
 
109
- def test_radix_sort_pairs_error_insufficient_storage(test, device):
126
+ def test_segmented_sort_pairs_empty(test, device):
110
127
  keyTypes = [int, wp.float32]
111
128
 
129
+ for keyType in keyTypes:
130
+ keys = wp.array((), dtype=keyType, device=device)
131
+ values = wp.array((), dtype=int, device=device)
132
+ wp.utils.segmented_sort_pairs(
133
+ keys, values, 0, wp.array((), dtype=int, device=device), wp.array((), dtype=int, device=device)
134
+ )
135
+
136
+
137
+ def test_radix_sort_pairs_error_insufficient_storage(test, device):
138
+ keyTypes = [int, wp.float32, wp.int64]
139
+
112
140
  for keyType in keyTypes:
113
141
  keys = wp.array((1, 2, 3), dtype=keyType, device=device)
114
142
  values = wp.array((1, 2, 3), dtype=int, device=device)
@@ -119,9 +147,28 @@ def test_radix_sort_pairs_error_insufficient_storage(test, device):
119
147
  wp.utils.radix_sort_pairs(keys, values, 3)
120
148
 
121
149
 
122
- def test_radix_sort_pairs_error_unsupported_dtype(test, device):
150
+ def test_segmented_sort_pairs_error_insufficient_storage(test, device):
123
151
  keyTypes = [int, wp.float32]
124
152
 
153
+ for keyType in keyTypes:
154
+ keys = wp.array((1, 2, 3), dtype=keyType, device=device)
155
+ values = wp.array((1, 2, 3), dtype=int, device=device)
156
+ with test.assertRaisesRegex(
157
+ RuntimeError,
158
+ r"Array storage must be large enough to contain 2\*count elements$",
159
+ ):
160
+ wp.utils.segmented_sort_pairs(
161
+ keys,
162
+ values,
163
+ 3,
164
+ wp.array((0,), dtype=int, device=device),
165
+ wp.array((3,), dtype=int, device=device),
166
+ )
167
+
168
+
169
+ def test_radix_sort_pairs_error_unsupported_dtype(test, device):
170
+ keyTypes = [int, wp.float32, wp.int64]
171
+
125
172
  for keyType in keyTypes:
126
173
  keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
127
174
  values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
@@ -132,6 +179,25 @@ def test_radix_sort_pairs_error_unsupported_dtype(test, device):
132
179
  wp.utils.radix_sort_pairs(keys, values, 1)
133
180
 
134
181
 
182
+ def test_segmented_sort_pairs_error_unsupported_dtype(test, device):
183
+ keyTypes = [int, wp.float32]
184
+
185
+ for keyType in keyTypes:
186
+ keys = wp.array((1.0, 2.0, 3.0), dtype=keyType, device=device)
187
+ values = wp.array((1.0, 2.0, 3.0), dtype=float, device=device)
188
+ with test.assertRaisesRegex(
189
+ RuntimeError,
190
+ r"Unsupported data type$",
191
+ ):
192
+ wp.utils.segmented_sort_pairs(
193
+ keys,
194
+ values,
195
+ 1,
196
+ wp.array((0,), dtype=int, device=device),
197
+ wp.array((3,), dtype=int, device=device),
198
+ )
199
+
200
+
135
201
  def test_array_sum(test, device):
136
202
  for dtype in (wp.float32, wp.float64):
137
203
  with test.subTest(dtype=dtype):
@@ -468,6 +534,20 @@ add_function_test(
468
534
  test_radix_sort_pairs_error_unsupported_dtype,
469
535
  devices=devices,
470
536
  )
537
+ add_function_test(TestUtils, "test_segmented_sort_pairs", test_segmented_sort_pairs, devices=devices)
538
+ add_function_test(TestUtils, "test_segmented_sort_pairs_empty", test_segmented_sort_pairs, devices=devices)
539
+ add_function_test(
540
+ TestUtils,
541
+ "test_segmented_sort_pairs_error_insufficient_storage",
542
+ test_segmented_sort_pairs_error_insufficient_storage,
543
+ devices=devices,
544
+ )
545
+ add_function_test(
546
+ TestUtils,
547
+ "test_segmented_sort_pairs_error_unsupported_dtype",
548
+ test_segmented_sort_pairs_error_unsupported_dtype,
549
+ devices=devices,
550
+ )
471
551
  add_function_test(TestUtils, "test_array_sum", test_array_sum, devices=devices)
472
552
  add_function_test(
473
553
  TestUtils, "test_array_sum_error_out_dtype_mismatch", test_array_sum_error_out_dtype_mismatch, devices=devices