warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +7 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +410 -0
- warp/build_dll.py +6 -14
- warp/builtins.py +452 -362
- warp/codegen.py +179 -119
- warp/config.py +42 -6
- warp/context.py +490 -271
- warp/dlpack.py +8 -6
- warp/examples/assets/nonuniform.usd +0 -0
- warp/examples/assets/nvidia_logo.png +0 -0
- warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
- warp/examples/core/example_sample_mesh.py +300 -0
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +2 -2
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_magnetostatics.py +6 -6
- warp/examples/fem/utils.py +9 -3
- warp/examples/interop/example_jax_callable.py +116 -0
- warp/examples/interop/example_jax_ffi_callback.py +132 -0
- warp/examples/interop/example_jax_kernel.py +205 -0
- warp/examples/optim/example_fluid_checkpoint.py +497 -0
- warp/examples/tile/example_tile_matmul.py +2 -4
- warp/fem/__init__.py +11 -1
- warp/fem/adaptivity.py +4 -4
- warp/fem/field/nodal_field.py +22 -68
- warp/fem/field/virtual.py +62 -23
- warp/fem/geometry/adaptive_nanogrid.py +9 -10
- warp/fem/geometry/closest_point.py +1 -1
- warp/fem/geometry/deformed_geometry.py +5 -2
- warp/fem/geometry/geometry.py +5 -0
- warp/fem/geometry/grid_2d.py +12 -12
- warp/fem/geometry/grid_3d.py +12 -15
- warp/fem/geometry/hexmesh.py +5 -7
- warp/fem/geometry/nanogrid.py +9 -11
- warp/fem/geometry/quadmesh.py +13 -13
- warp/fem/geometry/tetmesh.py +3 -4
- warp/fem/geometry/trimesh.py +3 -8
- warp/fem/integrate.py +262 -93
- warp/fem/linalg.py +5 -5
- warp/fem/quadrature/pic_quadrature.py +37 -22
- warp/fem/quadrature/quadrature.py +194 -25
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +4 -2
- warp/fem/space/basis_space.py +25 -18
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +6 -2
- warp/fem/space/quadmesh_function_space.py +8 -8
- warp/fem/space/shape/cube_shape_function.py +23 -23
- warp/fem/space/shape/square_shape_function.py +12 -12
- warp/fem/space/shape/triangle_shape_function.py +1 -1
- warp/fem/space/tetmesh_function_space.py +3 -3
- warp/fem/space/trimesh_function_space.py +2 -2
- warp/fem/utils.py +12 -6
- warp/jax.py +14 -1
- warp/jax_experimental/__init__.py +16 -0
- warp/{jax_experimental.py → jax_experimental/custom_call.py} +14 -27
- warp/jax_experimental/ffi.py +698 -0
- warp/jax_experimental/xla_ffi.py +602 -0
- warp/math.py +89 -0
- warp/native/array.h +13 -0
- warp/native/builtin.h +29 -3
- warp/native/bvh.cpp +3 -1
- warp/native/bvh.cu +42 -14
- warp/native/bvh.h +2 -1
- warp/native/clang/clang.cpp +30 -3
- warp/native/cuda_util.cpp +14 -0
- warp/native/cuda_util.h +2 -0
- warp/native/exports.h +68 -63
- warp/native/intersect.h +26 -26
- warp/native/intersect_adj.h +33 -33
- warp/native/marching.cu +1 -1
- warp/native/mat.h +513 -9
- warp/native/mesh.h +10 -10
- warp/native/quat.h +99 -11
- warp/native/rand.h +6 -0
- warp/native/sort.cpp +122 -59
- warp/native/sort.cu +152 -15
- warp/native/sort.h +8 -1
- warp/native/sparse.cpp +43 -22
- warp/native/sparse.cu +52 -17
- warp/native/svd.h +116 -0
- warp/native/tile.h +301 -105
- warp/native/tile_reduce.h +46 -3
- warp/native/vec.h +68 -7
- warp/native/volume.cpp +85 -113
- warp/native/volume_builder.cu +25 -10
- warp/native/volume_builder.h +6 -0
- warp/native/warp.cpp +5 -6
- warp/native/warp.cu +99 -10
- warp/native/warp.h +19 -10
- warp/optim/linear.py +10 -10
- warp/sim/articulation.py +4 -4
- warp/sim/collide.py +21 -10
- warp/sim/import_mjcf.py +449 -155
- warp/sim/import_urdf.py +32 -12
- warp/sim/integrator_euler.py +5 -5
- warp/sim/integrator_featherstone.py +3 -10
- warp/sim/integrator_vbd.py +207 -2
- warp/sim/integrator_xpbd.py +5 -5
- warp/sim/model.py +42 -13
- warp/sim/utils.py +2 -2
- warp/sparse.py +642 -555
- warp/stubs.py +216 -19
- warp/tests/__main__.py +0 -15
- warp/tests/cuda/__init__.py +0 -0
- warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
- warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
- warp/tests/geometry/__init__.py +0 -0
- warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
- warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
- warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
- warp/tests/interop/__init__.py +0 -0
- warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
- warp/tests/sim/__init__.py +0 -0
- warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
- warp/tests/{test_collision.py → sim/test_collision.py} +2 -2
- warp/tests/{test_model.py → sim/test_model.py} +40 -0
- warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
- warp/tests/sim/test_vbd.py +597 -0
- warp/tests/test_bool.py +1 -1
- warp/tests/test_examples.py +28 -36
- warp/tests/test_fem.py +23 -4
- warp/tests/test_linear_solvers.py +0 -11
- warp/tests/test_mat.py +233 -79
- warp/tests/test_mat_scalar_ops.py +4 -4
- warp/tests/test_overwrite.py +0 -60
- warp/tests/test_quat.py +67 -46
- warp/tests/test_rand.py +44 -37
- warp/tests/test_sparse.py +47 -6
- warp/tests/test_spatial.py +75 -0
- warp/tests/test_static.py +1 -1
- warp/tests/test_utils.py +84 -4
- warp/tests/test_vec.py +46 -34
- warp/tests/tile/__init__.py +0 -0
- warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
- warp/tests/{test_tile_load.py → tile/test_tile_load.py} +1 -1
- warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
- warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
- warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
- warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
- warp/tests/unittest_serial.py +1 -0
- warp/tests/unittest_suites.py +45 -59
- warp/tests/unittest_utils.py +2 -1
- warp/thirdparty/unittest_parallel.py +3 -1
- warp/types.py +110 -658
- warp/utils.py +137 -72
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/METADATA +29 -7
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/RECORD +172 -162
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/WHEEL +1 -1
- warp/examples/optim/example_walker.py +0 -317
- warp/native/cutlass_gemm.cpp +0 -43
- warp/native/cutlass_gemm.cu +0 -382
- warp/tests/test_matmul.py +0 -511
- warp/tests/test_matmul_lite.py +0 -411
- warp/tests/test_vbd.py +0 -386
- warp/tests/unused_test_misc.py +0 -77
- /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
- /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
- /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
- /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
- /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
- /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
- /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
- /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
- /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
- /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
- /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
- /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
- /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
- /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
- /warp/tests/{flaky_test_sim_grad.py → sim/flaky_test_sim_grad.py} +0 -0
- /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
- /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
- /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info/licenses}/LICENSE.md +0 -0
- {warp_lang-1.6.2.dist-info → warp_lang-1.7.0.dist-info}/top_level.txt +0 -0
warp/native/quat.h
CHANGED
|
@@ -375,12 +375,14 @@ inline CUDA_CALLABLE mat_t<3,3,Type> quat_to_matrix(const quat_t<Type>& q)
|
|
|
375
375
|
vec_t<3,Type> c2 = quat_rotate(q, vec_t<3,Type>(0.0, 1.0, 0.0));
|
|
376
376
|
vec_t<3,Type> c3 = quat_rotate(q, vec_t<3,Type>(0.0, 0.0, 1.0));
|
|
377
377
|
|
|
378
|
-
return
|
|
378
|
+
return matrix_from_cols<Type>(c1, c2, c3);
|
|
379
379
|
}
|
|
380
380
|
|
|
381
|
-
template<typename Type>
|
|
382
|
-
inline CUDA_CALLABLE quat_t<Type> quat_from_matrix(const mat_t<
|
|
381
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
382
|
+
inline CUDA_CALLABLE quat_t<Type> quat_from_matrix(const mat_t<Rows,Cols,Type>& m)
|
|
383
383
|
{
|
|
384
|
+
static_assert((Rows == 3 && Cols == 3) || (Rows == 4 && Cols == 4), "Non-square matrix");
|
|
385
|
+
|
|
384
386
|
const Type tr = m.data[0][0] + m.data[1][1] + m.data[2][2];
|
|
385
387
|
Type x, y, z, w, h = Type(0);
|
|
386
388
|
|
|
@@ -498,37 +500,98 @@ inline CUDA_CALLABLE void adj_indexref(quat_t<Type>* q, int idx,
|
|
|
498
500
|
|
|
499
501
|
|
|
500
502
|
template<typename Type>
|
|
501
|
-
inline CUDA_CALLABLE void
|
|
503
|
+
inline CUDA_CALLABLE void add_inplace(quat_t<Type>& q, int idx, Type value)
|
|
502
504
|
{
|
|
505
|
+
#ifndef NDEBUG
|
|
506
|
+
if (idx < 0 || idx > 3)
|
|
507
|
+
{
|
|
508
|
+
printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
|
|
509
|
+
assert(0);
|
|
510
|
+
}
|
|
511
|
+
#endif
|
|
512
|
+
|
|
503
513
|
q[idx] += value;
|
|
504
514
|
}
|
|
505
515
|
|
|
506
516
|
|
|
507
517
|
template<typename Type>
|
|
508
|
-
inline CUDA_CALLABLE void
|
|
518
|
+
inline CUDA_CALLABLE void adj_add_inplace(quat_t<Type>& q, int idx, Type value,
|
|
509
519
|
quat_t<Type>& adj_q, int adj_idx, Type& adj_value)
|
|
510
520
|
{
|
|
521
|
+
#ifndef NDEBUG
|
|
522
|
+
if (idx < 0 || idx > 3)
|
|
523
|
+
{
|
|
524
|
+
printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
|
|
525
|
+
assert(0);
|
|
526
|
+
}
|
|
527
|
+
#endif
|
|
528
|
+
|
|
511
529
|
adj_value += adj_q[idx];
|
|
512
530
|
}
|
|
513
531
|
|
|
514
532
|
|
|
515
533
|
template<typename Type>
|
|
516
|
-
inline CUDA_CALLABLE void
|
|
534
|
+
inline CUDA_CALLABLE void sub_inplace(quat_t<Type>& q, int idx, Type value)
|
|
517
535
|
{
|
|
536
|
+
#ifndef NDEBUG
|
|
537
|
+
if (idx < 0 || idx > 3)
|
|
538
|
+
{
|
|
539
|
+
printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
|
|
540
|
+
assert(0);
|
|
541
|
+
}
|
|
542
|
+
#endif
|
|
543
|
+
|
|
518
544
|
q[idx] -= value;
|
|
519
545
|
}
|
|
520
546
|
|
|
521
547
|
|
|
522
548
|
template<typename Type>
|
|
523
|
-
inline CUDA_CALLABLE void
|
|
549
|
+
inline CUDA_CALLABLE void adj_sub_inplace(quat_t<Type>& q, int idx, Type value,
|
|
524
550
|
quat_t<Type>& adj_q, int adj_idx, Type& adj_value)
|
|
525
551
|
{
|
|
552
|
+
#ifndef NDEBUG
|
|
553
|
+
if (idx < 0 || idx > 3)
|
|
554
|
+
{
|
|
555
|
+
printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
|
|
556
|
+
assert(0);
|
|
557
|
+
}
|
|
558
|
+
#endif
|
|
559
|
+
|
|
526
560
|
adj_value -= adj_q[idx];
|
|
527
561
|
}
|
|
528
562
|
|
|
529
563
|
|
|
530
564
|
template<typename Type>
|
|
531
|
-
inline CUDA_CALLABLE
|
|
565
|
+
inline CUDA_CALLABLE void assign_inplace(quat_t<Type>& q, int idx, Type value)
|
|
566
|
+
{
|
|
567
|
+
#ifndef NDEBUG
|
|
568
|
+
if (idx < 0 || idx > 3)
|
|
569
|
+
{
|
|
570
|
+
printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
|
|
571
|
+
assert(0);
|
|
572
|
+
}
|
|
573
|
+
#endif
|
|
574
|
+
|
|
575
|
+
q[idx] = value;
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
template<typename Type>
|
|
579
|
+
inline CUDA_CALLABLE void adj_assign_inplace(quat_t<Type>& q, int idx, Type value, quat_t<Type>& adj_q, int& adj_idx, Type& adj_value)
|
|
580
|
+
{
|
|
581
|
+
#ifndef NDEBUG
|
|
582
|
+
if (idx < 0 || idx > 3)
|
|
583
|
+
{
|
|
584
|
+
printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
|
|
585
|
+
assert(0);
|
|
586
|
+
}
|
|
587
|
+
#endif
|
|
588
|
+
|
|
589
|
+
adj_value += adj_q[idx];
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
template<typename Type>
|
|
594
|
+
inline CUDA_CALLABLE quat_t<Type> assign_copy(quat_t<Type>& q, int idx, Type value)
|
|
532
595
|
{
|
|
533
596
|
#ifndef NDEBUG
|
|
534
597
|
if (idx < 0 || idx > 3)
|
|
@@ -544,7 +607,7 @@ inline CUDA_CALLABLE quat_t<Type> assign(quat_t<Type>& q, int idx, Type value)
|
|
|
544
607
|
}
|
|
545
608
|
|
|
546
609
|
template<typename Type>
|
|
547
|
-
inline CUDA_CALLABLE void
|
|
610
|
+
inline CUDA_CALLABLE void adj_assign_copy(quat_t<Type>& q, int idx, Type value, quat_t<Type>& adj_q, int& adj_idx, Type& adj_value, const quat_t<Type>& adj_ret)
|
|
548
611
|
{
|
|
549
612
|
#ifndef NDEBUG
|
|
550
613
|
if (idx < 0 || idx > 3)
|
|
@@ -562,6 +625,7 @@ inline CUDA_CALLABLE void adj_assign(quat_t<Type>& q, int idx, Type value, quat_
|
|
|
562
625
|
}
|
|
563
626
|
}
|
|
564
627
|
|
|
628
|
+
|
|
565
629
|
template<typename Type>
|
|
566
630
|
CUDA_CALLABLE inline quat_t<Type> lerp(const quat_t<Type>& a, const quat_t<Type>& b, Type t)
|
|
567
631
|
{
|
|
@@ -1048,9 +1112,11 @@ inline CUDA_CALLABLE void adj_quat_to_matrix(const quat_t<Type>& q, quat_t<Type>
|
|
|
1048
1112
|
adj_quat_rotate(q, vec_t<3,Type>(0.0, 0.0, 1.0), adj_q, t, adj_ret.get_col(2));
|
|
1049
1113
|
}
|
|
1050
1114
|
|
|
1051
|
-
template<typename Type>
|
|
1052
|
-
inline CUDA_CALLABLE void adj_quat_from_matrix(const mat_t<
|
|
1115
|
+
template<unsigned Rows, unsigned Cols, typename Type>
|
|
1116
|
+
inline CUDA_CALLABLE void adj_quat_from_matrix(const mat_t<Rows,Cols,Type>& m, mat_t<Rows,Cols,Type>& adj_m, const quat_t<Type>& adj_ret)
|
|
1053
1117
|
{
|
|
1118
|
+
static_assert((Rows == 3 && Cols == 3) || (Rows == 4 && Cols == 4), "Non-square matrix");
|
|
1119
|
+
|
|
1054
1120
|
const Type tr = m.data[0][0] + m.data[1][1] + m.data[2][2];
|
|
1055
1121
|
Type x, y, z, w, h = Type(0);
|
|
1056
1122
|
|
|
@@ -1280,4 +1346,26 @@ CUDA_CALLABLE inline void adj_len(const quat_t<Type>& x, quat_t<Type>& adj_x, co
|
|
|
1280
1346
|
{
|
|
1281
1347
|
}
|
|
1282
1348
|
|
|
1349
|
+
template<typename Type>
|
|
1350
|
+
inline CUDA_CALLABLE void expect_near(const quat_t<Type>& actual, const quat_t<Type>& expected, const Type& tolerance)
|
|
1351
|
+
{
|
|
1352
|
+
Type diff(0);
|
|
1353
|
+
for(size_t i = 0; i < 4; ++i)
|
|
1354
|
+
{
|
|
1355
|
+
diff = max(diff, abs(actual[i] - expected[i]));
|
|
1356
|
+
}
|
|
1357
|
+
if (diff > tolerance)
|
|
1358
|
+
{
|
|
1359
|
+
printf("Error, expect_near() failed with tolerance "); print(tolerance);
|
|
1360
|
+
printf("\t Expected: "); print(expected);
|
|
1361
|
+
printf("\t Actual: "); print(actual);
|
|
1362
|
+
}
|
|
1363
|
+
}
|
|
1364
|
+
|
|
1365
|
+
template<typename Type>
|
|
1366
|
+
inline CUDA_CALLABLE void adj_expect_near(const quat_t<Type>& actual, const quat_t<Type>& expected, Type tolerance, quat_t<Type>& adj_actual, quat_t<Type>& adj_expected, Type adj_tolerance)
|
|
1367
|
+
{
|
|
1368
|
+
// nop
|
|
1369
|
+
}
|
|
1370
|
+
|
|
1283
1371
|
} // namespace wp
|
warp/native/rand.h
CHANGED
|
@@ -53,6 +53,9 @@ inline CUDA_CALLABLE uint32 rand_init(int seed, int offset) { return rand_pcg(ui
|
|
|
53
53
|
inline CUDA_CALLABLE int randi(uint32& state) { state = rand_pcg(state); return int(state); }
|
|
54
54
|
inline CUDA_CALLABLE int randi(uint32& state, int min, int max) { state = rand_pcg(state); return state % (max - min) + min; }
|
|
55
55
|
|
|
56
|
+
inline CUDA_CALLABLE uint32 randu(uint32& state) { state = rand_pcg(state); return state; }
|
|
57
|
+
inline CUDA_CALLABLE uint32 randu(uint32& state, uint32 min, uint32 max) { state = rand_pcg(state); return state % (max - min) + min; }
|
|
58
|
+
|
|
56
59
|
/*
|
|
57
60
|
* We want to ensure randf adheres to a uniform distribution over [0,1). The set of all possible float32 (IEEE 754 standard) values is not uniformly distributed however.
|
|
58
61
|
* On the other hand, for a given sign and exponent, the mantissa of the float32 representation is uniformly distributed.
|
|
@@ -74,6 +77,9 @@ inline CUDA_CALLABLE void adj_rand_init(int seed, int offset, int& adj_seed, int
|
|
|
74
77
|
inline CUDA_CALLABLE void adj_randi(uint32& state, uint32& adj_state, float adj_ret) {}
|
|
75
78
|
inline CUDA_CALLABLE void adj_randi(uint32& state, int min, int max, uint32& adj_state, int& adj_min, int& adj_max, float adj_ret) {}
|
|
76
79
|
|
|
80
|
+
inline CUDA_CALLABLE void adj_randu(uint32& state, uint32& adj_state, float adj_ret) {}
|
|
81
|
+
inline CUDA_CALLABLE void adj_randu(uint32& state, uint32 min, uint32 max, uint32& adj_state, uint32& adj_min, uint32& adj_max, float adj_ret) {}
|
|
82
|
+
|
|
77
83
|
inline CUDA_CALLABLE void adj_randf(uint32& state, uint32& adj_state, float adj_ret) {}
|
|
78
84
|
inline CUDA_CALLABLE void adj_randf(uint32& state, float min, float max, uint32& adj_state, float& adj_min, float& adj_max, float adj_ret) {}
|
|
79
85
|
|
warp/native/sort.cpp
CHANGED
|
@@ -21,69 +21,75 @@
|
|
|
21
21
|
|
|
22
22
|
#include <cstdint>
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
//Only integer keys (bit count 32 or 64) are supported. Floats need to get converted into int first. see radix_float_to_int.
|
|
25
|
+
template <typename KeyType>
|
|
26
|
+
void radix_sort_pairs_host(KeyType* keys, int* values, int n, int offset_to_scratch_memory)
|
|
25
27
|
{
|
|
26
|
-
|
|
28
|
+
const int numPasses = sizeof(KeyType) / 2;
|
|
29
|
+
static int tables[numPasses][1 << 16];
|
|
27
30
|
memset(tables, 0, sizeof(tables));
|
|
28
|
-
|
|
29
|
-
int* auxKeys = keys + n;
|
|
30
|
-
int* auxValues = values + n;
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
// build histograms
|
|
33
|
-
for (int
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
33
|
+
for (int p = 0; p < numPasses; ++p)
|
|
34
|
+
{
|
|
35
|
+
for (int i=0; i < n; ++i)
|
|
36
|
+
{
|
|
37
|
+
const int shift = p * 16;
|
|
38
|
+
const int b = (keys[i] >> shift) & 0xffff;
|
|
39
|
+
|
|
40
|
+
++tables[p][b];
|
|
41
|
+
}
|
|
40
42
|
}
|
|
41
43
|
|
|
42
|
-
// convert histograms to offset tables in-place
|
|
43
|
-
int
|
|
44
|
-
int offhigh = 0;
|
|
45
|
-
|
|
46
|
-
for (int i=0; i < 65536; ++i)
|
|
44
|
+
// convert histograms to offset tables in-place
|
|
45
|
+
for (int p = 0; p < numPasses; ++p)
|
|
47
46
|
{
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
47
|
+
int off = 0;
|
|
48
|
+
for (int i = 0; i < 65536; ++i)
|
|
49
|
+
{
|
|
50
|
+
const int newoff = off + tables[p][i];
|
|
51
|
+
|
|
52
|
+
tables[p][i] = off;
|
|
53
|
+
|
|
54
|
+
off = newoff;
|
|
55
|
+
}
|
|
56
56
|
}
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
//
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
57
|
+
|
|
58
|
+
for (int p = 0; p < numPasses; ++p)
|
|
59
|
+
{
|
|
60
|
+
int flipFlop = p % 2;
|
|
61
|
+
KeyType* readKeys = keys + offset_to_scratch_memory * flipFlop;
|
|
62
|
+
int* readValues = values + offset_to_scratch_memory * flipFlop;
|
|
63
|
+
KeyType* writeKeys = keys + offset_to_scratch_memory * (1 - flipFlop);
|
|
64
|
+
int* writeValues = values + offset_to_scratch_memory * (1 - flipFlop);
|
|
65
|
+
|
|
66
|
+
// pass 1 - sort by low 16 bits
|
|
67
|
+
for (int i=0; i < n; ++i)
|
|
68
|
+
{
|
|
69
|
+
// lookup offset of input
|
|
70
|
+
const KeyType k = readKeys[i];
|
|
71
|
+
const int v = readValues[i];
|
|
72
|
+
|
|
73
|
+
const int shift = p * 16;
|
|
74
|
+
const int b = (k >> shift) & 0xffff;
|
|
75
|
+
|
|
76
|
+
// find offset and increment
|
|
77
|
+
const int offset = tables[p][b]++;
|
|
78
|
+
|
|
79
|
+
writeKeys[offset] = k;
|
|
80
|
+
writeValues[offset] = v;
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
79
84
|
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
85
|
+
void radix_sort_pairs_host(int* keys, int* values, int n)
|
|
86
|
+
{
|
|
87
|
+
radix_sort_pairs_host<int>(keys, values, n, n);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
void radix_sort_pairs_host(int64_t* keys, int* values, int n)
|
|
91
|
+
{
|
|
92
|
+
radix_sort_pairs_host<int64_t>(keys, values, n, n);
|
|
87
93
|
}
|
|
88
94
|
|
|
89
95
|
//http://stereopsis.com/radix.html
|
|
@@ -94,13 +100,13 @@ inline unsigned int radix_float_to_int(float f)
|
|
|
94
100
|
return i ^ mask;
|
|
95
101
|
}
|
|
96
102
|
|
|
97
|
-
void radix_sort_pairs_host(float* keys, int* values, int n)
|
|
103
|
+
void radix_sort_pairs_host(float* keys, int* values, int n, int offset_to_scratch_memory)
|
|
98
104
|
{
|
|
99
105
|
static unsigned int tables[2][1 << 16];
|
|
100
106
|
memset(tables, 0, sizeof(tables));
|
|
101
107
|
|
|
102
|
-
float* auxKeys = keys +
|
|
103
|
-
int* auxValues = values +
|
|
108
|
+
float* auxKeys = keys + offset_to_scratch_memory;
|
|
109
|
+
int* auxValues = values + offset_to_scratch_memory;
|
|
104
110
|
|
|
105
111
|
// build histograms
|
|
106
112
|
for (int i=0; i < n; ++i)
|
|
@@ -162,14 +168,46 @@ void radix_sort_pairs_host(float* keys, int* values, int n)
|
|
|
162
168
|
}
|
|
163
169
|
}
|
|
164
170
|
|
|
171
|
+
void radix_sort_pairs_host(float* keys, int* values, int n)
|
|
172
|
+
{
|
|
173
|
+
radix_sort_pairs_host(keys, values, n, n);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
177
|
+
{
|
|
178
|
+
for (int i = 0; i < num_segments; ++i)
|
|
179
|
+
{
|
|
180
|
+
const int start = segment_start_indices[i];
|
|
181
|
+
const int end = segment_end_indices[i];
|
|
182
|
+
radix_sort_pairs_host(keys + start, values + start, end - start, n);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
void segmented_sort_pairs_host(int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
187
|
+
{
|
|
188
|
+
for (int i = 0; i < num_segments; ++i)
|
|
189
|
+
{
|
|
190
|
+
const int start = segment_start_indices[i];
|
|
191
|
+
const int end = segment_end_indices[i];
|
|
192
|
+
radix_sort_pairs_host(keys + start, values + start, end - start, n);
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
|
|
165
197
|
#if !WP_ENABLE_CUDA
|
|
166
198
|
|
|
167
199
|
void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out) {}
|
|
168
200
|
|
|
169
201
|
void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n) {}
|
|
170
202
|
|
|
203
|
+
void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n) {}
|
|
204
|
+
|
|
171
205
|
void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n) {}
|
|
172
206
|
|
|
207
|
+
void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments) {}
|
|
208
|
+
|
|
209
|
+
void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments) {}
|
|
210
|
+
|
|
173
211
|
#endif // !WP_ENABLE_CUDA
|
|
174
212
|
|
|
175
213
|
|
|
@@ -180,9 +218,34 @@ void radix_sort_pairs_int_host(uint64_t keys, uint64_t values, int n)
|
|
|
180
218
|
reinterpret_cast<int *>(values), n);
|
|
181
219
|
}
|
|
182
220
|
|
|
221
|
+
void radix_sort_pairs_int64_host(uint64_t keys, uint64_t values, int n)
|
|
222
|
+
{
|
|
223
|
+
radix_sort_pairs_host(
|
|
224
|
+
reinterpret_cast<int64_t *>(keys),
|
|
225
|
+
reinterpret_cast<int *>(values), n);
|
|
226
|
+
}
|
|
227
|
+
|
|
183
228
|
void radix_sort_pairs_float_host(uint64_t keys, uint64_t values, int n)
|
|
184
229
|
{
|
|
185
230
|
radix_sort_pairs_host(
|
|
186
231
|
reinterpret_cast<float *>(keys),
|
|
187
232
|
reinterpret_cast<int *>(values), n);
|
|
188
|
-
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
void segmented_sort_pairs_float_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
236
|
+
{
|
|
237
|
+
segmented_sort_pairs_host(
|
|
238
|
+
reinterpret_cast<float *>(keys),
|
|
239
|
+
reinterpret_cast<int *>(values), n,
|
|
240
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
241
|
+
reinterpret_cast<int *>(segment_end_indices), num_segments);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
void segmented_sort_pairs_int_host(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
245
|
+
{
|
|
246
|
+
segmented_sort_pairs_host(
|
|
247
|
+
reinterpret_cast<int *>(keys),
|
|
248
|
+
reinterpret_cast<int *>(values), n,
|
|
249
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
250
|
+
reinterpret_cast<int *>(segment_end_indices), num_segments);
|
|
251
|
+
}
|
warp/native/sort.cu
CHANGED
|
@@ -36,11 +36,12 @@ struct RadixSortTemp
|
|
|
36
36
|
static std::map<void*, RadixSortTemp> g_radix_sort_temp_map;
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
|
|
39
|
+
template <typename KeyType>
|
|
40
|
+
void radix_sort_reserve_internal(void* context, int n, void** mem_out, size_t* size_out)
|
|
40
41
|
{
|
|
41
42
|
ContextGuard guard(context);
|
|
42
43
|
|
|
43
|
-
cub::DoubleBuffer<
|
|
44
|
+
cub::DoubleBuffer<KeyType> d_keys;
|
|
44
45
|
cub::DoubleBuffer<int> d_values;
|
|
45
46
|
|
|
46
47
|
// compute temporary memory required
|
|
@@ -50,7 +51,7 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
|
50
51
|
sort_temp_size,
|
|
51
52
|
d_keys,
|
|
52
53
|
d_values,
|
|
53
|
-
n, 0,
|
|
54
|
+
n, 0, sizeof(KeyType)*8,
|
|
54
55
|
(cudaStream_t)cuda_stream_get_current()));
|
|
55
56
|
|
|
56
57
|
if (!context)
|
|
@@ -71,15 +72,21 @@ void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
|
71
72
|
*size_out = temp.size;
|
|
72
73
|
}
|
|
73
74
|
|
|
74
|
-
void
|
|
75
|
+
void radix_sort_reserve(void* context, int n, void** mem_out, size_t* size_out)
|
|
76
|
+
{
|
|
77
|
+
radix_sort_reserve_internal<int>(context, n, mem_out, size_out);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
template <typename KeyType>
|
|
81
|
+
void radix_sort_pairs_device(void* context, KeyType* keys, int* values, int n)
|
|
75
82
|
{
|
|
76
83
|
ContextGuard guard(context);
|
|
77
84
|
|
|
78
|
-
cub::DoubleBuffer<
|
|
85
|
+
cub::DoubleBuffer<KeyType> d_keys(keys, keys + n);
|
|
79
86
|
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
80
87
|
|
|
81
88
|
RadixSortTemp temp;
|
|
82
|
-
|
|
89
|
+
radix_sort_reserve_internal<KeyType>(WP_CURRENT_CONTEXT, n, &temp.mem, &temp.size);
|
|
83
90
|
|
|
84
91
|
// sort
|
|
85
92
|
check_cuda(cub::DeviceRadixSort::SortPairs(
|
|
@@ -87,16 +94,31 @@ void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
|
|
|
87
94
|
temp.size,
|
|
88
95
|
d_keys,
|
|
89
96
|
d_values,
|
|
90
|
-
n, 0,
|
|
97
|
+
n, 0, sizeof(KeyType)*8,
|
|
91
98
|
(cudaStream_t)cuda_stream_get_current()));
|
|
92
99
|
|
|
93
100
|
if (d_keys.Current() != keys)
|
|
94
|
-
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(
|
|
101
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(KeyType)*n);
|
|
95
102
|
|
|
96
103
|
if (d_values.Current() != values)
|
|
97
104
|
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
98
105
|
}
|
|
99
106
|
|
|
107
|
+
void radix_sort_pairs_device(void* context, int* keys, int* values, int n)
|
|
108
|
+
{
|
|
109
|
+
radix_sort_pairs_device<int>(context, keys, values, n);
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
|
|
113
|
+
{
|
|
114
|
+
radix_sort_pairs_device<float>(context, keys, values, n);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n)
|
|
118
|
+
{
|
|
119
|
+
radix_sort_pairs_device<int64_t>(context, keys, values, n);
|
|
120
|
+
}
|
|
121
|
+
|
|
100
122
|
void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
|
|
101
123
|
{
|
|
102
124
|
radix_sort_pairs_device(
|
|
@@ -105,7 +127,69 @@ void radix_sort_pairs_int_device(uint64_t keys, uint64_t values, int n)
|
|
|
105
127
|
reinterpret_cast<int *>(values), n);
|
|
106
128
|
}
|
|
107
129
|
|
|
108
|
-
void
|
|
130
|
+
void radix_sort_pairs_float_device(uint64_t keys, uint64_t values, int n)
|
|
131
|
+
{
|
|
132
|
+
radix_sort_pairs_device(
|
|
133
|
+
WP_CURRENT_CONTEXT,
|
|
134
|
+
reinterpret_cast<float *>(keys),
|
|
135
|
+
reinterpret_cast<int *>(values), n);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
void radix_sort_pairs_int64_device(uint64_t keys, uint64_t values, int n)
|
|
139
|
+
{
|
|
140
|
+
radix_sort_pairs_device(
|
|
141
|
+
WP_CURRENT_CONTEXT,
|
|
142
|
+
reinterpret_cast<int64_t *>(keys),
|
|
143
|
+
reinterpret_cast<int *>(values), n);
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
void segmented_sort_reserve(void* context, int n, int num_segments, void** mem_out, size_t* size_out)
|
|
147
|
+
{
|
|
148
|
+
ContextGuard guard(context);
|
|
149
|
+
|
|
150
|
+
cub::DoubleBuffer<int> d_keys;
|
|
151
|
+
cub::DoubleBuffer<int> d_values;
|
|
152
|
+
|
|
153
|
+
int* start_indices = NULL;
|
|
154
|
+
int* end_indices = NULL;
|
|
155
|
+
|
|
156
|
+
// compute temporary memory required
|
|
157
|
+
size_t sort_temp_size;
|
|
158
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
159
|
+
NULL,
|
|
160
|
+
sort_temp_size,
|
|
161
|
+
d_keys,
|
|
162
|
+
d_values,
|
|
163
|
+
n,
|
|
164
|
+
num_segments,
|
|
165
|
+
start_indices,
|
|
166
|
+
end_indices,
|
|
167
|
+
0,
|
|
168
|
+
32,
|
|
169
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
170
|
+
|
|
171
|
+
if (!context)
|
|
172
|
+
context = cuda_context_get_current();
|
|
173
|
+
|
|
174
|
+
RadixSortTemp& temp = g_radix_sort_temp_map[context];
|
|
175
|
+
|
|
176
|
+
if (sort_temp_size > temp.size)
|
|
177
|
+
{
|
|
178
|
+
free_device(WP_CURRENT_CONTEXT, temp.mem);
|
|
179
|
+
temp.mem = alloc_device(WP_CURRENT_CONTEXT, sort_temp_size);
|
|
180
|
+
temp.size = sort_temp_size;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if (mem_out)
|
|
184
|
+
*mem_out = temp.mem;
|
|
185
|
+
if (size_out)
|
|
186
|
+
*size_out = temp.size;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
// segment_start_indices and segment_end_indices are arrays of length num_segments, where segment_start_indices[i] is the index of the first element
|
|
190
|
+
// in the i-th segment and segment_end_indices[i] is the index after the last element in the i-th segment
|
|
191
|
+
// https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedRadixSort.html
|
|
192
|
+
void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
109
193
|
{
|
|
110
194
|
ContextGuard guard(context);
|
|
111
195
|
|
|
@@ -113,15 +197,20 @@ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
|
|
|
113
197
|
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
114
198
|
|
|
115
199
|
RadixSortTemp temp;
|
|
116
|
-
|
|
200
|
+
segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
|
|
117
201
|
|
|
118
202
|
// sort
|
|
119
|
-
check_cuda(cub::
|
|
203
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
120
204
|
temp.mem,
|
|
121
205
|
temp.size,
|
|
122
206
|
d_keys,
|
|
123
207
|
d_values,
|
|
124
|
-
n,
|
|
208
|
+
n,
|
|
209
|
+
num_segments,
|
|
210
|
+
segment_start_indices,
|
|
211
|
+
segment_end_indices,
|
|
212
|
+
0,
|
|
213
|
+
32,
|
|
125
214
|
(cudaStream_t)cuda_stream_get_current()));
|
|
126
215
|
|
|
127
216
|
if (d_keys.Current() != keys)
|
|
@@ -131,10 +220,58 @@ void radix_sort_pairs_device(void* context, float* keys, int* values, int n)
|
|
|
131
220
|
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
132
221
|
}
|
|
133
222
|
|
|
134
|
-
void
|
|
223
|
+
void segmented_sort_pairs_float_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
135
224
|
{
|
|
136
|
-
|
|
225
|
+
segmented_sort_pairs_device(
|
|
137
226
|
WP_CURRENT_CONTEXT,
|
|
138
227
|
reinterpret_cast<float *>(keys),
|
|
139
|
-
reinterpret_cast<int *>(values), n
|
|
228
|
+
reinterpret_cast<int *>(values), n,
|
|
229
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
230
|
+
reinterpret_cast<int *>(segment_end_indices),
|
|
231
|
+
num_segments);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// segment_indices is an array of length num_segments + 1, where segment_indices[i] is the index of the first element in the i-th segment
|
|
235
|
+
// The end of a segment is given by segment_indices[i+1]
|
|
236
|
+
// https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceSegmentedSort.html#a-simple-example
|
|
237
|
+
void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments)
|
|
238
|
+
{
|
|
239
|
+
ContextGuard guard(context);
|
|
240
|
+
|
|
241
|
+
cub::DoubleBuffer<int> d_keys(keys, keys + n);
|
|
242
|
+
cub::DoubleBuffer<int> d_values(values, values + n);
|
|
243
|
+
|
|
244
|
+
RadixSortTemp temp;
|
|
245
|
+
segmented_sort_reserve(WP_CURRENT_CONTEXT, n, num_segments, &temp.mem, &temp.size);
|
|
246
|
+
|
|
247
|
+
// sort
|
|
248
|
+
check_cuda(cub::DeviceSegmentedRadixSort::SortPairs(
|
|
249
|
+
temp.mem,
|
|
250
|
+
temp.size,
|
|
251
|
+
d_keys,
|
|
252
|
+
d_values,
|
|
253
|
+
n,
|
|
254
|
+
num_segments,
|
|
255
|
+
segment_start_indices,
|
|
256
|
+
segment_end_indices,
|
|
257
|
+
0,
|
|
258
|
+
32,
|
|
259
|
+
(cudaStream_t)cuda_stream_get_current()));
|
|
260
|
+
|
|
261
|
+
if (d_keys.Current() != keys)
|
|
262
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, keys, d_keys.Current(), sizeof(float)*n);
|
|
263
|
+
|
|
264
|
+
if (d_values.Current() != values)
|
|
265
|
+
memcpy_d2d(WP_CURRENT_CONTEXT, values, d_values.Current(), sizeof(int)*n);
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
void segmented_sort_pairs_int_device(uint64_t keys, uint64_t values, int n, uint64_t segment_start_indices, uint64_t segment_end_indices, int num_segments)
|
|
269
|
+
{
|
|
270
|
+
segmented_sort_pairs_device(
|
|
271
|
+
WP_CURRENT_CONTEXT,
|
|
272
|
+
reinterpret_cast<int *>(keys),
|
|
273
|
+
reinterpret_cast<int *>(values), n,
|
|
274
|
+
reinterpret_cast<int *>(segment_start_indices),
|
|
275
|
+
reinterpret_cast<int *>(segment_end_indices),
|
|
276
|
+
num_segments);
|
|
140
277
|
}
|
warp/native/sort.h
CHANGED
|
@@ -22,5 +22,12 @@
|
|
|
22
22
|
void radix_sort_reserve(void* context, int n, void** mem_out=NULL, size_t* size_out=NULL);
|
|
23
23
|
void radix_sort_pairs_host(int* keys, int* values, int n);
|
|
24
24
|
void radix_sort_pairs_host(float* keys, int* values, int n);
|
|
25
|
+
void radix_sort_pairs_host(int64_t* keys, int* values, int n);
|
|
25
26
|
void radix_sort_pairs_device(void* context, int* keys, int* values, int n);
|
|
26
|
-
void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
|
|
27
|
+
void radix_sort_pairs_device(void* context, float* keys, int* values, int n);
|
|
28
|
+
void radix_sort_pairs_device(void* context, int64_t* keys, int* values, int n);
|
|
29
|
+
|
|
30
|
+
void segmented_sort_pairs_host(float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
31
|
+
void segmented_sort_pairs_device(void* context, float* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
32
|
+
void segmented_sort_pairs_host(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|
|
33
|
+
void segmented_sort_pairs_device(void* context, int* keys, int* values, int n, int* segment_start_indices, int* segment_end_indices, int num_segments);
|