warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl → 1.7.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (45) hide show
  1. warp/autograd.py +12 -2
  2. warp/bin/warp.so +0 -0
  3. warp/build.py +1 -1
  4. warp/builtins.py +11 -10
  5. warp/codegen.py +17 -5
  6. warp/config.py +1 -1
  7. warp/context.py +6 -0
  8. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  9. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  10. warp/fem/field/field.py +11 -1
  11. warp/fem/field/nodal_field.py +36 -22
  12. warp/fem/geometry/adaptive_nanogrid.py +7 -3
  13. warp/fem/geometry/trimesh.py +4 -12
  14. warp/jax_experimental/custom_call.py +14 -2
  15. warp/jax_experimental/ffi.py +5 -1
  16. warp/native/tile.h +11 -11
  17. warp/native/warp.cu +1 -1
  18. warp/render/render_opengl.py +19 -17
  19. warp/render/render_usd.py +93 -3
  20. warp/sim/collide.py +11 -9
  21. warp/sim/inertia.py +189 -156
  22. warp/sim/integrator_euler.py +3 -0
  23. warp/sim/integrator_xpbd.py +3 -0
  24. warp/sim/model.py +29 -12
  25. warp/sim/render.py +4 -0
  26. warp/stubs.py +1 -1
  27. warp/tests/assets/torus.usda +1 -1
  28. warp/tests/sim/test_collision.py +237 -206
  29. warp/tests/sim/test_inertia.py +161 -0
  30. warp/tests/sim/{flaky_test_sim_grad.py → test_sim_grad.py} +4 -0
  31. warp/tests/sim/test_xpbd.py +399 -0
  32. warp/tests/test_codegen.py +24 -3
  33. warp/tests/test_examples.py +16 -6
  34. warp/tests/test_fem.py +75 -10
  35. warp/tests/test_mat.py +370 -103
  36. warp/tests/test_quat.py +321 -137
  37. warp/tests/test_vec.py +320 -174
  38. warp/tests/tile/test_tile_load.py +97 -0
  39. warp/tests/unittest_suites.py +2 -5
  40. warp/types.py +65 -8
  41. {warp_lang-1.7.0.dist-info → warp_lang-1.7.1.dist-info}/METADATA +21 -9
  42. {warp_lang-1.7.0.dist-info → warp_lang-1.7.1.dist-info}/RECORD +45 -42
  43. {warp_lang-1.7.0.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  44. {warp_lang-1.7.0.dist-info → warp_lang-1.7.1.dist-info}/licenses/LICENSE.md +0 -26
  45. {warp_lang-1.7.0.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
@@ -126,7 +126,14 @@ def _create_jax_warp_primitive():
126
126
 
127
127
  # Create and register the primitive.
128
128
  # TODO add default implementation that calls the kernel via warp.
129
- _jax_warp_p = jax.core.Primitive("jax_warp")
129
+ try:
130
+ # newer JAX versions
131
+ import jax.extend
132
+
133
+ _jax_warp_p = jax.extend.core.Primitive("jax_warp")
134
+ except (ImportError, AttributeError):
135
+ # older JAX versions
136
+ _jax_warp_p = jax.core.Primitive("jax_warp")
130
137
  _jax_warp_p.multiple_results = True
131
138
 
132
139
  # TODO Just launch the kernel directly, but make sure the argument
@@ -262,7 +269,12 @@ def _create_jax_warp_primitive():
262
269
  capsule = PyCapsule_New(ccall_address.value, b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0))
263
270
 
264
271
  # Register the callback in XLA.
265
- jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu")
272
+ try:
273
+ # newer JAX versions
274
+ jax.ffi.register_ffi_target("warp_call", capsule, platform="gpu", api_version=0)
275
+ except AttributeError:
276
+ # older JAX versions
277
+ jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu")
266
278
 
267
279
  def default_layout(shape):
268
280
  return range(len(shape) - 1, -1, -1)
@@ -560,7 +560,11 @@ class FfiCallable:
560
560
 
561
561
  # call the Python function with reconstructed arguments
562
562
  with wp.ScopedStream(stream, sync_enter=False):
563
- self.func(*arg_list)
563
+ if stream.is_capturing:
564
+ with wp.ScopedCapture(stream=stream, external=True):
565
+ self.func(*arg_list)
566
+ else:
567
+ self.func(*arg_list)
564
568
 
565
569
  except Exception as e:
566
570
  print(traceback.format_exc())
warp/native/tile.h CHANGED
@@ -1133,17 +1133,17 @@ struct tile_shared_t
1133
1133
  constexpr int lastdim = Layout::Shape::N-1;
1134
1134
  constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1;
1135
1135
  const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T);
