warp-lang 1.7.2__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/sim/model.py CHANGED
@@ -19,7 +19,7 @@ from __future__ import annotations
19
19
 
20
20
  import copy
21
21
  import math
22
- from typing import List, Optional, Tuple
22
+ from typing import List, Tuple
23
23
 
24
24
  import numpy as np
25
25
 
@@ -187,7 +187,7 @@ class SDF:
187
187
  return self.volume.id
188
188
 
189
189
  def __hash__(self):
190
- return hash((self.volume.id))
190
+ return hash(self.volume.id)
191
191
 
192
192
 
193
193
  class Mesh:
@@ -219,7 +219,7 @@ class Mesh:
219
219
  com (Vec3): The center of mass of the body
220
220
  """
221
221
 
222
- def __init__(self, vertices: List[Vec3], indices: List[int], compute_inertia=True, is_solid=True):
222
+ def __init__(self, vertices: list[Vec3], indices: list[int], compute_inertia=True, is_solid=True):
223
223
  """Construct a Mesh object from a triangle mesh
224
224
 
225
225
  The mesh center of mass and inertia tensor will automatically be
@@ -282,24 +282,24 @@ class State:
282
282
  """
283
283
 
284
284
  def __init__(self):
285
- self.particle_q: Optional[wp.array] = None
285
+ self.particle_q: wp.array | None = None
286
286
  """Array of 3D particle positions with shape ``(particle_count,)`` and type :class:`vec3`."""
287
287
 
288
- self.particle_qd: Optional[wp.array] = None
288
+ self.particle_qd: wp.array | None = None
289
289
  """Array of 3D particle velocities with shape ``(particle_count,)`` and type :class:`vec3`."""
290
290
 
291
- self.particle_f: Optional[wp.array] = None
291
+ self.particle_f: wp.array | None = None
292
292
  """Array of 3D particle forces with shape ``(particle_count,)`` and type :class:`vec3`."""
293
293
 
294
- self.body_q: Optional[wp.array] = None
294
+ self.body_q: wp.array | None = None
295
295
  """Array of body coordinates (7-dof transforms) in maximal coordinates with shape ``(body_count,)`` and type :class:`transform`."""
296
296
 
297
- self.body_qd: Optional[wp.array] = None
297
+ self.body_qd: wp.array | None = None
298
298
  """Array of body velocities in maximal coordinates (first three entries represent angular velocity,
299
299
  last three entries represent linear velocity) with shape ``(body_count,)`` and type :class:`spatial_vector`.
300
300
  """
301
301
 
302
- self.body_f: Optional[wp.array] = None
302
+ self.body_f: wp.array | None = None
303
303
  """Array of body forces in maximal coordinates (first three entries represent torque, last three
304
304
  entries represent linear force) with shape ``(body_count,)`` and type :class:`spatial_vector`.
305
305
 
@@ -309,10 +309,10 @@ class State:
309
309
  assumes the wrenches are measured w.r.t. world origin.
310
310
  """
311
311
 
312
- self.joint_q: Optional[wp.array] = None
312
+ self.joint_q: wp.array | None = None
313
313
  """Array of generalized joint coordinates with shape ``(joint_coord_count,)`` and type ``float``."""
314
314
 
315
- self.joint_qd: Optional[wp.array] = None
315
+ self.joint_qd: wp.array | None = None
316
316
  """Array of generalized joint velocities with shape ``(joint_dof_count,)`` and type ``float``."""
317
317
 
318
318
  def clear_forces(self) -> None:
@@ -372,7 +372,7 @@ class Control:
372
372
  should generally be created using the :func:`Model.control()` function.
373
373
  """
374
374
 
375
- def __init__(self, model: Model = None):
375
+ def __init__(self, model: Model | None = None):
376
376
  if model:
377
377
  wp.utils.warn(
378
378
  "Passing arguments to Control's __init__ is deprecated\n"
@@ -381,16 +381,16 @@ class Control:
381
381
  stacklevel=2,
382
382
  )
383
383
 
384
- self.joint_act: Optional[wp.array] = None
384
+ self.joint_act: wp.array | None = None
385
385
  """Array of joint control inputs with shape ``(joint_axis_count,)`` and type ``float``."""
386
386
 
387
- self.tri_activations: Optional[wp.array] = None
387
+ self.tri_activations: wp.array | None = None
388
388
  """Array of triangle element activations with shape ``(tri_count,)`` and type ``float``."""
389
389
 
390
- self.tet_activations: Optional[wp.array] = None
390
+ self.tet_activations: wp.array | None = None
391
391
  """Array of tetrahedral element activations with shape with shape ``(tet_count,) and type ``float``."""
392
392
 
393
- self.muscle_activations: Optional[wp.array] = None
393
+ self.muscle_activations: wp.array | None = None
394
394
  """Array of muscle activations with shape ``(muscle_count,)`` and type ``float``."""
395
395
 
396
396
  def clear(self) -> None:
@@ -499,7 +499,7 @@ def compute_shape_mass(type, scale, src, density, is_solid, thickness):
499
499
  vertices = np.array(src.vertices) * np.array(scale)
500
500
  m, c, I, vol = compute_mesh_inertia(density, vertices, src.indices, is_solid, thickness)
501
501
  return m, c, I
502
- raise ValueError("Unsupported shape type: {}".format(type))
502
+ raise ValueError(f"Unsupported shape type: {type}")
503
503
 
504
504
 
505
505
  class Model:
@@ -678,7 +678,8 @@ class Model:
678
678
  joint_dof_count (int): Total number of velocity degrees of freedom of all joints in the system
679
679
  joint_coord_count (int): Total number of position degrees of freedom of all joints in the system
680
680
 
