warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (158) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1783 -2
  8. warp/codegen.py +177 -45
  9. warp/config.py +2 -2
  10. warp/context.py +321 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +2 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -5
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +600 -0
  82. warp/native/cuda_util.cpp +14 -0
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1857 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +137 -65
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +88 -15
  117. warp/stubs.py +569 -4
  118. warp/tape.py +12 -7
  119. warp/tests/assets/pixel.npy +0 -0
  120. warp/tests/aux_test_instancing_gc.py +18 -0
  121. warp/tests/test_array.py +39 -0
  122. warp/tests/test_codegen.py +81 -1
  123. warp/tests/test_codegen_instancing.py +30 -0
  124. warp/tests/test_collision.py +110 -0
  125. warp/tests/test_coloring.py +241 -0
  126. warp/tests/test_context.py +34 -0
  127. warp/tests/test_examples.py +18 -4
  128. warp/tests/test_fem.py +453 -113
  129. warp/tests/test_func.py +13 -0
  130. warp/tests/test_generics.py +52 -0
  131. warp/tests/test_iter.py +68 -0
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_mesh_query_point.py +1 -1
  134. warp/tests/test_module_hashing.py +23 -0
  135. warp/tests/test_paddle.py +27 -87
  136. warp/tests/test_print.py +56 -1
  137. warp/tests/test_spatial.py +1 -1
  138. warp/tests/test_tile.py +700 -0
  139. warp/tests/test_tile_mathdx.py +144 -0
  140. warp/tests/test_tile_mlp.py +383 -0
  141. warp/tests/test_tile_reduce.py +374 -0
  142. warp/tests/test_tile_shared_memory.py +190 -0
  143. warp/tests/test_vbd.py +12 -20
  144. warp/tests/test_volume.py +43 -0
  145. warp/tests/unittest_suites.py +19 -2
  146. warp/tests/unittest_utils.py +4 -0
  147. warp/types.py +338 -72
  148. warp/utils.py +22 -1
  149. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  150. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
  151. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  152. warp/fem/field/test.py +0 -180
  153. warp/fem/field/trial.py +0 -183
  154. warp/fem/space/collocated_function_space.py +0 -102
  155. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  156. warp/fem/space/trimesh_2d_function_space.py +0 -153
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/paddle.py CHANGED
@@ -17,41 +17,58 @@ import warp.context
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  import paddle
20
+ from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place
20
21
 
21
22
 
22
23
  # return the warp device corresponding to a paddle device
23
- def device_from_paddle(paddle_device: Union[paddle.base.libpaddle.Place, str]) -> warp.context.Device:
24
+ def device_from_paddle(paddle_device: Union[Place, CPUPlace, CUDAPinnedPlace, CUDAPlace, str]) -> warp.context.Device:
24
25
  """Return the Warp device corresponding to a Paddle device.
25
26
 
26
27
  Args:
27
- paddle_device (`paddle.base.libpaddle.Place` or `str`): Paddle device identifier
28
+ paddle_device (`Place`, `CPUPlace`, `CUDAPinnedPlace`, `CUDAPlace`, or `str`): Paddle device identifier
28
29
 
29
30
  Raises:
30
31
  RuntimeError: Paddle device does not have a corresponding Warp device
31
32
  """
32
33
  if type(paddle_device) is str:
34
+ if paddle_device.startswith("gpu:"):
35
+ paddle_device = paddle_device.replace("gpu:", "cuda:")
33
36
  warp_device = warp.context.runtime.device_map.get(paddle_device)
34
37
  if warp_device is not None:
35
38
  return warp_device
36
- elif paddle_device.startswith("gpu"):
39
+ elif paddle_device == "gpu":
37
40
  return warp.context.runtime.get_current_cuda_device()
38
41
  else:
39
42
  raise RuntimeError(f"Unsupported Paddle device {paddle_device}")
40
43
  else:
41
- import paddle
42
-
43
44
  try:
44
- if paddle_device.is_gpu_place():
45
- return warp.context.runtime.cuda_devices[paddle_device.gpu_device_id()]
46
- elif paddle_device.is_cpu_place():
45
+ from paddle.base.libpaddle import CPUPlace, CUDAPinnedPlace, CUDAPlace, Place
46
+
47
+ if isinstance(paddle_device, Place):
48
+ if paddle_device.is_gpu_place():
49
+ return warp.context.runtime.cuda_devices[paddle_device.gpu_device_id()]
50
+ elif paddle_device.is_cpu_place():
51
+ return warp.context.runtime.cpu_device
52
+ else:
53
+ raise RuntimeError(f"Unsupported Paddle device type {paddle_device}")
54
+ elif isinstance(paddle_device, (CPUPlace, CUDAPinnedPlace)):
47
55
  return warp.context.runtime.cpu_device
56
+ elif isinstance(paddle_device, CUDAPlace):
57
+ return warp.context.runtime.cuda_devices[paddle_device.get_device_id()]
48
58
  else:
49
59
  raise RuntimeError(f"Unsupported Paddle device type {paddle_device}")
60
+ except ModuleNotFoundError as e:
61
+ raise ModuleNotFoundError("Please install paddlepaddle first.") from e
50
62
  except Exception as e:
51
- import paddle
52
-
53
- if not isinstance(paddle_device, paddle.base.libpaddle.Place):
54
- raise ValueError("Argument must be a paddle.base.libpaddle.Place object or a string") from e
63
+ if not isinstance(paddle_device, (Place, CPUPlace, CUDAPinnedPlace, CUDAPlace)):
64
+ raise TypeError(
65
+ "device_from_paddle() received an invalid argument - "
66
+ f"got {paddle_device}({type(paddle_device)}), but expected one of:\n"
67
+ "* paddle.base.libpaddle.Place\n"
68
+ "* paddle.CPUPlace\n"
69
+ "* paddle.CUDAPinnedPlace\n"
70
+ "* paddle.CUDAPlace or 'gpu' or 'gpu:x'(x means device id)"
71
+ ) from e
55
72
  raise
56
73
 
57
74
 
