warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (179) hide show
  1. warp/__init__.py +7 -1
  2. warp/bin/libwarp-clang.dylib +0 -0
  3. warp/bin/libwarp.dylib +0 -0
  4. warp/build.py +410 -0
  5. warp/build_dll.py +6 -14
  6. warp/builtins.py +452 -362
  7. warp/codegen.py +179 -119
  8. warp/config.py +42 -6
  9. warp/context.py +490 -271
  10. warp/dlpack.py +8 -6
  11. warp/examples/assets/nonuniform.usd +0 -0
  12. warp/examples/assets/nvidia_logo.png +0 -0
  13. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  14. warp/examples/core/example_sample_mesh.py +300 -0
  15. warp/examples/fem/example_apic_fluid.py +1 -1
  16. warp/examples/fem/example_burgers.py +2 -2
  17. warp/examples/fem/example_deformed_geometry.py +1 -1
  18. warp/examples/fem/example_distortion_energy.py +1 -1
  19. warp/examples/fem/example_magnetostatics.py +6 -6
  20. warp/examples/fem/utils.py +9 -3
  21. warp/examples/interop/example_jax_callable.py +116 -0
  22. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  23. warp/examples/interop/example_jax_kernel.py +205 -0
  24. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  25. warp/examples/tile/example_tile_matmul.py +2 -4
  26. warp/fem/__init__.py +11 -1
  27. warp/fem/adaptivity.py +4 -4
  28. warp/fem/field/nodal_field.py +22 -68
  29. warp/fem/field/virtual.py +62 -23
  30. warp/fem/geometry/adaptive_nanogrid.py +9 -10
  31. warp/fem/geometry/closest_point.py +1 -1
  32. warp/fem/geometry/deformed_geometry.py +5 -2
  33. warp/fem/geometry/geometry.py +5 -0
  34. warp/fem/geometry/grid_2d.py +12 -12
  35. warp/fem/geometry/grid_3d.py +12 -15
  36. warp/fem/geometry/hexmesh.py +5 -7
  37. warp/fem/geometry/nanogrid.py +9 -11
  38. warp/fem/geometry/quadmesh.py +13 -13
  39. warp/fem/geometry/tetmesh.py +3 -4
  40. warp/fem/geometry/trimesh.py +3 -8
  41. warp/fem/integrate.py +262 -93
  42. warp/fem/linalg.py +5 -5
  43. warp/fem/quadrature/pic_quadrature.py +37 -22
  44. warp/fem/quadrature/quadrature.py +194 -25
  45. warp/fem/space/__init__.py +1 -1
  46. warp/fem/space/basis_function_space.py +4 -2
  47. warp/fem/space/basis_space.py +25 -18
  48. warp/fem/space/hexmesh_function_space.py +2 -2
  49. warp/fem/space/partition.py +6 -2
  50. warp/fem/space/quadmesh_function_space.py +8 -8
  51. warp/fem/space/shape/cube_shape_function.py +23 -23
  52. warp/fem/space/shape/square_shape_function.py +12 -12
  53. warp/fem/space/shape/triangle_shape_function.py +1 -1
  54. warp/fem/space/tetmesh_function_space.py +3 -3
  55. warp/fem/space/trimesh_function_space.py +2 -2
  56. warp/fem/utils.py +12 -6
  57. warp/jax.py +14 -1
  58. warp/jax_experimental/__init__.py +16 -0
  59. warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
  60. warp/jax_experimental/ffi.py +698 -0
  61. warp/jax_experimental/xla_ffi.py +602 -0
  62. warp/math.py +89 -0
  63. warp/native/array.h +13 -0
  64. warp/native/builtin.h +29 -3
  65. warp/native/bvh.cpp +3 -1
  66. warp/native/bvh.cu +42 -14
  67. warp/native/bvh.h +2 -1
  68. warp/native/clang/clang.cpp +30 -3
  69. warp/native/cuda_util.cpp +14 -0
  70. warp/native/cuda_util.h +2 -0
  71. warp/native/exports.h +68 -63
  72. warp/native/intersect.h +26 -26
  73. warp/native/intersect_adj.h +33 -33
  74. warp/native/marching.cu +1 -1
  75. warp/native/mat.h +513 -9
  76. warp/native/mesh.h +10 -10
  77. warp/native/quat.h +99 -11
  78. warp/native/rand.h +6 -0
  79. warp/native/sort.cpp +122 -59
  80. warp/native/sort.cu +152 -15
  81. warp/native/sort.h +8 -1
  82. warp/native/sparse.cpp +43 -22
  83. warp/native/sparse.cu +52 -17
  84. warp/native/svd.h +116 -0
  85. warp/native/tile.h +301 -105
  86. warp/native/tile_reduce.h +46 -3
  87. warp/native/vec.h +68 -7
  88. warp/native/volume.cpp +85 -113
  89. warp/native/volume_builder.cu +25 -10
  90. warp/native/volume_builder.h +6 -0
  91. warp/native/warp.cpp +5 -6
  92. warp/native/warp.cu +99 -10
  93. warp/native/warp.h +19 -10
  94. warp/optim/linear.py +10 -10
  95. warp/sim/articulation.py +4 -4
  96. warp/sim/collide.py +21 -10
  97. warp/sim/import_mjcf.py +449 -155
  98. warp/sim/import_urdf.py +32 -12
  99. warp/sim/integrator_euler.py +5 -5
  100. warp/sim/integrator_featherstone.py +3 -10
  101. warp/sim/integrator_vbd.py +207 -2
  102. warp/sim/integrator_xpbd.py +5 -5
  103. warp/sim/model.py +42 -13
  104. warp/sim/utils.py +2 -2
  105. warp/sparse.py +642 -555
  106. warp/stubs.py +216 -19
  107. warp/tests/__main__.py +0 -15
  108. warp/tests/cuda/__init__.py +0 -0
  109. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  110. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  111. warp/tests/geometry/__init__.py +0 -0
  112. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  113. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  114. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  115. warp/tests/interop/__init__.py +0 -0
  116. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  117. warp/tests/sim/__init__.py +0 -0
  118. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  119. warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
  120. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  121. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  122. warp/tests/sim/test_vbd.py +597 -0
  123. warp/tests/test_bool.py +1 -1
  124. warp/tests/test_examples.py +28 -36
  125. warp/tests/test_fem.py +23 -4
  126. warp/tests/test_linear_solvers.py +0 -11
  127. warp/tests/test_mat.py +233 -79
  128. warp/tests/test_mat_scalar_ops.py +4 -4
  129. warp/tests/test_overwrite.py +0 -60
  130. warp/tests/test_quat.py +67 -46
  131. warp/tests/test_rand.py +44 -37
  132. warp/tests/test_sparse.py +47 -6
  133. warp/tests/test_spatial.py +75 -0
  134. warp/tests/test_static.py +1 -1
  135. warp/tests/test_utils.py +84 -4
  136. warp/tests/test_vec.py +46 -34
  137. warp/tests/tile/__init__.py +0 -0
  138. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  139. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
  140. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  141. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  142. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  143. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  144. warp/tests/unittest_serial.py +1 -0
  145. warp/tests/unittest_suites.py +45 -59
  146. warp/tests/unittest_utils.py +2 -1
  147. warp/thirdparty/unittest_parallel.py +3 -1
  148. warp/types.py +110 -658
  149. warp/utils.py +137 -72
  150. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
  151. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
  152. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
  153. warp/examples/optim/example_walker.py +0 -317
  154. warp/native/cutlass_gemm.cpp +0 -43
  155. warp/native/cutlass_gemm.cu +0 -382
  156. warp/tests/test_matmul.py +0 -511
  157. warp/tests/test_matmul_lite.py +0 -411
  158. warp/tests/test_vbd.py +0 -386
  159. warp/tests/unused_test_misc.py +0 -77
  160. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  161. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  162. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  163. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  164. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  165. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  166. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  167. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  168. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  169. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  170. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  171. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  172. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  173. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  174. /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
  175. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  176. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  177. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  178. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
  179. {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
@@ -267,11 +267,21 @@ __device__ std::enable_if_t<nanovdb::BuildTraits<typename Node::BuildType>::is_i
267
267
  {
268
268
  }
269
269
 
270
+ template <typename T>
271
+ struct alignas(alignof(T)) AlignedProxy
272
+ {
273
+ char data[sizeof(T)];
274
+ };
275
+
270
276
  template <typename Tree, typename NodeT>
271
277
  __global__ void setInternalBBoxAndBackgroundValue(Tree *tree, const typename Tree::BuildType background_value)
272
278
  {
273
279
  using BBox = nanovdb::math::BBox<typename NodeT::CoordT>;
274
- __shared__ BBox bbox;
280
+ using BBoxProxy = AlignedProxy<BBox>;
281
+
282
+ __shared__ BBoxProxy bbox_mem;
283
+
284
+ BBox& bbox = reinterpret_cast<BBox&>(bbox_mem);
275
285
 
276
286
  const unsigned node_count = tree->mNodeCount[NodeT::LEVEL];
277
287
  const unsigned node_id = blockIdx.x;
@@ -281,7 +291,7 @@ __global__ void setInternalBBoxAndBackgroundValue(Tree *tree, const typename Tre
281
291
 
282
292
  if (threadIdx.x == 0)
283
293
  {
284
- bbox = BBox();
294
+ new(&bbox) BBox();
285
295
  }
286
296
 
287
297
  __syncthreads();
@@ -313,14 +323,17 @@ __global__ void setRootBBoxAndBackgroundValue(nanovdb::Grid<Tree> *grid,
313
323
  const typename Tree::BuildType background_value)
314
324
  {
315
325
  using BBox = typename Tree::RootNodeType::BBoxType;
316
- __shared__ BBox bbox;
326
+ using BBoxProxy = AlignedProxy<BBox>;
327
+ __shared__ BBoxProxy bbox_mem;
328
+
329
+ BBox& bbox = reinterpret_cast<BBox&>(bbox_mem);
317
330
 
318
331
  Tree &tree = grid->tree();
319
332
  const unsigned upper_count = tree.mNodeCount[2];
320
333
 
321
334
  if (threadIdx.x == 0)
322
335
  {
323
- bbox = BBox();
336
+ new(&bbox) BBox();
324
337
  }
325
338
 
326
339
  __syncthreads();
@@ -450,12 +463,14 @@ void build_grid_from_points(nanovdb::Grid<nanovdb::NanoTree<BuildT>> *&out_grid,
450
463
  grid_handle.buffer().detachDeviceData();
451
464
  }
452
465
 
453
- template void build_grid_from_points(nanovdb::Grid<nanovdb::NanoTree<float>> *&, size_t &, const void *, size_t, bool,
454
- const BuildGridParams<float> &);
455
- template void build_grid_from_points(nanovdb::Grid<nanovdb::NanoTree<nanovdb::Vec3f>> *&, size_t &, const void *,
456
- size_t, bool, const BuildGridParams<nanovdb::Vec3f> &);
457
- template void build_grid_from_points(nanovdb::Grid<nanovdb::NanoTree<int32_t>> *&, size_t &, const void *, size_t, bool,
458
- const BuildGridParams<int32_t> &);
466
+
467
+ #define EXPAND_BUILDER_TYPE(type) \
468
+ template void build_grid_from_points(nanovdb::Grid<nanovdb::NanoTree<type>> *&, size_t &, const void *, size_t, bool, \
469
+ const BuildGridParams<type> &);
470
+
471
+ WP_VOLUME_BUILDER_INSTANTIATE_TYPES
472
+ #undef EXPAND_BUILDER_TYPE
473
+
459
474
  template void build_grid_from_points(nanovdb::Grid<nanovdb::NanoTree<nanovdb::ValueIndex>> *&, size_t &, const void *,
460
475
  size_t, bool, const BuildGridParams<nanovdb::ValueIndex> &);
461
476
  template void build_grid_from_points(nanovdb::Grid<nanovdb::NanoTree<nanovdb::ValueOnIndex>> *&, size_t &, const void *,
@@ -19,6 +19,12 @@
19
19
 
20
20
  #include <nanovdb/NanoVDB.h>
21
21
 
22
+ #define WP_VOLUME_BUILDER_INSTANTIATE_TYPES \
23
+ EXPAND_BUILDER_TYPE(int32_t) \
24
+ EXPAND_BUILDER_TYPE(float) \
25
+ EXPAND_BUILDER_TYPE(nanovdb::Vec3f) \
26
+ EXPAND_BUILDER_TYPE(nanovdb::Vec4f) \
27
+
22
28
  template <typename BuildT> struct BuildGridParams
23
29
  {
24
30
  nanovdb::Map map;
warp/native/warp.cpp CHANGED
@@ -151,11 +151,6 @@ int is_cuda_compatibility_enabled()
151
151
  return int(WP_ENABLE_CUDA_COMPATIBILITY);
152
152
  }
153
153
 
154
- int is_cutlass_enabled()
155
- {
156
- return int(WP_ENABLE_CUTLASS);
157
- }
158
-
159
154
  int is_mathdx_enabled()
160
155
  {
161
156
  return int(WP_ENABLE_MATHDX);
@@ -1004,6 +999,8 @@ WP_API int cuda_device_is_mempool_supported(int ordinal) { return 0; }
1004
999
  WP_API int cuda_device_is_ipc_supported(int ordinal) { return 0; }
1005
1000
  WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold) { return 0; }
1006
1001
  WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal) { return 0; }
1002
+ WP_API uint64_t cuda_device_get_mempool_used_mem_current(int ordinal) { return 0; }
1003
+ WP_API uint64_t cuda_device_get_mempool_used_mem_high(int ordinal) { return 0; }
1007
1004
  WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem) {}
1008
1005
 
1009
1006
  WP_API void* cuda_context_get_current() { return NULL; }
@@ -1033,6 +1030,7 @@ WP_API void* cuda_ipc_open_event_handle(void* context, char* handle) { return NU
1033
1030
 
1034
1031
  WP_API void* cuda_stream_create(void* context, int priority) { return NULL; }
1035
1032
  WP_API void cuda_stream_destroy(void* context, void* stream) {}
1033
+ WP_API int cuda_stream_query(void* stream) { return 0; }
1036
1034
  WP_API void cuda_stream_register(void* context, void* stream) {}
1037
1035
  WP_API void cuda_stream_unregister(void* context, void* stream) {}
1038
1036
  WP_API void* cuda_stream_get_current() { return NULL; }
@@ -1045,7 +1043,8 @@ WP_API int cuda_stream_get_priority(void* stream) { return 0; }
1045
1043
 
1046
1044
  WP_API void* cuda_event_create(void* context, unsigned flags) { return NULL; }
1047
1045
  WP_API void cuda_event_destroy(void* event) {}
1048
- WP_API void cuda_event_record(void* event, void* stream) {}
1046
+ WP_API int cuda_event_query(void* event) { return 0; }
1047
+ WP_API void cuda_event_record(void* event, void* stream, bool timing) {}
1049
1048
  WP_API void cuda_event_synchronize(void* event) {}
1050
1049
  WP_API float cuda_event_elapsed_time(void* start_event, void* end_event) { return 0.0f; }
1051
1050
 
warp/native/warp.cu CHANGED
@@ -1888,6 +1888,62 @@ uint64_t cuda_device_get_mempool_release_threshold(int ordinal)
1888
1888
  return threshold;
1889
1889
  }
1890
1890
 
1891
+ uint64_t cuda_device_get_mempool_used_mem_current(int ordinal)
1892
+ {
1893
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1894
+ {
1895
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1896
+ return 0;
1897
+ }
1898
+
1899
+ if (!g_devices[ordinal].is_mempool_supported)
1900
+ return 0;
1901
+
1902
+ cudaMemPool_t pool;
1903
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1904
+ {
1905
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1906
+ return 0;
1907
+ }
1908
+
1909
+ uint64_t mem_used = 0;
1910
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemCurrent, &mem_used)))
1911
+ {
1912
+ fprintf(stderr, "Warp error: Failed to get amount of currently used memory from the memory pool on device %d\n", ordinal);
1913
+ return 0;
1914
+ }
1915
+
1916
+ return mem_used;
1917
+ }
1918
+
1919
+ uint64_t cuda_device_get_mempool_used_mem_high(int ordinal)
1920
+ {
1921
+ if (ordinal < 0 || ordinal > int(g_devices.size()))
1922
+ {
1923
+ fprintf(stderr, "Invalid device ordinal %d\n", ordinal);
1924
+ return 0;
1925
+ }
1926
+
1927
+ if (!g_devices[ordinal].is_mempool_supported)
1928
+ return 0;
1929
+
1930
+ cudaMemPool_t pool;
1931
+ if (!check_cuda(cudaDeviceGetDefaultMemPool(&pool, ordinal)))
1932
+ {
1933
+ fprintf(stderr, "Warp error: Failed to get memory pool on device %d\n", ordinal);
1934
+ return 0;
1935
+ }
1936
+
1937
+ uint64_t mem_high_water_mark = 0;
1938
+ if (!check_cuda(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &mem_high_water_mark)))
1939
+ {
1940
+ fprintf(stderr, "Warp error: Failed to get memory usage high water mark from the memory pool on device %d\n", ordinal);
1941
+ return 0;
1942
+ }
1943
+
1944
+ return mem_high_water_mark;
1945
+ }
1946
+
1891
1947
  void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem)
