warp-lang 1.5.0__py3-none-manylinux2014_x86_64.whl → 1.6.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (132) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1124 -497
  8. warp/codegen.py +261 -136
  9. warp/config.py +1 -1
  10. warp/context.py +357 -119
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth.py +3 -1
  27. warp/examples/sim/example_cloth_self_contact.py +260 -0
  28. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  29. warp/examples/sim/example_jacobian_ik.py +0 -2
  30. warp/examples/sim/example_quadruped.py +5 -2
  31. warp/examples/tile/example_tile_cholesky.py +79 -0
  32. warp/examples/tile/example_tile_convolution.py +2 -2
  33. warp/examples/tile/example_tile_fft.py +2 -2
  34. warp/examples/tile/example_tile_filtering.py +3 -3
  35. warp/examples/tile/example_tile_matmul.py +4 -4
  36. warp/examples/tile/example_tile_mlp.py +12 -12
  37. warp/examples/tile/example_tile_nbody.py +180 -0
  38. warp/examples/tile/example_tile_walker.py +319 -0
  39. warp/fem/geometry/geometry.py +0 -2
  40. warp/math.py +147 -0
  41. warp/native/array.h +12 -0
  42. warp/native/builtin.h +0 -1
  43. warp/native/bvh.cpp +149 -70
  44. warp/native/bvh.cu +287 -68
  45. warp/native/bvh.h +195 -85
  46. warp/native/clang/clang.cpp +5 -1
  47. warp/native/coloring.cpp +5 -1
  48. warp/native/cuda_util.cpp +91 -53
  49. warp/native/cuda_util.h +5 -0
  50. warp/native/exports.h +40 -40
  51. warp/native/intersect.h +17 -0
  52. warp/native/mat.h +41 -0
  53. warp/native/mathdx.cpp +19 -0
  54. warp/native/mesh.cpp +25 -8
  55. warp/native/mesh.cu +153 -101
  56. warp/native/mesh.h +482 -403
  57. warp/native/quat.h +40 -0
  58. warp/native/solid_angle.h +7 -0
  59. warp/native/sort.cpp +85 -0
  60. warp/native/sort.cu +34 -0
  61. warp/native/sort.h +3 -1
  62. warp/native/spatial.h +11 -0
  63. warp/native/tile.h +1187 -669
  64. warp/native/tile_reduce.h +8 -6
  65. warp/native/vec.h +41 -0
  66. warp/native/warp.cpp +8 -1
  67. warp/native/warp.cu +263 -40
  68. warp/native/warp.h +19 -5
  69. warp/optim/linear.py +22 -4
  70. warp/render/render_opengl.py +130 -64
  71. warp/sim/__init__.py +6 -1
  72. warp/sim/collide.py +270 -26
  73. warp/sim/import_urdf.py +8 -8
  74. warp/sim/integrator_euler.py +25 -7
  75. warp/sim/integrator_featherstone.py +154 -35
  76. warp/sim/integrator_vbd.py +842 -40
  77. warp/sim/model.py +134 -72
  78. warp/sparse.py +1 -1
  79. warp/stubs.py +265 -132
  80. warp/tape.py +28 -30
  81. warp/tests/aux_test_module_unload.py +15 -0
  82. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  83. warp/tests/test_array.py +74 -0
  84. warp/tests/test_assert.py +242 -0
  85. warp/tests/test_codegen.py +14 -61
  86. warp/tests/test_collision.py +2 -2
  87. warp/tests/test_coloring.py +12 -2
  88. warp/tests/test_examples.py +12 -1
  89. warp/tests/test_func.py +21 -4
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_lerp.py +13 -87
  94. warp/tests/test_mat.py +138 -167
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +17 -16
  97. warp/tests/test_matmul_lite.py +10 -15
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +47 -2
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_smoothstep.py +17 -83
  109. warp/tests/test_static.py +19 -3
  110. warp/tests/test_tape.py +25 -0
  111. warp/tests/test_tile.py +178 -191
  112. warp/tests/test_tile_load.py +356 -0
  113. warp/tests/test_tile_mathdx.py +61 -8
  114. warp/tests/test_tile_mlp.py +17 -17
  115. warp/tests/test_tile_reduce.py +24 -18
  116. warp/tests/test_tile_shared_memory.py +66 -17
  117. warp/tests/test_tile_view.py +165 -0
  118. warp/tests/test_torch.py +35 -0
  119. warp/tests/test_utils.py +36 -24
  120. warp/tests/test_vec.py +110 -0
  121. warp/tests/unittest_suites.py +29 -4
  122. warp/tests/unittest_utils.py +30 -13
  123. warp/thirdparty/unittest_parallel.py +2 -2
  124. warp/types.py +411 -101
  125. warp/utils.py +10 -7
  126. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
  127. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
  128. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  129. warp/examples/benchmarks/benchmark_tile.py +0 -179
  130. warp/native/tile_gemm.h +0 -341
  131. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  132. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/native/warp.h CHANGED
