warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.1__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1819 -7
- warp/codegen.py +197 -61
- warp/config.py +2 -2
- warp/context.py +379 -107
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +4 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -7
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +604 -0
- warp/native/cuda_util.cpp +68 -51
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1854 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +140 -67
- warp/sim/graph_coloring.py +292 -0
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +109 -32
- warp/sparse.py +1 -1
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +251 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +21 -5
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +34 -4
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_matmul.py +6 -9
- warp/tests/test_matmul_lite.py +6 -11
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_overwrite.py +45 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_static.py +3 -3
- warp/tests/test_tile.py +744 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -2
- warp/types.py +340 -74
- warp/utils.py +23 -3
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import warp as wp
|
|
2
|
+
from warp.fem import cache
|
|
3
|
+
from warp.fem.geometry import Trimesh
|
|
4
|
+
from warp.fem.types import ElementIndex
|
|
5
|
+
|
|
6
|
+
from .shape import TriangleShapeFunction
|
|
7
|
+
from .topology import SpaceTopology, forward_base_topology
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@wp.struct
|
|
11
|
+
class TrimeshTopologyArg:
|
|
12
|
+
edge_vertex_indices: wp.array(dtype=wp.vec2i)
|
|
13
|
+
tri_edge_indices: wp.array2d(dtype=int)
|
|
14
|
+
|
|
15
|
+
vertex_count: int
|
|
16
|
+
edge_count: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TrimeshSpaceTopology(SpaceTopology):
|
|
20
|
+
TopologyArg = TrimeshTopologyArg
|
|
21
|
+
|
|
22
|
+
def __init__(self, mesh: Trimesh, shape: TriangleShapeFunction):
|
|
23
|
+
self._shape = shape
|
|
24
|
+
super().__init__(mesh, shape.NODES_PER_ELEMENT)
|
|
25
|
+
self._mesh = mesh
|
|
26
|
+
|
|
27
|
+
self._compute_tri_edge_indices()
|
|
28
|
+
self.element_node_index = self._make_element_node_index()
|
|
29
|
+
self.element_node_sign = self._make_element_node_sign()
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def name(self):
|
|
33
|
+
return f"{self.geometry.name}_{self._shape.name}"
|
|
34
|
+
|
|
35
|
+
@cache.cached_arg_value
|
|
36
|
+
def topo_arg_value(self, device):
|
|
37
|
+
arg = TrimeshTopologyArg()
|
|
38
|
+
arg.tri_edge_indices = self._tri_edge_indices.to(device)
|
|
39
|
+
arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
|
|
40
|
+
|
|
41
|
+
arg.vertex_count = self._mesh.vertex_count()
|
|
42
|
+
arg.edge_count = self._mesh.side_count()
|
|
43
|
+
return arg
|
|
44
|
+
|
|
45
|
+
def _compute_tri_edge_indices(self):
|
|
46
|
+
self._tri_edge_indices = wp.empty(
|
|
47
|
+
dtype=int, device=self._mesh.tri_vertex_indices.device, shape=(self._mesh.cell_count(), 3)
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
wp.launch(
|
|
51
|
+
kernel=TrimeshSpaceTopology._compute_tri_edge_indices_kernel,
|
|
52
|
+
dim=self._mesh.edge_tri_indices.shape,
|
|
53
|
+
device=self._mesh.tri_vertex_indices.device,
|
|
54
|
+
inputs=[
|
|
55
|
+
self._mesh.edge_tri_indices,
|
|
56
|
+
self._mesh.edge_vertex_indices,
|
|
57
|
+
self._mesh.tri_vertex_indices,
|
|
58
|
+
self._tri_edge_indices,
|
|
59
|
+
],
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
@wp.func
|
|
63
|
+
def _find_edge_index_in_tri(
|
|
64
|
+
edge_vtx: wp.vec2i,
|
|
65
|
+
tri_vtx: wp.vec3i,
|
|
66
|
+
):
|
|
67
|
+
for k in range(2):
|
|
68
|
+
if (edge_vtx[0] == tri_vtx[k] and edge_vtx[1] == tri_vtx[k + 1]) or (
|
|
69
|
+
edge_vtx[1] == tri_vtx[k] and edge_vtx[0] == tri_vtx[k + 1]
|
|
70
|
+
):
|
|
71
|
+
return k
|
|
72
|
+
return 2
|
|
73
|
+
|
|
74
|
+
@wp.kernel
|
|
75
|
+
def _compute_tri_edge_indices_kernel(
|
|
76
|
+
edge_tri_indices: wp.array(dtype=wp.vec2i),
|
|
77
|
+
edge_vertex_indices: wp.array(dtype=wp.vec2i),
|
|
78
|
+
tri_vertex_indices: wp.array2d(dtype=int),
|
|
79
|
+
tri_edge_indices: wp.array2d(dtype=int),
|
|
80
|
+
):
|
|
81
|
+
e = wp.tid()
|
|
82
|
+
|
|
83
|
+
edge_vtx = edge_vertex_indices[e]
|
|
84
|
+
edge_tris = edge_tri_indices[e]
|
|
85
|
+
|
|
86
|
+
t0 = edge_tris[0]
|
|
87
|
+
t0_vtx = wp.vec3i(tri_vertex_indices[t0, 0], tri_vertex_indices[t0, 1], tri_vertex_indices[t0, 2])
|
|
88
|
+
t0_edge = TrimeshSpaceTopology._find_edge_index_in_tri(edge_vtx, t0_vtx)
|
|
89
|
+
tri_edge_indices[t0, t0_edge] = e
|
|
90
|
+
|
|
91
|
+
t1 = edge_tris[1]
|
|
92
|
+
if t1 != t0:
|
|
93
|
+
t1_vtx = wp.vec3i(tri_vertex_indices[t1, 0], tri_vertex_indices[t1, 1], tri_vertex_indices[t1, 2])
|
|
94
|
+
t1_edge = TrimeshSpaceTopology._find_edge_index_in_tri(edge_vtx, t1_vtx)
|
|
95
|
+
tri_edge_indices[t1, t1_edge] = e
|
|
96
|
+
|
|
97
|
+
def node_count(self) -> int:
|
|
98
|
+
return (
|
|
99
|
+
self._mesh.vertex_count() * self._shape.VERTEX_NODE_COUNT
|
|
100
|
+
+ self._mesh.side_count() * self._shape.EDGE_NODE_COUNT
|
|
101
|
+
+ self._mesh.cell_count() * self._shape.INTERIOR_NODE_COUNT
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def _make_element_node_index(self):
|
|
105
|
+
VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
|
|
106
|
+
INTERIOR_NODES_PER_SIDE = self._shape.EDGE_NODE_COUNT
|
|
107
|
+
INTERIOR_NODES_PER_CELL = self._shape.INTERIOR_NODE_COUNT
|
|
108
|
+
|
|
109
|
+
@cache.dynamic_func(suffix=self.name)
|
|
110
|
+
def element_node_index(
|
|
111
|
+
geo_arg: self.geometry.CellArg,
|
|
112
|
+
topo_arg: TrimeshTopologyArg,
|
|
113
|
+
element_index: ElementIndex,
|
|
114
|
+
node_index_in_elt: int,
|
|
115
|
+
):
|
|
116
|
+
node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
|
|
117
|
+
|
|
118
|
+
if wp.static(VERTEX_NODE_COUNT > 0):
|
|
119
|
+
if node_type == TriangleShapeFunction.VERTEX:
|
|
120
|
+
vertex = type_index // VERTEX_NODE_COUNT
|
|
121
|
+
vertex_node = type_index - VERTEX_NODE_COUNT * vertex
|
|
122
|
+
return geo_arg.topology.tri_vertex_indices[element_index][vertex] * VERTEX_NODE_COUNT + vertex_node
|
|
123
|
+
|
|
124
|
+
global_offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
|
|
125
|
+
|
|
126
|
+
if wp.static(INTERIOR_NODES_PER_SIDE > 0):
|
|
127
|
+
if node_type == TriangleShapeFunction.EDGE:
|
|
128
|
+
edge = type_index // INTERIOR_NODES_PER_SIDE
|
|
129
|
+
edge_node = type_index - INTERIOR_NODES_PER_SIDE * edge
|
|
130
|
+
|
|
131
|
+
global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
|
|
132
|
+
|
|
133
|
+
if (
|
|
134
|
+
topo_arg.edge_vertex_indices[global_edge_index][0]
|
|
135
|
+
!= geo_arg.topology.tri_vertex_indices[element_index][edge]
|
|
136
|
+
):
|
|
137
|
+
edge_node = INTERIOR_NODES_PER_SIDE - 1 - edge_node
|
|
138
|
+
|
|
139
|
+
return global_offset + INTERIOR_NODES_PER_SIDE * global_edge_index + edge_node
|
|
140
|
+
|
|
141
|
+
global_offset += INTERIOR_NODES_PER_SIDE * topo_arg.edge_count
|
|
142
|
+
|
|
143
|
+
return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index
|
|
144
|
+
|
|
145
|
+
return element_node_index
|
|
146
|
+
|
|
147
|
+
def _make_element_node_sign(self):
|
|
148
|
+
INTERIOR_NODES_PER_SIDE = self._shape.EDGE_NODE_COUNT
|
|
149
|
+
|
|
150
|
+
@cache.dynamic_func(suffix=self.name)
|
|
151
|
+
def element_node_sign(
|
|
152
|
+
geo_arg: self.geometry.CellArg,
|
|
153
|
+
topo_arg: TrimeshTopologyArg,
|
|
154
|
+
element_index: ElementIndex,
|
|
155
|
+
node_index_in_elt: int,
|
|
156
|
+
):
|
|
157
|
+
node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
|
|
158
|
+
|
|
159
|
+
if node_type == TriangleShapeFunction.EDGE:
|
|
160
|
+
edge = type_index // INTERIOR_NODES_PER_SIDE
|
|
161
|
+
|
|
162
|
+
global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
|
|
163
|
+
return wp.select(
|
|
164
|
+
topo_arg.edge_vertex_indices[global_edge_index][0]
|
|
165
|
+
== geo_arg.topology.tri_vertex_indices[element_index][edge],
|
|
166
|
+
-1.0,
|
|
167
|
+
1.0,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return 1.0
|
|
171
|
+
|
|
172
|
+
return element_node_sign
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def make_trimesh_space_topology(mesh: Trimesh, shape: TriangleShapeFunction):
|
|
176
|
+
if isinstance(shape, TriangleShapeFunction):
|
|
177
|
+
return forward_base_topology(TrimeshSpaceTopology, mesh, shape)
|
|
178
|
+
|
|
179
|
+
raise ValueError(f"Unsupported shape function {shape.name}")
|
warp/fem/utils.py
CHANGED
|
@@ -1,323 +1,18 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Tuple, Union
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
5
|
import warp as wp
|
|
6
6
|
import warp.fem.cache as cache
|
|
7
|
+
from warp.fem.linalg import ( # noqa: F401 (for backward compatibility, not part of public API but used in examples)
|
|
8
|
+
array_axpy,
|
|
9
|
+
inverse_qr,
|
|
10
|
+
symmetric_eigenvalues_qr,
|
|
11
|
+
)
|
|
7
12
|
from warp.fem.types import NULL_NODE_INDEX
|
|
8
13
|
from warp.utils import array_scan, radix_sort_pairs, runlength_encode
|
|
9
14
|
|
|
10
15
|
|
|
11
|
-
@wp.func
|
|
12
|
-
def generalized_outer(x: Any, y: Any):
|
|
13
|
-
"""Generalized outer product allowing for the first argument to be a scalar"""
|
|
14
|
-
return wp.outer(x, y)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
@wp.func
|
|
18
|
-
def generalized_outer(x: wp.float32, y: wp.vec2):
|
|
19
|
-
return x * y
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
@wp.func
|
|
23
|
-
def generalized_outer(x: wp.float32, y: wp.vec3):
|
|
24
|
-
return x * y
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@wp.func
|
|
28
|
-
def generalized_inner(x: Any, y: Any):
|
|
29
|
-
"""Generalized inner product allowing for the first argument to be a tensor"""
|
|
30
|
-
return wp.dot(x, y)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
@wp.func
|
|
34
|
-
def generalized_inner(x: wp.mat22, y: wp.vec2):
|
|
35
|
-
return x[0] * y[0] + x[1] * y[1]
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
@wp.func
|
|
39
|
-
def generalized_inner(x: wp.mat33, y: wp.vec3):
|
|
40
|
-
return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
@wp.func
|
|
44
|
-
def unit_element(template_type: Any, coord: int):
|
|
45
|
-
"""Returns a instance of `template_type` with a single coordinate set to 1 in the canonical basis"""
|
|
46
|
-
|
|
47
|
-
t = type(template_type)(0.0)
|
|
48
|
-
t[coord] = 1.0
|
|
49
|
-
return t
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
@wp.func
|
|
53
|
-
def unit_element(template_type: wp.float32, coord: int):
|
|
54
|
-
return 1.0
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
@wp.func
|
|
58
|
-
def unit_element(template_type: wp.mat22, coord: int):
|
|
59
|
-
t = wp.mat22(0.0)
|
|
60
|
-
row = coord // 2
|
|
61
|
-
col = coord - 2 * row
|
|
62
|
-
t[row, col] = 1.0
|
|
63
|
-
return t
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
@wp.func
|
|
67
|
-
def unit_element(template_type: wp.mat33, coord: int):
|
|
68
|
-
t = wp.mat33(0.0)
|
|
69
|
-
row = coord // 3
|
|
70
|
-
col = coord - 3 * row
|
|
71
|
-
t[row, col] = 1.0
|
|
72
|
-
return t
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
@wp.func
|
|
76
|
-
def symmetric_part(x: Any):
|
|
77
|
-
"""Symmetric part of a square tensor"""
|
|
78
|
-
return 0.5 * (x + wp.transpose(x))
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
@wp.func
|
|
82
|
-
def skew_part(x: wp.mat22):
|
|
83
|
-
"""Skew part of a 2x2 tensor as corresponding rotation angle"""
|
|
84
|
-
return 0.5 * (x[1, 0] - x[0, 1])
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
@wp.func
|
|
88
|
-
def skew_part(x: wp.mat33):
|
|
89
|
-
"""Skew part of a 3x3 tensor as the corresponding rotation vector"""
|
|
90
|
-
a = 0.5 * (x[2, 1] - x[1, 2])
|
|
91
|
-
b = 0.5 * (x[0, 2] - x[2, 0])
|
|
92
|
-
c = 0.5 * (x[1, 0] - x[0, 1])
|
|
93
|
-
return wp.vec3(a, b, c)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
@wp.func
|
|
97
|
-
def householder_qr_decomposition(A: Any):
|
|
98
|
-
"""
|
|
99
|
-
QR decomposition of a square matrix using Householder reflections
|
|
100
|
-
|
|
101
|
-
Returns Q and R such that Q R = A, Q orthonormal (such that QQ^T = Id), R upper triangular
|
|
102
|
-
"""
|
|
103
|
-
|
|
104
|
-
x = type(A[0])()
|
|
105
|
-
Q = wp.identity(n=type(x).length, dtype=A.dtype)
|
|
106
|
-
|
|
107
|
-
zero = x.dtype(0.0)
|
|
108
|
-
two = x.dtype(2.0)
|
|
109
|
-
|
|
110
|
-
for i in range(type(x).length):
|
|
111
|
-
for k in range(type(x).length):
|
|
112
|
-
x[k] = wp.select(k < i, A[k, i], zero)
|
|
113
|
-
|
|
114
|
-
alpha = wp.length(x) * wp.sign(x[i])
|
|
115
|
-
x[i] += alpha
|
|
116
|
-
two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
|
|
117
|
-
|
|
118
|
-
A -= wp.outer(two_over_x_sq * x, x * A)
|
|
119
|
-
Q -= wp.outer(Q * x, two_over_x_sq * x)
|
|
120
|
-
|
|
121
|
-
return Q, A
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@wp.func
|
|
125
|
-
def householder_make_hessenberg(A: Any):
|
|
126
|
-
"""Transforms a square matrix to Hessenberg form (single lower diagonal) using Householder reflections
|
|
127
|
-
|
|
128
|
-
Returns:
|
|
129
|
-
Q and H such that Q H Q^T = A, Q orthonormal, H under Hessenberg form
|
|
130
|
-
If A is symmetric, H will be tridiagonal
|
|
131
|
-
"""
|
|
132
|
-
|
|
133
|
-
x = type(A[0])()
|
|
134
|
-
Q = wp.identity(n=type(x).length, dtype=A.dtype)
|
|
135
|
-
|
|
136
|
-
zero = x.dtype(0.0)
|
|
137
|
-
two = x.dtype(2.0)
|
|
138
|
-
|
|
139
|
-
for i in range(1, type(x).length):
|
|
140
|
-
for k in range(type(x).length):
|
|
141
|
-
x[k] = wp.select(k < i, A[k, i - 1], zero)
|
|
142
|
-
|
|
143
|
-
alpha = wp.length(x) * wp.sign(x[i])
|
|
144
|
-
x[i] += alpha
|
|
145
|
-
two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
|
|
146
|
-
|
|
147
|
-
# apply on both sides
|
|
148
|
-
A -= wp.outer(two_over_x_sq * x, x * A)
|
|
149
|
-
A -= wp.outer(A * x, two_over_x_sq * x)
|
|
150
|
-
Q -= wp.outer(Q * x, two_over_x_sq * x)
|
|
151
|
-
|
|
152
|
-
return Q, A
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
@wp.func
|
|
156
|
-
def solve_triangular(R: Any, b: Any):
|
|
157
|
-
"""Solves for R x = b where R is an upper triangular matrix
|
|
158
|
-
|
|
159
|
-
Returns x
|
|
160
|
-
"""
|
|
161
|
-
zero = b.dtype(0)
|
|
162
|
-
x = type(b)(b.dtype(0))
|
|
163
|
-
for i in range(b.length, 0, -1):
|
|
164
|
-
j = i - 1
|
|
165
|
-
r = b[j] - wp.dot(R[j], x)
|
|
166
|
-
x[j] = wp.select(R[j, j] == zero, r / R[j, j], zero)
|
|
167
|
-
|
|
168
|
-
return x
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
@wp.func
|
|
172
|
-
def inverse_qr(A: Any):
|
|
173
|
-
# Computes a square matrix inverse using QR factorization
|
|
174
|
-
|
|
175
|
-
Q, R = householder_qr_decomposition(A)
|
|
176
|
-
|
|
177
|
-
A_inv = type(A)()
|
|
178
|
-
for i in range(type(A[0]).length):
|
|
179
|
-
A_inv[i] = solve_triangular(R, Q[i]) # ith column of Q^T
|
|
180
|
-
|
|
181
|
-
return wp.transpose(A_inv)
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
@wp.func
|
|
185
|
-
def _wilkinson_shift(a: Any, b: Any, c: Any, tol: Any):
|
|
186
|
-
# Wilkinson shift: estimate eigenvalue of 2x2 symmetric matrix [a, c, c, b]
|
|
187
|
-
d = (a - b) * type(tol)(0.5)
|
|
188
|
-
return b + d - wp.sign(d) * wp.sqrt(d * d + c * c)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
@wp.func
|
|
192
|
-
def _givens_rotation(a: Any, b: Any):
|
|
193
|
-
# Givens rotation [[c -s], [s c]] such that sa+cb =0
|
|
194
|
-
zero = type(a)(0.0)
|
|
195
|
-
one = type(a)(1.0)
|
|
196
|
-
|
|
197
|
-
b2 = b * b
|
|
198
|
-
if b2 == zero:
|
|
199
|
-
# id rotation
|
|
200
|
-
return one, zero
|
|
201
|
-
|
|
202
|
-
scale = one / wp.sqrt(a * a + b2)
|
|
203
|
-
return a * scale, -b * scale
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
@wp.func
|
|
207
|
-
def tridiagonal_symmetric_eigenvalues_qr(D: Any, L: Any, Q: Any, tol: Any):
|
|
208
|
-
"""
|
|
209
|
-
Computes the eigenvalues and eigen vectors of a symmetric tridiagonal matrix using the
|
|
210
|
-
Symmetric tridiagonal QR algorithm with implicit Wilkinson shift
|
|
211
|
-
|
|
212
|
-
Args:
|
|
213
|
-
D: Main diagonal of the matrix
|
|
214
|
-
L: Lower diagonal of the matrix, indexed such that L[i] = A[i+1, i]
|
|
215
|
-
Q: Initialization for the eigenvectors, useful if a pre-transformation has been applied, otherwise may be identity
|
|
216
|
-
tol: Tolerance for the diagonalization residual (Linf norm of off-diagonal over diagonal terms)
|
|
217
|
-
|
|
218
|
-
Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
Ref: Arbenz P, Numerical Methods for Solving Large Scale Eigenvalue Problems, Chapter 4 (QR algorithm, Mar 13, 2018)
|
|
222
|
-
"""
|
|
223
|
-
|
|
224
|
-
two = D.dtype(2.0)
|
|
225
|
-
|
|
226
|
-
# so that we can use the type length in expressions
|
|
227
|
-
# this will prevent unrolling by warp, but should be ok for native code
|
|
228
|
-
m = int(0)
|
|
229
|
-
for _ in range(type(D).length):
|
|
230
|
-
m += 1
|
|
231
|
-
|
|
232
|
-
start = int(0)
|
|
233
|
-
y = D.dtype(0.0) # moving buldge
|
|
234
|
-
x = D.dtype(0.0) # coeff atop buldge
|
|
235
|
-
|
|
236
|
-
for _ in range(32 * m): # failsafe, usually converges faster than that
|
|
237
|
-
# Iterate over all independent (deflated) blocks
|
|
238
|
-
end = int(-1)
|
|
239
|
-
|
|
240
|
-
for k in range(m - 1):
|
|
241
|
-
if k >= end:
|
|
242
|
-
# Check if new block is starting
|
|
243
|
-
if k == end or wp.abs(L[k]) <= tol * (wp.abs(D[k]) + wp.abs(D[k + 1])):
|
|
244
|
-
continue
|
|
245
|
-
|
|
246
|
-
# Find end of block
|
|
247
|
-
start = k
|
|
248
|
-
end = start + 1
|
|
249
|
-
while end + 1 < m:
|
|
250
|
-
if wp.abs(L[end]) <= tol * (wp.abs(D[end + 1]) + wp.abs(D[end])):
|
|
251
|
-
break
|
|
252
|
-
end += 1
|
|
253
|
-
|
|
254
|
-
# Wilkinson shift (an eigenvalue of the last 2x2 block)
|
|
255
|
-
shift = _wilkinson_shift(D[end - 1], D[end], L[end - 1], tol)
|
|
256
|
-
|
|
257
|
-
# start with eliminating lower diag of first column of shifted matrix
|
|
258
|
-
# (i.e. first step of excplit QR factorization)
|
|
259
|
-
# Then all further steps eliminate the buldge (second diag) of the non-shifted matrix
|
|
260
|
-
x = D[start] - shift
|
|
261
|
-
y = L[start]
|
|
262
|
-
|
|
263
|
-
c, s = _givens_rotation(x, y)
|
|
264
|
-
|
|
265
|
-
# Apply Givens rotation on both sides of tridiagonal matrix
|
|
266
|
-
|
|
267
|
-
# middle block
|
|
268
|
-
d = D[k] - D[k + 1]
|
|
269
|
-
z = (two * c * L[k] + d * s) * s
|
|
270
|
-
D[k] -= z
|
|
271
|
-
D[k + 1] += z
|
|
272
|
-
L[k] = d * c * s + (c * c - s * s) * L[k]
|
|
273
|
-
|
|
274
|
-
if k > start:
|
|
275
|
-
L[k - 1] = c * x - s * y
|
|
276
|
-
|
|
277
|
-
x = L[k]
|
|
278
|
-
y = -s * L[k + 1] # new buldge
|
|
279
|
-
L[k + 1] *= c
|
|
280
|
-
|
|
281
|
-
# apply givens rotation on left of Q
|
|
282
|
-
# note: Q is transposed compared to usual impls, as Warp makes it easier to index rows
|
|
283
|
-
Qk0 = Q[k]
|
|
284
|
-
Qk1 = Q[k + 1]
|
|
285
|
-
Q[k] = c * Qk0 - s * Qk1
|
|
286
|
-
Q[k + 1] = c * Qk1 + s * Qk0
|
|
287
|
-
|
|
288
|
-
if end <= 0:
|
|
289
|
-
# We did nothing, so diagonalization must have been achieved
|
|
290
|
-
break
|
|
291
|
-
|
|
292
|
-
return D, Q
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
@wp.func
|
|
296
|
-
def symmetric_eigenvalues_qr(A: Any, tol: Any):
|
|
297
|
-
"""
|
|
298
|
-
Computes the eigenvalues and eigen vectors of a square symmetric matrix A using the QR algorithm
|
|
299
|
-
|
|
300
|
-
Args:
|
|
301
|
-
A: square symmetric matrix
|
|
302
|
-
tol: Tolerance for the diagonalization residual (Linf norm of off-diagonal over diagonal terms)
|
|
303
|
-
|
|
304
|
-
Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
|
|
305
|
-
"""
|
|
306
|
-
|
|
307
|
-
# Put A under Hessenberg form (tridiagonal)
|
|
308
|
-
Q, H = householder_make_hessenberg(A)
|
|
309
|
-
|
|
310
|
-
# tridiagonal storage for H
|
|
311
|
-
D = wp.get_diag(H)
|
|
312
|
-
L = type(D)(A.dtype(0.0))
|
|
313
|
-
for i in range(1, type(D).length):
|
|
314
|
-
L[i - 1] = H[i, i - 1]
|
|
315
|
-
|
|
316
|
-
Qt = wp.transpose(Q)
|
|
317
|
-
ev, P = tridiagonal_symmetric_eigenvalues_qr(D, L, Qt, tol)
|
|
318
|
-
return ev, P
|
|
319
|
-
|
|
320
|
-
|
|
321
16
|
def compress_node_indices(
|
|
322
17
|
node_count: int,
|
|
323
18
|
node_indices: wp.array(dtype=int),
|
|
@@ -458,20 +153,6 @@ def masked_indices(
|
|
|
458
153
|
return indices_temp, offsets_temp
|
|
459
154
|
|
|
460
155
|
|
|
461
|
-
def array_axpy(x: wp.array, y: wp.array, alpha: float = 1.0, beta: float = 1.0):
|
|
462
|
-
"""Performs y = alpha*x + beta*y"""
|
|
463
|
-
|
|
464
|
-
dtype = wp.types.type_scalar_type(x.dtype)
|
|
465
|
-
|
|
466
|
-
alpha = dtype(alpha)
|
|
467
|
-
beta = dtype(beta)
|
|
468
|
-
|
|
469
|
-
if not wp.types.types_equal(x.dtype, y.dtype) or x.shape != y.shape or x.device != y.device:
|
|
470
|
-
raise ValueError("x and y arrays must have same dat atype, shape and device")
|
|
471
|
-
|
|
472
|
-
wp.launch(kernel=_array_axpy_kernel, dim=x.shape, device=x.device, inputs=[x, y, alpha, beta])
|
|
473
|
-
|
|
474
|
-
|
|
475
156
|
@wp.kernel
|
|
476
157
|
def _iota_kernel(indices: wp.array(dtype=int), divisor: int):
|
|
477
158
|
indices[wp.tid()] = wp.tid() // divisor
|
|
@@ -515,12 +196,6 @@ def _masked_indices_kernel(
|
|
|
515
196
|
masked_to_global[masked_idx] = i
|
|
516
197
|
|
|
517
198
|
|
|
518
|
-
@wp.kernel
|
|
519
|
-
def _array_axpy_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any), alpha: Any, beta: Any):
|
|
520
|
-
i = wp.tid()
|
|
521
|
-
y[i] = beta * y[i] + alpha * x[i]
|
|
522
|
-
|
|
523
|
-
|
|
524
199
|
def grid_to_tris(Nx: int, Ny: int):
|
|
525
200
|
"""Constructs a triangular mesh topology by dividing each cell of a dense 2D grid into two triangles.
|
|
526
201
|
|
warp/jax_experimental.py
CHANGED
|
@@ -102,7 +102,9 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
|
|
|
102
102
|
assert hooks.forward, "Failed to find kernel entry point"
|
|
103
103
|
|
|
104
104
|
# Launch the kernel.
|
|
105
|
-
wp.context.runtime.core.cuda_launch_kernel(
|
|
105
|
+
wp.context.runtime.core.cuda_launch_kernel(
|
|
106
|
+
device.context, hooks.forward, bounds.size, 0, 256, hooks.forward_smem_bytes, kernel_params, stream
|
|
107
|
+
)
|
|
106
108
|
|
|
107
109
|
|
|
108
110
|
# TODO: is there a simpler way of getting the Jax "current" device?
|
warp/native/array.h
CHANGED
|
@@ -1,3 +1,11 @@
|
|
|
1
|
+
/** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
* and proprietary rights in and to this software, related documentation
|
|
4
|
+
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
* distribution of this software and related documentation without an express
|
|
6
|
+
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
*/
|
|
8
|
+
|
|
1
9
|
#pragma once
|
|
2
10
|
|
|
3
11
|
#include "builtin.h"
|
|
@@ -285,6 +293,13 @@ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i)
|
|
|
285
293
|
template <typename T>
|
|
286
294
|
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j)
|
|
287
295
|
{
|
|
296
|
+
// if (i < 0 || i >= arr.shape[0])
|
|
297
|
+
// printf("i: %d > arr.shape[0]: %d\n", i, arr.shape[0]);
|
|
298
|
+
|
|
299
|
+
// if (j < 0 || j >= arr.shape[1])
|
|
300
|
+
// printf("j: %d > arr.shape[1]: %d\n", j, arr.shape[1]);
|
|
301
|
+
|
|
302
|
+
|
|
288
303
|
assert(i >= 0 && i < arr.shape[0]);
|
|
289
304
|
assert(j >= 0 && j < arr.shape[1]);
|
|
290
305
|
|
warp/native/builtin.h
CHANGED
|
@@ -1145,7 +1145,47 @@ struct launch_bounds_t
|
|
|
1145
1145
|
size_t size; // total number of threads
|
|
1146
1146
|
};
|
|
1147
1147
|
|
|
1148
|
-
|
|
1148
|
+
// represents coordinate in the launch grid
|
|
1149
|
+
struct launch_coord_t
|
|
1150
|
+
{
|
|
1151
|
+
int i;
|
|
1152
|
+
int j;
|
|
1153
|
+
int k;
|
|
1154
|
+
int l;
|
|
1155
|
+
};
|
|
1156
|
+
|
|
1157
|
+
// unravels a linear thread index to the corresponding launch grid coord (up to 4d)
|
|
1158
|
+
inline CUDA_CALLABLE launch_coord_t launch_coord(size_t linear, const launch_bounds_t& bounds)
|
|
1159
|
+
{
|
|
1160
|
+
launch_coord_t coord = {0, 0, 0, 0};
|
|
1161
|
+
|
|
1162
|
+
if (bounds.ndim > 3)
|
|
1163
|
+
{
|
|
1164
|
+
coord.l = linear%bounds.shape[3];
|
|
1165
|
+
linear /= bounds.shape[3];
|
|
1166
|
+
}
|
|
1167
|
+
|
|
1168
|
+
if (bounds.ndim > 2)
|
|
1169
|
+
{
|
|
1170
|
+
coord.k = linear%bounds.shape[2];
|
|
1171
|
+
linear /= bounds.shape[2];
|
|
1172
|
+
}
|
|
1173
|
+
|
|
1174
|
+
if (bounds.ndim > 1)
|
|
1175
|
+
{
|
|
1176
|
+
coord.j = linear%bounds.shape[1];
|
|
1177
|
+
linear /= bounds.shape[1];
|
|
1178
|
+
}
|
|
1179
|
+
|
|
1180
|
+
if (bounds.ndim > 0)
|
|
1181
|
+
{
|
|
1182
|
+
coord.i = linear;
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
return coord;
|
|
1186
|
+
}
|
|
1187
|
+
|
|
1188
|
+
inline CUDA_CALLABLE int tid(size_t index, const launch_bounds_t& bounds)
|
|
1149
1189
|
{
|
|
1150
1190
|
// For the 1-D tid() we need to warn the user if we're about to provide a truncated index
|
|
1151
1191
|
// Only do this in _DEBUG when called from device to avoid excessive register allocation
|
|
@@ -1154,40 +1194,33 @@ inline CUDA_CALLABLE int tid(size_t index)
|
|
|
1154
1194
|
printf("Warp warning: tid() is returning an overflowed int\n");
|
|
1155
1195
|
}
|
|
1156
1196
|
#endif
|
|
1157
|
-
|
|
1197
|
+
|
|
1198
|
+
launch_coord_t c = launch_coord(index, bounds);
|
|
1199
|
+
return static_cast<int>(c.i);
|
|
1158
1200
|
}
|
|
1159
1201
|
|
|
1160
|
-
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t&
|
|
1202
|
+
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& bounds)
|
|
1161
1203
|
{
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
i = index/n;
|
|
1166
|
-
j = index%n;
|
|
1204
|
+
launch_coord_t c = launch_coord(index, bounds);
|
|
1205
|
+
i = c.i;
|
|
1206
|
+
j = c.j;
|
|
1167
1207
|
}
|
|
1168
1208
|
|
|
1169
|
-
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t&
|
|
1209
|
+
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& bounds)
|
|
1170
1210
|
{
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
i = index/(n*o);
|
|
1176
|
-
j = index%(n*o)/o;
|
|
1177
|
-
k = index%o;
|
|
1211
|
+
launch_coord_t c = launch_coord(index, bounds);
|
|
1212
|
+
i = c.i;
|
|
1213
|
+
j = c.j;
|
|
1214
|
+
k = c.k;
|
|
1178
1215
|
}
|
|
1179
1216
|
|
|
1180
|
-
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t&
|
|
1217
|
+
inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& bounds)
|
|
1181
1218
|
{
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
i = index/(n*o*p);
|
|
1188
|
-
j = index%(n*o*p)/(o*p);
|
|
1189
|
-
k = index%(o*p)/p;
|
|
1190
|
-
l = index%p;
|
|
1219
|
+
launch_coord_t c = launch_coord(index, bounds);
|
|
1220
|
+
i = c.i;
|
|
1221
|
+
j = c.j;
|
|
1222
|
+
k = c.k;
|
|
1223
|
+
l = c.l;
|
|
1191
1224
|
}
|
|
1192
1225
|
|
|
1193
1226
|
template<typename T>
|
|
@@ -1724,3 +1757,10 @@ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expect
|
|
|
1724
1757
|
#include "rand.h"
|
|
1725
1758
|
#include "noise.h"
|
|
1726
1759
|
#include "matnn.h"
|
|
1760
|
+
|
|
1761
|
+
// only include in kernels for now
|
|
1762
|
+
#if defined(__CUDACC_RTC__)
|
|
1763
|
+
#include "tile.h"
|
|
1764
|
+
#include "tile_gemm.h"
|
|
1765
|
+
#include "tile_reduce.h"
|
|
1766
|
+
#endif
|