warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -13,23 +13,21 @@
13
13
  #
14
14
  ##############################################################################
15
15
 
16
- import matplotlib.pyplot as plt
17
- import matplotlib.image as img
18
- import matplotlib.animation as animation
19
- from pxr import Usd, UsdGeom
16
+ import math
17
+ import os
20
18
 
21
- import warp as wp
22
19
  import numpy as np
20
+ from pxr import Usd, UsdGeom
23
21
 
24
- import os
25
- import math
22
+ import warp as wp
23
+ from warp.optim import SGD
26
24
 
27
25
  wp.init()
28
26
 
29
27
 
30
28
  class RenderMode:
31
29
  """Rendering modes
32
- grayscale: lambertian shading from multiple directional lights
30
+ grayscale: Lambertian shading from multiple directional lights
33
31
  texture: 2D texture map
34
32
  normal_map: mesh normal computed from interpolated vertex normals
35
33
  """
@@ -78,6 +76,208 @@ class DirectionalLights:
78
76
  num_lights: int
79
77
 
80
78
 
79
+ @wp.kernel
80
+ def vertex_normal_sum_kernel(
81
+ verts: wp.array(dtype=wp.vec3), indices: wp.array(dtype=int), normal_sums: wp.array(dtype=wp.vec3)
82
+ ):
83
+ tid = wp.tid()
84
+
85
+ i = indices[tid * 3]
86
+ j = indices[tid * 3 + 1]
87
+ k = indices[tid * 3 + 2]
88
+
89
+ a = verts[i]
90
+ b = verts[j]
91
+ c = verts[k]
92
+
93
+ ab = b - a
94
+ ac = c - a
95
+
96
+ area_normal = wp.cross(ab, ac)
97
+ wp.atomic_add(normal_sums, i, area_normal)
98
+ wp.atomic_add(normal_sums, j, area_normal)
99
+ wp.atomic_add(normal_sums, k, area_normal)
100
+
101
+
102
+ @wp.kernel
103
+ def normalize_kernel(
104
+ normal_sums: wp.array(dtype=wp.vec3),
105
+ vertex_normals: wp.array(dtype=wp.vec3),
106
+ ):
107
+ tid = wp.tid()
108
+ vertex_normals[tid] = wp.normalize(normal_sums[tid])
109
+
110
+
111
+ @wp.func
112
+ def texture_interpolation(tex_interp: wp.vec2, texture: wp.array2d(dtype=wp.vec3)):
113
+ tex_width = texture.shape[1]
114
+ tex_height = texture.shape[0]
115
+ tex = wp.vec2(tex_interp[0] * float(tex_width - 1), (1.0 - tex_interp[1]) * float(tex_height - 1))
116
+
117
+ x0 = int(tex[0])
118
+ x1 = x0 + 1
119
+ alpha_x = tex[0] - float(x0)
120
+ y0 = int(tex[1])
121
+ y1 = y0 + 1
122
+ alpha_y = tex[1] - float(y0)
123
+ c00 = texture[y0, x0]
124
+ c10 = texture[y0, x1]
125
+ c01 = texture[y1, x0]
126
+ c11 = texture[y1, x1]
127
+ lower = (1.0 - alpha_x) * c00 + alpha_x * c10
128
+ upper = (1.0 - alpha_x) * c01 + alpha_x * c11
129
+ color = (1.0 - alpha_y) * lower + alpha_y * upper
130
+
131
+ return color
132
+
133
+
134
+ @wp.kernel
135
+ def draw_kernel(
136
+ mesh: RenderMesh,
137
+ camera: Camera,
138
+ texture: wp.array2d(dtype=wp.vec3),
139
+ rays_width: int,
140
+ rays_height: int,
141
+ rays: wp.array(dtype=wp.vec3),
142
+ lights: DirectionalLights,
143
+ mode: int,
144
+ ):
145
+ tid = wp.tid()
146
+
147
+ x = tid % rays_width
148
+ y = rays_height - tid // rays_width
149
+
150
+ sx = 2.0 * float(x) / float(rays_width) - 1.0
151
+ sy = 2.0 * float(y) / float(rays_height) - 1.0
152
+
153
+ # compute view ray in world space
154
+ ro_world = camera.pos
155
+ rd_world = wp.normalize(wp.quat_rotate(camera.rot, wp.vec3(sx * camera.tan * camera.aspect, sy * camera.tan, -1.0)))
156
+
157
+ # compute view ray in mesh space
158
+ inv = wp.transform_inverse(wp.transform(mesh.pos[0], mesh.rot[0]))
159
+ ro = wp.transform_point(inv, ro_world)
160
+ rd = wp.transform_vector(inv, rd_world)
161
+
162
+ t = float(0.0)
163
+ ur = float(0.0)
164
+ vr = float(0.0)
165
+ sign = float(0.0)
166
+ n = wp.vec3()
167
+ f = int(0)
168
+
169
+ color = wp.vec3(0.0, 0.0, 0.0)
170
+
171
+ if wp.mesh_query_ray(mesh.id, ro, rd, 1.0e6, t, ur, vr, sign, n, f):
172
+ i = mesh.indices[f * 3]
173
+ j = mesh.indices[f * 3 + 1]
174
+ k = mesh.indices[f * 3 + 2]
175
+
176
+ a = mesh.vertices[i]
177
+ b = mesh.vertices[j]
178
+ c = mesh.vertices[k]
179
+
180
+ p = wp.mesh_eval_position(mesh.id, f, ur, vr)
181
+
182
+ # barycentric coordinates
183
+ tri_area = wp.length(wp.cross(b - a, c - a))
184
+ w = wp.length(wp.cross(b - a, p - a)) / tri_area
185
+ v = wp.length(wp.cross(p - a, c - a)) / tri_area
186
+ u = 1.0 - w - v
187
+
188
+ a_n = mesh.vertex_normals[i]
189
+ b_n = mesh.vertex_normals[j]
190
+ c_n = mesh.vertex_normals[k]
191
+
192
+ # vertex normal interpolation
193
+ normal = u * a_n + v * b_n + w * c_n
194
+
195
+ if mode == 0 or mode == 1:
196
+ if mode == 0: # grayscale
197
+ color = wp.vec3(1.0)
198
+
199
+ elif mode == 1: # texture interpolation
200
+ tex_a = mesh.tex_coords[mesh.tex_indices[f * 3]]
201
+ tex_b = mesh.tex_coords[mesh.tex_indices[f * 3 + 1]]
202
+ tex_c = mesh.tex_coords[mesh.tex_indices[f * 3 + 2]]
203
+
204
+ tex = u * tex_a + v * tex_b + w * tex_c
205
+
206
+ color = texture_interpolation(tex, texture)
207
+
208
+ # lambertian directional lighting
209
+ lambert = float(0.0)
210
+ for i in range(lights.num_lights):
211
+ dir = wp.transform_vector(inv, lights.dirs[i])
212
+ val = lights.intensities[i] * wp.dot(normal, dir)
213
+ if val < 0.0:
214
+ val = 0.0
215
+ lambert = lambert + val
216
+
217
+ color = lambert * color
218
+
219
+ elif mode == 2: # normal map
220
+ color = normal * 0.5 + wp.vec3(0.5, 0.5, 0.5)
221
+
222
+ if color[0] > 1.0:
223
+ color = wp.vec3(1.0, color[1], color[2])
224
+ if color[1] > 1.0:
225
+ color = wp.vec3(color[0], 1.0, color[2])
226
+ if color[2] > 1.0:
227
+ color = wp.vec3(color[0], color[1], 1.0)
228
+
229
+ rays[tid] = color
230
+
231
+
232
+ @wp.kernel
233
+ def downsample_kernel(
234
+ rays: wp.array(dtype=wp.vec3), pixels: wp.array(dtype=wp.vec3), rays_width: int, num_samples: int
235
+ ):
236
+ tid = wp.tid()
237
+
238
+ pixels_width = rays_width / num_samples
239
+ px = tid % pixels_width
240
+ py = tid // pixels_width
241
+ start_idx = py * num_samples * rays_width + px * num_samples
242
+
243
+ color = wp.vec3(0.0, 0.0, 0.0)
244
+
245
+ for i in range(0, num_samples):
246
+ for j in range(0, num_samples):
247
+ ray = rays[start_idx + i * rays_width + j]
248
+ color = wp.vec3(color[0] + ray[0], color[1] + ray[1], color[2] + ray[2])
249
+
250
+ num_samples_sq = float(num_samples * num_samples)
251
+ color = wp.vec3(color[0] / num_samples_sq, color[1] / num_samples_sq, color[2] / num_samples_sq)
252
+ pixels[tid] = color
253
+
254
+
255
+ @wp.kernel
256
+ def loss_kernel(pixels: wp.array(dtype=wp.vec3), target_pixels: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)):
257
+ tid = wp.tid()
258
+
259
+ pixel = pixels[tid]
260
+ target_pixel = target_pixels[tid]
261
+
262
+ diff = target_pixel - pixel
263
+
264
+ # pseudo Huber loss
265
+ delta = 1.0
266
+ x = delta * delta * (wp.sqrt(1.0 + (diff[0] / delta) * (diff[0] / delta)) - 1.0)
267
+ y = delta * delta * (wp.sqrt(1.0 + (diff[1] / delta) * (diff[1] / delta)) - 1.0)
268
+ z = delta * delta * (wp.sqrt(1.0 + (diff[2] / delta) * (diff[2] / delta)) - 1.0)
269
+ sum = x + y + z
270
+
271
+ wp.atomic_add(loss, 0, sum)
272
+
273
+
274
+ @wp.kernel
275
+ def normalize(x: wp.array(dtype=wp.quat)):
276
+ tid = wp.tid()
277
+
278
+ x[tid] = wp.normalize(x[tid])
279
+
280
+
81
281
  class Example:
82
282
  """A basic differentiable ray tracer
83
283
 
@@ -103,7 +303,10 @@ class Example:
103
303
  render_mesh.tex_coords: 2D texture coordinates
104
304
  """
105
305
 
106
- def __init__(self):
306
+ def __init__(self, stage=None, rot_array=[0.0, 0.0, 0.0, 1.0], verbose=False):
307
+ self.device = wp.get_device()
308
+
309
+ self.verbose = verbose
107
310
  cam_pos = wp.vec3(0.0, 0.75, 7.0)
108
311
  cam_rot = wp.quat(0.0, 0.0, 0.0, 1.0)
109
312
  horizontal_aperture = 36.0
@@ -143,328 +346,126 @@ class Example:
143
346
 
144
347
  # set training iterations
145
348
  self.train_rate = 3.0e-8
146
- self.train_iters = 300
349
+ self.train_rate = 5.00e-8
350
+ self.momentum = 0.5
351
+ self.dampening = 0.1
352
+ self.weight_decay = 0.0
353
+ self.train_iters = 150
147
354
  self.period = 10
355
+ self.iter = 0
148
356
 
149
357
  # storage for training animation
150
358
  self.images = np.zeros((self.height, self.width, 3, int(self.train_iters / self.period)))
359
+ self.image_counter = 0
151
360
 