681
- particle_coloring (list of array): The coloring of all the particles, used for VBD's Gauss-Seidel iteration.
681
+ particle_color_groups (list of array): The coloring of all the particles, used for VBD's Gauss-Seidel iteration. Each array contains indices of particles sharing the same color.
682
+ particle_colors (array): Contains the color assignment for every particle
682
683
 
683
684
  device (wp.Device): Device on which the Model was allocated
684
685
 
@@ -858,7 +859,10 @@ class Model:
858
859
  self.joint_dof_count = 0
859
860
  self.joint_coord_count = 0
860
861
 
861
- self.particle_coloring = []
862
+ # indices of particles sharing the same color
863
+ self.particle_color_groups = []
864
+ # the color of each particles
865
+ self.particle_colors = None
862
866
 
863
867
  self.device = wp.get_device(device)
864
868
 
@@ -1183,7 +1187,7 @@ class ModelBuilder:
1183
1187
  self.particle_flags = []
1184
1188
  self.particle_max_velocity = 1e5
1185
1189
  # list of np.array
1186
- self.particle_coloring = []
1190
+ self.particle_color_groups = []
1187
1191
 
1188
1192
  # shapes (each shape has an entry in these arrays)
1189
1193
  # transform from shape to body
@@ -1433,9 +1437,9 @@ class ModelBuilder:
1433
1437
  if builder.tet_count:
1434
1438
  self.tet_indices.extend((np.array(builder.tet_indices, dtype=np.int32) + start_particle_idx).tolist())
1435
1439
 
