warp-lang 1.6.2__py3-none-macosx_10_13_universal2.whl → 1.7.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/native/quat.h CHANGED
@@ -375,12 +375,14 @@ inline CUDA_CALLABLE mat_t<3,3,Type> quat_to_matrix(const quat_t<Type>& q)
375
375
  vec_t<3,Type> c2 = quat_rotate(q, vec_t<3,Type>(0.0, 1.0, 0.0));
376
376
  vec_t<3,Type> c3 = quat_rotate(q, vec_t<3,Type>(0.0, 0.0, 1.0));
377
377
 
378
- return mat_t<3,3,Type>(c1, c2, c3);
378
+ return matrix_from_cols<Type>(c1, c2, c3);
379
379
  }
380
380
 
381
- template<typename Type>
382
- inline CUDA_CALLABLE quat_t<Type> quat_from_matrix(const mat_t<3,3,Type>& m)
381
+ template<unsigned Rows, unsigned Cols, typename Type>
382
+ inline CUDA_CALLABLE quat_t<Type> quat_from_matrix(const mat_t<Rows,Cols,Type>& m)
383
383
  {
384
+ static_assert((Rows == 3 && Cols == 3) || (Rows == 4 && Cols == 4), "Non-square matrix");
385
+
384
386
  const Type tr = m.data[0][0] + m.data[1][1] + m.data[2][2];
385
387
  Type x, y, z, w, h = Type(0);
386
388
 
@@ -498,37 +500,98 @@ inline CUDA_CALLABLE void adj_indexref(quat_t<Type>* q, int idx,
498
500
 
499
501
 
500
502
  template<typename Type>
501
- inline CUDA_CALLABLE void augassign_add(quat_t<Type>& q, int idx, Type value)
503
+ inline CUDA_CALLABLE void add_inplace(quat_t<Type>& q, int idx, Type value)
502
504
  {
505
+ #ifndef NDEBUG
506
+ if (idx < 0 || idx > 3)
507
+ {
508
+ printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
509
+ assert(0);
510
+ }
511
+ #endif
512
+
503
513
  q[idx] += value;
504
514
  }
505
515
 
506
516
 
507
517
  template<typename Type>
508
- inline CUDA_CALLABLE void adj_augassign_add(quat_t<Type>& q, int idx, Type value,
518
+ inline CUDA_CALLABLE void adj_add_inplace(quat_t<Type>& q, int idx, Type value,
509
519
  quat_t<Type>& adj_q, int adj_idx, Type& adj_value)
510
520
  {
521
+ #ifndef NDEBUG
522
+ if (idx < 0 || idx > 3)
523
+ {
524
+ printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
525
+ assert(0);
526
+ }
527
+ #endif
528
+
511
529
  adj_value += adj_q[idx];
512
530
  }
513
531
 
514
532
 
515
533
  template<typename Type>
516
- inline CUDA_CALLABLE void augassign_sub(quat_t<Type>& q, int idx, Type value)
534
+ inline CUDA_CALLABLE void sub_inplace(quat_t<Type>& q, int idx, Type value)
517
535
  {
536
+ #ifndef NDEBUG
537
+ if (idx < 0 || idx > 3)
538
+ {
539
+ printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
540
+ assert(0);
541
+ }
542
+ #endif
543
+
518
544
  q[idx] -= value;
519
545
  }
520
546
 
521
547
 
522
548
  template<typename Type>
523
- inline CUDA_CALLABLE void adj_augassign_sub(quat_t<Type>& q, int idx, Type value,
549
+ inline CUDA_CALLABLE void adj_sub_inplace(quat_t<Type>& q, int idx, Type value,
524
550
  quat_t<Type>& adj_q, int adj_idx, Type& adj_value)
525
551
  {
552
+ #ifndef NDEBUG
553
+ if (idx < 0 || idx > 3)
554
+ {
555
+ printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
556
+ assert(0);
557
+ }
558
+ #endif
559
+
526
560
  adj_value -= adj_q[idx];
527
561
  }
528
562
 
529
563
 
530
564
  template<typename Type>
531
- inline CUDA_CALLABLE quat_t<Type> assign(quat_t<Type>& q, int idx, Type value)
565
+ inline CUDA_CALLABLE void assign_inplace(quat_t<Type>& q, int idx, Type value)
566
+ {
567
+ #ifndef NDEBUG
568
+ if (idx < 0 || idx > 3)
569
+ {
570
+ printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
571
+ assert(0);
572
+ }
573
+ #endif
574
+
575
+ q[idx] = value;
576
+ }
577
+
578
+ template<typename Type>
579
+ inline CUDA_CALLABLE void adj_assign_inplace(quat_t<Type>& q, int idx, Type value, quat_t<Type>& adj_q, int& adj_idx, Type& adj_value)
580
+ {
581
+ #ifndef NDEBUG
582
+ if (idx < 0 || idx > 3)
583
+ {
584
+ printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
585
+ assert(0);
586
+ }
587
+ #endif
588
+
589
+ adj_value += adj_q[idx];
590
+ }
591
+
592
+
593
+ template<typename Type>
594
+ inline CUDA_CALLABLE quat_t<Type> assign_copy(quat_t<Type>& q, int idx, Type value)
532
595
  {
533
596
  #ifndef NDEBUG
534
597
  if (idx < 0 || idx > 3)
@@ -544,7 +607,7 @@ inline CUDA_CALLABLE quat_t<Type> assign(quat_t<Type>& q, int idx, Type value)
544
607
  }
545
608
 
546
609
  template<typename Type>
547
- inline CUDA_CALLABLE void adj_assign(quat_t<Type>& q, int idx, Type value, quat_t<Type>& adj_q, int& adj_idx, Type& adj_value, const quat_t<Type>& adj_ret)
610
+ inline CUDA_CALLABLE void adj_assign_copy(quat_t<Type>& q, int idx, Type value, quat_t<Type>& adj_q, int& adj_idx, Type& adj_value, const quat_t<Type>& adj_ret)
548
611
  {
549
612
  #ifndef NDEBUG
550
613
  if (idx < 0 || idx > 3)
@@ -562,6 +625,7 @@ inline CUDA_CALLABLE void adj_assign(quat_t<Type>& q, int idx, Type value, quat_
562
625
  }
563
626
  }
564
627
 
628
+
565
629
  template<typename Type>
566
630
  CUDA_CALLABLE inline quat_t<Type> lerp(const quat_t<Type>& a, const quat_t<Type>& b, Type t)
567
631
  {
@@ -1048,9 +1112,11 @@ inline CUDA_CALLABLE void adj_quat_to_matrix(const quat_t<Type>& q, quat_t<Type>
1048
1112
  adj_quat_rotate(q, vec_t<3,Type>(0.0, 0.0, 1.0), adj_q, t, adj_ret.get_col(2));
1049
1113
  }
1050
1114
 
1051
- template<typename Type>
1052
- inline CUDA_CALLABLE void adj_quat_from_matrix(const mat_t<3,3,Type>& m, mat_t<3,3,Type>& adj_m, const quat_t<Type>& adj_ret)
1115
+ template<unsigned Rows, unsigned Cols, typename Type>
1116
+ inline CUDA_CALLABLE void adj_quat_from_matrix(const mat_t<Rows,Cols,Type>& m, mat_t<Rows,Cols,Type>& adj_m, const quat_t<Type>& adj_ret)
1053
1117
  {
1118
+ static_assert((Rows == 3 && Cols == 3) || (Rows == 4 && Cols == 4), "Non-square matrix");
1119
+
1054
1120
  const Type tr = m.data[0][0] + m.data[1][1] + m.data[2][2];
1055
1121
  Type x, y, z, w, h = Type(0);
1056
1122
 
@@ -1280,4 +1346,26 @@ CUDA_CALLABLE inline void adj_len(const quat_t<Type>& x, quat_t<Type>& adj_x, co
1280
1346
  {
1281
1347
  }
1282
1348
 
1349
+ template<typename Type>
1350
+ inline CUDA_CALLABLE void expect_near(const quat_t<Type>& actual, const quat_t<Type>& expected, const Type& tolerance)
1351
+ {
1352
+ Type diff(0);
1353
+ for(size_t i = 0; i < 4; ++i)
1354
+ {
1355
+ diff = max(diff, abs(actual[i] - expected[i]));
1356
+ }
1357
+ if (diff > tolerance)
1358
+ {
1359
+ printf("Error, expect_near() failed with tolerance "); print(tolerance);
1360
+ printf("\t Expected: "); print(expected);
1361
+ printf("\t Actual: "); print(actual);
1362
+ }
1363
+ }
1364
+
1365
+ template<typename Type>
1366
+ inline CUDA_CALLABLE void adj_expect_near(const quat_t<Type>& actual, const quat_t<Type>& expected, Type tolerance, quat_t<Type>& adj_actual, quat_t<Type>& adj_expected, Type adj_tolerance)
1367
+ {
1368
+ // nop
1369
+ }
1370
+
1283
1371
  } // namespace wp
warp/native/rand.h CHANGED
@@ -53,6 +53,9 @@ inline CUDA_CALLABLE uint32 rand_init(int seed, int offset) { return rand_pcg(ui
53
53
  inline CUDA_CALLABLE int randi(uint32& state) { state = rand_pcg(state); return int(state); }
54
54
  inline CUDA_CALLABLE int randi(uint32& state, int min, int max) { state = rand_pcg(state); return state % (max - min) + min; }
55
55
 
56
+ inline CUDA_CALLABLE uint32 randu(uint32& state) { state = rand_pcg(state); return state; }
57
+ inline CUDA_CALLABLE uint32 randu(uint32& state, uint32 min, uint32 max) { state = rand_pcg(state); return state % (max - min) + min; }
58
+
56
59
  /*
57
60
  * We want to ensure randf adheres to a uniform distribution over [0,1). The set of all possible float32 (IEEE 754 standard) values is not uniformly distributed however.
58
61
  * On the other hand, for a given sign and exponent, the mantissa of the float32 representation is uniformly distributed.
@@ -74,6 +77,9 @@ inline CUDA_CALLABLE void adj_rand_init(int seed, int offset, int& adj_seed, int
74
77
  inline CUDA_CALLABLE void adj_randi(uint32& state, uint32& adj_state, float adj_ret) {}
75
78
  inline CUDA_CALLABLE void adj_randi(uint32& state, int min, int max, uint32& adj_state, int& adj_min, int& adj_max, float adj_ret) {}
76
79
 
80
+ inline CUDA_CALLABLE void adj_randu(uint32& state, uint32& adj_state, float adj_ret) {}
81
+ inline CUDA_CALLABLE void adj_randu(uint32& state, uint32 min, uint32 max, uint32& adj_state, uint32& adj_min, uint32& adj_max, float adj_ret) {}
82
+
77
83
  inline CUDA_CALLABLE void adj_randf(uint32& state, uint32& adj_state, float adj_ret) {}
78
84
  inline CUDA_CALLABLE void adj_randf(uint32& state, float min, float max, uint32& adj_state, float& adj_min, float& adj_max, float adj_ret) {}
79
85
 
warp/native/sort.cpp CHANGED
@@ -21,69 +21,75 @@
21
21
 
22
22
  #include <cstdint>
23
23
 
24
- void radix_sort_pairs_host(int* keys, int* values, int n)
24
+ //Only integer keys (bit count 32 or 64) are supported. Floats need to get converted into int first. see radix_float_to_int.
25
+ template <typename KeyType>
26
+ void radix_sort_pairs_host(KeyType* keys, int* values, int n, int offset_to_scratch_memory)
25
27
  {
26
- static int tables[2][1 << 16];
28
+ const int numPasses = sizeof(KeyType) / 2;
29
+ static int tables[numPasses][1 << 16];
27
30
  memset(tables, 0, sizeof(tables));
28
-
29
- int* auxKeys = keys + n;
30
- int* auxValues = values + n;
31
-
31
+
32
32
  // build histograms
33
- for (int i=0; i < n; ++i)
34
- {
35
- const unsigned short low = keys[i] & 0xffff;
36
- const unsigned short high = keys[i] >> 16;
37
-
38
- ++tables[0][low];
39
- ++tables[1][high];
33
+ for (int p = 0; p < numPasses; ++p)
34
+ {
35
+ for (int i=0; i < n; ++i)
36
+ {
37
+ const int shift = p * 16;
38
+ const int b = (keys[i] >> shift) & 0xffff;
39
+
40
+ ++tables[p][b];
41
+ }
40
42
  }
41
43
 
42
- // convert histograms to offset tables in-place
43
- int offlow = 0;
44
- int offhigh = 0;
45
-
46
- for (int i=0; i < 65536; ++i)
44
+ // convert histograms to offset tables in-place
45
+ for (int p = 0; p < numPasses; ++p)
47
46
  {
48
- const int newofflow = offlow + tables[0][i];
49
- const int newoffhigh = offhigh + tables[1][i];
50
-
51
- tables[0][i] = offlow;
52
- tables[1][i] = offhigh;
53
-
54
- offlow = newofflow;
55
- offhigh = newoffhigh;
47
+ int off = 0;
48
+ for (int i = 0; i < 65536; ++i)
49
+ {
50
+ const int newoff = off + tables[p][i];
51
+
52
+ tables[p][i] = off;
53
+
54
+ off = newoff;
55
+ }
56
56
  }
57
-
58
- // pass 1 - sort by low 16 bits
59
- for (int i=0; i < n; ++i)
60
- {
61
- // lookup offset of input
62
- const int k = keys[i];
63
- const int v = values[i];
64
- const int b = k & 0xffff;
65
-
66
- // find offset and increment
67
- const int offset = tables[0][b]++;
68
-
69
- auxKeys[offset] = k;
70
- auxValues[offset] = v;
71
- }
72
-
73
- // pass 2 - sort by high 16 bits
74
- for (int i=0; i < n; ++i)
75
- {
76
- // lookup offset of input
77
- const int k = auxKeys[i];
78
- const int v = auxValues[i];
57
+
58
+ for (int p = 0; p < numPasses; ++p)
59
+ {
60
+ int flipFlop = p % 2;
61
+ KeyType* readKeys = keys + offset_to_scratch_memory * flipFlop;
62
+ int* readValues = values + offset_to_scratch_memory * flipFlop;
63
+ KeyType* writeKeys = keys + offset_to_scratch_memory * (1 - flipFlop);
64
+ int* writeValues = values + offset_to_scratch_memory * (1 - flipFlop);
65
+
66
+ // pass 1 - sort by low 16 bits
67
+ for (int i=0; i < n; ++i)
68
+ {
69
+ // lookup offset of input
70
+ const KeyType k = readKeys[i];
71
+ const int v = readValues[i];
72
+
73
+ const int shift = p * 16;
74
+ const int b = (k >> shift) & 0xffff;
75
+
76
+ // find offset and increment
77
+ const int offset = tables[p][b]++;
78
+
79
+ writeKeys[offset] = k;
80
+ writeValues[offset] = v;
81
+ }
82
+ }
83
+ }
79
84
 
80
- const int b = k >> 16;
81
-
82
- const int offset = tables[1][b]++;
83
-
84
- keys[offset] = k;
85
- values[offset] = v;
86
- }
85
+ void radix_sort_pairs_host(int* keys, int* values, int n)
86
+ {
87
+ radix_sort_pairs_host<int>(keys, values, n, n);
88
+ }
89
+
90
+ void radix_sort_pairs_host(int64_t* keys, int* values, int n)
91
+ {
92
+ radix_sort_pairs_host<int64_t>(keys, values, n, n);
87
93
  }
88
94
 
89
95
  //http://stereopsis.com/radix.html
@@ -94,13 +100,13 @@ inline unsigned int radix_float_to_int(float f)
94
100
  return i ^ mask;
95
101
  }
96
102
 
97
- void radix_sort_pairs_host(float* keys, int* values, int n)
103
+ void radix_sort_pairs_host(float* keys, int* values, int n, int offset_to_scratch_memory)
98
104
  {
99
105
  static unsigned int tables[2][1 << 16];
100
106
  memset(tables, 0, sizeof(tables));
101
107
 
102
- float* auxKeys = keys + n;
103
- int* auxValues = values + n;
108
+ float* auxKeys = keys + offset_to_scratch_memory;
109
+ int* auxValues = values + offset_to_scratch_memory;
104
110
 
105
111
  // build histograms
106
112
  for (int i=0; i < n; ++i)
@@ -162,14 +168,46 @@ void radix_sort_pairs_host(float* keys, int* values, int n)
162
168
  }
163
169
  }
164
170
 
171
+ void radix_sort_pairs_host(float* keys, int* values, int n)
172
+ {
173
+ radix_sort_pairs_host(keys, values, n, n);
174
+ }
175
+
176
+ void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
177
+ {
178
+ for (int i = 0; i < num_segments; ++i)
179
+ {
180
+ const int start = segment_start_indices[i];
181
+ const int end = segment_end_indices[i];
182
+ radix_sort_pairs_host(keys + start, values + start, end - start, n);
183
+ }
184
+ }
185
+
186
+ void segmented_sort_pairs_host(int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
187
+ {
188
+ for (int i = 0; i < num_segments; ++i)
189
+ {
190
+ const int start = segment_start_indices[i];
191
+ const int end = segment_end_indices[i];
192
+ radix_sort_pairs_host(keys + start, values + start, end - start, n);
193
+ }
194
+ }
195
+
196
+
165
197
  #if !WP_ENABLE_CUDA
166
198
 
167
199
  void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out) {}
168
200
 
169
201
  void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n) {}
170
202
 
203
+ void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n) {}
204
+
171
205
  void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n) {}
172
206
 
207
+ void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments) {}
208
+
209
+ void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments) {}
210
+
173
211
  #endif // !WP_ENABLE_CUDA
174
212
 
175
213
 
@@ -180,9 +218,34 @@ void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n)
180
218
  reinterpret_cast<int *>(values), n);
181
219
  }
182
220
 
221
+ void radix_sort_pairs_int64_host(uint64_t keys, uint64_t values, int n)
222
+ {
223
+ radix_sort_pairs_host(
224
+ reinterpret_cast<int64_t *>(keys),
225
+ reinterpret_cast<int *>(values), n);
226
+ }
227
+
183
228
  void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n)
184
229
  {
185
230
  radix_sort_pairs_host(
186
231
  reinterpret_cast<float *>(keys),
187
232
  reinterpret_cast<int *>(values), n);
188
- }
233
+ }
234
+
235
+ void segmented_sort_pairs_float_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
236
+ {
237
+ segmented_sort_pairs_host(
238
+ reinterpret_cast<float *>(keys),
239
+ reinterpret_cast<int *>(values), n,
240
+ reinterpret_cast<int *>(segment_start_indices),
241
+ reinterpret_cast<int *>(segment_end_indices), num_segments);
242
+ }
243
+
244
+ void segmented_sort_pairs_int_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
245
+ {
246
+ segmented_sort_pairs_host(
247
+ reinterpret_cast<int *>(keys),
248
+ reinterpret_cast<int *>(values), n,
249
+ reinterpret_cast<int *>(segment_start_indices),
250
+ reinterpret_cast<int *>(segment_end_indices), num_segments);
251
+ }
warp/native/sort.cu CHANGED
@@ -36,11 +36,12 @@ struct RadixSortTemp
36
36
  static std::map<void*, RadixSortTemp> g_radix_sort_temp_map;
37
37
 
38
38
 
39
- void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
39
+ template <typename KeyType>
40
+ void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* size_out)
40
41
  {
41
42
  ContextGuard guard(context);
42
43
 
43
- cub::DoubleBuffer<int> d_keys;
44
+ cub::DoubleBuffer<KeyType> d_keys;
44
45
  cub::DoubleBuffer<int> d_values;
45
46
 
46
47
  // compute temporary memory required
@@ -50,7 +51,7 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
50
51
  sort_temp_size,
51
52
  d_keys,
52
53
  d_values,
53
- n, 0, 32,
54
+ n, 0, sizeof(KeyType)*8,
54
55
  (cudaStream_t)cuda_stream_get_current()));
55
56
 
56
57
  if (!context)
@@ -71,15 +72,21 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
71
72
  *size_out = temp.size;
72
73
  }
73
74
 
74
- void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
75
+ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
76
+ {
77
+ radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
78
+ }
79
+
80
+ template <typename KeyType>
81
+ void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
75
82
  {
76
83
  ContextGuard guard(context);
77
84
 
78
- cub::DoubleBuffer<int> d_keys(keys, keys + n);
85
+ cub::DoubleBuffer<KeyType> d_keys(keys, keys + n);
79
86
  cub::DoubleBuffer<int> d_values(values, values + n);
80
87
 
81
88
  RadixSortTemp temp;
82
- radix_sort_reserve(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
89
+ radix_sort_reserve_internal<KeyType>(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
83
90
 
84
91
  // sort
85
92
  check_cuda(cub::DeviceRadixSort::SortPairs(
@@ -87,16 +94,31 @@ void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
87
94
  temp.size,
88
95
  d_keys,
89
96
  d_values,
90
- n, 0, 32,
97
+ n, 0, sizeof(KeyType)*8,
91
98
  (cudaStream_t)cuda_stream_get_current()));
92
99
 
93
100
  if (d_keys.Current() != keys)
94
- memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(int)*n);
101
+ memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(KeyType)*n);
95
102
 
96
103
  if (d_values.Current() != values)
97
104
  memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
98
105
  }
99
106
 
107
+ void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
108
+ {
109
+ radix_sort_pairs_device<int>(context, keys, values, n);
110
+ }
111
+
112
+ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
113
+ {
114
+ radix_sort_pairs_device<float>(context, keys, values, n);
115
+ }
116
+
117
+ void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n)
118
+ {
119
+ radix_sort_pairs_device<int64_t>(context, keys, values, n);
120
+ }
121
+
100
122
  void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
101
123
  {
102
124
  radix_sort_pairs_device(
@@ -105,7 +127,69 @@ void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
105
127
  reinterpret_cast<int *>(values), n);
106
128
  }
107
129
 
108
- void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
130
+ void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
131
+ {
132
+ radix_sort_pairs_device(
133
+ WP_CURRENT_CONTEXT,
134
+ reinterpret_cast<float *>(keys),
135
+ reinterpret_cast<int *>(values), n);
136
+ }
137
+
138
+ void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n)
139
+ {
140
+ radix_sort_pairs_device(
141
+ WP_CURRENT_CONTEXT,
142
+ reinterpret_cast<int64_t *>(keys),
143
+ reinterpret_cast<int *>(values), n);
144
+ }
145
+
146
+ void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_out, size_t* size_out)
147
+ {
148
+ ContextGuard guard(context);
149
+
150
+ cub::DoubleBuffer<int> d_keys;
151
+ cub::DoubleBuffer<int> d_values;
152
+
153
+ int* start_indices = NULL;
154
+ int* end_indices = NULL;
155
+
156
+ // compute temporary memory required
157
+ size_t sort_temp_size;
158
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
159
+ NULL,
160
+ sort_temp_size,
161
+ d_keys,
162
+ d_values,
163
+ n,
164
+ num_segments,
165
+ start_indices,
166
+ end_indices,
167
+ 0,
168
+ 32,
169
+ (cudaStream_t)cuda_stream_get_current()));
170
+
171
+ if (!context)
172
+ context = cuda_context_get_current();
173
+
174
+ RadixSortTemp& temp = g_radix_sort_temp_map[context];
175
+
176
+ if (sort_temp_size > temp.size)
177
+ {
178
+ free_device(WP_CURRENT_CONTEXT, temp.mem);
179
+ temp.mem = alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
180
+ temp.size = sort_temp_size;
181
+ }
182
+
183
+ if (mem_out)
184
+ *mem_out = temp.mem;
185
+ if (size_out)
186
+ *size_out = temp.size;
187
+ }
188
+
189
+ // segment_start_indices and segment_end_indices are arrays of length num_segments, where segment_start_indices[i] is the index of the first element
190
+ // in the i-th segment and segment_end_indices[i] is the index after the last element in the i-th segment
191
+ // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
192
+ void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
109
193
  {
110
194
  ContextGuard guard(context);
111
195
 
@@ -113,15 +197,20 @@ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
113
197
  cub::DoubleBuffer<int> d_values(values, values + n);
114
198
 
115
199
  RadixSortTemp temp;
116
- radix_sort_reserve(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
200
+ segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
117
201
 
118
202
  // sort
119
- check_cuda(cub::DeviceRadixSort::SortPairs(
203
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
120
204
  temp.mem,
121
205
  temp.size,
122
206
  d_keys,
123
207
  d_values,
124
- n, 0, 32,
208
+ n,
209
+ num_segments,
210
+ segment_start_indices,
211
+ segment_end_indices,
212
+ 0,
213
+ 32,
125
214
  (cudaStream_t)cuda_stream_get_current()));
126
215
 
127
216
  if (d_keys.Current() != keys)
@@ -131,10 +220,58 @@ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
131
220
  memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
132
221
  }
133
222
 
134
- void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
223
+ void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
135
224
  {
136
- radix_sort_pairs_device(
225
+ segmented_sort_pairs_device(
137
226
  WP_CURRENT_CONTEXT,
138
227
  reinterpret_cast<float *>(keys),
139
- reinterpret_cast<int *>(values), n);
228
+ reinterpret_cast<int *>(values), n,
229
+ reinterpret_cast<int *>(segment_start_indices),
230
+ reinterpret_cast<int *>(segment_end_indices),
231
+ num_segments);
232
+ }
233
+
234
+ // segment_indices is an array of length num_segments + 1, where segment_indices[i] is the index of the first element in the i-th segment
235
+ // The end of a segment is given by segment_indices[i+1]
236
+ // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#a-simple-example
237
+ void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
238
+ {
239
+ ContextGuard guard(context);
240
+
241
+ cub::DoubleBuffer<int> d_keys(keys, keys + n);
242
+ cub::DoubleBuffer<int> d_values(values, values + n);
243
+
244
+ RadixSortTemp temp;
245
+ segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
246
+
247
+ // sort
248
+ check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
249
+ temp.mem,
250
+ temp.size,
251
+ d_keys,
252
+ d_values,
253
+ n,
254
+ num_segments,
255
+ segment_start_indices,
256
+ segment_end_indices,
257
+ 0,
258
+ 32,
259
+ (cudaStream_t)cuda_stream_get_current()));
260
+
261
+ if (d_keys.Current() != keys)
262
+ memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
263
+
264
+ if (d_values.Current() != values)
265
+ memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
266
+ }
267
+
268
+ void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
269
+ {
270
+ segmented_sort_pairs_device(
271
+ WP_CURRENT_CONTEXT,
272
+ reinterpret_cast<int *>(keys),
273
+ reinterpret_cast<int *>(values), n,
274
+ reinterpret_cast<int *>(segment_start_indices),
275
+ reinterpret_cast<int *>(segment_end_indices),
276
+ num_segments);
140
277
  }
warp/native/sort.h CHANGED
@@ -22,5 +22,12 @@
22
22
  void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
23
23
  void radix_sort_pairs_host(int* keys, int* values, int n);
24
24
  void radix_sort_pairs_host(float* keys, int* values, int n);
25
+ void radix_sort_pairs_host(int64_t* keys, int* values, int n);
25
26
  void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
26
- void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
27
+ void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
28
+ void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n);
29
+
30
+ void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
31
+ void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
32
+ void segmented_sort_pairs_host(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
33
+ void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);