warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (158) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1783 -2
  8. warp/codegen.py +177 -45
  9. warp/config.py +2 -2
  10. warp/context.py +321 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +2 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -5
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +600 -0
  82. warp/native/cuda_util.cpp +14 -0
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1857 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +137 -65
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +88 -15
  117. warp/stubs.py +569 -4
  118. warp/tape.py +12 -7
  119. warp/tests/assets/pixel.npy +0 -0
  120. warp/tests/aux_test_instancing_gc.py +18 -0
  121. warp/tests/test_array.py +39 -0
  122. warp/tests/test_codegen.py +81 -1
  123. warp/tests/test_codegen_instancing.py +30 -0
  124. warp/tests/test_collision.py +110 -0
  125. warp/tests/test_coloring.py +241 -0
  126. warp/tests/test_context.py +34 -0
  127. warp/tests/test_examples.py +18 -4
  128. warp/tests/test_fem.py +453 -113
  129. warp/tests/test_func.py +13 -0
  130. warp/tests/test_generics.py +52 -0
  131. warp/tests/test_iter.py +68 -0
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_mesh_query_point.py +1 -1
  134. warp/tests/test_module_hashing.py +23 -0
  135. warp/tests/test_paddle.py +27 -87
  136. warp/tests/test_print.py +56 -1
  137. warp/tests/test_spatial.py +1 -1
  138. warp/tests/test_tile.py +700 -0
  139. warp/tests/test_tile_mathdx.py +144 -0
  140. warp/tests/test_tile_mlp.py +383 -0
  141. warp/tests/test_tile_reduce.py +374 -0
  142. warp/tests/test_tile_shared_memory.py +190 -0
  143. warp/tests/test_vbd.py +12 -20
  144. warp/tests/test_volume.py +43 -0
  145. warp/tests/unittest_suites.py +19 -2
  146. warp/tests/unittest_utils.py +4 -0
  147. warp/types.py +338 -72
  148. warp/utils.py +22 -1
  149. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  150. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
  151. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  152. warp/fem/field/test.py +0 -180
  153. warp/fem/field/trial.py +0 -183
  154. warp/fem/space/collocated_function_space.py +0 -102
  155. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  156. warp/fem/space/trimesh_2d_function_space.py +0 -153
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,179 @@
1
+ import warp as wp
2
+ from warp.fem import cache
3
+ from warp.fem.geometry import Trimesh
4
+ from warp.fem.types import ElementIndex
5
+
6
+ from .shape import TriangleShapeFunction
7
+ from .topology import SpaceTopology, forward_base_topology
8
+
9
+
10
+ @wp.struct
11
+ class TrimeshTopologyArg:
12
+ edge_vertex_indices: wp.array(dtype=wp.vec2i)
13
+ tri_edge_indices: wp.array2d(dtype=int)
14
+
15
+ vertex_count: int
16
+ edge_count: int
17
+
18
+
19
+ class TrimeshSpaceTopology(SpaceTopology):
20
+ TopologyArg = TrimeshTopologyArg
21
+
22
+ def __init__(self, mesh: Trimesh, shape: TriangleShapeFunction):
23
+ self._shape = shape
24
+ super().__init__(mesh, shape.NODES_PER_ELEMENT)
25
+ self._mesh = mesh
26
+
27
+ self._compute_tri_edge_indices()
28
+ self.element_node_index = self._make_element_node_index()
29
+ self.element_node_sign = self._make_element_node_sign()
30
+
31
+ @property
32
+ def name(self):
33
+ return f"{self.geometry.name}_{self._shape.name}"
34
+
35
+ @cache.cached_arg_value
36
+ def topo_arg_value(self, device):
37
+ arg = TrimeshTopologyArg()
38
+ arg.tri_edge_indices = self._tri_edge_indices.to(device)
39
+ arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
40
+
41
+ arg.vertex_count = self._mesh.vertex_count()
42
+ arg.edge_count = self._mesh.side_count()
43
+ return arg
44
+
45
+ def _compute_tri_edge_indices(self):
46
+ self._tri_edge_indices = wp.empty(
47
+ dtype=int, device=self._mesh.tri_vertex_indices.device, shape=(self._mesh.cell_count(), 3)
48
+ )
49
+
50
+ wp.launch(
51
+ kernel=TrimeshSpaceTopology._compute_tri_edge_indices_kernel,
52
+ dim=self._mesh.edge_tri_indices.shape,
53
+ device=self._mesh.tri_vertex_indices.device,
54
+ inputs=[
55
+ self._mesh.edge_tri_indices,
56
+ self._mesh.edge_vertex_indices,
57
+ self._mesh.tri_vertex_indices,
58
+ self._tri_edge_indices,
59
+ ],
60
+ )
61
+
62
+ @wp.func
63
+ def _find_edge_index_in_tri(
64
+ edge_vtx: wp.vec2i,
65
+ tri_vtx: wp.vec3i,
66
+ ):
67
+ for k in range(2):
68
+ if (edge_vtx[0] == tri_vtx[k] and edge_vtx[1] == tri_vtx[k + 1]) or (
69
+ edge_vtx[1] == tri_vtx[k] and edge_vtx[0] == tri_vtx[k + 1]
70
+ ):
71
+ return k
72
+ return 2
73
+
74
+ @wp.kernel
75
+ def _compute_tri_edge_indices_kernel(
76
+ edge_tri_indices: wp.array(dtype=wp.vec2i),
77
+ edge_vertex_indices: wp.array(dtype=wp.vec2i),
78
+ tri_vertex_indices: wp.array2d(dtype=int),
79
+ tri_edge_indices: wp.array2d(dtype=int),
80
+ ):
81
+ e = wp.tid()
82
+
83
+ edge_vtx = edge_vertex_indices[e]
84
+ edge_tris = edge_tri_indices[e]
85
+
86
+ t0 = edge_tris[0]
87
+ t0_vtx = wp.vec3i(tri_vertex_indices[t0, 0], tri_vertex_indices[t0, 1], tri_vertex_indices[t0, 2])
88
+ t0_edge = TrimeshSpaceTopology._find_edge_index_in_tri(edge_vtx, t0_vtx)
89
+ tri_edge_indices[t0, t0_edge] = e
90
+
91
+ t1 = edge_tris[1]
92
+ if t1 != t0:
93
+ t1_vtx = wp.vec3i(tri_vertex_indices[t1, 0], tri_vertex_indices[t1, 1], tri_vertex_indices[t1, 2])
94
+ t1_edge = TrimeshSpaceTopology._find_edge_index_in_tri(edge_vtx, t1_vtx)
95
+ tri_edge_indices[t1, t1_edge] = e
96
+
97
+ def node_count(self) -> int:
98
+ return (
99
+ self._mesh.vertex_count() * self._shape.VERTEX_NODE_COUNT
100
+ + self._mesh.side_count() * self._shape.EDGE_NODE_COUNT
101
+ + self._mesh.cell_count() * self._shape.INTERIOR_NODE_COUNT
102
+ )
103
+
104
+ def _make_element_node_index(self):
105
+ VERTEX_NODE_COUNT = self._shape.VERTEX_NODE_COUNT
106
+ INTERIOR_NODES_PER_SIDE = self._shape.EDGE_NODE_COUNT
107
+ INTERIOR_NODES_PER_CELL = self._shape.INTERIOR_NODE_COUNT
108
+
109
+ @cache.dynamic_func(suffix=self.name)
110
+ def element_node_index(
111
+ geo_arg: self.geometry.CellArg,
112
+ topo_arg: TrimeshTopologyArg,
113
+ element_index: ElementIndex,
114
+ node_index_in_elt: int,
115
+ ):
116
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
117
+
118
+ if wp.static(VERTEX_NODE_COUNT > 0):
119
+ if node_type == TriangleShapeFunction.VERTEX:
120
+ vertex = type_index // VERTEX_NODE_COUNT
121
+ vertex_node = type_index - VERTEX_NODE_COUNT * vertex
122
+ return geo_arg.topology.tri_vertex_indices[element_index][vertex] * VERTEX_NODE_COUNT + vertex_node
123
+
124
+ global_offset = topo_arg.vertex_count * VERTEX_NODE_COUNT
125
+
126
+ if wp.static(INTERIOR_NODES_PER_SIDE > 0):
127
+ if node_type == TriangleShapeFunction.EDGE:
128
+ edge = type_index // INTERIOR_NODES_PER_SIDE
129
+ edge_node = type_index - INTERIOR_NODES_PER_SIDE * edge
130
+
131
+ global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
132
+
133
+ if (
134
+ topo_arg.edge_vertex_indices[global_edge_index][0]
135
+ != geo_arg.topology.tri_vertex_indices[element_index][edge]
136
+ ):
137
+ edge_node = INTERIOR_NODES_PER_SIDE - 1 - edge_node
138
+
139
+ return global_offset + INTERIOR_NODES_PER_SIDE * global_edge_index + edge_node
140
+
141
+ global_offset += INTERIOR_NODES_PER_SIDE * topo_arg.edge_count
142
+
143
+ return global_offset + INTERIOR_NODES_PER_CELL * element_index + type_index
144
+
145
+ return element_node_index
146
+
147
+ def _make_element_node_sign(self):
148
+ INTERIOR_NODES_PER_SIDE = self._shape.EDGE_NODE_COUNT
149
+
150
+ @cache.dynamic_func(suffix=self.name)
151
+ def element_node_sign(
152
+ geo_arg: self.geometry.CellArg,
153
+ topo_arg: TrimeshTopologyArg,
154
+ element_index: ElementIndex,
155
+ node_index_in_elt: int,
156
+ ):
157
+ node_type, type_index = self._shape.node_type_and_type_index(node_index_in_elt)
158
+
159
+ if node_type == TriangleShapeFunction.EDGE:
160
+ edge = type_index // INTERIOR_NODES_PER_SIDE
161
+
162
+ global_edge_index = topo_arg.tri_edge_indices[element_index][edge]
163
+ return wp.select(
164
+ topo_arg.edge_vertex_indices[global_edge_index][0]
165
+ == geo_arg.topology.tri_vertex_indices[element_index][edge],
166
+ -1.0,
167
+ 1.0,
168
+ )
169
+
170
+ return 1.0
171
+
172
+ return element_node_sign
173
+
174
+
175
+ def make_trimesh_space_topology(mesh: Trimesh, shape: TriangleShapeFunction):
176
+ if isinstance(shape, TriangleShapeFunction):
177
+ return forward_base_topology(TrimeshSpaceTopology, mesh, shape)
178
+
179
+ raise ValueError(f"Unsupported shape function {shape.name}")
warp/fem/utils.py CHANGED
@@ -1,323 +1,18 @@
1
- from typing import Any, Tuple, Union
1
+ from typing import Tuple, Union
2
2
 