@@ -67,21 +67,23 @@ extern "C"
67
67
  WP_API void memtile_host(void* dest, const void* src, size_t srcsize, size_t n);
68
68
  WP_API void memtile_device(void* context, void* dest, const void* src, size_t srcsize, size_t n);
69
69
 
70
- WP_API uint64_t bvh_create_host(wp::vec3* lowers, wp::vec3* uppers, int num_items);
70
+ WP_API uint64_t bvh_create_host(wp::vec3* lowers, wp::vec3* uppers, int num_items, int constructor_type);
71
71
  WP_API void bvh_destroy_host(uint64_t id);
72
72
  WP_API void bvh_refit_host(uint64_t id);
73
73
 
74
- WP_API uint64_t bvh_create_device(void* context, wp::vec3* lowers, wp::vec3* uppers, int num_items);
74
+ WP_API uint64_t bvh_create_device(void* context, wp::vec3* lowers, wp::vec3* uppers, int num_items, int constructor_type);
75
75
  WP_API void bvh_destroy_device(uint64_t id);
76
76
  WP_API void bvh_refit_device(uint64_t id);
77
77
 
78
78
  // create a user-accessible copy of the mesh, it is the
79
79
  // users responsibility to keep-alive the points/tris data for the duration of the mesh lifetime
80
- WP_API uint64_t mesh_create_host(wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number);
80
+ WP_API uint64_t mesh_create_host(wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris,
81
+ int num_points, int num_tris, int support_winding_number, int constructor_type);
81
82
  WP_API void mesh_destroy_host(uint64_t id);
82
83
  WP_API void mesh_refit_host(uint64_t id);
83
84
 
84
- WP_API uint64_t mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities, wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number);
85
+ WP_API uint64_t mesh_create_device(void* context, wp::array_t<wp::vec3> points, wp::array_t<wp::vec3> velocities,
86
+ wp::array_t<int> tris, int num_points, int num_tris, int support_winding_number, int constructor_type);
85
87
  WP_API void mesh_destroy_device(uint64_t id);
86
88
  WP_API void mesh_refit_device(uint64_t id);
87
89
 
@@ -159,6 +161,9 @@ extern "C"
159
161
  WP_API void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n);
160
162
  WP_API void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n);
161
163
 
164
+ WP_API void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n);
165
+ WP_API void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n);
166
+
162
167
  WP_API void runlength_encode_int_host(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
163
168
  WP_API void runlength_encode_int_device(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
164
169
 
@@ -266,6 +271,7 @@ extern "C"
266
271
  WP_API int cuda_device_get_pci_device_id(int ordinal);
267
272
  WP_API int cuda_device_is_uva(int ordinal);
268
273
  WP_API int cuda_device_is_mempool_supported(int ordinal);
274
+ WP_API int cuda_device_is_ipc_supported(int ordinal);
269
275
  WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold);
270
276
  WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal);
271
277
  WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem);
@@ -294,6 +300,13 @@ extern "C"
294
300
  WP_API int cuda_is_mempool_access_enabled(int target_ordinal, int peer_ordinal);
295
301
  WP_API int cuda_set_mempool_access_enabled(int target_ordinal, int peer_ordinal, int enable);
296
302
 
303
+ // inter-process communication
304
+ WP_API void cuda_ipc_get_mem_handle(void* ptr, char* out_buffer);
305
+ WP_API void* cuda_ipc_open_mem_handle(void* context, char* handle);
306
+ WP_API void cuda_ipc_close_mem_handle(void* ptr);
307
+ WP_API void cuda_ipc_get_event_handle(void* context, void* event, char* out_buffer);
308
+ WP_API void* cuda_ipc_open_event_handle(void* context, char* handle);
309
+
297
310
  WP_API void* cuda_stream_create(void* context, int priority);
298
311
  WP_API void cuda_stream_destroy(void* context, void* stream);
299
312
  WP_API void cuda_stream_register(void* context, void* stream);
@@ -317,9 +330,10 @@ extern "C"
317
330
  WP_API bool cuda_graph_launch(void* graph, void* stream);
318
331
  WP_API bool cuda_graph_destroy(void* context, void* graph);
319
332
 