152
- with wp.ScopedDevice(device="cuda:0"):
153
- # construct RenderMesh
154
- self.render_mesh = RenderMesh()
155
- self.mesh = wp.Mesh(
156
- points=wp.array(points, dtype=wp.vec3, requires_grad=True), indices=wp.array(indices, dtype=int)
157
- )
158
- self.render_mesh.id = self.mesh.id
159
- self.render_mesh.vertices = self.mesh.points
160
- self.render_mesh.indices = self.mesh.indices
161
- self.render_mesh.tex_coords = wp.array(tex_coords, dtype=wp.vec2, requires_grad=True)
162
- self.render_mesh.tex_indices = wp.array(tex_indices, dtype=int)
163
- self.normal_sums = wp.zeros(num_points, dtype=wp.vec3, requires_grad=True)
164
- self.render_mesh.vertex_normals = wp.zeros(num_points, dtype=wp.vec3, requires_grad=True)
165
- self.render_mesh.pos = wp.zeros(1, dtype=wp.vec3, requires_grad=True)
166
- self.render_mesh.rot = wp.array(np.array([0.0, 0.0, 0.0, 1.0]), dtype=wp.quat, requires_grad=True)
167
-
168
- # compute vertex normals
169
- wp.launch(
170
- kernel=Example.vertex_normal_sum_kernel,
171
- dim=num_faces,
172
- inputs=[self.render_mesh.vertices, self.render_mesh.indices, self.normal_sums],
173
- )
174
- wp.launch(
175
- kernel=Example.normalize_kernel,
176
- dim=num_points,
177
- inputs=[self.normal_sums, self.render_mesh.vertex_normals],
178
- )
179
-
180
- # construct camera
181
- self.camera = Camera()
182
- self.camera.horizontal = horizontal_aperture
183
- self.camera.vertical = vertical_aperture
184
- self.camera.aspect = aspect
185
- self.camera.e = focal_length
186
- self.camera.tan = vertical_aperture / (2.0 * focal_length)
187
- self.camera.pos = cam_pos
188
- self.camera.rot = cam_rot
189
-
190
- # construct texture
191
- self.texture = wp.array2d(texture_host, dtype=wp.vec3, requires_grad=True)
192
-
193
- # construct lights
194
- self.lights = DirectionalLights()
195
- self.lights.dirs = wp.array(np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), dtype=wp.vec3, requires_grad=True)
196
- self.lights.intensities = wp.array(np.array([2.0, 0.2]), dtype=float, requires_grad=True)
197
- self.lights.num_lights = 2
198
-
199
- # construct rays
200
- self.rays_width = self.width * pow(2, self.num_samples)
201
- self.rays_height = self.height * pow(2, self.num_samples)
202
- self.num_rays = self.rays_width * self.rays_height
203
- self.rays = wp.zeros(self.num_rays, dtype=wp.vec3, requires_grad=True)
204
-
205
- # construct pixels
206
- self.pixels = wp.zeros(self.num_pixels, dtype=wp.vec3, requires_grad=True)
207
- self.target_pixels = wp.zeros(self.num_pixels, dtype=wp.vec3)
208
-
209
- # loss array
210
- self.loss = wp.zeros(1, dtype=float, requires_grad=True)
211
-
212
- def update(self):
213
- pass
214
-
215
- def render(self, is_live=False):
216
- with wp.ScopedDevice("cuda:0"):
217
- # raycast
218
- wp.launch(
219
- kernel=Example.draw_kernel,
220
- dim=self.num_rays,
221
- inputs=[
222
- self.render_mesh,
223
- self.camera,
224
- self.texture,
225
- self.rays_width,
226
- self.rays_height,
227
- self.rays,
228
- self.lights,
229
- self.render_mode,
230
- ],
231
- )
232
-
233
- # downsample
234
- wp.launch(
235
- kernel=Example.downsample_kernel,
236
- dim=self.num_pixels,
237
- inputs=[self.rays, self.pixels, self.rays_width, pow(2, self.num_samples)],
238
- )
239
-
240
- @wp.kernel
241
- def vertex_normal_sum_kernel(
242
- verts: wp.array(dtype=wp.vec3), indices: wp.array(dtype=int), normal_sums: wp.array(dtype=wp.vec3)
243
- ):
244
- tid = wp.tid()
245
-
246
- i = indices[tid * 3]
247
- j = indices[tid * 3 + 1]
248
- k = indices[tid * 3 + 2]
249
-
250
- a = verts[i]
251
- b = verts[j]
252
- c = verts[k]
253
-
254
- ab = b - a
255
- ac = c - a
256
-
257
- area_normal = wp.cross(ab, ac)
258
- wp.atomic_add(normal_sums, i, area_normal)
259
- wp.atomic_add(normal_sums, j, area_normal)
260
- wp.atomic_add(normal_sums, k, area_normal)
261
-
262
- @wp.kernel
263
- def normalize_kernel(
264
- normal_sums: wp.array(dtype=wp.vec3),
265
- vertex_normals: wp.array(dtype=wp.vec3),
266
- ):
267
- tid = wp.tid()
268
- vertex_normals[tid] = wp.normalize(normal_sums[tid])
269
-
270
- @wp.func
271
- def texture_interpolation(tex_interp: wp.vec2, texture: wp.array2d(dtype=wp.vec3)):
272
- tex_width = texture.shape[1]
273
- tex_height = texture.shape[0]
274
- tex = wp.vec2(tex_interp[0] * float(tex_width - 1), (1.0 - tex_interp[1]) * float(tex_height - 1))
275
-
276
- x0 = int(tex[0])
277
- x1 = x0 + 1
278
- alpha_x = tex[0] - float(x0)
279
- y0 = int(tex[1])
280
- y1 = y0 + 1
281
- alpha_y = tex[1] - float(y0)
282
- c00 = texture[y0, x0]
283
- c10 = texture[y0, x1]
284
- c01 = texture[y1, x0]
285
- c11 = texture[y1, x1]
286
- lower = (1.0 - alpha_x) * c00 + alpha_x * c10
287
- upper = (1.0 - alpha_x) * c01 + alpha_x * c11
288
- color = (1.0 - alpha_y) * lower + alpha_y * upper
289
-
290
- return color
291
-
292
- @wp.kernel
293
- def draw_kernel(
294
- mesh: RenderMesh,
295
- camera: Camera,
296
- texture: wp.array2d(dtype=wp.vec3),
297
- rays_width: int,
298
- rays_height: int,
299
- rays: wp.array(dtype=wp.vec3),
300
- lights: DirectionalLights,
301
- mode: int,
302
- ):
303
- tid = wp.tid()
304
-
305
- x = tid % rays_width
306
- y = rays_height - tid // rays_width
307
-
308
- sx = 2.0 * float(x) / float(rays_width) - 1.0
309
- sy = 2.0 * float(y) / float(rays_height) - 1.0
310
-
311
- # compute view ray in world space
312
- ro_world = camera.pos
313
- rd_world = wp.normalize(
314
- wp.quat_rotate(camera.rot, wp.vec3(sx * camera.tan * camera.aspect, sy * camera.tan, -1.0))
361
+ # construct RenderMesh
362
+ self.render_mesh = RenderMesh()
363
+ self.mesh = wp.Mesh(
364
+ points=wp.array(points, dtype=wp.vec3, requires_grad=True), indices=wp.array(indices, dtype=int)
365
+ )
366
+ self.render_mesh.id = self.mesh.id
367
+ self.render_mesh.vertices = self.mesh.points
368
+ self.render_mesh.indices = self.mesh.indices
369
+ self.render_mesh.tex_coords = wp.array(tex_coords, dtype=wp.vec2, requires_grad=True)
370
+ self.render_mesh.tex_indices = wp.array(tex_indices, dtype=int)
371
+ self.normal_sums = wp.zeros(num_points, dtype=wp.vec3, requires_grad=True)
372
+ self.render_mesh.vertex_normals = wp.zeros(num_points, dtype=wp.vec3, requires_grad=True)
373
+ self.render_mesh.pos = wp.zeros(1, dtype=wp.vec3, requires_grad=True)
374
+ self.render_mesh.rot = wp.array(np.array(rot_array), dtype=wp.quat, requires_grad=True)
375
+
376
+ # compute vertex normals
377
+ wp.launch(
378
+ kernel=vertex_normal_sum_kernel,
379
+ dim=num_faces,
380
+ inputs=[self.render_mesh.vertices, self.render_mesh.indices, self.normal_sums],
381
+ )
382
+ wp.launch(
383
+ kernel=normalize_kernel,
384
+ dim=num_points,
385
+ inputs=[self.normal_sums, self.render_mesh.vertex_normals],
315
386
  )
316
387
 
317
- # compute view ray in mesh space
318
- inv = wp.transform_inverse(wp.transform(mesh.pos[0], mesh.rot[0]))
319
- ro = wp.transform_point(inv, ro_world)
320
- rd = wp.transform_vector(inv, rd_world)
321
-
322
- t = float(0.0)
323
- ur = float(0.0)
324
- vr = float(0.0)
325
- sign = float(0.0)
326
- n = wp.vec3()
327
- f = int(0)
328
-
329
- color = wp.vec3(0.0, 0.0, 0.0)
330
-
331
- if wp.mesh_query_ray(mesh.id, ro, rd, 1.0e6, t, ur, vr, sign, n, f):
332
- i = mesh.indices[f * 3]
333
- j = mesh.indices[f * 3 + 1]
334
- k = mesh.indices[f * 3 + 2]
335
-
336
- a = mesh.vertices[i]
337
- b = mesh.vertices[j]
338
- c = mesh.vertices[k]
339
-
340
- p = wp.mesh_eval_position(mesh.id, f, ur, vr)
341
-
342
- # barycentric coordinates
343
- tri_area = wp.length(wp.cross(b - a, c - a))
344
- w = wp.length(wp.cross(b - a, p - a)) / tri_area
345
- v = wp.length(wp.cross(p - a, c - a)) / tri_area
346
- u = 1.0 - w - v
347
-
348
- a_n = mesh.vertex_normals[i]
349
- b_n = mesh.vertex_normals[j]
350
- c_n = mesh.vertex_normals[k]
351
-
352
- # vertex normal interpolation
353
- normal = u * a_n + v * b_n + w * c_n
354
-
355
- if mode == 0 or mode == 1:
356
- if mode == 0: # grayscale
357
- color = wp.vec3(1.0)
358
-
359
- elif mode == 1: # texture interpolation
360
- tex_a = mesh.tex_coords[mesh.tex_indices[f * 3]]
361
- tex_b = mesh.tex_coords[mesh.tex_indices[f * 3 + 1]]
362
- tex_c = mesh.tex_coords[mesh.tex_indices[f * 3 + 2]]
363
-
364
- tex = u * tex_a + v * tex_b + w * tex_c
365
-
366
- color = Example.texture_interpolation(tex, texture)
367
-
368
- # lambertian directional lighting
369
- lambert = float(0.0)
370
- for i in range(lights.num_lights):
371
- dir = wp.transform_vector(inv, lights.dirs[i])
372
- val = lights.intensities[i] * wp.dot(normal, dir)
373
- if val < 0.0:
374
- val = 0.0
375
- lambert = lambert + val
376
-
377
- color = lambert * color
378
-
379
- elif mode == 2: # normal map
380
- color = normal * 0.5 + wp.vec3(0.5, 0.5, 0.5)
381
-
382
- if color[0] > 1.0:
383
- color = wp.vec3(1.0, color[1], color[2])
384
- if color[1] > 1.0:
385
- color = wp.vec3(color[0], 1.0, color[2])
386
- if color[2] > 1.0:
387
- color = wp.vec3(color[0], color[1], 1.0)
388
-
389
- rays[tid] = color
390
-
391
- @wp.kernel
392
- def downsample_kernel(
393
- rays: wp.array(dtype=wp.vec3), pixels: wp.array(dtype=wp.vec3), rays_width: int, num_samples: int
394
- ):
395
- tid = wp.tid()
396
-
397
- pixels_width = rays_width / num_samples
398
- px = tid % pixels_width
399
- py = tid // pixels_width
400
- start_idx = py * num_samples * rays_width + px * num_samples
401
-
402
- color = wp.vec3(0.0, 0.0, 0.0)
403
-
404
- for i in range(0, num_samples):
405
- for j in range(0, num_samples):
406
- ray = rays[start_idx + i * rays_width + j]
407
- color = wp.vec3(color[0] + ray[0], color[1] + ray[1], color[2] + ray[2])
408
-
409
- num_samples_sq = float(num_samples * num_samples)
410
- color = wp.vec3(color[0] / num_samples_sq, color[1] / num_samples_sq, color[2] / num_samples_sq)
411
- pixels[tid] = color
412
-
413
- @wp.kernel
414
- def loss_kernel(
415
- pixels: wp.array(dtype=wp.vec3), target_pixels: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)
416
- ):
417
- tid = wp.tid()
418
-
419
- pixel = pixels[tid]
420
- target_pixel = target_pixels[tid]
421
-
422
- diff = target_pixel - pixel
423
-
424
- # pseudo Huber loss
425
- delta = 1.0
426
- x = delta * delta * (wp.sqrt(1.0 + (diff[0] / delta) * (diff[0] / delta)) - 1.0)
427
- y = delta * delta * (wp.sqrt(1.0 + (diff[1] / delta) * (diff[1] / delta)) - 1.0)
428
- z = delta * delta * (wp.sqrt(1.0 + (diff[2] / delta) * (diff[2] / delta)) - 1.0)
429
- sum = x + y + z
430
-
431
- wp.atomic_add(loss, 0, sum)
432
-
433
- @wp.kernel
434
- def step_kernel(x: wp.array(dtype=wp.quat), grad: wp.array(dtype=wp.quat), alpha: float):
435
- tid = wp.tid()
436
-
437
- # projected gradient descent
438
- x[tid] = wp.normalize(wp.sub(x[tid], wp.mul(grad[tid], alpha)))
439
-
440
- def compute_loss(self):
441
- self.render()
442
- wp.launch(self.loss_kernel, dim=self.num_pixels, inputs=[self.pixels, self.target_pixels, self.loss])
443
-
444
- def train_graph(self):
445
- with wp.ScopedDevice("cuda:0"):
446
- # capture graph
447
- wp.capture_begin()
448
- tape = wp.Tape()
449
- with tape:
388
+ # construct camera
389
+ self.camera = Camera()
390
+ self.camera.horizontal = horizontal_aperture
391
+ self.camera.vertical = vertical_aperture
392
+ self.camera.aspect = aspect
393
+ self.camera.e = focal_length
394
+ self.camera.tan = vertical_aperture / (2.0 * focal_length)
395
+ self.camera.pos = cam_pos
396
+ self.camera.rot = cam_rot
397
+
398
+ # construct texture
399
+ self.texture = wp.array2d(texture_host, dtype=wp.vec3, requires_grad=True)
400
+
401
+ # construct lights
402
+ self.lights = DirectionalLights()
403
+ self.lights.dirs = wp.array(np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), dtype=wp.vec3, requires_grad=True)
404
+ self.lights.intensities = wp.array(np.array([2.0, 0.2]), dtype=float, requires_grad=True)
405
+ self.lights.num_lights = 2
406
+
407
+ # construct rays
408
+ self.rays_width = self.width * pow(2, self.num_samples)
409
+ self.rays_height = self.height * pow(2, self.num_samples)
410
+ self.num_rays = self.rays_width * self.rays_height
411
+ self.rays = wp.zeros(self.num_rays, dtype=wp.vec3, requires_grad=True)
412
+
413
+ # construct pixels
414
+ self.pixels = wp.zeros(self.num_pixels, dtype=wp.vec3, requires_grad=True)
415
+ self.target_pixels = wp.zeros(self.num_pixels, dtype=wp.vec3)
416
+
417
+ # loss array
418
+ self.loss = wp.zeros(1, dtype=float, requires_grad=True)
419
+
420
+ # capture graph
421
+ wp.capture_begin(self.device)
422
+ try:
423
+ self.tape = wp.Tape()
424
+ with self.tape:
450
425
  self.compute_loss()
