warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -298,7 +298,7 @@ def triangulate(face_counts, face_indices):
298
298
  def test_mesh_query_point(test, device):
299
299
  from pxr import Usd, UsdGeom
300
300
 
301
- mesh = Usd.Stage.Open(os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/spiky.usd")))
301
+ mesh = Usd.Stage.Open(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "spiky.usd")))
302
302
  mesh_geom = UsdGeom.Mesh(mesh.GetPrimAtPath("/Cube/Cube"))
303
303
 
304
304
  mesh_counts = mesh_geom.GetFaceVertexCountsAttr().Get()
@@ -526,7 +526,7 @@ def mesh_query_point_loss(
526
526
  def test_adj_mesh_query_point(test, device):
527
527
  from pxr import Usd, UsdGeom
528
528
 
529
- mesh = Usd.Stage.Open(os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/torus.usda")))
529
+ mesh = Usd.Stage.Open(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "torus.usda")))
530
530
  mesh_geom = UsdGeom.Mesh(mesh.GetPrimAtPath("/World/Torus"))
531
531
 
532
532
  mesh_counts = mesh_geom.GetFaceVertexCountsAttr().Get()
@@ -663,7 +663,7 @@ def sample_furthest_points_brute(
663
663
  def test_mesh_query_furthest_point(test, device):
664
664
  from pxr import Usd, UsdGeom
665
665
 
666
- mesh = Usd.Stage.Open(os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/spiky.usd")))
666
+ mesh = Usd.Stage.Open(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "spiky.usd")))
667
667
  mesh_geom = UsdGeom.Mesh(mesh.GetPrimAtPath("/Cube/Cube"))
668
668
 
669
669
  mesh_counts = mesh_geom.GetFaceVertexCountsAttr().Get()
