warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/fem/linalg.py ADDED
@@ -0,0 +1,404 @@
1
+ from typing import Any
2
+
3
+ import warp as wp
4
+
5
+
6
+ @wp.func
7
+ def generalized_outer(x: Any, y: Any):
8
+ """Generalized outer product allowing for the first argument to be a scalar"""
9
+ return wp.outer(x, y)
10
+
11
+
12
+ @wp.func
13
+ def generalized_outer(x: wp.float32, y: wp.vec2):
14
+ return x * y
15
+
16
+
17
+ @wp.func
18
+ def generalized_outer(x: wp.float32, y: wp.vec3):
19
+ return x * y
20
+
21
+
22
+ @wp.func
23
+ def generalized_inner(x: Any, y: Any):
24
+ """Generalized inner product allowing for the first argument to be a tensor"""
25
+ return wp.dot(x, y)
26
+
27
+
28
+ @wp.func
29
+ def generalized_inner(x: float, y: float):
30
+ return x * y
31
+
32
+
33
+ @wp.func
34
+ def generalized_inner(x: wp.mat22, y: wp.vec2):
35
+ return x[0] * y[0] + x[1] * y[1]
36
+
37
+
38
+ @wp.func
39
+ def generalized_inner(x: wp.mat33, y: wp.vec3):
40
+ return x[0] * y[0] + x[1] * y[1] + x[2] * y[2]
41
+
42
+
43
+ @wp.func
44
+ def basis_element(template_type: Any, coord: int):
45
+ """Returns a instance of `template_type` with a single coordinate set to 1 in the canonical basis"""
46
+
47
+ t = type(template_type)(0.0)
48
+ t[coord] = 1.0
49
+ return t
50
+
51
+
52
+ @wp.func
53
+ def basis_element(template_type: wp.float32, coord: int):
54
+ return 1.0
55
+
56
+
57
+ @wp.func
58
+ def basis_element(template_type: wp.mat22, coord: int):
59
+ t = wp.mat22(0.0)
60
+ row = coord // 2
61
+ col = coord - 2 * row
62
+ t[row, col] = 1.0
63
+ return t
64
+
65
+
66
+ @wp.func
67
+ def basis_element(template_type: wp.mat33, coord: int):
68
+ t = wp.mat33(0.0)
69
+ row = coord // 3
70
+ col = coord - 3 * row
71
+ t[row, col] = 1.0
72
+ return t
73
+
74
+
75
+ @wp.func
76
+ def basis_coefficient(val: wp.float32, i: int):
77
+ return val
78
+
79
+
80
+ @wp.func
81
+ def basis_coefficient(val: Any, i: int):
82
+ return val[i]
83
+
84
+
85
+ @wp.func
86
+ def basis_coefficient(val: wp.vec2, i: int, j: int):
87
+ # treat as row vector
88
+ return val[j]
89
+
90
+
91
+ @wp.func
92
+ def basis_coefficient(val: wp.vec3, i: int, j: int):
93
+ # treat as row vector
94
+ return val[j]
95
+
96
+
97
+ @wp.func
98
+ def basis_coefficient(val: Any, i: int, j: int):
99
+ return val[i, j]
100
+
101
+
102
+ @wp.func
103
+ def basis_coefficient(template_type: wp.mat33, coord: int):
104
+ t = wp.mat33(0.0)
105
+ row = coord // 3
106
+ col = coord - 3 * row
107
+ t[row, col] = 1.0
108
+ return t
109
+
110
+
111
+ @wp.func
112
+ def symmetric_part(x: Any):
113
+ """Symmetric part of a square tensor"""
114
+ return 0.5 * (x + wp.transpose(x))
115
+
116
+
117
+ @wp.func
118
+ def spherical_part(x: wp.mat22):
119
+ """Spherical part of a square tensor"""
120
+ return 0.5 * wp.trace(x) * wp.identity(n=2, dtype=float)
121
+
122
+
123
+ @wp.func
124
+ def spherical_part(x: wp.mat33):
125
+ """Spherical part of a square tensor"""
126
+ return (wp.trace(x) / 3.0) * wp.identity(n=3, dtype=float)
127
+
128
+
129
+ @wp.func
130
+ def skew_part(x: wp.mat22):
131
+ """Skew part of a 2x2 tensor as corresponding rotation angle"""
132
+ return 0.5 * (x[1, 0] - x[0, 1])
133
+
134
+
135
+ @wp.func
136
+ def skew_part(x: wp.mat33):
137
+ """Skew part of a 3x3 tensor as the corresponding rotation vector"""
138
+ a = 0.5 * (x[2, 1] - x[1, 2])
139
+ b = 0.5 * (x[0, 2] - x[2, 0])
140
+ c = 0.5 * (x[1, 0] - x[0, 1])
141
+ return wp.vec3(a, b, c)
142
+
143
+
144
+ @wp.func
145
+ def householder_qr_decomposition(A: Any):
146
+ """
147
+ QR decomposition of a square matrix using Householder reflections
148
+
149
+ Returns Q and R such that Q R = A, Q orthonormal (such that QQ^T = Id), R upper triangular
150
+ """
151
+
152
+ x = type(A[0])()
153
+ Q = wp.identity(n=type(x).length, dtype=A.dtype)
154
+
155
+ zero = x.dtype(0.0)
156
+ two = x.dtype(2.0)
157
+
158
+ for i in range(type(x).length):
159
+ for k in range(type(x).length):
160
+ x[k] = wp.select(k < i, A[k, i], zero)
161
+
162
+ alpha = wp.length(x) * wp.sign(x[i])
163
+ x[i] += alpha
164
+ two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
165
+
166
+ A -= wp.outer(two_over_x_sq * x, x * A)
167
+ Q -= wp.outer(Q * x, two_over_x_sq * x)
168
+
169
+ return Q, A
170
+
171
+
172
+ @wp.func
173
+ def householder_make_hessenberg(A: Any):
174
+ """Transforms a square matrix to Hessenberg form (single lower diagonal) using Householder reflections
175
+
176
+ Returns:
177
+ Q and H such that Q H Q^T = A, Q orthonormal, H under Hessenberg form
178
+ If A is symmetric, H will be tridiagonal
179
+ """
180
+
181
+ x = type(A[0])()
182
+ Q = wp.identity(n=type(x).length, dtype=A.dtype)
183
+
184
+ zero = x.dtype(0.0)
185
+ two = x.dtype(2.0)
186
+
187
+ for i in range(1, type(x).length):
188
+ for k in range(type(x).length):
189
+ x[k] = wp.select(k < i, A[k, i - 1], zero)
190
+
191
+ alpha = wp.length(x) * wp.sign(x[i])
192
+ x[i] += alpha
193
+ two_over_x_sq = wp.select(alpha == zero, two / wp.length_sq(x), zero)
194
+
195
+ # apply on both sides
196
+ A -= wp.outer(two_over_x_sq * x, x * A)
197
+ A -= wp.outer(A * x, two_over_x_sq * x)
198
+ Q -= wp.outer(Q * x, two_over_x_sq * x)
199
+
200
+ return Q, A
201
+
202
+
203
+ @wp.func
204
+ def solve_triangular(R: Any, b: Any):
205
+ """Solves for R x = b where R is an upper triangular matrix
206
+
207
+ Returns x
208
+ """
209
+ zero = b.dtype(0)
210
+ x = type(b)(b.dtype(0))
211
+ for i in range(b.length, 0, -1):
212
+ j = i - 1
213
+ r = b[j] - wp.dot(R[j], x)
214
+ x[j] = wp.select(R[j, j] == zero, r / R[j, j], zero)
215
+
216
+ return x
217
+
218
+
219
+ @wp.func
220
+ def inverse_qr(A: Any):
221
+ # Computes a square matrix inverse using QR factorization
222
+
223
+ Q, R = householder_qr_decomposition(A)
224
+
225
+ A_inv = type(A)()
226
+ for i in range(type(A[0]).length):
227
+ A_inv[i] = solve_triangular(R, Q[i]) # ith column of Q^T
228
+
229
+ return wp.transpose(A_inv)
230
+
231
+
232
+ @wp.func
233
+ def _wilkinson_shift(a: Any, b: Any, c: Any, tol: Any):
234
+ # Wilkinson shift: estimate eigenvalue of 2x2 symmetric matrix [a, c, c, b]
235
+ d = (a - b) * type(tol)(0.5)
236
+ return b + d - wp.sign(d) * wp.sqrt(d * d + c * c)
237
+
238
+
239
+ @wp.func
240
+ def _givens_rotation(a: Any, b: Any):
241
+ # Givens rotation [[c -s], [s c]] such that sa+cb =0
242
+ zero = type(a)(0.0)
243
+ one = type(a)(1.0)
244
+
245
+ b2 = b * b
246
+ if b2 == zero:
247
+ # id rotation
248
+ return one, zero
249
+
250
+ scale = one / wp.sqrt(a * a + b2)
251
+ return a * scale, -b * scale
252
+
253
+
254
+ @wp.func
255
+ def tridiagonal_symmetric_eigenvalues_qr(D: Any, L: Any, Q: Any, tol: Any):
256
+ """
257
+ Computes the eigenvalues and eigen vectors of a symmetric tridiagonal matrix using the
258
+ Symmetric tridiagonal QR algorithm with implicit Wilkinson shift
259
+
260
+ Args:
261
+ D: Main diagonal of the matrix
262
+ L: Lower diagonal of the matrix, indexed such that L[i] = A[i+1, i]
263
+ Q: Initialization for the eigenvectors, useful if a pre-transformation has been applied, otherwise may be identity
264
+ tol: Tolerance for the diagonalization residual (Linf norm of off-diagonal over diagonal terms)
265
+
266
+ Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
267
+
268
+
269
+ Ref: Arbenz P, Numerical Methods for Solving Large Scale Eigenvalue Problems, Chapter 4 (QR algorithm, Mar 13, 2018)
270
+ """
271
+
272
+ two = D.dtype(2.0)
273
+
274
+ # so that we can use the type length in expressions
275
+ # this will prevent unrolling by warp, but should be ok for native code
276
+ m = int(0)
277
+ for _ in range(type(D).length):
278
+ m += 1
279
+
280
+ start = int(0)
281
+ y = D.dtype(0.0) # moving buldge
282
+ x = D.dtype(0.0) # coeff atop buldge
283
+
284
+ for _ in range(32 * m): # failsafe, usually converges faster than that
285
+ # Iterate over all independent (deflated) blocks
286
+ end = int(-1)
287
+
288
+ for k in range(m - 1):
289
+ if k >= end:
290
+ # Check if new block is starting
291
+ if k == end or wp.abs(L[k]) <= tol * (wp.abs(D[k]) + wp.abs(D[k + 1])):
292
+ continue
293
+
294
+ # Find end of block
295
+ start = k
296
+ end = start + 1
297
+ while end + 1 < m:
298
+ if wp.abs(L[end]) <= tol * (wp.abs(D[end + 1]) + wp.abs(D[end])):
299
+ break
300
+ end += 1
301
+
302
+ # Wilkinson shift (an eigenvalue of the last 2x2 block)
303
+ shift = _wilkinson_shift(D[end - 1], D[end], L[end - 1], tol)
304
+
305
+ # start with eliminating lower diag of first column of shifted matrix
306
+ # (i.e. first step of explicit QR factorization)
307
+ # Then all further steps eliminate the buldge (second diag) of the non-shifted matrix
308
+ x = D[start] - shift
309
+ y = L[start]
310
+
311
+ c, s = _givens_rotation(x, y)
312
+
313
+ # Apply Givens rotation on both sides of tridiagonal matrix
314
+
315
+ # middle block
316
+ d = D[k] - D[k + 1]
317
+ z = (two * c * L[k] + d * s) * s
318
+ D[k] -= z
319
+ D[k + 1] += z
320
+ L[k] = d * c * s + (c * c - s * s) * L[k]
321
+
322
+ if k > start:
323
+ L[k - 1] = c * x - s * y
324
+
325
+ x = L[k]
326
+ y = -s * L[k + 1] # new buldge
327
+ L[k + 1] *= c
328
+
329
+ # apply givens rotation on left of Q
330
+ # note: Q is transposed compared to usual impls, as Warp makes it easier to index rows
331
+ Qk0 = Q[k]
332
+ Qk1 = Q[k + 1]
333
+ Q[k] = c * Qk0 - s * Qk1
334
+ Q[k + 1] = c * Qk1 + s * Qk0
335
+
336
+ if end <= 0:
337
+ # We did nothing, so diagonalization must have been achieved
338
+ break
339
+
340
+ return D, Q
341
+
342
+
343
+ @wp.func
344
+ def symmetric_eigenvalues_qr(A: Any, tol: Any):
345
+ """
346
+ Computes the eigenvalues and eigen vectors of a square symmetric matrix A using the QR algorithm
347
+
348
+ Args:
349
+ A: square symmetric matrix
350
+ tol: Tolerance for the diagonalization residual (Linf norm of off-diagonal over diagonal terms)
351
+
352
+ Returns a tuple (D: vector of eigenvalues, P: matrix with one eigenvector per row) such that A = P^T D P
353
+ """
354
+
355
+ # Put A under Hessenberg form (tridiagonal)
356
+ Q, H = householder_make_hessenberg(A)
357
+
358
+ # tridiagonal storage for H
359
+ D = wp.get_diag(H)
360
+ L = type(D)(A.dtype(0.0))
361
+ for i in range(1, type(D).length):
362
+ L[i - 1] = H[i, i - 1]
363
+
364
+ Qt = wp.transpose(Q)
365
+ ev, P = tridiagonal_symmetric_eigenvalues_qr(D, L, Qt, tol)
366
+ return ev, P
367
+
368
+
369
+ def array_axpy(x: wp.array, y: wp.array, alpha: float = 1.0, beta: float = 1.0):
370
+ """Performs y = alpha*x + beta*y"""
371
+
372
+ dtype = wp.types.type_scalar_type(y.dtype)
373
+
374
+ alpha = dtype(alpha)
375
+ beta = dtype(beta)
376
+
377
+ if x.shape != y.shape or x.device != y.device:
378
+ raise ValueError("x and y arrays must have the same shape and device")
379
+
380
+ # array_axpy requires a custom adjoint; unfortunately we cannot use `wp.func_grad`
381
+ # as generic functions are not supported yet. Instead we use a non-differentiable kernel
382
+ # and record a custom adjoint function on the tape.
383
+
384
+ # temporarily disable tape to avoid printing warning that kernel is not differentiable
385
+ (tape, wp.context.runtime.tape) = (wp.context.runtime.tape, None)
386
+ wp.launch(kernel=_array_axpy_kernel, dim=x.shape, device=x.device, inputs=[x, y, alpha, beta])
387
+ wp.context.runtime.tape = tape
388
+
389
+ if tape is not None and (x.requires_grad or y.requires_grad):
390
+
391
+ def backward_axpy():
392
+ # adj_x += adj_y * alpha
393
+ # adj_y = adj_y * beta
394
+ array_axpy(x=y.grad, y=x.grad, alpha=alpha, beta=1.0)
395
+ if beta != 1.0:
396
+ array_axpy(x=y.grad, y=y.grad, alpha=0.0, beta=beta)
397
+
398
+ tape.record_func(backward_axpy, arrays=[x, y])
399
+
400
+
401
+ @wp.kernel(enable_backward=False)
402
+ def _array_axpy_kernel(x: wp.array(dtype=Any), y: wp.array(dtype=Any), alpha: Any, beta: Any):
403
+ i = wp.tid()
404
+ y[i] = beta * y[i] + alpha * y.dtype(x[i])
warp/fem/operator.py CHANGED
@@ -1,8 +1,8 @@
1
- from typing import Any, Callable
1
+ from typing import Any, Callable, Dict, Optional, Set
2
2
 