1436
- builder_coloring_translated = [group + start_particle_idx for group in builder.particle_coloring]
1437
- self.particle_coloring = combine_independent_particle_coloring(
1438
- self.particle_coloring, builder_coloring_translated
1440
+ builder_coloring_translated = [group + start_particle_idx for group in builder.particle_color_groups]
1441
+ self.particle_color_groups = combine_independent_particle_coloring(
1442
+ self.particle_color_groups, builder_coloring_translated
1439
1443
  )
1440
1444
 
1441
1445
  start_body_idx = self.body_count
@@ -1593,12 +1597,12 @@ class ModelBuilder:
1593
1597
  # register a rigid body and return its index.
1594
1598
  def add_body(
1595
1599
  self,
1596
- origin: Optional[Transform] = None,
1600
+ origin: Transform | None = None,
1597
1601
  armature: float = 0.0,
1598
- com: Optional[Vec3] = None,
1599
- I_m: Optional[Mat33] = None,
1602
+ com: Vec3 | None = None,
1603
+ I_m: Mat33 | None = None,
1600
1604
  m: float = 0.0,
1601
- name: str = None,
1605
+ name: str | None = None,
1602
1606
  ) -> int:
1603
1607
  """Adds a rigid body to the model.
1604
1608
 
@@ -1657,11 +1661,11 @@ class ModelBuilder:
1657
1661
  joint_type: wp.constant,
1658
1662
  parent: int,
1659
1663
  child: int,
1660
- linear_axes: Optional[List[JointAxis]] = None,
1661
- angular_axes: Optional[List[JointAxis]] = None,
1662
- name: str = None,
1663
- parent_xform: Optional[wp.transform] = None,
1664
- child_xform: Optional[wp.transform] = None,
1664
+ linear_axes: list[JointAxis] | None = None,
1665
+ angular_axes: list[JointAxis] | None = None,
1666
+ name: str | None = None,
1667
+ parent_xform: wp.transform | None = None,
1668
+ child_xform: wp.transform | None = None,
1665
1669
  linear_compliance: float = 0.0,
1666
1670
  angular_compliance: float = 0.0,
1667
1671
  armature: float = 1e-2,
@@ -1797,21 +1801,21 @@ class ModelBuilder:
1797
1801
  self,
1798
1802
  parent: int,
1799
1803
  child: int,
1800
- parent_xform: Optional[wp.transform] = None,
1801
- child_xform: Optional[wp.transform] = None,
1804
+ parent_xform: wp.transform | None = None,
1805
+ child_xform: wp.transform | None = None,
1802
1806
  axis: Vec3 = (1.0, 0.0, 0.0),
1803
- target: float = None,
1807
+ target: float | None = None,
1804
1808
  target_ke: float = 0.0,
1805
1809
  target_kd: float = 0.0,
1806
1810
  mode: int = JOINT_MODE_FORCE,
1807
1811
  limit_lower: float = -2 * math.pi,
1808
1812
  limit_upper: float = 2 * math.pi,
1809
- limit_ke: float = None,
1810
- limit_kd: float = None,
1813
+ limit_ke: float | None = None,
1814
+ limit_kd: float | None = None,
1811
1815
  linear_compliance: float = 0.0,
1812
1816
  angular_compliance: float = 0.0,
1813
1817
  armature: float = 1e-2,
1814
- name: str = None,
1818
+ name: str | None = None,
1815
1819
  collision_filter_parent: bool = True,
1816
1820
  enabled: bool = True,
1817
1821
  ) -> int:
@@ -1887,21 +1891,21 @@ class ModelBuilder:
1887
1891
  self,
1888
1892
  parent: int,
1889
1893
  child: int,
1890
- parent_xform: Optional[wp.transform] = None,
1891
- child_xform: Optional[wp.transform] = None,
1894
+ parent_xform: wp.transform | None = None,
1895
+ child_xform: wp.transform | None = None,
1892
1896
  axis: Vec3 = (1.0, 0.0, 0.0),
1893
- target: float = None,
1897
+ target: float | None = None,
1894
1898
  target_ke: float = 0.0,
1895
1899
  target_kd: float = 0.0,
1896
1900
  mode: int = JOINT_MODE_FORCE,
1897
1901
  limit_lower: float = -1e4,
1898
1902
  limit_upper: float = 1e4,
1899
- limit_ke: float = None,
1900
- limit_kd: float = None,
1903
+ limit_ke: float | None = None,
1904
+ limit_kd: float | None = None,
1901
1905
  linear_compliance: float = 0.0,
1902
1906
  angular_compliance: float = 0.0,
1903
1907
  armature: float = 1e-2,
1904
- name: str = None,
1908
+ name: str | None = None,
1905
1909
  collision_filter_parent: bool = True,
1906
1910
  enabled: bool = True,
1907
1911
  ) -> int:
@@ -1977,12 +1981,12 @@ class ModelBuilder:
1977
1981
  self,
1978
1982
  parent: int,
1979
1983
  child: int,
1980
- parent_xform: Optional[wp.transform] = None,
1981
- child_xform: Optional[wp.transform] = None,
1984
+ parent_xform: wp.transform | None = None,
1985
+ child_xform: wp.transform | None = None,
1982
1986
  linear_compliance: float = 0.0,
1983
1987
  angular_compliance: float = 0.0,
1984
1988
  armature: float = 1e-2,
1985
- name: str = None,
1989
+ name: str | None = None,
1986
1990
  collision_filter_parent: bool = True,
1987
1991
  enabled: bool = True,
1988
1992
  ) -> int:
@@ -2028,12 +2032,12 @@ class ModelBuilder:
2028
2032
  self,
2029
2033
  parent: int,
2030
2034
  child: int,
2031
- parent_xform: Optional[wp.transform] = None,
2032
- child_xform: Optional[wp.transform] = None,
2035
+ parent_xform: wp.transform | None = None,
2036
+ child_xform: wp.transform | None = None,
2033
2037
  linear_compliance: float = 0.0,
2034
2038
  angular_compliance: float = 0.0,
2035
2039
  armature: float = 1e-2,
2036
- name: str = None,
2040
+ name: str | None = None,
2037
2041
  collision_filter_parent: bool = True,
2038
2042
  enabled: bool = True,
2039
2043
  ) -> int:
@@ -2079,11 +2083,11 @@ class ModelBuilder:
2079
2083
  def add_joint_free(
2080
2084
  self,
2081
2085
  child: int,
2082
- parent_xform: Optional[wp.transform] = None,
2083
- child_xform: Optional[wp.transform] = None,
2086
+ parent_xform: wp.transform | None = None,
2087
+ child_xform: wp.transform | None = None,
2084
2088
  armature: float = 0.0,
2085
2089
  parent: int = -1,
2086
- name: str = None,
2090
+ name: str | None = None,
2087
2091
  collision_filter_parent: bool = True,
2088
2092
  enabled: bool = True,
2089
2093
  ) -> int:
@@ -2126,8 +2130,8 @@ class ModelBuilder:
2126
2130
  self,
2127
2131
  parent: int,
2128
2132
  child: int,
2129
- parent_xform: Optional[wp.transform] = None,
2130
- child_xform: Optional[wp.transform] = None,
2133
+ parent_xform: wp.transform | None = None,
2134
+ child_xform: wp.transform | None = None,
2131
2135
  min_distance: float = -1.0,
2132
2136
  max_distance: float = 1.0,
2133
2137
  compliance: float = 0.0,
@@ -2183,12 +2187,12 @@ class ModelBuilder:
2183
2187
  child: int,
2184
2188
  axis_0: JointAxis,
2185
2189
  axis_1: JointAxis,
2186
- parent_xform: Optional[wp.transform] = None,
2187
- child_xform: Optional[wp.transform] = None,
2190
+ parent_xform: wp.transform | None = None,
2191
+ child_xform: wp.transform | None = None,
2188
2192
  linear_compliance: float = 0.0,
2189
2193
  angular_compliance: float = 0.0,
2190
2194
  armature: float = 1e-2,
2191
- name: str = None,
2195
+ name: str | None = None,
2192
2196
  collision_filter_parent: bool = True,
2193
2197
  enabled: bool = True,
2194
2198
  ) -> int:
@@ -2240,12 +2244,12 @@ class ModelBuilder:
2240
2244
  axis_0: JointAxis,
2241
2245
  axis_1: JointAxis,
2242
2246
  axis_2: JointAxis,
2243
- parent_xform: Optional[wp.transform] = None,
2244
- child_xform: Optional[wp.transform] = None,
2247
+ parent_xform: wp.transform | None = None,
2248
+ child_xform: wp.transform | None = None,
2245
2249
  linear_compliance: float = 0.0,
2246
2250
  angular_compliance: float = 0.0,
2247
2251
  armature: float = 1e-2,
2248
- name: str = None,
2252
+ name: str | None = None,
2249
2253
  collision_filter_parent: bool = True,
2250
2254
  enabled: bool = True,
2251
2255
  ) -> int:
@@ -2298,11 +2302,11 @@ class ModelBuilder:
2298
2302
  self,
2299
2303
  parent: int,
2300
2304
  child: int,
2301
- linear_axes: Optional[List[JointAxis]] = None,
2302
- angular_axes: Optional[List[JointAxis]] = None,
2303
- name: str = None,
2304
- parent_xform: Optional[wp.transform] = None,
2305
- child_xform: Optional[wp.transform] = None,
2305
+ linear_axes: list[JointAxis] | None = None,
2306
+ angular_axes: list[JointAxis] | None = None,
2307
+ name: str | None = None,
2308
+ parent_xform: wp.transform | None = None,
2309
+ child_xform: wp.transform | None = None,
2306
2310
  linear_compliance: float = 0.0,
2307
2311
  angular_compliance: float = 0.0,
2308
2312
  armature: float = 1e-2,
@@ -2425,7 +2429,7 @@ class ModelBuilder:
2425
2429
  return "unknown"
2426
2430
 
2427
2431
  if show_body_names:
2428
- vertices = ["world"] + self.body_name
2432
+ vertices = ["world", *self.body_name]
2429
2433
  else:
2430
2434
  vertices = ["-1"] + [str(i) for i in range(self.body_count)]
2431
2435
  if plot_shapes:
@@ -2749,7 +2753,7 @@ class ModelBuilder:
2749
2753
 
2750
2754
  # muscles
2751
2755
  def add_muscle(
2752
- self, bodies: List[int], positions: List[Vec3], f0: float, lm: float, lt: float, lmax: float, pen: float
2756
+ self, bodies: list[int], positions: list[Vec3], f0: float, lm: float, lt: float, lmax: float, pen: float
2753
2757
  ) -> float:
2754
2758
  """Adds a muscle-tendon activation unit.
2755
2759
 
@@ -2784,28 +2788,28 @@ class ModelBuilder:
2784
2788
  # shapes
2785
2789
  def add_shape_plane(
2786
2790
  self,
2787
- plane: Vec4 = (0.0, 1.0, 0.0, 0.0),
2788
- pos: Vec3 = None,
2789
- rot: Quat = None,
2791
+ plane: Vec4 | tuple[float, float, float, float] = (0.0, 1.0, 0.0, 0.0),
2792
+ pos: Vec3 | None = None,
2793
+ rot: Quat | None = None,
2790
2794
  width: float = 10.0,
2791
2795
  length: float = 10.0,
2792
2796
  body: int = -1,
2793
- ke: float = None,
2794
- kd: float = None,
2795
- kf: float = None,
2796
- ka: float = None,
2797
- mu: float = None,
2798
- restitution: float = None,
2799
- thickness: float = None,
2797
+ ke: float | None = None,
2798
+ kd: float | None = None,
2799
+ kf: float | None = None,
2800
+ ka: float | None = None,
2801
+ mu: float | None = None,
2802
+ restitution: float | None = None,
2803
+ thickness: float | None = None,
2800
2804
  has_ground_collision: bool = False,
2801
2805
  has_shape_collision: bool = True,
2802
2806
  is_visible: bool = True,
2803
2807
  collision_group: int = -1,
2804
- ):
2805
- """
2806
- Adds a plane collision shape.
2807
- If pos and rot are defined, the plane is assumed to have its normal as (0, 1, 0).
2808
- Otherwise, the plane equation defined through the `plane` argument is used.
2808
+ ) -> int:
2809
+ """Add a plane collision shape.
2810
+
2811
+ If ``pos`` and ``rot`` are defined, the plane is assumed to have its normal as (0, 1, 0).
2812
+ Otherwise, the plane equation defined through the ``plane`` argument is used.
2809
2813
 
2810
2814
  Args:
2811
2815
  plane: The plane equation in form a*x + b*y + c*z + d = 0
@@ -2869,24 +2873,24 @@ class ModelBuilder:
2869
2873
  def add_shape_sphere(
2870
2874
  self,
2871
2875
  body,
2872
- pos: Vec3 = (0.0, 0.0, 0.0),
2873
- rot: Quat = (0.0, 0.0, 0.0, 1.0),
2876
+ pos: Vec3 | tuple[float, float, float] = (0.0, 0.0, 0.0),
2877
+ rot: Quat | tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
2874
2878
  radius: float = 1.0,
2875
- density: float = None,
2876
- ke: float = None,
2877
- kd: float = None,
2878
- kf: float = None,
2879
- ka: float = None,
2880
- mu: float = None,
2881
- restitution: float = None,
2879
+ density: float | None = None,
2880
+ ke: float | None = None,
2881
+ kd: float | None = None,
2882
+ kf: float | None = None,
2883
+ ka: float | None = None,
2884
+ mu: float | None = None,
2885
+ restitution: float | None = None,
2882
2886
  is_solid: bool = True,
2883
- thickness: float = None,
2887
+ thickness: float | None = None,
2884
2888
  has_ground_collision: bool = True,
2885
2889
  has_shape_collision: bool = True,
2886
2890
  collision_group: int = -1,
2887
2891
  is_visible: bool = True,
2888
- ):
2889
- """Adds a sphere collision shape to a body.
2892
+ ) -> int:
2893
+ """Add a sphere collision shape to a body.
2890
2894
 
2891
2895
  Args:
2892
2896
  body: The index of the parent body this shape belongs to (use -1 for static shapes)
@@ -2938,26 +2942,26 @@ class ModelBuilder:
2938
2942
  def add_shape_box(
2939
2943
  self,
2940
2944
  body: int,
2941
- pos: Vec3 = (0.0, 0.0, 0.0),
2942
- rot: Quat = (0.0, 0.0, 0.0, 1.0),
2945
+ pos: Vec3 | tuple[float, float, float] = (0.0, 0.0, 0.0),
2946
+ rot: Quat | tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
2943
2947
  hx: float = 0.5,
2944
2948
  hy: float = 0.5,
2945
2949
  hz: float = 0.5,
2946
- density: float = None,
2947
- ke: float = None,
2948
- kd: float = None,
2949
- kf: float = None,
2950
- ka: float = None,
2951
- mu: float = None,
2952
- restitution: float = None,
2950
+ density: float | None = None,
2951
+ ke: float | None = None,
2952
+ kd: float | None = None,
2953
+ kf: float | None = None,
2954
+ ka: float | None = None,
2955
+ mu: float | None = None,
2956
+ restitution: float | None = None,
2953
2957
  is_solid: bool = True,
2954
- thickness: float = None,
2958
+ thickness: float | None = None,
2955
2959
  has_ground_collision: bool = True,
2956
2960
  has_shape_collision: bool = True,
2957
2961
  collision_group: int = -1,
2958
2962
  is_visible: bool = True,
2959
- ):
2960
- """Adds a box collision shape to a body.
2963
+ ) -> int:
2964
+ """Add a box collision shape to a body.
2961
2965
 
2962
2966
  Args:
2963
2967
  body: The index of the parent body this shape belongs to (use -1 for static shapes)
@@ -2982,7 +2986,6 @@ class ModelBuilder:
2982
2986
 
2983
2987
  Returns:
2984
2988
  The index of the added shape
2985
-
2986
2989
  """
2987
2990
 
2988
2991
  return self._add_shape(
@@ -3010,26 +3013,26 @@ class ModelBuilder:
3010
3013
  def add_shape_capsule(
3011
3014
  self,
3012
3015
  body: int,
3013
- pos: Vec3 = (0.0, 0.0, 0.0),
3014
- rot: Quat = (0.0, 0.0, 0.0, 1.0),
3016
+ pos: Vec3 | tuple[float, float, float] = (0.0, 0.0, 0.0),
3017
+ rot: Quat | tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
3015
3018
  radius: float = 1.0,
3016
3019
  half_height: float = 0.5,
3017
3020
  up_axis: int = 1,
3018
- density: float = None,
3019
- ke: float = None,
3020
- kd: float = None,
3021
- kf: float = None,
3022
- ka: float = None,
3023
- mu: float = None,
3024
- restitution: float = None,
3021
+ density: float | None = None,
3022
+ ke: float | None = None,
3023
+ kd: float | None = None,
3024
+ kf: float | None = None,
3025
+ ka: float | None = None,
3026
+ mu: float | None = None,
3027
+ restitution: float | None = None,
3025
3028
  is_solid: bool = True,
3026
- thickness: float = None,
3029
+ thickness: float | None = None,
3027
3030
  has_ground_collision: bool = True,
3028
3031
  has_shape_collision: bool = True,
3029
3032
  collision_group: int = -1,
3030
3033
  is_visible: bool = True,
3031
- ):
3032
- """Adds a capsule collision shape to a body.
3034
+ ) -> int:
3035
+ """Add a capsule collision shape to a body.
3033
3036
 
3034
3037
  Args:
3035
3038
  body: The index of the parent body this shape belongs to (use -1 for static shapes)
@@ -3090,26 +3093,26 @@ class ModelBuilder:
3090
3093
  def add_shape_cylinder(
3091
3094
  self,
3092
3095
  body: int,
3093
- pos: Vec3 = (0.0, 0.0, 0.0),
3094
- rot: Quat = (0.0, 0.0, 0.0, 1.0),
3096
+ pos: Vec3 | tuple[float, float, float] = (0.0, 0.0, 0.0),
3097
+ rot: Quat | tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
3095
3098
  radius: float = 1.0,
3096
3099
  half_height: float = 0.5,
3097
3100
  up_axis: int = 1,
3098
- density: float = None,
3099
- ke: float = None,
3100
- kd: float = None,
3101
- kf: float = None,
3102
- ka: float = None,
3103
- mu: float = None,
3104
- restitution: float = None,
3101
+ density: float | None = None,
3102
+ ke: float | None = None,
3103
+ kd: float | None = None,
3104
+ kf: float | None = None,
3105
+ ka: float | None = None,
3106
+ mu: float | None = None,
3107
+ restitution: float | None = None,
3105
3108
  is_solid: bool = True,
3106
- thickness: float = None,
3109
+ thickness: float | None = None,
3107
3110
  has_ground_collision: bool = True,
3108
3111
  has_shape_collision: bool = True,
3109
3112
  collision_group: int = -1,
3110
3113
  is_visible: bool = True,
3111
- ):
3112
- """Adds a cylinder collision shape to a body.
3114
+ ) -> int:
3115
+ """Add a cylinder collision shape to a body.
3113
3116
 
3114
3117
  Args:
3115
3118
  body: The index of the parent body this shape belongs to (use -1 for static shapes)
@@ -3172,26 +3175,26 @@ class ModelBuilder:
3172
3175
  def add_shape_cone(
3173
3176
  self,
3174
3177
  body: int,
3175
- pos: Vec3 = (0.0, 0.0, 0.0),
3176
- rot: Quat = (0.0, 0.0, 0.0, 1.0),
3178
+ pos: Vec3 | tuple[float, float, float] = (0.0, 0.0, 0.0),
3179
+ rot: Quat | tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
3177
3180
  radius: float = 1.0,
3178
3181
  half_height: float = 0.5,
3179
3182
  up_axis: int = 1,
3180
- density: float = None,
3181
- ke: float = None,
3182
- kd: float = None,
3183
- kf: float = None,
3184
- ka: float = None,
3185
- mu: float = None,
3186
- restitution: float = None,
3183
+ density: float | None = None,
3184
+ ke: float | None = None,
3185
+ kd: float | None = None,
3186
+ kf: float | None = None,
3187
+ ka: float | None = None,
3188
+ mu: float | None = None,
3189
+ restitution: float | None = None,
3187
3190
  is_solid: bool = True,
3188
- thickness: float = None,
3191
+ thickness: float | None = None,
3189
3192
  has_ground_collision: bool = True,
3190
3193
  has_shape_collision: bool = True,
3191
3194
  collision_group: int = -1,
3192
3195
  is_visible: bool = True,
3193
- ):
3194
- """Adds a cone collision shape to a body.
3196
+ ) -> int:
3197
+ """Add a cone collision shape to a body.
3195
3198
 
3196
3199
  Args:
3197
3200
  body: The index of the parent body this shape belongs to (use -1 for static shapes)
@@ -3254,25 +3257,25 @@ class ModelBuilder:
3254
3257
  def add_shape_mesh(
3255
3258
  self,
3256
3259
  body: int,
3257
- pos: Optional[Vec3] = None,
3258
- rot: Optional[Quat] = None,
3259
- mesh: Optional[Mesh] = None,
3260
- scale: Optional[Vec3] = None,
3261
- density: float = None,
3262
- ke: float = None,
3263
- kd: float = None,
3264
- kf: float = None,
3265
- ka: float = None,
3266
- mu: float = None,
3267
- restitution: float = None,
3260
+ pos: Vec3 | None = None,
3261
+ rot: Quat | None = None,
3262
+ mesh: Mesh | None = None,
3263
+ scale: Vec3 | None = None,
3264
+ density: float | None = None,
3265
+ ke: float | None = None,
3266
+ kd: float | None = None,
3267
+ kf: float | None = None,
3268
+ ka: float | None = None,
3269
+ mu: float | None = None,
3270
+ restitution: float | None = None,
3268
3271
  is_solid: bool = True,
3269
- thickness: float = None,
3272
+ thickness: float | None = None,
3270
3273
  has_ground_collision: bool = True,
3271
3274
  has_shape_collision: bool = True,
3272
3275
  collision_group: int = -1,
3273
3276
  is_visible: bool = True,
3274
- ):
3275
- """Adds a triangle mesh collision shape to a body.
3277
+ ) -> int:
3278
+ """Add a triangle mesh collision shape to a body.
3276
3279
 
3277
3280
  Args:
3278
3281
  body: The index of the parent body this shape belongs to (use -1 for static shapes)
@@ -3335,25 +3338,25 @@ class ModelBuilder:
3335
3338
  def add_shape_sdf(
3336
3339
  self,
3337
3340
  body: int,
3338
- pos: Vec3 = (0.0, 0.0, 0.0),
3339
- rot: Quat = (0.0, 0.0, 0.0, 1.0),
3340
- sdf: SDF = None,
3341
- scale: Vec3 = (1.0, 1.0, 1.0),
3342
- density: float = None,
3343
- ke: float = None,
3344
- kd: float = None,
3345
- kf: float = None,
3346
- ka: float = None,
3347
- mu: float = None,
3348
- restitution: float = None,
3341
+ pos: Vec3 | tuple[float, float, float] = (0.0, 0.0, 0.0),
3342
+ rot: Quat | tuple[float, float, float, float] = (0.0, 0.0, 0.0, 1.0),
3343
+ sdf: SDF | None = None,
3344
+ scale: Vec3 | tuple[float, float, float] = (1.0, 1.0, 1.0),
3345
+ density: float | None = None,
3346
+ ke: float | None = None,
3347
+ kd: float | None = None,
3348
+ kf: float | None = None,
3349
+ ka: float | None = None,
3350
+ mu: float | None = None,
3351
+ restitution: float | None = None,
3349
3352
  is_solid: bool = True,
3350
- thickness: float = None,
3353
+ thickness: float | None = None,
3351
3354
  has_ground_collision: bool = True,
3352
3355
  has_shape_collision: bool = True,
3353
3356
  collision_group: int = -1,
3354
3357
  is_visible: bool = True,
3355
- ):
3356
- """Adds SDF collision shape to a body.
3358
+ ) -> int:
3359
+ """Add a SDF collision shape to a body.
3357
3360
 
3358
3361
  Args:
3359
3362
  body: The index of the parent body this shape belongs to (use -1 for static shapes)
@@ -3444,8 +3447,8 @@ class ModelBuilder:
3444
3447
  collision_filter_parent=True,
3445
3448
  has_ground_collision=True,
3446
3449
  has_shape_collision=True,
3447
- is_visible=True,
3448
- ):
3450
+ is_visible: bool = True,
3451
+ ) -> int:
3449
3452
  self.shape_body.append(body)
3450
3453
  shape = self.shape_count
3451
3454
  if body in self.body_shapes:
@@ -3504,7 +3507,7 @@ class ModelBuilder:
3504
3507
  pos: Vec3,
3505
3508
  vel: Vec3,
3506
3509
  mass: float,
3507
- radius: float = None,
3510
+ radius: float | None = None,
3508
3511
  flags: wp.uint32 = PARTICLE_FLAG_ACTIVE,
3509
3512
  ) -> int:
3510
3513
  """Adds a single particle to the model
