warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/sim/model.py CHANGED
@@ -15,6 +15,7 @@ import numpy as np
15
15
 
16
16
  import warp as wp
17
17
 
18
+ from .graph_coloring import ColoringAlgorithm, color_trimesh, combine_independent_particle_coloring
18
19
  from .inertia import (
19
20
  compute_box_inertia,
20
21
  compute_capsule_inertia,
@@ -577,14 +578,14 @@ class Model:
577
578
 
578
579
  This setting is not supported by :class:`FeatherstoneIntegrator`.
579
580
 
580
- joint_limit_lower (array): Joint lower position limits, shape [joint_count], float
581
- joint_limit_upper (array): Joint upper position limits, shape [joint_count], float
582
- joint_limit_ke (array): Joint position limit stiffness (used by the Euler integrators), shape [joint_count], float
583
- joint_limit_kd (array): Joint position limit damping (used by the Euler integrators), shape [joint_count], float
581
+ joint_limit_lower (array): Joint lower position limits, shape [joint_axis_count], float
582
+ joint_limit_upper (array): Joint upper position limits, shape [joint_axis_count], float
583
+ joint_limit_ke (array): Joint position limit stiffness (used by the Euler integrators), shape [joint_axis_count], float
584
+ joint_limit_kd (array): Joint position limit damping (used by the Euler integrators), shape [joint_axis_count], float
584
585
  joint_twist_lower (array): Joint lower twist limit, shape [joint_count], float
585
586
  joint_twist_upper (array): Joint upper twist limit, shape [joint_count], float
586
- joint_q_start (array): Start index of the first position coordinate per joint, shape [joint_count], int
587
- joint_qd_start (array): Start index of the first velocity coordinate per joint, shape [joint_count], int
587
+ joint_q_start (array): Start index of the first position coordinate per joint (note the last value is an additional sentinel entry to allow for querying the q dimensionality of joint i via ``joint_q_start[i+1] - joint_q_start[i]``), shape [joint_count + 1], int
588
+ joint_qd_start (array): Start index of the first velocity coordinate per joint (note the last value is an additional sentinel entry to allow for querying the qd dimensionality of joint i via ``joint_qd_start[i+1] - joint_qd_start[i]``), shape [joint_count + 1], int
588
589
  articulation_start (array): Articulation start index, shape [articulation_count], int
589
590
  joint_name (list): Joint names, shape [joint_count], str
590
591
  joint_attach_ke (float): Joint attachment force stiffness (used by :class:`SemiImplicitIntegrator`)
@@ -891,7 +892,7 @@ class Model:
891
892
  target.soft_contact_body_pos = wp.zeros(count, dtype=wp.vec3, requires_grad=requires_grad)
892
893
  target.soft_contact_body_vel = wp.zeros(count, dtype=wp.vec3, requires_grad=requires_grad)
893
894
  target.soft_contact_normal = wp.zeros(count, dtype=wp.vec3, requires_grad=requires_grad)
894
- target.soft_contact_tids = wp.zeros(count, dtype=int)
895
+ target.soft_contact_tids = wp.zeros(self.particle_count * (self.shape_count - 1), dtype=int)
895
896
 
896
897
  def allocate_soft_contacts(self, count, requires_grad=False):
897
898
  self._allocate_soft_contacts(self, count, requires_grad)
@@ -1134,6 +1135,8 @@ class ModelBuilder:
1134
1135
  self.particle_radius = []
1135
1136
  self.particle_flags = []
1136
1137
  self.particle_max_velocity = 1e5
1138
+ # list of np.array
1139
+ self.particle_coloring = []
1137
1140
 
1138
1141
  # shapes (each shape has an entry in these arrays)
1139
1142
  # transform from shape to body
@@ -1372,6 +1375,11 @@ class ModelBuilder:
1372
1375
  if builder.tet_count:
1373
1376
  self.tet_indices.extend((np.array(builder.tet_indices, dtype=np.int32) + start_particle_idx).tolist())
1374
1377
 
1378
+ builder_coloring_translated = [group + start_particle_idx for group in builder.particle_coloring]
1379
+ self.particle_coloring = combine_independent_particle_coloring(
1380
+ self.particle_coloring, builder_coloring_translated
1381
+ )
1382
+
1375
1383
  start_body_idx = self.body_count
1376
1384
  start_shape_idx = self.shape_count
1377
1385
  for s, b in enumerate(builder.shape_body):
@@ -1434,12 +1442,14 @@ class ModelBuilder:
1434
1442
  self.shape_collision_filter_pairs.add((i + shape_count, j + shape_count))
1435
1443
  for group, shapes in builder.shape_collision_group_map.items():
1436
1444
  if separate_collision_group:
1437
- group = self.last_collision_group + 1
1445
+ extend_group = self.last_collision_group + 1
1438
1446
  else:
1439
- group = group + self.last_collision_group if group > -1 else -1
1440
- if group not in self.shape_collision_group_map:
1441
- self.shape_collision_group_map[group] = []
1442
- self.shape_collision_group_map[group].extend([s + shape_count for s in shapes])
1447
+ extend_group = group + self.last_collision_group if group > -1 else -1
1448
+
1449
+ if extend_group not in self.shape_collision_group_map:
1450
+ self.shape_collision_group_map[extend_group] = []
1451
+
1452
+ self.shape_collision_group_map[extend_group].extend([s + shape_count for s in shapes])
1443
1453
 
1444
1454
  # update last collision group counter
1445
1455
  if separate_collision_group:
@@ -2608,11 +2618,12 @@ class ModelBuilder:
2608
2618
  joint_remap[joint["original_id"]] = i
2609
2619
  # update articulation_start
2610
2620
  for i, old_i in enumerate(self.articulation_start):
2611
- while old_i not in joint_remap:
2612
- old_i += 1
2613
- if old_i >= self.joint_count:
2621
+ start_i = old_i
2622
+ while start_i not in joint_remap:
2623
+ start_i += 1
2624
+ if start_i >= self.joint_count:
2614
2625
  break
2615
- self.articulation_start[i] = joint_remap.get(old_i, old_i)
2626
+ self.articulation_start[i] = joint_remap.get(start_i, start_i)
2616
2627
  # remove empty articulation starts, i.e. where the start and end are the same
2617
2628
  self.articulation_start = list(set(self.articulation_start))
2618
2629
 
@@ -3421,7 +3432,12 @@ class ModelBuilder:
3421
3432
 
3422
3433
  # particles
3423
3434
  def add_particle(
3424
- self, pos: Vec3, vel: Vec3, mass: float, radius: float = None, flags: wp.uint32 = PARTICLE_FLAG_ACTIVE
3435
+ self,
3436
+ pos: Vec3,
3437
+ vel: Vec3,
3438
+ mass: float,
3439
+ radius: float = None,
3440
+ flags: wp.uint32 = PARTICLE_FLAG_ACTIVE,
3425
3441
  ) -> int:
3426
3442
  """Adds a single particle to the model
3427
3443
 
@@ -3446,7 +3462,9 @@ class ModelBuilder:
3446
3462
  self.particle_radius.append(radius)
3447
3463
  self.particle_flags.append(flags)
3448
3464
 
3449
- return len(self.particle_q) - 1
3465
+ particle_id = self.particle_count - 1
3466
+
3467
+ return particle_id
3450
3468
 
3451
3469
  def add_spring(self, i: int, j, ke: float, kd: float, control: float):
3452
3470
  """Adds a spring between two particles in the system
@@ -3826,6 +3844,7 @@ class ModelBuilder:
3826
3844
  add_springs: bool = False,
3827
3845
  spring_ke: float = default_spring_ke,
3828
3846
  spring_kd: float = default_spring_kd,
3847
+ particle_radius: float = default_particle_radius,
3829
3848
  ):