320
- WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes);
333
+ WP_API size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes, int* ltoir_input_types);
321
334
  WP_API bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size);
322
335
  WP_API bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads);
336
+ WP_API bool cuda_compile_solver(const char* fatbin_output_path, const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int function, int precision, int fill_mode, int num_threads);
323
337
 
324
338
  WP_API void* cuda_load_module(void* context, const char* ptx);
325
339
  WP_API void cuda_unload_module(void* context, void* module);
warp/optim/linear.py CHANGED
@@ -84,10 +84,7 @@ def aslinearoperator(A: _Matrix) -> LinearOperator:
84
84
  sparse.bsr_mv(A, x, z, alpha, beta)
85
85
 
86
86
  def dense_mv(x, y, z, alpha, beta):
87
- x = x.reshape((x.shape[0], 1))
88
- y = y.reshape((y.shape[0], 1))
89
- z = z.reshape((y.shape[0], 1))
90
- wp.matmul(A, x, y, z, alpha, beta)
87
+ wp.launch(_dense_mv_kernel, dim=A.shape[1], device=A.device, inputs=[A, x, y, z, alpha, beta])
91
88
 
92
89
  def diag_mv(x, y, z, alpha, beta):
93
90
  scalar_type = wp.types.type_scalar_type(A.dtype)
