warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (131) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1077 -481
  8. warp/codegen.py +250 -122
  9. warp/config.py +65 -21
  10. warp/context.py +500 -149
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_marching_cubes.py +1 -1
  16. warp/examples/core/example_mesh.py +1 -1
  17. warp/examples/core/example_torch.py +18 -34
  18. warp/examples/core/example_wave.py +1 -1
  19. warp/examples/fem/example_apic_fluid.py +1 -0
  20. warp/examples/fem/example_mixed_elasticity.py +1 -1
  21. warp/examples/optim/example_bounce.py +1 -1
  22. warp/examples/optim/example_cloth_throw.py +1 -1
  23. warp/examples/optim/example_diffray.py +4 -15
  24. warp/examples/optim/example_drone.py +1 -1
  25. warp/examples/optim/example_softbody_properties.py +392 -0
  26. warp/examples/optim/example_trajectory.py +1 -3
  27. warp/examples/optim/example_walker.py +5 -0
  28. warp/examples/sim/example_cartpole.py +0 -2
  29. warp/examples/sim/example_cloth_self_contact.py +314 -0
  30. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  31. warp/examples/sim/example_jacobian_ik.py +0 -2
  32. warp/examples/sim/example_quadruped.py +5 -2
  33. warp/examples/tile/example_tile_cholesky.py +79 -0
  34. warp/examples/tile/example_tile_convolution.py +2 -2
  35. warp/examples/tile/example_tile_fft.py +2 -2
  36. warp/examples/tile/example_tile_filtering.py +3 -3
  37. warp/examples/tile/example_tile_matmul.py +4 -4
  38. warp/examples/tile/example_tile_mlp.py +12 -12
  39. warp/examples/tile/example_tile_nbody.py +191 -0
  40. warp/examples/tile/example_tile_walker.py +319 -0
  41. warp/math.py +147 -0
  42. warp/native/array.h +12 -0
  43. warp/native/builtin.h +0 -1
  44. warp/native/bvh.cpp +149 -70
  45. warp/native/bvh.cu +287 -68
  46. warp/native/bvh.h +195 -85
  47. warp/native/clang/clang.cpp +6 -2
  48. warp/native/crt.h +1 -0
  49. warp/native/cuda_util.cpp +35 -0
  50. warp/native/cuda_util.h +5 -0
  51. warp/native/exports.h +40 -40
  52. warp/native/intersect.h +17 -0
  53. warp/native/mat.h +57 -3
  54. warp/native/mathdx.cpp +19 -0
  55. warp/native/mesh.cpp +25 -8
  56. warp/native/mesh.cu +153 -101
  57. warp/native/mesh.h +482 -403
  58. warp/native/quat.h +40 -0
  59. warp/native/solid_angle.h +7 -0
  60. warp/native/sort.cpp +85 -0
  61. warp/native/sort.cu +34 -0
  62. warp/native/sort.h +3 -1
  63. warp/native/spatial.h +11 -0
  64. warp/native/tile.h +1189 -664
  65. warp/native/tile_reduce.h +8 -6
  66. warp/native/vec.h +41 -0
  67. warp/native/warp.cpp +8 -1
  68. warp/native/warp.cu +263 -40
  69. warp/native/warp.h +19 -5
  70. warp/optim/linear.py +22 -4
  71. warp/render/render_opengl.py +132 -59
  72. warp/render/render_usd.py +10 -2
  73. warp/sim/__init__.py +6 -1
  74. warp/sim/collide.py +289 -32
  75. warp/sim/import_urdf.py +20 -5
  76. warp/sim/integrator_euler.py +25 -7
  77. warp/sim/integrator_featherstone.py +147 -35
  78. warp/sim/integrator_vbd.py +842 -40
  79. warp/sim/model.py +173 -112
  80. warp/sim/render.py +2 -2
  81. warp/stubs.py +249 -116
  82. warp/tape.py +28 -30
  83. warp/tests/aux_test_module_unload.py +15 -0
  84. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  85. warp/tests/test_array.py +100 -0
  86. warp/tests/test_assert.py +242 -0
  87. warp/tests/test_codegen.py +14 -61
  88. warp/tests/test_collision.py +8 -8
  89. warp/tests/test_examples.py +16 -1
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_launch.py +77 -26
  94. warp/tests/test_mat.py +213 -168
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +11 -7
  97. warp/tests/test_matmul_lite.py +4 -4
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +6 -5
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_static.py +16 -0
  109. warp/tests/test_tape.py +25 -0
  110. warp/tests/test_tile.py +134 -191
  111. warp/tests/test_tile_load.py +399 -0
  112. warp/tests/test_tile_mathdx.py +61 -8
  113. warp/tests/test_tile_mlp.py +17 -17
  114. warp/tests/test_tile_reduce.py +24 -18
  115. warp/tests/test_tile_shared_memory.py +66 -17
  116. warp/tests/test_tile_view.py +165 -0
  117. warp/tests/test_torch.py +35 -0
  118. warp/tests/test_utils.py +36 -24
  119. warp/tests/test_vec.py +110 -0
  120. warp/tests/unittest_suites.py +29 -4
  121. warp/tests/unittest_utils.py +30 -11
  122. warp/thirdparty/unittest_parallel.py +5 -2
  123. warp/types.py +419 -111
  124. warp/utils.py +9 -5
  125. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
  126. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
  127. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
  128. warp/examples/benchmarks/benchmark_tile.py +0 -179
  129. warp/native/tile_gemm.h +0 -341
  130. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
  131. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/native/warp.h CHANGED