@@ -3569,11 +3572,11 @@ class ModelBuilder:
3569
3572
  i: int,
3570
3573
  j: int,
3571
3574
  k: int,
3572
- tri_ke: float = None,
3573
- tri_ka: float = None,
3574
- tri_kd: float = None,
3575
- tri_drag: float = None,
3576
- tri_lift: float = None,
3575
+ tri_ke: float | None = None,
3576
+ tri_ka: float | None = None,
3577
+ tri_kd: float | None = None,
3578
+ tri_drag: float | None = None,
3579
+ tri_lift: float | None = None,
3577
3580
  ) -> float:
3578
3581
  """Adds a triangular FEM element between three particles in the system.
3579
3582
 
@@ -3634,15 +3637,15 @@ class ModelBuilder:
3634
3637
 
3635
3638
  def add_triangles(
3636
3639
  self,
3637
- i: List[int],
3638
- j: List[int],
3639
- k: List[int],
3640
- tri_ke: Optional[List[float]] = None,
3641
- tri_ka: Optional[List[float]] = None,
3642
- tri_kd: Optional[List[float]] = None,
3643
- tri_drag: Optional[List[float]] = None,
3644
- tri_lift: Optional[List[float]] = None,
3645
- ) -> List[float]:
3640
+ i: list[int],
3641
+ j: list[int],
3642
+ k: list[int],
3643
+ tri_ke: list[float] | None = None,
3644
+ tri_ka: list[float] | None = None,
3645
+ tri_kd: list[float] | None = None,
3646
+ tri_drag: list[float] | None = None,
3647
+ tri_lift: list[float] | None = None,
3648
+ ) -> list[float]:
3646
3649
  """Adds triangular FEM elements between groups of three particles in the system.