3
3
  import warp as wp
4
- from warp.fem import utils
5
- from warp.fem.types import Domain, Field, NodeIndex, Sample
4
+ from warp.fem.linalg import skew_part, symmetric_part
5
+ from warp.fem.types import Coords, Domain, ElementIndex, Field, NodeIndex, Sample, make_free_sample
6
6
 
7
7
 
8
8
  class Integrand:
@@ -10,35 +10,55 @@ class Integrand:
10
10
  It will get transformed to a proper warp.Function by resolving concrete Field types at call time.
11
11
  """
12
12
 
13
- def __init__(self, func: Callable):
13
+ def __init__(self, func: Callable, kernel_options: Optional[Dict[str, Any]] = None):
14
14
  self.func = func
15
15
  self.name = wp.codegen.make_full_qualified_name(self.func)
16
16
  self.module = wp.get_module(self.func.__module__)
17
17
  self.argspec = wp.codegen.get_full_arg_spec(self.func)
18
+ self.kernel_options = {} if kernel_options is None else kernel_options
19
+
20
+ # Operators for each field argument. This will be populated at first integrate call
21
+ self.operators: Dict[str, Set[Operator]] = None
18
22
 
19
23
 
20
24
  class Operator:
21
25
  """
22
- Operators provide syntaxic sugar over Field and Domain evaluation functions and arguments
26
+ Operators provide syntactic sugar over Field and Domain evaluation functions and arguments
23
27
  """
24
28
 
25
- def __init__(self, func: Callable, resolver: Callable):
29
+ def __init__(self, func: Callable, resolver: Callable, field_result: Callable = None):
26
30
  self.func = func
31
+ self.name = func.__name__
27
32
  self.resolver = resolver
33
+ self.field_result = field_result
34
+
35
+
36
+ def integrand(func: Callable = None, kernel_options: Optional[Dict[str, Any]] = None):
37
+ """Decorator for functions to be integrated (or interpolated) using warp.fem
28
38
 
