warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,375 @@
1
+ # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ ###########################################################################
9
+ # Example Image Multilayer Perceptron (MLP)
10
+ #
11
+ # Shows how to train a coordinate-based MLP on an image to predict the RGB
12
+ # color at a given input position. By default, a positional encoding is
13
+ # applied to the input coordinates to improve the ability of the MLP to
14
+ # represent higher-frequency content. This can be disabled by passing the
15
+ # '--no_encoding' option.
16
+ #
17
+ # References:
18
+ # Ben Mildenhall et al. 2021. NeRF: representing scenes
19
+ # as neural radiance fields for view synthesis. Commun. ACM 65, 1
20
+ # (January 2022), 99–106. https://doi.org/10.1145/3503250
21
+ #
22
+ ###########################################################################
23
+
24
+ import math
25
+ import os
26
+
27
+ import numpy as np
28
+ from PIL import Image
29
+
30
+ import warp as wp
31
+ import warp.examples
32
+ import warp.optim
33
+
34
+ rng = np.random.default_rng(45)
35
+
36
+
37
+ def create_layer(dim_in, dim_hid, dtype=float):
38
+ w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
39
+ b = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, 1))
40
+
41
+ weights = wp.array(w, dtype=dtype, requires_grad=True)
42
+ bias = wp.array(b, dtype=dtype, requires_grad=True)
43
+
44
+ return (weights, bias)
45
+
46
+
47
+ def create_array(dim_in, dim_hid, dtype=float):
48
+ s = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
49
+ a = wp.array(s, dtype=dtype, requires_grad=True)
50
+
51
+ return a
52
+
53
+
54
+ # number of frequencies for the positional encoding
55
+ NUM_FREQ = wp.constant(8)
56
+
57
+ DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequenecy
58
+ DIM_HID = 32
59
+ DIM_OUT = 3
60
+
61
+ # threads per-block
62
+ NUM_THREADS = 32
63
+
64
+ IMG_WIDTH = 512
65
+ IMG_HEIGHT = 512
66
+
67
+ BATCH_SIZE = min(1024, int((IMG_WIDTH * IMG_HEIGHT) / 8))
68
+
69
+ # dtype for our weights and bias matrices
70
+ dtype = wp.float16
71
+
72
+
73
+ @wp.func
74
+ def relu(x: dtype):
75
+ return wp.max(x, dtype(0.0))
76
+
77
+
78
+ @wp.kernel
79
+ def compute(
80
+ indices: wp.array(dtype=int),
81
+ weights_0: wp.array2d(dtype=dtype),
82
+ bias_0: wp.array2d(dtype=dtype),
83
+ weights_1: wp.array2d(dtype=dtype),
84
+ bias_1: wp.array2d(dtype=dtype),
85
+ weights_2: wp.array2d(dtype=dtype),
86
+ bias_2: wp.array2d(dtype=dtype),
87
+ weights_3: wp.array2d(dtype=dtype),
88
+ bias_3: wp.array2d(dtype=dtype),
89
+ reference: wp.array2d(dtype=float),
90
+ loss: wp.array1d(dtype=float),
91
+ out: wp.array2d(dtype=float),
92
+ ):
93
+ # batch indices
94
+ linear = indices[wp.tid()]
95
+
96
+ row = linear / IMG_WIDTH
97
+ col = linear % IMG_WIDTH
98
+
99
+ # normalize input coordinates to [-1, 1]
100
+ x = (float(row) / float(IMG_WIDTH) - 0.5) * 2.0
101
+ y = (float(col) / float(IMG_HEIGHT) - 0.5) * 2.0
102
+
103
+ local = wp.vector(dtype=dtype, length=DIM_IN)
104
+
105
+ # construct positional encoding
106
+ for s in range(NUM_FREQ):
107
+ scale = wp.pow(2.0, float(s)) * wp.pi
108
+
109
+ # x-coord
110
+ local[s * 4 + 0] = dtype(wp.sin(x * scale))
111
+ local[s * 4 + 1] = dtype(wp.cos(x * scale))
112
+ # y-coord
113
+ local[s * 4 + 2] = dtype(wp.sin(y * scale))
114
+ local[s * 4 + 3] = dtype(wp.cos(y * scale))
115
+
116
+ # tile feature vectors across the block, returns [dim(f), NUM_THREADS]
117
+ f = wp.tile(local)
118
+
119
+ # input layer
120
+ w0 = wp.tile_load(weights_0, 0, 0, m=DIM_HID, n=DIM_IN)
121
+ b0 = wp.tile_load(bias_0, 0, 0, m=DIM_HID, n=1)
122
+ z = wp.tile_map(relu, wp.tile_matmul(w0, f) + wp.tile_broadcast(b0, m=DIM_HID, n=NUM_THREADS))
123
+
124
+ # hidden layer
125
+ w1 = wp.tile_load(weights_1, 0, 0, m=DIM_HID, n=DIM_HID)
126
+ b1 = wp.tile_load(bias_1, 0, 0, m=DIM_HID, n=1)
127
+ z = wp.tile_map(relu, wp.tile_matmul(w1, z) + wp.tile_broadcast(b1, m=DIM_HID, n=NUM_THREADS))
128
+
129
+ w2 = wp.tile_load(weights_2, 0, 0, m=DIM_HID, n=DIM_HID)
130
+ b2 = wp.tile_load(bias_2, 0, 0, m=DIM_HID, n=1)
131
+ z = wp.tile_map(relu, wp.tile_matmul(w2, z) + wp.tile_broadcast(b2, m=DIM_HID, n=NUM_THREADS))
132
+
133
+ # output layer
134
+ w3 = wp.tile_load(weights_3, 0, 0, m=DIM_OUT, n=DIM_HID)
135
+ b3 = wp.tile_load(bias_3, 0, 0, m=DIM_OUT, n=1)
136
+ o = wp.tile_map(relu, wp.tile_matmul(w3, z) + wp.tile_broadcast(b3, m=DIM_OUT, n=NUM_THREADS))
137
+
138
+ # untile back to SIMT
139
+ output = wp.untile(o)
140
+
141
+ # compute error
142
+ error = wp.vec3(
143
+ float(output[0]) - reference[0, linear],
144
+ float(output[1]) - reference[1, linear],
145
+ float(output[2]) - reference[2, linear],
146
+ )
147
+
148
+ # write MSE loss
149
+ if loss:
150
+ wp.atomic_add(loss, 0, wp.length_sq(error) / float(3 * BATCH_SIZE))
151
+
152
+ # write image output
153
+ if out:
154
+ for i in range(DIM_OUT):
155
+ out[i, linear] = float(output[i])
156
+
157
+
158
+ class Example:
159
+ def __init__(self, train_iters):
160
+ self.weights_0, self.bias_0 = create_layer(DIM_IN, DIM_HID, dtype=dtype)
161
+ self.weights_1, self.bias_1 = create_layer(DIM_HID, DIM_HID, dtype=dtype)
162
+ self.weights_2, self.bias_2 = create_layer(DIM_HID, DIM_HID, dtype=dtype)
163
+ self.weights_3, self.bias_3 = create_layer(DIM_HID, DIM_OUT, dtype=dtype)
164
+
165
+ # reference
166
+ reference_path = os.path.join(wp.examples.get_asset_directory(), "pixel.jpg")
167
+ with Image.open(reference_path) as im:
168
+ reference_image = np.asarray(im.resize((IMG_WIDTH, IMG_HEIGHT)).convert("RGB")) / 255.0
169
+ self.reference = wp.array(reference_image.reshape(IMG_WIDTH * IMG_HEIGHT, 3).T, dtype=float)
170
+
171
+ # create randomized batch indices
172
+ indices = np.arange(0, IMG_WIDTH * IMG_HEIGHT, dtype=np.int32)
173
+ rng.shuffle(indices)
174
+ self.indices = wp.array(indices)
175
+
176
+ self.num_batches = int((IMG_WIDTH * IMG_HEIGHT) / BATCH_SIZE)
177
+ self.max_iters = train_iters
178
+ self.max_epochs = max(1, int(self.max_iters / self.num_batches))
179
+
180
+ def train_warp(self):
181
+ params = [
182
+ self.weights_0,
183
+ self.bias_0,
184
+ self.weights_1,
185
+ self.bias_1,
186
+ self.weights_2,
187
+ self.bias_2,
188
+ self.weights_3,
189
+ self.bias_3,
190
+ ]
191
+
192
+ optimizer_grads = [p.grad.flatten() for p in params]
193
+ optimizer_inputs = [p.flatten() for p in params]
194
+ optimizer = warp.optim.Adam(optimizer_inputs, lr=0.01)
195
+
196
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
197
+ output = create_array(IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
198
+
199
+ # capture graph for whole epoch
200
+ wp.capture_begin()
201
+
202
+ for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
203
+ loss.zero_()
204
+
205
+ with wp.Tape() as tape:
206
+ wp.launch(
207
+ compute,
208
+ dim=[BATCH_SIZE],
209
+ inputs=[
210
+ self.indices[b : b + BATCH_SIZE],
211
+ self.weights_0,
212
+ self.bias_0,
213
+ self.weights_1,
214
+ self.bias_1,
215
+ self.weights_2,
216
+ self.bias_2,
217
+ self.weights_3,
218
+ self.bias_3,
219
+ self.reference,
220
+ loss,
221
+ None,
222
+ ],
223
+ block_dim=NUM_THREADS,
224
+ )
225
+
226
+ tape.backward(loss)
227
+ optimizer.step(optimizer_grads)
228
+ tape.zero()
229
+
230
+ graph = wp.capture_end()
231
+
232
+ with wp.ScopedTimer("Training"):
233
+ for i in range(self.max_epochs):
234
+ with wp.ScopedTimer("Epoch"):
235
+ wp.capture_launch(graph)
236
+ print(f"Epoch: {i} Loss: {loss.numpy()}")
237
+
238
+ # evaluate full image
239
+ wp.launch(
240
+ compute,
241
+ dim=[IMG_WIDTH * IMG_HEIGHT],
242
+ inputs=[
243
+ self.indices,
244
+ self.weights_0,
245
+ self.bias_0,
246
+ self.weights_1,
247
+ self.bias_1,
248
+ self.weights_2,
249
+ self.bias_2,
250
+ self.weights_3,
251
+ self.bias_3,
252
+ self.reference,
253
+ loss,
254
+ output,
255
+ ],
256
+ block_dim=NUM_THREADS,
257
+ )
258
+
259
+ self.save_image("example_tile_mlp.jpg", output.numpy())
260
+
261
+ def train_torch(self):
262
+ import torch as tc
263
+
264
+ weights_0 = tc.nn.Parameter(wp.to_torch(self.weights_0))
265
+ weights_1 = tc.nn.Parameter(wp.to_torch(self.weights_1))
266
+ weights_2 = tc.nn.Parameter(wp.to_torch(self.weights_2))
267
+ weights_3 = tc.nn.Parameter(wp.to_torch(self.weights_3))
268
+
269
+ bias_0 = tc.nn.Parameter(wp.to_torch(self.bias_0))
270
+ bias_1 = tc.nn.Parameter(wp.to_torch(self.bias_1))
271
+ bias_2 = tc.nn.Parameter(wp.to_torch(self.bias_2))
272
+ bias_3 = tc.nn.Parameter(wp.to_torch(self.bias_3))
273
+
274
+ indices = wp.to_torch(self.indices)
275
+ reference = wp.to_torch(self.reference)
276
+
277
+ optimizer = tc.optim.Adam(
278
+ [weights_0, bias_0, weights_1, bias_1, weights_2, bias_2, weights_3, bias_3],
279
+ capturable=True,
280
+ lr=0.0001,
281
+ betas=(0.9, 0.95),
282
+ eps=1.0e-6,
283
+ )
284
+
285
+ # generate frequency space encoding of pixels
286
+ # based on their linear index in the image
287
+ def encode(linear):
288
+ row = (linear // IMG_WIDTH).float()
289
+ col = (linear % IMG_WIDTH).float()
290
+
291
+ x = (row / float(IMG_WIDTH) - 0.5) * 2.0
292
+ y = (col / float(IMG_HEIGHT) - 0.5) * 2.0
293
+
294
+ encoding = tc.zeros((NUM_FREQ * 4, len(linear)), dtype=tc.float16, device="cuda")
295
+
296
+ for s in range(NUM_FREQ):
297
+ scale = math.pow(2.0, float(s)) * math.pi
298
+
299
+ # Directly write the computed values into the encoding tensor
300
+ encoding[s * 4 + 0, :] = tc.sin(scale * x)
301
+ encoding[s * 4 + 1, :] = tc.cos(scale * x)
302
+ encoding[s * 4 + 2, :] = tc.sin(scale * y)
303
+ encoding[s * 4 + 3, :] = tc.cos(scale * y)
304
+
305
+ return encoding
306
+
307
+ stream = tc.cuda.Stream()
308
+ graph = tc.cuda.CUDAGraph()
309
+
310
+ # warm-up
311
+ with tc.cuda.stream(stream):
312
+ f = tc.rand((NUM_FREQ * 4, BATCH_SIZE), dtype=tc.float16, device="cuda")
313
+ z = tc.relu(weights_0 @ f + bias_0)
314
+ z = tc.relu(weights_1 @ z + bias_1)
315
+ z = tc.relu(weights_2 @ z + bias_2)
316
+ z = tc.relu(weights_3 @ z + bias_3)
317
+ ref = tc.rand((3, BATCH_SIZE), dtype=tc.float16, device="cuda")
318
+ loss = tc.mean((z - ref) ** 2)
319
+ optimizer.zero_grad()
320
+ loss.backward()
321
+ optimizer.step()
322
+
323
+ with tc.cuda.graph(graph):
324
+ for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
325
+ linear = indices[b : b + BATCH_SIZE]
326
+
327
+ f = encode(linear)
328
+
329
+ z = tc.relu(weights_0 @ f + bias_0)
330
+ z = tc.relu(weights_1 @ z + bias_1)
331
+ z = tc.relu(weights_2 @ z + bias_2)
332
+ z = tc.relu(weights_3 @ z + bias_3)
333
+
334
+ ref = reference[:, linear]
335
+ loss = tc.mean((z - ref) ** 2)
336
+
337
+ optimizer.zero_grad()
338
+ loss.backward()
339
+ optimizer.step()
340
+
341
+ with wp.ScopedTimer("Training (Torch)"):
342
+ for _i in range(self.max_epochs):
343
+ with wp.ScopedTimer("Epoch"):
344
+ graph.replay()
345
+
346
+ print(loss)
347
+
348
+ f = encode(tc.arange(0, IMG_WIDTH * IMG_HEIGHT))
349
+ z = tc.relu(weights_0 @ f + bias_0)
350
+ z = tc.relu(weights_1 @ z + bias_1)
351
+ z = tc.relu(weights_2 @ z + bias_2)
352
+ z = tc.relu(weights_3 @ z + bias_3)
353
+
354
+ self.save_image("example_tile_mlp_torch.jpg", z.detach().cpu().numpy())
355
+
356
+ def save_image(self, name, output):
357
+ predicted_image = output.T.reshape(IMG_WIDTH, IMG_HEIGHT, 3)
358
+ predicted_image = (predicted_image * 255).astype(np.uint8)
359
+
360
+ predicted_image_pil = Image.fromarray(predicted_image)
361
+ predicted_image_pil.save(name)
362
+
363
+
364
+ if __name__ == "__main__":
365
+ import argparse
366
+
367
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
368
+ parser.add_argument("--train_iters", type=int, default=20000, help="Total number of training iterations.")
369
+
370
+ args = parser.parse_known_args()[0]
371
+
372
+ with wp.ScopedDevice("cuda:0"):
373
+ example = Example(args.train_iters)
374
+ example.train_warp()
375
+ # example.train_torch()
warp/fem/__init__.py CHANGED
@@ -24,14 +24,17 @@ from .geometry import (
24
24
  LinearGeometryPartition,
25
25
  Nanogrid,
26
26
  Quadmesh2D,
27
+ Quadmesh3D,
27
28
  Tetmesh,
28
29
  Trimesh2D,
30
+ Trimesh3D,
29
31
  )
30
32
  from .integrate import integrate, interpolate
31
33
  from .operator import (
32
34
  D,
33
35
  at_node,
34
36
  average,
37
+ cells,
35
38
  curl,
36
39
  deformation_gradient,
37
40
  degree,
@@ -50,6 +53,9 @@ from .operator import (
50
53
  normal,
51
54
  outer,
52
55
  position,
56
+ to_cell_side,
57
+ to_inner_cell,
58
+ to_outer_cell,
53
59
  )
54
60
  from .polynomial import Polynomial
55
61
  from .quadrature import ExplicitQuadrature, NodalQuadrature, PicQuadrature, Quadrature, RegularQuadrature
@@ -65,6 +71,8 @@ from .space import (
65
71
  SpaceTopology,
66
72
  SymmetricTensorMapper,
67
73
  make_collocated_function_space,
74
+ make_contravariant_function_space,
75
+ make_covariant_function_space,
68
76
  make_polynomial_basis_space,
69
77
  make_polynomial_space,
70
78
  make_space_partition,
warp/fem/cache.py CHANGED
@@ -6,6 +6,7 @@ from copy import copy
6
6
  from typing import Any, Callable, Dict, Optional, Tuple, Union
7
7
 
8
8
  import warp as wp
9
+ from warp.fem.operator import Integrand
9
10
 
10
11
  _kernel_cache = {}
11
12
  _struct_cache = {}
@@ -186,7 +187,7 @@ class ExpandStarredArgumentStruct(ast.NodeTransformer):
186
187
 
187
188
 
188
189
  def get_integrand_function(
189
- integrand: "warp.fem.operator.Integrand", # noqa: F821
190
+ integrand: Integrand,
190
191
  suffix: str,
191
192
  func=None,
192
193
  annotations=None,
@@ -208,27 +209,30 @@ def get_integrand_function(
208
209
 
209
210
 
210
211
  def get_integrand_kernel(
211
- integrand: "warp.fem.operator.Integrand", # noqa: F821
212
+ integrand: Integrand,
212
213
  suffix: str,
213
214
  kernel_fn: Optional[Callable] = None,
214
215
  kernel_options: Dict[str, Any] = None,
215
216
  code_transformers=None,
216
217
  ):
217
- if kernel_options is None:
218
- kernel_options = {}
218
+ options = integrand.module.options.copy()
219
+ options.update(integrand.kernel_options)
220
+ if kernel_options is not None:
221
+ options.update(kernel_options)
219
222
 
220
- key = _make_key(integrand.func, suffix, use_qualified_name=True)
223
+ kernel_key = _make_key(integrand.func, suffix, use_qualified_name=True)
224
+ opts_key = "".join([f"{k}:{v}" for k, v in sorted(options.items())])
225
+ cache_key = kernel_key + opts_key
221
226
 
222
- if key not in _kernel_cache:
227
+ if cache_key not in _kernel_cache:
223
228
  if kernel_fn is None:
224
229
  return None
225
230
 
226
231
  module = wp.get_module(f"{integrand.module.name}.{integrand.name}")
227
- module.options = copy(integrand.module.options)
228
- module.options.update(kernel_options)
229
-
230
- _kernel_cache[key] = wp.Kernel(func=kernel_fn, key=key, module=module, code_transformers=code_transformers)
231
- return _kernel_cache[key]
232
+ _kernel_cache[cache_key] = wp.Kernel(
233
+ func=kernel_fn, key=kernel_key, module=module, code_transformers=code_transformers, options=options
234
+ )
235
+ return _kernel_cache[cache_key]
232
236
 
233
237
 
234
238
  def cached_arg_value(func: Callable):
@@ -478,7 +482,7 @@ def borrow_temporary(
478
482
  if temporary_store is None:
479
483
  temporary_store = TemporaryStore._default_store
480
484
 
481
- if temporary_store is None:
485
+ if temporary_store is None or (requires_grad and wp.context.runtime.tape is not None):
482
486
  return Temporary(
483
487
  array=wp.empty(shape=shape, dtype=dtype, pinned=pinned, device=device, requires_grad=requires_grad)
484
488
  )
warp/fem/dirichlet.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from typing import Any, Optional
2
2
 
3
3
  import warp as wp
4
- from warp.fem.utils import array_axpy, symmetric_eigenvalues_qr
4
+ from warp.fem.linalg import array_axpy, symmetric_eigenvalues_qr
5
5
  from warp.sparse import BsrMatrix, bsr_assign, bsr_axpy, bsr_copy, bsr_mm, bsr_mv
6
6
  from warp.types import type_is_matrix, type_length
7
7
 
warp/fem/domain.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional, Union
1
+ from typing import Any, Optional, Set, Union
2
2
 
3
3
  import warp as wp
4
4
  import warp.codegen
@@ -11,6 +11,7 @@ from warp.fem.geometry import (
11
11
  GeometryPartition,
12
12
  WholeGeometryPartition,
13
13
  )
14
+ from warp.fem.operator import Operator
14
15
  from warp.fem.types import ElementKind
15
16
 
16
17
  GeometryOrPartition = Union[Geometry, GeometryPartition]
@@ -94,6 +95,10 @@ class GeometryDomain:
94
95
  element_lookup: wp.Function
95
96
  """Device function returning the sample point corresponding to a world position"""
96
97
 
98
+ def notify_operator_usage(self, ops: Set[Operator]):
99
+ """Makes the Domain aware that the operators `ops` will be applied"""
100
+ pass
101
+
97
102
 
98
103
  class Cells(GeometryDomain):
99
104
  """A Domain containing all cells of the geometry or geometry partition"""
@@ -160,6 +165,17 @@ class Cells(GeometryDomain):
160
165
  def element_lookup(self) -> wp.Function:
161
166
  return self.geometry.cell_lookup
162
167
 
168
+ @property
169
+ def domain_cell_arg(self) -> wp.Function:
170
+ return Cells._identity_fn
171
+
172
+ def cell_domain(self):
173
+ return self
174
+
175
+ @wp.func
176
+ def _identity_fn(x: Any):
177
+ return x
178
+
163
179
 
164
180
  class Sides(GeometryDomain):
165
181
  """A Domain containing all (interior and boundary) sides of the geometry or geometry partition"""
@@ -225,6 +241,33 @@ class Sides(GeometryDomain):
225
241
  def element_normal(self) -> wp.Function:
226
242
  return self.geometry.side_normal
227
243
 
244
+ @property
245
+ def element_inner_cell_index(self) -> wp.Function:
246
+ return self.geometry.side_inner_cell_index
247
+
248
+ @property
249
+ def element_outer_cell_index(self) -> wp.Function:
250
+ return self.geometry.side_outer_cell_index
251
+
252
+ @property
253
+ def element_inner_cell_coords(self) -> wp.Function:
254
+ return self.geometry.side_inner_cell_coords
255
+
256
+ @property
257
+ def element_outer_cell_coords(self) -> wp.Function:
258
+ return self.geometry.side_outer_cell_coords
259
+
260
+ @property
261
+ def cell_to_element_coords(self) -> wp.Function:
262
+ return self.geometry.side_from_cell_coords
263
+
264
+ @property
265
+ def domain_cell_arg(self) -> wp.Function:
266
+ return self.geometry.side_to_cell_arg
267
+
268
+ def cell_domain(self):
269
+ return Cells(self.geometry_partition)
270
+
228
271
 
229
272
  class BoundarySides(Sides):
230
273
  """A Domain containing boundary sides of the geometry or geometry partition"""
@@ -6,8 +6,7 @@ from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction, make
6
6
  from .field import DiscreteField, FieldLike, GeometryField, ImplicitField, NonconformingField, SpaceField, UniformField
7
7
  from .nodal_field import NodalField
8
8
  from .restriction import FieldRestriction
9
- from .test import TestField
10
- from .trial import TrialField
9
+ from .virtual import LocalTestField, LocalTrialField, TestField, TrialField
11
10
 
12
11
 
13
12
  def make_restriction(
warp/fem/field/field.py CHANGED
@@ -1,10 +1,10 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Dict, Optional, Set
2
2
 
3
3
  import warp as wp
4
4
  from warp.fem import cache
5
5
  from warp.fem.domain import GeometryDomain, Sides
6
6
  from warp.fem.geometry import DeformedGeometry, Geometry
7
- from warp.fem.operator import integrand
7
+ from warp.fem.operator import Operator, integrand
8
8
  from warp.fem.space import FunctionSpace, SpacePartition
9
9
  from warp.fem.types import NULL_ELEMENT_INDEX, ElementKind, Sample
10
10
 
@@ -48,32 +48,32 @@ class FieldLike:
48
48
  return False
49
49
 
50
50
  @staticmethod
51
- def eval_inner(args: "ElementEvalArg", s: "Sample"): # noqa: F821
51
+ def eval_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
52
52
  """Device function evaluating the inner field value at a sample point"""
53
53
  raise NotImplementedError
54
54
 
55
55
  @staticmethod
56
- def eval_grad_inner(args: "ElementEvalArg", s: "Sample"): # noqa: F821
56
+ def eval_grad_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
57
57
  """Device function evaluating the inner field gradient at a sample point"""
58
58
  raise NotImplementedError
59
59
 
60
60
  @staticmethod
61
- def eval_div_inner(args: "ElementEvalArg", s: "Sample"): # noqa: F821
61
+ def eval_div_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
62
62
  """Device function evaluating the inner field divergence at a sample point"""
63
63
  raise NotImplementedError
64
64
 
65
65
  @staticmethod
66
- def eval_outer(args: "ElementEvalArg", s: "Sample"): # noqa: F821
66
+ def eval_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
67
67
  """Device function evaluating the outer field value at a sample point"""
68
68
  raise NotImplementedError
69
69
 
70
70
  @staticmethod
71
- def eval_grad_outer(args: "ElementEvalArg", s: "Sample"): # noqa: F821
71
+ def eval_grad_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
72
72
  """Device function evaluating the outer field gradient at a sample point"""
73
73
  raise NotImplementedError
74
74
 
75
75
  @staticmethod
76
- def eval_div_outer(args: "ElementEvalArg", s: "Sample"): # noqa: F821
76
+ def eval_div_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
77
77
  """Device function evaluating the outer field divergence at a sample point"""
78
78
  raise NotImplementedError
79
79
 
@@ -82,6 +82,10 @@ class FieldLike:
82
82
  """Polynomial degree of the field is applicable, or hint for determination of interpolation order"""
83
83
  raise NotImplementedError
84
84
 
85
+ def notify_operator_usage(self, ops: Set[Operator]):
86
+ """Makes the Domain aware that the operators `ops` will be applied"""
87
+ pass
88
+
85
89
 
86
90
  class GeometryField(FieldLike):
87
91
  """Base class for fields defined over a geometry"""
@@ -97,12 +101,12 @@ class GeometryField(FieldLike):
97
101
  raise NotImplementedError
98
102
 
99
103
  @staticmethod
100
- def eval_reference_grad_inner(args: "ElementEvalArg", s: "Sample"): # noqa: F821
104
+ def eval_reference_grad_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
101
105
  """Device function evaluating the inner field gradient with respect to reference element coordinates at a sample point"""
102
106
  raise NotImplementedError
103
107
 
104
108
  @staticmethod
105
- def eval_reference_grad_outer(args: "ElementEvalArg", s: "Sample"): # noqa: F821
109
+ def eval_reference_grad_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
106
110
  """Device function evaluating the outer field gradient with respect to reference element coordinates at a sample point"""
107
111
  raise NotImplementedError
108
112
 
@@ -128,6 +132,9 @@ class SpaceField(GeometryField):
128
132
  self._space = space
129
133
  self._space_partition = space_partition
130
134
 
135
+ self.gradient_valid = self.space.gradient_valid
136
+ self.divergence_valid = self.space.divergence_valid
137
+
131
138
  @property
132
139
  def geometry(self) -> Geometry:
133
140
  return self._space.geometry
@@ -156,17 +163,22 @@ class SpaceField(GeometryField):
156
163
  def dof_dtype(self) -> type:
157
164
  return self.space.dof_dtype
158
165
 
159
- def gradient_valid(self) -> bool:
160
- """Whether gradient operator can be computed. Only for scalar and vector fields as higher-order tensors are not support yet"""
161
- return not wp.types.type_is_matrix(self.dtype)
166
+ @property
167
+ def gradient_dtype(self):
168
+ """Return type of the gradient operator. Assumes self.gradient_valid()"""
169
+ if wp.types.type_is_vector(self.dtype):
170
+ return cache.cached_mat_type(
171
+ shape=(wp.types.type_length(self.dtype), self.geometry.dimension),
172
+ dtype=wp.types.type_scalar_type(self.dtype),
173
+ )
174
+ return cache.cached_vec_type(length=self.geometry.dimension, dtype=wp.types.type_scalar_type(self.dtype))
162
175
 
163
- def divergence_valid(self) -> bool:
164
- """Whether divergence of this field can be computed. Only for vector and tensor fields with same dimension as embedding geometry"""
176
+ @property
177
+ def divergence_dtype(self):
178
+ """Return type of the divergence operator. Assumes self.gradient_valid()"""
165
179
  if wp.types.type_is_vector(self.dtype):
166
- return wp.types.type_length(self.dtype) == self.space.geometry.dimension
167
- if wp.types.type_is_matrix(self.dtype):
168
- return self.dtype._shape_[0] == self.space.geometry.dimension
169
- return False
180
+ return wp.types.type_scalar_type(self.dtype)
181
+ return cache.cached_vec_type(length=self.dtype._shape_[1], dtype=wp.types.type_scalar_type(self.dtype))
170
182
 
171
183
  def _make_eval_degree(self):
172
184
  ORDER = self.space.ORDER