warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/libwarp-clang.dylib +0 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/fem/space/partition.py
CHANGED
|
@@ -19,7 +19,7 @@ import warp as wp
|
|
|
19
19
|
import warp.fem.cache as cache
|
|
20
20
|
from warp.fem.geometry import GeometryPartition, WholeGeometryPartition
|
|
21
21
|
from warp.fem.types import NULL_NODE_INDEX
|
|
22
|
-
from warp.fem.utils import
|
|
22
|
+
from warp.fem.utils import compress_node_indices
|
|
23
23
|
|
|
24
24
|
from .function_space import FunctionSpace
|
|
25
25
|
from .topology import SpaceTopology
|
|
@@ -87,7 +87,7 @@ class WholeSpacePartition(SpacePartition):
|
|
|
87
87
|
"""Return the global function space indices for nodes in this partition"""
|
|
88
88
|
if self._node_indices is None:
|
|
89
89
|
self._node_indices = cache.borrow_temporary(temporary_store=None, shape=(self.node_count(),), dtype=int)
|
|
90
|
-
wp.launch(kernel=_iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array
|
|
90
|
+
wp.launch(kernel=self._iota_kernel, dim=self.node_count(), inputs=[self._node_indices.array])
|
|
91
91
|
return self._node_indices.array
|
|
92
92
|
|
|
93
93
|
def partition_arg_value(self, device):
|
|
@@ -104,6 +104,10 @@ class WholeSpacePartition(SpacePartition):
|
|
|
104
104
|
def name(self) -> str:
|
|
105
105
|
return "Whole"
|
|
106
106
|
|
|
107
|
+
@wp.kernel
|
|
108
|
+
def _iota_kernel(indices: wp.array(dtype=int)):
|
|
109
|
+
indices[wp.tid()] = wp.tid()
|
|
110
|
+
|
|
107
111
|
|
|
108
112
|
class NodeCategory:
|
|
109
113
|
OWNED_INTERIOR = wp.constant(0)
|
|
@@ -166,10 +166,10 @@ class QuadmeshSpaceTopology(SpaceTopology):
|
|
|
166
166
|
|
|
167
167
|
if wp.static(EDGE_NODE_COUNT > 0):
|
|
168
168
|
# EDGE_X, EDGE_Y
|
|
169
|
-
side_start = wp.
|
|
169
|
+
side_start = wp.where(
|
|
170
170
|
node_type == SquareShapeFunction.EDGE_X,
|
|
171
|
-
wp.
|
|
172
|
-
wp.
|
|
171
|
+
wp.where(type_instance == 0, 0, 2),
|
|
172
|
+
wp.where(type_instance == 0, 3, 1),
|
|
173
173
|
)
|
|
174
174
|
|
|
175
175
|
side_index = topo_arg.quad_edge_indices[element_index, side_start]
|
|
@@ -178,7 +178,7 @@ class QuadmeshSpaceTopology(SpaceTopology):
|
|
|
178
178
|
|
|
179
179
|
# Flip indexing direction
|
|
180
180
|
flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
|
|
181
|
-
index_in_side = wp.
|
|
181
|
+
index_in_side = wp.where(flipped, EDGE_NODE_COUNT - 1 - type_index, type_index)
|
|
182
182
|
|
|
183
183
|
return global_offset + EDGE_NODE_COUNT * side_index + index_in_side
|
|
184
184
|
|
|
@@ -197,10 +197,10 @@ class QuadmeshSpaceTopology(SpaceTopology):
|
|
|
197
197
|
node_type, type_instance, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
|
|
198
198
|
|
|
199
199
|
if node_type == SquareShapeFunction.EDGE_X or node_type == SquareShapeFunction.EDGE_Y:
|
|
200
|
-
side_start = wp.
|
|
200
|
+
side_start = wp.where(
|
|
201
201
|
node_type == SquareShapeFunction.EDGE_X,
|
|
202
|
-
wp.
|
|
203
|
-
wp.
|
|
202
|
+
wp.where(type_instance == 0, 0, 2),
|
|
203
|
+
wp.where(type_instance == 0, 3, 1),
|
|
204
204
|
)
|
|
205
205
|
|
|
206
206
|
side_index = topo_arg.quad_edge_indices[element_index, side_start]
|
|
@@ -209,7 +209,7 @@ class QuadmeshSpaceTopology(SpaceTopology):
|
|
|
209
209
|
|
|
210
210
|
# Flip indexing direction
|
|
211
211
|
flipped = int(side_start >= 2) ^ int(local_vs != global_vs)
|
|
212
|
-
return wp.
|
|
212
|
+
return wp.where(flipped, -1.0, 1.0)
|
|
213
213
|
|
|
214
214
|
return 1.0
|
|
215
215
|
|
|
@@ -63,7 +63,7 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
|
|
|
63
63
|
|
|
64
64
|
self.ORDER = wp.constant(degree)
|
|
65
65
|
self.NODES_PER_ELEMENT = wp.constant((degree + 1) ** 3)
|
|
66
|
-
self.
|
|
66
|
+
self.NODES_PER_SIDE = wp.constant((degree + 1) ** 2)
|
|
67
67
|
|
|
68
68
|
if is_closed(self.family):
|
|
69
69
|
self.VERTEX_NODE_COUNT = wp.constant(1)
|
|
@@ -152,13 +152,13 @@ class CubeTripolynomialShapeFunctions(CubeShapeFunction):
|
|
|
152
152
|
):
|
|
153
153
|
i, j, k = self._node_ijk(node_index_in_elt)
|
|
154
154
|
|
|
155
|
-
zi = wp.
|
|
156
|
-
zj = wp.
|
|
157
|
-
zk = wp.
|
|
155
|
+
zi = wp.where(i == 0, 1, 0)
|
|
156
|
+
zj = wp.where(j == 0, 1, 0)
|
|
157
|
+
zk = wp.where(k == 0, 1, 0)
|
|
158
158
|
|
|
159
|
-
mi = wp.
|
|
160
|
-
mj = wp.
|
|
161
|
-
mk = wp.
|
|
159
|
+
mi = wp.where(i == ORDER, 1, 0)
|
|
160
|
+
mj = wp.where(j == ORDER, 1, 0)
|
|
161
|
+
mk = wp.where(k == ORDER, 1, 0)
|
|
162
162
|
|
|
163
163
|
if zi + mi == 1:
|
|
164
164
|
if zj + mj == 1:
|
|
@@ -504,7 +504,7 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
|
|
|
504
504
|
|
|
505
505
|
self.ORDER = wp.constant(degree)
|
|
506
506
|
self.NODES_PER_ELEMENT = wp.constant(8 + 12 * (degree - 1))
|
|
507
|
-
self.
|
|
507
|
+
self.NODES_PER_SIDE = wp.constant(4 * degree)
|
|
508
508
|
|
|
509
509
|
self.VERTEX_NODE_COUNT = wp.constant(1)
|
|
510
510
|
self.EDGE_NODE_COUNT = wp.constant(degree - 1)
|
|
@@ -634,9 +634,9 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
|
|
|
634
634
|
if node_type == CubeSerendipityShapeFunctions.VERTEX:
|
|
635
635
|
node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
|
|
636
636
|
|
|
637
|
-
cx = wp.
|
|
638
|
-
cy = wp.
|
|
639
|
-
cz = wp.
|
|
637
|
+
cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
|
|
638
|
+
cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
|
|
639
|
+
cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
|
|
640
640
|
|
|
641
641
|
w = cx * cy * cz
|
|
642
642
|
|
|
@@ -659,8 +659,8 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
|
|
|
659
659
|
local_coords = Grid3D._world_to_local(axis, coords)
|
|
660
660
|
|
|
661
661
|
w = float(1.0)
|
|
662
|
-
w *= wp.
|
|
663
|
-
w *= wp.
|
|
662
|
+
w *= wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
|
|
663
|
+
w *= wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
|
|
664
664
|
|
|
665
665
|
for k in range(ORDER_PLUS_ONE):
|
|
666
666
|
if k != node_all[0]:
|
|
@@ -690,13 +690,13 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
|
|
|
690
690
|
if node_type == CubeSerendipityShapeFunctions.VERTEX:
|
|
691
691
|
node_ijk = CubeSerendipityShapeFunctions._vertex_coords(type_instance)
|
|
692
692
|
|
|
693
|
-
cx = wp.
|
|
694
|
-
cy = wp.
|
|
695
|
-
cz = wp.
|
|
693
|
+
cx = wp.where(node_ijk[0] == 0, 1.0 - coords[0], coords[0])
|
|
694
|
+
cy = wp.where(node_ijk[1] == 0, 1.0 - coords[1], coords[1])
|
|
695
|
+
cz = wp.where(node_ijk[2] == 0, 1.0 - coords[2], coords[2])
|
|
696
696
|
|
|
697
|
-
gx = wp.
|
|
698
|
-
gy = wp.
|
|
699
|
-
gz = wp.
|
|
697
|
+
gx = wp.where(node_ijk[0] == 0, -1.0, 1.0)
|
|
698
|
+
gy = wp.where(node_ijk[1] == 0, -1.0, 1.0)
|
|
699
|
+
gz = wp.where(node_ijk[2] == 0, -1.0, 1.0)
|
|
700
700
|
|
|
701
701
|
if wp.static(ORDER == 2):
|
|
702
702
|
w = cx + cy + cz - 3.0 + LOBATTO_COORDS[1]
|
|
@@ -728,11 +728,11 @@ class CubeSerendipityShapeFunctions(CubeShapeFunction):
|
|
|
728
728
|
|
|
729
729
|
local_coords = Grid3D._world_to_local(axis, coords)
|
|
730
730
|
|
|
731
|
-
w_long = wp.
|
|
732
|
-
w_lat = wp.
|
|
731
|
+
w_long = wp.where(node_all[1] == 0, 1.0 - local_coords[1], local_coords[1])
|
|
732
|
+
w_lat = wp.where(node_all[2] == 0, 1.0 - local_coords[2], local_coords[2])
|
|
733
733
|
|
|
734
|
-
g_long = wp.
|
|
735
|
-
g_lat = wp.
|
|
734
|
+
g_long = wp.where(node_all[1] == 0, -1.0, 1.0)
|
|
735
|
+
g_lat = wp.where(node_all[2] == 0, -1.0, 1.0)
|
|
736
736
|
|
|
737
737
|
w_alt = LAGRANGE_SCALE[node_all[0]]
|
|
738
738
|
g_alt = float(0.0)
|
|
@@ -461,8 +461,8 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
461
461
|
node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
|
|
462
462
|
|
|
463
463
|
if node_type == SquareSerendipityShapeFunctions.VERTEX:
|
|
464
|
-
cx = wp.
|
|
465
|
-
cy = wp.
|
|
464
|
+
cx = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
|
|
465
|
+
cy = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
|
|
466
466
|
|
|
467
467
|
w = cx * cy
|
|
468
468
|
|
|
@@ -475,7 +475,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
475
475
|
|
|
476
476
|
w = float(1.0)
|
|
477
477
|
if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
|
|
478
|
-
w *= wp.
|
|
478
|
+
w *= wp.where(node_i == 0, 1.0 - coords[0], coords[0])
|
|
479
479
|
else:
|
|
480
480
|
for k in range(ORDER_PLUS_ONE):
|
|
481
481
|
if k != node_i:
|
|
@@ -484,7 +484,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
484
484
|
w *= LAGRANGE_SCALE[node_i]
|
|
485
485
|
|
|
486
486
|
if node_type == SquareSerendipityShapeFunctions.EDGE_X:
|
|
487
|
-
w *= wp.
|
|
487
|
+
w *= wp.where(node_j == 0, 1.0 - coords[1], coords[1])
|
|
488
488
|
else:
|
|
489
489
|
for k in range(ORDER_PLUS_ONE):
|
|
490
490
|
if k != node_j:
|
|
@@ -513,11 +513,11 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
513
513
|
node_i, node_j = self._node_lobatto_indices(node_type, type_instance, type_index)
|
|
514
514
|
|
|
515
515
|
if node_type == SquareSerendipityShapeFunctions.VERTEX:
|
|
516
|
-
cx = wp.
|
|
517
|
-
cy = wp.
|
|
516
|
+
cx = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
|
|
517
|
+
cy = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
|
|
518
518
|
|
|
519
|
-
gx = wp.
|
|
520
|
-
gy = wp.
|
|
519
|
+
gx = wp.where(node_i == 0, -1.0, 1.0)
|
|
520
|
+
gy = wp.where(node_j == 0, -1.0, 1.0)
|
|
521
521
|
|
|
522
522
|
if ORDER == 2:
|
|
523
523
|
w = cx + cy - 2.0 + LOBATTO_COORDS[1]
|
|
@@ -537,7 +537,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
537
537
|
return wp.vec2(grad_x, grad_y) * DEGREE_3_CIRCLE_SCALE
|
|
538
538
|
|
|
539
539
|
if node_type == SquareSerendipityShapeFunctions.EDGE_X:
|
|
540
|
-
prefix_x = wp.
|
|
540
|
+
prefix_x = wp.where(node_j == 0, 1.0 - coords[1], coords[1])
|
|
541
541
|
else:
|
|
542
542
|
prefix_x = LAGRANGE_SCALE[node_j]
|
|
543
543
|
for k in range(ORDER_PLUS_ONE):
|
|
@@ -545,7 +545,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
545
545
|
prefix_x *= coords[1] - LOBATTO_COORDS[k]
|
|
546
546
|
|
|
547
547
|
if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
|
|
548
|
-
prefix_y = wp.
|
|
548
|
+
prefix_y = wp.where(node_i == 0, 1.0 - coords[0], coords[0])
|
|
549
549
|
else:
|
|
550
550
|
prefix_y = LAGRANGE_SCALE[node_i]
|
|
551
551
|
for k in range(ORDER_PLUS_ONE):
|
|
@@ -553,7 +553,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
553
553
|
prefix_y *= coords[0] - LOBATTO_COORDS[k]
|
|
554
554
|
|
|
555
555
|
if node_type == SquareSerendipityShapeFunctions.EDGE_X:
|
|
556
|
-
grad_y = wp.
|
|
556
|
+
grad_y = wp.where(node_j == 0, -1.0, 1.0) * prefix_y
|
|
557
557
|
else:
|
|
558
558
|
prefix_y *= LAGRANGE_SCALE[node_j]
|
|
559
559
|
grad_y = float(0.0)
|
|
@@ -564,7 +564,7 @@ class SquareSerendipityShapeFunctions(SquareShapeFunction):
|
|
|
564
564
|
prefix_y *= delta_y
|
|
565
565
|
|
|
566
566
|
if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
|
|
567
|
-
grad_x = wp.
|
|
567
|
+
grad_x = wp.where(node_i == 0, -1.0, 1.0) * prefix_x
|
|
568
568
|
else:
|
|
569
569
|
prefix_x *= LAGRANGE_SCALE[node_i]
|
|
570
570
|
grad_x = float(0.0)
|
|
@@ -196,7 +196,7 @@ class TrianglePolynomialShapeFunctions(TriangleShapeFunction):
|
|
|
196
196
|
def trace_node_quadrature_weight(node_index_in_element: int):
|
|
197
197
|
node_type, type_index = self.node_type_and_type_index(node_index_in_element)
|
|
198
198
|
|
|
199
|
-
return wp.
|
|
199
|
+
return wp.where(node_type == TrianglePolynomialShapeFunctions.VERTEX, VERTEX_WEIGHT, EDGE_WEIGHT)
|
|
200
200
|
|
|
201
201
|
return trace_node_quadrature_weight
|
|
202
202
|
|
|
@@ -244,10 +244,10 @@ class TetmeshSpaceTopology(SpaceTopology):
|
|
|
244
244
|
edge = type_index // INTERIOR_NODES_PER_EDGE
|
|
245
245
|
c1, c2 = TetrahedronShapeFunction.edge_vidx(edge)
|
|
246
246
|
|
|
247
|
-
return wp.
|
|
247
|
+
return wp.where(
|
|
248
248
|
geo_arg.tet_vertex_indices[element_index][c1] > geo_arg.tet_vertex_indices[element_index][c2],
|
|
249
|
-
1.0,
|
|
250
249
|
-1.0,
|
|
250
|
+
1.0,
|
|
251
251
|
)
|
|
252
252
|
|
|
253
253
|
if wp.static(INTERIOR_NODES_PER_FACE > 0):
|
|
@@ -257,7 +257,7 @@ class TetmeshSpaceTopology(SpaceTopology):
|
|
|
257
257
|
global_face_index = topo_arg.tet_face_indices[element_index][face]
|
|
258
258
|
inner = topo_arg.face_tet_indices[global_face_index][0]
|
|
259
259
|
|
|
260
|
-
return wp.
|
|
260
|
+
return wp.where(inner == element_index, 1.0, -1.0)
|
|
261
261
|
|
|
262
262
|
return 1.0
|
|
263
263
|
|
|
@@ -175,11 +175,11 @@ class TrimeshSpaceTopology(SpaceTopology):
|
|
|
175
175
|
edge = type_index // INTERIOR_NODES_PER_SIDE
|
|
176
176
|
|
|
177
177
|
global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
|
|
178
|
-
return wp.
|
|
178
|
+
return wp.where(
|
|
179
179
|
topo_arg.edge_vertex_indices[global_edge_index][0]
|
|
180
180
|
== geo_arg.topology.tri_vertex_indices[element_index][edge],
|
|
181
|
-
-1.0,
|
|
182
181
|
1.0,
|
|
182
|
+
-1.0,
|
|
183
183
|
)
|
|
184
184
|
|
|
185
185
|
return 1.0
|
warp/fem/utils.py
CHANGED
|
@@ -56,13 +56,11 @@ def compress_node_indices(
|
|
|
56
56
|
sorted_node_indices = sorted_node_indices_temp.array
|
|
57
57
|
sorted_array_indices = sorted_array_indices_temp.array
|
|
58
58
|
|
|
59
|
-
wp.copy(dest=sorted_node_indices, src=node_indices, count=index_count)
|
|
60
|
-
|
|
61
59
|
indices_per_element = 1 if node_indices.ndim == 1 else node_indices.shape[-1]
|
|
62
60
|
wp.launch(
|
|
63
|
-
kernel=
|
|
61
|
+
kernel=_prepare_node_sort_kernel,
|
|
64
62
|
dim=index_count,
|
|
65
|
-
inputs=[sorted_array_indices, indices_per_element],
|
|
63
|
+
inputs=[node_indices.flatten(), sorted_node_indices, sorted_array_indices, indices_per_element],
|
|
66
64
|
)
|
|
67
65
|
|
|
68
66
|
# Sort indices
|
|
@@ -169,8 +167,16 @@ def masked_indices(
|
|
|
169
167
|
|
|
170
168
|
|
|
171
169
|
@wp.kernel
|
|
172
|
-
def
|
|
173
|
-
|
|
170
|
+
def _prepare_node_sort_kernel(
|
|
171
|
+
node_indices: wp.array(dtype=int),
|
|
172
|
+
sort_keys: wp.array(dtype=int),
|
|
173
|
+
sort_values: wp.array(dtype=int),
|
|
174
|
+
divisor: int,
|
|
175
|
+
):
|
|
176
|
+
i = wp.tid()
|
|
177
|
+
node = node_indices[i]
|
|
178
|
+
sort_keys[i] = wp.where(node >= 0, node, NULL_NODE_INDEX)
|
|
179
|
+
sort_values[i] = i // divisor
|
|
174
180
|
|
|
175
181
|
|
|
176
182
|
@wp.kernel
|
warp/jax.py
CHANGED
|
@@ -58,6 +58,19 @@ def device_from_jax(jax_device) -> warp.context.Device:
|
|
|
58
58
|
raise RuntimeError(f"Unsupported Jax device platform '{jax_device.platform}'")
|
|
59
59
|
|
|
60
60
|
|
|
61
|
+
def get_jax_device():
|
|
62
|
+
"""Get the current Jax device."""
|
|
63
|
+
import jax
|
|
64
|
+
|
|
65
|
+
# TODO: is there a simpler way of getting the Jax "current" device?
|
|
66
|
+
# check if jax.default_device() context manager is active
|
|
67
|
+
device = jax.config.jax_default_device
|
|
68
|
+
# if default device is not set, use first device
|
|
69
|
+
if device is None:
|
|
70
|
+
device = jax.local_devices()[0]
|
|
71
|
+
return device
|
|
72
|
+
|
|
73
|
+
|
|
61
74
|
def dtype_to_jax(warp_dtype):
|
|
62
75
|
"""Return the Jax dtype corresponding to a Warp dtype.
|
|
63
76
|
|
|
@@ -156,7 +169,7 @@ def to_jax(warp_array):
|
|
|
156
169
|
"""
|
|
157
170
|
import jax.dlpack
|
|
158
171
|
|
|
159
|
-
return jax.dlpack.from_dlpack(
|
|
172
|
+
return jax.dlpack.from_dlpack(warp_array)
|
|
160
173
|
|
|
161
174
|
|
|
162
175
|
def from_jax(jax_array, dtype=None) -> warp.array:
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from .custom_call import jax_kernel
|
|
@@ -15,10 +15,9 @@
|
|
|
15
15
|
|
|
16
16
|
import ctypes
|
|
17
17
|
|
|
18
|
-
import jax
|
|
19
|
-
|
|
20
18
|
import warp as wp
|
|
21
19
|
from warp.context import type_str
|
|
20
|
+
from warp.jax import get_jax_device
|
|
22
21
|
from warp.types import array_t, launch_bounds_t, strides_from_shape
|
|
23
22
|
|
|
24
23
|
_jax_warp_p = None
|
|
@@ -29,35 +28,33 @@ _registered_kernels = [None]
|
|
|
29
28
|
_registered_kernel_to_id = {}
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
def jax_kernel(
|
|
31
|
+
def jax_kernel(kernel, launch_dims=None):
|
|
33
32
|
"""Create a Jax primitive from a Warp kernel.
|
|
34
33
|
|
|
35
34
|
NOTE: This is an experimental feature under development.
|
|
36
35
|
|
|
37
36
|
Args:
|
|
38
|
-
|
|
37
|
+
kernel: The Warp kernel to be wrapped.
|
|
39
38
|
launch_dims: Optional. Specify the kernel launch dimensions. If None,
|
|
40
39
|
dimensions are inferred from the shape of the first argument.
|
|
41
40
|
This option when set will specify the output dimensions.
|
|
42
41
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
- All arrays must be contiguous.
|
|
49
|
-
- Only the CUDA backend is supported.
|
|
42
|
+
Limitations:
|
|
43
|
+
- All kernel arguments must be contiguous arrays.
|
|
44
|
+
- Input arguments are followed by output arguments in the Warp kernel definition.
|
|
45
|
+
- There must be at least one input argument and at least one output argument.
|
|
46
|
+
- Only the CUDA backend is supported.
|
|
50
47
|
"""
|
|
51
48
|
|
|
52
49
|
if _jax_warp_p is None:
|
|
53
50
|
# Create and register the primitive
|
|
54
51
|
_create_jax_warp_primitive()
|
|
55
|
-
if
|
|
52
|
+
if kernel not in _registered_kernel_to_id:
|
|
56
53
|
id = len(_registered_kernels)
|
|
57
|
-
_registered_kernels.append(
|
|
58
|
-
_registered_kernel_to_id[
|
|
54
|
+
_registered_kernels.append(kernel)
|
|
55
|
+
_registered_kernel_to_id[kernel] = id
|
|
59
56
|
else:
|
|
60
|
-
id = _registered_kernel_to_id[
|
|
57
|
+
id = _registered_kernel_to_id[kernel]
|
|
61
58
|
|
|
62
59
|
def bind(*args):
|
|
63
60
|
return _jax_warp_p.bind(*args, kernel=id, launch_dims=launch_dims)
|
|
@@ -102,7 +99,7 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
|
|
|
102
99
|
kernel_params[i + 1] = arg_ptr
|
|
103
100
|
|
|
104
101
|
# Get current device.
|
|
105
|
-
device = wp.device_from_jax(
|
|
102
|
+
device = wp.device_from_jax(get_jax_device())
|
|
106
103
|
|
|
107
104
|
# Get kernel hooks.
|
|
108
105
|
# Note: module was loaded during jit lowering.
|
|
@@ -115,16 +112,6 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
|
|
|
115
112
|
)
|
|
116
113
|
|
|
117
114
|
|
|
118
|
-
# TODO: is there a simpler way of getting the Jax "current" device?
|
|
119
|
-
def _get_jax_device():
|
|
120
|
-
# check if jax.default_device() context manager is active
|
|
121
|
-
device = jax.config.jax_default_device
|
|
122
|
-
# if default device is not set, use first device
|
|
123
|
-
if device is None:
|
|
124
|
-
device = jax.local_devices()[0]
|
|
125
|
-
return device
|
|
126
|
-
|
|
127
|
-
|
|
128
115
|
def _create_jax_warp_primitive():
|
|
129
116
|
from functools import reduce
|
|
130
117
|
|
|
@@ -288,7 +275,7 @@ def _create_jax_warp_primitive():
|
|
|
288
275
|
# TODO This may not be necessary, but it is perhaps better not to be
|
|
289
276
|
# mucking with kernel loading while already running the workload.
|
|
290
277
|
module = wp_kernel.module
|
|
291
|
-
device = wp.device_from_jax(
|
|
278
|
+
device = wp.device_from_jax(get_jax_device())
|
|
292
279
|
if not module.load(device):
|
|
293
280
|
raise Exception("Could not load kernel on device")
|
|
294
281
|
|