@@ -67,21 +67,23 @@ extern "C"
67
67
  WP_API void memtile_host(void* dest, const void* src, size_t srcsize, size_t n);
68
68
  WP_API void memtile_device(void* context, void* dest, const void* src, size_t srcsize, size_t n);
69
69
 
70
- WP_API uint64_t bvh_create_host(wp::vec3* lowers, wp::vec3* uppers, int num_items);
70
+ WP_API uint64_t bvh_create_host(wp::vec3* lowers, wp::vec3* uppers, int num_items, int constructor_type);
71
71
  WP_API void bvh_destroy_host(uint64_t id);
72
72
  WP_API void bvh_refit_host(uint64_t id);
73
73
 
74
- WP_API uint64_t bvh_create_device(void* context, wp::vec3* lowers, wp::vec3* uppers, int num_items);
74
+ WP_API uint64_t bvh_create_device(void* context, wp::vec3* lowers, wp::vec3* uppers, int num_items, int constructor_type);
75
75
  WP_API void bvh_destroy_device(uint64_t id);
76
76
  WP_API void bvh_refit_device(uint64_t id);
77
77
 
78
78
  // create a user-accessible copy of the mesh, it is the
79
79
  // users responsibility to keep-alive the points/tris data for the duration of the mesh lifetime
80
- WP_API uint64_t mesh_create_host(wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number);
80
+ WP_API uint64_t mesh_create_host(wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris,
81
+ int num_points, int num_tris, int support_winding_number, int constructor_type);
81
82
  WP_API void mesh_destroy_host(uint64_t id);
82
83
  WP_API void mesh_refit_host(uint64_t id);
83
84
 
84
- WP_API uint64_t mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number);
85
+ WP_API uint64_t mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities,
86
+ wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number, int constructor_type);
85
87
  WP_API void mesh_destroy_device(uint64_t id);
86
88
  WP_API void mesh_refit_device(uint64_t id);
87
89
 
@@ -159,6 +161,9 @@ extern "C"
159
161
  WP_API void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n);
160
162
  WP_API void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n);
161
163
 
164
+ WP_API void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n);
165
+ WP_API void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n);
166
+
162
167
  WP_API void runlength_encode_int_host(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
163
168
  WP_API void runlength_encode_int_device(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
164
169
 
@@ -266,6 +271,7 @@ extern "C"
266
271
  WP_API int cuda_device_get_pci_device_id(int ordinal);
267
272
  WP_API int cuda_device_is_uva(int ordinal);
268
273
  WP_API int cuda_device_is_mempool_supported(int ordinal);
274
+ WP_API int cuda_device_is_ipc_supported(int ordinal);
269
275
  WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold);
270
276
  WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal);
271
277
  WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem);
@@ -294,6 +300,13 @@ extern "C"
294
300
  WP_API int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal);
295
301
  WP_API int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable);
296
302
 
303
+ // inter-process communication
304
+ WP_API void cuda_ipc_get_mem_handle(void* ptr, char* out_buffer);
305
+ WP_API void* cuda_ipc_open_mem_handle(void* context, char* handle);
306
+ WP_API void cuda_ipc_close_mem_handle(void* ptr);
307
+ WP_API void cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer);
308
+ WP_API void* cuda_ipc_open_event_handle(void* context, char* handle);
309
+
297
310
  WP_API void* cuda_stream_create(void* context, int priority);
298
311
  WP_API void cuda_stream_destroy(void* context, void* stream);
299
312
  WP_API void cuda_stream_register(void* context, void* stream);
@@ -317,9 +330,10 @@ extern "C"
317
330
  WP_API bool cuda_graph_launch(void* graph, void* stream);
318
331
  WP_API bool cuda_graph_destroy(void* context, void* graph);
319
332
 