@@ -803,6 +800,27 @@ def _run_solver_loop(
803
800
  return cur_iter, err, atol
804
801
 
805
802
 
803
+ @wp.func
804
+ def _calc_mv_product(i: wp.int32, A: wp.array2d(dtype=Any), x: wp.array1d(dtype=Any)):
805
+ sum = A.dtype(0)
806
+ for j in range(A.shape[1]):
807
+ sum += A[i, j] * x[j]
808
+ return sum
809
+
810
+
811
+ @wp.kernel
812
+ def _dense_mv_kernel(
813
+ A: wp.array2d(dtype=Any),
814
+ x: wp.array1d(dtype=Any),
815
+ y: wp.array1d(dtype=Any),
816
+ z: wp.array1d(dtype=Any),
817
+ alpha: Any,
818
+ beta: Any,
819
+ ):
820
+ i = wp.tid()
821
+ z[i] = z.dtype(beta) * y[i] + z.dtype(alpha) * _calc_mv_product(i, A, x)
822
+
823
+
806
824
  @wp.kernel
807
825
  def _diag_mv_kernel(
808
826
  A: wp.array(dtype=Any),
@@ -671,6 +671,15 @@ class ShapeInstancer:
671
671
  [3D point, 3D normal, UV texture coordinates]
672
672
  """
673
673
 
674
+ gl = None # Class-level variable to hold the imported module
675
+
676
+ @classmethod
677
+ def initialize_gl(cls):
678
+ if cls.gl is None: # Only import if not already imported
679
+ from pyglet import gl
680
+
681
+ cls.gl = gl
682
+
674
683
  def __new__(cls, *args, **kwargs):
675
684
  instance = super(ShapeInstancer, cls).__new__(cls)
676
685
  instance.instance_transform_gl_buffer = None
@@ -690,8 +699,10 @@ class ShapeInstancer:
690
699
  self.scalings = None
691
700
  self._instance_transform_cuda_buffer = None
692
701
 
702
+ ShapeInstancer.initialize_gl()
703
+
693
704
  def __del__(self):
694
- from pyglet import gl
705
+ gl = ShapeInstancer.gl
695
706
 
696
707
  if self.instance_transform_gl_buffer is not None:
697
708
  try:
@@ -709,7 +720,7 @@ class ShapeInstancer:
709
720
  pass
710
721
 
711
722
  def register_shape(self, vertices, indices, color1=(1.0, 1.0, 1.0), color2=(0.0, 0.0, 0.0)):
712
- from pyglet import gl
723
+ gl = ShapeInstancer.gl
713
724
 
714
725
  if color1 is not None and color2 is None:
715
726
  color2 = np.clip(np.array(color1) + 0.25, 0.0, 1.0)
@@ -750,7 +761,7 @@ class ShapeInstancer:
750
761
  self.face_count = len(indices)
751
762
 
752
763
  def update_colors(self, colors1, colors2):
753
- from pyglet import gl
764
+ gl = ShapeInstancer.gl
754
765
 
755
766
  if colors1 is None:
756
767
  colors1 = np.tile(self.color1, (self.num_instances, 1))
@@ -789,7 +800,7 @@ class ShapeInstancer:
789
800
  gl.glVertexAttribDivisor(8, 1)
790
801
 
791
802
  def allocate_instances(self, positions, rotations=None, colors1=None, colors2=None, scalings=None):
792
- from pyglet import gl
803
+ gl = ShapeInstancer.gl
793
804
 
794
805
  gl.glBindVertexArray(self.vao)
795
806
 
@@ -836,6 +847,7 @@ class ShapeInstancer:
836
847
  vbo_transforms,
837
848
  ],
838
849
  device=self.device,
850
+ record_tape=False,
839
851
  )
840
852
 
841
853
  vbo_transforms = vbo_transforms.numpy()
@@ -864,7 +876,7 @@ class ShapeInstancer:
864
876
  gl.glBindVertexArray(0)
865
877
 
866
878
  def update_instances(self, transforms: wp.array = None, scalings: wp.array = None, colors1=None, colors2=None):
867
- from pyglet import gl
879
+ gl = ShapeInstancer.gl
868
880
 
869
881
  if transforms is not None:
870
882
  if transforms.device.is_cuda:
@@ -897,6 +909,7 @@ class ShapeInstancer:
897
909
  vbo_transforms,
898
910
  ],
899
911
  device=self.device,
912
+ record_tape=False,
900
913
  )
901
914
 
902
915
  self._instance_transform_cuda_buffer.unmap()
@@ -905,7 +918,7 @@ class ShapeInstancer:
905
918
  self.update_colors(colors1, colors2)
906
919
 
907
920
  def render(self):
908
- from pyglet import gl
921
+ gl = ShapeInstancer.gl
909
922
 
910
923
  gl.glUseProgram(self.shape_shader.id)
911
924
 
@@ -915,7 +928,7 @@ class ShapeInstancer:
915
928
 
916
929
  # scope exposes VBO transforms to be set directly by a warp kernel
917
930
  def __enter__(self):
918
- from pyglet import gl
931
+ gl = ShapeInstancer.gl
919
932
 
920
933
  gl.glBindVertexArray(self.vao)
921
934
  self.vbo_transforms = self._instance_transform_cuda_buffer.map(dtype=wp.mat44, shape=(self.num_instances,))
@@ -941,6 +954,15 @@ class OpenGLRenderer:
941
954
  # number of segments to use for rendering spheres, capsules, cones and cylinders
942
955
  default_num_segments = 32
943
956
 
957
+ gl = None # Class-level variable to hold the imported module
958
+
959
+ @classmethod
960
+ def initialize_gl(cls):
961
+ if cls.gl is None: # Only import if not already imported
962
+ from pyglet import gl
963
+
964
+ cls.gl = gl
965
+
944
966
  def __init__(
945
967
  self,
946
968
  title="Warp",
@@ -968,6 +990,7 @@ class OpenGLRenderer:
968
990
  enable_backface_culling=True,
969
991
  enable_mouse_interaction=True,
970
992
  enable_keyboard_interaction=True,
993
+ device=None,
971
994
  ):
972
995
  """
973
996
  Args:
@@ -997,6 +1020,7 @@ class OpenGLRenderer:
997
1020
  enable_backface_culling (bool): Whether to enable backface culling.
998
1021
  enable_mouse_interaction (bool): Whether to enable mouse interaction.
999
1022
  enable_keyboard_interaction (bool): Whether to enable keyboard interaction.
1023
+ device (Devicelike): Where to store the internal data.
1000
1024
 
1001
1025
  Note:
1002
1026
 
@@ -1021,9 +1045,11 @@ class OpenGLRenderer:
1021
1045
  # disable error checking for performance
1022
1046
  pyglet.options["debug_gl"] = False
1023
1047
 
1024
- from pyglet import gl
1025
1048
  from pyglet.graphics.shader import Shader, ShaderProgram
1026
1049
  from pyglet.math import Vec3 as PyVec3
1050
+
1051
+ OpenGLRenderer.initialize_gl()
1052
+ gl = OpenGLRenderer.gl
1027
1053
  except ImportError as e:
1028
1054
  raise Exception("OpenGLRenderer requires pyglet (version >= 2.0) to be installed.") from e
1029
1055
 
@@ -1040,7 +1066,11 @@ class OpenGLRenderer:
1040
1066
  self.render_depth = render_depth
1041
1067
  self.enable_backface_culling = enable_backface_culling
1042
1068
 
1043
- self._device = wp.get_cuda_device()
1069
+ if device is None:
1070
+ self._device = wp.get_preferred_device()
1071
+ else:
1072
+ self._device = wp.get_device(device)
1073
+
1044
1074
  self._title = title
1045
1075
 
1046
1076
  self.window = pyglet.window.Window(
@@ -1052,9 +1082,8 @@ class OpenGLRenderer:
1052
1082
  self.headless = headless
1053
1083
  self.app = pyglet.app
1054
1084
 
1055
- if not headless:
1056
- # making window current opengl rendering context
1057
- self.window.switch_to()
1085
+ # making window current opengl rendering context
1086
+ self.window.switch_to()
1058
1087
 
1059
1088
  self.screen_width, self.screen_height = self.window.get_framebuffer_size()
1060
1089
 
@@ -1379,7 +1408,6 @@ class OpenGLRenderer:
1379
1408
 
1380
1409
  Window._enable_event_queue = False
1381
1410
 
1382
- self.window.switch_to()
1383
1411
  self.window.dispatch_pending_events()
1384
1412
 
1385
1413
  platform_event_loop = self.app.platform_event_loop
@@ -1405,7 +1433,9 @@ class OpenGLRenderer:
1405
1433
  return self.app.event_loop.has_exit
1406
1434
 
1407
1435
  def clear(self):
1408
- from pyglet import gl
1436
+ gl = OpenGLRenderer.gl
1437
+
1438
+ self.window.switch_to()
1409
1439
 
1410
1440
  if not self.headless:
1411
1441
  self.app.event_loop.dispatch_event("on_exit")
@@ -1525,9 +1555,9 @@ class OpenGLRenderer:
1525
1555
  if rescale_window:
1526
1556
  self.window.set_size(self._tile_width * self._tile_ncols, self._tile_height * self._tile_nrows)
1527
1557
  else:
1528
- assert (
1529
- len(tile_positions) == n and len(tile_sizes) == n
1530
- ), "Number of tiles does not match number of instances."
1558
+ assert len(tile_positions) == n and len(tile_sizes) == n, (
1559
+ "Number of tiles does not match number of instances."
1560
+ )
1531
1561
  self._tile_ncols = None
1532
1562
  self._tile_nrows = None
1533
1563
  self._tile_width = None
@@ -1599,7 +1629,9 @@ class OpenGLRenderer:
1599
1629
  self._tile_viewports[tile_id] = (x, y, w, h)
1600
1630
 
1601
1631
  def _setup_framebuffer(self):
1602
- from pyglet import gl
1632
+ gl = OpenGLRenderer.gl
1633
+
1634
+ self.window.switch_to()
1603
1635
 
1604
1636
  if self._frame_texture is None:
1605
1637
  self._frame_texture = gl.GLuint()
@@ -1767,7 +1799,9 @@ class OpenGLRenderer:
1767
1799
  return np.array((scaling, 0, 0, 0, 0, scaling, 0, 0, 0, 0, scaling, 0, 0, 0, 0, 1), dtype=np.float32)
1768
1800
 
1769
1801
  def update_model_matrix(self, model_matrix: Optional[Mat44] = None):
1770
- from pyglet import gl
1802
+ gl = OpenGLRenderer.gl
1803
+
1804
+ self.window.switch_to()
1771
1805
 
1772
1806
  if model_matrix is None:
1773
1807
  self._model_matrix = self.compute_model_matrix(self._camera_axis, self._scaling)
@@ -1862,7 +1896,9 @@ class OpenGLRenderer:
1862
1896
  self._draw()
1863
1897
 
1864
1898
  def _draw(self):
1865
- from pyglet import gl
1899
+ gl = OpenGLRenderer.gl
1900
+
1901
+ self.window.switch_to()
1866
1902
 
1867
1903
  if not self.headless:
1868
1904
  # catch key hold events
@@ -1961,7 +1997,9 @@ Instances: {len(self._instances)}"""
1961
1997
  cb()
1962
1998
 
1963
1999
  def _draw_grid(self, is_tiled=False):
1964
- from pyglet import gl
2000
+ gl = OpenGLRenderer.gl
2001
+
2002
+ self.window.switch_to()
1965
2003
 
1966
2004
  if not is_tiled:
1967
2005
  gl.glUseProgram(self._grid_shader.id)
@@ -1974,7 +2012,9 @@ Instances: {len(self._instances)}"""
1974
2012
  gl.glBindVertexArray(0)
1975
2013
 
1976
2014
  def _draw_sky(self, is_tiled=False):
1977
- from pyglet import gl
2015
+ gl = OpenGLRenderer.gl
2016
+
2017
+ self.window.switch_to()
1978
2018
 
1979
2019
  if not is_tiled:
1980
2020
  gl.glUseProgram(self._sky_shader.id)
@@ -1988,7 +2028,9 @@ Instances: {len(self._instances)}"""
1988
2028
  gl.glBindVertexArray(0)
1989
2029
 
1990
2030
  def _render_scene(self):
1991
- from pyglet import gl
2031
+ gl = OpenGLRenderer.gl
2032
+
2033
+ self.window.switch_to()
1992
2034
 
1993
2035
  start_instance_idx = 0
1994
2036
 
@@ -2011,7 +2053,9 @@ Instances: {len(self._instances)}"""
2011
2053
  gl.glBindVertexArray(0)
2012
2054
 
2013
2055
  def _render_scene_tiled(self):
2014
- from pyglet import gl
2056
+ gl = OpenGLRenderer.gl
2057
+
2058
+ self.window.switch_to()
2015
2059
 
2016
2060
  for i, viewport in enumerate(self._tile_viewports):
2017
2061
  projection_matrix_ptr = arr_pointer(self._tile_projection_matrices[i])
@@ -2066,6 +2110,7 @@ Instances: {len(self._instances)}"""
2066
2110
  return
2067
2111
 
2068
2112
  import pyglet
2113
+ from pyglet.math import Vec3 as PyVec3
2069
2114
 
2070
2115
  if buttons & pyglet.window.mouse.LEFT:
2071
2116
  sensitivity = 0.1
@@ -2077,10 +2122,12 @@ Instances: {len(self._instances)}"""
2077
2122
 
2078
2123
  self._pitch = max(min(self._pitch, 89.0), -89.0)
2079
2124
 
2080
- self._camera_front.x = np.cos(np.deg2rad(self._yaw)) * np.cos(np.deg2rad(self._pitch))
2081
- self._camera_front.y = np.sin(np.deg2rad(self._pitch))
2082
- self._camera_front.z = np.sin(np.deg2rad(self._yaw)) * np.cos(np.deg2rad(self._pitch))
2083
- self._camera_front = self._camera_front.normalize()
2125
+ self._camera_front = PyVec3(
2126
+ np.cos(np.deg2rad(self._yaw)) * np.cos(np.deg2rad(self._pitch)),
2127
+ np.sin(np.deg2rad(self._pitch)),
2128
+ np.sin(np.deg2rad(self._yaw)) * np.cos(np.deg2rad(self._pitch)),
2129
+ ).normalize()
2130
+
2084
2131
  self.update_view_matrix()
2085
2132
 
2086
2133
  def _scroll_callback(self, x, y, scroll_x, scroll_y):
@@ -2156,7 +2203,9 @@ Instances: {len(self._instances)}"""
2156
2203
  self._setup_framebuffer()
2157
2204
 
2158
2205
  def register_shape(self, geo_hash, vertices, indices, color1=None, color2=None):
2159
- from pyglet import gl
2206
+ gl = OpenGLRenderer.gl
2207
+
2208
+ self.window.switch_to()
2160
2209
 
2161
2210
  shape = len(self._shapes)
2162
2211
  if color1 is None:
@@ -2205,7 +2254,9 @@ Instances: {len(self._instances)}"""
2205
2254
  return shape
2206
2255
 
2207
2256
  def deregister_shape(self, shape):
2208
- from pyglet import gl
2257
+ gl = OpenGLRenderer.gl
2258
+
2259
+ self.window.switch_to()
2209
2260
 
2210
2261
  if shape not in self._shape_gl_buffers:
2211
2262
  return
@@ -2264,7 +2315,9 @@ Instances: {len(self._instances)}"""
2264
2315
  del self._instances[name]
2265
2316
 
2266
2317
  def update_instance_colors(self):
2267
- from pyglet import gl
2318
+ gl = OpenGLRenderer.gl
2319
+
2320
+ self.window.switch_to()
2268
2321
 
2269
2322
  colors1, colors2 = [], []
2270
2323
  all_instances = list(self._instances.values())
@@ -2278,19 +2331,16 @@ Instances: {len(self._instances)}"""
2278
2331
  colors1 = np.array(colors1, dtype=np.float32)
2279
2332
  colors2 = np.array(colors2, dtype=np.float32)
2280
2333
 
2281
- # create buffer for checkerboard colors
2282
- self._instance_color1_buffer = gl.GLuint()
2283
- gl.glGenBuffers(1, self._instance_color1_buffer)
2284
2334
  gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._instance_color1_buffer)
2285
2335
  gl.glBufferData(gl.GL_ARRAY_BUFFER, colors1.nbytes, colors1.ctypes.data, gl.GL_STATIC_DRAW)
2286
2336
 
2287
- self._instance_color2_buffer = gl.GLuint()
2288
- gl.glGenBuffers(1, self._instance_color2_buffer)
2289
2337
  gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._instance_color2_buffer)
2290
2338
  gl.glBufferData(gl.GL_ARRAY_BUFFER, colors2.nbytes, colors2.ctypes.data, gl.GL_STATIC_DRAW)
2291
2339
 
2292
2340
  def allocate_shape_instances(self):
2293
- from pyglet import gl
2341
+ gl = OpenGLRenderer.gl
2342
+
2343
+ self.window.switch_to()
2294
2344
 
2295
2345
  self._add_shape_instances = False
2296
2346
  self._wp_instance_transforms = wp.array(
@@ -2322,6 +2372,12 @@ Instances: {len(self._instances)}"""
2322
2372
  int(self._instance_transform_gl_buffer.value), self._device
2323
2373
  )
2324
2374
 
2375
+ # create color buffers
2376
+ self._instance_color1_buffer = gl.GLuint()
2377
+ gl.glGenBuffers(1, self._instance_color1_buffer)
2378
+ self._instance_color2_buffer = gl.GLuint()
2379
+ gl.glGenBuffers(1, self._instance_color2_buffer)
2380
+
2325
2381
  self.update_instance_colors()
2326
2382
 
2327
2383
  # set up instance attribute pointers
@@ -2386,7 +2442,9 @@ Instances: {len(self._instances)}"""
2386
2442
  color2: The second color of the checker pattern
2387
2443
  visible: Whether the shape is visible
2388
2444
  """
2389
- from pyglet import gl
2445
+ gl = OpenGLRenderer.gl
2446
+
2447
+ self.window.switch_to()
2390
2448
 
2391
2449
  if name in self._instances:
2392
2450
  i, body, shape, tf, scale, old_color1, old_color2, v = self._instances[name]
@@ -2451,6 +2509,7 @@ Instances: {len(self._instances)}"""
2451
2509
  vbo_transforms,
2452
2510
  ],
2453
2511
  device=self._device,
2512
+ record_tape=False,
2454
2513
  )