451
- tape.backward(self.loss)
452
- self.graph = wp.capture_end()
426
+ self.tape.backward(self.loss)
427
+ finally:
428
+ self.graph = wp.capture_end(self.device)
429
+
430
+ self.optimizer = SGD(
431
+ [self.render_mesh.rot],
432
+ self.train_rate,
433
+ momentum=self.momentum,
434
+ dampening=self.dampening,
435
+ weight_decay=self.weight_decay,
436
+ )
453
437
 
454
- # train
455
- image_counter = 0
456
- for i in range(self.train_iters):
457
- wp.capture_launch(self.graph)
458
- rot_grad = tape.gradients[self.render_mesh.rot]
459
- wp.launch(Example.step_kernel, dim=1, inputs=[self.render_mesh.rot, rot_grad, self.train_rate])
438
+ def ray_trace(self, is_live=False):
439
+ # raycast
440
+ wp.launch(
441
+ kernel=draw_kernel,
442
+ dim=self.num_rays,
443
+ inputs=[
444
+ self.render_mesh,
445
+ self.camera,
446
+ self.texture,
447
+ self.rays_width,
448
+ self.rays_height,
449
+ self.rays,
450
+ self.lights,
451
+ self.render_mode,
452
+ ],
453
+ device=self.device,
454
+ )
460
455
 