320
- WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes);
333
+ WP_API size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types);
321
334
  WP_API bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size);
322
335
  WP_API bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads);
336
+ WP_API bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads);
323
337
 
324
338
  WP_API void* cuda_load_module(void* context, const char* ptx);
325
339
  WP_API void cuda_unload_module(void* context, void* module);
warp/optim/linear.py CHANGED
@@ -84,10 +84,7 @@ def aslinearoperator(A: _Matrix) -> LinearOperator:
84
84
  sparse.bsr_mv(A, x, z, alpha, beta)
85
85
 
86
86
  def dense_mv(x, y, z, alpha, beta):
87
- x = x.reshape((x.shape[0], 1))
88
- y = y.reshape((y.shape[0], 1))
89
- z = z.reshape((y.shape[0], 1))
90
- wp.matmul(A, x, y, z, alpha, beta)
87
+ wp.launch(_dense_mv_kernel, dim=A.shape[1], device=A.device, inputs=[A, x, y, z, alpha, beta])
91
88
 
92
89
  def diag_mv(x, y, z, alpha, beta):
93
90
  scalar_type = wp.types.type_scalar_type(A.dtype)
@@ -803,6 +800,27 @@ def _run_solver_loop(
803
800
  return cur_iter, err, atol
804
801
 
805
802
 
803
+ @wp.func
804
+ def _calc_mv_product(i: wp.int32, A: wp.array2d(dtype=Any), x: wp.array1d(dtype=Any)):
805
+ sum = A.dtype(0)
806
+ for j in range(A.shape[1]):
807
+ sum += A[i, j] * x[j]
808
+ return sum
809
+
810
+
811
+ @wp.kernel
812
+ def _dense_mv_kernel(
813
+ A: wp.array2d(dtype=Any),
814
+ x: wp.array1d(dtype=Any),
815
+ y: wp.array1d(dtype=Any),
816
+ z: wp.array1d(dtype=Any),
817
+ alpha: Any,
818
+ beta: Any,
819
+ ):
820
+ i = wp.tid()
821
+ z[i] = z.dtype(beta) * y[i] + z.dtype(alpha) * _calc_mv_product(i, A, x)
822
+
823
+
806
824
  @wp.kernel
807
825
  def _diag_mv_kernel(
808
826
  A: wp.array(dtype=Any),
@@ -671,6 +671,15 @@ class ShapeInstancer:
671
671
  [3D point, 3D normal, UV texture coordinates]
672
672
  """
673
673
 
674
+ gl = None # Class-level variable to hold the imported module
675
+
676
+ @classmethod
677
+ def initialize_gl(cls):
678
+ if cls.gl is None: # Only import if not already imported
679
+ from pyglet import gl
680
+
681
+ cls.gl = gl
682
+
674
683
  def __new__(cls, *args, **kwargs):
675
684
  instance = super(ShapeInstancer, cls).__new__(cls)
676
685
  instance.instance_transform_gl_buffer = None
@@ -690,8 +699,10 @@ class ShapeInstancer:
690
699
  self.scalings = None
691
700
  self._instance_transform_cuda_buffer = None
692
701
 
702
+ ShapeInstancer.initialize_gl()
703
+
693
704
  def __del__(self):
694
- from pyglet import gl
705
+ gl = ShapeInstancer.gl
695
706
 
696
707
  if self.instance_transform_gl_buffer is not None:
697
708
  try:
@@ -709,7 +720,7 @@ class ShapeInstancer:
709
720
  pass
710
721
 
711
722
  def register_shape(self, vertices, indices, color1=(1.0, 1.0, 1.0), color2=(0.0, 0.0, 0.0)):
712
- from pyglet import gl
723
+ gl = ShapeInstancer.gl
713
724
 
714
725
  if color1 is not None and color2 is None:
715
726
  color2 = np.clip(np.array(color1) + 0.25, 0.0, 1.0)
@@ -750,7 +761,7 @@ class ShapeInstancer:
750
761
  self.face_count = len(indices)
751
762
 
752
763
  def update_colors(self, colors1, colors2):
753
- from pyglet import gl
764
+ gl = ShapeInstancer.gl
754
765
 
755
766
  if colors1 is None:
756
767
  colors1 = np.tile(self.color1, (self.num_instances, 1))
@@ -789,7 +800,7 @@ class ShapeInstancer:
789
800
  gl.glVertexAttribDivisor(8, 1)
790
801
 
791
802
  def allocate_instances(self, positions, rotations=None, colors1=None, colors2=None, scalings=None):
792
- from pyglet import gl
803
+ gl = ShapeInstancer.gl
793
804
 
794
805
  gl.glBindVertexArray(self.vao)
795
806
 
@@ -836,6 +847,7 @@ class ShapeInstancer:
836
847
  vbo_transforms,
837
848
  ],
838
849
  device=self.device,
850
+ record_tape=False,
839
851
  )
840
852
 
841
853
  vbo_transforms = vbo_transforms.numpy()
@@ -864,7 +876,7 @@ class ShapeInstancer:
864
876
  gl.glBindVertexArray(0)
865
877
 
866
878
  def update_instances(self, transforms: wp.array = None, scalings: wp.array = None, colors1=None, colors2=None):
867
- from pyglet import gl
879
+ gl = ShapeInstancer.gl
868
880
 
869
881
  if transforms is not None:
870
882
  if transforms.device.is_cuda:
@@ -897,6 +909,7 @@ class ShapeInstancer:
897
909
  vbo_transforms,
898
910
  ],
899
911
  device=self.device,
912
+ record_tape=False,
900
913
  )
901
914
 
902
915
  self._instance_transform_cuda_buffer.unmap()
@@ -905,7 +918,7 @@ class ShapeInstancer:
905
918
  self.update_colors(colors1, colors2)
906
919
 
907
920
  def render(self):
908
- from pyglet import gl
921
+ gl = ShapeInstancer.gl
909
922
 
910
923
  gl.glUseProgram(self.shape_shader.id)
911
924
 
@@ -915,7 +928,7 @@ class ShapeInstancer:
915
928
 
916
929
  # scope exposes VBO transforms to be set directly by a warp kernel
917
930
  def __enter__(self):
918
- from pyglet import gl
931
+ gl = ShapeInstancer.gl
919
932
 
920
933
  gl.glBindVertexArray(self.vao)
921
934
  self.vbo_transforms = self._instance_transform_cuda_buffer.map(dtype=wp.mat44, shape=(self.num_instances,))
@@ -941,6 +954,15 @@ class OpenGLRenderer:
941
954
  # number of segments to use for rendering spheres, capsules, cones and cylinders
942
955
  default_num_segments = 32
943
956
 
957
+ gl = None # Class-level variable to hold the imported module
958
+
959
+ @classmethod
960
+ def initialize_gl(cls):
961
+ if cls.gl is None: # Only import if not already imported
962
+ from pyglet import gl
963
+
964
+ cls.gl = gl
965
+
944
966
  def __init__(
945
967
  self,
946
968
  title="Warp",
@@ -968,6 +990,7 @@ class OpenGLRenderer:
968
990
  enable_backface_culling=True,
969
991
  enable_mouse_interaction=True,
970
992
  enable_keyboard_interaction=True,
993
+ device=None,
971
994
  ):
972
995
  """
973
996
  Args:
@@ -997,6 +1020,7 @@ class OpenGLRenderer:
997
1020
  enable_backface_culling (bool): Whether to enable backface culling.
998
1021
  enable_mouse_interaction (bool): Whether to enable mouse interaction.
999
1022
  enable_keyboard_interaction (bool): Whether to enable keyboard interaction.
1023
+ device (Devicelike): Where to store the internal data.
1000
1024
 
1001
1025
  Note:
1002
1026
 
@@ -1021,9 +1045,11 @@ class OpenGLRenderer:
1021
1045
  # disable error checking for performance
1022
1046
  pyglet.options["debug_gl"] = False
1023
1047
 
1024
- from pyglet import gl
1025
1048
  from pyglet.graphics.shader import Shader, ShaderProgram
1026
1049
  from pyglet.math import Vec3 as PyVec3
1050
+
1051
+ OpenGLRenderer.initialize_gl()
1052
+ gl = OpenGLRenderer.gl
1027
1053
  except ImportError as e:
1028
1054
  raise Exception("OpenGLRenderer requires pyglet (version >= 2.0) to be installed.") from e
1029
1055
 
@@ -1040,7 +1066,11 @@ class OpenGLRenderer:
1040
1066
  self.render_depth = render_depth
1041
1067
  self.enable_backface_culling = enable_backface_culling
1042
1068
 
1043
- self._device = wp.get_preferred_device()
1069
+ if device is None:
1070
+ self._device = wp.get_preferred_device()
1071
+ else:
1072
+ self._device = wp.get_device(device)
1073
+
1044
1074
  self._title = title
1045
1075
 
1046
1076
  self.window = pyglet.window.Window(
@@ -1052,9 +1082,8 @@ class OpenGLRenderer:
1052
1082
  self.headless = headless
1053
1083
  self.app = pyglet.app
1054
1084
 
1055
- if not headless:
1056
- # making window current opengl rendering context
1057
- self.window.switch_to()
1085
+ # making window current opengl rendering context
1086
+ self._switch_context()
1058
1087
 
1059
1088
  self.screen_width, self.screen_height = self.window.get_framebuffer_size()
1060
1089
 
@@ -1379,7 +1408,6 @@ class OpenGLRenderer:
1379
1408
 
1380
1409
  Window._enable_event_queue = False
1381
1410
 
1382
- self.window.switch_to()
1383
1411
  self.window.dispatch_pending_events()
1384
1412
 
1385
1413
  platform_event_loop = self.app.platform_event_loop
@@ -1405,7 +1433,9 @@ class OpenGLRenderer:
1405
1433
  return self.app.event_loop.has_exit
1406
1434
 
1407
1435
  def clear(self):
1408
- from pyglet import gl
1436
+ gl = OpenGLRenderer.gl
1437
+
1438
+ self._switch_context()
1409
1439
 
1410
1440
  if not self.headless:
1411
1441
  self.app.event_loop.dispatch_event("on_exit")
@@ -1525,9 +1555,9 @@ class OpenGLRenderer:
1525
1555
  if rescale_window:
1526
1556
  self.window.set_size(self._tile_width * self._tile_ncols, self._tile_height * self._tile_nrows)
1527
1557
  else:
1528
- assert (
1529
- len(tile_positions) == n and len(tile_sizes) == n
1530
- ), "Number of tiles does not match number of instances."
1558
+ assert len(tile_positions) == n and len(tile_sizes) == n, (
1559
+ "Number of tiles does not match number of instances."
1560
+ )
1531
1561
  self._tile_ncols = None
1532
1562
  self._tile_nrows = None
1533
1563
  self._tile_width = None
@@ -1599,7 +1629,9 @@ class OpenGLRenderer:
1599
1629
  self._tile_viewports[tile_id] = (x, y, w, h)
1600
1630
 
1601
1631
  def _setup_framebuffer(self):
1602
- from pyglet import gl
1632
+ gl = OpenGLRenderer.gl
1633
+
1634
+ self._switch_context()
1603
1635
 
1604
1636
  if self._frame_texture is None:
1605
1637
  self._frame_texture = gl.GLuint()
@@ -1767,7 +1799,9 @@ class OpenGLRenderer:
1767
1799
  return np.array((scaling, 0, 0, 0, 0, scaling, 0, 0, 0, 0, scaling, 0, 0, 0, 0, 1), dtype=np.float32)
1768
1800
 
1769
1801
  def update_model_matrix(self, model_matrix: Optional[Mat44] = None):
1770
- from pyglet import gl
1802
+ gl = OpenGLRenderer.gl
1803
+
1804
+ self._switch_context()
1771
1805
 
1772
1806
  if model_matrix is None:
1773
1807
  self._model_matrix = self.compute_model_matrix(self._camera_axis, self._scaling)
@@ -1862,7 +1896,9 @@ class OpenGLRenderer:
1862
1896
  self._draw()
1863
1897
 
1864
1898
  def _draw(self):
1865
- from pyglet import gl
1899
+ gl = OpenGLRenderer.gl
1900
+
1901
+ self._switch_context()
1866
1902
 
1867
1903
  if not self.headless:
1868
1904
  # catch key hold events
@@ -1961,7 +1997,9 @@ Instances: {len(self._instances)}"""
1961
1997
  cb()
1962
1998
 
1963
1999
  def _draw_grid(self, is_tiled=False):
1964
- from pyglet import gl
2000
+ gl = OpenGLRenderer.gl
2001
+
2002
+ self._switch_context()
1965
2003
 
1966
2004
  if not is_tiled:
1967
2005
  gl.glUseProgram(self._grid_shader.id)
@@ -1974,7 +2012,9 @@ Instances: {len(self._instances)}"""
1974
2012
  gl.glBindVertexArray(0)
1975
2013
 
1976
2014
  def _draw_sky(self, is_tiled=False):
1977
- from pyglet import gl
2015
+ gl = OpenGLRenderer.gl
2016
+
2017
+ self._switch_context()
1978
2018
 
1979
2019
  if not is_tiled:
1980
2020
  gl.glUseProgram(self._sky_shader.id)
@@ -1988,7 +2028,9 @@ Instances: {len(self._instances)}"""
1988
2028
  gl.glBindVertexArray(0)
1989
2029
 
1990
2030
  def _render_scene(self):
1991
- from pyglet import gl
2031
+ gl = OpenGLRenderer.gl
2032
+
2033
+ self._switch_context()
1992
2034
 
1993
2035
  start_instance_idx = 0
1994
2036
 
@@ -2011,7 +2053,9 @@ Instances: {len(self._instances)}"""
2011
2053
  gl.glBindVertexArray(0)
2012
2054
 
2013
2055
  def _render_scene_tiled(self):
2014
- from pyglet import gl
2056
+ gl = OpenGLRenderer.gl
2057
+
2058
+ self._switch_context()
2015
2059
 
2016
2060
  for i, viewport in enumerate(self._tile_viewports):
2017
2061
  projection_matrix_ptr = arr_pointer(self._tile_projection_matrices[i])
@@ -2066,6 +2110,7 @@ Instances: {len(self._instances)}"""
2066
2110
  return
2067
2111
 
2068
2112
  import pyglet
2113
+ from pyglet.math import Vec3 as PyVec3
2069
2114
 
2070
2115
  if buttons & pyglet.window.mouse.LEFT:
2071
2116
  sensitivity = 0.1
@@ -2077,10 +2122,12 @@ Instances: {len(self._instances)}"""
2077
2122
 
2078
2123
  self._pitch = max(min(self._pitch, 89.0), -89.0)
2079
2124
 
2080
- self._camera_front.x = np.cos(np.deg2rad(self._yaw)) * np.cos(np.deg2rad(self._pitch))
2081
- self._camera_front.y = np.sin(np.deg2rad(self._pitch))
2082
- self._camera_front.z = np.sin(np.deg2rad(self._yaw)) * np.cos(np.deg2rad(self._pitch))
2083
- self._camera_front = self._camera_front.normalize()
2125
+ self._camera_front = PyVec3(
2126
+ np.cos(np.deg2rad(self._yaw)) * np.cos(np.deg2rad(self._pitch)),
2127
+ np.sin(np.deg2rad(self._pitch)),
2128
+ np.sin(np.deg2rad(self._yaw)) * np.cos(np.deg2rad(self._pitch)),
2129
+ ).normalize()
2130
+
2084
2131
  self.update_view_matrix()
2085
2132
 
2086
2133
  def _scroll_callback(self, x, y, scroll_x, scroll_y):
@@ -2156,7 +2203,9 @@ Instances: {len(self._instances)}"""
2156
2203
  self._setup_framebuffer()
2157
2204
 
2158
2205
  def register_shape(self, geo_hash, vertices, indices, color1=None, color2=None):
2159
- from pyglet import gl
2206
+ gl = OpenGLRenderer.gl
2207
+
2208
+ self._switch_context()
2160
2209
 
2161
2210
  shape = len(self._shapes)
2162
2211
  if color1 is None:
@@ -2205,7 +2254,9 @@ Instances: {len(self._instances)}"""
2205
2254
  return shape
2206
2255
 
2207
2256
  def deregister_shape(self, shape):
2208
- from pyglet import gl
2257
+ gl = OpenGLRenderer.gl
2258
+
2259
+ self._switch_context()
2209
2260
 
2210
2261
  if shape not in self._shape_gl_buffers:
2211
2262
  return
@@ -2264,7 +2315,9 @@ Instances: {len(self._instances)}"""
2264
2315
  del self._instances[name]
2265
2316
 
2266
2317
  def update_instance_colors(self):
2267
- from pyglet import gl
2318
+ gl = OpenGLRenderer.gl
2319
+
2320
+ self._switch_context()
2268
2321
 
2269
2322
  colors1, colors2 = [], []
2270
2323
  all_instances = list(self._instances.values())
@@ -2285,7 +2338,9 @@ Instances: {len(self._instances)}"""
2285
2338
  gl.glBufferData(gl.GL_ARRAY_BUFFER, colors2.nbytes, colors2.ctypes.data, gl.GL_STATIC_DRAW)
2286
2339
 
2287
2340
  def allocate_shape_instances(self):
2288
- from pyglet import gl
2341
+ gl = OpenGLRenderer.gl
2342
+
2343
+ self._switch_context()
2289
2344
 
2290
2345
  self._add_shape_instances = False
2291
2346
  self._wp_instance_transforms = wp.array(
@@ -2387,7 +2442,9 @@ Instances: {len(self._instances)}"""
2387
2442
  color2: The second color of the checker pattern
2388
2443
  visible: Whether the shape is visible
2389
2444
  """
2390
- from pyglet import gl
2445
+ gl = OpenGLRenderer.gl
2446
+
2447
+ self._switch_context()
2391
2448
 
2392
2449
  if name in self._instances:
2393
2450
  i, body, shape, tf, scale, old_color1, old_color2, v = self._instances[name]
@@ -2452,6 +2509,7 @@ Instances: {len(self._instances)}"""
2452
2509
  vbo_transforms,
2453
2510
  ],
2454
2511
  device=self._device,
2512
+ record_tape=False,
2455
2513
  )
2456
2514
 
2457
2515
  self._instance_transform_cuda_buffer.unmap()
@@ -2498,29 +2556,30 @@ Instances: {len(self._instances)}"""
2498
2556
  Returns:
2499
2557
  bool: Whether the pixels were successfully read.
2500
2558
  """
2501
- from pyglet import gl
2559
+ gl = OpenGLRenderer.gl
2560
+
2561
+ self._switch_context()
2502
2562
 
2503
2563
  channels = 3 if mode == "rgb" else 1
2504
2564
 
2505
2565
  if split_up_tiles:
2506
- assert (
2507
- self._tile_width is not None and self._tile_height is not None
2508
- ), "Tile width and height are not set, tiles must all have the same size"
2509
- assert all(
2510
- vp[2] == self._tile_width for vp in self._tile_viewports
2511
- ), "Tile widths do not all equal global tile_width, use `get_tile_pixels` instead to retrieve pixels for a single tile"
2512
- assert all(
2513
- vp[3] == self._tile_height for vp in self._tile_viewports
2514
- ), "Tile heights do not all equal global tile_height, use `get_tile_pixels` instead to retrieve pixels for a single tile"
2515
- assert (
2516
- target_image.shape
2517
- == (
2518
- self.num_tiles,
2519
- self._tile_height,
2520
- self._tile_width,
2521
- channels,
2522
- )
2523
- ), f"Shape of `target_image` array does not match {self.num_tiles} x {self._tile_height} x {self._tile_width} x {channels}"
2566
+ assert self._tile_width is not None and self._tile_height is not None, (
2567
+ "Tile width and height are not set, tiles must all have the same size"
2568
+ )
2569
+ assert all(vp[2] == self._tile_width for vp in self._tile_viewports), (
2570
+ "Tile widths do not all equal global tile_width, use `get_tile_pixels` instead to retrieve pixels for a single tile"
2571
+ )
2572
+ assert all(vp[3] == self._tile_height for vp in self._tile_viewports), (
2573
+ "Tile heights do not all equal global tile_height, use `get_tile_pixels` instead to retrieve pixels for a single tile"
2574
+ )
2575
+ assert target_image.shape == (
2576
+ self.num_tiles,
2577
+ self._tile_height,
2578
+ self._tile_width,
2579
+ channels,
2580
+ ), (
2581
+ f"Shape of `target_image` array does not match {self.num_tiles} x {self._tile_height} x {self._tile_width} x {channels}"
2582
+ )
2524
2583
  else:
2525
2584
  assert target_image.shape == (
2526
2585
  self.screen_height,
@@ -2784,7 +2843,7 @@ Instances: {len(self._instances)}"""
2784
2843
  up_axis: The axis of the capsule that points up (0: x, 1: y, 2: z)
2785
2844
  color: The color of the capsule
2786
2845
  """
2787
- geo_hash = hash(("capsule", radius, half_height))
2846
+ geo_hash = hash(("capsule", radius, half_height, up_axis))
2788
2847
  if geo_hash in self._shape_geo_hash:
2789
2848
  shape = self._shape_geo_hash[geo_hash]
2790
2849
  if self.update_shape_instance(name, pos, rot):
@@ -2819,7 +2878,7 @@ Instances: {len(self._instances)}"""
2819
2878
  up_axis: The axis of the cylinder that points up (0: x, 1: y, 2: z)
2820
2879
  color: The color of the capsule
2821
2880
  """
2822
- geo_hash = hash(("cylinder", radius, half_height))
2881
+ geo_hash = hash(("cylinder", radius, half_height, up_axis))
2823
2882
  if geo_hash in self._shape_geo_hash:
2824
2883
  shape = self._shape_geo_hash[geo_hash]
2825
2884
  if self.update_shape_instance(name, pos, rot):
@@ -2854,7 +2913,7 @@ Instances: {len(self._instances)}"""
2854
2913
  up_axis: The axis of the cone that points up (0: x, 1: y, 2: z)
2855
2914
  color: The color of the cone
2856
2915
  """
2857
- geo_hash = hash(("cone", radius, half_height))
2916
+ geo_hash = hash(("cone", radius, half_height, up_axis))
2858
2917
  if geo_hash in self._shape_geo_hash:
2859
2918
  shape = self._shape_geo_hash[geo_hash]
2860
2919
  if self.update_shape_instance(name, pos, rot):
@@ -2933,7 +2992,7 @@ Instances: {len(self._instances)}"""
2933
2992
  indices = np.array(indices, dtype=np.int32).reshape((-1, 3))
2934
2993
  idx_count = len(indices)
2935
2994
 
2936
- geo_hash = hash((indices.tobytes(),))
2995
+ geo_hash = hash((points.tobytes(), indices.tobytes()))
2937
2996
 
2938
2997
  if name in self._instances:
2939
2998
  # We've already registered this mesh instance and its associated shape.
@@ -2955,6 +3014,12 @@ Instances: {len(self._instances)}"""
2955
3014
  if shape is not None:
2956
3015
  # Update the shape's point positions.
2957
3016
  self.update_shape_vertices(shape, points, scale)
3017
+
3018
+ if not is_template and name not in self._instances:
3019
+ # Create a new instance.
3020
+ body = self._resolve_body_id(parent_body)
3021
+ self.add_shape_instance(name, shape, body, pos, rot, color1=colors)
3022
+
2958
3023
  return shape
2959
3024
 
2960
3025
  # No existing shape for the given mesh was found, or its topology may have changed,
@@ -3032,16 +3097,16 @@ Instances: {len(self._instances)}"""
3032
3097
  name: A name for the USD prim on the stage
3033
3098
  up_axis: The axis of the arrow that points up (0: x, 1: y, 2: z)
3034
3099
  """
3035
- geo_hash = hash(("arrow", base_radius, base_height, cap_radius, cap_height))
3100
+ geo_hash = hash(("arrow", base_radius, base_height, cap_radius, cap_height, up_axis))
3036
3101
  if geo_hash in self._shape_geo_hash:
3037
3102
  shape = self._shape_geo_hash[geo_hash]
3038
- if self.update_shape_instance(name, pos, rot):
3103
+ if self.update_shape_instance(name, pos, rot, color1=color, color2=color):
3039
3104
  return shape
3040
3105
  else:
3041
3106
  vertices, indices = self._create_arrow_mesh(
3042
3107
  base_radius, base_height, cap_radius, cap_height, up_axis=up_axis
3043
3108
  )
3044
- shape = self.register_shape(geo_hash, vertices, indices)
3109
+ shape = self.register_shape(geo_hash, vertices, indices, color1=color, color2=color)
3045
3110
  if not is_template:
3046
3111
  body = self._resolve_body_id(parent_body)
3047
3112
  self.add_shape_instance(name, shape, body, pos, rot, color1=color, color2=color)
@@ -3432,6 +3497,14 @@ Instances: {len(self._instances)}"""
3432
3497
  # fmt: on
3433
3498
  return np.array(vertices, dtype=np.float32), np.array(indices, dtype=np.uint32)
3434
3499
 
3500
+ def _switch_context(self):
3501
+ try:
3502
+ self.window.switch_to()
3503
+ except AttributeError:
3504
+ # The window could be in the process of being closed, in which case
3505
+ # its corresponding context might have been destroyed and set to `None`.
3506
+ pass
3507
+
3435
3508
 
3436
3509
  if __name__ == "__main__":
3437
3510
  renderer = OpenGLRenderer()
warp/render/render_usd.py CHANGED
@@ -582,7 +582,12 @@ class UsdRenderer:
582
582
  mesh = UsdGeom.Mesh.Get(self.stage, mesh_path)
583
583
  if not mesh:
584
584
  mesh = UsdGeom.Mesh.Define(self.stage, mesh_path)
585
- UsdGeom.Primvar(mesh.GetDisplayColorAttr()).SetInterpolation("vertex")
585
+ if colors is not None and len(colors) == 3:
586
+ color_interp = "constant"
587
+ else:
588
+ color_interp = "vertex"
589
+
590
+ UsdGeom.Primvar(mesh.GetDisplayColorAttr()).SetInterpolation(color_interp)
586
591
  _usd_add_xform(mesh)
587
592
 
588
593
  # force topology update on first frame
@@ -595,7 +600,10 @@ class UsdRenderer:
595
600
  mesh.GetFaceVertexIndicesAttr().Set(idxs, self.time)
596
601
  mesh.GetFaceVertexCountsAttr().Set([3] * len(idxs), self.time)
597
602
 
598
- if colors:
603
+ if colors is not None:
604
+ if len(colors) == 3:
605
+ colors = (colors,)
606
+
599
607
  mesh.GetDisplayColorAttr().Set(colors, self.time)
600
608
 
601
609
  self._shape_constructors[name] = UsdGeom.Mesh
warp/sim/__init__.py CHANGED
@@ -50,4 +50,9 @@ from .model import (
50
50
  ModelShapeMaterials,
51
51
  State,
52
52
  )
53
- from .utils import load_mesh, quat_from_euler, quat_to_euler, velocity_at_point
53
+ from .utils import (
54
+ load_mesh,
55
+ quat_from_euler,
56
+ quat_to_euler,
57
+ velocity_at_point,
58
+ )