2455
2514
 
2456
2515
  self._instance_transform_cuda_buffer.unmap()
@@ -2497,29 +2556,30 @@ Instances: {len(self._instances)}"""
2497
2556
  Returns:
2498
2557
  bool: Whether the pixels were successfully read.
2499
2558
  """
2500
- from pyglet import gl
2559
+ gl = OpenGLRenderer.gl
2560
+
2561
+ self.window.switch_to()
2501
2562
 
2502
2563
  channels = 3 if mode == "rgb" else 1
2503
2564
 
2504
2565
  if split_up_tiles:
2505
- assert (
2506
- self._tile_width is not None and self._tile_height is not None
2507
- ), "Tile width and height are not set, tiles must all have the same size"
2508
- assert all(
2509
- vp[2] == self._tile_width for vp in self._tile_viewports
2510
- ), "Tile widths do not all equal global tile_width, use `get_tile_pixels` instead to retrieve pixels for a single tile"
2511
- assert all(
2512
- vp[3] == self._tile_height for vp in self._tile_viewports
2513
- ), "Tile heights do not all equal global tile_height, use `get_tile_pixels` instead to retrieve pixels for a single tile"
2514
- assert (
2515
- target_image.shape
2516
- == (
2517
- self.num_tiles,
2518
- self._tile_height,
2519
- self._tile_width,
2520
- channels,
2521
- )
2522
- ), f"Shape of `target_image` array does not match {self.num_tiles} x {self._tile_height} x {self._tile_width} x {channels}"
2566
+ assert self._tile_width is not None and self._tile_height is not None, (
2567
+ "Tile width and height are not set, tiles must all have the same size"
2568
+ )
2569
+ assert all(vp[2] == self._tile_width for vp in self._tile_viewports), (
2570
+ "Tile widths do not all equal global tile_width, use `get_tile_pixels` instead to retrieve pixels for a single tile"
2571
+ )
2572
+ assert all(vp[3] == self._tile_height for vp in self._tile_viewports), (
2573
+ "Tile heights do not all equal global tile_height, use `get_tile_pixels` instead to retrieve pixels for a single tile"
2574
+ )
2575
+ assert target_image.shape == (
2576
+ self.num_tiles,
2577
+ self._tile_height,
2578
+ self._tile_width,
2579
+ channels,
2580
+ ), (
2581
+ f"Shape of `target_image` array does not match {self.num_tiles} x {self._tile_height} x {self._tile_width} x {channels}"
2582
+ )
2523
2583
  else:
2524
2584
  assert target_image.shape == (
2525
2585
  self.screen_height,
@@ -2783,7 +2843,7 @@ Instances: {len(self._instances)}"""
2783
2843
  up_axis: The axis of the capsule that points up (0: x, 1: y, 2: z)
2784
2844
  color: The color of the capsule
2785
2845
  """
2786
- geo_hash = hash(("capsule", radius, half_height))
2846
+ geo_hash = hash(("capsule", radius, half_height, up_axis))
2787
2847
  if geo_hash in self._shape_geo_hash:
2788
2848
  shape = self._shape_geo_hash[geo_hash]
2789
2849
  if self.update_shape_instance(name, pos, rot):
@@ -2818,7 +2878,7 @@ Instances: {len(self._instances)}"""
2818
2878
  up_axis: The axis of the cylinder that points up (0: x, 1: y, 2: z)
2819
2879
  color: The color of the capsule
2820
2880
  """
2821
- geo_hash = hash(("cylinder", radius, half_height))
2881
+ geo_hash = hash(("cylinder", radius, half_height, up_axis))
2822
2882
  if geo_hash in self._shape_geo_hash:
2823
2883
  shape = self._shape_geo_hash[geo_hash]
2824
2884
  if self.update_shape_instance(name, pos, rot):
@@ -2853,7 +2913,7 @@ Instances: {len(self._instances)}"""
2853
2913
  up_axis: The axis of the cone that points up (0: x, 1: y, 2: z)
2854
2914
  color: The color of the cone
2855
2915
  """
2856
- geo_hash = hash(("cone", radius, half_height))
2916
+ geo_hash = hash(("cone", radius, half_height, up_axis))
2857
2917
  if geo_hash in self._shape_geo_hash:
2858
2918
  shape = self._shape_geo_hash[geo_hash]
2859
2919
  if self.update_shape_instance(name, pos, rot):
@@ -2932,7 +2992,7 @@ Instances: {len(self._instances)}"""
2932
2992
  indices = np.array(indices, dtype=np.int32).reshape((-1, 3))
2933
2993
  idx_count = len(indices)
2934
2994
 
2935
- geo_hash = hash((indices.tobytes(),))
2995
+ geo_hash = hash((points.tobytes(), indices.tobytes()))
2936
2996
 
2937
2997
  if name in self._instances:
2938
2998
  # We've already registered this mesh instance and its associated shape.
@@ -2954,6 +3014,12 @@ Instances: {len(self._instances)}"""
2954
3014
  if shape is not None:
2955
3015
  # Update the shape's point positions.
2956
3016
  self.update_shape_vertices(shape, points, scale)
3017
+
3018
+ if not is_template and name not in self._instances:
3019
+ # Create a new instance.
3020
+ body = self._resolve_body_id(parent_body)
3021
+ self.add_shape_instance(name, shape, body, pos, rot, color1=colors)
3022
+
2957
3023
  return shape
2958
3024
 
2959
3025
  # No existing shape for the given mesh was found, or its topology may have changed,
@@ -3031,16 +3097,16 @@ Instances: {len(self._instances)}"""
3031
3097
  name: A name for the USD prim on the stage
3032
3098
  up_axis: The axis of the arrow that points up (0: x, 1: y, 2: z)
3033
3099
  """
3034
- geo_hash = hash(("arrow", base_radius, base_height, cap_radius, cap_height))
3100
+ geo_hash = hash(("arrow", base_radius, base_height, cap_radius, cap_height, up_axis))
3035
3101
  if geo_hash in self._shape_geo_hash:
3036
3102
  shape = self._shape_geo_hash[geo_hash]
3037
- if self.update_shape_instance(name, pos, rot):
3103
+ if self.update_shape_instance(name, pos, rot, color1=color, color2=color):
3038
3104
  return shape
3039
3105
  else:
3040
3106
  vertices, indices = self._create_arrow_mesh(
3041
3107
  base_radius, base_height, cap_radius, cap_height, up_axis=up_axis
3042
3108
  )
3043
- shape = self.register_shape(geo_hash, vertices, indices)
3109
+ shape = self.register_shape(geo_hash, vertices, indices, color1=color, color2=color)
3044
3110
  if not is_template:
3045
3111
  body = self._resolve_body_id(parent_body)
3046
3112
  self.add_shape_instance(name, shape, body, pos, rot, color1=color, color2=color)
warp/sim/__init__.py CHANGED
@@ -50,4 +50,9 @@ from .model import (
50
50
  ModelShapeMaterials,
51
51
  State,
52
52
  )
53
- from .utils import load_mesh, quat_from_euler, quat_to_euler, velocity_at_point
53
+ from .utils import (
54
+ load_mesh,
55
+ quat_from_euler,
56
+ quat_to_euler,
57
+ velocity_at_point,
58
+ )