3830
3849
  """Helper to create a regular planar cloth grid
3831
3850
 
@@ -3846,7 +3865,6 @@ class ModelBuilder:
3846
3865
  fix_right: Make the right-most edge of particles kinematic
3847
3866
  fix_top: Make the top-most edge of particles kinematic
3848
3867
  fix_bottom: Make the bottom-most edge of particles kinematic
3849
-
3850
3868
  """
3851
3869
 
3852
3870
  def grid_index(x, y, dim_x):
@@ -3876,7 +3894,7 @@ class ModelBuilder:
3876
3894
  m = 0.0
3877
3895
  particle_flag = wp.uint32(int(particle_flag) & ~int(PARTICLE_FLAG_ACTIVE))
3878
3896
 
3879
- self.add_particle(p, vel, m, flags=particle_flag)
3897
+ self.add_particle(p, vel, m, flags=particle_flag, radius=particle_radius)
3880
3898
 
3881
3899
  if x > 0 and y > 0:
3882
3900
  if reverse_winding:
@@ -3960,6 +3978,7 @@ class ModelBuilder:
3960
3978
  add_springs: bool = False,
3961
3979
  spring_ke: float = default_spring_ke,
3962
3980
  spring_kd: float = default_spring_kd,
3981
+ particle_radius: float = default_particle_radius,
3963
3982
  ):
