warp-lang 1.4.2__py3-none-manylinux2014_x86_64.whl → 1.5.1__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +4 -0
- warp/autograd.py +43 -8
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +21 -2
- warp/build_dll.py +23 -6
- warp/builtins.py +1819 -7
- warp/codegen.py +197 -61
- warp/config.py +2 -2
- warp/context.py +379 -107
- warp/examples/assets/pixel.jpg +0 -0
- warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
- warp/examples/benchmarks/benchmark_gemm.py +121 -0
- warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
- warp/examples/benchmarks/benchmark_tile.py +179 -0
- warp/examples/fem/example_adaptive_grid.py +37 -10
- warp/examples/fem/example_apic_fluid.py +3 -2
- warp/examples/fem/example_convection_diffusion_dg.py +4 -5
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion_3d.py +47 -4
- warp/examples/fem/example_distortion_energy.py +220 -0
- warp/examples/fem/example_magnetostatics.py +127 -85
- warp/examples/fem/example_nonconforming_contact.py +5 -5
- warp/examples/fem/example_stokes.py +3 -1
- warp/examples/fem/example_streamlines.py +12 -19
- warp/examples/fem/utils.py +38 -15
- warp/examples/sim/example_cloth.py +4 -25
- warp/examples/sim/example_quadruped.py +2 -1
- warp/examples/tile/example_tile_convolution.py +58 -0
- warp/examples/tile/example_tile_fft.py +47 -0
- warp/examples/tile/example_tile_filtering.py +105 -0
- warp/examples/tile/example_tile_matmul.py +79 -0
- warp/examples/tile/example_tile_mlp.py +375 -0
- warp/fem/__init__.py +8 -0
- warp/fem/cache.py +16 -12
- warp/fem/dirichlet.py +1 -1
- warp/fem/domain.py +44 -1
- warp/fem/field/__init__.py +1 -2
- warp/fem/field/field.py +31 -19
- warp/fem/field/nodal_field.py +101 -49
- warp/fem/field/virtual.py +794 -0
- warp/fem/geometry/__init__.py +2 -2
- warp/fem/geometry/deformed_geometry.py +3 -105
- warp/fem/geometry/element.py +13 -0
- warp/fem/geometry/geometry.py +165 -7
- warp/fem/geometry/grid_2d.py +3 -6
- warp/fem/geometry/grid_3d.py +31 -28
- warp/fem/geometry/hexmesh.py +3 -46
- warp/fem/geometry/nanogrid.py +3 -2
- warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
- warp/fem/geometry/tetmesh.py +2 -43
- warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
- warp/fem/integrate.py +683 -261
- warp/fem/linalg.py +404 -0
- warp/fem/operator.py +101 -18
- warp/fem/polynomial.py +5 -5
- warp/fem/quadrature/quadrature.py +45 -21
- warp/fem/space/__init__.py +45 -11
- warp/fem/space/basis_function_space.py +451 -0
- warp/fem/space/basis_space.py +58 -11
- warp/fem/space/function_space.py +146 -5
- warp/fem/space/grid_2d_function_space.py +80 -66
- warp/fem/space/grid_3d_function_space.py +113 -68
- warp/fem/space/hexmesh_function_space.py +96 -108
- warp/fem/space/nanogrid_function_space.py +62 -110
- warp/fem/space/quadmesh_function_space.py +208 -0
- warp/fem/space/shape/__init__.py +45 -7
- warp/fem/space/shape/cube_shape_function.py +328 -54
- warp/fem/space/shape/shape_function.py +10 -1
- warp/fem/space/shape/square_shape_function.py +328 -60
- warp/fem/space/shape/tet_shape_function.py +269 -19
- warp/fem/space/shape/triangle_shape_function.py +238 -19
- warp/fem/space/tetmesh_function_space.py +69 -37
- warp/fem/space/topology.py +38 -0
- warp/fem/space/trimesh_function_space.py +179 -0
- warp/fem/utils.py +6 -331
- warp/jax_experimental.py +3 -1
- warp/native/array.h +15 -0
- warp/native/builtin.h +66 -26
- warp/native/bvh.h +4 -0
- warp/native/coloring.cpp +604 -0
- warp/native/cuda_util.cpp +68 -51
- warp/native/cuda_util.h +2 -1
- warp/native/fabric.h +8 -0
- warp/native/hashgrid.h +4 -0
- warp/native/marching.cu +8 -0
- warp/native/mat.h +14 -3
- warp/native/mathdx.cpp +59 -0
- warp/native/mesh.h +4 -0
- warp/native/range.h +13 -1
- warp/native/reduce.cpp +9 -1
- warp/native/reduce.cu +7 -0
- warp/native/runlength_encode.cpp +9 -1
- warp/native/runlength_encode.cu +7 -1
- warp/native/scan.cpp +8 -0
- warp/native/scan.cu +8 -0
- warp/native/scan.h +8 -1
- warp/native/sparse.cpp +8 -0
- warp/native/sparse.cu +8 -0
- warp/native/temp_buffer.h +7 -0
- warp/native/tile.h +1854 -0
- warp/native/tile_gemm.h +341 -0
- warp/native/tile_reduce.h +210 -0
- warp/native/volume_builder.cu +8 -0
- warp/native/volume_builder.h +8 -0
- warp/native/warp.cpp +10 -2
- warp/native/warp.cu +369 -15
- warp/native/warp.h +12 -2
- warp/optim/adam.py +39 -4
- warp/paddle.py +29 -12
- warp/render/render_opengl.py +140 -67
- warp/sim/graph_coloring.py +292 -0
- warp/sim/import_urdf.py +8 -8
- warp/sim/integrator_euler.py +4 -2
- warp/sim/integrator_featherstone.py +115 -44
- warp/sim/integrator_vbd.py +6 -0
- warp/sim/model.py +109 -32
- warp/sparse.py +1 -1
- warp/stubs.py +569 -4
- warp/tape.py +12 -7
- warp/tests/assets/pixel.npy +0 -0
- warp/tests/aux_test_instancing_gc.py +18 -0
- warp/tests/test_array.py +39 -0
- warp/tests/test_codegen.py +81 -1
- warp/tests/test_codegen_instancing.py +30 -0
- warp/tests/test_collision.py +110 -0
- warp/tests/test_coloring.py +251 -0
- warp/tests/test_context.py +34 -0
- warp/tests/test_examples.py +21 -5
- warp/tests/test_fem.py +453 -113
- warp/tests/test_func.py +34 -4
- warp/tests/test_generics.py +52 -0
- warp/tests/test_iter.py +68 -0
- warp/tests/test_lerp.py +13 -87
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_matmul.py +6 -9
- warp/tests/test_matmul_lite.py +6 -11
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_module_hashing.py +23 -0
- warp/tests/test_overwrite.py +45 -0
- warp/tests/test_paddle.py +27 -87
- warp/tests/test_print.py +56 -1
- warp/tests/test_smoothstep.py +17 -83
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_static.py +3 -3
- warp/tests/test_tile.py +744 -0
- warp/tests/test_tile_mathdx.py +144 -0
- warp/tests/test_tile_mlp.py +383 -0
- warp/tests/test_tile_reduce.py +374 -0
- warp/tests/test_tile_shared_memory.py +190 -0
- warp/tests/test_vbd.py +12 -20
- warp/tests/test_volume.py +43 -0
- warp/tests/unittest_suites.py +19 -2
- warp/tests/unittest_utils.py +4 -2
- warp/types.py +340 -74
- warp/utils.py +23 -3
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
- warp/fem/field/test.py +0 -180
- warp/fem/field/trial.py +0 -183
- warp/fem/space/collocated_function_space.py +0 -102
- warp/fem/space/quadmesh_2d_function_space.py +0 -261
- warp/fem/space/trimesh_2d_function_space.py +0 -153
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
# Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
###########################################################################
|
|
9
|
+
# Example Image Multilayer Perceptron (MLP)
|
|
10
|
+
#
|
|
11
|
+
# Shows how to train a coordinate-based MLP on an image to predict the RGB
|
|
12
|
+
# color at a given input position. By default, a positional encoding is
|
|
13
|
+
# applied to the input coordinates to improve the ability of the MLP to
|
|
14
|
+
# represent higher-frequency content. This can be disabled by passing the
|
|
15
|
+
# '--no_encoding' option.
|
|
16
|
+
#
|
|
17
|
+
# References:
|
|
18
|
+
# Ben Mildenhall et al. 2021. NeRF: representing scenes
|
|
19
|
+
# as neural radiance fields for view synthesis. Commun. ACM 65, 1
|
|
20
|
+
# (January 2022), 99–106. https://doi.org/10.1145/3503250
|
|
21
|
+
#
|
|
22
|
+
###########################################################################
|
|
23
|
+
|
|
24
|
+
import math
|
|
25
|
+
import os
|
|
26
|
+
|
|
27
|
+
import numpy as np
|
|
28
|
+
from PIL import Image
|
|
29
|
+
|
|
30
|
+
import warp as wp
|
|
31
|
+
import warp.examples
|
|
32
|
+
import warp.optim
|
|
33
|
+
|
|
34
|
+
rng = np.random.default_rng(45)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_layer(dim_in, dim_hid, dtype=float):
|
|
38
|
+
w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
|
|
39
|
+
b = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, 1))
|
|
40
|
+
|
|
41
|
+
weights = wp.array(w, dtype=dtype, requires_grad=True)
|
|
42
|
+
bias = wp.array(b, dtype=dtype, requires_grad=True)
|
|
43
|
+
|
|
44
|
+
return (weights, bias)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def create_array(dim_in, dim_hid, dtype=float):
|
|
48
|
+
s = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
|
|
49
|
+
a = wp.array(s, dtype=dtype, requires_grad=True)
|
|
50
|
+
|
|
51
|
+
return a
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# number of frequencies for the positional encoding
|
|
55
|
+
NUM_FREQ = wp.constant(8)
|
|
56
|
+
|
|
57
|
+
DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequenecy
|
|
58
|
+
DIM_HID = 32
|
|
59
|
+
DIM_OUT = 3
|
|
60
|
+
|
|
61
|
+
# threads per-block
|
|
62
|
+
NUM_THREADS = 32
|
|
63
|
+
|
|
64
|
+
IMG_WIDTH = 512
|
|
65
|
+
IMG_HEIGHT = 512
|
|
66
|
+
|
|
67
|
+
BATCH_SIZE = min(1024, int((IMG_WIDTH * IMG_HEIGHT) / 8))
|
|
68
|
+
|
|
69
|
+
# dtype for our weights and bias matrices
|
|
70
|
+
dtype = wp.float16
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@wp.func
|
|
74
|
+
def relu(x: dtype):
|
|
75
|
+
return wp.max(x, dtype(0.0))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@wp.kernel
|
|
79
|
+
def compute(
|
|
80
|
+
indices: wp.array(dtype=int),
|
|
81
|
+
weights_0: wp.array2d(dtype=dtype),
|
|
82
|
+
bias_0: wp.array2d(dtype=dtype),
|
|
83
|
+
weights_1: wp.array2d(dtype=dtype),
|
|
84
|
+
bias_1: wp.array2d(dtype=dtype),
|
|
85
|
+
weights_2: wp.array2d(dtype=dtype),
|
|
86
|
+
bias_2: wp.array2d(dtype=dtype),
|
|
87
|
+
weights_3: wp.array2d(dtype=dtype),
|
|
88
|
+
bias_3: wp.array2d(dtype=dtype),
|
|
89
|
+
reference: wp.array2d(dtype=float),
|
|
90
|
+
loss: wp.array1d(dtype=float),
|
|
91
|
+
out: wp.array2d(dtype=float),
|
|
92
|
+
):
|
|
93
|
+
# batch indices
|
|
94
|
+
linear = indices[wp.tid()]
|
|
95
|
+
|
|
96
|
+
row = linear / IMG_WIDTH
|
|
97
|
+
col = linear % IMG_WIDTH
|
|
98
|
+
|
|
99
|
+
# normalize input coordinates to [-1, 1]
|
|
100
|
+
x = (float(row) / float(IMG_WIDTH) - 0.5) * 2.0
|
|
101
|
+
y = (float(col) / float(IMG_HEIGHT) - 0.5) * 2.0
|
|
102
|
+
|
|
103
|
+
local = wp.vector(dtype=dtype, length=DIM_IN)
|
|
104
|
+
|
|
105
|
+
# construct positional encoding
|
|
106
|
+
for s in range(NUM_FREQ):
|
|
107
|
+
scale = wp.pow(2.0, float(s)) * wp.pi
|
|
108
|
+
|
|
109
|
+
# x-coord
|
|
110
|
+
local[s * 4 + 0] = dtype(wp.sin(x * scale))
|
|
111
|
+
local[s * 4 + 1] = dtype(wp.cos(x * scale))
|
|
112
|
+
# y-coord
|
|
113
|
+
local[s * 4 + 2] = dtype(wp.sin(y * scale))
|
|
114
|
+
local[s * 4 + 3] = dtype(wp.cos(y * scale))
|
|
115
|
+
|
|
116
|
+
# tile feature vectors across the block, returns [dim(f), NUM_THREADS]
|
|
117
|
+
f = wp.tile(local)
|
|
118
|
+
|
|
119
|
+
# input layer
|
|
120
|
+
w0 = wp.tile_load(weights_0, 0, 0, m=DIM_HID, n=DIM_IN)
|
|
121
|
+
b0 = wp.tile_load(bias_0, 0, 0, m=DIM_HID, n=1)
|
|
122
|
+
z = wp.tile_map(relu, wp.tile_matmul(w0, f) + wp.tile_broadcast(b0, m=DIM_HID, n=NUM_THREADS))
|
|
123
|
+
|
|
124
|
+
# hidden layer
|
|
125
|
+
w1 = wp.tile_load(weights_1, 0, 0, m=DIM_HID, n=DIM_HID)
|
|
126
|
+
b1 = wp.tile_load(bias_1, 0, 0, m=DIM_HID, n=1)
|
|
127
|
+
z = wp.tile_map(relu, wp.tile_matmul(w1, z) + wp.tile_broadcast(b1, m=DIM_HID, n=NUM_THREADS))
|
|
128
|
+
|
|
129
|
+
w2 = wp.tile_load(weights_2, 0, 0, m=DIM_HID, n=DIM_HID)
|
|
130
|
+
b2 = wp.tile_load(bias_2, 0, 0, m=DIM_HID, n=1)
|
|
131
|
+
z = wp.tile_map(relu, wp.tile_matmul(w2, z) + wp.tile_broadcast(b2, m=DIM_HID, n=NUM_THREADS))
|
|
132
|
+
|
|
133
|
+
# output layer
|
|
134
|
+
w3 = wp.tile_load(weights_3, 0, 0, m=DIM_OUT, n=DIM_HID)
|
|
135
|
+
b3 = wp.tile_load(bias_3, 0, 0, m=DIM_OUT, n=1)
|
|
136
|
+
o = wp.tile_map(relu, wp.tile_matmul(w3, z) + wp.tile_broadcast(b3, m=DIM_OUT, n=NUM_THREADS))
|
|
137
|
+
|
|
138
|
+
# untile back to SIMT
|
|
139
|
+
output = wp.untile(o)
|
|
140
|
+
|
|
141
|
+
# compute error
|
|
142
|
+
error = wp.vec3(
|
|
143
|
+
float(output[0]) - reference[0, linear],
|
|
144
|
+
float(output[1]) - reference[1, linear],
|
|
145
|
+
float(output[2]) - reference[2, linear],
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# write MSE loss
|
|
149
|
+
if loss:
|
|
150
|
+
wp.atomic_add(loss, 0, wp.length_sq(error) / float(3 * BATCH_SIZE))
|
|
151
|
+
|
|
152
|
+
# write image output
|
|
153
|
+
if out:
|
|
154
|
+
for i in range(DIM_OUT):
|
|
155
|
+
out[i, linear] = float(output[i])
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class Example:
|
|
159
|
+
def __init__(self, train_iters):
|
|
160
|
+
self.weights_0, self.bias_0 = create_layer(DIM_IN, DIM_HID, dtype=dtype)
|
|
161
|
+
self.weights_1, self.bias_1 = create_layer(DIM_HID, DIM_HID, dtype=dtype)
|
|
162
|
+
self.weights_2, self.bias_2 = create_layer(DIM_HID, DIM_HID, dtype=dtype)
|
|
163
|
+
self.weights_3, self.bias_3 = create_layer(DIM_HID, DIM_OUT, dtype=dtype)
|
|
164
|
+
|
|
165
|
+
# reference
|
|
166
|
+
reference_path = os.path.join(wp.examples.get_asset_directory(), "pixel.jpg")
|
|
167
|
+
with Image.open(reference_path) as im:
|
|
168
|
+
reference_image = np.asarray(im.resize((IMG_WIDTH, IMG_HEIGHT)).convert("RGB")) / 255.0
|
|
169
|
+
self.reference = wp.array(reference_image.reshape(IMG_WIDTH * IMG_HEIGHT, 3).T, dtype=float)
|
|
170
|
+
|
|
171
|
+
# create randomized batch indices
|
|
172
|
+
indices = np.arange(0, IMG_WIDTH * IMG_HEIGHT, dtype=np.int32)
|
|
173
|
+
rng.shuffle(indices)
|
|
174
|
+
self.indices = wp.array(indices)
|
|
175
|
+
|
|
176
|
+
self.num_batches = int((IMG_WIDTH * IMG_HEIGHT) / BATCH_SIZE)
|
|
177
|
+
self.max_iters = train_iters
|
|
178
|
+
self.max_epochs = max(1, int(self.max_iters / self.num_batches))
|
|
179
|
+
|
|
180
|
+
def train_warp(self):
|
|
181
|
+
params = [
|
|
182
|
+
self.weights_0,
|
|
183
|
+
self.bias_0,
|
|
184
|
+
self.weights_1,
|
|
185
|
+
self.bias_1,
|
|
186
|
+
self.weights_2,
|
|
187
|
+
self.bias_2,
|
|
188
|
+
self.weights_3,
|
|
189
|
+
self.bias_3,
|
|
190
|
+
]
|
|
191
|
+
|
|
192
|
+
optimizer_grads = [p.grad.flatten() for p in params]
|
|
193
|
+
optimizer_inputs = [p.flatten() for p in params]
|
|
194
|
+
optimizer = warp.optim.Adam(optimizer_inputs, lr=0.01)
|
|
195
|
+
|
|
196
|
+
loss = wp.zeros(1, dtype=float, requires_grad=True)
|
|
197
|
+
output = create_array(IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
|
|
198
|
+
|
|
199
|
+
# capture graph for whole epoch
|
|
200
|
+
wp.capture_begin()
|
|
201
|
+
|
|
202
|
+
for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
|
|
203
|
+
loss.zero_()
|
|
204
|
+
|
|
205
|
+
with wp.Tape() as tape:
|
|
206
|
+
wp.launch(
|
|
207
|
+
compute,
|
|
208
|
+
dim=[BATCH_SIZE],
|
|
209
|
+
inputs=[
|
|
210
|
+
self.indices[b : b + BATCH_SIZE],
|
|
211
|
+
self.weights_0,
|
|
212
|
+
self.bias_0,
|
|
213
|
+
self.weights_1,
|
|
214
|
+
self.bias_1,
|
|
215
|
+
self.weights_2,
|
|
216
|
+
self.bias_2,
|
|
217
|
+
self.weights_3,
|
|
218
|
+
self.bias_3,
|
|
219
|
+
self.reference,
|
|
220
|
+
loss,
|
|
221
|
+
None,
|
|
222
|
+
],
|
|
223
|
+
block_dim=NUM_THREADS,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
tape.backward(loss)
|
|
227
|
+
optimizer.step(optimizer_grads)
|
|
228
|
+
tape.zero()
|
|
229
|
+
|
|
230
|
+
graph = wp.capture_end()
|
|
231
|
+
|
|
232
|
+
with wp.ScopedTimer("Training"):
|
|
233
|
+
for i in range(self.max_epochs):
|
|
234
|
+
with wp.ScopedTimer("Epoch"):
|
|
235
|
+
wp.capture_launch(graph)
|
|
236
|
+
print(f"Epoch: {i} Loss: {loss.numpy()}")
|
|
237
|
+
|
|
238
|
+
# evaluate full image
|
|
239
|
+
wp.launch(
|
|
240
|
+
compute,
|
|
241
|
+
dim=[IMG_WIDTH * IMG_HEIGHT],
|
|
242
|
+
inputs=[
|
|
243
|
+
self.indices,
|
|
244
|
+
self.weights_0,
|
|
245
|
+
self.bias_0,
|
|
246
|
+
self.weights_1,
|
|
247
|
+
self.bias_1,
|
|
248
|
+
self.weights_2,
|
|
249
|
+
self.bias_2,
|
|
250
|
+
self.weights_3,
|
|
251
|
+
self.bias_3,
|
|
252
|
+
self.reference,
|
|
253
|
+
loss,
|
|
254
|
+
output,
|
|
255
|
+
],
|
|
256
|
+
block_dim=NUM_THREADS,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
self.save_image("example_tile_mlp.jpg", output.numpy())
|
|
260
|
+
|
|
261
|
+
def train_torch(self):
|
|
262
|
+
import torch as tc
|
|
263
|
+
|
|
264
|
+
weights_0 = tc.nn.Parameter(wp.to_torch(self.weights_0))
|
|
265
|
+
weights_1 = tc.nn.Parameter(wp.to_torch(self.weights_1))
|
|
266
|
+
weights_2 = tc.nn.Parameter(wp.to_torch(self.weights_2))
|
|
267
|
+
weights_3 = tc.nn.Parameter(wp.to_torch(self.weights_3))
|
|
268
|
+
|
|
269
|
+
bias_0 = tc.nn.Parameter(wp.to_torch(self.bias_0))
|
|
270
|
+
bias_1 = tc.nn.Parameter(wp.to_torch(self.bias_1))
|
|
271
|
+
bias_2 = tc.nn.Parameter(wp.to_torch(self.bias_2))
|
|
272
|
+
bias_3 = tc.nn.Parameter(wp.to_torch(self.bias_3))
|
|
273
|
+
|
|
274
|
+
indices = wp.to_torch(self.indices)
|
|
275
|
+
reference = wp.to_torch(self.reference)
|
|
276
|
+
|
|
277
|
+
optimizer = tc.optim.Adam(
|
|
278
|
+
[weights_0, bias_0, weights_1, bias_1, weights_2, bias_2, weights_3, bias_3],
|
|
279
|
+
capturable=True,
|
|
280
|
+
lr=0.0001,
|
|
281
|
+
betas=(0.9, 0.95),
|
|
282
|
+
eps=1.0e-6,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# generate frequency space encoding of pixels
|
|
286
|
+
# based on their linear index in the image
|
|
287
|
+
def encode(linear):
|
|
288
|
+
row = (linear // IMG_WIDTH).float()
|
|
289
|
+
col = (linear % IMG_WIDTH).float()
|
|
290
|
+
|
|
291
|
+
x = (row / float(IMG_WIDTH) - 0.5) * 2.0
|
|
292
|
+
y = (col / float(IMG_HEIGHT) - 0.5) * 2.0
|
|
293
|
+
|
|
294
|
+
encoding = tc.zeros((NUM_FREQ * 4, len(linear)), dtype=tc.float16, device="cuda")
|
|
295
|
+
|
|
296
|
+
for s in range(NUM_FREQ):
|
|
297
|
+
scale = math.pow(2.0, float(s)) * math.pi
|
|
298
|
+
|
|
299
|
+
# Directly write the computed values into the encoding tensor
|
|
300
|
+
encoding[s * 4 + 0, :] = tc.sin(scale * x)
|
|
301
|
+
encoding[s * 4 + 1, :] = tc.cos(scale * x)
|
|
302
|
+
encoding[s * 4 + 2, :] = tc.sin(scale * y)
|
|
303
|
+
encoding[s * 4 + 3, :] = tc.cos(scale * y)
|
|
304
|
+
|
|
305
|
+
return encoding
|
|
306
|
+
|
|
307
|
+
stream = tc.cuda.Stream()
|
|
308
|
+
graph = tc.cuda.CUDAGraph()
|
|
309
|
+
|
|
310
|
+
# warm-up
|
|
311
|
+
with tc.cuda.stream(stream):
|
|
312
|
+
f = tc.rand((NUM_FREQ * 4, BATCH_SIZE), dtype=tc.float16, device="cuda")
|
|
313
|
+
z = tc.relu(weights_0 @ f + bias_0)
|
|
314
|
+
z = tc.relu(weights_1 @ z + bias_1)
|
|
315
|
+
z = tc.relu(weights_2 @ z + bias_2)
|
|
316
|
+
z = tc.relu(weights_3 @ z + bias_3)
|
|
317
|
+
ref = tc.rand((3, BATCH_SIZE), dtype=tc.float16, device="cuda")
|
|
318
|
+
loss = tc.mean((z - ref) ** 2)
|
|
319
|
+
optimizer.zero_grad()
|
|
320
|
+
loss.backward()
|
|
321
|
+
optimizer.step()
|
|
322
|
+
|
|
323
|
+
with tc.cuda.graph(graph):
|
|
324
|
+
for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
|
|
325
|
+
linear = indices[b : b + BATCH_SIZE]
|
|
326
|
+
|
|
327
|
+
f = encode(linear)
|
|
328
|
+
|
|
329
|
+
z = tc.relu(weights_0 @ f + bias_0)
|
|
330
|
+
z = tc.relu(weights_1 @ z + bias_1)
|
|
331
|
+
z = tc.relu(weights_2 @ z + bias_2)
|
|
332
|
+
z = tc.relu(weights_3 @ z + bias_3)
|
|
333
|
+
|
|
334
|
+
ref = reference[:, linear]
|
|
335
|
+
loss = tc.mean((z - ref) ** 2)
|
|
336
|
+
|
|
337
|
+
optimizer.zero_grad()
|
|
338
|
+
loss.backward()
|
|
339
|
+
optimizer.step()
|
|
340
|
+
|
|
341
|
+
with wp.ScopedTimer("Training (Torch)"):
|
|
342
|
+
for _i in range(self.max_epochs):
|
|
343
|
+
with wp.ScopedTimer("Epoch"):
|
|
344
|
+
graph.replay()
|
|
345
|
+
|
|
346
|
+
print(loss)
|
|
347
|
+
|
|
348
|
+
f = encode(tc.arange(0, IMG_WIDTH * IMG_HEIGHT))
|
|
349
|
+
z = tc.relu(weights_0 @ f + bias_0)
|
|
350
|
+
z = tc.relu(weights_1 @ z + bias_1)
|
|
351
|
+
z = tc.relu(weights_2 @ z + bias_2)
|
|
352
|
+
z = tc.relu(weights_3 @ z + bias_3)
|
|
353
|
+
|
|
354
|
+
self.save_image("example_tile_mlp_torch.jpg", z.detach().cpu().numpy())
|
|
355
|
+
|
|
356
|
+
def save_image(self, name, output):
|
|
357
|
+
predicted_image = output.T.reshape(IMG_WIDTH, IMG_HEIGHT, 3)
|
|
358
|
+
predicted_image = (predicted_image * 255).astype(np.uint8)
|
|
359
|
+
|
|
360
|
+
predicted_image_pil = Image.fromarray(predicted_image)
|
|
361
|
+
predicted_image_pil.save(name)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
if __name__ == "__main__":
|
|
365
|
+
import argparse
|
|
366
|
+
|
|
367
|
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
368
|
+
parser.add_argument("--train_iters", type=int, default=20000, help="Total number of training iterations.")
|
|
369
|
+
|
|
370
|
+
args = parser.parse_known_args()[0]
|
|
371
|
+
|
|
372
|
+
with wp.ScopedDevice("cuda:0"):
|
|
373
|
+
example = Example(args.train_iters)
|
|
374
|
+
example.train_warp()
|
|
375
|
+
# example.train_torch()
|
warp/fem/__init__.py
CHANGED
|
@@ -24,14 +24,17 @@ from .geometry import (
|
|
|
24
24
|
LinearGeometryPartition,
|
|
25
25
|
Nanogrid,
|
|
26
26
|
Quadmesh2D,
|
|
27
|
+
Quadmesh3D,
|
|
27
28
|
Tetmesh,
|
|
28
29
|
Trimesh2D,
|
|
30
|
+
Trimesh3D,
|
|
29
31
|
)
|
|
30
32
|
from .integrate import integrate, interpolate
|
|
31
33
|
from .operator import (
|
|
32
34
|
D,
|
|
33
35
|
at_node,
|
|
34
36
|
average,
|
|
37
|
+
cells,
|
|
35
38
|
curl,
|
|
36
39
|
deformation_gradient,
|
|
37
40
|
degree,
|
|
@@ -50,6 +53,9 @@ from .operator import (
|
|
|
50
53
|
normal,
|
|
51
54
|
outer,
|
|
52
55
|
position,
|
|
56
|
+
to_cell_side,
|
|
57
|
+
to_inner_cell,
|
|
58
|
+
to_outer_cell,
|
|
53
59
|
)
|
|
54
60
|
from .polynomial import Polynomial
|
|
55
61
|
from .quadrature import ExplicitQuadrature, NodalQuadrature, PicQuadrature, Quadrature, RegularQuadrature
|
|
@@ -65,6 +71,8 @@ from .space import (
|
|
|
65
71
|
SpaceTopology,
|
|
66
72
|
SymmetricTensorMapper,
|
|
67
73
|
make_collocated_function_space,
|
|
74
|
+
make_contravariant_function_space,
|
|
75
|
+
make_covariant_function_space,
|
|
68
76
|
make_polynomial_basis_space,
|
|
69
77
|
make_polynomial_space,
|
|
70
78
|
make_space_partition,
|
warp/fem/cache.py
CHANGED
|
@@ -6,6 +6,7 @@ from copy import copy
|
|
|
6
6
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import warp as wp
|
|
9
|
+
from warp.fem.operator import Integrand
|
|
9
10
|
|
|
10
11
|
_kernel_cache = {}
|
|
11
12
|
_struct_cache = {}
|
|
@@ -186,7 +187,7 @@ class ExpandStarredArgumentStruct(ast.NodeTransformer):
|
|
|
186
187
|
|
|
187
188
|
|
|
188
189
|
def get_integrand_function(
|
|
189
|
-
integrand:
|
|
190
|
+
integrand: Integrand,
|
|
190
191
|
suffix: str,
|
|
191
192
|
func=None,
|
|
192
193
|
annotations=None,
|
|
@@ -208,27 +209,30 @@ def get_integrand_function(
|
|
|
208
209
|
|
|
209
210
|
|
|
210
211
|
def get_integrand_kernel(
|
|
211
|
-
integrand:
|
|
212
|
+
integrand: Integrand,
|
|
212
213
|
suffix: str,
|
|
213
214
|
kernel_fn: Optional[Callable] = None,
|
|
214
215
|
kernel_options: Dict[str, Any] = None,
|
|
215
216
|
code_transformers=None,
|
|
216
217
|
):
|
|
217
|
-
|
|
218
|
-
|
|
218
|
+
options = integrand.module.options.copy()
|
|
219
|
+
options.update(integrand.kernel_options)
|
|
220
|
+
if kernel_options is not None:
|
|
221
|
+
options.update(kernel_options)
|
|
219
222
|
|
|
220
|
-
|
|
223
|
+
kernel_key = _make_key(integrand.func, suffix, use_qualified_name=True)
|
|
224
|
+
opts_key = "".join([f"{k}:{v}" for k, v in sorted(options.items())])
|
|
225
|
+
cache_key = kernel_key + opts_key
|
|
221
226
|
|
|
222
|
-
if
|
|
227
|
+
if cache_key not in _kernel_cache:
|
|
223
228
|
if kernel_fn is None:
|
|
224
229
|
return None
|
|
225
230
|
|
|
226
231
|
module = wp.get_module(f"{integrand.module.name}.{integrand.name}")
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
return _kernel_cache[key]
|
|
232
|
+
_kernel_cache[cache_key] = wp.Kernel(
|
|
233
|
+
func=kernel_fn, key=kernel_key, module=module, code_transformers=code_transformers, options=options
|
|
234
|
+
)
|
|
235
|
+
return _kernel_cache[cache_key]
|
|
232
236
|
|
|
233
237
|
|
|
234
238
|
def cached_arg_value(func: Callable):
|
|
@@ -478,7 +482,7 @@ def borrow_temporary(
|
|
|
478
482
|
if temporary_store is None:
|
|
479
483
|
temporary_store = TemporaryStore._default_store
|
|
480
484
|
|
|
481
|
-
if temporary_store is None:
|
|
485
|
+
if temporary_store is None or (requires_grad and wp.context.runtime.tape is not None):
|
|
482
486
|
return Temporary(
|
|
483
487
|
array=wp.empty(shape=shape, dtype=dtype, pinned=pinned, device=device, requires_grad=requires_grad)
|
|
484
488
|
)
|
warp/fem/dirichlet.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from typing import Any, Optional
|
|
2
2
|
|
|
3
3
|
import warp as wp
|
|
4
|
-
from warp.fem.
|
|
4
|
+
from warp.fem.linalg import array_axpy, symmetric_eigenvalues_qr
|
|
5
5
|
from warp.sparse import BsrMatrix, bsr_assign, bsr_axpy, bsr_copy, bsr_mm, bsr_mv
|
|
6
6
|
from warp.types import type_is_matrix, type_length
|
|
7
7
|
|
warp/fem/domain.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional, Union
|
|
1
|
+
from typing import Any, Optional, Set, Union
|
|
2
2
|
|
|
3
3
|
import warp as wp
|
|
4
4
|
import warp.codegen
|
|
@@ -11,6 +11,7 @@ from warp.fem.geometry import (
|
|
|
11
11
|
GeometryPartition,
|
|
12
12
|
WholeGeometryPartition,
|
|
13
13
|
)
|
|
14
|
+
from warp.fem.operator import Operator
|
|
14
15
|
from warp.fem.types import ElementKind
|
|
15
16
|
|
|
16
17
|
GeometryOrPartition = Union[Geometry, GeometryPartition]
|
|
@@ -94,6 +95,10 @@ class GeometryDomain:
|
|
|
94
95
|
element_lookup: wp.Function
|
|
95
96
|
"""Device function returning the sample point corresponding to a world position"""
|
|
96
97
|
|
|
98
|
+
def notify_operator_usage(self, ops: Set[Operator]):
|
|
99
|
+
"""Makes the Domain aware that the operators `ops` will be applied"""
|
|
100
|
+
pass
|
|
101
|
+
|
|
97
102
|
|
|
98
103
|
class Cells(GeometryDomain):
|
|
99
104
|
"""A Domain containing all cells of the geometry or geometry partition"""
|
|
@@ -160,6 +165,17 @@ class Cells(GeometryDomain):
|
|
|
160
165
|
def element_lookup(self) -> wp.Function:
|
|
161
166
|
return self.geometry.cell_lookup
|
|
162
167
|
|
|
168
|
+
@property
|
|
169
|
+
def domain_cell_arg(self) -> wp.Function:
|
|
170
|
+
return Cells._identity_fn
|
|
171
|
+
|
|
172
|
+
def cell_domain(self):
|
|
173
|
+
return self
|
|
174
|
+
|
|
175
|
+
@wp.func
|
|
176
|
+
def _identity_fn(x: Any):
|
|
177
|
+
return x
|
|
178
|
+
|
|
163
179
|
|
|
164
180
|
class Sides(GeometryDomain):
|
|
165
181
|
"""A Domain containing all (interior and boundary) sides of the geometry or geometry partition"""
|
|
@@ -225,6 +241,33 @@ class Sides(GeometryDomain):
|
|
|
225
241
|
def element_normal(self) -> wp.Function:
|
|
226
242
|
return self.geometry.side_normal
|
|
227
243
|
|
|
244
|
+
@property
|
|
245
|
+
def element_inner_cell_index(self) -> wp.Function:
|
|
246
|
+
return self.geometry.side_inner_cell_index
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def element_outer_cell_index(self) -> wp.Function:
|
|
250
|
+
return self.geometry.side_outer_cell_index
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def element_inner_cell_coords(self) -> wp.Function:
|
|
254
|
+
return self.geometry.side_inner_cell_coords
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def element_outer_cell_coords(self) -> wp.Function:
|
|
258
|
+
return self.geometry.side_outer_cell_coords
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def cell_to_element_coords(self) -> wp.Function:
|
|
262
|
+
return self.geometry.side_from_cell_coords
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def domain_cell_arg(self) -> wp.Function:
|
|
266
|
+
return self.geometry.side_to_cell_arg
|
|
267
|
+
|
|
268
|
+
def cell_domain(self):
|
|
269
|
+
return Cells(self.geometry_partition)
|
|
270
|
+
|
|
228
271
|
|
|
229
272
|
class BoundarySides(Sides):
|
|
230
273
|
"""A Domain containing boundary sides of the geometry or geometry partition"""
|
warp/fem/field/__init__.py
CHANGED
|
@@ -6,8 +6,7 @@ from warp.fem.space import FunctionSpace, SpacePartition, SpaceRestriction, make
|
|
|
6
6
|
from .field import DiscreteField, FieldLike, GeometryField, ImplicitField, NonconformingField, SpaceField, UniformField
|
|
7
7
|
from .nodal_field import NodalField
|
|
8
8
|
from .restriction import FieldRestriction
|
|
9
|
-
from .
|
|
10
|
-
from .trial import TrialField
|
|
9
|
+
from .virtual import LocalTestField, LocalTrialField, TestField, TrialField
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
def make_restriction(
|
warp/fem/field/field.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from typing import Any, Dict, Optional
|
|
1
|
+
from typing import Any, Dict, Optional, Set
|
|
2
2
|
|
|
3
3
|
import warp as wp
|
|
4
4
|
from warp.fem import cache
|
|
5
5
|
from warp.fem.domain import GeometryDomain, Sides
|
|
6
6
|
from warp.fem.geometry import DeformedGeometry, Geometry
|
|
7
|
-
from warp.fem.operator import integrand
|
|
7
|
+
from warp.fem.operator import Operator, integrand
|
|
8
8
|
from warp.fem.space import FunctionSpace, SpacePartition
|
|
9
9
|
from warp.fem.types import NULL_ELEMENT_INDEX, ElementKind, Sample
|
|
10
10
|
|
|
@@ -48,32 +48,32 @@ class FieldLike:
|
|
|
48
48
|
return False
|
|
49
49
|
|
|
50
50
|
@staticmethod
|
|
51
|
-
def eval_inner(args: "ElementEvalArg", s:
|
|
51
|
+
def eval_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
|
|
52
52
|
"""Device function evaluating the inner field value at a sample point"""
|
|
53
53
|
raise NotImplementedError
|
|
54
54
|
|
|
55
55
|
@staticmethod
|
|
56
|
-
def eval_grad_inner(args: "ElementEvalArg", s:
|
|
56
|
+
def eval_grad_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
|
|
57
57
|
"""Device function evaluating the inner field gradient at a sample point"""
|
|
58
58
|
raise NotImplementedError
|
|
59
59
|
|
|
60
60
|
@staticmethod
|
|
61
|
-
def eval_div_inner(args: "ElementEvalArg", s:
|
|
61
|
+
def eval_div_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
|
|
62
62
|
"""Device function evaluating the inner field divergence at a sample point"""
|
|
63
63
|
raise NotImplementedError
|
|
64
64
|
|
|
65
65
|
@staticmethod
|
|
66
|
-
def eval_outer(args: "ElementEvalArg", s:
|
|
66
|
+
def eval_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
|
|
67
67
|
"""Device function evaluating the outer field value at a sample point"""
|
|
68
68
|
raise NotImplementedError
|
|
69
69
|
|
|
70
70
|
@staticmethod
|
|
71
|
-
def eval_grad_outer(args: "ElementEvalArg", s:
|
|
71
|
+
def eval_grad_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
|
|
72
72
|
"""Device function evaluating the outer field gradient at a sample point"""
|
|
73
73
|
raise NotImplementedError
|
|
74
74
|
|
|
75
75
|
@staticmethod
|
|
76
|
-
def eval_div_outer(args: "ElementEvalArg", s:
|
|
76
|
+
def eval_div_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
|
|
77
77
|
"""Device function evaluating the outer field divergence at a sample point"""
|
|
78
78
|
raise NotImplementedError
|
|
79
79
|
|
|
@@ -82,6 +82,10 @@ class FieldLike:
|
|
|
82
82
|
"""Polynomial degree of the field is applicable, or hint for determination of interpolation order"""
|
|
83
83
|
raise NotImplementedError
|
|
84
84
|
|
|
85
|
+
def notify_operator_usage(self, ops: Set[Operator]):
|
|
86
|
+
"""Makes the Domain aware that the operators `ops` will be applied"""
|
|
87
|
+
pass
|
|
88
|
+
|
|
85
89
|
|
|
86
90
|
class GeometryField(FieldLike):
|
|
87
91
|
"""Base class for fields defined over a geometry"""
|
|
@@ -97,12 +101,12 @@ class GeometryField(FieldLike):
|
|
|
97
101
|
raise NotImplementedError
|
|
98
102
|
|
|
99
103
|
@staticmethod
|
|
100
|
-
def eval_reference_grad_inner(args: "ElementEvalArg", s:
|
|
104
|
+
def eval_reference_grad_inner(args: "ElementEvalArg", s: Sample): # noqa: F821
|
|
101
105
|
"""Device function evaluating the inner field gradient with respect to reference element coordinates at a sample point"""
|
|
102
106
|
raise NotImplementedError
|
|
103
107
|
|
|
104
108
|
@staticmethod
|
|
105
|
-
def eval_reference_grad_outer(args: "ElementEvalArg", s:
|
|
109
|
+
def eval_reference_grad_outer(args: "ElementEvalArg", s: Sample): # noqa: F821
|
|
106
110
|
"""Device function evaluating the outer field gradient with respect to reference element coordinates at a sample point"""
|
|
107
111
|
raise NotImplementedError
|
|
108
112
|
|
|
@@ -128,6 +132,9 @@ class SpaceField(GeometryField):
|
|
|
128
132
|
self._space = space
|
|
129
133
|
self._space_partition = space_partition
|
|
130
134
|
|
|
135
|
+
self.gradient_valid = self.space.gradient_valid
|
|
136
|
+
self.divergence_valid = self.space.divergence_valid
|
|
137
|
+
|
|
131
138
|
@property
|
|
132
139
|
def geometry(self) -> Geometry:
|
|
133
140
|
return self._space.geometry
|
|
@@ -156,17 +163,22 @@ class SpaceField(GeometryField):
|
|
|
156
163
|
def dof_dtype(self) -> type:
|
|
157
164
|
return self.space.dof_dtype
|
|
158
165
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
166
|
+
@property
|
|
167
|
+
def gradient_dtype(self):
|
|
168
|
+
"""Return type of the gradient operator. Assumes self.gradient_valid()"""
|
|
169
|
+
if wp.types.type_is_vector(self.dtype):
|
|
170
|
+
return cache.cached_mat_type(
|
|
171
|
+
shape=(wp.types.type_length(self.dtype), self.geometry.dimension),
|
|
172
|
+
dtype=wp.types.type_scalar_type(self.dtype),
|
|
173
|
+
)
|
|
174
|
+
return cache.cached_vec_type(length=self.geometry.dimension, dtype=wp.types.type_scalar_type(self.dtype))
|
|
162
175
|
|
|
163
|
-
|
|
164
|
-
|
|
176
|
+
@property
|
|
177
|
+
def divergence_dtype(self):
|
|
178
|
+
"""Return type of the divergence operator. Assumes self.gradient_valid()"""
|
|
165
179
|
if wp.types.type_is_vector(self.dtype):
|
|
166
|
-
return wp.types.
|
|
167
|
-
|
|
168
|
-
return self.dtype._shape_[0] == self.space.geometry.dimension
|
|
169
|
-
return False
|
|
180
|
+
return wp.types.type_scalar_type(self.dtype)
|
|
181
|
+
return cache.cached_vec_type(length=self.dtype._shape_[1], dtype=wp.types.type_scalar_type(self.dtype))
|
|
170
182
|
|
|
171
183
|
def _make_eval_degree(self):
|
|
172
184
|
ORDER = self.space.ORDER
|