1892
1948
  {
1893
1949
  // use temporary storage if user didn't specify pointers
@@ -2371,6 +2427,19 @@ void cuda_stream_destroy(void* context, void* stream)
2371
2427
  check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
2372
2428
  }
2373
2429
 
2430
+ int cuda_stream_query(void* stream)
2431
+ {
2432
+ CUresult res = cuStreamQuery_f(static_cast<CUstream>(stream));
2433
+
2434
+ if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
2435
+ {
2436
+ // Abnormal, print out error
2437
+ check_cu(res);
2438
+ }
2439
+
2440
+ return res;
2441
+ }
2442
+
2374
2443
  void cuda_stream_register(void* context, void* stream)
2375
2444
  {
2376
2445
  if (!stream)
@@ -2465,9 +2534,30 @@ void cuda_event_destroy(void* event)
2465
2534
  check_cu(cuEventDestroy_f(static_cast<CUevent>(event)));
2466
2535
  }
2467
2536
 
2468
- void cuda_event_record(void* event, void* stream)
2537
+ int cuda_event_query(void* event)
2538
+ {
2539
+ CUresult res = cuEventQuery_f(static_cast<CUevent>(event));
2540
+
2541
+ if ((res != CUDA_SUCCESS) && (res != CUDA_ERROR_NOT_READY))
2542
+ {
2543
+ // Abnormal, print out error
2544
+ check_cu(res);
2545
+ }
2546
+
2547
+ return res;
2548
+ }
2549
+
2550
+ void cuda_event_record(void* event, void* stream, bool timing)
2469
2551
  {
2470
- check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
2552
+ if (timing && !g_captures.empty() && cuda_stream_is_capturing(stream))
2553
+ {
2554
+ // record timing event during graph capture
2555
+ check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
2556
+ }
2557
+ else
2558
+ {
2559
+ check_cu(cuEventRecord_f(static_cast<CUevent>(event), static_cast<CUstream>(stream)));
2560
+ }
2471
2561
  }