39
+ Args:
40
+ func: Decorated function
41
+ kernel_options: Supplemental code-generation options to be passed to the generated kernel.
42
+ """
29
43
 
30
- def integrand(func: Callable):
31
- """Decorator for functions to be integrated (or interpolated) using warp.fem"""
32
- itg = Integrand(func)
33
- itg.__doc__ = func.__doc__
34
- return itg
44
+ if func is not None:
45
+ itg = Integrand(func)
46
+ itg.__doc__ = func.__doc__
47
+ return itg
35
48
 
49
+ def wrap_integrand(func: Callable):
50
+ itg = Integrand(func, kernel_options)
51
+ itg.__doc__ = func.__doc__
52
+ return itg
36
53
 
37
- def operator(resolver: Callable):
54
+ return wrap_integrand
55
+
56
+
57
+ def operator(**kwargs):
38
58
  """Decorator for functions operating on Field-like or Domain-like data inside warp.fem integrands"""
39
59
 
40
60
  def wrap_operator(func: Callable):
41
- op = Operator(func, resolver)
61
+ op = Operator(func, **kwargs)
42
62
  op.__doc__ = func.__doc__
43
63
  return op
44
64
 
@@ -56,7 +76,7 @@ def position(domain: Domain, s: Sample):
56
76
 
57
77
  @operator(resolver=lambda dmn: dmn.element_normal)
58
78
  def normal(domain: Domain, s: Sample):
59
- """Evaluates the element normal at the sample point `s`. Null for interior points."""
79
+ """Evaluates the element normal at the sample point `s`. Non zero if the element is a side or the geometry is embedded in a higher-dimensional space (e.g. :class:`Trimesh3D`)"""
60
80
  pass
61
81
 
62
82
 
@@ -75,7 +95,7 @@ def lookup(domain: Domain, x: Any) -> Sample:
75
95
  guess: (optional) :class:`Sample` initial guess, may help perform the query
76
96
 
77
97
  Note:
78
- Currently this operator is unsupported for :class:`Hexmesh`, :class:`Quadmesh2D` and deformed geometries.
98
+ Currently this operator is unsupported for :class:`Hexmesh`, :class:`Quadmesh2D`, :class:`Quadmesh3D` and deformed geometries.
79
99
  """
