warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/fem/space/topology.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
from typing import Optional, Tuple, Type
|
|
16
|
+
from typing import ClassVar, Optional, Tuple, Type
|
|
17
17
|
|
|
18
18
|
import warp as wp
|
|
19
19
|
from warp.fem import cache
|
|
@@ -39,6 +39,12 @@ class SpaceTopology:
|
|
|
39
39
|
.. note:: This will change to be defined per-element in future versions
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
|
+
_dynamic_attribute_constructors: ClassVar = {
|
|
43
|
+
"element_node_count": lambda obj: obj._make_constant_element_node_count(),
|
|
44
|
+
"element_node_sign": lambda obj: obj._make_constant_element_node_sign(),
|
|
45
|
+
"side_neighbor_node_counts": lambda obj: obj._make_constant_side_neighbor_node_counts(),
|
|
46
|
+
}
|
|
47
|
+
|
|
42
48
|
@wp.struct
|
|
43
49
|
class TopologyArg:
|
|
44
50
|
"""Structure containing arguments to be passed to device functions"""
|
|
@@ -51,8 +57,7 @@ class SpaceTopology:
|
|
|
51
57
|
self.MAX_NODES_PER_ELEMENT = wp.constant(max_nodes_per_element)
|
|
52
58
|
self.ElementArg = geometry.CellArg
|
|
53
59
|
|
|
54
|
-
|
|
55
|
-
self._make_constant_element_node_sign()
|
|
60
|
+
cache.setup_dynamic_attributes(self, cls=__class__)
|
|
56
61
|
|
|
57
62
|
@property
|
|
58
63
|
def geometry(self) -> Geometry:
|
|
@@ -67,6 +72,9 @@ class SpaceTopology:
|
|
|
67
72
|
"""Value of the topology argument structure to be passed to device functions"""
|
|
68
73
|
return SpaceTopology.TopologyArg()
|
|
69
74
|
|
|
75
|
+
def fill_topo_arg(self, arg, device):
|
|
76
|
+
pass
|
|
77
|
+
|
|
70
78
|
@property
|
|
71
79
|
def name(self):
|
|
72
80
|
return f"{self.__class__.__name__}_{self.MAX_NODES_PER_ELEMENT}"
|
|
@@ -182,6 +190,11 @@ class SpaceTopology:
|
|
|
182
190
|
):
|
|
183
191
|
return NODES_PER_ELEMENT
|
|
184
192
|
|
|
193
|
+
return constant_element_node_count
|
|
194
|
+
|
|
195
|
+
def _make_constant_side_neighbor_node_counts(self):
|
|
196
|
+
NODES_PER_ELEMENT = wp.constant(self.MAX_NODES_PER_ELEMENT)
|
|
197
|
+
|
|
185
198
|
@cache.dynamic_func(suffix=self.name)
|
|
186
199
|
def constant_side_neighbor_node_counts(
|
|
187
200
|
side_arg: self.geometry.SideArg,
|
|
@@ -189,8 +202,7 @@ class SpaceTopology:
|
|
|
189
202
|
):
|
|
190
203
|
return NODES_PER_ELEMENT, NODES_PER_ELEMENT
|
|
191
204
|
|
|
192
|
-
|
|
193
|
-
self.side_neighbor_node_counts = constant_side_neighbor_node_counts
|
|
205
|
+
return constant_side_neighbor_node_counts
|
|
194
206
|
|
|
195
207
|
def _make_constant_element_node_sign(self):
|
|
196
208
|
@cache.dynamic_func(suffix=self.name)
|
|
@@ -202,12 +214,21 @@ class SpaceTopology:
|
|
|
202
214
|
):
|
|
203
215
|
return 1.0
|
|
204
216
|
|
|
205
|
-
|
|
217
|
+
return constant_element_node_sign
|
|
206
218
|
|
|
207
219
|
|
|
208
220
|
class TraceSpaceTopology(SpaceTopology):
|
|
209
221
|
"""Auto-generated trace topology defining the node indices associated to the geometry sides"""
|
|
210
222
|
|
|
223
|
+
_dynamic_attribute_constructors: ClassVar = {
|
|
224
|
+
"inner_cell_index": lambda obj: obj._make_inner_cell_index(),
|
|
225
|
+
"outer_cell_index": lambda obj: obj._make_outer_cell_index(),
|
|
226
|
+
"neighbor_cell_index": lambda obj: obj._make_neighbor_cell_index(),
|
|
227
|
+
"element_node_index": lambda obj: obj._make_element_node_index(),
|
|
228
|
+
"element_node_count": lambda obj: obj._make_element_node_count(),
|
|
229
|
+
"element_node_sign": lambda obj: obj._make_element_node_sign(),
|
|
230
|
+
}
|
|
231
|
+
|
|
211
232
|
def __init__(self, topo: SpaceTopology):
|
|
212
233
|
self._topo = topo
|
|
213
234
|
|
|
@@ -218,14 +239,10 @@ class TraceSpaceTopology(SpaceTopology):
|
|
|
218
239
|
|
|
219
240
|
self.TopologyArg = topo.TopologyArg
|
|
220
241
|
self.topo_arg_value = topo.topo_arg_value
|
|
242
|
+
self.fill_topo_arg = topo.fill_topo_arg
|
|
221
243
|
|
|
222
|
-
self.inner_cell_index = self._make_inner_cell_index()
|
|
223
|
-
self.outer_cell_index = self._make_outer_cell_index()
|
|
224
|
-
self.neighbor_cell_index = self._make_neighbor_cell_index()
|
|
225
|
-
|
|
226
|
-
self.element_node_index = self._make_element_node_index()
|
|
227
|
-
self.element_node_count = self._make_element_node_count()
|
|
228
244
|
self.side_neighbor_node_counts = None
|
|
245
|
+
cache.setup_dynamic_attributes(self, cls=__class__)
|
|
229
246
|
|
|
230
247
|
def node_count(self) -> int:
|
|
231
248
|
return self._topo.node_count()
|
|
@@ -354,21 +371,29 @@ class RegularDiscontinuousSpaceTopology(RegularDiscontinuousSpaceTopologyMixin,
|
|
|
354
371
|
|
|
355
372
|
|
|
356
373
|
class DeformedGeometrySpaceTopology(SpaceTopology):
|
|
374
|
+
_dynamic_attribute_constructors: ClassVar = {
|
|
375
|
+
"element_node_index": lambda obj: obj._make_element_node_index(),
|
|
376
|
+
"element_node_count": lambda obj: obj._make_element_node_count(),
|
|
377
|
+
"element_node_sign": lambda obj: obj._make_element_node_sign(),
|
|
378
|
+
"side_neighbor_node_counts": lambda obj: obj._make_side_neighbor_node_counts(),
|
|
379
|
+
}
|
|
380
|
+
|
|
357
381
|
def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
|
|
358
382
|
self.base = base_topology
|
|
359
383
|
super().__init__(geometry, base_topology.MAX_NODES_PER_ELEMENT)
|
|
360
384
|
|
|
361
385
|
self.node_count = self.base.node_count
|
|
362
386
|
self.topo_arg_value = self.base.topo_arg_value
|
|
387
|
+
self.fill_topo_arg = self.base.fill_topo_arg
|
|
363
388
|
self.TopologyArg = self.base.TopologyArg
|
|
364
389
|
|
|
365
|
-
|
|
390
|
+
cache.setup_dynamic_attributes(self, cls=__class__)
|
|
366
391
|
|
|
367
392
|
@property
|
|
368
393
|
def name(self):
|
|
369
394
|
return f"{self.base.name}_{self.geometry.field.name}"
|
|
370
395
|
|
|
371
|
-
def
|
|
396
|
+
def _make_element_node_index(self):
|
|
372
397
|
@cache.dynamic_func(suffix=self.name)
|
|
373
398
|
def element_node_index(
|
|
374
399
|
elt_arg: self.geometry.CellArg,
|
|
@@ -376,16 +401,22 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
|
|
|
376
401
|
element_index: ElementIndex,
|
|
377
402
|
node_index_in_elt: int,
|
|
378
403
|
):
|
|
379
|
-
return self.base.element_node_index(elt_arg.
|
|
404
|
+
return self.base.element_node_index(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
|
|
405
|
+
|
|
406
|
+
return element_node_index
|
|
380
407
|
|
|
408
|
+
def _make_element_node_count(self):
|
|
381
409
|
@cache.dynamic_func(suffix=self.name)
|
|
382
410
|
def element_node_count(
|
|
383
411
|
elt_arg: self.geometry.CellArg,
|
|
384
412
|
topo_arg: self.TopologyArg,
|
|
385
413
|
element_count: ElementIndex,
|
|
386
414
|
):
|
|
387
|
-
return self.base.element_node_count(elt_arg.
|
|
415
|
+
return self.base.element_node_count(elt_arg.base_arg, topo_arg, element_count)
|
|
416
|
+
|
|
417
|
+
return element_node_count
|
|
388
418
|
|
|
419
|
+
def _make_side_neighbor_node_counts(self):
|
|
389
420
|
@cache.dynamic_func(suffix=self.name)
|
|
390
421
|
def side_neighbor_node_counts(
|
|
391
422
|
side_arg: self.geometry.SideArg,
|
|
@@ -394,6 +425,9 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
|
|
|
394
425
|
inner_count, outer_count = self.base.side_neighbor_node_counts(side_arg.base_arg, element_index)
|
|
395
426
|
return inner_count, outer_count
|
|
396
427
|
|
|
428
|
+
return side_neighbor_node_counts
|
|
429
|
+
|
|
430
|
+
def _make_element_node_sign(self):
|
|
397
431
|
@cache.dynamic_func(suffix=self.name)
|
|
398
432
|
def element_node_sign(
|
|
399
433
|
elt_arg: self.geometry.CellArg,
|
|
@@ -401,12 +435,9 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
|
|
|
401
435
|
element_index: ElementIndex,
|
|
402
436
|
node_index_in_elt: int,
|
|
403
437
|
):
|
|
404
|
-
return self.base.element_node_sign(elt_arg.
|
|
438
|
+
return self.base.element_node_sign(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
|
|
405
439
|
|
|
406
|
-
|
|
407
|
-
self.element_node_count = element_node_count
|
|
408
|
-
self.element_node_sign = element_node_sign
|
|
409
|
-
self.side_neighbor_node_counts = side_neighbor_node_counts
|
|
440
|
+
return element_node_sign
|
|
410
441
|
|
|
411
442
|
|
|
412
443
|
def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
|
|
@@ -50,12 +50,15 @@ class TrimeshSpaceTopology(SpaceTopology):
|
|
|
50
50
|
@cache.cached_arg_value
|
|
51
51
|
def topo_arg_value(self, device):
|
|
52
52
|
arg = TrimeshTopologyArg()
|
|
53
|
+
self.fill_topo_arg(arg, device)
|
|
54
|
+
return arg
|
|
55
|
+
|
|
56
|
+
def fill_topo_arg(self, arg: TrimeshTopologyArg, device):
|
|
53
57
|
arg.tri_edge_indices = self._tri_edge_indices.to(device)
|
|
54
58
|
arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
|
|
55
59
|
|
|
56
60
|
arg.vertex_count = self._mesh.vertex_count()
|
|
57
61
|
arg.edge_count = self._mesh.side_count()
|
|
58
|
-
return arg
|
|
59
62
|
|
|
60
63
|
def _compute_tri_edge_indices(self):
|
|
61
64
|
self._tri_edge_indices = wp.empty(
|
warp/fem/utils.py
CHANGED
|
@@ -19,6 +19,7 @@ import numpy as np
|
|
|
19
19
|
|
|
20
20
|
import warp as wp
|
|
21
21
|
import warp.fem.cache as cache
|
|
22
|
+
import warp.types
|
|
22
23
|
from warp.fem.linalg import ( # noqa: F401 (for backward compatibility, not part of public API but used in examples)
|
|
23
24
|
array_axpy,
|
|
24
25
|
inverse_qr,
|
|
@@ -28,6 +29,57 @@ from warp.fem.types import NULL_NODE_INDEX
|
|
|
28
29
|
from warp.utils import array_scan, radix_sort_pairs, runlength_encode
|
|
29
30
|
|
|
30
31
|
|
|
32
|
+
def type_zero_element(dtype):
|
|
33
|
+
suffix = warp.types.get_type_code(dtype)
|
|
34
|
+
|
|
35
|
+
if dtype in warp.types.scalar_types:
|
|
36
|
+
|
|
37
|
+
@cache.dynamic_func(suffix=suffix)
|
|
38
|
+
def zero_element():
|
|
39
|
+
return dtype(0.0)
|
|
40
|
+
|
|
41
|
+
return zero_element
|
|
42
|
+
|
|
43
|
+
@cache.dynamic_func(suffix=suffix)
|
|
44
|
+
def zero_element():
|
|
45
|
+
return dtype()
|
|
46
|
+
|
|
47
|
+
return zero_element
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def type_basis_element(dtype):
|
|
51
|
+
suffix = warp.types.get_type_code(dtype)
|
|
52
|
+
|
|
53
|
+
if dtype in warp.types.scalar_types:
|
|
54
|
+
|
|
55
|
+
@cache.dynamic_func(suffix=suffix)
|
|
56
|
+
def basis_element(coord: int):
|
|
57
|
+
return dtype(1.0)
|
|
58
|
+
|
|
59
|
+
return basis_element
|
|
60
|
+
|
|
61
|
+
if warp.types.type_is_matrix(dtype):
|
|
62
|
+
cols = dtype._shape_[1]
|
|
63
|
+
|
|
64
|
+
@cache.dynamic_func(suffix=suffix)
|
|
65
|
+
def basis_element(coord: int):
|
|
66
|
+
v = dtype()
|
|
67
|
+
i = coord // cols
|
|
68
|
+
j = coord - i * cols
|
|
69
|
+
v[i, j] = v.dtype(1.0)
|
|
70
|
+
return v
|
|
71
|
+
|
|
72
|
+
return basis_element
|
|
73
|
+
|
|
74
|
+
@cache.dynamic_func(suffix=suffix)
|
|
75
|
+
def basis_element(coord: int):
|
|
76
|
+
v = dtype()
|
|
77
|
+
v[coord] = v.dtype(1.0)
|
|
78
|
+
return v
|
|
79
|
+
|
|
80
|
+
return basis_element
|
|
81
|
+
|
|
82
|
+
|
|
31
83
|
def compress_node_indices(
|
|
32
84
|
node_count: int,
|
|
33
85
|
node_indices: wp.array(dtype=int),
|
|
@@ -126,14 +178,7 @@ def host_read_at_index(array: wp.array, index: int = -1, temporary_store: cache.
|
|
|
126
178
|
|
|
127
179
|
if index < 0:
|
|
128
180
|
index += array.shape[0]
|
|
129
|
-
|
|
130
|
-
if array.device.is_cuda:
|
|
131
|
-
temp = cache.borrow_temporary(temporary_store, shape=1, dtype=int, pinned=True, device="cpu")
|
|
132
|
-
wp.copy(dest=temp.array, src=array, src_offset=index, count=1)
|
|
133
|
-
wp.synchronize_stream(wp.get_stream(array.device))
|
|
134
|
-
return temp.array.numpy()[0]
|
|
135
|
-
|
|
136
|
-
return array.numpy()[index]
|
|
181
|
+
return array[index : index + 1].numpy()[0]
|
|
137
182
|
|
|
138
183
|
|
|
139
184
|
def masked_indices(
|
warp/jax.py
CHANGED
|
@@ -182,6 +182,5 @@ def from_jax(jax_array, dtype=None) -> warp.array:
|
|
|
182
182
|
Returns:
|
|
183
183
|
warp.array: The converted Warp array.
|
|
184
184
|
"""
|
|
185
|
-
import jax.dlpack
|
|
186
185
|
|
|
187
|
-
return warp.from_dlpack(
|
|
186
|
+
return warp.from_dlpack(jax_array, dtype=dtype)
|