1136
- const int elements = (dest.data.shape[lastdim] - dest.offset[lastdim]);
1136
+ const int elements = min(Layout::Shape::dim(1), (dest.data.shape[lastdim] - dest.offset[lastdim]));
1137
1137
  const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1138
1138
 
1139
1139
  float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
1140
1140
  const bool aligned_dst = (uint64_t)(dest128)%sizeof(float4) == 0;
1141
1141
 
1142
- if (contiguous_dest && contiguous_src && aligned_size && aligned_dst)
1143
- {
1144
- constexpr int M = Layout::Shape::dim(0);
1145
- constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1142
+ constexpr int M = Layout::Shape::dim(0);
1143
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1146
1144
 
1145
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_dst && N)
1146
+ {
1147
1147
  // alias of shared tile with 128bit type
1148
1148
  using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1149
1149
  tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
@@ -1222,17 +1222,17 @@ struct tile_shared_t
1222
1222
  constexpr int lastdim = Layout::Shape::N-1;
1223
1223
  constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1;
1224
1224
  const bool contiguous_src = src.data.strides[lastdim] == sizeof(T);
1225
- const int elements = (src.data.shape[lastdim] - src.offset[lastdim]);
1225
+ const int elements = min(Layout::Shape::dim(1), (src.data.shape[lastdim] - src.offset[lastdim]));
1226
1226
  const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1227
-
1227
+
1228
1228
  float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1229
1229
  const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
1230
1230
 
1231
- if (contiguous_dest && contiguous_src && aligned_size && aligned_src)
1232
- {
1233
- constexpr int M = Layout::Shape::dim(0);
1234
- constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1231
+ constexpr int M = Layout::Shape::dim(0);
1232
+ constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1235
1233
 
1234
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_src && N)
1235
+ {
1236
1236
  // alias of shared tile with 128bit type
1237
1237
  using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1238
1238
  tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
warp/native/warp.cu CHANGED
@@ -3027,7 +3027,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3027
3027
  fprintf(stderr, "Warp error: num_ltoirs > 0 but ltoir_input_types, ltoirs or ltoir_sizes are NULL\n");
3028
3028
  return size_t(-1);
3029
3029
  }
3030
- nvJitLinkHandle handle;
3030
+ nvJitLinkHandle handle = nullptr;
3031
3031
  std::vector<const char *> lopts = {"-dlto", arch_opt_lto};
3032
3032
  if (use_ptx) {
3033
3033
  lopts.push_back("-ptx");
@@ -13,11 +13,13 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  import ctypes
17
19
  import sys
18
20
  import time
19
21
  from collections import defaultdict
20
- from typing import List, Optional, Tuple, Union
22
+ from typing import List, Union
21
23
 
22
24
  import numpy as np
23
25
 
@@ -1500,16 +1502,16 @@ class OpenGLRenderer:
1500
1502
 
1501
1503
  def setup_tiled_rendering(
1502
1504
  self,
1503
- instances: List[List[int]],
1505
+ instances: list[list[int]],
1504
1506
  rescale_window: bool = False,
1505
- tile_width: Optional[int] = None,
1506
- tile_height: Optional[int] = None,
1507
- tile_ncols: Optional[int] = None,
1508
- tile_nrows: Optional[int] = None,
1509
- tile_positions: Optional[List[Tuple[int]]] = None,
1510
- tile_sizes: Optional[List[Tuple[int]]] = None,
1511
- projection_matrices: Optional[List[Mat44]] = None,
1512
- view_matrices: Optional[List[Mat44]] = None,
1507
+ tile_width: int | None = None,
1508
+ tile_height: int | None = None,
1509
+ tile_ncols: int | None = None,
1510
+ tile_nrows: int | None = None,
1511
+ tile_positions: list[tuple[int]] | None = None,
1512
+ tile_sizes: list[tuple[int]] | None = None,
1513
+ projection_matrices: list[Mat44] | None = None,
1514
+ view_matrices: list[Mat44] | None = None,
1513
1515
  ):
1514
1516
  """
1515
1517
  Set up tiled rendering where the render buffer is split into multiple tiles that can visualize
@@ -1602,11 +1604,11 @@ class OpenGLRenderer:
1602
1604
  def update_tile(
1603
1605
  self,
1604
1606
  tile_id,
1605
- instances: Optional[List[int]] = None,
1606
- projection_matrix: Optional[Mat44] = None,
1607
- view_matrix: Optional[Mat44] = None,
1608
- tile_size: Optional[Tuple[int]] = None,
1609
- tile_position: Optional[Tuple[int]] = None,
1607
+ instances: list[int] | None = None,
1608
+ projection_matrix: Mat44 | None = None,
1609
+ view_matrix: Mat44 | None = None,
1610
+ tile_size: tuple[int] | None = None,
1611
+ tile_position: tuple[int] | None = None,
1610
1612
  ):
1611
1613
  """
1612
1614
  Update the shape instances, projection matrix, view matrix, tile size, or tile position
@@ -1806,7 +1808,7 @@ class OpenGLRenderer:
1806
1808
 
1807
1809
  return np.array((scaling, 0, 0, 0, 0, scaling, 0, 0, 0, 0, scaling, 0, 0, 0, 0, 1), dtype=np.float32)
1808
1810
 
1809
- def update_model_matrix(self, model_matrix: Optional[Mat44] = None):
1811
+ def update_model_matrix(self, model_matrix: Mat44 | None = None):
1810
1812
  gl = OpenGLRenderer.gl
1811
1813
 
1812
1814
  self._switch_context()
@@ -3092,7 +3094,7 @@ Instances: {len(self._instances)}"""
3092
3094
  parent_body: str = None,
3093
3095
  is_template: bool = False,
3094
3096
  up_axis: int = 1,
3095
- color: Tuple[float, float, float] = None,
3097
+ color: tuple[float, float, float] = None,
3096
3098
  ):
3097
3099
  """Add a arrow for visualization
3098
3100
 
warp/render/render_usd.py CHANGED
@@ -13,10 +13,19 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  import numpy as np
17
19
 
18
20
  import warp as wp
19
21
 
22
+ UP_AXIS_TOKEN = ("X", "Y", "Z")
23
+ UP_AXIS_VEC = (
24
+ np.array((1.0, 0.0, 0.0), dtype=float),
25
+ np.array((0.0, 1.0, 0.0), dtype=float),
26
+ np.array((0.0, 0.0, 1.0), dtype=float),
27
+ )
28
+
20
29
 
21
30
  def _usd_add_xform(prim):
22
31
  from pxr import UsdGeom
@@ -29,7 +38,13 @@ def _usd_add_xform(prim):
29
38
  prim.AddScaleOp()
30
39
 
31
40
 
32
- def _usd_set_xform(xform, pos: tuple, rot: tuple, scale: tuple, time):
41
+ def _usd_set_xform(
42
+ xform,
43
+ pos: tuple | None = None,
44
+ rot: tuple | None = None,
45
+ scale: tuple | None = None,
46
+ time: float = 0.0,
47
+ ):
33
48
  from pxr import Gf, UsdGeom
34
49
 
35
50
  xform = UsdGeom.Xform(xform)
@@ -108,7 +123,7 @@ class UsdRenderer:
108
123
  self.stage.SetDefaultPrim(self.root.GetPrim())
109
124
  self.stage.SetStartTimeCode(0.0)
110
125
  self.stage.SetEndTimeCode(0.0)
111
- self.stage.SetTimeCodesPerSecond(self.fps)
126
+ self.stage.SetFramesPerSecond(self.fps)
112
127
 
113
128
  if up_axis == "X":
114
129
  UsdGeom.SetStageUpAxis(self.stage, UsdGeom.Tokens.x)
@@ -622,7 +637,82 @@ class UsdRenderer:
622
637
 
623
638
  return prim_path
624
639
 
625
- def render_line_list(self, name, vertices, indices, color, radius):
640
+ def render_arrow(
641
+ self,
642
+ name: str,
643
+ pos: tuple,
644
+ rot: tuple,
645
+ base_radius: float,
646
+ base_height: float,
647
+ cap_radius: float = None,
648
+ cap_height: float = None,
649
+ parent_body: str = None,
650
+ is_template: bool = False,
651
+ up_axis: int = 1,
652
+ color: tuple[float, float, float] = None,
653
+ visible: bool = True,
654
+ ):
655
+ from pxr import Gf, Sdf, UsdGeom
656
+
657
+ if is_template:
658
+ prim_path = self._resolve_path(name, parent_body, is_template)
659
+ blueprint = UsdGeom.Scope.Define(self.stage, prim_path)
660
+ blueprint_prim = blueprint.GetPrim()
661
+ blueprint_prim.SetInstanceable(True)
662
+ blueprint_prim.SetSpecifier(Sdf.SpecifierClass)
663
+ arrow_path = prim_path.AppendChild("arrow")
664
+ else:
665
+ arrow_path = self._resolve_path(name, parent_body)
666
+ prim_path = arrow_path
667
+
668
+ arrow = UsdGeom.Xform.Get(self.stage, arrow_path)
669
+ if not arrow:
670
+ arrow = UsdGeom.Xform.Define(self.stage, arrow_path)
671
+ _usd_add_xform(arrow)
672
+
673
+ base_path = arrow_path.AppendChild("base")
674
+ base = UsdGeom.Xform.Get(self.stage, base_path)
675
+ if not base:
676
+ base = UsdGeom.Cylinder.Define(self.stage, base_path)
677
+ _usd_add_xform(base)
678
+
679
+ base.GetRadiusAttr().Set(float(base_radius))
680
+ base.GetHeightAttr().Set(float(base_height))
681
+ base.GetAxisAttr().Set(UP_AXIS_TOKEN[up_axis])
682
+ _usd_set_xform(base, UP_AXIS_VEC[up_axis] * base_height * 0.5)
683
+
684
+ cap_path = arrow_path.AppendChild("cap")
685
+ cap = UsdGeom.Xform.Get(self.stage, cap_path)
686
+ if not cap:
687
+ cap = UsdGeom.Cone.Define(self.stage, arrow_path.AppendChild("cap"))
688
+ _usd_add_xform(cap)
689
+
690
+ cap.GetRadiusAttr().Set(float(cap_radius))
691
+ cap.GetHeightAttr().Set(float(cap_height))
692
+ cap.GetAxisAttr().Set(UP_AXIS_TOKEN[up_axis])
693
+ _usd_set_xform(cap, UP_AXIS_VEC[up_axis] * (base_height + cap_height * 0.5))
694
+
695
+ if color is not None:
696
+ base.GetDisplayColorAttr().Set([Gf.Vec3f(color)], self.time)
697
+ cap.GetDisplayColorAttr().Set([Gf.Vec3f(color)], self.time)
698
+
699
+ self._shape_constructors[name] = UsdGeom.Xform
700
+
701
+ if not is_template:
702
+ _usd_set_xform(arrow, pos, rot, (1.0, 1.0, 1.0), self.time)
703
+
704
+ arrow.GetVisibilityAttr().Set("inherited" if visible else "invisible", self.time)
705
+ return prim_path
706
+
707
+ def render_line_list(
708
+ self,
709
+ name: str,
710
+ vertices,
711
+ indices,
712
+ color: tuple = None,
713
+ radius: float = 0.01,
714
+ visible: bool = True,
715
+ ):
626
716
  """Debug helper to add a line list as a set of capsules
627
717
 
628
718
  Args:
warp/sim/collide.py CHANGED
@@ -614,9 +614,9 @@ def volume_grad(volume: wp.uint64, p: wp.vec3):
614
614
  @wp.func
615
615
  def counter_increment(counter: wp.array(dtype=int), counter_index: int, tids: wp.array(dtype=int), tid: int):
616
616
  # increment counter, remember which thread received which counter value
617
- next_count = wp.atomic_add(counter, counter_index, 1)
618
- tids[tid] = next_count
619
- return next_count
617
+ count = wp.atomic_add(counter, counter_index, 1)
618
+ tids[tid] = count
619
+ return count
620
620
 
621
621
 
622
622
  @wp.func_replay(counter_increment)
@@ -629,10 +629,10 @@ def limited_counter_increment(
629
629
  counter: wp.array(dtype=int), counter_index: int, tids: wp.array(dtype=int), tid: int, index_limit: int
630
630
  ):
631
631
  # increment counter but only if it is smaller than index_limit, remember which thread received which counter value
632
- next_count = wp.atomic_add(counter, counter_index, 1)
633
- if next_count < index_limit or index_limit < 0:
634
- tids[tid] = next_count
635
- return next_count
632
+ count = wp.atomic_add(counter, counter_index, 1)
633
+ if count < index_limit or index_limit < 0:
634
+ tids[tid] = count
635
+ return count
636
636
  tids[tid] = -1
637
637
  return -1
638
638
 
@@ -1547,6 +1547,8 @@ def handle_contact_pairs(
1547
1547
  # reached contact point limit
1548
1548
  return
1549
1549
  index = counter_increment(contact_count, 0, contact_tids, tid)
1550
+ if index == -1:
1551
+ return
1550
1552
  contact_shape0[index] = shape_a
1551
1553
  contact_shape1[index] = shape_b
1552
1554
  # transform from world into body frame (so the contact point includes the shape transform)
@@ -1690,7 +1692,7 @@ def collide(
1690
1692
  model.rigid_contact_normal = wp.empty_like(model.rigid_contact_normal)
1691
1693
  model.rigid_contact_thickness = wp.empty_like(model.rigid_contact_thickness)
1692
1694
  model.rigid_contact_count = wp.zeros_like(model.rigid_contact_count)
1693
- model.rigid_contact_tids = wp.zeros_like(model.rigid_contact_tids)
1695
+ model.rigid_contact_tids = wp.full_like(model.rigid_contact_tids, -1)
1694
1696
  model.rigid_contact_shape0 = wp.empty_like(model.rigid_contact_shape0)
1695
1697
  model.rigid_contact_shape1 = wp.empty_like(model.rigid_contact_shape1)
1696
1698
 
@@ -1698,7 +1700,7 @@ def collide(
1698
1700
  model.rigid_contact_pairwise_counter = wp.zeros_like(model.rigid_contact_pairwise_counter)
1699
1701
  else:
1700
1702
  model.rigid_contact_count.zero_()
1701
- model.rigid_contact_tids.zero_()
1703
+ model.rigid_contact_tids.fill_(-1)
1702
1704
 
1703
1705
  if model.rigid_contact_pairwise_counter is not None:
1704
1706
  model.rigid_contact_pairwise_counter.zero_()