warp-lang 1.2.2__py3-none-win_amd64.whl → 1.3.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (194) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +6 -2
  6. warp/builtins.py +1410 -886
  7. warp/codegen.py +503 -166
  8. warp/config.py +48 -18
  9. warp/context.py +400 -198
  10. warp/dlpack.py +8 -0
  11. warp/examples/assets/bunny.usd +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  13. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  14. warp/examples/benchmarks/benchmark_launches.py +1 -1
  15. warp/examples/core/example_cupy.py +78 -0
  16. warp/examples/fem/example_apic_fluid.py +17 -36
  17. warp/examples/fem/example_burgers.py +9 -18
  18. warp/examples/fem/example_convection_diffusion.py +7 -17
  19. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  20. warp/examples/fem/example_deformed_geometry.py +11 -22
  21. warp/examples/fem/example_diffusion.py +7 -18
  22. warp/examples/fem/example_diffusion_3d.py +24 -28
  23. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  24. warp/examples/fem/example_magnetostatics.py +190 -0
  25. warp/examples/fem/example_mixed_elasticity.py +111 -80
  26. warp/examples/fem/example_navier_stokes.py +30 -34
  27. warp/examples/fem/example_nonconforming_contact.py +290 -0
  28. warp/examples/fem/example_stokes.py +17 -32
  29. warp/examples/fem/example_stokes_transfer.py +12 -21
  30. warp/examples/fem/example_streamlines.py +350 -0
  31. warp/examples/fem/utils.py +936 -0
  32. warp/fabric.py +5 -2
  33. warp/fem/__init__.py +13 -3
  34. warp/fem/cache.py +161 -11
  35. warp/fem/dirichlet.py +37 -28
  36. warp/fem/domain.py +105 -14
  37. warp/fem/field/__init__.py +14 -3
  38. warp/fem/field/field.py +454 -11
  39. warp/fem/field/nodal_field.py +33 -18
  40. warp/fem/geometry/deformed_geometry.py +50 -15
  41. warp/fem/geometry/hexmesh.py +12 -24
  42. warp/fem/geometry/nanogrid.py +106 -31
  43. warp/fem/geometry/quadmesh_2d.py +6 -11
  44. warp/fem/geometry/tetmesh.py +103 -61
  45. warp/fem/geometry/trimesh_2d.py +98 -47
  46. warp/fem/integrate.py +231 -186
  47. warp/fem/operator.py +14 -9
  48. warp/fem/quadrature/pic_quadrature.py +35 -9
  49. warp/fem/quadrature/quadrature.py +119 -32
  50. warp/fem/space/basis_space.py +98 -22
  51. warp/fem/space/collocated_function_space.py +3 -1
  52. warp/fem/space/function_space.py +7 -2
  53. warp/fem/space/grid_2d_function_space.py +3 -3
  54. warp/fem/space/grid_3d_function_space.py +4 -4
  55. warp/fem/space/hexmesh_function_space.py +3 -2
  56. warp/fem/space/nanogrid_function_space.py +12 -14
  57. warp/fem/space/partition.py +45 -47
  58. warp/fem/space/restriction.py +19 -16
  59. warp/fem/space/shape/cube_shape_function.py +91 -3
  60. warp/fem/space/shape/shape_function.py +7 -0
  61. warp/fem/space/shape/square_shape_function.py +32 -0
  62. warp/fem/space/shape/tet_shape_function.py +11 -7
  63. warp/fem/space/shape/triangle_shape_function.py +10 -1
  64. warp/fem/space/topology.py +116 -42
  65. warp/fem/types.py +8 -1
  66. warp/fem/utils.py +301 -83
  67. warp/native/array.h +16 -0
  68. warp/native/builtin.h +0 -15
  69. warp/native/cuda_util.cpp +14 -6
  70. warp/native/exports.h +1348 -1308
  71. warp/native/quat.h +79 -0
  72. warp/native/rand.h +27 -4
  73. warp/native/sparse.cpp +83 -81
  74. warp/native/sparse.cu +381 -453
  75. warp/native/vec.h +64 -0
  76. warp/native/volume.cpp +40 -49
  77. warp/native/volume_builder.cu +2 -3
  78. warp/native/volume_builder.h +12 -17
  79. warp/native/warp.cu +3 -3
  80. warp/native/warp.h +69 -59
  81. warp/render/render_opengl.py +17 -9
  82. warp/sim/articulation.py +117 -17
  83. warp/sim/collide.py +35 -29
  84. warp/sim/model.py +123 -18
  85. warp/sim/render.py +3 -1
  86. warp/sparse.py +867 -203
  87. warp/stubs.py +312 -541
  88. warp/tape.py +29 -1
  89. warp/tests/disabled_kinematics.py +1 -1
  90. warp/tests/test_adam.py +1 -1
  91. warp/tests/test_arithmetic.py +1 -1
  92. warp/tests/test_array.py +58 -1
  93. warp/tests/test_array_reduce.py +1 -1
  94. warp/tests/test_async.py +1 -1
  95. warp/tests/test_atomic.py +1 -1
  96. warp/tests/test_bool.py +1 -1
  97. warp/tests/test_builtins_resolution.py +1 -1
  98. warp/tests/test_bvh.py +6 -1
  99. warp/tests/test_closest_point_edge_edge.py +1 -1
  100. warp/tests/test_codegen.py +66 -1
  101. warp/tests/test_compile_consts.py +1 -1
  102. warp/tests/test_conditional.py +1 -1
  103. warp/tests/test_copy.py +1 -1
  104. warp/tests/test_ctypes.py +1 -1
  105. warp/tests/test_dense.py +1 -1
  106. warp/tests/test_devices.py +1 -1
  107. warp/tests/test_dlpack.py +1 -1
  108. warp/tests/test_examples.py +33 -4
  109. warp/tests/test_fabricarray.py +5 -2
  110. warp/tests/test_fast_math.py +1 -1
  111. warp/tests/test_fem.py +213 -6
  112. warp/tests/test_fp16.py +1 -1
  113. warp/tests/test_func.py +1 -1
  114. warp/tests/test_future_annotations.py +90 -0
  115. warp/tests/test_generics.py +1 -1
  116. warp/tests/test_grad.py +1 -1
  117. warp/tests/test_grad_customs.py +1 -1
  118. warp/tests/test_grad_debug.py +247 -0
  119. warp/tests/test_hash_grid.py +6 -1
  120. warp/tests/test_implicit_init.py +354 -0
  121. warp/tests/test_import.py +1 -1
  122. warp/tests/test_indexedarray.py +1 -1
  123. warp/tests/test_intersect.py +1 -1
  124. warp/tests/test_jax.py +1 -1
  125. warp/tests/test_large.py +1 -1
  126. warp/tests/test_launch.py +1 -1
  127. warp/tests/test_lerp.py +1 -1
  128. warp/tests/test_linear_solvers.py +1 -1
  129. warp/tests/test_lvalue.py +1 -1
  130. warp/tests/test_marching_cubes.py +5 -2
  131. warp/tests/test_mat.py +34 -35
  132. warp/tests/test_mat_lite.py +2 -1
  133. warp/tests/test_mat_scalar_ops.py +1 -1
  134. warp/tests/test_math.py +1 -1
  135. warp/tests/test_matmul.py +20 -16
  136. warp/tests/test_matmul_lite.py +1 -1
  137. warp/tests/test_mempool.py +1 -1
  138. warp/tests/test_mesh.py +5 -2
  139. warp/tests/test_mesh_query_aabb.py +1 -1
  140. warp/tests/test_mesh_query_point.py +1 -1
  141. warp/tests/test_mesh_query_ray.py +1 -1
  142. warp/tests/test_mlp.py +1 -1
  143. warp/tests/test_model.py +1 -1
  144. warp/tests/test_module_hashing.py +77 -1
  145. warp/tests/test_modules_lite.py +1 -1
  146. warp/tests/test_multigpu.py +1 -1
  147. warp/tests/test_noise.py +1 -1
  148. warp/tests/test_operators.py +1 -1
  149. warp/tests/test_options.py +1 -1
  150. warp/tests/test_overwrite.py +542 -0
  151. warp/tests/test_peer.py +1 -1
  152. warp/tests/test_pinned.py +1 -1
  153. warp/tests/test_print.py +1 -1
  154. warp/tests/test_quat.py +15 -1
  155. warp/tests/test_rand.py +1 -1
  156. warp/tests/test_reload.py +1 -1
  157. warp/tests/test_rounding.py +1 -1
  158. warp/tests/test_runlength_encode.py +1 -1
  159. warp/tests/test_scalar_ops.py +95 -0
  160. warp/tests/test_sim_grad.py +1 -1
  161. warp/tests/test_sim_kinematics.py +1 -1
  162. warp/tests/test_smoothstep.py +1 -1
  163. warp/tests/test_sparse.py +82 -15
  164. warp/tests/test_spatial.py +1 -1
  165. warp/tests/test_special_values.py +2 -11
  166. warp/tests/test_streams.py +11 -1
  167. warp/tests/test_struct.py +1 -1
  168. warp/tests/test_tape.py +1 -1
  169. warp/tests/test_torch.py +194 -1
  170. warp/tests/test_transient_module.py +1 -1
  171. warp/tests/test_types.py +1 -1
  172. warp/tests/test_utils.py +1 -1
  173. warp/tests/test_vec.py +15 -63
  174. warp/tests/test_vec_lite.py +2 -1
  175. warp/tests/test_vec_scalar_ops.py +65 -1
  176. warp/tests/test_verify_fp.py +1 -1
  177. warp/tests/test_volume.py +28 -2
  178. warp/tests/test_volume_write.py +1 -1
  179. warp/tests/unittest_serial.py +1 -1
  180. warp/tests/unittest_suites.py +9 -1
  181. warp/tests/walkthrough_debug.py +1 -1
  182. warp/thirdparty/unittest_parallel.py +2 -5
  183. warp/torch.py +103 -41
  184. warp/types.py +341 -224
  185. warp/utils.py +11 -2
  186. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  187. warp_lang-1.3.0.dist-info/RECORD +368 -0
  188. warp/examples/fem/bsr_utils.py +0 -378
  189. warp/examples/fem/mesh_utils.py +0 -133
  190. warp/examples/fem/plot_utils.py +0 -292
  191. warp_lang-1.2.2.dist-info/RECORD +0 -359
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  194. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/sim/articulation.py CHANGED
@@ -164,12 +164,10 @@ def invert_3d_rotational_dofs(
164
164
  return angles, velocities
165
165
 
166
166
 
167
- @wp.kernel
168
- def eval_articulation_fk(
169
- articulation_start: wp.array(dtype=int),
170
- articulation_mask: wp.array(
171
- dtype=int
172
- ), # used to enable / disable FK for an articulation, if None then treat all as enabled
167
+ @wp.func
168
+ def eval_single_articulation_fk(
169
+ joint_start: int,
170
+ joint_end: int,
173
171
  joint_q: wp.array(dtype=float),
174
172
  joint_qd: wp.array(dtype=float),
175
173
  joint_q_start: wp.array(dtype=int),
@@ -187,16 +185,6 @@ def eval_articulation_fk(
187
185
  body_q: wp.array(dtype=wp.transform),
188
186
  body_qd: wp.array(dtype=wp.spatial_vector),
189
187
  ):
190
- tid = wp.tid()
191
-
192
- # early out if disabling FK for this articulation
193
- if articulation_mask:
194
- if articulation_mask[tid] == 0:
195
- return
196
-
197
- joint_start = articulation_start[tid]
198
- joint_end = articulation_start[tid + 1]
199
-
200
188
  for i in range(joint_start, joint_end):
201
189
  parent = joint_parent[i]
202
190
  child = joint_child[i]
@@ -374,6 +362,118 @@ def eval_articulation_fk(
374
362
  body_qd[child] = v_wc
375
363
 
376
364
 
365
+ # implementation where mask is an integer array
366
+ @wp.kernel
367
+ def eval_articulation_fk(
368
+ articulation_start: wp.array(dtype=int),
369
+ articulation_mask: wp.array(
370
+ dtype=int
371
+ ), # used to enable / disable FK for an articulation, if None then treat all as enabled
372
+ joint_q: wp.array(dtype=float),
373
+ joint_qd: wp.array(dtype=float),
374
+ joint_q_start: wp.array(dtype=int),
375
+ joint_qd_start: wp.array(dtype=int),
376
+ joint_type: wp.array(dtype=int),
377
+ joint_parent: wp.array(dtype=int),
378
+ joint_child: wp.array(dtype=int),
379
+ joint_X_p: wp.array(dtype=wp.transform),
380
+ joint_X_c: wp.array(dtype=wp.transform),
381
+ joint_axis: wp.array(dtype=wp.vec3),
382
+ joint_axis_start: wp.array(dtype=int),
383
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
384
+ body_com: wp.array(dtype=wp.vec3),
385
+ # outputs
386
+ body_q: wp.array(dtype=wp.transform),
387
+ body_qd: wp.array(dtype=wp.spatial_vector),
388
+ ):
389
+ tid = wp.tid()
390
+
391
+ # early out if disabling FK for this articulation
392
+ if articulation_mask:
393
+ if articulation_mask[tid] == 0:
394
+ return
395
+
396
+ joint_start = articulation_start[tid]
397
+ joint_end = articulation_start[tid + 1]
398
+
399
+ eval_single_articulation_fk(
400
+ joint_start,
401
+ joint_end,
402
+ joint_q,
403
+ joint_qd,
404
+ joint_q_start,
405
+ joint_qd_start,
406
+ joint_type,
407
+ joint_parent,
408
+ joint_child,
409
+ joint_X_p,
410
+ joint_X_c,
411
+ joint_axis,
412
+ joint_axis_start,
413
+ joint_axis_dim,
414
+ body_com,
415
+ # outputs
416
+ body_q,
417
+ body_qd,
418
+ )
419
+
420
+
421
+ # overload where mask is a bool array
422
+ @wp.kernel
423
+ def eval_articulation_fk(
424
+ articulation_start: wp.array(dtype=int),
425
+ articulation_mask: wp.array(
426
+ dtype=bool
427
+ ), # used to enable / disable FK for an articulation, if None then treat all as enabled
428
+ joint_q: wp.array(dtype=float),
429
+ joint_qd: wp.array(dtype=float),
430
+ joint_q_start: wp.array(dtype=int),
431
+ joint_qd_start: wp.array(dtype=int),
432
+ joint_type: wp.array(dtype=int),
433
+ joint_parent: wp.array(dtype=int),
434
+ joint_child: wp.array(dtype=int),
435
+ joint_X_p: wp.array(dtype=wp.transform),
436
+ joint_X_c: wp.array(dtype=wp.transform),
437
+ joint_axis: wp.array(dtype=wp.vec3),
438
+ joint_axis_start: wp.array(dtype=int),
439
+ joint_axis_dim: wp.array(dtype=int, ndim=2),
440
+ body_com: wp.array(dtype=wp.vec3),
441
+ # outputs
442
+ body_q: wp.array(dtype=wp.transform),
443
+ body_qd: wp.array(dtype=wp.spatial_vector),
444
+ ):
445
+ tid = wp.tid()
446
+
447
+ # early out if disabling FK for this articulation
448
+ if articulation_mask:
449
+ if not articulation_mask[tid]:
450
+ return
451
+
452
+ joint_start = articulation_start[tid]
453
+ joint_end = articulation_start[tid + 1]
454
+
455
+ eval_single_articulation_fk(
456
+ joint_start,
457
+ joint_end,
458
+ joint_q,
459
+ joint_qd,
460
+ joint_q_start,
461
+ joint_qd_start,
462
+ joint_type,
463
+ joint_parent,
464
+ joint_child,
465
+ joint_X_p,
466
+ joint_X_c,
467
+ joint_axis,
468
+ joint_axis_start,
469
+ joint_axis_dim,
470
+ body_com,
471
+ # outputs
472
+ body_q,
473
+ body_qd,
474
+ )
475
+
476
+
377
477
  # updates state body information based on joint coordinates
378
478
  def eval_fk(model, joint_q, joint_qd, mask, state):
379
479
  """
@@ -383,7 +483,7 @@ def eval_fk(model, joint_q, joint_qd, mask, state):
383
483
  model (Model): The model to evaluate.
384
484
  joint_q (array): Generalized joint position coordinates, shape [joint_coord_count], float
385
485
  joint_qd (array): Generalized joint velocity coordinates, shape [joint_dof_count], float
386
- mask (array): The mask to use to enable / disable FK for an articulation. If None then treat all as enabled, shape [articulation_count], int
486
+ mask (array): The mask to use to enable / disable FK for an articulation. If None then treat all as enabled, shape [articulation_count], int/bool
387
487
  state (State): The state to update.
388
488
  """
389
489
  wp.launch(
warp/sim/collide.py CHANGED
@@ -859,10 +859,9 @@ def broadphase_collision_pairs(
859
859
  contact_shape0[index + num_contacts_a + i] = actual_shape_b
860
860
  contact_shape1[index + num_contacts_a + i] = actual_shape_a
861
861
  contact_point_id[index + num_contacts_a + i] = i
862
- contact_point_limit[pair_index_ab] = 2
863
- if mesh_contact_max > 0:
862
+ if mesh_contact_max > 0 and contact_point_limit and pair_index_ba < contact_point_limit.shape[0]:
864
863
  num_contacts_b = wp.min(mesh_contact_max, num_contacts_b)
865
- contact_point_limit[pair_index_ba] = num_contacts_b
864
+ contact_point_limit[pair_index_ba] = num_contacts_b
866
865
  return
867
866
  else:
868
867
  num_contacts = 2
@@ -877,13 +876,11 @@ def broadphase_collision_pairs(
877
876
  contact_shape0[index + i] = shape_a
878
877
  contact_shape1[index + i] = shape_b
879
878
  contact_point_id[index + i] = i
880
- contact_point_limit[pair_index_ab] = 12
881
879
  # allocate contact points from box B against A
882
880
  for i in range(12):
883
881
  contact_shape0[index + 12 + i] = shape_b
884
882
  contact_shape1[index + 12 + i] = shape_a
885
883
  contact_point_id[index + 12 + i] = i
886
- contact_point_limit[pair_index_ba] = 12
887
884
  return
888
885
  elif actual_type_b == wp.sim.GEO_MESH:
889
886
  num_contacts_a = 8
@@ -908,10 +905,9 @@ def broadphase_collision_pairs(
908
905
  contact_shape1[index + num_contacts_a + i] = actual_shape_a
909
906
  contact_point_id[index + num_contacts_a + i] = i
910
907
 
911
- contact_point_limit[pair_index_ab] = num_contacts_a
912
- if mesh_contact_max > 0:
908
+ if mesh_contact_max > 0 and contact_point_limit and pair_index_ba < contact_point_limit.shape[0]:
913
909
  num_contacts_b = wp.min(mesh_contact_max, num_contacts_b)
914
- contact_point_limit[pair_index_ba] = num_contacts_b
910
+ contact_point_limit[pair_index_ba] = num_contacts_b
915
911
  return
916
912
  elif actual_type_b == wp.sim.GEO_PLANE:
917
913
  if geo.scale[actual_shape_b][0] == 0.0 and geo.scale[actual_shape_b][1] == 0.0:
@@ -947,11 +943,13 @@ def broadphase_collision_pairs(
947
943
  contact_shape1[index + num_contacts_a + i] = actual_shape_a
948
944
  contact_point_id[index + num_contacts_a + i] = i
949
945
 
950
- if mesh_contact_max > 0:
946
+ if mesh_contact_max > 0 and contact_point_limit:
951
947
  num_contacts_a = wp.min(mesh_contact_max, num_contacts_a)
952
948
  num_contacts_b = wp.min(mesh_contact_max, num_contacts_b)
953
- contact_point_limit[pair_index_ab] = num_contacts_a
954
- contact_point_limit[pair_index_ba] = num_contacts_b
949
+ if pair_index_ab < contact_point_limit.shape[0]:
950
+ contact_point_limit[pair_index_ab] = num_contacts_a
951
+ if pair_index_ba < contact_point_limit.shape[0]:
952
+ contact_point_limit[pair_index_ba] = num_contacts_b
955
953
  return
956
954
  elif actual_type_a == wp.sim.GEO_PLANE:
957
955
  return # no plane-plane contacts
@@ -969,8 +967,11 @@ def broadphase_collision_pairs(
969
967
  contact_shape0[cp_index] = actual_shape_a
970
968
  contact_shape1[cp_index] = actual_shape_b
971
969
  contact_point_id[cp_index] = i
972
- contact_point_limit[pair_index_ab] = num_contacts
973
- contact_point_limit[pair_index_ba] = 0
970
+ if contact_point_limit:
971
+ if pair_index_ab < contact_point_limit.shape[0]:
972
+ contact_point_limit[pair_index_ab] = num_contacts
973
+ if pair_index_ba < contact_point_limit.shape[0]:
974
+ contact_point_limit[pair_index_ba] = 0
974
975
 
975
976
 
976
977
  @wp.kernel
@@ -1005,12 +1006,14 @@ def handle_contact_pairs(
1005
1006
  if shape_a == shape_b:
1006
1007
  return
1007
1008
 
1009
+ if contact_point_limit:
1010
+ pair_index = shape_a * num_shapes + shape_b
1011
+ contact_limit = contact_point_limit[pair_index]
1012
+ if contact_pairwise_counter[pair_index] >= contact_limit:
1013
+ # reached limit of contact points per contact pair
1014
+ return
1015
+
1008
1016
  point_id = contact_point_id[tid]
1009
- pair_index = shape_a * num_shapes + shape_b
1010
- contact_limit = contact_point_limit[pair_index]
1011
- if contact_pairwise_counter[pair_index] >= contact_limit:
1012
- # reached limit of contact points per contact pair
1013
- return
1014
1017
 
1015
1018
  rigid_a = shape_body[shape_a]
1016
1019
  X_wb_a = wp.transform_identity()
@@ -1404,15 +1407,16 @@ def handle_contact_pairs(
1404
1407
 
1405
1408
  d = distance - thickness
1406
1409
  if d < rigid_contact_margin:
1407
- pair_contact_id = limited_counter_increment(
1408
- contact_pairwise_counter, pair_index, contact_tids, tid, contact_limit
1409
- )
1410
- if pair_contact_id == -1:
1411
- # wp.printf("Reached contact point limit %d >= %d for shape pair %d and %d (pair_index: %d)\n",
1412
- # contact_pairwise_counter[pair_index], contact_limit, shape_a, shape_b, pair_index)
1413
- # reached contact point limit
1414
- return
1415
- index = limited_counter_increment(contact_count, 0, contact_tids, tid, -1)
1410
+ if contact_pairwise_counter:
1411
+ pair_contact_id = limited_counter_increment(
1412
+ contact_pairwise_counter, pair_index, contact_tids, tid, contact_limit
1413
+ )
1414
+ if pair_contact_id == -1:
1415
+ # wp.printf("Reached contact point limit %d >= %d for shape pair %d and %d (pair_index: %d)\n",
1416
+ # contact_pairwise_counter[pair_index], contact_limit, shape_a, shape_b, pair_index)
1417
+ # reached contact point limit
1418
+ return
1419
+ index = counter_increment(contact_count, 0, contact_tids, tid)
1416
1420
  contact_shape0[index] = shape_a
1417
1421
  contact_shape1[index] = shape_b
1418
1422
  # transform from world into body frame (so the contact point includes the shape transform)
@@ -1550,14 +1554,16 @@ def collide(model, state, edge_sdf_iter: int = 10, iterate_mesh_vertices: bool =
1550
1554
  model.rigid_contact_normal = wp.clone(model.rigid_contact_normal)
1551
1555
  model.rigid_contact_thickness = wp.clone(model.rigid_contact_thickness)
1552
1556
  model.rigid_contact_count = wp.zeros_like(model.rigid_contact_count)
1553
- model.rigid_contact_pairwise_counter = wp.zeros_like(model.rigid_contact_pairwise_counter)
1554
1557
  model.rigid_contact_tids = wp.zeros_like(model.rigid_contact_tids)
1555
1558
  model.rigid_contact_shape0 = wp.empty_like(model.rigid_contact_shape0)
1556
1559
  model.rigid_contact_shape1 = wp.empty_like(model.rigid_contact_shape1)
1560
+ if model.rigid_contact_pairwise_counter is not None:
1561
+ model.rigid_contact_pairwise_counter = wp.zeros_like(model.rigid_contact_pairwise_counter)
1557
1562
  else:
1558
1563
  model.rigid_contact_count.zero_()
1559
- model.rigid_contact_pairwise_counter.zero_()
1560
1564
  model.rigid_contact_tids.zero_()
1565
+ if model.rigid_contact_pairwise_counter is not None:
1566
+ model.rigid_contact_pairwise_counter.zero_()
1561
1567
  model.rigid_contact_shape0.fill_(-1)
1562
1568
  model.rigid_contact_shape1.fill_(-1)
1563
1569
 
warp/sim/model.py CHANGED
@@ -561,7 +561,7 @@ class Model:
561
561
  joint_X_p (array): Joint transform in parent frame, shape [joint_count, 7], float
562
562
  joint_X_c (array): Joint mass frame in child frame, shape [joint_count, 7], float
563
563
  joint_axis (array): Joint axis in child frame, shape [joint_axis_count, 3], float
564
- joint_armature (array): Armature for each joint axis (only used by :class:`FeatherstoneIntegrator`), shape [joint_count], float
564
+ joint_armature (array): Armature for each joint axis (only used by :class:`FeatherstoneIntegrator`), shape [joint_dof_count], float
565
565
  joint_target_ke (array): Joint stiffness, shape [joint_axis_count], float
566
566
  joint_target_kd (array): Joint damping, shape [joint_axis_count], float
567
567
  joint_axis_start (array): Start index of the first axis per joint, shape [joint_count], int
@@ -1014,11 +1014,23 @@ class Model:
1014
1014
  target.rigid_contact_broad_shape0 = wp.zeros(self.rigid_contact_max, dtype=wp.int32)
1015
1015
  target.rigid_contact_broad_shape1 = wp.zeros(self.rigid_contact_max, dtype=wp.int32)
1016
1016
 
1017
- max_pair_count = self.shape_count * self.shape_count
1018
- # maximum number of contact points per contact pair
1019
- target.rigid_contact_point_limit = wp.zeros(max_pair_count, dtype=wp.int32)
1020
- # currently found contacts per contact pair
1021
- target.rigid_contact_pairwise_counter = wp.zeros(max_pair_count, dtype=wp.int32)
1017
+ if self.rigid_mesh_contact_max > 0:
1018
+ # add additional buffers to track how many contact points are generated per contact pair
1019
+ # (significantly increases memory usage, only enable if mesh contacts need to be pruned)
1020
+ if self.shape_count >= 46340:
1021
+ # clip the number of potential contacts to avoid signed 32-bit integer overflow
1022
+ # i.e. when the number of shapes exceeds sqrt(2**31 - 1)
1023
+ max_pair_count = 2**31 - 1
1024
+ else:
1025
+ max_pair_count = self.shape_count * self.shape_count
1026
+ # maximum number of contact points per contact pair
1027
+ target.rigid_contact_point_limit = wp.zeros(max_pair_count, dtype=wp.int32)
1028
+ # currently found contacts per contact pair
1029
+ target.rigid_contact_pairwise_counter = wp.zeros(max_pair_count, dtype=wp.int32)
1030
+ else:
1031
+ target.rigid_contact_point_limit = None
1032
+ target.rigid_contact_pairwise_counter = None
1033
+
1022
1034
  # ID of thread that found the current contact point
1023
1035
  target.rigid_contact_tids = wp.zeros(self.rigid_contact_max, dtype=wp.int32)
1024
1036
 
@@ -1193,7 +1205,6 @@ class ModelBuilder:
1193
1205
  self.body_shapes = {} # mapping from body to shapes
1194
1206
 
1195
1207
  # rigid joints
1196
- self.joint = {}
1197
1208
  self.joint_parent = [] # index of the parent body (constant)
1198
1209
  self.joint_parents = {} # mapping from joint to parent bodies
1199
1210
  self.joint_child = [] # index of the child body (constant)
@@ -1233,7 +1244,7 @@ class ModelBuilder:
1233
1244
  self.joint_axis_total_count = 0
1234
1245
 
1235
1246
  self.up_vector = wp.vec3(up_vector)
1236
- self.up_axis = wp.vec3(np.argmax(np.abs(up_vector)))
1247
+ self.up_axis = int(np.argmax(np.abs(up_vector)))
1237
1248
  self.gravity = gravity
1238
1249
  # indicates whether a ground plane has been created
1239
1250
  self._ground_created = False
@@ -2262,8 +2273,30 @@ class ModelBuilder:
2262
2273
  enabled=enabled,
2263
2274
  )
2264
2275
 
2265
- def plot_articulation(self, plot_shapes=True):
2266
- """Plots the model's articulation."""
2276
+ def plot_articulation(
2277
+ self,
2278
+ show_body_names=True,
2279
+ show_joint_names=True,
2280
+ show_joint_types=True,
2281
+ plot_shapes=True,
2282
+ show_shape_types=True,
2283
+ show_legend=True,
2284
+ ):
2285
+ """
2286
+ Visualizes the model's articulation graph using matplotlib and networkx.
2287
+ Uses the spring layout algorithm from networkx to arrange the nodes.
2288
+ Bodies are shown as orange squares, shapes are shown as blue circles.
2289
+
2290
+ Args:
2291
+ show_body_names (bool): Whether to show the body names or indices
2292
+ show_joint_names (bool): Whether to show the joint names or indices
2293
+ show_joint_types (bool): Whether to show the joint types
2294
+ plot_shapes (bool): Whether to render the shapes connected to the rigid bodies
2295
+ show_shape_types (bool): Whether to show the shape geometry types
2296
+ show_legend (bool): Whether to show a legend
2297
+ """
2298
+ import matplotlib.pyplot as plt
2299
+ import networkx as nx
2267
2300
 
2268
2301
  def joint_type_str(type):
2269
2302
  if type == JOINT_FREE:
@@ -2286,18 +2319,88 @@ class ModelBuilder:
2286
2319
  return "distance"
2287
2320
  return "unknown"
2288
2321
 
2289
- vertices = ["world"] + self.body_name
2322
+ def shape_type_str(type):
2323
+ if type == GEO_SPHERE:
2324
+ return "sphere"
2325
+ if type == GEO_BOX:
2326
+ return "box"
2327
+ if type == GEO_CAPSULE:
2328
+ return "capsule"
2329
+ if type == GEO_CYLINDER:
2330
+ return "cylinder"
2331
+ if type == GEO_CONE:
2332
+ return "cone"
2333
+ if type == GEO_MESH:
2334
+ return "mesh"
2335
+ if type == GEO_SDF:
2336
+ return "sdf"
2337
+ if type == GEO_PLANE:
2338
+ return "plane"
2339
+ if type == GEO_NONE:
2340
+ return "none"
2341
+ return "unknown"
2342
+
2343
+ if show_body_names:
2344
+ vertices = ["world"] + self.body_name
2345
+ else:
2346
+ vertices = ["-1"] + [str(i) for i in range(self.body_count)]
2290
2347
  if plot_shapes:
2291
- vertices += [f"shape_{i}" for i in range(self.shape_count)]
2348
+ for i in range(self.shape_count):
2349
+ shape_label = f"shape_{i}"
2350
+ if show_shape_types:
2351
+ shape_label += f"\n({shape_type_str(self.shape_geo_type[i])})"
2352
+ vertices.append(shape_label)
2292
2353
  edges = []
2293
2354
  edge_labels = []
2294
2355
  for i in range(self.joint_count):
2295
- edges.append((self.joint_child[i] + 1, self.joint_parent[i] + 1))
2296
- edge_labels.append(f"{self.joint_name[i]}\n({joint_type_str(self.joint_type[i])})")
2356
+ edge = (self.joint_child[i] + 1, self.joint_parent[i] + 1)
2357
+ edges.append(edge)
2358
+ if show_joint_names:
2359
+ joint_label = self.joint_name[i]
2360
+ else:
2361
+ joint_label = str(i)
2362
+ if show_joint_types:
2363
+ joint_label += f"\n({joint_type_str(self.joint_type[i])})"
2364
+ edge_labels.append(joint_label)
2365
+
2297
2366
  if plot_shapes:
2298
2367
  for i in range(self.shape_count):
2299
2368
  edges.append((len(self.body_name) + i + 1, self.shape_body[i] + 1))
2300
- wp.plot_graph(vertices, edges, edge_labels=edge_labels)
2369
+
2370
+ # plot graph
2371
+ G = nx.Graph()
2372
+ for i in range(len(vertices)):
2373
+ G.add_node(i, label=vertices[i])
2374
+ for i in range(len(edges)):
2375
+ label = edge_labels[i] if i < len(edge_labels) else ""
2376
+ G.add_edge(edges[i][0], edges[i][1], label=label)
2377
+ pos = nx.spring_layout(G)
2378
+ nx.draw_networkx_edges(G, pos, node_size=0, edgelist=edges[: self.joint_count])
2379
+ # render body vertices
2380
+ draw_args = {"node_size": 100}
2381
+ bodies = nx.subgraph(G, list(range(self.body_count + 1)))
2382
+ nx.draw_networkx_nodes(bodies, pos, node_color="orange", node_shape="s", **draw_args)
2383
+ if plot_shapes:
2384
+ # render shape vertices
2385
+ shapes = nx.subgraph(G, list(range(self.body_count + 1, len(vertices))))
2386
+ nx.draw_networkx_nodes(shapes, pos, node_color="skyblue", **draw_args)
2387
+ nx.draw_networkx_edges(
2388
+ G, pos, node_size=0, edgelist=edges[self.joint_count :], edge_color="gray", style="dashed"
2389
+ )
2390
+ edge_labels = nx.get_edge_attributes(G, "label")
2391
+ nx.draw_networkx_edge_labels(
2392
+ G, pos, edge_labels=edge_labels, font_size=6, bbox={"alpha": 0.6, "color": "w", "lw": 0}
2393
+ )
2394
+ # add node labels
2395
+ nx.draw_networkx_labels(G, pos, dict(enumerate(vertices)), font_size=6)
2396
+ if show_legend:
2397
+ plt.plot([], [], "s", color="orange", label="body")
2398
+ plt.plot([], [], "k-", label="joint")
2399
+ if plot_shapes:
2400
+ plt.plot([], [], "o", color="skyblue", label="shape")
2401
+ plt.plot([], [], "k--", label="shape-body connection")
2402
+ plt.legend(loc="upper left", fontsize=6)
2403
+ plt.show()
2301
2404
 
2302
2405
  def collapse_fixed_joints(self, verbose=wp.config.verbose):
2303
2406
  """Removes fixed joints from the model and merges the bodies they connect. This is useful for simplifying the model for faster and more stable simulation."""
@@ -2342,7 +2445,6 @@ class ModelBuilder:
2342
2445
  "type": self.joint_type[i],
2343
2446
  "q": self.joint_q[q_start : q_start + q_dim],
2344
2447
  "qd": self.joint_qd[qd_start : qd_start + qd_dim],
2345
- "act": self.joint_act[qd_start : qd_start + qd_dim],
2346
2448
  "armature": self.joint_armature[qd_start : qd_start + qd_dim],
2347
2449
  "q_start": q_start,
2348
2450
  "qd_start": qd_start,
@@ -2371,6 +2473,7 @@ class ModelBuilder:
2371
2473
  "limit_kd": self.joint_limit_kd[j],
2372
2474
  "limit_lower": self.joint_limit_lower[j],
2373
2475
  "limit_upper": self.joint_limit_upper[j],
2476
+ "act": self.joint_act[j],
2374
2477
  }
2375
2478
  )
2376
2479
 
@@ -2525,7 +2628,6 @@ class ModelBuilder:
2525
2628
  self.joint_qd_start.append(len(self.joint_qd))
2526
2629
  self.joint_q.extend(joint["q"])
2527
2630
  self.joint_qd.extend(joint["qd"])
2528
- self.joint_act.extend(joint["act"])
2529
2631
  self.joint_armature.extend(joint["armature"])
2530
2632
  self.joint_enabled.append(joint["enabled"])
2531
2633
  self.joint_linear_compliance.append(joint["linear_compliance"])
@@ -2543,6 +2645,7 @@ class ModelBuilder:
2543
2645
  self.joint_limit_upper.append(axis["limit_upper"])
2544
2646
  self.joint_limit_ke.append(axis["limit_ke"])
2545
2647
  self.joint_limit_kd.append(axis["limit_kd"])
2648
+ self.joint_act.append(axis["act"])
2546
2649
 
2547
2650
  # muscles
2548
2651
  def add_muscle(
@@ -4466,7 +4569,9 @@ class ModelBuilder:
4466
4569
 
4467
4570
  # enable ground plane
4468
4571
  m.ground_plane = wp.array(self._ground_params["plane"], dtype=wp.float32, requires_grad=requires_grad)
4469
- m.gravity = np.array(self.up_vector) * self.gravity
4572
+ m.gravity = np.array(self.up_vector, dtype=wp.float32) * self.gravity
4573
+ m.up_axis = self.up_axis
4574
+ m.up_vector = np.array(self.up_vector, dtype=wp.float32)
4470
4575
 
4471
4576
  m.enable_tri_collisions = False
4472
4577
 
warp/sim/render.py CHANGED
@@ -67,13 +67,15 @@ def CreateSimRenderer(renderer):
67
67
  path,
68
68
  scaling=1.0,
69
69
  fps=60,
70
- up_axis="Y",
70
+ up_axis=None,
71
71
  show_rigid_contact_points=False,
72
72
  contact_points_radius=1e-3,
73
73
  show_joints=False,
74
74
  **render_kwargs,
75
75
  ):
76
76
  # create USD stage
77
+ if up_axis is None:
78
+ up_axis = "XYZ"[model.up_axis]
77
79
  super().__init__(path, scaling=scaling, fps=fps, up_axis=up_axis, **render_kwargs)
78
80
  self.scaling = scaling
79
81
  self.cam_axis = "XYZ".index(up_axis.upper())