warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,535 @@
1
+ # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
3
+ # and proprietary rights in and to this software, related documentation
4
+ # and any modifications thereto. Any use, reproduction, disclosure or
5
+ # distribution of this software and related documentation without an express
6
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
+
8
+ #############################################################################
9
+ # Example Differentiable Ray Caster
10
+ #
11
+ # Shows how to use the built-in wp.Mesh data structure and wp.mesh_query_ray()
12
+ # function to implement a basic differentiable ray caster
13
+ #
14
+ ##############################################################################
15
+
16
+ import math
17
+ import os
18
+
19
+ import numpy as np
20
+ from pxr import Usd, UsdGeom
21
+
22
+ import warp as wp
23
+ from warp.optim import SGD
24
+
25
+ wp.init()
26
+
27
+
28
+ class RenderMode:
29
+ """Rendering modes
30
+ grayscale: Lambertian shading from multiple directional lights
31
+ texture: 2D texture map
32
+ normal_map: mesh normal computed from interpolated vertex normals
33
+ """
34
+
35
+ grayscale = 0
36
+ texture = 1
37
+ normal_map = 2
38
+
39
+
40
+ @wp.struct
41
+ class RenderMesh:
42
+ """Mesh to be ray casted.
43
+ Assumes a triangle mesh as input.
44
+ Per-vertex normals are computed with compute_vertex_normals()
45
+ """
46
+
47
+ id: wp.uint64
48
+ vertices: wp.array(dtype=wp.vec3)
49
+ indices: wp.array(dtype=int)
50
+ tex_coords: wp.array(dtype=wp.vec2)
51
+ tex_indices: wp.array(dtype=int)
52
+ vertex_normals: wp.array(dtype=wp.vec3)
53
+ pos: wp.array(dtype=wp.vec3)
54
+ rot: wp.array(dtype=wp.quat)
55
+
56
+
57
+ @wp.struct
58
+ class Camera:
59
+ """Basic camera for ray casting"""
60
+
61
+ horizontal: float
62
+ vertical: float
63
+ aspect: float
64
+ e: float
65
+ tan: float
66
+ pos: wp.vec3
67
+ rot: wp.quat
68
+
69
+
70
+ @wp.struct
71
+ class DirectionalLights:
72
+ """Stores arrays of directional light directions and intensities."""
73
+
74
+ dirs: wp.array(dtype=wp.vec3)
75
+ intensities: wp.array(dtype=float)
76
+ num_lights: int
77
+
78
+
79
+ @wp.kernel
80
+ def vertex_normal_sum_kernel(
81
+ verts: wp.array(dtype=wp.vec3), indices: wp.array(dtype=int), normal_sums: wp.array(dtype=wp.vec3)
82
+ ):
83
+ tid = wp.tid()
84
+
85
+ i = indices[tid * 3]
86
+ j = indices[tid * 3 + 1]
87
+ k = indices[tid * 3 + 2]
88
+
89
+ a = verts[i]
90
+ b = verts[j]
91
+ c = verts[k]
92
+
93
+ ab = b - a
94
+ ac = c - a
95
+
96
+ area_normal = wp.cross(ab, ac)
97
+ wp.atomic_add(normal_sums, i, area_normal)
98
+ wp.atomic_add(normal_sums, j, area_normal)
99
+ wp.atomic_add(normal_sums, k, area_normal)
100
+
101
+
102
+ @wp.kernel
103
+ def normalize_kernel(
104
+ normal_sums: wp.array(dtype=wp.vec3),
105
+ vertex_normals: wp.array(dtype=wp.vec3),
106
+ ):
107
+ tid = wp.tid()
108
+ vertex_normals[tid] = wp.normalize(normal_sums[tid])
109
+
110
+
111
+ @wp.func
112
+ def texture_interpolation(tex_interp: wp.vec2, texture: wp.array2d(dtype=wp.vec3)):
113
+ tex_width = texture.shape[1]
114
+ tex_height = texture.shape[0]
115
+ tex = wp.vec2(tex_interp[0] * float(tex_width - 1), (1.0 - tex_interp[1]) * float(tex_height - 1))
116
+
117
+ x0 = int(tex[0])
118
+ x1 = x0 + 1
119
+ alpha_x = tex[0] - float(x0)
120
+ y0 = int(tex[1])
121
+ y1 = y0 + 1
122
+ alpha_y = tex[1] - float(y0)
123
+ c00 = texture[y0, x0]
124
+ c10 = texture[y0, x1]
125
+ c01 = texture[y1, x0]
126
+ c11 = texture[y1, x1]
127
+ lower = (1.0 - alpha_x) * c00 + alpha_x * c10
128
+ upper = (1.0 - alpha_x) * c01 + alpha_x * c11
129
+ color = (1.0 - alpha_y) * lower + alpha_y * upper
130
+
131
+ return color
132
+
133
+
134
+ @wp.kernel
135
+ def draw_kernel(
136
+ mesh: RenderMesh,
137
+ camera: Camera,
138
+ texture: wp.array2d(dtype=wp.vec3),
139
+ rays_width: int,
140
+ rays_height: int,
141
+ rays: wp.array(dtype=wp.vec3),
142
+ lights: DirectionalLights,
143
+ mode: int,
144
+ ):
145
+ tid = wp.tid()
146
+
147
+ x = tid % rays_width
148
+ y = rays_height - tid // rays_width
149
+
150
+ sx = 2.0 * float(x) / float(rays_width) - 1.0
151
+ sy = 2.0 * float(y) / float(rays_height) - 1.0
152
+
153
+ # compute view ray in world space
154
+ ro_world = camera.pos
155
+ rd_world = wp.normalize(wp.quat_rotate(camera.rot, wp.vec3(sx * camera.tan * camera.aspect, sy * camera.tan, -1.0)))
156
+
157
+ # compute view ray in mesh space
158
+ inv = wp.transform_inverse(wp.transform(mesh.pos[0], mesh.rot[0]))
159
+ ro = wp.transform_point(inv, ro_world)
160
+ rd = wp.transform_vector(inv, rd_world)
161
+
162
+ color = wp.vec3(0.0, 0.0, 0.0)
163
+
164
+ query = wp.mesh_query_ray(mesh.id, ro, rd, 1.0e6)
165
+ if query.result:
166
+ i = mesh.indices[query.face * 3]
167
+ j = mesh.indices[query.face * 3 + 1]
168
+ k = mesh.indices[query.face * 3 + 2]
169
+
170
+ a = mesh.vertices[i]
171
+ b = mesh.vertices[j]
172
+ c = mesh.vertices[k]
173
+
174
+ p = wp.mesh_eval_position(mesh.id, query.face, query.u, query.v)
175
+
176
+ # barycentric coordinates
177
+ tri_area = wp.length(wp.cross(b - a, c - a))
178
+ w = wp.length(wp.cross(b - a, p - a)) / tri_area
179
+ v = wp.length(wp.cross(p - a, c - a)) / tri_area
180
+ u = 1.0 - w - v
181
+
182
+ a_n = mesh.vertex_normals[i]
183
+ b_n = mesh.vertex_normals[j]
184
+ c_n = mesh.vertex_normals[k]
185
+
186
+ # vertex normal interpolation
187
+ normal = u * a_n + v * b_n + w * c_n
188
+
189
+ if mode == 0 or mode == 1:
190
+ if mode == 0: # grayscale
191
+ color = wp.vec3(1.0)
192
+
193
+ elif mode == 1: # texture interpolation
194
+ tex_a = mesh.tex_coords[mesh.tex_indices[query.face * 3]]
195
+ tex_b = mesh.tex_coords[mesh.tex_indices[query.face * 3 + 1]]
196
+ tex_c = mesh.tex_coords[mesh.tex_indices[query.face * 3 + 2]]
197
+
198
+ tex = u * tex_a + v * tex_b + w * tex_c
199
+
200
+ color = texture_interpolation(tex, texture)
201
+
202
+ # lambertian directional lighting
203
+ lambert = float(0.0)
204
+ for i in range(lights.num_lights):
205
+ dir = wp.transform_vector(inv, lights.dirs[i])
206
+ val = lights.intensities[i] * wp.dot(normal, dir)
207
+ if val < 0.0:
208
+ val = 0.0
209
+ lambert = lambert + val
210
+
211
+ color = lambert * color
212
+
213
+ elif mode == 2: # normal map
214
+ color = normal * 0.5 + wp.vec3(0.5, 0.5, 0.5)
215
+
216
+ if color[0] > 1.0:
217
+ color = wp.vec3(1.0, color[1], color[2])
218
+ if color[1] > 1.0:
219
+ color = wp.vec3(color[0], 1.0, color[2])
220
+ if color[2] > 1.0:
221
+ color = wp.vec3(color[0], color[1], 1.0)
222
+
223
+ rays[tid] = color
224
+
225
+
226
+ @wp.kernel
227
+ def downsample_kernel(
228
+ rays: wp.array(dtype=wp.vec3), pixels: wp.array(dtype=wp.vec3), rays_width: int, num_samples: int
229
+ ):
230
+ tid = wp.tid()
231
+
232
+ pixels_width = rays_width / num_samples
233
+ px = tid % pixels_width
234
+ py = tid // pixels_width
235
+ start_idx = py * num_samples * rays_width + px * num_samples
236
+
237
+ color = wp.vec3(0.0, 0.0, 0.0)
238
+
239
+ for i in range(0, num_samples):
240
+ for j in range(0, num_samples):
241
+ ray = rays[start_idx + i * rays_width + j]
242
+ color = wp.vec3(color[0] + ray[0], color[1] + ray[1], color[2] + ray[2])
243
+
244
+ num_samples_sq = float(num_samples * num_samples)
245
+ color = wp.vec3(color[0] / num_samples_sq, color[1] / num_samples_sq, color[2] / num_samples_sq)
246
+ pixels[tid] = color
247
+
248
+
249
+ @wp.kernel
250
+ def loss_kernel(pixels: wp.array(dtype=wp.vec3), target_pixels: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)):
251
+ tid = wp.tid()
252
+
253
+ pixel = pixels[tid]
254
+ target_pixel = target_pixels[tid]
255
+
256
+ diff = target_pixel - pixel
257
+
258
+ # pseudo Huber loss
259
+ delta = 1.0
260
+ x = delta * delta * (wp.sqrt(1.0 + (diff[0] / delta) * (diff[0] / delta)) - 1.0)
261
+ y = delta * delta * (wp.sqrt(1.0 + (diff[1] / delta) * (diff[1] / delta)) - 1.0)
262
+ z = delta * delta * (wp.sqrt(1.0 + (diff[2] / delta) * (diff[2] / delta)) - 1.0)
263
+ sum = x + y + z
264
+
265
+ wp.atomic_add(loss, 0, sum)
266
+
267
+
268
+ @wp.kernel
269
+ def normalize(x: wp.array(dtype=wp.quat)):
270
+ tid = wp.tid()
271
+
272
+ x[tid] = wp.normalize(x[tid])
273
+
274
+
275
+ class Example:
276
+ """
277
+ Non-differentiable variables:
278
+ camera.horizontal: camera horizontal aperture size
279
+ camera.vertical: camera vertical aperture size
280
+ camera.aspect: camera aspect ratio
281
+ camera.e: focal length
282
+ camera.pos: camera displacement
283
+ camera.rot: camera rotation (quaternion)
284
+ pix_width: final image width in pixels
285
+ pix_height: final image height in pixels
286
+ num_samples: anti-aliasing. calculated as pow(2, num_samples)
287
+ directional_lights: characterized by intensity (scalar) and direction (vec3)
288
+ render_mesh.indices: mesh vertex indices
289
+ render_mesh.tex_indices: texture indices
290
+
291
+ Differentiable variables:
292
+ render_mesh.pos: parent transform displacement
293
+ render_mesh.quat: parent transform rotation (quaternion)
294
+ render_mesh.vertices: mesh vertex positions
295
+ render_mesh.vertex_normals: mesh vertex normals
296
+ render_mesh.tex_coords: 2D texture coordinates
297
+ """
298
+
299
+ def __init__(self, stage=None, rot_array=[0.0, 0.0, 0.0, 1.0], verbose=False):
300
+ self.verbose = verbose
301
+ cam_pos = wp.vec3(0.0, 0.75, 7.0)
302
+ cam_rot = wp.quat(0.0, 0.0, 0.0, 1.0)
303
+ horizontal_aperture = 36.0
304
+ vertical_aperture = 20.25
305
+ aspect = horizontal_aperture / vertical_aperture
306
+ focal_length = 50.0
307
+ self.height = 1024
308
+ self.width = int(aspect * self.height)
309
+ self.num_pixels = self.width * self.height
310
+
311
+ asset_stage = Usd.Stage.Open(os.path.join(os.path.dirname(__file__), "../assets/bunny.usd"))
312
+ mesh_geom = UsdGeom.Mesh(asset_stage.GetPrimAtPath("/bunny/bunny"))
313
+
314
+ points = np.array(mesh_geom.GetPointsAttr().Get())
315
+ indices = np.array(mesh_geom.GetFaceVertexIndicesAttr().Get())
316
+ num_points = points.shape[0]
317
+ num_faces = int(indices.shape[0] / 3)
318
+
319
+ # manufacture texture coordinates + indices for this asset
320
+ distance = np.linalg.norm(points, axis=1)
321
+ radius = np.max(distance)
322
+ distance = distance / radius
323
+ tex_coords = np.stack((distance, distance), axis=1)
324
+ tex_indices = indices
325
+
326
+ # manufacture texture for this asset
327
+ x = np.arange(256.0)
328
+ xx, yy = np.meshgrid(x, x)
329
+ zz = np.zeros_like(xx)
330
+ texture_host = np.stack((xx, yy, zz), axis=2) / 255.0
331
+
332
+ # set anti-aliasing
333
+ self.num_samples = 1
334
+
335
+ # set render mode
336
+ self.render_mode = RenderMode.texture
337
+
338
+ # set training iterations
339
+ self.train_rate = 3.0e-8
340
+ self.train_rate = 5.00e-8
341
+ self.momentum = 0.5
342
+ self.dampening = 0.1
343
+ self.weight_decay = 0.0
344
+ self.train_iters = 150
345
+ self.period = 10
346
+ self.iter = 0
347
+
348
+ # storage for training animation
349
+ self.images = np.zeros((self.height, self.width, 3, int(self.train_iters / self.period)))
350
+ self.image_counter = 0
351
+
352
+ # construct RenderMesh
353
+ self.render_mesh = RenderMesh()
354
+ self.mesh = wp.Mesh(
355
+ points=wp.array(points, dtype=wp.vec3, requires_grad=True), indices=wp.array(indices, dtype=int)
356
+ )
357
+ self.render_mesh.id = self.mesh.id
358
+ self.render_mesh.vertices = self.mesh.points
359
+ self.render_mesh.indices = self.mesh.indices
360
+ self.render_mesh.tex_coords = wp.array(tex_coords, dtype=wp.vec2, requires_grad=True)
361
+ self.render_mesh.tex_indices = wp.array(tex_indices, dtype=int)
362
+ self.normal_sums = wp.zeros(num_points, dtype=wp.vec3, requires_grad=True)
363
+ self.render_mesh.vertex_normals = wp.zeros(num_points, dtype=wp.vec3, requires_grad=True)
364
+ self.render_mesh.pos = wp.zeros(1, dtype=wp.vec3, requires_grad=True)
365
+ self.render_mesh.rot = wp.array(np.array(rot_array), dtype=wp.quat, requires_grad=True)
366
+
367
+ # compute vertex normals
368
+ wp.launch(
369
+ kernel=vertex_normal_sum_kernel,
370
+ dim=num_faces,
371
+ inputs=[self.render_mesh.vertices, self.render_mesh.indices, self.normal_sums],
372
+ )
373
+ wp.launch(
374
+ kernel=normalize_kernel,
375
+ dim=num_points,
376
+ inputs=[self.normal_sums, self.render_mesh.vertex_normals],
377
+ )
378
+
379
+ # construct camera
380
+ self.camera = Camera()
381
+ self.camera.horizontal = horizontal_aperture
382
+ self.camera.vertical = vertical_aperture
383
+ self.camera.aspect = aspect
384
+ self.camera.e = focal_length
385
+ self.camera.tan = vertical_aperture / (2.0 * focal_length)
386
+ self.camera.pos = cam_pos
387
+ self.camera.rot = cam_rot
388
+
389
+ # construct texture
390
+ self.texture = wp.array2d(texture_host, dtype=wp.vec3, requires_grad=True)
391
+
392
+ # construct lights
393
+ self.lights = DirectionalLights()
394
+ self.lights.dirs = wp.array(np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), dtype=wp.vec3, requires_grad=True)
395
+ self.lights.intensities = wp.array(np.array([2.0, 0.2]), dtype=float, requires_grad=True)
396
+ self.lights.num_lights = 2
397
+
398
+ # construct rays
399
+ self.rays_width = self.width * pow(2, self.num_samples)
400
+ self.rays_height = self.height * pow(2, self.num_samples)
401
+ self.num_rays = self.rays_width * self.rays_height
402
+ self.rays = wp.zeros(self.num_rays, dtype=wp.vec3, requires_grad=True)
403
+
404
+ # construct pixels
405
+ self.pixels = wp.zeros(self.num_pixels, dtype=wp.vec3, requires_grad=True)
406
+ self.target_pixels = wp.zeros(self.num_pixels, dtype=wp.vec3)
407
+
408
+ # loss array
409
+ self.loss = wp.zeros(1, dtype=float, requires_grad=True)
410
+
411
+ # capture graph
412
+ self.use_graph = wp.get_device().is_cuda
413
+ if self.use_graph:
414
+ with wp.ScopedCapture() as capture:
415
+ self.tape = wp.Tape()
416
+ with self.tape:
417
+ self.forward()
418
+ self.tape.backward(self.loss)
419
+ self.graph = capture.graph
420
+
421
+ self.optimizer = SGD(
422
+ [self.render_mesh.rot],
423
+ self.train_rate,
424
+ momentum=self.momentum,
425
+ dampening=self.dampening,
426
+ weight_decay=self.weight_decay,
427
+ )
428
+
429
+ def ray_cast(self):
430
+ # raycast
431
+ wp.launch(
432
+ kernel=draw_kernel,
433
+ dim=self.num_rays,
434
+ inputs=[
435
+ self.render_mesh,
436
+ self.camera,
437
+ self.texture,
438
+ self.rays_width,
439
+ self.rays_height,
440
+ self.rays,
441
+ self.lights,
442
+ self.render_mode,
443
+ ],
444
+ )
445
+
446
+ # downsample
447
+ wp.launch(
448
+ kernel=downsample_kernel,
449
+ dim=self.num_pixels,
450
+ inputs=[self.rays, self.pixels, self.rays_width, pow(2, self.num_samples)],
451
+ )
452
+
453
+ def forward(self):
454
+ self.ray_cast()
455
+
456
+ # compute pixel loss
457
+ wp.launch(loss_kernel, dim=self.num_pixels, inputs=[self.pixels, self.target_pixels, self.loss])
458
+
459
+ def step(self):
460
+ if self.use_graph:
461
+ wp.capture_launch(self.graph)
462
+ else:
463
+ self.tape = wp.Tape()
464
+ with self.tape:
465
+ self.forward()
466
+ self.tape.backward(self.loss)
467
+
468
+ rot_grad = self.tape.gradients[self.render_mesh.rot]
469
+ self.optimizer.step([rot_grad])
470
+ wp.launch(normalize, dim=1, inputs=[self.render_mesh.rot])
471
+
472
+ if self.verbose and self.iter % self.period == 0:
473
+ print(f"Iter: {self.iter} Loss: {self.loss}")
474
+
475
+ self.tape.zero()
476
+ self.loss.zero_()
477
+
478
+ self.iter = self.iter + 1
479
+
480
+ def render(self):
481
+ self.images[:, :, :, self.image_counter] = self.get_image()
482
+ self.image_counter += 1
483
+
484
+ def get_image(self):
485
+ return self.pixels.numpy().reshape((self.height, self.width, 3))
486
+
487
+ def get_animation(self):
488
+ fig, ax = plt.subplots()
489
+ plt.axis("off")
490
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
491
+ plt.margins(0, 0)
492
+
493
+ frames = []
494
+ for i in range(self.images.shape[3]):
495
+ frame = ax.imshow(self.images[:, :, :, i], animated=True)
496
+ frames.append([frame])
497
+
498
+ ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True, repeat_delay=1000)
499
+ return ani
500
+
501
+
502
+ if __name__ == "__main__":
503
+ import matplotlib.animation as animation
504
+ import matplotlib.image as img
505
+ import matplotlib.pyplot as plt
506
+
507
+ output_dir = os.path.dirname(__file__)
508
+
509
+ reference_example = Example()
510
+
511
+ # render target rotation
512
+ reference_example.ray_cast()
513
+ target_image = reference_example.get_image()
514
+ img.imsave(output_dir + "/example_diffray_target_image.png", target_image)
515
+
516
+ # offset mesh rotation
517
+ example = Example(
518
+ rot_array=[0.0, (math.sqrt(3) - 1) / (2.0 * math.sqrt(2.0)), 0.0, (math.sqrt(3) + 1) / (2.0 * math.sqrt(2.0))],
519
+ verbose=True,
520
+ )
521
+
522
+ wp.copy(example.target_pixels, reference_example.pixels)
523
+
524
+ # recover target rotation
525
+ for i in range(example.train_iters):
526
+ example.step()
527
+
528
+ if i % example.period == 0:
529
+ example.render()
530
+
531
+ final_image = example.get_image()
532
+ img.imsave(output_dir + "/example_diffray_final_image.png", final_image)
533
+
534
+ video = example.get_animation()
535
+ video.save(output_dir + "/example_diffray_animation.gif", dpi=300, writer=animation.PillowWriter(fps=5))