3647
3650
 
3648
3651
  Triangles are modeled as viscoelastic elements with elastic stiffness and damping
@@ -3777,10 +3780,10 @@ class ModelBuilder:
3777
3780
  j: int,
3778
3781
  k: int,
3779
3782
  l: int,
3780
- rest: float = None,
3781
- edge_ke: float = None,
3782
- edge_kd: float = None,
3783
- ):
3783
+ rest: float | None = None,
3784
+ edge_ke: float | None = None,
3785
+ edge_kd: float | None = None,
3786
+ ) -> None:
3784
3787
  """Adds a bending edge element between four particles in the system.
3785
3788
 
3786
3789
  Bending elements are designed to be between two connected triangles. Then
@@ -3832,10 +3835,10 @@ class ModelBuilder:
3832
3835
  j,
3833
3836
  k,
3834
3837
  l,
3835
- rest: Optional[List[float]] = None,
3836
- edge_ke: Optional[List[float]] = None,
3837
- edge_kd: Optional[List[float]] = None,
3838
- ):
3838
+ rest: list[float] | None = None,
3839
+ edge_ke: list[float] | None = None,
3840
+ edge_kd: list[float] | None = None,
3841
+ ) -> None:
3839
3842
  """Adds bending edge elements between groups of four particles in the system.
