warp-lang 1.6.0__py3-none-manylinux2014_x86_64.whl → 1.6.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

@@ -29,17 +29,20 @@ from warp.sim.model import PARTICLE_FLAG_ACTIVE
29
29
 
30
30
  @wp.kernel
31
31
  def initialize_rotation(
32
- particle_indices_to_rot: wp.array(dtype=wp.int32),
32
+ # input
33
+ vertex_indices_to_rot: wp.array(dtype=wp.int32),
33
34
  pos: wp.array(dtype=wp.vec3),
34
35
  rot_centers: wp.array(dtype=wp.vec3),
35
36
  rot_axes: wp.array(dtype=wp.vec3),
37
+ t: wp.array(dtype=float),
38
+ # output
36
39
  roots: wp.array(dtype=wp.vec3),
37
40
  roots_to_ps: wp.array(dtype=wp.vec3),
38
41
  ):
39
42
  tid = wp.tid()
40
- particle_index = particle_indices_to_rot[wp.tid()]
43
+ v_index = vertex_indices_to_rot[wp.tid()]
41
44
 
42
- p = pos[particle_index]
45
+ p = pos[v_index]
43
46
  rot_center = rot_centers[tid]
44
47
  rot_axis = rot_axes[tid]
45
48
  op = p - rot_center
@@ -51,21 +54,31 @@ def initialize_rotation(
51
54
  roots[tid] = root
52
55
  roots_to_ps[tid] = root_to_p
53
56
 
57
+ if tid == 0:
58
+ t[0] = 0.0
59
+
54
60
 
55
61
  @wp.kernel
56
62
  def apply_rotation(
57
- time: float,
58
- angular_velocity: float,
59
- particle_indices_to_rot: wp.array(dtype=wp.int32),
60
- rot_centers: wp.array(dtype=wp.vec3),
63
+ # input
64
+ vertex_indices_to_rot: wp.array(dtype=wp.int32),
61
65
  rot_axes: wp.array(dtype=wp.vec3),
62
66
  roots: wp.array(dtype=wp.vec3),
63
67
  roots_to_ps: wp.array(dtype=wp.vec3),
68
+ t: wp.array(dtype=float),
69
+ angular_velocity: float,
70
+ dt: float,
71
+ end_time: float,
72
+ # output
64
73
  pos_0: wp.array(dtype=wp.vec3),
65
74
  pos_1: wp.array(dtype=wp.vec3),
66
75
  ):
76
+ cur_t = t[0]
77
+ if cur_t > end_time:
78
+ return
79
+
67
80
  tid = wp.tid()
68
- particle_index = particle_indices_to_rot[wp.tid()]
81
+ v_index = vertex_indices_to_rot[wp.tid()]
69
82
 
70
83
  rot_axis = rot_axes[tid]
71
84
 
@@ -73,7 +86,7 @@ def apply_rotation(
73
86
  uy = rot_axis[1]
74
87
  uz = rot_axis[2]
75
88
 
76
- theta = time * angular_velocity
89
+ theta = cur_t * angular_velocity
77
90
 
78
91
  R = wp.mat33(
79
92
  wp.cos(theta) + ux * ux * (1.0 - wp.cos(theta)),
@@ -92,24 +105,29 @@ def apply_rotation(
92
105
  root_to_p_rot = R * root_to_p
93
106
  p_rot = root + root_to_p_rot
94
107
 
95
- pos_0[particle_index] = p_rot
96
- pos_1[particle_index] = p_rot
108
+ pos_0[v_index] = p_rot
109
+ pos_1[v_index] = p_rot
110
+
111
+ if tid == 0:
112
+ t[0] = cur_t + dt
97
113
 
98
114
 
99
115
  class Example:
100
- def __init__(self, stage_path="example_cloth_self_contact.usd", num_frames=1500):
116
+ def __init__(self, stage_path="example_cloth_self_contact.usd", num_frames=600):
101
117
  fps = 60
102
118
  self.frame_dt = 1.0 / fps
119
+ # must be an even number when using CUDA Graph
103
120
  self.num_substeps = 10
104
- self.iterations = 10
121
+ self.iterations = 4
105
122
  self.dt = self.frame_dt / self.num_substeps
106
123
 
107
124
  self.num_frames = num_frames
108
125
  self.sim_time = 0.0
109
126
  self.profiler = {}
110
127
 
111
- self.rot_angular_velocity = math.pi / 6
112
- self.rot_end_time = 21
128
+ self.rot_angular_velocity = math.pi / 3
129
+ self.rot_end_time = 10
130
+ self.use_cuda_graph = wp.get_device().is_cuda
113
131
 
114
132
  usd_stage = Usd.Stage.Open(os.path.join(warp.examples.get_asset_directory(), "square_cloth.usd"))
115
133
  usd_geom = UsdGeom.Mesh(usd_stage.GetPrimAtPath("/root/cloth/cloth"))
@@ -134,7 +152,8 @@ class Example:
134
152
  density=0.02,
135
153
  tri_ke=1.0e5,
136
154
  tri_ka=1.0e5,
137
- tri_kd=3.0e-5,
155
+ tri_kd=2.0e-6,
156
+ edge_ke=10,
138
157
  )
139
158
  builder.color()
140
159
  self.model = builder.finalize()
@@ -170,6 +189,7 @@ class Example:
170
189
  rot_axes = [[1, 0, 0]] * len(right_side) + [[-1, 0, 0]] * len(left_side)
171
190
 
172
191
  self.rot_point_indices = wp.array(rot_point_indices, dtype=int)
192
+ self.t = wp.zeros((1,), dtype=float)
173
193
  self.rot_centers = wp.zeros(len(rot_point_indices), dtype=wp.vec3)
174
194
  self.rot_axes = wp.array(rot_axes, dtype=wp.vec3)
175
195
 
@@ -184,41 +204,75 @@ class Example:
184
204
  self.state0.particle_q,
185
205
  self.rot_centers,
186
206
  self.rot_axes,
207
+ self.t,
208
+ ],
209
+ outputs=[
187
210
  self.roots,
188
211
  self.roots_to_ps,
189
212
  ],
190
213
  )
191
214
 
192
215
  if stage_path:
193
- self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=40.0)
216
+ self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=1)
194
217
  else:
195
218
  self.renderer = None
219
+ self.cuda_graph = None
220
+ if self.use_cuda_graph:
221
+ with wp.ScopedCapture() as capture:
222
+ for _ in range(self.num_substeps):
223
+ wp.launch(
224
+ kernel=apply_rotation,
225
+ dim=self.rot_point_indices.shape[0],
226
+ inputs=[
227
+ self.rot_point_indices,
228
+ self.rot_axes,
229
+ self.roots,
230
+ self.roots_to_ps,
231
+ self.t,
232
+ self.rot_angular_velocity,
233
+ self.dt,
234
+ self.rot_end_time,
235
+ ],
236
+ outputs=[
237
+ self.state0.particle_q,
238
+ self.state1.particle_q,
239
+ ],
240
+ )
241
+
242
+ self.integrator.simulate(self.model, self.state0, self.state1, self.dt, None)
243
+ (self.state0, self.state1) = (self.state1, self.state0)
244
+
245
+ self.cuda_graph = capture.graph
196
246
 
197
247
  def step(self):
198
248
  with wp.ScopedTimer("step", print=False, dict=self.profiler):
199
- for _ in range(self.num_substeps):
200
- if self.sim_time < self.rot_end_time:
249
+ if self.use_cuda_graph:
250
+ wp.capture_launch(self.cuda_graph)
251
+ else:
252
+ for _ in range(self.num_substeps):
201
253
  wp.launch(
202
254
  kernel=apply_rotation,
203
255
  dim=self.rot_point_indices.shape[0],
204
256
  inputs=[
205
- self.sim_time,
206
- self.rot_angular_velocity,
207
257
  self.rot_point_indices,
208
- self.rot_centers,
209
258
  self.rot_axes,
210
259
  self.roots,
211
260
  self.roots_to_ps,
261
+ self.t,
262
+ self.rot_angular_velocity,
263
+ self.dt,
264
+ self.rot_end_time,
265
+ ],
266
+ outputs=[
212
267
  self.state0.particle_q,
213
268
  self.state1.particle_q,
214
269
  ],
215
270
  )
