warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/sim/particles.py CHANGED
@@ -11,6 +11,7 @@ from .model import PARTICLE_FLAG_ACTIVE
11
11
 
12
12
  @wp.func
13
13
  def particle_force(n: wp.vec3, v: wp.vec3, c: float, k_n: float, k_d: float, k_f: float, k_mu: float):
14
+ # compute normal and tangential friction force for a single contact
14
15
  vn = wp.dot(n, v)
15
16
  jn = c * k_n
16
17
  jd = min(vn, 0.0) * k_d
@@ -89,7 +90,7 @@ def eval_particle_forces_kernel(
89
90
 
90
91
 
91
92
  def eval_particle_forces(model, state, forces):
92
- if model.particle_max_radius > 0.0:
93
+ if model.particle_count > 1 and model.particle_max_radius > 0.0:
93
94
  wp.launch(
94
95
  kernel=eval_particle_forces_kernel,
95
96
  dim=model.particle_count,
warp/sim/render.py CHANGED
@@ -23,6 +23,7 @@ NAN = wp.constant(-1.0e8)
23
23
  def compute_contact_points(
24
24
  body_q: wp.array(dtype=wp.transform),
25
25
  shape_body: wp.array(dtype=int),
26
+ contact_count: wp.array(dtype=int),
26
27
  contact_shape0: wp.array(dtype=int),
27
28
  contact_shape1: wp.array(dtype=int),
28
29
  contact_point0: wp.array(dtype=wp.vec3),
@@ -32,6 +33,11 @@ def compute_contact_points(
32
33
  contact_pos1: wp.array(dtype=wp.vec3),
33
34
  ):
34
35
  tid = wp.tid()
36
+ count = contact_count[0]
37
+ if tid >= count:
38
+ contact_pos0[tid] = wp.vec3(NAN, NAN, NAN)
39
+ contact_pos1[tid] = wp.vec3(NAN, NAN, NAN)
40
+ return
35
41
  shape_a = contact_shape0[tid]
36
42
  shape_b = contact_shape1[tid]
37
43
  if shape_a == shape_b:
@@ -124,11 +130,12 @@ def CreateSimRenderer(renderer):
124
130
  shape_geo_thickness = model.shape_geo.thickness.numpy()
125
131
  shape_geo_is_solid = model.shape_geo.is_solid.numpy()
126
132
  shape_transform = model.shape_transform.numpy()
133
+ shape_visible = model.shape_visible.numpy()
127
134
 
128
135
  p = np.zeros(3, dtype=np.float32)
129
136
  q = np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32)
130
137
  scale = np.ones(3)
131
- color = np.ones(3)
138
+ color = (1.0, 1.0, 1.0)
132
139
  # loop over shapes excluding the ground plane
133
140
  for s in range(model.shape_count - 1):
134
141
  geo_type = shape_geo_type[s]
@@ -136,8 +143,6 @@ def CreateSimRenderer(renderer):
136
143
  geo_thickness = float(shape_geo_thickness[s])
137
144
  geo_is_solid = bool(shape_geo_is_solid[s])
138
145
  geo_src = shape_geo_src[s]
139
- if self.use_unique_colors:
140
- color = self._get_new_color()
141
146
  name = f"shape_{s}"
142
147
 
143
148
  # shape transform in body frame
@@ -147,6 +152,9 @@ def CreateSimRenderer(renderer):
147
152
  else:
148
153
  body = None
149
154
 
155
+ if self.use_unique_colors and body is not None:
156
+ color = self._get_new_color()
157
+
150
158
  # shape transform in body frame
151
159
  X_bs = wp.transform_expand(shape_transform[s])
152
160
  # check whether we can instance an already created shape with the same geometry
@@ -167,25 +175,25 @@ def CreateSimRenderer(renderer):
167
175
  )
168
176
 
169
177
  elif geo_type == warp.sim.GEO_SPHERE:
170
- shape = self.render_sphere(name, p, q, geo_scale[0], parent_body=body, is_template=True)
178
+ shape = self.render_sphere(name, p, q, geo_scale[0], parent_body=body, is_template=True, color=color)
171
179
 
172
180
  elif geo_type == warp.sim.GEO_CAPSULE:
173
181
  shape = self.render_capsule(
174
- name, p, q, geo_scale[0], geo_scale[1], parent_body=body, is_template=True
182
+ name, p, q, geo_scale[0], geo_scale[1], parent_body=body, is_template=True, color=color
175
183
  )
176
184
 
177
185
  elif geo_type == warp.sim.GEO_CYLINDER:
178
186
  shape = self.render_cylinder(
179
- name, p, q, geo_scale[0], geo_scale[1], parent_body=body, is_template=True
187
+ name, p, q, geo_scale[0], geo_scale[1], parent_body=body, is_template=True, color=color
180
188
  )
181
189
 
182
190
  elif geo_type == warp.sim.GEO_CONE:
183
191
  shape = self.render_cone(
184
- name, p, q, geo_scale[0], geo_scale[1], parent_body=body, is_template=True
192
+ name, p, q, geo_scale[0], geo_scale[1], parent_body=body, is_template=True, color=color
185
193
  )
186
194
 
187
195
  elif geo_type == warp.sim.GEO_BOX:
188
- shape = self.render_box(name, p, q, geo_scale, parent_body=body, is_template=True)
196
+ shape = self.render_box(name, p, q, geo_scale, parent_body=body, is_template=True, color=color)
189
197
 
190
198
  elif geo_type == warp.sim.GEO_MESH:
191
199
  if not geo_is_solid:
@@ -210,7 +218,9 @@ def CreateSimRenderer(renderer):
210
218
 
211
219
  self.geo_shape[geo_hash] = shape
212
220
 
213
- self.add_shape_instance(name, shape, body, X_bs.p, X_bs.q, scale)
221
+ if shape_visible[s]:
222
+ # TODO support dynamic visibility
223
+ self.add_shape_instance(name, shape, body, X_bs.p, X_bs.q, scale, custom_index=s, visible=shape_visible[s])
214
224
  self.instance_count += 1
215
225
 
216
226
  if self.show_joints and model.joint_count:
@@ -274,7 +284,7 @@ def CreateSimRenderer(renderer):
274
284
  self.instance_count += 1
275
285
 
276
286
  if model.ground:
277
- self.render_ground()
287
+ self.render_ground(plane=model.ground_plane_params)
278
288
 
279
289
  if hasattr(self, "complete_setup"):
280
290
  self.complete_setup()
@@ -283,6 +293,12 @@ def CreateSimRenderer(renderer):
283
293
  return tab10_color_map(self.instance_count)
284
294
 
285
295
  def render(self, state: warp.sim.State):
296
+ """
297
+ Updates the renderer with the given simulation state.
298
+
299
+ Args:
300
+ state (warp.sim.State): The simulation state to render.
301
+ """
286
302
  if self.skip_rendering:
287
303
  return
288
304
 
@@ -290,15 +306,20 @@ def CreateSimRenderer(renderer):
290
306
  particle_q = state.particle_q.numpy()
291
307
 
292
308
  # render particles
293
- self.render_points("particles", particle_q, radius=self.model.particle_radius.numpy())
309
+ self.render_points("particles", particle_q, radius=self.model.particle_radius.numpy(), colors=((0.8, 0.3, 0.2),) * len(particle_q))
294
310
 
295
311
  # render tris
296
312
  if self.model.tri_count:
297
- self.render_mesh("surface", particle_q, self.model.tri_indices.numpy().flatten())
313
+ self.render_mesh(
314
+ "surface",
315
+ particle_q,
316
+ self.model.tri_indices.numpy().flatten(),
317
+ colors=(((0.75, 0.25, 0.0),) * len(particle_q)),
318
+ )
298
319
 
299
320
  # render springs
300
321
  if self.model.spring_count:
301
- self.render_line_list("springs", particle_q, self.model.spring_indices.numpy().flatten(), [], 0.05)
322
+ self.render_line_list("springs", particle_q, self.model.spring_indices.numpy().flatten(), (0.25, 0.5, 0.25), 0.02)
302
323
 
303
324
  # render muscles
304
325
  if self.model.muscle_count:
@@ -348,6 +369,7 @@ def CreateSimRenderer(renderer):
348
369
  inputs=[
349
370
  state.body_q,
350
371
  self.model.shape_body,
372
+ self.model.rigid_contact_count,
351
373
  self.model.rigid_contact_shape0,
352
374
  self.model.rigid_contact_shape1,
353
375
  self.model.rigid_contact_point0,
warp/sim/utils.py CHANGED
@@ -1,13 +1,20 @@
1
1
  import warp as wp
2
+ import numpy as np
2
3
 
3
- PI = wp.constant(3.14159265359)
4
- PI_2 = wp.constant(1.57079632679)
4
+ from typing import Tuple, List
5
5
 
6
6
 
7
7
  @wp.func
8
8
  def velocity_at_point(qd: wp.spatial_vector, r: wp.vec3):
9
9
  """
10
10
  Returns the velocity of a point relative to the frame with the given spatial velocity.
11
+
12
+ Args:
13
+ qd (spatial_vector): The spatial velocity of the frame.
14
+ r (vec3): The position of the point relative to the frame.
15
+
16
+ Returns:
17
+ vec3: The velocity of the point.
11
18
  """
12
19
  return wp.cross(wp.spatial_top(qd), r) + wp.spatial_bottom(qd)
13
20
 
@@ -52,7 +59,7 @@ def quat_decompose(q: wp.quat):
52
59
  phi = wp.atan2(R[1, 2], R[2, 2])
53
60
  sinp = -R[0, 2]
54
61
  if wp.abs(sinp) >= 1.0:
55
- theta = 1.57079632679 * wp.sign(sinp)
62
+ theta = wp.HALF_PI * wp.sign(sinp)
56
63
  else:
57
64
  theta = wp.asin(-R[0, 2])
58
65
  psi = wp.atan2(R[0, 1], R[0, 0])
@@ -63,7 +70,7 @@ def quat_decompose(q: wp.quat):
63
70
  @wp.func
64
71
  def quat_to_rpy(q: wp.quat):
65
72
  """
66
- Convert a quaternion into euler angles (roll, pitch, yaw)
73
+ Convert a quaternion into Euler angles (roll, pitch, yaw)
67
74
  roll is rotation around x in radians (counterclockwise)
68
75
  pitch is rotation around y in radians (counterclockwise)
69
76
  yaw is rotation around z in radians (counterclockwise)
@@ -90,11 +97,27 @@ def quat_to_rpy(q: wp.quat):
90
97
  @wp.func
91
98
  def quat_to_euler(q: wp.quat, i: int, j: int, k: int) -> wp.vec3:
92
99
  """
93
- Convert a quaternion into euler angles
94
- i, j, k are the indices in [1,2,3] of the axes to use
95
- (i != j, j != k)
100
+ Convert a quaternion into Euler angles.
101
+
102
+ :math:`i, j, k` are the indices in :math:`[0, 1, 2]` of the axes to use
103
+ (:math:`i \\neq j, j \\neq k`).
104
+
105
+ Reference: https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0276302
106
+
107
+ Args:
108
+ q (quat): The quaternion to convert
109
+ i (int): The index of the first axis
110
+ j (int): The index of the second axis
111
+ k (int): The index of the third axis
112
+
113
+ Returns:
114
+ vec3: The Euler angles (in radians)
96
115
  """
97
- # https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0276302
116
+ # i, j, k are actually assumed to follow 1-based indexing but
117
+ # we want to be compatible with quat_from_euler
118
+ i += 1
119
+ j += 1
120
+ k += 1
98
121
  not_proper = True
99
122
  if i == k:
100
123
  not_proper = False
@@ -116,17 +139,54 @@ def quat_to_euler(q: wp.quat, i: int, j: int, k: int) -> wp.vec3:
116
139
  t3 = 0.0
117
140
  if wp.abs(t2) < 1e-6:
118
141
  t3 = 2.0 * tp - t1
119
- elif wp.abs(t2 - PI_2) < 1e-6:
142
+ elif wp.abs(t2 - wp.HALF_PI) < 1e-6:
120
143
  t3 = 2.0 * tm + t1
121
144
  else:
122
145
  t1 = tp - tm
123
146
  t3 = tp + tm
124
147
  if not_proper:
125
- t2 -= PI_2
148
+ t2 -= wp.HALF_PI
126
149
  t3 *= e
127
150
  return wp.vec3(t1, t2, t3)
128
151
 
129
152
 
153
+ @wp.func
154
+ def quat_from_euler(e: wp.vec3, i: int, j: int, k: int) -> wp.quat:
155
+ """
156
+ Convert Euler angles to a quaternion.
157
+
158
+ :math:`i, j, k` are the indices in :math:`[0, 1, 2]` of the axes in which the Euler angles are provided
159
+ (:math:`i \\neq j, j \\neq k`), e.g. (0, 1, 2) for Euler sequence XYZ.
160
+
161
+ Args:
162
+ e (vec3): The Euler angles (in radians)
163
+ i (int): The index of the first axis
164
+ j (int): The index of the second axis
165
+ k (int): The index of the third axis
166
+
167
+ Returns:
168
+ quat: The quaternion
169
+ """
170
+ # Half angles
171
+ half_e = e / 2.0
172
+
173
+ # Precompute sines and cosines of half angles
174
+ cr = wp.cos(half_e[i])
175
+ sr = wp.sin(half_e[i])
176
+ cp = wp.cos(half_e[j])
177
+ sp = wp.sin(half_e[j])
178
+ cy = wp.cos(half_e[k])
179
+ sy = wp.sin(half_e[k])
180
+
181
+ # Components of the quaternion based on the rotation sequence
182
+ return wp.quat(
183
+ (cy * sr * cp - sy * cr * sp),
184
+ (cy * cr * sp + sy * sr * cp),
185
+ (sy * cr * cp - cy * sr * sp),
186
+ (cy * cr * cp + sy * sr * sp),
187
+ )
188
+
189
+
130
190
  @wp.func
131
191
  def transform_twist(t: wp.transform, x: wp.spatial_vector):
132
192
  # Frank & Park definition 3.20, pg 100
@@ -181,6 +241,35 @@ def transform_inertia(t: wp.transform, I: wp.spatial_matrix):
181
241
  return wp.mul(wp.mul(wp.transpose(T), I), T)
182
242
 
183
243
 
244
+ @wp.func
245
+ def boltzmann(a: float, b: float, alpha: float):
246
+ e1 = wp.exp(alpha * a)
247
+ e2 = wp.exp(alpha * b)
248
+ return (a * e1 + b * e2) / (e1 + e2)
249
+
250
+
251
+ @wp.func
252
+ def smooth_max(a: float, b: float, eps: float):
253
+ d = a - b
254
+ return 0.5 * (a + b + wp.sqrt(d * d + eps))
255
+
256
+
257
+ @wp.func
258
+ def smooth_min(a: float, b: float, eps: float):
259
+ d = a - b
260
+ return 0.5 * (a + b - wp.sqrt(d * d + eps))
261
+
262
+
263
+ @wp.func
264
+ def leaky_max(a: float, b: float):
265
+ return smooth_max(a, b, 1e-5)
266
+
267
+
268
+ @wp.func
269
+ def leaky_min(a: float, b: float):
270
+ return smooth_min(a, b, 1e-5)
271
+
272
+
184
273
  @wp.func
185
274
  def vec_min(a: wp.vec3, b: wp.vec3):
186
275
  return wp.vec3(wp.min(a[0], b[0]), wp.min(a[1], b[1]), wp.min(a[2], b[2]))
@@ -191,9 +280,131 @@ def vec_max(a: wp.vec3, b: wp.vec3):
191
280
  return wp.vec3(wp.max(a[0], b[0]), wp.max(a[1], b[1]), wp.max(a[2], b[2]))
192
281
 
193
282
 
283
+ @wp.func
284
+ def vec_leaky_min(a: wp.vec3, b: wp.vec3):
285
+ return wp.vec3(leaky_min(a[0], b[0]), leaky_min(a[1], b[1]), leaky_min(a[2], b[2]))
286
+
287
+
288
+ @wp.func
289
+ def vec_leaky_max(a: wp.vec3, b: wp.vec3):
290
+ return wp.vec3(leaky_max(a[0], b[0]), leaky_max(a[1], b[1]), leaky_max(a[2], b[2]))
291
+
292
+
194
293
  @wp.func
195
294
  def vec_abs(a: wp.vec3):
196
295
  return wp.vec3(wp.abs(a[0]), wp.abs(a[1]), wp.abs(a[2]))
197
296
 
198
297
 
199
-
298
+ def load_mesh(filename: str, method: str = None):
299
+ """
300
+ Loads a 3D triangular surface mesh from a file.
301
+
302
+ Args:
303
+ filename (str): The path to the 3D model file (obj, and other formats supported by the different methods) to load.
304
+ method (str): The method to use for loading the mesh (default None). Can be either `"trimesh"`, `"meshio"`, `"pcu"`, or `"openmesh"`. If None, every method is tried and the first successful mesh import where the number of vertices is greater than 0 is returned.
305
+
306
+ Returns:
307
+ Tuple of (mesh_points, mesh_indices), where mesh_points is a Nx3 numpy array of vertex positions (float32),
308
+ and mesh_indices is a Mx3 numpy array of vertex indices (int32) for the triangular faces.
309
+ """
310
+ import os
311
+
312
+ if not os.path.exists(filename):
313
+ raise ValueError(f"File not found: {filename}")
314
+
315
+ def load_mesh_with_method(method):
316
+ if method == "meshio":
317
+ import meshio
318
+
319
+ m = meshio.read(filename)
320
+ mesh_points = np.array(m.points)
321
+ mesh_indices = np.array(m.cells[0].data, dtype=np.int32)
322
+ elif method == "openmesh":
323
+ import openmesh
324
+
325
+ m = openmesh.read_trimesh(filename)
326
+ mesh_points = np.array(m.points())
327
+ mesh_indices = np.array(m.face_vertex_indices(), dtype=np.int32)
328
+ elif method == "pcu":
329
+ import point_cloud_utils as pcu
330
+
331
+ mesh_points, mesh_indices = pcu.load_mesh_vf(filename)
332
+ mesh_indices = mesh_indices.flatten()
333
+ else:
334
+ import trimesh
335
+
336
+ m = trimesh.load(filename)
337
+ if hasattr(m, "geometry"):
338
+ # multiple meshes are contained in a scene; combine to one mesh
339
+ mesh_points = []
340
+ mesh_indices = []
341
+ index_offset = 0
342
+ for geom in m.geometry.values():
343
+ vertices = np.array(geom.vertices, dtype=np.float32)
344
+ faces = np.array(geom.faces.flatten(), dtype=np.int32)
345
+ mesh_points.append(vertices)
346
+ mesh_indices.append(faces + index_offset)
347
+ index_offset += len(vertices)
348
+ mesh_points = np.concatenate(mesh_points, axis=0)
349
+ mesh_indices = np.concatenate(mesh_indices)
350
+ else:
351
+ # a single mesh
352
+ mesh_points = np.array(m.vertices, dtype=np.float32)
353
+ mesh_indices = np.array(m.faces.flatten(), dtype=np.int32)
354
+ return mesh_points, mesh_indices
355
+
356
+ if method is None:
357
+ methods = ["trimesh", "meshio", "pcu", "openmesh"]
358
+ for method in methods:
359
+ try:
360
+ mesh = load_mesh_with_method(method)
361
+ if mesh is not None and len(mesh[0]) > 0:
362
+ return mesh
363
+ except Exception:
364
+ pass
365
+ raise ValueError(f"Failed to load mesh using any of the methods: {methods}")
366
+ else:
367
+ mesh = load_mesh_with_method(method)
368
+ if mesh is None or len(mesh[0]) == 0:
369
+ raise ValueError(f"Failed to load mesh using method {method}")
370
+ return mesh
371
+
372
+
373
+ def visualize_meshes(
374
+ meshes: List[Tuple[list, list]], num_cols=0, num_rows=0, titles=[], scale_axes=True, show_plot=True
375
+ ):
376
+ # render meshes in a grid with matplotlib
377
+ import matplotlib.pyplot as plt
378
+ from mpl_toolkits.mplot3d import Axes3D
379
+
380
+ num_cols = min(num_cols, len(meshes))
381
+ num_rows = min(num_rows, len(meshes))
382
+ if num_cols and not num_rows:
383
+ num_rows = int(np.ceil(len(meshes) / num_cols))
384
+ elif num_rows and not num_cols:
385
+ num_cols = int(np.ceil(len(meshes) / num_rows))
386
+ else:
387
+ num_cols = len(meshes)
388
+ num_rows = 1
389
+
390
+ vertices = [np.array(v).reshape((-1, 3)) for v, _ in meshes]
391
+ faces = [np.array(f, dtype=np.int32).reshape((-1, 3)) for _, f in meshes]
392
+ if scale_axes:
393
+ ranges = np.array([v.max(axis=0) - v.min(axis=0) for v in vertices])
394
+ max_range = ranges.max()
395
+ mid_points = np.array([v.max(axis=0) + v.min(axis=0) for v in vertices]) * 0.5
396
+
397
+ fig = plt.figure(figsize=(12, 6))
398
+ for i, (vertices, faces) in enumerate(meshes):
399
+ ax = fig.add_subplot(num_rows, num_cols, i + 1, projection="3d")
400
+ if i < len(titles):
401
+ ax.set_title(titles[i])
402
+ ax.plot_trisurf(vertices[:, 0], vertices[:, 1], vertices[:, 2], triangles=faces, edgecolor="k")
403
+ if scale_axes:
404
+ mid = mid_points[i]
405
+ ax.set_xlim(mid[0] - max_range, mid[0] + max_range)
406
+ ax.set_ylim(mid[1] - max_range, mid[1] + max_range)
407
+ ax.set_zlim(mid[2] - max_range, mid[2] + max_range)
408
+ if show_plot:
409
+ plt.show()
410
+ return fig
warp/stubs.py CHANGED
@@ -56,6 +56,8 @@ from warp.context import get_device, set_device, synchronize_device
56
56
  from warp.context import (
57
57
  zeros,
58
58
  zeros_like,
59
+ ones,
60
+ ones_like,
59
61
  full,
60
62
  full_like,
61
63
  clone,
@@ -74,9 +76,15 @@ from warp.context import Kernel, Function, Launch
74
76
  from warp.context import Stream, get_stream, set_stream, synchronize_stream
75
77
  from warp.context import Event, record_event, wait_event, wait_stream
76
78
  from warp.context import RegisteredGLBuffer
79
+ from warp.context import is_mempool_supported, is_mempool_enabled, set_mempool_enabled
80
+ from warp.context import set_mempool_release_threshold, get_mempool_release_threshold
81
+ from warp.context import is_mempool_access_supported, is_mempool_access_enabled, set_mempool_access_enabled
82
+ from warp.context import is_peer_access_supported, is_peer_access_enabled, set_peer_access_enabled
77
83
 
78
84
  from warp.tape import Tape
79
85
  from warp.utils import ScopedTimer, ScopedDevice, ScopedStream
86
+ from warp.utils import ScopedMempool, ScopedMempoolAccess, ScopedPeerAccess
87
+ from warp.utils import ScopedCapture
80
88
  from warp.utils import transform_expand, quat_between_vectors
81
89
 
82
90
  from warp.torch import from_torch, to_torch
warp/tape.py CHANGED
@@ -95,7 +95,11 @@ class Tape:
95
95
  # existing code before we added wp.array.grad attribute
96
96
  if grads:
97
97
  for a, g in grads.items():
98
- a.grad = g
98
+ if a.grad is None:
99
+ a.grad = g
100
+ else:
101
+ # ensure we can capture this backward pass in a CUDA graph
102
+ a.grad.assign(g)
99
103
  self.const_gradients.add(a)
100
104
 
101
105
  # run launches backwards
@@ -104,6 +108,17 @@ class Tape:
104
108
  launch()
105
109
 
106
110
  else:
111
+ # kernel option takes precedence over module option
112
+ kernel_enable_backward = launch[0].options.get("enable_backward")
113
+ if kernel_enable_backward is False:
114
+ msg = f"Running the tape backwards may produce incorrect gradients because recorded kernel {launch[0].key} is configured with the option 'enable_backward=False'."
115
+ wp.utils.warn(msg)
116
+ elif kernel_enable_backward is None:
117
+ module_enable_backward = launch[0].module.options.get("enable_backward")
118
+ if module_enable_backward is False:
119
+ msg = f"Running the tape backwards may produce incorrect gradients because recorded kernel {launch[0].key} is defined in a module with the option 'enable_backward=False' set."
120
+ wp.utils.warn(msg)
121
+
107
122
  kernel = launch[0]
108
123
  dim = launch[1]
109
124
  max_blocks = launch[2]
@@ -0,0 +1,23 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ """This file is used to test importing a user-defined function with a custom gradient"""
9
+
10
+ import warp as wp
11
+
12
+ wp.init()
13
+
14
+
15
+ @wp.func
16
+ def aux_custom_fn(x: float, y: float):
17
+ return x * 3.0 + y / 3.0, y**2.5
18
+
19
+
20
+ @wp.func_grad(aux_custom_fn)
21
+ def aux_custom_fn_grad(x: float, y: float, adj_ret0: float, adj_ret1: float):
22
+ wp.adjoint[x] += x * adj_ret0 * 42.0 + y * adj_ret1 * 10.0
23
+ wp.adjoint[y] += y * adj_ret1 * 3.0