461
- if i % self.period == 0:
462
- print(f"Iter: {i} Loss: {self.loss}")
463
- self.images[:, :, :, image_counter] = self.get_image()
464
- image_counter += 1
456
+ # downsample
457
+ wp.launch(
458
+ kernel=downsample_kernel,
459
+ dim=self.num_pixels,
460
+ inputs=[self.rays, self.pixels, self.rays_width, pow(2, self.num_samples)],
461
+ device=self.device,
462
+ )
465
463
 
466
- tape.zero()
467
- self.loss.zero_()
464
+ def compute_loss(self):
465
+ self.ray_trace()
466
+ wp.launch(
467
+ loss_kernel, dim=self.num_pixels, inputs=[self.pixels, self.target_pixels, self.loss], device=self.device
468
+ )
468
469
 
469
470
  def get_image(self):
470
471
  return self.pixels.numpy().reshape((self.height, self.width, 3))
@@ -483,33 +484,59 @@ class Example:
483
484
  ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True, repeat_delay=1000)
484
485
  return ani
485
486
 
487
+ def update(self):
488
+ wp.capture_launch(self.graph)
489
+ rot_grad = self.tape.gradients[self.render_mesh.rot]
490
+ self.optimizer.step([rot_grad])
491
+ wp.launch(normalize, dim=1, inputs=[self.render_mesh.rot])
492
+
493
+ if self.verbose and self.iter % self.period == 0:
494
+ print(f"Iter: {self.iter} Loss: {self.loss}")
495
+
496
+ self.tape.zero()
497
+ self.loss.zero_()
498
+
499
+ self.iter = self.iter + 1
500
+
501
+ def render(self):
502
+ self.images[:, :, :, self.image_counter] = self.get_image()
503
+ self.image_counter += 1
504
+
505
+ def train_graph(self):
506
+ # train
507
+ for i in range(self.train_iters):
508
+ self.update()
509
+
510
+ if i % self.period == 0:
511
+ self.render()
512
+
486
513
 
