warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -42,7 +42,7 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
42
42
 
43
43
  # The Python function to call.
44
44
  # Note the argument annotations, just like Warp kernels.
45
- def example_func(
45
+ def scale_func(
46
46
  # inputs
47
47
  a: wp.array(dtype=float),
48
48
  b: wp.array(dtype=wp.vec2),
@@ -55,8 +55,23 @@ def example_func(
55
55
  wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
56
56
 
57
57
 
58
+ @wp.kernel
59
+ def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
60
+ tid = wp.tid()
61
+ b[tid] += a[tid]
62
+
63
+
64
+ def in_out_func(
65
+ a: wp.array(dtype=float), # input only
66
+ b: wp.array(dtype=float), # input and output
67
+ c: wp.array(dtype=float), # output only
68
+ ):
69
+ wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
70
+ wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
71
+
72
+
58
73
  def example1():
59
- jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
74
+ jax_func = jax_callable(scale_func, num_outputs=2)
60
75
 
61
76
  @jax.jit
62
77
  def f():
@@ -78,7 +93,7 @@ def example1():
78
93
 
79
94
 
80
95
  def example2():
81
- jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
96
+ jax_func = jax_callable(scale_func, num_outputs=2)
82
97
 
83
98
  # NOTE: scalar arguments must be static compile-time constants
84
99
  @partial(jax.jit, static_argnames=["s"])
@@ -100,11 +115,26 @@ def example2():
100
115
  print(r2)
101
116
 
102
117
 
118
+ def example3():
119
+ # Using input-output arguments
120
+
121
+ jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
122
+
123
+ f = jax.jit(jax_func)
124
+
125
+ a = jnp.ones(10, dtype=jnp.float32)
126
+ b = jnp.arange(10, dtype=jnp.float32)
127
+
128
+ b, c = f(a, b)
129
+ print(b)
130
+ print(c)
131
+
132
+
103
133
  def main():
104
134
  wp.init()
105
135
  wp.load_module(device=wp.get_device())
106
136
 
107
- examples = [example1, example2]
137
+ examples = [example1, example2, example3]
108
138
 
109
139
  for example in examples:
110
140
  print("\n===========================================================================")
@@ -45,11 +45,11 @@ def example1():
45
45
  # the Python function to call
46
46
  def print_args(inputs, outputs, attrs, ctx):
47
47
  def buffer_to_string(b):
48
- return str(b.dtype) + str(list(b.shape)) + " @%x" % b.data
48
+ return f"{b.dtype}{list(b.shape)} @{b.data:x}"
49
49
 
50
50
  print("Inputs: ", ", ".join([buffer_to_string(b) for b in inputs]))
51
51
  print("Outputs: ", ", ".join([buffer_to_string(b) for b in outputs]))
52
- print("Attributes: ", "".join(["\n %s: %s" % (k, str(v)) for k, v in attrs.items()]))
52
+ print("Attributes: ", "".join([f"\n {k}: {str(v)}" for k, v in attrs.items()])) # noqa: RUF010
53
53
 
54
54
  # register callback
55
55
  register_ffi_callback("print_args", print_args)
@@ -72,6 +72,17 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
72
72
  output[tid] = a[tid] * s
73
73
 
74
74
 
75
+ @wp.kernel
76
+ def in_out_kernel(
77
+ a: wp.array(dtype=float), # input only
78
+ b: wp.array(dtype=float), # input and output
79
+ c: wp.array(dtype=float), # output only
80
+ ):
81
+ tid = wp.tid()
82
+ b[tid] += a[tid]
83
+ c[tid] = 2.0 * a[tid]
84
+
85
+
75
86
  def example1():
76
87
  # two inputs and one output
77
88
  jax_add = jax_kernel(add_kernel)
@@ -189,11 +200,26 @@ def example7():
189
200
  print(f())
190
201
 
191
202
 
203
+ def example8():
204
+ # Using input-output arguments
205
+
206
+ jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
207
+
208
+ f = jax.jit(jax_func)
209
+
210
+ a = jnp.ones(10, dtype=jnp.float32)
211
+ b = jnp.arange(10, dtype=jnp.float32)
212
+
213
+ b, c = f(a, b)
214
+ print(b)
215
+ print(c)
216
+
217
+
192
218
  def main():
193
219
  wp.init()
194
220
  wp.load_module(device=wp.get_device())
195
221
 
196
- examples = [example1, example2, example3, example4, example5, example6, example7]
222
+ examples = [example1, example2, example3, example4, example5, example6, example7, example8]
197
223
 
198
224
  for example in examples:
199
225
  print("\n===========================================================================")
@@ -708,7 +708,7 @@ class Example:
708
708
  self.tape.zero()
709
709
 
710
710
  def step(self):
711
- if self.frame % int((self.num_frames / len(self.targets))) == 0:
711
+ if self.frame % int(self.num_frames / len(self.targets)) == 0:
712
712
  if self.verbose:
713
713
  print(f"Choosing new flight target: {self.target_idx + 1}")
714
714
 
@@ -219,7 +219,7 @@ if __name__ == "__main__":
219
219
  example.render()
220
220
 
221
221
  frame_times = example.profiler["step"]
222
- print("\nAverage frame sim time: {:.2f} ms".format(sum(frame_times) / len(frame_times)))
222
+ print(f"\nAverage frame sim time: {sum(frame_times) / len(frame_times):.2f} ms")
223
223
 
224
224
  if example.renderer:
225
225
  example.renderer.save()
@@ -128,6 +128,10 @@ class Example:
128
128
  self.num_substeps = 10
129
129
  self.iterations = 4
130
130
  self.dt = self.frame_dt / self.num_substeps
131
+ # the BVH used by VBDIntegrator will be rebuilt every self.bvh_rebuild_frames
132
+ # When the simulated object deforms significantly, simply refitting the BVH can lead to deterioration of the BVH's
133
+ # quality, in this case we need to completely rebuild the tree to achieve better query efficiency.
134
+ self.bvh_rebuild_frames = 10
131
135
 
132
136
  self.num_frames = num_frames
133
137
  self.sim_time = 0.0
@@ -227,69 +231,62 @@ class Example:
227
231
  self.cuda_graph = None
228
232
  if self.use_cuda_graph:
229
233
  with wp.ScopedCapture() as capture:
230
- for _ in range(self.num_substeps):
231
- wp.launch(
232
- kernel=apply_rotation,
233
- dim=self.rot_point_indices.shape[0],
234
- inputs=[
235
- self.rot_point_indices,
236
- self.rot_axes,
237
- self.roots,
238
- self.roots_to_ps,
239
- self.t,
240
- self.rot_angular_velocity,
241
- self.dt,
242
- self.rot_end_time,
243
- ],
244
- outputs=[
245
- self.state0.particle_q,
246
- self.state1.particle_q,
247
- ],
248
- )
249
-
250
- self.integrator.simulate(self.model, self.state0, self.state1, self.dt, None)
251
- (self.state0, self.state1) = (self.state1, self.state0)
234
+ self.integrate_frame_substeps()
252
235
 
253
236
  self.cuda_graph = capture.graph
254
237
 
255
- def step(self):
238
+ def integrate_frame_substeps(self):
239
+ for _ in range(self.num_substeps):
240
+ wp.launch(
241
+ kernel=apply_rotation,
242
+ dim=self.rot_point_indices.shape[0],
243
+ inputs=[
244
+ self.rot_point_indices,
245
+ self.rot_axes,
246
+ self.roots,
247
+ self.roots_to_ps,
248
+ self.t,
249
+ self.rot_angular_velocity,
250
+ self.dt,
251
+ self.rot_end_time,
252
+ ],
253
+ outputs=[
254
+ self.state0.particle_q,
255
+ self.state1.particle_q,
256
+ ],
257
+ )
258
+
259
+ self.integrator.simulate(self.model, self.state0, self.state1, self.dt, None)
260
+ (self.state0, self.state1) = (self.state1, self.state0)
261
+
262
+ def advance_frame(self):
256
263
  with wp.ScopedTimer("step", print=False, dict=self.profiler):
257
264
  if self.use_cuda_graph:
258
265
  wp.capture_launch(self.cuda_graph)
259
266
  else:
260
- for _ in range(self.num_substeps):
261
- wp.launch(
262
- kernel=apply_rotation,
263
- dim=self.rot_point_indices.shape[0],
264
- inputs=[
265
- self.rot_point_indices,
266
- self.rot_axes,
267
- self.roots,
268
- self.roots_to_ps,
269
- self.t,
270
- self.rot_angular_velocity,
271
- self.dt,
272
- self.rot_end_time,
273
- ],
274
- outputs=[
275
- self.state0.particle_q,
276
- self.state1.particle_q,
277
- ],
278
- )
279
- self.integrator.simulate(self.model, self.state0, self.state1, self.dt)
280
-
281
- (self.state0, self.state1) = (self.state1, self.state0)
267
+ self.integrate_frame_substeps()
282
268
 
283
269
  self.sim_time += self.dt
284
270
 
271
+ def run(self):
272
+ for i in range(self.num_frames):
273
+ self.advance_frame()
274
+ self.render()
275
+ print(f"[{i:4d}/{self.num_frames}]")
276
+
277
+ if i != 0 and not i % self.bvh_rebuild_frames and self.use_cuda_graph:
278
+ self.integrator.rebuild_bvh(self.state0)
279
+ with wp.ScopedCapture() as capture:
280
+ self.integrate_frame_substeps()
281
+ self.cuda_graph = capture.graph
282
+
285
283
  def render(self):
286
284
  if self.renderer is None:
287
285
  return
288
286
 
289
- with wp.ScopedTimer("render", print=False):
290
- self.renderer.begin_frame(self.sim_time)
291
- self.renderer.render(self.state0)
292
- self.renderer.end_frame()
287
+ self.renderer.begin_frame(self.sim_time)
288
+ self.renderer.render(self.state0)
289
+ self.renderer.end_frame()
293
290
 
294
291
 
295
292
  if __name__ == "__main__":
@@ -310,13 +307,10 @@ if __name__ == "__main__":
310
307
  with wp.ScopedDevice(args.device):
311
308
  example = Example(stage_path=args.stage_path, num_frames=args.num_frames)
312
309
 
313
- for i in range(example.num_frames):
314
- example.step()
315
- example.render()
316
- print(f"[{i:4d}/{example.num_frames}]")
310
+ example.run()
317
311
 
318
312
  frame_times = example.profiler["step"]
319
- print("\nAverage frame sim time: {:.2f} ms".format(sum(frame_times) / len(frame_times)))
313
+ print(f"\nAverage frame sim time: {sum(frame_times) / len(frame_times):.2f} ms")
320
314
 
321
315
  if example.renderer:
322
316
  example.renderer.save()