3840
3843
 
3841
3844
  Bending elements are designed to be between two connected triangles. Then
@@ -3914,18 +3917,18 @@ class ModelBuilder:
3914
3917
  fix_right: bool = False,
3915
3918
  fix_top: bool = False,
3916
3919
  fix_bottom: bool = False,
3917
- tri_ke: float = None,
3918
- tri_ka: float = None,
3919
- tri_kd: float = None,
3920
- tri_drag: float = None,
3921
- tri_lift: float = None,
3922
- edge_ke: float = None,
3923
- edge_kd: float = None,
3920
+ tri_ke: float | None = None,
3921
+ tri_ka: float | None = None,
3922
+ tri_kd: float | None = None,
3923
+ tri_drag: float | None = None,
3924
+ tri_lift: float | None = None,
3925
+ edge_ke: float | None = None,
3926
+ edge_kd: float | None = None,
3924
3927
  add_springs: bool = False,
3925
- spring_ke: float = None,
3926
- spring_kd: float = None,
3927
- particle_radius: float = None,
3928
- ):
3928
+ spring_ke: float | None = None,
3929
+ spring_kd: float | None = None,
3930
+ particle_radius: float | None = None,
3931
+ ) -> None:
3929
3932
  """Helper to create a regular planar cloth grid
3930
3933
 
3931
3934
  Creates a rectangular grid of particles with FEM triangles and bending elements
@@ -4054,23 +4057,23 @@ class ModelBuilder:
4054
4057
  rot: Quat,
4055
4058
  scale: float,
4056
4059
  vel: Vec3,
4057
- vertices: List[Vec3],
4058
- indices: List[int],
4060
+ vertices: list[Vec3],
4061
+ indices: list[int],
4059
4062
  density: float,
4060
4063
  edge_callback=None,
4061
4064
  face_callback=None,
4062
- tri_ke: float = None,
4063
- tri_ka: float = None,
4064
- tri_kd: float = None,
4065
- tri_drag: float = None,
4066
- tri_lift: float = None,
4067
- edge_ke: float = None,
4068
- edge_kd: float = None,
4065
+ tri_ke: float | None = None,
4066
+ tri_ka: float | None = None,
4067
+ tri_kd: float | None = None,
4068
+ tri_drag: float | None = None,
4069
+ tri_lift: float | None = None,
4070
+ edge_ke: float | None = None,
4071
+ edge_kd: float | None = None,
4069
4072
  add_springs: bool = False,
4070
- spring_ke: float = None,
4071
- spring_kd: float = None,
4072
- particle_radius: float = None,
4073
- ):
4073
+ spring_ke: float | None = None,
4074
+ spring_kd: float | None = None,
4075
+ particle_radius: float | None = None,
4076
+ ) -> None:
4074
4077
  """Helper to create a cloth model from a regular triangle mesh