@@ -310,14 +310,15 @@ def update_vbo_transforms(
310
310
  @wp.kernel
311
311
  def update_vbo_vertices(
312
312
  points: wp.array(dtype=wp.vec3),
313
+ scale: wp.vec3,
313
314
  # outputs
314
315
  vbo_vertices: wp.array(dtype=float, ndim=2),
315
316
  ):
316
317
  tid = wp.tid()
317
318
  p = points[tid]
318
- vbo_vertices[tid, 0] = p[0]
319
- vbo_vertices[tid, 1] = p[1]
320
- vbo_vertices[tid, 2] = p[2]
319
+ vbo_vertices[tid, 0] = p[0] * scale[0]
320
+ vbo_vertices[tid, 1] = p[1] * scale[1]
321
+ vbo_vertices[tid, 2] = p[2] * scale[2]
321
322
 
322
323
 
323
324
  @wp.kernel
@@ -375,13 +376,14 @@ def update_line_transforms(
375
376
  def compute_gfx_vertices(
376
377
  indices: wp.array(dtype=int, ndim=2),
377
378
  vertices: wp.array(dtype=wp.vec3, ndim=1),
379
+ scale: wp.vec3,
378
380
  # outputs
379
381
  gfx_vertices: wp.array(dtype=float, ndim=2),
380
382
  ):
381
383
  tid = wp.tid()
382
- v0 = vertices[indices[tid, 0]]
383
- v1 = vertices[indices[tid, 1]]
384
- v2 = vertices[indices[tid, 2]]
384
+ v0 = vertices[indices[tid, 0]] * scale[0]
385
+ v1 = vertices[indices[tid, 1]] * scale[1]
386
+ v2 = vertices[indices[tid, 2]] * scale[2]
385
387
  i = tid * 3
386
388
  j = i + 1
387
389
  k = i + 2
@@ -410,6 +412,7 @@ def compute_gfx_vertices(
410
412
  def compute_average_normals(
411
413
  indices: wp.array(dtype=int, ndim=2),
412
414
  vertices: wp.array(dtype=wp.vec3),
415
+ scale: wp.vec3,
413
416
  # outputs
414
417
  normals: wp.array(dtype=wp.vec3),
415
418
  faces_per_vertex: wp.array(dtype=int),
@@ -418,9 +421,9 @@ def compute_average_normals(
418
421
  i = indices[tid, 0]
419
422
  j = indices[tid, 1]
420
423
  k = indices[tid, 2]
421
- v0 = vertices[i]
422
- v1 = vertices[j]
423
- v2 = vertices[k]
424
+ v0 = vertices[i] * scale[0]
425
+ v1 = vertices[j] * scale[1]
426
+ v2 = vertices[k] * scale[2]
424
427
  n = wp.normalize(wp.cross(v1 - v0, v2 - v0))
425
428
  wp.atomic_add(normals, i, n)
426
429
  wp.atomic_add(faces_per_vertex, i, 1)
@@ -435,15 +438,16 @@ def assemble_gfx_vertices(
435
438
  vertices: wp.array(dtype=wp.vec3, ndim=1),
436
439
  normals: wp.array(dtype=wp.vec3),
437
440
  faces_per_vertex: wp.array(dtype=int),
441
+ scale: wp.vec3,
438
442
  # outputs
439
443
  gfx_vertices: wp.array(dtype=float, ndim=2),
440
444
  ):
441
445
  tid = wp.tid()
442
446
  v = vertices[tid]
443
447
  n = normals[tid] / float(faces_per_vertex[tid])
444
- gfx_vertices[tid, 0] = v[0]
445
- gfx_vertices[tid, 1] = v[1]
446
- gfx_vertices[tid, 2] = v[2]
448
+ gfx_vertices[tid, 0] = v[0] * scale[0]
449
+ gfx_vertices[tid, 1] = v[1] * scale[1]
450
+ gfx_vertices[tid, 2] = v[2] * scale[2]
447
451
  gfx_vertices[tid, 3] = n[0]
448
452
  gfx_vertices[tid, 4] = n[1]
449
453
  gfx_vertices[tid, 5] = n[2]
@@ -1062,7 +1066,6 @@ class OpenGLRenderer:
1062
1066
  self._camera_axis = up_axis
1063
1067
  else:
1064
1068
  self._camera_axis = "XYZ".index(up_axis.upper())
1065
- self._yaw, self._pitch = -90.0, 0.0
1066
1069
  self._last_x, self._last_y = self.screen_width // 2, self.screen_height // 2
1067
1070
  self._first_mouse = True
1068
1071
  self._left_mouse_pressed = False
@@ -1083,6 +1086,10 @@ class OpenGLRenderer:
1083
1086
  self.update_view_matrix(cam_pos=camera_pos, cam_front=camera_front, cam_up=camera_up)
1084
1087
  self.update_projection_matrix()
1085
1088
 
1089
+ self._camera_front = self._camera_front.normalize()
1090
+ self._pitch = np.rad2deg(np.arcsin(self._camera_front.y))
1091
+ self._yaw = -np.rad2deg(np.arccos(self._camera_front.x / np.cos(np.deg2rad(self._pitch))))
1092
+
1086
1093
  self._frame_dt = 1.0 / fps
1087
1094
  self.time = 0.0
1088
1095
  self._start_time = time.time()
@@ -1146,6 +1153,7 @@ class OpenGLRenderer:
1146
1153
  self.window.push_handlers(on_draw=self._draw)
1147
1154
  self.window.push_handlers(on_resize=self._window_resize_callback)
1148
1155
  self.window.push_handlers(on_key_press=self._key_press_callback)
1156
+ self.window.push_handlers(on_close=self._close_callback)
1149
1157
 
1150
1158
  self._key_handler = pyglet.window.key.KeyStateHandler()
1151
1159
  self.window.push_handlers(self._key_handler)
@@ -2050,6 +2058,9 @@ Instances: {len(self._instances)}"""
2050
2058
 
2051
2059
  gl.glBindVertexArray(0)
2052
2060
 
2061
+ def _close_callback(self):
2062
+ self.close()
2063
+
2053
2064
  def _mouse_drag_callback(self, x, y, dx, dy, buttons, modifiers):
2054
2065
  if not self.enable_mouse_interaction:
2055
2066
  return
@@ -2193,6 +2204,25 @@ Instances: {len(self._instances)}"""
2193
2204
 
2194
2205
  return shape
2195
2206
 
2207
+ def deregister_shape(self, shape):
2208
+ from pyglet import gl
2209
+
2210
+ if shape not in self._shape_gl_buffers:
2211
+ return
2212
+
2213
+ vao, vbo, ebo, _, vertex_cuda_buffer = self._shape_gl_buffers[shape]
2214
+ try:
2215
+ gl.glDeleteVertexArrays(1, vao)
2216
+ gl.glDeleteBuffers(1, vbo)
2217
+ gl.glDeleteBuffers(1, ebo)
2218
+ except gl.GLException:
2219
+ pass
2220
+
2221
+ _, _, _, _, geo_hash = self._shapes[shape]
2222
+ del self._shape_geo_hash[geo_hash]
2223
+ del self._shape_gl_buffers[shape]
2224
+ self._shapes.pop(shape)
2225
+
2196
2226
  def add_shape_instance(
2197
2227
  self,
2198
2228
  name: str,
@@ -2220,6 +2250,19 @@ Instances: {len(self._instances)}"""
2220
2250
  self._instance_count = len(self._instances)
2221
2251
  return instance
2222
2252
 
2253
+ def remove_shape_instance(self, name: str):
2254
+ if name not in self._instances:
2255
+ return
2256
+
2257
+ instance, _, shape, _, _, _, _, _ = self._instances[name]
2258
+
2259
+ self._shape_instances[shape].remove(instance)
2260
+ self._instance_count = len(self._instances)
2261
+ self._add_shape_instances = self._instance_count > 0
2262
+ del self._instance_shape[instance]
2263
+ del self._instance_custom_ids[instance]
2264
+ del self._instances[name]
2265
+
2223
2266
  def update_instance_colors(self):
2224
2267
  from pyglet import gl
2225
2268
 
@@ -2236,15 +2279,13 @@ Instances: {len(self._instances)}"""
2236
2279
  colors2 = np.array(colors2, dtype=np.float32)
2237
2280
 
2238
2281
  # create buffer for checkerboard colors
2239
- if self._instance_color1_buffer is None:
2240
- self._instance_color1_buffer = gl.GLuint()
2241
- gl.glGenBuffers(1, self._instance_color1_buffer)
2282
+ self._instance_color1_buffer = gl.GLuint()
2283
+ gl.glGenBuffers(1, self._instance_color1_buffer)
2242
2284
  gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._instance_color1_buffer)
2243
2285
  gl.glBufferData(gl.GL_ARRAY_BUFFER, colors1.nbytes, colors1.ctypes.data, gl.GL_STATIC_DRAW)
2244
2286
 
2245
- if self._instance_color2_buffer is None:
2246
- self._instance_color2_buffer = gl.GLuint()
2247
- gl.glGenBuffers(1, self._instance_color2_buffer)
2287
+ self._instance_color2_buffer = gl.GLuint()
2288
+ gl.glGenBuffers(1, self._instance_color2_buffer)
2248
2289
  gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._instance_color2_buffer)
2249
2290
  gl.glBufferData(gl.GL_ARRAY_BUFFER, colors2.nbytes, colors2.ctypes.data, gl.GL_STATIC_DRAW)
2250
2291
 
@@ -2362,8 +2403,8 @@ Instances: {len(self._instances)}"""
2362
2403
  shape,
2363
2404
  new_tf,
2364
2405
  scale,
2365
- color1 or old_color1,
2366
- color2 or old_color2,
2406
+ old_color1 if color1 is None else color1,
2407
+ old_color2 if color2 is None else color2,
2367
2408
  visible,
2368
2409
  )
2369
2410
  self._update_shape_instances = True
@@ -2882,56 +2923,87 @@ Instances: {len(self._instances)}"""
2882
2923
  name: A name for the USD prim on the stage
2883
2924
  smooth_shading: Whether to average face normals at each vertex or introduce additional vertices for each face
2884
2925
  """
2885
- if colors is None:
2886
- colors = np.ones((len(points), 3), dtype=np.float32)
2887
- else:
2926
+ if colors is not None:
2888
2927
  colors = np.array(colors, dtype=np.float32)
2889
- points = np.array(points, dtype=np.float32) * np.array(scale, dtype=np.float32)
2928
+
2929
+ points = np.array(points, dtype=np.float32)
2930
+ point_count = len(points)
2931
+
2890
2932
  indices = np.array(indices, dtype=np.int32).reshape((-1, 3))
2933
+ idx_count = len(indices)
2934
+
2935
+ geo_hash = hash((indices.tobytes(),))
2936
+
2891
2937
  if name in self._instances:
2892
- self.update_shape_instance(name, pos, rot)
2938
+ # We've already registered this mesh instance and its associated shape.
2893
2939
  shape = self._instances[name][2]
2894
- self.update_shape_vertices(shape, points)
2895
- return
2896
- geo_hash = hash((points.tobytes(), indices.tobytes(), colors.tobytes()))
2897
- if geo_hash in self._shape_geo_hash:
2898
- shape = self._shape_geo_hash[geo_hash]
2899
- if self.update_shape_instance(name, pos, rot):
2900
- return shape
2901
2940
  else:
2902
- if smooth_shading:
2903
- normals = wp.zeros(len(points), dtype=wp.vec3)
2904
- vertices = wp.array(points, dtype=wp.vec3)
2905
- faces_per_vertex = wp.zeros(len(points), dtype=int)
2906
- wp.launch(
2907
- compute_average_normals,
2908
- dim=len(indices),
2909
- inputs=[wp.array(indices, dtype=int), vertices],
2910
- outputs=[normals, faces_per_vertex],
2911
- )
2912
- gfx_vertices = wp.zeros((len(points), 8), dtype=float)
2913
- wp.launch(
2914
- assemble_gfx_vertices,
2915
- dim=len(points),
2916
- inputs=[vertices, normals, faces_per_vertex],
2917
- outputs=[gfx_vertices],
2918
- )
2919
- gfx_vertices = gfx_vertices.numpy()
2920
- gfx_indices = indices.flatten()
2941
+ if geo_hash in self._shape_geo_hash:
2942
+ # We've only registered the shape, which can happen when `is_template` is `True`.
2943
+ shape = self._shape_geo_hash[geo_hash]
2921
2944
  else:
2922
- gfx_vertices = wp.zeros((len(indices) * 3, 8), dtype=float)
2923
- wp.launch(
2924
- compute_gfx_vertices,
2925
- dim=len(indices),
2926
- inputs=[wp.array(indices, dtype=int), wp.array(points, dtype=wp.vec3)],
2927
- outputs=[gfx_vertices],
2928
- )
2929
- gfx_vertices = gfx_vertices.numpy()
2930
- gfx_indices = np.arange(len(indices) * 3)
2931
- shape = self.register_shape(geo_hash, gfx_vertices, gfx_indices)
2945
+ shape = None
2946
+
2947
+ # Check if we already have that shape registered and can perform
2948
+ # minimal updates since the topology is not changing, before exiting.
2949
+ if not update_topology:
2950
+ if name in self._instances:
2951
+ # Update the instance's transform.
2952
+ self.update_shape_instance(name, pos, rot, color1=colors)
2953
+
2954
+ if shape is not None:
2955
+ # Update the shape's point positions.
2956
+ self.update_shape_vertices(shape, points, scale)
2957
+ return shape
2958
+
2959
+ # No existing shape for the given mesh was found, or its topology may have changed,
2960
+ # so we need to define a new one either way.
2961
+ if smooth_shading:
2962
+ normals = wp.zeros(point_count, dtype=wp.vec3)
2963
+ vertices = wp.array(points, dtype=wp.vec3)
2964
+ faces_per_vertex = wp.zeros(point_count, dtype=int)
2965
+ wp.launch(
2966
+ compute_average_normals,
2967
+ dim=idx_count,
2968
+ inputs=[wp.array(indices, dtype=int), vertices, scale],
2969
+ outputs=[normals, faces_per_vertex],
2970
+ )
2971
+ gfx_vertices = wp.zeros((point_count, 8), dtype=float)
2972
+ wp.launch(
2973
+ assemble_gfx_vertices,
2974
+ dim=point_count,
2975
+ inputs=[vertices, normals, faces_per_vertex, scale],
2976
+ outputs=[gfx_vertices],
2977
+ )
2978
+ gfx_vertices = gfx_vertices.numpy()
2979
+ gfx_indices = indices.flatten()
2980
+ else:
2981
+ gfx_vertices = wp.zeros((idx_count * 3, 8), dtype=float)
2982
+ wp.launch(
2983
+ compute_gfx_vertices,
2984
+ dim=idx_count,
2985
+ inputs=[wp.array(indices, dtype=int), wp.array(points, dtype=wp.vec3), scale],
2986
+ outputs=[gfx_vertices],
2987
+ )
2988
+ gfx_vertices = gfx_vertices.numpy()
2989
+ gfx_indices = np.arange(idx_count * 3)
2990
+
2991
+ # If there was a shape for the given mesh, clean it up.
2992
+ if shape is not None:
2993
+ self.deregister_shape(shape)
2994
+
2995
+ # If there was an instance for the given mesh, clean it up.
2996
+ if name in self._instances:
2997
+ self.remove_shape_instance(name)
2998
+
2999
+ # Register the new shape.
3000
+ shape = self.register_shape(geo_hash, gfx_vertices, gfx_indices)
3001
+
2932
3002
  if not is_template:
3003
+ # Create a new instance if necessary.
2933
3004
  body = self._resolve_body_id(parent_body)
2934
- self.add_shape_instance(name, shape, body, pos, rot)
3005
+ self.add_shape_instance(name, shape, body, pos, rot, color1=colors)
3006
+
2935
3007
  return shape
2936
3008
 
2937
3009
  def render_arrow(
@@ -3096,7 +3168,7 @@ Instances: {len(self._instances)}"""
3096
3168
  lines = np.array(lines)
3097
3169
  self._render_lines(name, lines, color, radius)
3098
3170
 
3099
- def update_shape_vertices(self, shape, points):
3171
+ def update_shape_vertices(self, shape, points, scale):
3100
3172
  if isinstance(points, wp.array):
3101
3173
  wp_points = points.to(self._device)
3102
3174
  else:
@@ -3109,7 +3181,7 @@ Instances: {len(self._instances)}"""
3109
3181
  wp.launch(
3110
3182
  update_vbo_vertices,
3111
3183
  dim=vertices_shape[0],
3112
- inputs=[wp_points],
3184
+ inputs=[wp_points, scale],
3113
3185
  outputs=[vbo_vertices],
3114
3186
  device=self._device,
3115
3187
  )