80
100
  pass
81
101
 
@@ -88,10 +108,73 @@ def measure(domain: Domain, s: Sample) -> float:
88
108
 
89
109
  @operator(resolver=lambda dmn: dmn.element_measure_ratio)
90
110
  def measure_ratio(domain: Domain, s: Sample) -> float:
91
- """Returns the maximum ratio between the measure of this element and that of higher-dimensional neighbours."""
111
+ """Returns the maximum ratio between the measure of this element and that of higher-dimensional neighbors."""
112
+ pass
113
+
114
+
115
+ # Operators for evaluating cell-level quantities on domains defined on sides
116
+
117
+
118
+ @operator(
119
+ resolver=lambda dmn: dmn.domain_cell_arg, field_result=lambda dmn: (dmn.cell_domain(), Domain, dmn.geometry.CellArg)
120
+ )
121
+ def cells(domain: Domain) -> Domain:
122
+ """Converts a domain defined on geometry sides to a domain defined of cells."""
123
+ pass
124
+
125
+
126
+ @operator(resolver=lambda dmn: dmn.element_inner_cell_index)
127
+ def _inner_cell_index(domain: Domain, side_index: ElementIndex, side_coords: Coords) -> Sample:
128
+ pass
129
+
130
+
131
+ @operator(resolver=lambda dmn: dmn.element_outer_cell_index)
132
+ def _outer_cell_index(domain: Domain, side_index: ElementIndex, side_coords: Coords) -> Sample:
92
133
  pass