3964
3983
  """Helper to create a cloth model from a regular triangle mesh
3965
3984
 
@@ -3975,7 +3994,7 @@ class ModelBuilder:
3975
3994
  density: The density per-area of the mesh
3976
3995
  edge_callback: A user callback when an edge is created
3977
3996
  face_callback: A user callback when a face is created
3978
-
3997
+ particle_radius: The particle_radius which controls particle based collisions.
3979
3998
  Note:
3980
3999
 
3981
4000
  The mesh should be two manifold.
@@ -3989,7 +4008,7 @@ class ModelBuilder:
3989
4008
  for v in vertices:
3990
4009
  p = wp.quat_rotate(rot, v * scale) + pos
3991
4010
 
3992
- self.add_particle(p, vel, 0.0)
4011
+ self.add_particle(p, vel, 0.0, radius=particle_radius)
3993
4012
 
3994
4013
  # triangles
3995
4014
  inds = start_vertex + np.array(indices)
@@ -4016,22 +4035,22 @@ class ModelBuilder:
4016
4035
 
4017
4036
  adj = wp.utils.MeshAdjacency(self.tri_indices[start_tri:end_tri], end_tri - start_tri)
4018
4037
 
4019
- edgeinds = np.fromiter(
4038
+ edge_indices = np.fromiter(
4020
4039
  (x for e in adj.edges.values() for x in (e.o0, e.o1, e.v0, e.v1)),
4021
4040
  int,
4022
4041
  ).reshape(-1, 4)
4023
4042
  self.add_edges(
4024
- edgeinds[:, 0],
4025
- edgeinds[:, 1],
4026
- edgeinds[:, 2],
4027
- edgeinds[:, 3],
4028
- edge_ke=[edge_ke] * len(edgeinds),
4029
- edge_kd=[edge_kd] * len(edgeinds),
4043
+ edge_indices[:, 0],
4044
+ edge_indices[:, 1],
4045
+ edge_indices[:, 2],
4046
+ edge_indices[:, 3],
4047
+ edge_ke=[edge_ke] * len(edge_indices),
4048
+ edge_kd=[edge_kd] * len(edge_indices),
4030
4049
  )
4031
4050
 
4032
4051
  if add_springs:
4033
4052
  spring_indices = set()
4034
- for i, j, k, l in edgeinds:
4053
+ for i, j, k, l in edge_indices:
4035
4054
  spring_indices.add((min(i, j), max(i, j)))
4036
4055
  spring_indices.add((min(i, k), max(i, k)))
4037
4056
  spring_indices.add((min(i, l), max(i, l)))
@@ -4253,8 +4272,7 @@ class ModelBuilder:
4253
4272
  pos = wp.vec3(pos[0], pos[1], pos[2])
4254
4273
  # add particles
4255
4274
  for v in vertices:
4256
- v = wp.vec3(v[0], v[1], v[2])
4257
- p = wp.quat_rotate(rot, v * scale) + pos
4275
+ p = wp.quat_rotate(rot, wp.vec3(v[0], v[1], v[2]) * scale) + pos
4258
4276
 
4259
4277
  self.add_particle(p, vel, 0.0)
4260
4278
 
@@ -4356,6 +4374,63 @@ class ModelBuilder:
4356
4374
  for i in range(self.shape_count - 1):
4357
4375
  self.shape_collision_filter_pairs.add((i, ground_id))
4358
4376
 
4377
+ def set_coloring(self, particle_coloring):
4378
+ """
4379
+ Set coloring information with user-provided coloring.
4380
+
4381
+ Args:
4382
+ particle_coloring: A list of list or `np.array` with `dtype`=`int`. The length of the list is the number of colors
4383
+ and each list or `np.array` contains the indices of vertices with this color.
4384
+ """
4385
+ particle_coloring = [
4386
+ color_group if isinstance(color_group, np.ndarray) else np.array(color_group)
4387
+ for color_group in particle_coloring
4388
+ ]
4389
+ self.particle_coloring = particle_coloring
4390
+
4391
+ def color(
4392
+ self,
4393
+ include_bending=False,
4394
+ balance_colors=True,
4395
+ target_max_min_color_ratio=1.1,
4396
+ coloring_algorithm=ColoringAlgorithm.MCS,
4397
+ ):
4398
+ """
4399
+ Run coloring algorithm to generate coloring information.
4400
+
4401
+ Args:
4402
+ include_bending_energy: Whether to consider bending energy for trimeshes in the coloring process. If set to `True`, the generated
4403
+ graph will contain all the edges connecting o1 and o2; otherwise, the graph will be equivalent to the trimesh.
4404
+ balance_colors: Whether to apply the color balancing algorithm to balance the size of each color
4405
+ target_max_min_color_ratio: the color balancing algorithm will stop when the ratio between the largest color and
4406
+ the smallest color reaches this value
4407
+ algorithm: Value should be an enum type of ColoringAlgorithm, otherwise it will raise an error. ColoringAlgorithm.mcs means using the MCS coloring algorithm,
4408
+ while ColoringAlgorithm.ordered_greedy means using the degree-ordered greedy algorithm. The MCS algorithm typically generates 30% to 50% fewer colors
4409
+ compared to the ordered greedy algorithm, while maintaining the same linear complexity. Although MCS has a constant overhead that makes it about twice
4410
+ as slow as the greedy algorithm, it produces significantly better coloring results. We recommend using MCS, especially if coloring is only part of the
4411
+ preprocessing.
4412
+
4413
+ Note:
4414
+
4415
+ References to the coloring algorithm:
4416
+
4417
+ MCS: Pereira, F. M. Q., & Palsberg, J. (2005, November). Register allocation via coloring of chordal graphs. In Asian Symposium on Programming Languages and Systems (pp. 315-329). Berlin, Heidelberg: Springer Berlin Heidelberg.
4418
+
4419
+ Ordered Greedy: Ton-That, Q. M., Kry, P. G., & Andrews, S. (2023). Parallel block Neo-Hookean XPBD using graph clustering. Computers & Graphics, 110, 1-10.
4420
+
4421
+ """
4422
+ # ignore bending energy if it is too small
4423
+ edge_indices = np.array(self.edge_indices)
4424
+
4425
+ self.particle_coloring = color_trimesh(
4426
+ len(self.particle_q),
4427
+ edge_indices,
4428
+ include_bending,
4429
+ algorithm=coloring_algorithm,
4430
+ balance_colors=balance_colors,
4431
+ target_max_min_color_ratio=target_max_min_color_ratio,
4432
+ )
4433
+
4359
4434
  def finalize(self, device=None, requires_grad=False) -> Model:
4360
4435
  """Convert this builder object to a concrete model for simulation.
4361
4436
 
@@ -4407,6 +4482,8 @@ class ModelBuilder:
4407
4482
  m.particle_max_radius = np.max(self.particle_radius) if len(self.particle_radius) > 0 else 0.0
4408
4483
  m.particle_max_velocity = self.particle_max_velocity
4409
4484
 
4485
+ m.particle_coloring = [wp.array(group, dtype=int) for group in self.particle_coloring]
4486
+
4410
4487
  # hash-grid for particle interactions
4411
4488
  m.particle_grid = wp.HashGrid(128, 128, 128)
4412
4489
 
warp/sparse.py CHANGED
@@ -8,7 +8,7 @@ from warp.types import Array, Cols, Rows, Scalar, Vector
8
8
 
9
9
  # typing hints
10
10
 
11
- _BlockType = TypeVar("BlockType")
11
+ _BlockType = TypeVar("BlockType") # noqa: PLC0132
12
12
 
13
13
 
14
14
  class _MatrixBlockType(Generic[Rows, Cols, Scalar]):