4075
4078
 
4076
4079
  Creates one FEM triangle element and one bending element for every face
@@ -4179,9 +4182,9 @@ class ModelBuilder:
4179
4182
  cell_z: float,
4180
4183
  mass: float,
4181
4184
  jitter: float,
4182
- radius_mean: float = None,
4185
+ radius_mean: float | None = None,
4183
4186
  radius_std: float = 0.0,
4184
- ):
4187
+ ) -> None:
4185
4188
  radius_mean = radius_mean if radius_mean is not None else self.default_particle_radius
4186
4189
 
4187
4190
  rng = np.random.default_rng(42)
@@ -4218,12 +4221,12 @@ class ModelBuilder:
4218
4221
  fix_right: bool = False,
4219
4222
  fix_top: bool = False,
4220
4223
  fix_bottom: bool = False,
4221
- tri_ke: float = None,
4222
- tri_ka: float = None,
4223
- tri_kd: float = None,
4224
- tri_drag: float = None,
4225
- tri_lift: float = None,
4226
- ):
4224
+ tri_ke: float | None = None,
4225
+ tri_ka: float | None = None,
4226
+ tri_kd: float | None = None,
4227
+ tri_drag: float | None = None,
4228
+ tri_lift: float | None = None,
4229
+ ) -> None:
4227
4230
  """Helper to create a rectangular tetrahedral FEM grid
4228
4231
 
4229
4232
  Creates a regular grid of FEM tetrahedra and surface triangles. Useful for example
@@ -4339,18 +4342,18 @@ class ModelBuilder:
4339
4342
  rot: Quat,
4340
4343
  scale: float,
4341
4344
  vel: Vec3,
4342
- vertices: List[Vec3],
4343
- indices: List[int],
4345
+ vertices: list[Vec3],
4346
+ indices: list[int],
4344
4347
  density: float,
4345
4348
  k_mu: float,
4346
4349
  k_lambda: float,
4347
4350
  k_damp: float,
4348
- tri_ke: float = None,
4349
- tri_ka: float = None,
4350
- tri_kd: float = None,
4351
- tri_drag: float = None,
4352
- tri_lift: float = None,
4353
- ):
4351
+ tri_ke: float | None = None,
4352
+ tri_ka: float | None = None,
4353
+ tri_kd: float | None = None,
4354
+ tri_drag: float | None = None,
4355
+ tri_lift: float | None = None,
4356
+ ) -> None:
4354
4357
  """Helper to create a tetrahedral model from an input tetrahedral mesh