271
+ self.integrator.simulate(self.model, self.state0, self.state1, self.dt)
216
272
 
217
- self.integrator.simulate(self.model, self.state0, self.state1, self.dt)
218
-
219
- (self.state0, self.state1) = (self.state1, self.state0)
273
+ (self.state0, self.state1) = (self.state1, self.state0)
220
274
 
221
- self.sim_time += self.dt
275
+ self.sim_time += self.dt
222
276
 
223
277
  def render(self):
224
278
  if self.renderer is None:
@@ -241,7 +295,7 @@ if __name__ == "__main__":
241
295
  default="example_cloth_self_contact.usd",
242
296
  help="Path to the output USD file.",
243
297
  )
244
- parser.add_argument("--num_frames", type=int, default=1500, help="Total number of frames.")
298
+ parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames.")
245
299
 
246
300
  args = parser.parse_known_args()[0]
247
301
 
@@ -26,7 +26,7 @@ import warp as wp
26
26
 
27
27
  wp.init()
28
28
 
29
- DT = wp.constant(0.01)
29
+ DT = wp.constant(0.016)
30
30
  SOFTENING_SQ = wp.constant(0.1**2) # Softening factor for numerical stability
31
31
  TILE_SIZE = wp.constant(64)
32
32
  PARTICLE_MASS = wp.constant(1.0)