@@ -750,11 +750,11 @@ def triangle_closest_point_for_test(a: wp.vec3, b: wp.vec3, c: wp.vec3, p: wp.ve
750
750
  return a + v * ab + w * ac, bary
751
751
 
752
752
 
753
- def load_mesh():
753
+ def load_mesh(model_name="bunny"):
754
754
  from pxr import Usd, UsdGeom
755
755
 
756
- usd_stage = Usd.Stage.Open(os.path.join(wp.examples.get_asset_directory(), "bunny.usd"))
757
- usd_geom = UsdGeom.Mesh(usd_stage.GetPrimAtPath("/root/bunny"))
756
+ usd_stage = Usd.Stage.Open(os.path.join(wp.examples.get_asset_directory(), model_name + ".usd"))
757
+ usd_geom = UsdGeom.Mesh(usd_stage.GetPrimAtPath("/root/" + model_name))
758
758
 
759
759
  vertices = np.array(usd_geom.GetPointsAttr().Get())
760
760
  faces = np.array(usd_geom.GetFaceVertexIndicesAttr().Get())
@@ -820,76 +820,79 @@ def test_set_mesh_points(test, device):
820
820
 
821
821
  rng = np.random.default_rng(123)
822
822
 
823
- vs, fs = load_mesh()
824
- vertices1 = wp.array(vs, dtype=wp.vec3, device=device)
825
- velocities1_np = rng.standard_normal(size=(vertices1.shape[0], 3))
826
- velocities1 = wp.array(velocities1_np, dtype=wp.vec3, device=device)
823
+ models = ["bunny", "nonuniform"]
827
824
 
828
- faces = wp.array(fs, dtype=wp.int32, device=device)
825
+ for model in models:
826
+ vs, fs = load_mesh(model)
827
+ vertices1 = wp.array(vs, dtype=wp.vec3, device=device)
828
+ velocities1_np = rng.standard_normal(size=(vertices1.shape[0], 3))
829
+ velocities1 = wp.array(velocities1_np, dtype=wp.vec3, device=device)
829
830
 
830
- n = 1000
831
- query_radius = 0.2
832
- pts1 = wp.array(rng.standard_normal(size=(n, 3)), dtype=wp.vec3, device=device)
831
+ faces = wp.array(fs, dtype=wp.int32, device=device)
833
832
 
834
- query_results_num_cols1 = wp.zeros(n, dtype=wp.int32, device=device)
835
- query_results_min_dist1 = wp.zeros(n, dtype=float, device=device)
836
- query_results_closest_point_velocity1 = wp.zeros(n, dtype=wp.vec3, device=device)
833
+ n = 1000
834
+ query_radius = 0.2
835
+ pts1 = wp.array(rng.standard_normal(size=(n, 3)), dtype=wp.vec3, device=device)
837
836
 
838
- for constructor in constructors:
839
- mesh = wp.Mesh(vertices1, faces, velocities=velocities1, bvh_constructor=constructor)
840
- fs_2D = faces.reshape((-1, 3))
837
+ query_results_num_cols1 = wp.zeros(n, dtype=wp.int32, device=device)
838
+ query_results_min_dist1 = wp.zeros(n, dtype=float, device=device)
839
+ query_results_closest_point_velocity1 = wp.zeros(n, dtype=wp.vec3, device=device)
841
840
 
842
- wp.launch(
843
- kernel=point_query_aabb_and_closest,
844
- inputs=[
845
- query_radius,
846
- mesh.id,
847
- pts1,
848
- vertices1,
849
- fs_2D,
850
- query_results_num_cols1,
851
- query_results_min_dist1,
852
- query_results_closest_point_velocity1,
853
- ],
854
- dim=n,
855
- device=device,
856
- )
841
+ for constructor in constructors:
842
+ mesh = wp.Mesh(vertices1, faces, velocities=velocities1, bvh_constructor=constructor)
843
+ fs_2D = faces.reshape((-1, 3))
844
+
845
+ wp.launch(
846
+ kernel=point_query_aabb_and_closest,
847
+ inputs=[
848
+ query_radius,
849
+ mesh.id,
850
+ pts1,
851
+ vertices1,
852
+ fs_2D,
853
+ query_results_num_cols1,
854
+ query_results_min_dist1,
855
+ query_results_closest_point_velocity1,
856
+ ],
857
+ dim=n,
858
+ device=device,
859
+ )
857
860
 
858
- shift = rng.standard_normal(size=3)
861
+ shift = rng.standard_normal(size=3)
859
862
 
860
- vs_higher = vs + shift
861
- vertices2 = wp.array(vs_higher, dtype=wp.vec3, device=device)
863
+ vs_higher = vs + shift
864
+ vertices2 = wp.array(vs_higher, dtype=wp.vec3, device=device)
862
865
 
863
- velocities2_np = velocities1_np + shift[None, ...]
864
- velocities2 = wp.array(velocities2_np, dtype=wp.vec3, device=device)
866
+ velocities2_np = velocities1_np + shift[None, ...]
867
+ velocities2 = wp.array(velocities2_np, dtype=wp.vec3, device=device)
865
868
 
866
- pts2 = wp.array(pts1.numpy() + shift, dtype=wp.vec3, device=device)
869
+ pts2 = wp.array(pts1.numpy() + shift, dtype=wp.vec3, device=device)
867
870
 
868
- mesh.points = vertices2
869
- mesh.velocities = velocities2
871
+ mesh.points = vertices2
872
+ mesh.velocities = velocities2
870
873
 
871
- query_results_num_cols2 = wp.zeros(n, dtype=wp.int32, device=device)
872
- query_results_min_dist2 = wp.zeros(n, dtype=float, device=device)
873
- query_results_closest_point_velocity2 = wp.array([shift for i in range(n)], dtype=wp.vec3, device=device)
874
+ query_results_num_cols2 = wp.zeros(n, dtype=wp.int32, device=device)
875
+ query_results_min_dist2 = wp.zeros(n, dtype=float, device=device)
876
+ query_results_closest_point_velocity2 = wp.array([shift for i in range(n)], dtype=wp.vec3, device=device)
874
877
 
875
- wp.launch(
876
- kernel=point_query_aabb_and_closest,
877
- inputs=[
878
- query_radius,
879
- mesh.id,
880
- pts2,
881
- vertices2,
882
- fs_2D,
883
- query_results_num_cols2,
884
- query_results_min_dist2,
885
- query_results_closest_point_velocity2,
886
- ],
887
- dim=n,
888
- device=device,
889
- )
878
+ wp.launch(
879
+ kernel=point_query_aabb_and_closest,
880
+ inputs=[
881
+ query_radius,
882
+ mesh.id,
883
+ pts2,
884
+ vertices2,
885
+ fs_2D,
886
+ query_results_num_cols2,
887
+ query_results_min_dist2,
888
+ query_results_closest_point_velocity2,
889
+ ],
890
+ dim=n,
891
+ device=device,
892
+ )
890
893
 
891
- test.assertTrue((query_results_num_cols1.numpy() == query_results_num_cols2.numpy()).all())
892
- test.assertTrue(((query_results_min_dist1.numpy() - query_results_min_dist2.numpy()) < 1e-5).all())
894
+ test.assertTrue((query_results_num_cols1.numpy() == query_results_num_cols2.numpy()).all())
895
+ test.assertTrue(((query_results_min_dist1.numpy() - query_results_min_dist2.numpy()) < 1e-5).all())
893
896
 
894
897
 
895
898
  devices = get_test_devices()
@@ -88,7 +88,7 @@ def test_mesh_query_ray_grad(test, device):
88
88
  # mesh_points = wp.array(np.array([[0.0, 0.0, 0.0], [2.0, 0.0, 0.0], [0.0, 2.0, 0.0]]), dtype=wp.vec3, device=device)
89
89
  # mesh_indices = wp.array(np.array([0,1,2]), dtype=int, device=device)
90
90
 
91
- mesh = Usd.Stage.Open(os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/torus.usda")))
91
+ mesh = Usd.Stage.Open(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "torus.usda")))
92
92
  mesh_geom = UsdGeom.Mesh(mesh.GetPrimAtPath("/World/Torus"))