4355
4358
 
4356
4359
  Args:
@@ -4460,15 +4463,15 @@ class ModelBuilder:
4460
4463
  self,
4461
4464
  normal=None,
4462
4465
  offset=0.0,
4463
- ke: float = None,
4464
- kd: float = None,
4465
- kf: float = None,
4466
- mu: float = None,
4467
- restitution: float = None,
4468
- ):
4469
- """
4470
- Creates a ground plane for the world. If the normal is not specified,
4471
- the up_vector of the ModelBuilder is used.
4466
+ ke: float | None = None,
4467
+ kd: float | None = None,
4468
+ kf: float | None = None,
4469
+ mu: float | None = None,
4470
+ restitution: float | None = None,
4471
+ ) -> None:
4472
+ """Create a ground plane for the world.
4473
+
4474
+ If the normal is not specified, the ``up_vector`` of the :class:`ModelBuilder` is used.
4472
4475
  """
4473
4476
  ke = ke if ke is not None else self.default_shape_ke
4474
4477
  kd = kd if kd is not None else self.default_shape_kd
@@ -4489,26 +4492,25 @@ class ModelBuilder:
4489
4492
  "restitution": restitution,
4490
4493
  }
4491
4494
 
4492
- def _create_ground_plane(self):
4495
+ def _create_ground_plane(self) -> None:
4493
4496
  ground_id = self.add_shape_plane(**self._ground_params)
4494
4497
  self._ground_created = True
4495
4498
  # disable ground collisions as they will be treated separately
4496
4499
  for i in range(self.shape_count - 1):
4497
4500
  self.shape_collision_filter_pairs.add((i, ground_id))
4498
4501
 
4499
- def set_coloring(self, particle_coloring):
4500
- """
4501
- Set coloring information with user-provided coloring.
4502
+ def set_coloring(self, particle_color_groups):
4503
+ """Set coloring information with user-provided coloring.
4502
4504
 
4503
4505
  Args:
4504
- particle_coloring: A list of list or `np.array` with `dtype`=`int`. The length of the list is the number of colors
4505
- and each list or `np.array` contains the indices of vertices with this color.
4506
+ particle_color_groups: A list of list or `np.array` with `dtype`=`int`. The length of the list is the number of colors
4507
+ and each list or `np.array` contains the indices of vertices with this color.
4506
4508
  """
4507
- particle_coloring = [
4509
+ particle_color_groups = [
4508
4510
  color_group if isinstance(color_group, np.ndarray) else np.array(color_group)
4509
- for color_group in particle_coloring
4511
+ for color_group in particle_color_groups
4510
4512
  ]
4511
- self.particle_coloring = particle_coloring
4513
+ self.particle_color_groups = particle_color_groups
4512
4514
 
4513
4515
  def color(
4514
4516
  self,
@@ -4516,9 +4518,8 @@ class ModelBuilder:
4516
4518
  balance_colors=True,
4517
4519
  target_max_min_color_ratio=1.1,
4518
4520
  coloring_algorithm=ColoringAlgorithm.MCS,
4519
- ):
4520
- """
4521
- Run coloring algorithm to generate coloring information.
4521
+ ) -> None:
4522
+ """Run coloring algorithm to generate coloring information.
4522
4523
 
4523
4524
  Args:
4524
4525
  include_bending_energy: Whether to consider bending energy for trimeshes in the coloring process. If set to `True`, the generated
@@ -4544,7 +4545,7 @@ class ModelBuilder:
4544
4545
  # ignore bending energy if it is too small
4545
4546
  edge_indices = np.array(self.edge_indices)
4546
4547
 
4547
- self.particle_coloring = color_trimesh(
4548
+ self.particle_color_groups = color_trimesh(
4548
4549
  len(self.particle_q),
4549
4550
  edge_indices,
4550
4551
  include_bending,
@@ -4604,7 +4605,11 @@ class ModelBuilder:
4604
4605
  m.particle_max_radius = np.max(self.particle_radius) if len(self.particle_radius) > 0 else 0.0
4605
4606
  m.particle_max_velocity = self.particle_max_velocity
4606
4607
 
4607
- m.particle_coloring = [wp.array(group, dtype=int) for group in self.particle_coloring]
4608
+ particle_colors = np.empty(self.particle_count, dtype=int)
4609
+ for color in range(len(self.particle_color_groups)):
4610
+ particle_colors[self.particle_color_groups[color]] = color
4611
+ m.particle_colors = wp.array(particle_colors, dtype=int)
4612
+ m.particle_color_groups = [wp.array(group, dtype=int) for group in self.particle_color_groups]
4608
4613
 
4609
4614
  # hash-grid for particle interactions
4610
4615
  m.particle_grid = wp.HashGrid(128, 128, 128)