93
134
 
94
135
 
136
+ @operator(resolver=lambda dmn: dmn.element_inner_cell_coords)
137
+ def _inner_cell_coords(domain: Domain, side_index: ElementIndex, side_coords: Coords) -> Sample:
138
+ pass
139
+
140
+
141
+ @operator(resolver=lambda dmn: dmn.element_outer_cell_coords)
142
+ def _outer_cell_coords(domain: Domain, side_index: ElementIndex, side_coords: Coords) -> Sample:
143
+ pass
144
+
145
+
146
+ @operator(resolver=lambda dmn: dmn.cell_to_element_coords)
147
+ def _cell_to_element_coords(
148
+ domain: Domain, side_index: ElementIndex, cell_index: ElementIndex, cell_coords: Coords
149
+ ) -> Sample:
150
+ pass
151
+
152
+
153
+ @integrand
154
+ def to_inner_cell(domain: Domain, s: Sample):
155
+ """Converts a :class:`Sample` defined on a side to a sample defined on the side's inner cell"""
156
+ return make_free_sample(
157
+ _inner_cell_index(domain, s.element_index), _inner_cell_coords(domain, s.element_index, s.element_coords)
158
+ )
159
+
160
+
161
+ @integrand
162
+ def to_outer_cell(domain: Domain, s: Sample):
163
+ """Converts a :class:`Sample` defined on a side to a sample defined on the side's outer cell"""
164
+ return make_free_sample(
165
+ _outer_cell_index(domain, s.element_index), _outer_cell_coords(domain, s.element_index, s.element_coords)
166
+ )
167
+
168
+
169
+ @integrand
170
+ def to_cell_side(domain: Domain, cell_s: Sample, side_index: ElementIndex):
171
+ """Converts a :class:`Sample` defined on a cell to a sample defined on one of its side.
172
+ If the result does not lie on the side `side_index`, the resulting coordinates will be set to ``OUTSIDE``."""
173
+ return make_free_sample(
174
+ side_index, _cell_to_element_coords(domain, side_index, cell_s.element_index, cell_s.element_coords)
175
+ )
176
+
177
+
95
178
  # Field operators