93
93
 
94
94
  mesh_counts = mesh_geom.GetFaceVertexCountsAttr().Get()
@@ -370,6 +370,22 @@ def test_volume_store_i(volume: wp.uint64, points: wp.array(dtype=wp.vec3), valu
370
370
  values[tid] = wp.volume_lookup_i(volume, i, j, k)
371
371
 
372
372
 
373
+ @wp.kernel
374
+ def test_volume_store_v4(volume: wp.uint64, points: wp.array(dtype=wp.vec3), values: wp.array(dtype=wp.vec4)):
375
+ tid = wp.tid()
376
+
377
+ p = points[tid]
378
+ i = int(p[0])
379
+ j = int(p[1])
380
+ k = int(p[2])
381
+
382
+ v = wp.vec4(p[0], p[1], p[2], float(i + 100 * j + 10000 * k))
383
+
384
+ wp.volume_store(volume, i, j, k, v)
385
+
386
+ values[tid] = wp.volume_lookup(volume, i, j, k, dtype=wp.vec4)
387
+
388
+
373
389
  devices = get_test_devices()
374
390
  rng = np.random.default_rng(101215)
375
391
 
@@ -393,12 +409,12 @@ rng = np.random.default_rng(101215)
393
409
  # (-90 degrees rotation along X)
394
410
  # voxel size: 0.1
395
411
  volume_paths = {
396
- "float": os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/test_grid.nvdb")),
397
- "int32": os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/test_int32_grid.nvdb")),
398
- "vec3f": os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/test_vec_grid.nvdb")),
399
- "index": os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/test_index_grid.nvdb")),
400
- "torus": os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/torus.nvdb")),
401
- "float_write": os.path.abspath(os.path.join(os.path.dirname(__file__), "assets/test_grid.nvdb")),
412
+ "float": os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "test_grid.nvdb")),
413
+ "int32": os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "test_int32_grid.nvdb")),
414
+ "vec3f": os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "test_vec_grid.nvdb")),
415
+ "index": os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "test_index_grid.nvdb")),
416
+ "torus": os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "torus.nvdb")),
417
+ "float_write": os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "assets", "test_grid.nvdb")),
402
418
  }