2472
2562
 
2473
2563
  void cuda_event_synchronize(void* event)
@@ -2814,6 +2904,12 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
2814
2904
  opts.push_back("--define-macro=WP_VERIFY_FP");
2815
2905
  else
2816
2906
  opts.push_back("--undefine-macro=WP_VERIFY_FP");
2907
+
2908
+ #if WP_ENABLE_MATHDX
2909
+ opts.push_back("--define-macro=WP_ENABLE_MATHDX=1");
2910
+ #else
2911
+ opts.push_back("--define-macro=WP_ENABLE_MATHDX=0");
2912
+ #endif
2817
2913
 
2818
2914
  if (fast_math)
2819
2915
  opts.push_back("--use_fast_math");
@@ -2823,10 +2919,6 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
2823
2919
  else
2824
2920
  opts.push_back("--fmad=false");
2825
2921
 
2826
- char include_cutlass[max_path];
2827
- sprintf(include_cutlass, "--include-path=%s/cutlass/include", include_dir);
2828
- opts.push_back(include_cutlass);
2829
-
2830
2922
  std::vector<std::string> cuda_include_opt;
2831
2923
  for(int i = 0; i < num_cuda_include_dirs; i++)
2832
2924
  {
@@ -3182,7 +3274,7 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
3182
3274
  std::vector<char> lto(lto_size);
3183
3275
  CHECK_CUSOLVER(cusolverGetLTOIR(h, lto.size(), lto.data()));
3184
3276
 
3185
- // This fatbin is universal, ie it is the same for any instantations of a cusolver device function
3277
+ // This fatbin is universal, ie it is the same for any instantiations of a cusolver device function
3186
3278
  size_t fatbin_size = 0;
3187
3279
  CHECK_CUSOLVER(cusolverGetUniversalFATBINSize(h, &fatbin_size));
3188
3280
 
@@ -3539,9 +3631,6 @@ void cuda_timing_end(timing_result_t* results, int size)
3539
3631
  #include "sparse.cu"
3540
3632
  #include "volume.cu"
3541
3633
  #include "volume_builder.cu"
3542
- #if WP_ENABLE_CUTLASS
3543
- #include "cutlass_gemm.cu"
3544
- #endif
3545
3634
 
3546
3635
  //#include "spline.inl"
3547
3636
  //#include "volume.inl"
warp/native/warp.h CHANGED
@@ -41,8 +41,6 @@ extern "C"
41
41
  WP_API int is_cuda_enabled();
42
42
  // whether Warp was compiled with enhanced CUDA compatibility
43
43
  WP_API int is_cuda_compatibility_enabled();
44
- // whether Warp was compiled with CUTLASS support
45
- WP_API int is_cutlass_enabled();
46
44
  // whether Warp was compiled with MathDx support
47
45
  WP_API int is_mathdx_enabled();
48
46
  // whether Warp was compiled with debug support
@@ -112,10 +110,6 @@ extern "C"
112
110
  WP_API void hash_grid_destroy_device(uint64_t id);
113
111
  WP_API void hash_grid_update_device(uint64_t id, float cell_width, const wp::array_t<wp::vec3>* points);
114
112
 
115
- WP_API bool cutlass_gemm(void* context, int compute_capability, int m, int n, int k, const char* datatype,
116
- const void* a, const void* b, const void* c, void* d, float alpha, float beta,
117
- bool row_major_a, bool row_major_b, bool allow_tf32x3_arith, int batch_count);
118
-
119
113
  WP_API uint64_t volume_create_host(void* buf, uint64_t size, bool copy, bool owner);
120
114
  WP_API void volume_get_tiles_host(uint64_t id, void* buf);
121
115
  WP_API void volume_get_voxels_host(uint64_t id, void* buf);
@@ -126,9 +120,7 @@ extern "C"
126
120
  WP_API void volume_get_voxels_device(uint64_t id, void* buf);
127
121
  WP_API void volume_destroy_device(uint64_t id);
128
122
 
129
- WP_API uint64_t volume_f_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space, float bg_value);
130
- WP_API uint64_t volume_v_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space, float bg_value[3]);
131
- WP_API uint64_t volume_i_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space, int bg_value);
123
+ WP_API uint64_t volume_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space, const void* bg_value, uint32_t bg_value_size, const char* bg_value_type);
132
124
  WP_API uint64_t volume_index_from_tiles_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space);