96
179
  # On a side, inner and outer are such that normal goes from inner to outer
97
180
 
@@ -157,13 +240,13 @@ def node_partition_index(f: Field, node_index: NodeIndex):
157
240
  @integrand
158
241
  def D(f: Field, s: Sample):
159
242
  """Symmetric part of the (inner) gradient of the field at `s`"""
160
- return utils.symmetric_part(grad(f, s))
243
+ return symmetric_part(grad(f, s))
161
244
 
162
245
 
163
246
  @integrand
164
247
  def curl(f: Field, s: Sample):
165
248
  """Skew part of the (inner) gradient of the field at `s`, as a vector such that ``wp.cross(curl(u), v) = skew(grad(u)) v``"""
166
- return utils.skew_part(grad(f, s))
249
+ return skew_part(grad(f, s))
167
250
 
168
251
 
169
252
  @integrand
warp/fem/polynomial.py CHANGED
@@ -7,20 +7,20 @@ import numpy as np
7
7
  class Polynomial(Enum):
8
8
  """Polynomial family defining interpolation nodes over an interval"""
9
9
 
10
- GAUSS_LEGENDRE = 0
10
+ GAUSS_LEGENDRE = "GL"
11
11
  """Gauss--Legendre 1D polynomial family (does not include endpoints)"""
12
12
 
13
- LOBATTO_GAUSS_LEGENDRE = 1
13
+ LOBATTO_GAUSS_LEGENDRE = "LGL"
14
14
  """Lobatto--Gauss--Legendre 1D polynomial family (includes endpoints)"""
15
15
 
16
- EQUISPACED_CLOSED = 2
16
+ EQUISPACED_CLOSED = "closed"
17
17
  """Closed 1D polynomial family with uniformly distributed nodes (includes endpoints)"""
18
18
 
19
- EQUISPACED_OPEN = 3
19
+ EQUISPACED_OPEN = "open"
20
20
  """Open 1D polynomial family with uniformly distributed nodes (does not include endpoints)"""
21
21
 
22
22
  def __str__(self):
23
- return self.name
23
+ return self.value
24
24
 
25
25
 
26
26
  def is_closed(family: Polynomial):