487
514
  if __name__ == "__main__":
515
+ import matplotlib.animation as animation
516
+ import matplotlib.image as img
517
+ import matplotlib.pyplot as plt
518
+
488
519
  output_dir = os.path.join(os.path.dirname(__file__), "outputs")
489
520
 
490
- example = Example()
521
+ reference_example = Example()
491
522
 
492
523
  # render target rotation
493
- example.render()
494
- with wp.ScopedDevice(device="cuda:0"):
495
- wp.copy(example.target_pixels, example.pixels)
496
- target_image = example.get_image()
524
+ reference_example.ray_trace()
525
+ target_image = reference_example.get_image()
497
526
  img.imsave(output_dir + "/target_image.png", target_image)
498
527
 
499
528
  # offset mesh rotation
500
- with wp.ScopedDevice(device="cuda:0"):
501
- example.render_mesh.rot = wp.array(
502
- np.array(
503
- [0.0, (math.sqrt(3) - 1) / (2.0 * math.sqrt(2.0)), 0.0, (math.sqrt(3) + 1) / (2.0 * math.sqrt(2.0))]
504
- ),
505
- dtype=wp.quat,
506
- requires_grad=True,
507
- )
529
+ rotated_example = Example(
530
+ rot_array=[0.0, (math.sqrt(3) - 1) / (2.0 * math.sqrt(2.0)), 0.0, (math.sqrt(3) + 1) / (2.0 * math.sqrt(2.0))],
531
+ verbose=True,
532
+ )
533
+
534
+ wp.copy(rotated_example.target_pixels, reference_example.pixels)
508
535
 
509
536
  # recover target rotation
510
- example.train_graph()
511
- final_image = example.get_image()
537
+ rotated_example.train_graph()
538
+ final_image = rotated_example.get_image()
512
539
  img.imsave(output_dir + "/final_image.png", final_image)
513
540
 
514
- video = example.get_animation()
515
- video.save(output_dir + "/animation.gif", dpi=300, writer=animation.PillowWriter(fps=15))
541
+ video = rotated_example.get_animation()
542
+ video.save(output_dir + "/animation.gif", dpi=300, writer=animation.PillowWriter(fps=5))