@@ -75,25 +75,33 @@ def integrate_bodies_tiled(
75
75
 
76
76
 
77
77
  class Example:
78
- def __init__(self, headless=False, num_bodies=1024):
78
+ def __init__(self, headless=False, num_bodies=16384):
79
79
  self.num_bodies = num_bodies
80
80
 
81
81
  rng = np.random.default_rng(42)
82
82
 
83
83
  # Sample the surface of a sphere
84
- r = 10.0 * (num_bodies / 1024) ** (1 / 2) # Scale factor to maintain a constant density
85
- phi = np.arccos(1.0 - 2.0 * rng.uniform(size=self.num_bodies))
84
+ phi = np.arccos(1.0 - 2.0 * rng.uniform(low=0.0, high=1.0, size=self.num_bodies))
86
85
  theta = rng.uniform(low=0.0, high=2.0 * np.pi, size=self.num_bodies)
87
- x = r * np.cos(theta) * np.sin(phi)
88
- y = r * np.sin(theta) * np.sin(phi)
89
- z = r * np.cos(phi)
90
-
91
- self.scale = r
86
+ x = np.cos(theta) * np.sin(phi)
87
+ y = np.sin(theta) * np.sin(phi)
88
+ z = np.cos(phi)
92
89
  init_pos_np = np.stack((x, y, z), axis=1)
93
90
 
91
+ scale = (num_bodies / 1024) ** (1 / 2) # Scale factor to maintain a constant density
92
+ inner = 0.9625 * scale
93
+ outer = 1.54 * scale
94
+ radii = inner + (outer - inner) * rng.uniform(size=(self.num_bodies, 1))
95
+ init_pos_np = init_pos_np * radii
96
+
97
+ axis = np.array([0.0, 0.0, 1.0])
98
+ v_scale = scale * 3.08
99
+ init_vel_np = v_scale * np.cross(init_pos_np, axis)
100
+
101
+ self.graph_scale = np.max(radii) * 5.0
94
102
  self.pos_array_0 = wp.array(init_pos_np, dtype=wp.vec3)
95
103
  self.pos_array_1 = wp.empty_like(self.pos_array_0)
96
- self.vel_array = wp.zeros(self.num_bodies, dtype=wp.vec3)
104
+ self.vel_array = wp.array(init_vel_np, dtype=wp.vec3)
97
105
 
98
106
  if headless:
99
107
  self.scatter_plot = None
@@ -108,13 +116,16 @@ class Example:
108
116
  ax = self.fig.add_subplot(111, projection="3d")
109
117
 
110
118
  # Scatter plot of initial positions
119
+ point_size = 0.05 * self.graph_scale
111
120
  init_pos_np = self.pos_array_0.numpy()
112
- scatter_plot = ax.scatter(init_pos_np[:, 0], init_pos_np[:, 1], init_pos_np[:, 2], c="#76b900", alpha=0.5)
121
+ scatter_plot = ax.scatter(
122
+ init_pos_np[:, 0], init_pos_np[:, 1], init_pos_np[:, 2], s=point_size, c="#76b900", alpha=0.5
123
+ )
113
124
 
114
125
  # Set axis limits
115
- ax.set_xlim(-self.scale, self.scale)
116
- ax.set_ylim(-self.scale, self.scale)
117
- ax.set_zlim(-self.scale, self.scale)
126
+ ax.set_xlim(-self.graph_scale, self.graph_scale)
127
+ ax.set_ylim(-self.graph_scale, self.graph_scale)
128
+ ax.set_zlim(-self.graph_scale, self.graph_scale)
118
129
 
119
130
  return scatter_plot
120
131
 
@@ -149,7 +160,7 @@ if __name__ == "__main__":
149
160
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
150
161
  parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
151
162
  parser.add_argument("--num_frames", type=int, default=1000, help="Total number of frames.")
152
- parser.add_argument("-N", help="Number of bodies. Should be a multiple of 64.", type=int, default=1024)
163
+ parser.add_argument("-N", help="Number of bodies. Should be a multiple of 64.", type=int, default=16384)
153
164
  parser.add_argument(
154
165
  "--headless",
155
166
  action="store_true",
@@ -413,7 +413,7 @@ WP_API int load_obj(const char* object_file, const char* module_name)
413
413
  SYMBOL(sinhf), SYMBOL_T(sinh, double(*)(double)),
414
414
  SYMBOL(coshf), SYMBOL_T(cosh, double(*)(double)),
415
415
  SYMBOL(tanhf), SYMBOL_T(tanh, double(*)(double)),
416
- SYMBOL(fmaf),
416
+ SYMBOL(fmaf), SYMBOL_T(fma, double(*)(double, double, double)),
417
417
  SYMBOL(memcpy), SYMBOL(memset), SYMBOL(memmove),
418
418
  SYMBOL(_wp_assert),
419
419
  SYMBOL(_wp_isfinite),
warp/native/crt.h CHANGED
@@ -303,6 +303,7 @@ double cosh(double);
303
303
  float tanhf(float);
304
304
  double tanh(double);
305
305
  float fmaf(float, float, float);
306
+ double fma(double, double, double);
306
307
 
307
308
  // stddef.h
308
309
  #if defined(_WIN32)
warp/native/mat.h CHANGED
@@ -681,6 +681,20 @@ inline CUDA_CALLABLE vec_t<Cols,Type> mul(const vec_t<Rows,Type>& b, const mat_t
681
681
  return r;
682
682
  }
683
683
 
684
+ template<typename T>
685
+ inline CUDA_CALLABLE T muladd(T a, T b, T c) {
686
+ return c + a*b;
687
+ }
688
+ template<>
689
+ inline CUDA_CALLABLE float muladd(float a, float b, float c) {
690
+ return fmaf(a, b, c);
691
+ }
692
+ template<>
693
+ inline CUDA_CALLABLE double muladd(double a, double b, double c) {
694
+ return fma(a, b, c);
695
+ }
696
+
697
+
684
698
  template<unsigned Rows, unsigned Cols, unsigned ColsOut, typename Type>
685
699
  inline CUDA_CALLABLE mat_t<Rows,ColsOut,Type> mul(const mat_t<Rows,Cols,Type>& a, const mat_t<Cols,ColsOut,Type>& b)
686
700
  {
@@ -693,8 +707,7 @@ inline CUDA_CALLABLE mat_t<Rows,ColsOut,Type> mul(const mat_t<Rows,Cols,Type>& a
693
707
 
694
708
  for (unsigned k=0; k < Cols; ++k)
695
709
  {
696
- //t.data[i][j] += a.data[i][k]*b.data[k][j];
697
- sum = fmaf(a.data[i][k], b.data[k][j], sum);
710
+ sum = muladd<Type>(a.data[i][k], b.data[k][j], sum);
698
711
  }
699
712
 
700
713
  t.data[i][j] = sum;
@@ -713,7 +726,7 @@ inline CUDA_CALLABLE Type ddot(const mat_t<Rows,Cols,Type>& a, const mat_t<Rows,
713
726
  {
714
727
  for (unsigned j=0; j < Cols; ++j)
715
728
  {
716
- r += a.data[i][j] * b.data[i][j];
729
+ r = muladd<Type>(a.data[i][j], b.data[i][j], r);
717
730
  }
718
731
  }
719
732
  return r;
warp/native/tile.h CHANGED
@@ -1079,9 +1079,12 @@ struct tile_shared_t
1079
1079
  constexpr bool contiguous_src = Layout::Stride::dim(lastdim) == 1;
1080
1080
  const bool contiguous_dest = dest.data.strides[lastdim] == sizeof(T);
1081
1081
  const int elements = (dest.data.shape[lastdim] - dest.offset[lastdim]);
1082
- const bool aligned = (elements*sizeof(T))%sizeof(float4) == 0;
1082
+ const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1083
1083
 
1084
- if (contiguous_dest && contiguous_src && aligned)
1084
+ float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
1085
+ const bool aligned_dst = (uint64_t)(dest128)%sizeof(float4) == 0;
1086
+
1087
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_dst)
1085
1088
  {
1086
1089
  constexpr int M = Layout::Shape::dim(0);
1087
1090
  constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
@@ -1089,10 +1092,9 @@ struct tile_shared_t
1089
1092
  // alias of shared tile with 128bit type
1090
1093
  using SrcLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1091
1094
  tile_shared_t<float4, SrcLayout> src128((float4*)data.ptr);
1092
- float4* dest128 = (float4*)&dest.data.data[dest.index_from_coord(tile_coord(0,0))];
1093
1095
 
1094
1096
  assert(((uint64_t)(data.ptr))%sizeof(float4) == 0);
1095
- assert(((uint64_t)(ptr))%sizeof(float4) == 0);
1097
+ assert(((uint64_t)(dest128))%sizeof(float4) == 0);
1096
1098
 
1097
1099
  const int stride_i = dest.data.strides[0]/sizeof(float4);
1098
1100
  const int stride_j = 1;
@@ -1164,17 +1166,19 @@ struct tile_shared_t
1164
1166
  constexpr bool contiguous_dest = Layout::Stride::dim(lastdim) == 1;
1165
1167
  const bool contiguous_src = src.data.strides[lastdim] == sizeof(T);
1166
1168
  const int elements = (src.data.shape[lastdim] - src.offset[lastdim]);
1167
- const bool aligned = (elements*sizeof(T))%sizeof(float4) == 0;
1169
+ const bool aligned_size = (elements*sizeof(T))%sizeof(float4) == 0;
1170
+
1171
+ float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1172
+ const bool aligned_src = (uint64_t)(src128)%sizeof(float4) == 0;
1168
1173
 
1169
- if (contiguous_dest && contiguous_src && aligned)
1174
+ if (contiguous_dest && contiguous_src && aligned_size && aligned_src)
1170
1175
  {
1171
1176
  constexpr int M = Layout::Shape::dim(0);
1172
1177
  constexpr int N = (Layout::Shape::dim(1)*sizeof(T))/sizeof(float4);
1173
1178
 
1174
1179
  // alias of shared tile with 128bit type
1175
1180
  using DestLayout = tile_layout_strided_t<tile_shape_t<M, N>>;
1176
- tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
1177
- float4* src128 = (float4*)&src.data.data[src.index_from_coord(tile_coord(0,0))];
1181
+ tile_shared_t<float4, DestLayout> dest128((float4*)data.ptr);
1178
1182
 
1179
1183
  assert(((uint64_t)(dest128.data.ptr))%sizeof(float4) == 0);
1180
1184
  assert(((uint64_t)(src128))%sizeof(float4) == 0);
@@ -1083,7 +1083,7 @@ class OpenGLRenderer:
1083
1083
  self.app = pyglet.app
1084
1084
 
1085
1085
  # making window current opengl rendering context
1086
- self.window.switch_to()
1086
+ self._switch_context()
1087
1087
 
1088
1088
  self.screen_width, self.screen_height = self.window.get_framebuffer_size()
1089
1089
 
@@ -1435,7 +1435,7 @@ class OpenGLRenderer:
1435
1435
  def clear(self):
1436
1436
  gl = OpenGLRenderer.gl
1437
1437
 
1438
- self.window.switch_to()
1438
+ self._switch_context()
1439
1439
 
1440
1440
  if not self.headless:
1441
1441
  self.app.event_loop.dispatch_event("on_exit")
@@ -1631,7 +1631,7 @@ class OpenGLRenderer:
1631
1631
  def _setup_framebuffer(self):
1632
1632
  gl = OpenGLRenderer.gl
1633
1633
 
1634
- self.window.switch_to()
1634
+ self._switch_context()
1635
1635
 
1636
1636
  if self._frame_texture is None:
1637
1637
  self._frame_texture = gl.GLuint()
@@ -1801,7 +1801,7 @@ class OpenGLRenderer:
1801
1801
  def update_model_matrix(self, model_matrix: Optional[Mat44] = None):
1802
1802
  gl = OpenGLRenderer.gl
1803
1803
 
1804
- self.window.switch_to()
1804
+ self._switch_context()
1805
1805
 
1806
1806
  if model_matrix is None:
1807
1807
  self._model_matrix = self.compute_model_matrix(self._camera_axis, self._scaling)
@@ -1898,7 +1898,7 @@ class OpenGLRenderer:
1898
1898
  def _draw(self):
1899
1899
  gl = OpenGLRenderer.gl
1900
1900
 
1901
- self.window.switch_to()
1901
+ self._switch_context()
1902
1902
 
1903
1903
  if not self.headless:
1904
1904
  # catch key hold events
@@ -1999,7 +1999,7 @@ Instances: {len(self._instances)}"""
1999
1999
  def _draw_grid(self, is_tiled=False):
2000
2000
  gl = OpenGLRenderer.gl
2001
2001
 
2002
- self.window.switch_to()
2002
+ self._switch_context()
2003
2003
 
2004
2004
  if not is_tiled:
2005
2005
  gl.glUseProgram(self._grid_shader.id)
@@ -2014,7 +2014,7 @@ Instances: {len(self._instances)}"""
2014
2014
  def _draw_sky(self, is_tiled=False):
2015
2015
  gl = OpenGLRenderer.gl
2016
2016
 
2017
- self.window.switch_to()
2017
+ self._switch_context()
2018
2018
 
2019
2019
  if not is_tiled:
2020
2020
  gl.glUseProgram(self._sky_shader.id)
@@ -2030,7 +2030,7 @@ Instances: {len(self._instances)}"""
2030
2030
  def _render_scene(self):
2031
2031
  gl = OpenGLRenderer.gl
2032
2032
 
2033
- self.window.switch_to()
2033
+ self._switch_context()
2034
2034
 
2035
2035
  start_instance_idx = 0
2036
2036
 
@@ -2055,7 +2055,7 @@ Instances: {len(self._instances)}"""
2055
2055
  def _render_scene_tiled(self):
2056
2056
  gl = OpenGLRenderer.gl
2057
2057
 
2058
- self.window.switch_to()
2058
+ self._switch_context()
2059
2059
 
2060
2060
  for i, viewport in enumerate(self._tile_viewports):
2061
2061
  projection_matrix_ptr = arr_pointer(self._tile_projection_matrices[i])
@@ -2205,7 +2205,7 @@ Instances: {len(self._instances)}"""
2205
2205
  def register_shape(self, geo_hash, vertices, indices, color1=None, color2=None):
2206
2206
  gl = OpenGLRenderer.gl
2207
2207
 
2208
- self.window.switch_to()
2208
+ self._switch_context()
2209
2209
 
2210
2210
  shape = len(self._shapes)
2211
2211
  if color1 is None:
@@ -2256,7 +2256,7 @@ Instances: {len(self._instances)}"""
2256
2256
  def deregister_shape(self, shape):
2257
2257
  gl = OpenGLRenderer.gl
2258
2258
 
2259
- self.window.switch_to()
2259
+ self._switch_context()
2260
2260
 
2261
2261
  if shape not in self._shape_gl_buffers:
2262
2262
  return
@@ -2317,7 +2317,7 @@ Instances: {len(self._instances)}"""
2317
2317
  def update_instance_colors(self):
2318
2318
  gl = OpenGLRenderer.gl
2319
2319
 
2320
- self.window.switch_to()
2320
+ self._switch_context()
2321
2321
 
2322
2322
  colors1, colors2 = [], []
2323
2323
  all_instances = list(self._instances.values())
@@ -2340,7 +2340,7 @@ Instances: {len(self._instances)}"""
2340
2340
  def allocate_shape_instances(self):
2341
2341
  gl = OpenGLRenderer.gl
2342
2342
 
2343
- self.window.switch_to()
2343
+ self._switch_context()
2344
2344
 
2345
2345
  self._add_shape_instances = False
2346
2346
  self._wp_instance_transforms = wp.array(
@@ -2444,7 +2444,7 @@ Instances: {len(self._instances)}"""
2444
2444
  """
2445
2445
  gl = OpenGLRenderer.gl
2446
2446
 
2447
- self.window.switch_to()
2447
+ self._switch_context()
2448
2448
 
2449
2449
  if name in self._instances:
2450
2450
  i, body, shape, tf, scale, old_color1, old_color2, v = self._instances[name]
@@ -2558,7 +2558,7 @@ Instances: {len(self._instances)}"""
2558
2558
  """
2559
2559
  gl = OpenGLRenderer.gl
2560
2560
 
2561
- self.window.switch_to()
2561
+ self._switch_context()
2562
2562
 
2563
2563
  channels = 3 if mode == "rgb" else 1
2564
2564
 
@@ -3497,6 +3497,14 @@ Instances: {len(self._instances)}"""
3497
3497
  # fmt: on
3498
3498
  return np.array(vertices, dtype=np.float32), np.array(indices, dtype=np.uint32)
3499
3499
 
3500
+ def _switch_context(self):
3501
+ try:
3502
+ self.window.switch_to()
3503
+ except AttributeError:
3504
+ # The window could be in the process of being closed, in which case
3505
+ # its corresponding context might have been destroyed and set to `None`.
3506
+ pass
3507
+
3500
3508
 
3501
3509
  if __name__ == "__main__":
3502
3510
  renderer = OpenGLRenderer()
warp/render/render_usd.py CHANGED
@@ -582,7 +582,12 @@ class UsdRenderer:
582
582
  mesh = UsdGeom.Mesh.Get(self.stage, mesh_path)
583
583
  if not mesh:
584
584
  mesh = UsdGeom.Mesh.Define(self.stage, mesh_path)
585
- UsdGeom.Primvar(mesh.GetDisplayColorAttr()).SetInterpolation("vertex")
585
+ if colors is not None and len(colors) == 3:
586
+ color_interp = "constant"
587
+ else:
588
+ color_interp = "vertex"
589
+
590
+ UsdGeom.Primvar(mesh.GetDisplayColorAttr()).SetInterpolation(color_interp)
586
591
  _usd_add_xform(mesh)
587
592
 
588
593
  # force topology update on first frame
@@ -595,7 +600,10 @@ class UsdRenderer:
595
600
  mesh.GetFaceVertexIndicesAttr().Set(idxs, self.time)
596
601
  mesh.GetFaceVertexCountsAttr().Set([3] * len(idxs), self.time)
597
602
 
598
- if colors:
603
+ if colors is not None:
604
+ if len(colors) == 3:
605
+ colors = (colors,)
606
+
599
607
  mesh.GetDisplayColorAttr().Set(colors, self.time)
600
608
 
601
609
  self._shape_constructors[name] = UsdGeom.Mesh