133
125
  WP_API uint64_t volume_from_active_voxels_device(void* context, void* points, int num_points, float transform[9], float translation[3], bool points_in_world_space);
134
126
 
@@ -173,6 +165,15 @@ extern "C"
173
165
  WP_API void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n);
174
166
  WP_API void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n);
175
167
 
168
+ WP_API void radix_sort_pairs_int64_host(uint64_t keys, uint64_t values, int n);
169
+ WP_API void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n);
170
+
171
+ WP_API void segmented_sort_pairs_float_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments);
172
+ WP_API void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments);
173
+
174
+ WP_API void segmented_sort_pairs_int_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments);
175
+ WP_API void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments);
176
+
176
177
  WP_API void runlength_encode_int_host(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
177
178
  WP_API void runlength_encode_int_device(uint64_t values, uint64_t run_values, uint64_t run_lengths, uint64_t run_count, int n);
178
179
 
@@ -185,6 +186,7 @@ extern "C"
185
186
  int* tpl_columns,
186
187
  void* tpl_values,
187
188
  bool prune_numerical_zeros,
189
+ bool masked,
188
190
  int* bsr_offsets,
189
191
  int* bsr_columns,
190
192
  void* bsr_values,
@@ -199,6 +201,7 @@ extern "C"
199
201
  int* tpl_columns,
200
202
  void* tpl_values,
201
203
  bool prune_numerical_zeros,
204
+ bool masked,
202
205
  int* bsr_offsets,
203
206
  int* bsr_columns,
204
207
  void* bsr_values,
@@ -213,6 +216,7 @@ extern "C"
213
216
  int* tpl_columns,
214
217
  void* tpl_values,
215
218
  bool prune_numerical_zeros,
219
+ bool masked,
216
220
  int* bsr_offsets,
217
221
  int* bsr_columns,
218
222
  void* bsr_values,
@@ -227,6 +231,7 @@ extern "C"
227
231
  int* tpl_columns,
228
232
  void* tpl_values,
229
233
  bool prune_numerical_zeros,
234
+ bool masked,
230
235
  int* bsr_offsets,
231
236
  int* bsr_columns,
232
237
  void* bsr_values,
@@ -283,6 +288,8 @@ extern "C"
283
288
  WP_API int cuda_device_is_ipc_supported(int ordinal);
284
289
  WP_API int cuda_device_set_mempool_release_threshold(int ordinal, uint64_t threshold);
285
290
  WP_API uint64_t cuda_device_get_mempool_release_threshold(int ordinal);
291
+ WP_API uint64_t cuda_device_get_mempool_used_mem_current(int ordinal);
292
+ WP_API uint64_t cuda_device_get_mempool_used_mem_high(int ordinal);
286
293
  WP_API void cuda_device_get_memory_info(int ordinal, size_t* free_mem, size_t* total_mem);
287
294
 
288
295
  WP_API void* cuda_context_get_current();
@@ -318,6 +325,7 @@ extern "C"
318
325
 
319
326
  WP_API void* cuda_stream_create(void* context, int priority);
320
327
  WP_API void cuda_stream_destroy(void* context, void* stream);
328
+ WP_API int cuda_stream_query(void* stream);
321
329
  WP_API void cuda_stream_register(void* context, void* stream);
322
330
  WP_API void cuda_stream_unregister(void* context, void* stream);
323
331
  WP_API void* cuda_stream_get_current();
@@ -330,7 +338,8 @@ extern "C"
330
338
 
331
339
  WP_API void* cuda_event_create(void* context, unsigned flags);
332
340
  WP_API void cuda_event_destroy(void* event);
333
- WP_API void cuda_event_record(void* event, void* stream);
341
+ WP_API int cuda_event_query(void* event);
342
+ WP_API void cuda_event_record(void* event, void* stream, bool timing=false);
334
343
  WP_API void cuda_event_synchronize(void* event);
335
344
  WP_API float cuda_event_elapsed_time(void* start_event, void* end_event);
336
345
 
warp/optim/linear.py CHANGED
@@ -866,7 +866,7 @@ def _diag_mv_vec_kernel(
866
866
  def _inverse_diag_coefficient(coeff: Any, use_abs: wp.bool):
867
867
  zero = type(coeff)(0.0)
868
868
  one = type(coeff)(1.0)
869
- return wp.select(coeff == zero, one / wp.select(use_abs, coeff, wp.abs(coeff)), one)
869
+ return wp.where(coeff == zero, one, one / wp.where(use_abs, wp.abs(coeff), coeff))
870
870
 
871
871
 
872
872
  @wp.kernel
@@ -917,7 +917,7 @@ def _cg_kernel_1(
917
917
  ):
918
918
  i = wp.tid()
919
919
 
920
- alpha = wp.select(resid[0] > tol, rz_old.dtype(0.0), rz_old[0] / p_Ap[0])
920
+ alpha = wp.where(resid[0] > tol, rz_old[0] / p_Ap[0], rz_old.dtype(0.0))
921
921
 
922
922
  x[i] = x[i] + alpha * p[i]
923
923
  r[i] = r[i] - alpha * Ap[i]
@@ -935,7 +935,7 @@ def _cg_kernel_2(
935
935
  # p = r + (rz_new / rz_old) * p;
936
936
  i = wp.tid()
937
937
 
938
- beta = wp.select(resid[0] > tol, rz_old.dtype(0.0), rz_new[0] / rz_old[0])
938
+ beta = wp.where(resid[0] > tol, rz_new[0] / rz_old[0], rz_old.dtype(0.0))
939
939
 
940
940
  p[i] = z[i] + beta * p[i]
941
941
 
@@ -955,7 +955,7 @@ def _cr_kernel_1(
955
955
  ):
956
956
  i = wp.tid()
957
957
 
958
- alpha = wp.select(resid[0] > tol and y_Ap[0] > 0.0, zAz_old.dtype(0.0), zAz_old[0] / y_Ap[0])
958
+ alpha = wp.where(resid[0] > tol and y_Ap[0] > 0.0, zAz_old[0] / y_Ap[0], zAz_old.dtype(0.0))
959
959
 
960
960
  x[i] = x[i] + alpha * p[i]
961
961
  r[i] = r[i] - alpha * Ap[i]
@@ -976,7 +976,7 @@ def _cr_kernel_2(
976
976
  # p = r + (rz_new / rz_old) * p;
977
977
  i = wp.tid()
978
978
 
979
- beta = wp.select(resid[0] > tol and zAz_old[0] > 0.0, zAz_old.dtype(0.0), zAz_new[0] / zAz_old[0])
979
+ beta = wp.where(resid[0] > tol and zAz_old[0] > 0.0, zAz_new[0] / zAz_old[0], zAz_old.dtype(0.0))
980
980
 
981
981
  p[i] = z[i] + beta * p[i]
982
982
  Ap[i] = Az[i] + beta * Ap[i]
@@ -995,7 +995,7 @@ def _bicgstab_kernel_1(
995
995
  ):
996
996
  i = wp.tid()
997
997
 
998
- alpha = wp.select(resid[0] > tol, rho_old.dtype(0.0), rho_old[0] / r0v[0])
998
+ alpha = wp.where(resid[0] > tol, rho_old[0] / r0v[0], rho_old.dtype(0.0))
999
999
 
1000
1000
  x[i] += alpha * y[i]
1001
1001
  r[i] -= alpha * v[i]
@@ -1014,7 +1014,7 @@ def _bicgstab_kernel_2(
1014
1014
  ):
1015
1015
  i = wp.tid()
1016
1016
 
1017
- omega = wp.select(resid[0] > tol, st.dtype(0.0), st[0] / tt[0])
1017
+ omega = wp.where(resid[0] > tol, st[0] / tt[0], st.dtype(0.0))
1018
1018
 
1019
1019
  x[i] += omega * z[i]
1020
1020
  r[i] -= omega * t[i]
@@ -1034,8 +1034,8 @@ def _bicgstab_kernel_3(
1034
1034
  ):
1035
1035
  i = wp.tid()
1036
1036
 
1037
- beta = wp.select(resid[0] > tol, st.dtype(0.0), rho_new[0] * tt[0] / (r0v[0] * st[0]))
1038
- beta_omega = wp.select(resid[0] > tol, st.dtype(0.0), rho_new[0] / r0v[0])
1037
+ beta = wp.where(resid[0] > tol, rho_new[0] * tt[0] / (r0v[0] * st[0]), st.dtype(0.0))
1038
+ beta_omega = wp.where(resid[0] > tol, rho_new[0] / r0v[0], st.dtype(0.0))
1039
1039
 
1040
1040
  p[i] = r[i] + beta * p[i] - beta_omega * v[i]
1041
1041
 
@@ -1123,7 +1123,7 @@ def _gmres_arnoldi_normalize_kernel(
1123
1123
  alpha: wp.array(dtype=Any),
1124
1124
  ):
1125
1125
  tid = wp.tid()
1126
- y[tid] = wp.select(alpha[0] == alpha.dtype(0.0), x[tid] / wp.sqrt(alpha[0]), x[tid])
1126
+ y[tid] = wp.where(alpha[0] == alpha.dtype(0.0), x[tid], x[tid] / wp.sqrt(alpha[0]))
1127
1127
 
1128
1128
 
1129
1129
  @wp.kernel
warp/sim/articulation.py CHANGED
@@ -30,7 +30,7 @@ def compute_2d_rotational_dofs(
30
30
  """
31
31
  Computes the rotation quaternion and 3D angular velocity given the joint axes, coordinates and velocities.
32
32
  """
33
- q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, wp.cross(axis_0, axis_1)))
33
+ q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, wp.cross(axis_0, axis_1)))
34
34
 
35
35
  # body local axes
36
36
  local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
@@ -60,7 +60,7 @@ def invert_2d_rotational_dofs(
60
60
  """
61
61
  Computes generalized joint position and velocity coordinates for a 2D rotational joint given the joint axes, relative orientations and angular velocity differences between the two bodies the joint connects.
62
62
  """
63
- q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, wp.cross(axis_0, axis_1)))
63
+ q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, wp.cross(axis_0, axis_1)))
64
64
  q_pc = wp.quat_inverse(q_off) * wp.quat_inverse(q_p) * q_c * q_off
