warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional, Tuple, Type
16
+ from typing import ClassVar, Optional, Tuple, Type
17
17
 
18
18
  import warp as wp
19
19
  from warp.fem import cache
@@ -39,6 +39,12 @@ class SpaceTopology:
39
39
  .. note:: This will change to be defined per-element in future versions
40
40
  """
41
41
 
42
+ _dynamic_attribute_constructors: ClassVar = {
43
+ "element_node_count": lambda obj: obj._make_constant_element_node_count(),
44
+ "element_node_sign": lambda obj: obj._make_constant_element_node_sign(),
45
+ "side_neighbor_node_counts": lambda obj: obj._make_constant_side_neighbor_node_counts(),
46
+ }
47
+
42
48
  @wp.struct
43
49
  class TopologyArg:
44
50
  """Structure containing arguments to be passed to device functions"""
@@ -51,8 +57,7 @@ class SpaceTopology:
51
57
  self.MAX_NODES_PER_ELEMENT = wp.constant(max_nodes_per_element)
52
58
  self.ElementArg = geometry.CellArg
53
59
 
54
- self._make_constant_element_node_count()
55
- self._make_constant_element_node_sign()
60
+ cache.setup_dynamic_attributes(self, cls=__class__)
56
61
 
57
62
  @property
58
63
  def geometry(self) -> Geometry:
@@ -67,6 +72,9 @@ class SpaceTopology:
67
72
  """Value of the topology argument structure to be passed to device functions"""
68
73
  return SpaceTopology.TopologyArg()
69
74
 
75
+ def fill_topo_arg(self, arg, device):
76
+ pass
77
+
70
78
  @property
71
79
  def name(self):
72
80
  return f"{self.__class__.__name__}_{self.MAX_NODES_PER_ELEMENT}"
@@ -182,6 +190,11 @@ class SpaceTopology:
182
190
  ):
183
191
  return NODES_PER_ELEMENT
184
192
 
193
+ return constant_element_node_count
194
+
195
+ def _make_constant_side_neighbor_node_counts(self):
196
+ NODES_PER_ELEMENT = wp.constant(self.MAX_NODES_PER_ELEMENT)
197
+
185
198
  @cache.dynamic_func(suffix=self.name)
186
199
  def constant_side_neighbor_node_counts(
187
200
  side_arg: self.geometry.SideArg,
@@ -189,8 +202,7 @@ class SpaceTopology:
189
202
  ):
190
203
  return NODES_PER_ELEMENT, NODES_PER_ELEMENT
191
204
 
192
- self.element_node_count = constant_element_node_count
193
- self.side_neighbor_node_counts = constant_side_neighbor_node_counts
205
+ return constant_side_neighbor_node_counts
194
206
 
195
207
  def _make_constant_element_node_sign(self):
196
208
  @cache.dynamic_func(suffix=self.name)
@@ -202,12 +214,21 @@ class SpaceTopology:
202
214
  ):
203
215
  return 1.0
204
216
 
205
- self.element_node_sign = constant_element_node_sign
217
+ return constant_element_node_sign
206
218
 
207
219
 
208
220
  class TraceSpaceTopology(SpaceTopology):
209
221
  """Auto-generated trace topology defining the node indices associated to the geometry sides"""
210
222
 
223
+ _dynamic_attribute_constructors: ClassVar = {
224
+ "inner_cell_index": lambda obj: obj._make_inner_cell_index(),
225
+ "outer_cell_index": lambda obj: obj._make_outer_cell_index(),
226
+ "neighbor_cell_index": lambda obj: obj._make_neighbor_cell_index(),
227
+ "element_node_index": lambda obj: obj._make_element_node_index(),
228
+ "element_node_count": lambda obj: obj._make_element_node_count(),
229
+ "element_node_sign": lambda obj: obj._make_element_node_sign(),
230
+ }
231
+
211
232
  def __init__(self, topo: SpaceTopology):
212
233
  self._topo = topo
213
234
 
@@ -218,14 +239,10 @@ class TraceSpaceTopology(SpaceTopology):
218
239
 
219
240
  self.TopologyArg = topo.TopologyArg
220
241
  self.topo_arg_value = topo.topo_arg_value
242
+ self.fill_topo_arg = topo.fill_topo_arg
221
243
 
222
- self.inner_cell_index = self._make_inner_cell_index()
223
- self.outer_cell_index = self._make_outer_cell_index()
224
- self.neighbor_cell_index = self._make_neighbor_cell_index()
225
-
226
- self.element_node_index = self._make_element_node_index()
227
- self.element_node_count = self._make_element_node_count()
228
244
  self.side_neighbor_node_counts = None
245
+ cache.setup_dynamic_attributes(self, cls=__class__)
229
246
 
230
247
  def node_count(self) -> int:
231
248
  return self._topo.node_count()
@@ -354,21 +371,29 @@ class RegularDiscontinuousSpaceTopology(RegularDiscontinuousSpaceTopologyMixin,
354
371
 
355
372
 
356
373
  class DeformedGeometrySpaceTopology(SpaceTopology):
374
+ _dynamic_attribute_constructors: ClassVar = {
375
+ "element_node_index": lambda obj: obj._make_element_node_index(),
376
+ "element_node_count": lambda obj: obj._make_element_node_count(),
377
+ "element_node_sign": lambda obj: obj._make_element_node_sign(),
378
+ "side_neighbor_node_counts": lambda obj: obj._make_side_neighbor_node_counts(),
379
+ }
380
+
357
381
  def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
358
382
  self.base = base_topology
359
383
  super().__init__(geometry, base_topology.MAX_NODES_PER_ELEMENT)
360
384
 
361
385
  self.node_count = self.base.node_count
362
386
  self.topo_arg_value = self.base.topo_arg_value
387
+ self.fill_topo_arg = self.base.fill_topo_arg
363
388
  self.TopologyArg = self.base.TopologyArg
364
389
 
365
- self._make_passthrough_functions()
390
+ cache.setup_dynamic_attributes(self, cls=__class__)
366
391
 
367
392
  @property
368
393
  def name(self):
369
394
  return f"{self.base.name}_{self.geometry.field.name}"
370
395
 
371
- def _make_passthrough_functions(self):
396
+ def _make_element_node_index(self):
372
397
  @cache.dynamic_func(suffix=self.name)
373
398
  def element_node_index(
374
399
  elt_arg: self.geometry.CellArg,
@@ -376,16 +401,22 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
376
401
  element_index: ElementIndex,
377
402
  node_index_in_elt: int,
378
403
  ):
379
- return self.base.element_node_index(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
404
+ return self.base.element_node_index(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
405
+
406
+ return element_node_index
380
407
 
408
+ def _make_element_node_count(self):
381
409
  @cache.dynamic_func(suffix=self.name)
382
410
  def element_node_count(
383
411
  elt_arg: self.geometry.CellArg,
384
412
  topo_arg: self.TopologyArg,
385
413
  element_count: ElementIndex,
386
414
  ):
387
- return self.base.element_node_count(elt_arg.elt_arg, topo_arg, element_count)
415
+ return self.base.element_node_count(elt_arg.base_arg, topo_arg, element_count)
416
+
417
+ return element_node_count
388
418
 
419
+ def _make_side_neighbor_node_counts(self):
389
420
  @cache.dynamic_func(suffix=self.name)
390
421
  def side_neighbor_node_counts(
391
422
  side_arg: self.geometry.SideArg,
@@ -394,6 +425,9 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
394
425
  inner_count, outer_count = self.base.side_neighbor_node_counts(side_arg.base_arg, element_index)
395
426
  return inner_count, outer_count
396
427
 
428
+ return side_neighbor_node_counts
429
+
430
+ def _make_element_node_sign(self):
397
431
  @cache.dynamic_func(suffix=self.name)
398
432
  def element_node_sign(
399
433
  elt_arg: self.geometry.CellArg,
@@ -401,12 +435,9 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
401
435
  element_index: ElementIndex,
402
436
  node_index_in_elt: int,
403
437
  ):
404
- return self.base.element_node_sign(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
438
+ return self.base.element_node_sign(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
405
439
 
406
- self.element_node_index = element_node_index
407
- self.element_node_count = element_node_count
408
- self.element_node_sign = element_node_sign
409
- self.side_neighbor_node_counts = side_neighbor_node_counts
440
+ return element_node_sign
410
441
 
411
442
 
412
443
  def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
@@ -50,12 +50,15 @@ class TrimeshSpaceTopology(SpaceTopology):
50
50
  @cache.cached_arg_value
51
51
  def topo_arg_value(self, device):
52
52
  arg = TrimeshTopologyArg()
53
+ self.fill_topo_arg(arg, device)
54
+ return arg
55
+
56
+ def fill_topo_arg(self, arg: TrimeshTopologyArg, device):
53
57
  arg.tri_edge_indices = self._tri_edge_indices.to(device)
54
58
  arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
55
59
 
56
60
  arg.vertex_count = self._mesh.vertex_count()
57
61
  arg.edge_count = self._mesh.side_count()
58
- return arg
59
62
 
60
63
  def _compute_tri_edge_indices(self):
61
64
  self._tri_edge_indices = wp.empty(
warp/fem/utils.py CHANGED
@@ -19,6 +19,7 @@ import numpy as np
19
19
 
20
20
  import warp as wp
21
21
  import warp.fem.cache as cache
22
+ import warp.types
22
23
  from warp.fem.linalg import ( # noqa: F401 (for backward compatibility, not part of public API but used in examples)
23
24
  array_axpy,
24
25
  inverse_qr,
@@ -28,6 +29,57 @@ from warp.fem.types import NULL_NODE_INDEX
28
29
  from warp.utils import array_scan, radix_sort_pairs, runlength_encode
29
30
 
30
31
 
32
+ def type_zero_element(dtype):
33
+ suffix = warp.types.get_type_code(dtype)
34
+
35
+ if dtype in warp.types.scalar_types:
36
+
37
+ @cache.dynamic_func(suffix=suffix)
38
+ def zero_element():
39
+ return dtype(0.0)
40
+
41
+ return zero_element
42
+
43
+ @cache.dynamic_func(suffix=suffix)
44
+ def zero_element():
45
+ return dtype()
46
+
47
+ return zero_element
48
+
49
+
50
+ def type_basis_element(dtype):
51
+ suffix = warp.types.get_type_code(dtype)
52
+
53
+ if dtype in warp.types.scalar_types:
54
+
55
+ @cache.dynamic_func(suffix=suffix)
56
+ def basis_element(coord: int):
57
+ return dtype(1.0)
58
+
59
+ return basis_element
60
+
61
+ if warp.types.type_is_matrix(dtype):
62
+ cols = dtype._shape_[1]
63
+
64
+ @cache.dynamic_func(suffix=suffix)
65
+ def basis_element(coord: int):
66
+ v = dtype()
67
+ i = coord // cols
68
+ j = coord - i * cols
69
+ v[i, j] = v.dtype(1.0)
70
+ return v
71
+
72
+ return basis_element
73
+
74
+ @cache.dynamic_func(suffix=suffix)
75
+ def basis_element(coord: int):
76
+ v = dtype()
77
+ v[coord] = v.dtype(1.0)
78
+ return v
79
+
80
+ return basis_element
81
+
82
+
31
83
  def compress_node_indices(
32
84
  node_count: int,
33
85
  node_indices: wp.array(dtype=int),
@@ -126,14 +178,7 @@ def host_read_at_index(array: wp.array, index: int = -1, temporary_store: cache.
126
178
 
127
179
  if index < 0:
128
180
  index += array.shape[0]
129
-
130
- if array.device.is_cuda:
131
- temp = cache.borrow_temporary(temporary_store, shape=1, dtype=int, pinned=True, device="cpu")
132
- wp.copy(dest=temp.array, src=array, src_offset=index, count=1)
133
- wp.synchronize_stream(wp.get_stream(array.device))
134
- return temp.array.numpy()[0]
135
-
136
- return array.numpy()[index]
181
+ return array[index : index + 1].numpy()[0]
137
182
 
138
183
 
139
184
  def masked_indices(
warp/jax.py CHANGED
@@ -182,6 +182,5 @@ def from_jax(jax_array, dtype=None) -> warp.array:
182
182
  Returns:
183
183
  warp.array: The converted Warp array.
184
184
  """
185
- import jax.dlpack
186
185
 
187
- return warp.from_dlpack(jax.dlpack.to_dlpack(jax_array), dtype=dtype)
186
+ return warp.from_dlpack(jax_array, dtype=dtype)