403
419
 
404
420
  test_volume_tiles = (
@@ -635,6 +651,22 @@ def test_volume_allocation_i(test, device):
635
651
  np.testing.assert_equal(values_res, values_ref)
636
652
 
637
653
 
654
+ def test_volume_allocation_v4(test, device):
655
+ bg_value = (-1, 2.0, -3, 5)
656
+ points_np = np.append(point_grid, [[8096, 8096, 8096]], axis=0)
657
+
658
+ w_ref = np.array([x + 100 * y + 10000 * z for x, y, z in point_grid])[:, np.newaxis]
659
+ values_ref = np.append(np.hstack((point_grid, w_ref)), [bg_value], axis=0)
660
+
661
+ volume = wp.Volume.allocate(min=[-11, -11, -11], max=[11, 11, 11], voxel_size=0.1, bg_value=bg_value, device=device)
662
+ points = wp.array(points_np, dtype=wp.vec3, device=device)
663
+ values = wp.empty(len(points_np), dtype=wp.vec4, device=device)
664
+ wp.launch(test_volume_store_v4, dim=len(points_np), inputs=[volume.id, points, values], device=device)
665
+
666
+ values_res = values.numpy()
667
+ np.testing.assert_equal(values_res, values_ref)
668
+
669
+
638
670
  def test_volume_introspection(test, device):
639
671
  for volume_names in ("float", "vec3f"):
640
672
  with test.subTest(volume_names=volume_names):
@@ -967,6 +999,9 @@ add_function_test(
967
999
  add_function_test(
968
1000
  TestVolume, "test_volume_allocation_i", test_volume_allocation_i, devices=get_selected_cuda_test_devices()
969
1001
  )
1002
+ add_function_test(
1003
+ TestVolume, "test_volume_allocation_v4", test_volume_allocation_v4, devices=get_selected_cuda_test_devices()
1004
+ )
970
1005
  add_function_test(TestVolume, "test_volume_introspection", test_volume_introspection, devices=devices)
971
1006
  add_function_test(
972
1007
  TestVolume, "test_volume_from_numpy", test_volume_from_numpy, devices=get_selected_cuda_test_devices()
File without changes
@@ -389,11 +389,23 @@ def test_dlpack_paddle_to_warp(test, device):
389
389
  def test_dlpack_warp_to_jax(test, device):
390
390
  import jax
391
391
  import jax.dlpack
392
+ import jax.numpy as jnp
392
393
 
393
- a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
394
+ cpu_device = jax.devices("cpu")[0]
395
+
396
+ # Create a numpy array from a JAX array to respect XLA alignment needs
397
+ with jax.default_device(cpu_device):
398
+ x_jax = jnp.arange(N, dtype=jnp.float32)
399
+ x_numpy = np.asarray(x_jax)
400
+ test.assertEqual(x_jax.unsafe_buffer_pointer(), np.lib.array_utils.byte_bounds(x_numpy)[0])
401
+
402
+ a = wp.array(x_numpy, device=device, dtype=wp.float32, copy=False)
403
+
404
+ if device.is_cpu:
405
+ test.assertEqual(a.ptr, np.lib.array_utils.byte_bounds(x_numpy)[0])
394
406
 
395
407
  # use generic dlpack conversion
396
- j1 = jax.dlpack.from_dlpack(wp.to_dlpack(a))
408
+ j1 = jax.dlpack.from_dlpack(a, copy=False)
397
409
 
398
410
  # use jax wrapper
399
411
  j2 = wp.to_jax(a)
@@ -423,14 +435,25 @@ def test_dlpack_warp_to_jax(test, device):
423
435
  @unittest.skipUnless(_jax_version() >= (0, 4, 15), "Jax version too old")
424
436
  def test_dlpack_warp_to_jax_v2(test, device):
425
437
  # same as original test, but uses newer __dlpack__() method
426
-
427
438
  import jax
428
439
  import jax.dlpack
440
+ import jax.numpy as jnp
429
441
 
430
- a = wp.array(data=np.arange(N, dtype=np.float32), device=device)
442
+ cpu_device = jax.devices("cpu")[0]
443
+
444
+ # Create a numpy array from a JAX array to respect XLA alignment needs
445
+ with jax.default_device(cpu_device):
446
+ x_jax = jnp.arange(N, dtype=jnp.float32)
447
+ x_numpy = np.asarray(x_jax)
448
+ test.assertEqual(x_jax.unsafe_buffer_pointer(), np.lib.array_utils.byte_bounds(x_numpy)[0])
449
+
450
+ a = wp.array(x_numpy, device=device, dtype=wp.float32, copy=False)
451
+
452
+ if device.is_cpu:
453
+ test.assertEqual(a.ptr, np.lib.array_utils.byte_bounds(x_numpy)[0])
431
454
 
432
455
  # pass warp array directly
433
- j1 = jax.dlpack.from_dlpack(a)
456
+ j1 = jax.dlpack.from_dlpack(a, copy=False)
434
457
 
435
458
  # use jax wrapper
436
459
  j2 = wp.to_jax(a)
File without changes
@@ -17,6 +17,7 @@ import math
17
17
  import unittest
18
18
 
19
19
  import warp as wp
20
+ import warp.examples
20
21
  import warp.sim
21
22
  from warp.tests.unittest_utils import *
22
23
 
@@ -24,11 +25,7 @@ from warp.tests.unittest_utils import *
24
25
  def build_ant(num_envs):
25
26
  builder = wp.sim.ModelBuilder()
26
27
  for i in range(num_envs):
27
- wp.sim.parse_mjcf(
28
- os.path.join(os.path.dirname(__file__), "../../examples/assets/nv_ant.xml"),
29
- builder,
30
- up_axis="y",
31
- )
28
+ wp.sim.parse_mjcf(os.path.join(warp.examples.get_asset_directory(), "nv_ant.xml"), builder, up_axis="y")
32
29
 
33
30
  coord_count = 15
34
31
  dof_count = 14
@@ -37,8 +34,10 @@ def build_ant(num_envs):
37
34
  dof_start = i * dof_count
38
35
 
39
36
  # base
40
- builder.joint_q[coord_start : coord_start + 3] = [i * 2.0, 0.70, 0.0]
41
- builder.joint_q[coord_start + 3 : coord_start + 7] = wp.quat_from_axis_angle((1.0, 0.0, 0.0), -math.pi * 0.5)
37
+ p = [i * 2.0, 0.70, 0.0]
38
+ q = wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), -math.pi * 0.5)
39
+ builder.joint_q[coord_start : coord_start + 3] = p
40
+ builder.joint_q[coord_start + 3 : coord_start + 7] = q
42
41
 
43
42
  # joints
44
43
  builder.joint_q[coord_start + 7 : coord_start + coord_count] = [0.0, 1.0, 0.0, -1.0, 0.0, -1.0, 0.0, 1.0]
@@ -56,9 +55,9 @@ def build_complex_joint_mechanism(chain_length):
56
55
  ax1 = wp.normalize(wp.vec3(4.0, -1.0, 2.0))
57
56
  ax2 = wp.normalize(wp.vec3(-3.0, 4.0, -1.0))
58
57
  # declare some transforms with nonzero translation and orientation
59
- tf0 = wp.transform(wp.vec3(1.0, 2.0, 3.0), wp.quat_from_axis_angle((1.0, 0.0, 0.0), math.pi * 0.25))
60
- tf1 = wp.transform(wp.vec3(4.0, 5.0, 6.0), wp.quat_from_axis_angle((0.0, 1.0, 0.0), math.pi * 0.5))
61
- tf2 = wp.transform(wp.vec3(7.0, 8.0, 9.0), wp.quat_from_axis_angle((0.0, 0.0, 1.0), math.pi * 0.75))
58
+ tf0 = wp.transform(wp.vec3(1.0, 2.0, 3.0), wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), math.pi * 0.25))
59
+ tf1 = wp.transform(wp.vec3(4.0, 5.0, 6.0), wp.quat_from_axis_angle(wp.vec3(0.0, 1.0, 0.0), math.pi * 0.5))
60
+ tf2 = wp.transform(wp.vec3(7.0, 8.0, 9.0), wp.quat_from_axis_angle(wp.vec3(0.0, 0.0, 1.0), math.pi * 0.75))
62
61
 