3
3
  import numpy as np
4
4
 
5
5
  import warp as wp
6
6
  import warp.fem.cache as cache
7
+ from warp.fem.linalg import ( # noqa: F401 (for backward compatibility, not part of public API but used in examples)
8
+ array_axpy,
9
+ inverse_qr,
10
+ symmetric_eigenvalues_qr,
11
+ )
7
12
  from warp.fem.types import NULL_NODE_INDEX
8
13
  from warp.utils import array_scan, radix_sort_pairs, runlength_encode
9
14
 
10
15
 
11
- @wp.func
12
- def generalized_outer(x: Any, y: Any):
13
- """Generalized outer product allowing for the first argument to be a scalar"""
14
- return wp.outer(x, y)
15
-
16
-
17
- @wp.func
18
- def generalized_outer(x: wp.float32, y: wp.vec2):
19
- return x * y
20
-
21
-
22
- @wp.func
23
- def generalized_outer(x: wp.float32, y: wp.vec3):
24
- return x * y
25
-
26
-
27
- @wp.func
28
- def generalized_inner(x: Any, y: Any):
29
- """Generalized inner product allowing for the first argument to be a tensor"""
30
- return wp.dot(x, y)
31
-
32
-
33
- @wp.func
34
- def generalized_inner(x: wp.mat22, y: wp.vec2):
35
- return x[0] * y[0] + x[1] * y[1]
36
-
37
-
38
- @wp.func
39
- def generalized_inner(x: wp.mat33, y: wp.vec3):
40
- return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]
41
-
42
-
43
- @wp.func
44
- def unit_element(template_type: Any, coord: int):
45
- """Returns a instance of `template_type` with a single coordinate set to 1 in the canonical basis"""
46
-
47
- t = type(template_type)(0.0)
48
- t[coord] = 1.0
49
- return t
50
-
51
-
52
- @wp.func
53
- def unit_element(template_type: wp.float32, coord: int):
54
- return 1.0
55
-
56
-
57
- @wp.func
58
- def unit_element(template_type: wp.mat22, coord: int):
59
- t = wp.mat22(0.0)
60
- row = coord // 2
61
- col = coord - 2 * row
62
- t[row, col] = 1.0
63
- return t
64
-
65
-
66
- @wp.func
67
- def unit_element(template_type: wp.mat33, coord: int):
68
- t = wp.mat33(0.0)
69
- row = coord // 3
70
- col = coord - 3 * row
71
- t[row, col] = 1.0
72
- return t
73
-
74
-
75
- @wp.func
76
- def symmetric_part(x: Any):
77
- """Symmetric part of a square tensor"""
78
- return 0.5 * (x + wp.transpose(x))
79
-
80
-
81
- @wp.func
82
- def skew_part(x: wp.mat22):
83
- """Skew part of a 2x2 tensor as corresponding rotation angle"""
84
- return 0.5 * (x[1, 0] - x[0, 1])
85
-
86
-
87
- @wp.func
88
- def skew_part(x: wp.mat33):
89
- """Skew part of a 3x3 tensor as the corresponding rotation vector"""
90
- a = 0.5 * (x[2, 1] - x[1, 2])
91
- b = 0.5 * (x[0, 2] - x[2, 0])
92
- c = 0.5 * (x[1, 0] - x[0, 1])
93
- return wp.vec3(a, b, c)
94
-
95
-
96
- @wp.func
97
- def householder_qr_decomposition(A: Any):
98
- """
99
- QR decomposition of a square matrix using Householder reflections
100
-
101
- Returns Q and R such that Q R = A, Q orthonormal (such that QQ^T = Id), R upper triangular
102
- """
103
-
104
- x = type(A[0])()
105
- Q = wp.identity(n=type(x).length, dtype=A.dtype)
106
-
107
- zero = x.dtype(0.0)
108
- two = x.dtype(2.0)
109
-
110
- for i in range(type(x).length):
111
- for k in range(type(x).length):
112
- x[k] = wp.select(k < i, A[k, i], zero)
113
-
114
- alpha = wp.length(x) * wp.sign(x[i])
115
- x[i] += alpha
116
- two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
117
-
118
- A -= wp.outer(two_over_x_sq * x, x * A)
119
- Q -= wp.outer(Q * x, two_over_x_sq * x)
120
-
121
- return Q, A
122
-
123
-
124
- @wp.func
125
- def householder_make_hessenberg(A: Any):
126
- """Transforms a square matrix to Hessenberg form (single lower diagonal) using Householder reflections
127
-
128
- Returns:
129
- Q and H such that Q H Q^T = A, Q orthonormal, H under Hessenberg form
130
- If A is symmetric, H will be tridiagonal
131
- """
132
-
133
- x = type(A[0])()
134
- Q = wp.identity(n=type(x).length, dtype=A.dtype)
135
-
136
- zero = x.dtype(0.0)
137
- two = x.dtype(2.0)
138
-
139
- for i in range(1, type(x).length):
140
- for k in range(type(x).length):
141
- x[k] = wp.select(k < i, A[k, i - 1], zero)
142
-
143
- alpha = wp.length(x) * wp.sign(x[i])
144
- x[i] += alpha
145
- two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
146
-
147
- # apply on both sides
148
- A -= wp.outer(two_over_x_sq * x, x * A)
149
- A -= wp.outer(A * x, two_over_x_sq * x)
150
- Q -= wp.outer(Q * x, two_over_x_sq * x)
151
-
152
- return Q, A
153
-
154
-
155
- @wp.func
156
- def solve_triangular(R: Any, b: Any):
157
- """Solves for R x = b where R is an upper triangular matrix
158
-
159
- Returns x
160
- """
161
- zero = b.dtype(0)
162
- x = type(b)(b.dtype(0))
163
- for i in range(b.length, 0, -1):
164
- j = i - 1
165
- r = b[j] - wp.dot(R[j], x)
166
- x[j] = wp.select(R[j, j] == zero, r / R[j, j], zero)
167
-
168
- return x
169
-
170
-
171
- @wp.func
172
- def inverse_qr(A: Any):
173
- # Computes a square matrix inverse using QR factorization
174
-
175
- Q, R = householder_qr_decomposition(A)
176
-
177
- A_inv = type(A)()
178
- for i in range(type(A[0]).length):
179
- A_inv[i] = solve_triangular(R, Q[i]) # ith column of Q^T
180
-
181
- return wp.transpose(A_inv)
182
-
183
-
184
- @wp.func
185
- def _wilkinson_shift(a: Any, b: Any, c: Any, tol: Any):
186
- # Wilkinson shift: estimate eigenvalue of 2x2 symmetric matrix [a, c, c, b]
187
- d = (a - b) * type(tol)(0.5)
188
- return b + d - wp.sign(d) * wp.sqrt(d * d + c * c)
189
-
190
-
191
- @wp.func
192
- def _givens_rotation(a: Any, b: Any):
193
- # Givens rotation [[c -s], [s c]] such that sa+cb =0
194
- zero = type(a)(0.0)
195
- one = type(a)(1.0)
196
-
197
- b2 = b * b
198
- if b2 == zero:
199
- # id rotation
200
- return one, zero
201
-
202
- scale = one / wp.sqrt(a * a + b2)
203
- return a * scale, -b * scale
204
-
205
-
206
- @wp.func
207
- def tridiagonal_symmetric_eigenvalues_qr(D: Any, L: Any, Q: Any, tol: Any):
208
- """
209
- Computes the eigenvalues and eigen vectors of a symmetric tridiagonal matrix using the
210
- Symmetric tridiagonal QR algorithm with implicit Wilkinson shift
211
-
212
- Args:
213
- D: Main diagonal of the matrix
214
- L: Lower diagonal of the matrix, indexed such that L[i] = A[i+1, i]
215
- Q: Initialization for the eigenvectors, useful if a pre-transformation has been applied, otherwise may be identity
216
- tol: Tolerance for the diagonalization residual (Linf norm of off-diagonal over diagonal terms)
217
-
218
- Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
219
-
220
-
221
- Ref: Arbenz P, Numerical Methods for Solving Large Scale Eigenvalue Problems, Chapter 4 (QR algorithm, Mar 13, 2018)
222
- """
223
-
224
- two = D.dtype(2.0)
225
-
226
- # so that we can use the type length in expressions
227
- # this will prevent unrolling by warp, but should be ok for native code
228
- m = int(0)
229
- for _ in range(type(D).length):
230
- m += 1
231
-
232
- start = int(0)
233
- y = D.dtype(0.0) # moving buldge
234
- x = D.dtype(0.0) # coeff atop buldge
235
-
236
- for _ in range(32 * m): # failsafe, usually converges faster than that
237
- # Iterate over all independent (deflated) blocks
238
- end = int(-1)
239
-
240
- for k in range(m - 1):
241
- if k >= end:
242
- # Check if new block is starting
243
- if k == end or wp.abs(L[k]) <= tol * (wp.abs(D[k]) + wp.abs(D[k + 1])):
244
- continue
245
-
246
- # Find end of block
247
- start = k
248
- end = start + 1
249
- while end + 1 < m:
250
- if wp.abs(L[end]) <= tol * (wp.abs(D[end + 1]) + wp.abs(D[end])):
251
- break
252
- end += 1
253
-
254
- # Wilkinson shift (an eigenvalue of the last 2x2 block)
255
- shift = _wilkinson_shift(D[end - 1], D[end], L[end - 1], tol)
256
-
257
- # start with eliminating lower diag of first column of shifted matrix
258
- # (i.e. first step of excplit QR factorization)
259
- # Then all further steps eliminate the buldge (second diag) of the non-shifted matrix
260
- x = D[start] - shift
261
- y = L[start]
262
-
263
- c, s = _givens_rotation(x, y)
264
-
265
- # Apply Givens rotation on both sides of tridiagonal matrix
266
-
267
- # middle block
268
- d = D[k] - D[k + 1]
269
- z = (two * c * L[k] + d * s) * s
270
- D[k] -= z
271
- D[k + 1] += z
272
- L[k] = d * c * s + (c * c - s * s) * L[k]
273
-
274
- if k > start:
275
- L[k - 1] = c * x - s * y
276
-
277
- x = L[k]
278
- y = -s * L[k + 1] # new buldge
279
- L[k + 1] *= c
280
-
281
- # apply givens rotation on left of Q
282
- # note: Q is transposed compared to usual impls, as Warp makes it easier to index rows
283
- Qk0 = Q[k]
284
- Qk1 = Q[k + 1]
285
- Q[k] = c * Qk0 - s * Qk1
286
- Q[k + 1] = c * Qk1 + s * Qk0
287
-
288
- if end <= 0:
289
- # We did nothing, so diagonalization must have been achieved
290
- break
291
-
292
- return D, Q
293
-
294
-
295
- @wp.func
296
- def symmetric_eigenvalues_qr(A: Any, tol: Any):
297
- """
298
- Computes the eigenvalues and eigen vectors of a square symmetric matrix A using the QR algorithm
299
-
300
- Args:
301
- A: square symmetric matrix
302
- tol: Tolerance for the diagonalization residual (Linf norm of off-diagonal over diagonal terms)
303
-
304
- Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
305
- """
306
-
307
- # Put A under Hessenberg form (tridiagonal)
308
- Q, H = householder_make_hessenberg(A)
309
-
310
- # tridiagonal storage for H
311
- D = wp.get_diag(H)
312
- L = type(D)(A.dtype(0.0))
313
- for i in range(1, type(D).length):
314
- L[i - 1] = H[i, i - 1]
315
-
316
- Qt = wp.transpose(Q)
317
- ev, P = tridiagonal_symmetric_eigenvalues_qr(D, L, Qt, tol)
318
- return ev, P
319
-
320
-
321
16
  def compress_node_indices(
322
17
  node_count: int,
323
18
  node_indices: wp.array(dtype=int),
@@ -458,20 +153,6 @@ def masked_indices(
458
153
  return indices_temp, offsets_temp
459
154
 
460
155
 
461
- def array_axpy(x: wp.array, y: wp.array, alpha: float = 1.0, beta: float = 1.0):
462
- """Performs y = alpha*x + beta*y"""
463
-
464
- dtype = wp.types.type_scalar_type(x.dtype)
465
-
466
- alpha = dtype(alpha)
467
- beta = dtype(beta)
468
-
469
- if not wp.types.types_equal(x.dtype, y.dtype) or x.shape != y.shape or x.device != y.device:
470
- raise ValueError("x and y arrays must have same dat atype, shape and device")
471
-
472
- wp.launch(kernel=_array_axpy_kernel, dim=x.shape, device=x.device, inputs=[x, y, alpha, beta])
473
-
474
-
475
156
  @wp.kernel
476
157
  def _iota_kernel(indices: wp.array(dtype=int), divisor: int):
477
158
  indices[wp.tid()] = wp.tid() // divisor
@@ -515,12 +196,6 @@ def _masked_indices_kernel(
515
196
  masked_to_global[masked_idx] = i
516
197
 
517
198
 
518
- @wp.kernel
519
- def _array_axpy_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any), alpha: Any, beta: Any):
520
- i = wp.tid()
521
- y[i] = beta * y[i] + alpha * x[i]
522
-
523
-
524
199
  def grid_to_tris(Nx: int, Ny: int):
525
200
  """Constructs a triangular mesh topology by dividing each cell of a dense 2D grid into two triangles.
526
201
 
warp/jax_experimental.py CHANGED
@@ -102,7 +102,9 @@ def _warp_custom_callback(stream, buffers, opaque, opaque_len):
102
102
  assert hooks.forward, "Failed to find kernel entry point"
103
103
 
104
104
  # Launch the kernel.
105
- wp.context.runtime.core.cuda_launch_kernel(device.context, hooks.forward, bounds.size, 0, kernel_params, stream)
105
+ wp.context.runtime.core.cuda_launch_kernel(
106
+ device.context, hooks.forward, bounds.size, 0, 256, hooks.forward_smem_bytes, kernel_params, stream
107
+ )
106
108
 
107
109
 
108
110
  # TODO: is there a simpler way of getting the Jax "current" device?
warp/native/array.h CHANGED
@@ -1,3 +1,11 @@
1
+ /** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ * and proprietary rights in and to this software, related documentation
4
+ * and any modifications thereto. Any use, reproduction, disclosure or
5
+ * distribution of this software and related documentation without an express
6
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+ */
8
+
1
9
  #pragma once
2
10
 
3
11
  #include "builtin.h"
@@ -285,6 +293,13 @@ CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i)
285
293
  template <typename T>
286
294
  CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j)
287
295
  {
296
+ // if (i < 0 || i >= arr.shape[0])
297
+ // printf("i: %d > arr.shape[0]: %d\n", i, arr.shape[0]);
298
+
299
+ // if (j < 0 || j >= arr.shape[1])
300
+ // printf("j: %d > arr.shape[1]: %d\n", j, arr.shape[1]);
301
+
302
+
288
303
  assert(i >= 0 && i < arr.shape[0]);
289
304
  assert(j >= 0 && j < arr.shape[1]);
290
305
 
warp/native/builtin.h CHANGED
@@ -1145,7 +1145,47 @@ struct launch_bounds_t
1145
1145
  size_t size; // total number of threads
1146
1146
  };
1147
1147
 
1148
- inline CUDA_CALLABLE int tid(size_t index)
1148
+ // represents coordinate in the launch grid
1149
+ struct launch_coord_t
1150
+ {
1151
+ int i;
1152
+ int j;
1153
+ int k;
1154
+ int l;
1155
+ };
1156
+
1157
+ // unravels a linear thread index to the corresponding launch grid coord (up to 4d)
1158
+ inline CUDA_CALLABLE launch_coord_t launch_coord(size_t linear, const launch_bounds_t& bounds)
1159
+ {
1160
+ launch_coord_t coord = {0, 0, 0, 0};
1161
+
1162
+ if (bounds.ndim > 3)
1163
+ {
1164
+ coord.l = linear%bounds.shape[3];
1165
+ linear /= bounds.shape[3];
1166
+ }
1167
+
1168
+ if (bounds.ndim > 2)
1169
+ {
1170
+ coord.k = linear%bounds.shape[2];
1171
+ linear /= bounds.shape[2];
1172
+ }
1173
+
1174
+ if (bounds.ndim > 1)
1175
+ {
1176
+ coord.j = linear%bounds.shape[1];
1177
+ linear /= bounds.shape[1];
1178
+ }
1179
+
1180
+ if (bounds.ndim > 0)
1181
+ {
1182
+ coord.i = linear;
1183
+ }
1184
+
1185
+ return coord;
1186
+ }
1187
+
1188
+ inline CUDA_CALLABLE int tid(size_t index, const launch_bounds_t& bounds)
1149
1189
  {
1150
1190
  // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1151
1191
  // Only do this in _DEBUG when called from device to avoid excessive register allocation
@@ -1154,40 +1194,33 @@ inline CUDA_CALLABLE int tid(size_t index)
1154
1194
  printf("Warp warning: tid() is returning an overflowed int\n");
1155
1195
  }
1156
1196
  #endif
1157
- return static_cast<int>(index);
1197
+
1198
+ launch_coord_t c = launch_coord(index, bounds);
1199
+ return static_cast<int>(c.i);
1158
1200
  }
1159
1201
 
1160
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
1202
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& bounds)
1161
1203
  {
1162
- const size_t n = launch_bounds.shape[1];
1163
-
1164
- // convert to work item
1165
- i = index/n;
1166
- j = index%n;
1204
+ launch_coord_t c = launch_coord(index, bounds);
1205
+ i = c.i;
1206
+ j = c.j;
1167
1207
  }
1168
1208
 
1169
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
1209
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& bounds)
1170
1210
  {
1171
- const size_t n = launch_bounds.shape[1];
1172
- const size_t o = launch_bounds.shape[2];
1173
-
1174
- // convert to work item
1175
- i = index/(n*o);
1176
- j = index%(n*o)/o;
1177
- k = index%o;
1211
+ launch_coord_t c = launch_coord(index, bounds);
1212
+ i = c.i;
1213
+ j = c.j;
1214
+ k = c.k;
1178
1215
  }
1179
1216
 
1180
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
1217
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& bounds)
1181
1218
  {
1182
- const size_t n = launch_bounds.shape[1];
1183
- const size_t o = launch_bounds.shape[2];
1184
- const size_t p = launch_bounds.shape[3];
1185
-
1186
- // convert to work item
1187
- i = index/(n*o*p);
1188
- j = index%(n*o*p)/(o*p);
1189
- k = index%(o*p)/p;
1190
- l = index%p;
1219
+ launch_coord_t c = launch_coord(index, bounds);
1220
+ i = c.i;
1221
+ j = c.j;
1222
+ k = c.k;
1223
+ l = c.l;
1191
1224
  }
1192
1225
 
1193
1226
  template<typename T>
@@ -1724,3 +1757,10 @@ inline CUDA_CALLABLE void adj_expect_near(const vec3& actual, const vec3& expect
1724
1757
  #include "rand.h"
1725
1758
  #include "noise.h"
1726
1759
  #include "matnn.h"
1760
+
1761
+ // only include in kernels for now
1762
+ #if defined(__CUDACC_RTC__)
1763
+ #include "tile.h"
1764
+ #include "tile_gemm.h"
1765
+ #include "tile_reduce.h"
1766
+ #endif