warp-lang 1.7.2__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/fem/space/topology.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from typing import Optional, Tuple, Type
|
|
16
|
+
from typing import ClassVar, Optional, Tuple, Type
|
|
17
17
|
|
|
18
18
|
import warp as wp
|
|
19
19
|
from warp.fem import cache
|
|
@@ -39,6 +39,12 @@ class SpaceTopology:
|
|
|
39
39
|
.. note:: This will change to be defined per-element in future versions
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
+
_dynamic_attribute_constructors: ClassVar = {
|
|
43
|
+
"element_node_count": lambda obj: obj._make_constant_element_node_count(),
|
|
44
|
+
"element_node_sign": lambda obj: obj._make_constant_element_node_sign(),
|
|
45
|
+
"side_neighbor_node_counts": lambda obj: obj._make_constant_side_neighbor_node_counts(),
|
|
46
|
+
}
|
|
47
|
+
|
|
42
48
|
@wp.struct
|
|
43
49
|
class TopologyArg:
|
|
44
50
|
"""Structure containing arguments to be passed to device functions"""
|
|
@@ -51,8 +57,7 @@ class SpaceTopology:
|
|
|
51
57
|
self.MAX_NODES_PER_ELEMENT = wp.constant(max_nodes_per_element)
|
|
52
58
|
self.ElementArg = geometry.CellArg
|
|
53
59
|
|
|
54
|
-
|
|
55
|
-
self._make_constant_element_node_sign()
|
|
60
|
+
cache.setup_dynamic_attributes(self, cls=__class__)
|
|
56
61
|
|
|
57
62
|
@property
|
|
58
63
|
def geometry(self) -> Geometry:
|
|
@@ -67,6 +72,9 @@ class SpaceTopology:
|
|
|
67
72
|
"""Value of the topology argument structure to be passed to device functions"""
|
|
68
73
|
return SpaceTopology.TopologyArg()
|
|
69
74
|
|
|
75
|
+
def fill_topo_arg(self, arg, device):
|
|
76
|
+
pass
|
|
77
|
+
|
|
70
78
|
@property
|
|
71
79
|
def name(self):
|
|
72
80
|
return f"{self.__class__.__name__}_{self.MAX_NODES_PER_ELEMENT}"
|
|
@@ -182,6 +190,11 @@ class SpaceTopology:
|
|
|
182
190
|
):
|
|
183
191
|
return NODES_PER_ELEMENT
|
|
184
192
|
|
|
193
|
+
return constant_element_node_count
|
|
194
|
+
|
|
195
|
+
def _make_constant_side_neighbor_node_counts(self):
|
|
196
|
+
NODES_PER_ELEMENT = wp.constant(self.MAX_NODES_PER_ELEMENT)
|
|
197
|
+
|
|
185
198
|
@cache.dynamic_func(suffix=self.name)
|
|
186
199
|
def constant_side_neighbor_node_counts(
|
|
187
200
|
side_arg: self.geometry.SideArg,
|
|
@@ -189,8 +202,7 @@ class SpaceTopology:
|
|
|
189
202
|
):
|
|
190
203
|
return NODES_PER_ELEMENT, NODES_PER_ELEMENT
|
|
191
204
|
|
|
192
|
-
|
|
193
|
-
self.side_neighbor_node_counts = constant_side_neighbor_node_counts
|
|
205
|
+
return constant_side_neighbor_node_counts
|
|
194
206
|
|
|
195
207
|
def _make_constant_element_node_sign(self):
|
|
196
208
|
@cache.dynamic_func(suffix=self.name)
|
|
@@ -202,12 +214,21 @@ class SpaceTopology:
|
|
|
202
214
|
):
|
|
203
215
|
return 1.0
|
|
204
216
|
|
|
205
|
-
|
|
217
|
+
return constant_element_node_sign
|
|
206
218
|
|
|
207
219
|
|
|
208
220
|
class TraceSpaceTopology(SpaceTopology):
|
|
209
221
|
"""Auto-generated trace topology defining the node indices associated to the geometry sides"""
|
|
210
222
|
|
|
223
|
+
_dynamic_attribute_constructors: ClassVar = {
|
|
224
|
+
"inner_cell_index": lambda obj: obj._make_inner_cell_index(),
|
|
225
|
+
"outer_cell_index": lambda obj: obj._make_outer_cell_index(),
|
|
226
|
+
"neighbor_cell_index": lambda obj: obj._make_neighbor_cell_index(),
|
|
227
|
+
"element_node_index": lambda obj: obj._make_element_node_index(),
|
|
228
|
+
"element_node_count": lambda obj: obj._make_element_node_count(),
|
|
229
|
+
"element_node_sign": lambda obj: obj._make_element_node_sign(),
|
|
230
|
+
}
|
|
231
|
+
|
|
211
232
|
def __init__(self, topo: SpaceTopology):
|
|
212
233
|
self._topo = topo
|
|
213
234
|
|
|
@@ -218,14 +239,10 @@ class TraceSpaceTopology(SpaceTopology):
|
|
|
218
239
|
|
|
219
240
|
self.TopologyArg = topo.TopologyArg
|
|
220
241
|
self.topo_arg_value = topo.topo_arg_value
|
|
242
|
+
self.fill_topo_arg = topo.fill_topo_arg
|
|
221
243
|
|
|
222
|
-
self.inner_cell_index = self._make_inner_cell_index()
|
|
223
|
-
self.outer_cell_index = self._make_outer_cell_index()
|
|
224
|
-
self.neighbor_cell_index = self._make_neighbor_cell_index()
|
|
225
|
-
|
|
226
|
-
self.element_node_index = self._make_element_node_index()
|
|
227
|
-
self.element_node_count = self._make_element_node_count()
|
|
228
244
|
self.side_neighbor_node_counts = None
|
|
245
|
+
cache.setup_dynamic_attributes(self, cls=__class__)
|
|
229
246
|
|
|
230
247
|
def node_count(self) -> int:
|
|
231
248
|
return self._topo.node_count()
|
|
@@ -354,21 +371,29 @@ class RegularDiscontinuousSpaceTopology(RegularDiscontinuousSpaceTopologyMixin,
|
|
|
354
371
|
|
|
355
372
|
|
|
356
373
|
class DeformedGeometrySpaceTopology(SpaceTopology):
|
|
374
|
+
_dynamic_attribute_constructors: ClassVar = {
|
|
375
|
+
"element_node_index": lambda obj: obj._make_element_node_index(),
|
|
376
|
+
"element_node_count": lambda obj: obj._make_element_node_count(),
|
|
377
|
+
"element_node_sign": lambda obj: obj._make_element_node_sign(),
|
|
378
|
+
"side_neighbor_node_counts": lambda obj: obj._make_side_neighbor_node_counts(),
|
|
379
|
+
}
|
|
380
|
+
|
|
357
381
|
def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
|
|
358
382
|
self.base = base_topology
|
|
359
383
|
super().__init__(geometry, base_topology.MAX_NODES_PER_ELEMENT)
|
|
360
384
|
|
|
361
385
|
self.node_count = self.base.node_count
|
|
362
386
|
self.topo_arg_value = self.base.topo_arg_value
|
|
387
|
+
self.fill_topo_arg = self.base.fill_topo_arg
|
|
363
388
|
self.TopologyArg = self.base.TopologyArg
|
|
364
389
|
|
|
365
|
-
|
|
390
|
+
cache.setup_dynamic_attributes(self, cls=__class__)
|
|
366
391
|
|
|
367
392
|
@property
|
|
368
393
|
def name(self):
|
|
369
394
|
return f"{self.base.name}_{self.geometry.field.name}"
|
|
370
395
|
|
|
371
|
-
def
|
|
396
|
+
def _make_element_node_index(self):
|
|
372
397
|
@cache.dynamic_func(suffix=self.name)
|
|
373
398
|
def element_node_index(
|
|
374
399
|
elt_arg: self.geometry.CellArg,
|
|
@@ -376,16 +401,22 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
|
|
|
376
401
|
element_index: ElementIndex,
|
|
377
402
|
node_index_in_elt: int,
|
|
378
403
|
):
|
|
379
|
-
return self.base.element_node_index(elt_arg.
|
|
404
|
+
return self.base.element_node_index(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
|
|
405
|
+
|
|
406
|
+
return element_node_index
|
|
380
407
|
|
|
408
|
+
def _make_element_node_count(self):
|
|
381
409
|
@cache.dynamic_func(suffix=self.name)
|
|
382
410
|
def element_node_count(
|
|
383
411
|
elt_arg: self.geometry.CellArg,
|
|
384
412
|
topo_arg: self.TopologyArg,
|
|
385
413
|
element_count: ElementIndex,
|
|
386
414
|
):
|
|
387
|
-
return self.base.element_node_count(elt_arg.
|
|
415
|
+
return self.base.element_node_count(elt_arg.base_arg, topo_arg, element_count)
|
|
416
|
+
|
|
417
|
+
return element_node_count
|
|
388
418
|
|
|
419
|
+
def _make_side_neighbor_node_counts(self):
|
|
389
420
|
@cache.dynamic_func(suffix=self.name)
|
|
390
421
|
def side_neighbor_node_counts(
|
|
391
422
|
side_arg: self.geometry.SideArg,
|
|
@@ -394,6 +425,9 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
|
|
|
394
425
|
inner_count, outer_count = self.base.side_neighbor_node_counts(side_arg.base_arg, element_index)
|
|
395
426
|
return inner_count, outer_count
|
|
396
427
|
|
|
428
|
+
return side_neighbor_node_counts
|
|
429
|
+
|
|
430
|
+
def _make_element_node_sign(self):
|
|
397
431
|
@cache.dynamic_func(suffix=self.name)
|
|
398
432
|
def element_node_sign(
|
|
399
433
|
elt_arg: self.geometry.CellArg,
|
|
@@ -401,12 +435,9 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
|
|
|
401
435
|
element_index: ElementIndex,
|
|
402
436
|
node_index_in_elt: int,
|
|
403
437
|
):
|
|
404
|
-
return self.base.element_node_sign(elt_arg.
|
|
438
|
+
return self.base.element_node_sign(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
|
|
405
439
|
|
|
406
|
-
|
|
407
|
-
self.element_node_count = element_node_count
|
|
408
|
-
self.element_node_sign = element_node_sign
|
|
409
|
-
self.side_neighbor_node_counts = side_neighbor_node_counts
|
|
440
|
+
return element_node_sign
|
|
410
441
|
|
|
411
442
|
|
|
412
443
|
def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
|
|
@@ -50,12 +50,15 @@ class TrimeshSpaceTopology(SpaceTopology):
|
|
|
50
50
|
@cache.cached_arg_value
|
|
51
51
|
def topo_arg_value(self, device):
|
|
52
52
|
arg = TrimeshTopologyArg()
|
|
53
|
+
self.fill_topo_arg(arg, device)
|
|
54
|
+
return arg
|
|
55
|
+
|
|
56
|
+
def fill_topo_arg(self, arg: TrimeshTopologyArg, device):
|
|
53
57
|
arg.tri_edge_indices = self._tri_edge_indices.to(device)
|
|
54
58
|
arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
|
|
55
59
|
|
|
56
60
|
arg.vertex_count = self._mesh.vertex_count()
|
|
57
61
|
arg.edge_count = self._mesh.side_count()
|
|
58
|
-
return arg
|
|
59
62
|
|
|
60
63
|
def _compute_tri_edge_indices(self):
|
|
61
64
|
self._tri_edge_indices = wp.empty(
|
warp/fem/utils.py
CHANGED
|
@@ -19,6 +19,7 @@ import numpy as np
|
|
|
19
19
|
|
|
20
20
|
import warp as wp
|
|
21
21
|
import warp.fem.cache as cache
|
|
22
|
+
import warp.types
|
|
22
23
|
from warp.fem.linalg import ( # noqa: F401 (for backward compatibility, not part of public API but used in examples)
|
|
23
24
|
array_axpy,
|
|
24
25
|
inverse_qr,
|
|
@@ -28,6 +29,57 @@ from warp.fem.types import NULL_NODE_INDEX
|
|
|
28
29
|
from warp.utils import array_scan, radix_sort_pairs, runlength_encode
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
def type_zero_element(dtype):
|
|
33
|
+
suffix = warp.types.get_type_code(dtype)
|
|
34
|
+
|
|
35
|
+
if dtype in warp.types.scalar_types:
|
|
36
|
+
|
|
37
|
+
@cache.dynamic_func(suffix=suffix)
|
|
38
|
+
def zero_element():
|
|
39
|
+
return dtype(0.0)
|
|
40
|
+
|
|
41
|
+
return zero_element
|
|
42
|
+
|
|
43
|
+
@cache.dynamic_func(suffix=suffix)
|
|
44
|
+
def zero_element():
|
|
45
|
+
return dtype()
|
|
46
|
+
|
|
47
|
+
return zero_element
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def type_basis_element(dtype):
|
|
51
|
+
suffix = warp.types.get_type_code(dtype)
|
|
52
|
+
|
|
53
|
+
if dtype in warp.types.scalar_types:
|
|
54
|
+
|
|
55
|
+
@cache.dynamic_func(suffix=suffix)
|
|
56
|
+
def basis_element(coord: int):
|
|
57
|
+
return dtype(1.0)
|
|
58
|
+
|
|
59
|
+
return basis_element
|
|
60
|
+
|
|
61
|
+
if warp.types.type_is_matrix(dtype):
|
|
62
|
+
cols = dtype._shape_[1]
|
|
63
|
+
|
|
64
|
+
@cache.dynamic_func(suffix=suffix)
|
|
65
|
+
def basis_element(coord: int):
|
|
66
|
+
v = dtype()
|
|
67
|
+
i = coord // cols
|
|
68
|
+
j = coord - i * cols
|
|
69
|
+
v[i, j] = v.dtype(1.0)
|
|
70
|
+
return v
|
|
71
|
+
|
|
72
|
+
return basis_element
|
|
73
|
+
|
|
74
|
+
@cache.dynamic_func(suffix=suffix)
|
|
75
|
+
def basis_element(coord: int):
|
|
76
|
+
v = dtype()
|
|
77
|
+
v[coord] = v.dtype(1.0)
|
|
78
|
+
return v
|
|
79
|
+
|
|
80
|
+
return basis_element
|
|
81
|
+
|
|
82
|
+
|
|
31
83
|
def compress_node_indices(
|
|
32
84
|
node_count: int,
|
|
33
85
|
node_indices: wp.array(dtype=int),
|
|
@@ -126,14 +178,7 @@ def host_read_at_index(array: wp.array, index: int = -1, temporary_store: cache.
|
|
|
126
178
|
|
|
127
179
|
if index < 0:
|
|
128
180
|
index += array.shape[0]
|
|
129
|
-
|
|
130
|
-
if array.device.is_cuda:
|
|
131
|
-
temp = cache.borrow_temporary(temporary_store, shape=1, dtype=int, pinned=True, device="cpu")
|
|
132
|
-
wp.copy(dest=temp.array, src=array, src_offset=index, count=1)
|
|
133
|
-
wp.synchronize_stream(wp.get_stream(array.device))
|
|
134
|
-
return temp.array.numpy()[0]
|
|
135
|
-
|
|
136
|
-
return array.numpy()[index]
|
|
181
|
+
return array[index : index + 1].numpy()[0]
|
|
137
182
|
|
|
138
183
|
|
|
139
184
|
def masked_indices(
|
warp/jax.py
CHANGED
|
@@ -182,6 +182,5 @@ def from_jax(jax_array, dtype=None) -> warp.array:
|
|
|
182
182
|
Returns:
|
|
183
183
|
warp.array: The converted Warp array.
|
|
184
184
|
"""
|
|
185
|
-
import jax.dlpack
|
|
186
185
|
|
|
187
|
-
return warp.from_dlpack(
|
|
186
|
+
return warp.from_dlpack(jax_array, dtype=dtype)
|
warp/jax_experimental/ffi.py
CHANGED
|
@@ -306,7 +306,6 @@ class FfiCallable:
|
|
|
306
306
|
self.graph_compatible = graph_compatible
|
|
307
307
|
self.output_dims = output_dims
|
|
308
308
|
self.first_array_arg = None
|
|
309
|
-
self.has_static_args = False
|
|
310
309
|
self.call_id = 0
|
|
311
310
|
self.call_descriptors = {}
|
|
312
311
|
|
|
@@ -335,8 +334,6 @@ class FfiCallable:
|
|
|
335
334
|
if arg.is_array:
|
|
336
335
|
if arg_idx < self.num_inputs and self.first_array_arg is None:
|
|
337
336
|
self.first_array_arg = arg_idx
|
|
338
|
-
else:
|
|
339
|
-
self.has_static_args = True
|
|
340
337
|
self.args.append(arg)
|
|
341
338
|
arg_idx += 1
|
|
342
339
|
|
|
@@ -425,14 +422,11 @@ class FfiCallable:
|
|
|
425
422
|
module = wp.get_module(self.func.__module__)
|
|
426
423
|
module.load(device)
|
|
427
424
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
return call(*args, call_id=call_id)
|
|
434
|
-
else:
|
|
435
|
-
return call(*args)
|
|
425
|
+
# save call data to be retrieved by callback
|
|
426
|
+
call_id = self.call_id
|
|
427
|
+
self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
|
|
428
|
+
self.call_id += 1
|
|
429
|
+
return call(*args, call_id=call_id)
|
|
436
430
|
|
|
437
431
|
def ffi_callback(self, call_frame):
|
|
438
432
|
try:
|
|
@@ -454,11 +448,10 @@ class FfiCallable:
|
|
|
454
448
|
)
|
|
455
449
|
return None
|
|
456
450
|
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
call_desc = self.call_descriptors[call_id]
|
|
451
|
+
# retrieve call info
|
|
452
|
+
attrs = decode_attrs(call_frame.contents.attrs)
|
|
453
|
+
call_id = int(attrs["call_id"])
|
|
454
|
+
call_desc = self.call_descriptors[call_id]
|
|
462
455
|
|
|
463
456
|
num_inputs = call_frame.contents.args.size
|
|
464
457
|
inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
|
|
@@ -500,8 +493,10 @@ class FfiCallable:
|
|
|
500
493
|
# call the Python function with reconstructed arguments
|
|
501
494
|
with wp.ScopedStream(stream, sync_enter=False):
|
|
502
495
|
if stream.is_capturing:
|
|
503
|
-
with wp.ScopedCapture(stream=stream, external=True):
|
|
496
|
+
with wp.ScopedCapture(stream=stream, external=True) as capture:
|
|
504
497
|
self.func(*arg_list)
|
|
498
|
+
# keep a reference to the capture object to prevent required modules getting unloaded
|
|
499
|
+
call_desc.capture = capture
|
|
505
500
|
else:
|
|
506
501
|
self.func(*arg_list)
|
|
507
502
|
|
warp/jax_experimental/xla_ffi.py
CHANGED
|
@@ -130,14 +130,14 @@ class XLA_FFI_DataType(enum.IntEnum):
|
|
|
130
130
|
# int64_t* dims; // length == rank
|
|
131
131
|
# };
|
|
132
132
|
class XLA_FFI_Buffer(ctypes.Structure):
|
|
133
|
-
_fields_ =
|
|
133
|
+
_fields_ = (
|
|
134
134
|
("struct_size", ctypes.c_size_t),
|
|
135
135
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
136
136
|
("dtype", ctypes.c_int), # XLA_FFI_DataType
|
|
137
137
|
("data", ctypes.c_void_p),
|
|
138
138
|
("rank", ctypes.c_int64),
|
|
139
139
|
("dims", ctypes.POINTER(ctypes.c_int64)),
|
|
140
|
-
|
|
140
|
+
)
|
|
141
141
|
|
|
142
142
|
|
|
143
143
|
# typedef enum {
|
|
@@ -162,13 +162,13 @@ class XLA_FFI_RetType(enum.IntEnum):
|
|
|
162
162
|
# void** args; // length == size
|
|
163
163
|
# };
|
|
164
164
|
class XLA_FFI_Args(ctypes.Structure):
|
|
165
|
-
_fields_ =
|
|
165
|
+
_fields_ = (
|
|
166
166
|
("struct_size", ctypes.c_size_t),
|
|
167
167
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
168
168
|
("size", ctypes.c_int64),
|
|
169
169
|
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_ArgType*
|
|
170
170
|
("args", ctypes.POINTER(ctypes.c_void_p)),
|
|
171
|
-
|
|
171
|
+
)
|
|
172
172
|
|
|
173
173
|
|
|
174
174
|
# struct XLA_FFI_Rets {
|
|
@@ -179,13 +179,13 @@ class XLA_FFI_Args(ctypes.Structure):
|
|
|
179
179
|
# void** rets; // length == size
|
|
180
180
|
# };
|
|
181
181
|
class XLA_FFI_Rets(ctypes.Structure):
|
|
182
|
-
_fields_ =
|
|
182
|
+
_fields_ = (
|
|
183
183
|
("struct_size", ctypes.c_size_t),
|
|
184
184
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
185
185
|
("size", ctypes.c_int64),
|
|
186
186
|
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_RetType*
|
|
187
187
|
("rets", ctypes.POINTER(ctypes.c_void_p)),
|
|
188
|
-
|
|
188
|
+
)
|
|
189
189
|
|
|
190
190
|
|
|
191
191
|
# typedef struct XLA_FFI_ByteSpan {
|
|
@@ -193,7 +193,10 @@ class XLA_FFI_Rets(ctypes.Structure):
|
|
|
193
193
|
# size_t len;
|
|
194
194
|
# } XLA_FFI_ByteSpan;
|
|
195
195
|
class XLA_FFI_ByteSpan(ctypes.Structure):
|
|
196
|
-
_fields_ =
|
|
196
|
+
_fields_ = (
|
|
197
|
+
("ptr", ctypes.POINTER(ctypes.c_char)),
|
|
198
|
+
("len", ctypes.c_size_t),
|
|
199
|
+
)
|
|
197
200
|
|
|
198
201
|
|
|
199
202
|
# typedef struct XLA_FFI_Scalar {
|
|
@@ -201,7 +204,10 @@ class XLA_FFI_ByteSpan(ctypes.Structure):
|
|
|
201
204
|
# void* value;
|
|
202
205
|
# } XLA_FFI_Scalar;
|
|
203
206
|
class XLA_FFI_Scalar(ctypes.Structure):
|
|
204
|
-
_fields_ =
|
|
207
|
+
_fields_ = (
|
|
208
|
+
("dtype", ctypes.c_int),
|
|
209
|
+
("value", ctypes.c_void_p),
|
|
210
|
+
)
|
|
205
211
|
|
|
206
212
|
|
|
207
213
|
# typedef struct XLA_FFI_Array {
|
|
@@ -210,7 +216,11 @@ class XLA_FFI_Scalar(ctypes.Structure):
|
|
|
210
216
|
# void* data;
|
|
211
217
|
# } XLA_FFI_Array;
|
|
212
218
|
class XLA_FFI_Array(ctypes.Structure):
|
|
213
|
-
_fields_ =
|
|
219
|
+
_fields_ = (
|
|
220
|
+
("dtype", ctypes.c_int),
|
|
221
|
+
("size", ctypes.c_size_t),
|
|
222
|
+
("data", ctypes.c_void_p),
|
|
223
|
+
)
|
|
214
224
|
|
|
215
225
|
|
|
216
226
|
# typedef enum {
|
|
@@ -235,14 +245,14 @@ class XLA_FFI_AttrType(enum.IntEnum):
|
|
|
235
245
|
# void** attrs; // length == size
|
|
236
246
|
# };
|
|
237
247
|
class XLA_FFI_Attrs(ctypes.Structure):
|
|
238
|
-
_fields_ =
|
|
248
|
+
_fields_ = (
|
|
239
249
|
("struct_size", ctypes.c_size_t),
|
|
240
250
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
241
251
|
("size", ctypes.c_int64),
|
|
242
252
|
("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_AttrType*
|
|
243
253
|
("names", ctypes.POINTER(ctypes.POINTER(XLA_FFI_ByteSpan))),
|
|
244
254
|
("attrs", ctypes.POINTER(ctypes.c_void_p)),
|
|
245
|
-
|
|
255
|
+
)
|
|
246
256
|
|
|
247
257
|
|
|
248
258
|
# struct XLA_FFI_Api_Version {
|
|
@@ -252,12 +262,12 @@ class XLA_FFI_Attrs(ctypes.Structure):
|
|
|
252
262
|
# int minor_version; // out
|
|
253
263
|
# };
|
|
254
264
|
class XLA_FFI_Api_Version(ctypes.Structure):
|
|
255
|
-
_fields_ =
|
|
265
|
+
_fields_ = (
|
|
256
266
|
("struct_size", ctypes.c_size_t),
|
|
257
267
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
258
268
|
("major_version", ctypes.c_int),
|
|
259
269
|
("minor_version", ctypes.c_int),
|
|
260
|
-
|
|
270
|
+
)
|
|
261
271
|
|
|
262
272
|
|
|
263
273
|
# enum XLA_FFI_Handler_TraitsBits {
|
|
@@ -276,11 +286,11 @@ class XLA_FFI_Handler_TraitsBits(enum.IntEnum):
|
|
|
276
286
|
# XLA_FFI_Handler_Traits traits;
|
|
277
287
|
# };
|
|
278
288
|
class XLA_FFI_Metadata(ctypes.Structure):
|
|
279
|
-
_fields_ =
|
|
289
|
+
_fields_ = (
|
|
280
290
|
("struct_size", ctypes.c_size_t),
|
|
281
291
|
("api_version", XLA_FFI_Api_Version), # XLA_FFI_Extension_Type
|
|
282
292
|
("traits", ctypes.c_uint32), # XLA_FFI_Handler_Traits
|
|
283
|
-
|
|
293
|
+
)
|
|
284
294
|
|
|
285
295
|
|
|
286
296
|
# struct XLA_FFI_Metadata_Extension {
|
|
@@ -288,7 +298,10 @@ class XLA_FFI_Metadata(ctypes.Structure):
|
|
|
288
298
|
# XLA_FFI_Metadata* metadata;
|
|
289
299
|
# };
|
|
290
300
|
class XLA_FFI_Metadata_Extension(ctypes.Structure):
|
|
291
|
-
_fields_ =
|
|
301
|
+
_fields_ = (
|
|
302
|
+
("extension_base", XLA_FFI_Extension_Base),
|
|
303
|
+
("metadata", ctypes.POINTER(XLA_FFI_Metadata)),
|
|
304
|
+
)
|
|
292
305
|
|
|
293
306
|
|
|
294
307
|
# typedef enum {
|
|
@@ -337,12 +350,12 @@ class XLA_FFI_Error_Code(enum.IntEnum):
|
|
|
337
350
|
# XLA_FFI_Error_Code errc;
|
|
338
351
|
# };
|
|
339
352
|
class XLA_FFI_Error_Create_Args(ctypes.Structure):
|
|
340
|
-
_fields_ =
|
|
353
|
+
_fields_ = (
|
|
341
354
|
("struct_size", ctypes.c_size_t),
|
|
342
355
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
343
356
|
("message", ctypes.c_char_p),
|
|
344
357
|
("errc", ctypes.c_int),
|
|
345
|
-
|
|
358
|
+
) # XLA_FFI_Error_Code
|
|
346
359
|
|
|
347
360
|
|
|
348
361
|
XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Error_Create_Args))
|
|
@@ -355,12 +368,12 @@ XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_
|
|
|
355
368
|
# void* stream; // out
|
|
356
369
|
# };
|
|
357
370
|
class XLA_FFI_Stream_Get_Args(ctypes.Structure):
|
|
358
|
-
_fields_ =
|
|
371
|
+
_fields_ = (
|
|
359
372
|
("struct_size", ctypes.c_size_t),
|
|
360
373
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
361
374
|
("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
|
|
362
375
|
("stream", ctypes.c_void_p),
|
|
363
|
-
|
|
376
|
+
) # // out
|
|
364
377
|
|
|
365
378
|
|
|
366
379
|
XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
|
|
@@ -391,7 +404,7 @@ XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_St
|
|
|
391
404
|
# _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
|
|
392
405
|
# };
|
|
393
406
|
class XLA_FFI_Api(ctypes.Structure):
|
|
394
|
-
_fields_ =
|
|
407
|
+
_fields_ = (
|
|
395
408
|
("struct_size", ctypes.c_size_t),
|
|
396
409
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
397
410
|
("api_version", XLA_FFI_Api_Version),
|
|
@@ -412,7 +425,7 @@ class XLA_FFI_Api(ctypes.Structure):
|
|
|
412
425
|
("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
|
|
413
426
|
("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
|
|
414
427
|
("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
|
|
415
|
-
|
|
428
|
+
)
|
|
416
429
|
|
|
417
430
|
|
|
418
431
|
# struct XLA_FFI_CallFrame {
|
|
@@ -431,7 +444,7 @@ class XLA_FFI_Api(ctypes.Structure):
|
|
|
431
444
|
# XLA_FFI_Future* future; // out
|
|
432
445
|
# };
|
|
433
446
|
class XLA_FFI_CallFrame(ctypes.Structure):
|
|
434
|
-
_fields_ =
|
|
447
|
+
_fields_ = (
|
|
435
448
|
("struct_size", ctypes.c_size_t),
|
|
436
449
|
("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
|
|
437
450
|
("api", ctypes.POINTER(XLA_FFI_Api)),
|
|
@@ -441,7 +454,7 @@ class XLA_FFI_CallFrame(ctypes.Structure):
|
|
|
441
454
|
("rets", XLA_FFI_Rets),
|
|
442
455
|
("attrs", XLA_FFI_Attrs),
|
|
443
456
|
("future", ctypes.c_void_p), # XLA_FFI_Future* // out
|
|
444
|
-
|
|
457
|
+
)
|
|
445
458
|
|
|
446
459
|
|
|
447
460
|
_xla_data_type_to_constructor = {
|