63
62
  parent = -1
64
63
  for _i in range(chain_length):
@@ -595,8 +595,8 @@ class TestCollision(unittest.TestCase):
595
595
  pass
596
596
 
597
597
 
598
- add_function_test(TestCollision, "test_vertex_triangle_collision", test_vertex_triangle_collision, devices=devices)
599
- add_function_test(TestCollision, "test_edge_edge_collision", test_vertex_triangle_collision, devices=devices)
598
+ # add_function_test(TestCollision, "test_vertex_triangle_collision", test_vertex_triangle_collision, devices=devices)
599
+ # add_function_test(TestCollision, "test_edge_edge_collision", test_vertex_triangle_collision, devices=devices)
600
600
  add_function_test(TestCollision, "test_particle_collision", test_particle_collision, devices=devices)
601
601
 
602
602
  if __name__ == "__main__":
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import math
16
17
  import unittest
17
18
 
18
19
  import numpy as np
@@ -178,6 +179,45 @@ class TestModel(unittest.TestCase):
178
179
  assert builder2.articulation_count == 2 * builder.articulation_count
179
180
  assert builder2.articulation_start == [0, 1, 2, 3]
180
181
 
182
+ def test_add_builder_with_open_edges(self):
183
+ builder = wp.sim.ModelBuilder()
184
+
185
+ dim_x = 16
186
+ dim_y = 16
187
+
188
+ env_builder = wp.sim.ModelBuilder()
189
+ env_builder.add_cloth_grid(
190
+ pos=wp.vec3(0.0, 0.0, 0.0),
191
+ vel=wp.vec3(0.1, 0.1, 0.0),
192
+ rot=wp.quat_from_axis_angle(wp.vec3(1.0, 0.0, 0.0), -math.pi * 0.25),
193
+ dim_x=dim_x,
194
+ dim_y=dim_y,
195
+ cell_x=1.0 / dim_x,
196
+ cell_y=1.0 / dim_y,
197
+ mass=1.0,
198
+ )
199
+
200
+ num_envs = 2
201
+ env_offsets = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
202
+
203
+ builder_open_edge_count = np.sum(np.array(builder.edge_indices) == -1)
204
+ env_builder_open_edge_count = np.sum(np.array(env_builder.edge_indices) == -1)
205
+
206
+ for i in range(num_envs):
207
+ xform = wp.transform(env_offsets[i], wp.quat_identity())
208
+ builder.add_builder(
209
+ env_builder,
210
+ xform,
211
+ update_num_env_count=True,
212
+ separate_collision_group=True,
213
+ )
214
+
215
+ self.assertEqual(
216
+ np.sum(np.array(builder.edge_indices) == -1),
217
+ builder_open_edge_count + num_envs * env_builder_open_edge_count,
218
+ "builder does not have the expected number of open edges",
219
+ )
220
+
181
221
 
182
222
  if __name__ == "__main__":
183
223
  wp.clear_kernel_cache()
@@ -18,6 +18,7 @@ import os
18
18
  import unittest
19
19
 
20
20
  import warp as wp
21
+ import warp.examples
21
22
  import warp.sim
22
23
  from warp.tests.unittest_utils import *
23
24
 
@@ -29,7 +30,7 @@ def test_fk_ik(test, device):
29
30
 
30
31
  for i in range(num_envs):
31
32
  wp.sim.parse_mjcf(
32
- os.path.join(os.path.dirname(__file__), "../examples/assets/nv_ant.xml"),
33
+ os.path.join(warp.examples.get_asset_directory(), "nv_ant.xml"),
33
34
  builder,
34
35
  stiffness=0.0,
35
36
  damping=1.0,