65
65
 
66
66
  # decompose to a compound rotation each axis
@@ -106,7 +106,7 @@ def compute_3d_rotational_dofs(
106
106
  """
107
107
  Computes the rotation quaternion and 3D angular velocity given the joint axes, coordinates and velocities.
108
108
  """
109
- q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, axis_2))
109
+ q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, axis_2))
110
110
 
111
111
  # body local axes
112
112
  local_0 = wp.quat_rotate(q_off, wp.vec3(1.0, 0.0, 0.0))
@@ -136,7 +136,7 @@ def invert_3d_rotational_dofs(
136
136
  """
137
137
  Computes generalized joint position and velocity coordinates for a 3D rotational joint given the joint axes, relative orientations and angular velocity differences between the two bodies the joint connects.
138
138
  """
139
- q_off = wp.quat_from_matrix(wp.mat33(axis_0, axis_1, axis_2))
139
+ q_off = wp.quat_from_matrix(wp.matrix_from_cols(axis_0, axis_1, axis_2))
140
140
  q_pc = wp.quat_inverse(q_off) * wp.quat_inverse(q_p) * q_c * q_off
141
141
 
142
142
  # decompose to a compound rotation each axis
warp/sim/collide.py CHANGED
@@ -17,10 +17,12 @@
17
17
  Collision handling functions and kernels.
18
18
  """
19
19
 
20
+ from typing import Optional
21
+
20
22
  import numpy as np
21
23
 
22
24
  import warp as wp
23
- from warp.sim.model import Model
25
+ from warp.sim.model import Model, State
24
26
 
25
27
  from .model import PARTICLE_FLAG_ACTIVE, ModelShapeGeometry
26
28
 
@@ -1556,17 +1558,23 @@ def handle_contact_pairs(
1556
1558
  contact_thickness[index] = thickness
1557
1559
 
1558
1560
 
1559
- def collide(model, state, edge_sdf_iter: int = 10, iterate_mesh_vertices: bool = True, requires_grad: bool = None):
1560
- """
1561
- Generates contact points for the particles and rigid bodies in the model,
1562
- to be used in the contact dynamics kernel of the integrator.
1561
+ def collide(
1562
+ model: Model,
1563
+ state: State,
1564
+ edge_sdf_iter: int = 10,
1565
+ iterate_mesh_vertices: bool = True,
1566
+ requires_grad: Optional[bool] = None,
1567
+ ) -> None:
1568
+ """Generate contact points for the particles and rigid bodies in the model for use in contact-dynamics kernels.
1563
1569
 
1564
1570
  Args:
1565
- model: the model to be simulated
1566
- state: the state of the model
1567
- edge_sdf_iter: number of search iterations for finding closest contact points between edges and SDF
1568
- iterate_mesh_vertices: whether to iterate over all vertices of a mesh for contact generation (used for capsule/box <> mesh collision)
1569
- requires_grad: whether to duplicate contact arrays for gradient computation (if None uses model.requires_grad)
1571
+ model: The model to be simulated.
1572
+ state: The state of the model.
1573
+ edge_sdf_iter: Number of search iterations for finding closest contact points between edges and SDF.
1574
+ iterate_mesh_vertices: Whether to iterate over all vertices of a mesh for contact generation
1575
+ (used for capsule/box <> mesh collision).
1576
+ requires_grad: Whether to duplicate contact arrays for gradient computation
1577
+ (if ``None``, uses ``model.requires_grad``).
1570
1578
  """
1571
1579
 
1572
1580
  if requires_grad is None:
@@ -1685,13 +1693,16 @@ def collide(model, state, edge_sdf_iter: int = 10, iterate_mesh_vertices: bool =
1685
1693
  model.rigid_contact_tids = wp.zeros_like(model.rigid_contact_tids)
1686
1694
  model.rigid_contact_shape0 = wp.empty_like(model.rigid_contact_shape0)
1687
1695
  model.rigid_contact_shape1 = wp.empty_like(model.rigid_contact_shape1)
1696
+
1688
1697
  if model.rigid_contact_pairwise_counter is not None:
1689
1698
  model.rigid_contact_pairwise_counter = wp.zeros_like(model.rigid_contact_pairwise_counter)
1690
1699
  else:
1691
1700
  model.rigid_contact_count.zero_()
1692
1701
  model.rigid_contact_tids.zero_()
1702
+
1693
1703
  if model.rigid_contact_pairwise_counter is not None:
1694
1704
  model.rigid_contact_pairwise_counter.zero_()
1705
+
1695
1706
  model.rigid_contact_shape0.fill_(-1)
1696
1707
  model.rigid_contact_shape1.fill_(-1)
1697
1708