warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1108 @@
1
+ /*
2
+ * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: Apache-2.0
4
+ *
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ */
17
+
18
+ #pragma once
19
+
20
+ #include "tile.h"
21
+
22
+ #if defined(__clang__)
23
+ // disable warnings related to C++17 extensions on CPU JIT builds
24
+ #pragma clang diagnostic push
25
+ #pragma clang diagnostic ignored "-Wc++17-extensions"
26
+ #endif
27
+
28
+ namespace wp
29
+ {
30
+
31
+
32
+ // After this threshold, using segmented_sort from cub is faster
33
+ // The threshold must be a power of 2
34
+ // The radix sort in this file is consistently slower than the bitonic sort
35
+ #define BITONIC_SORT_THRESHOLD 2048
36
+
37
+ struct UintKeyToUint
38
+ {
39
+ inline CUDA_CALLABLE uint32_t convert(uint32 value)
40
+ {
41
+ return value;
42
+ }
43
+
44
+ inline CUDA_CALLABLE uint32_t max_possible_key_value()
45
+ {
46
+ return 0xFFFFFFFF;
47
+ }
48
+ };
49
+
50
+ struct IntKeyToUint
51
+ {
52
+ inline CUDA_CALLABLE uint32_t convert(int value)
53
+ {
54
+ // Flip the sign bit: ensures negative numbers come before positive numbers
55
+ return static_cast<uint32_t>(value) ^ 0x80000000;
56
+ }
57
+
58
+ inline CUDA_CALLABLE int max_possible_key_value()
59
+ {
60
+ return 2147483647;
61
+ }
62
+ };
63
+
64
+ struct FloatKeyToUint
65
+ {
66
+ //http://stereopsis.com/radix.html
67
+ inline CUDA_CALLABLE uint32_t convert(float value)
68
+ {
69
+ unsigned int i = reinterpret_cast<unsigned int&>(value);
70
+ unsigned int mask = (unsigned int)(-(int)(i >> 31)) | 0x80000000;
71
+ return i ^ mask;
72
+ }
73
+
74
+ inline CUDA_CALLABLE float max_possible_key_value()
75
+ {
76
+ return FLT_MAX;
77
+ }
78
+ };
79
+
80
+
81
+ constexpr inline CUDA_CALLABLE bool is_power_of_two(int x)
82
+ {
83
+ return (x & (x - 1)) == 0;
84
+ }
85
+
86
+ constexpr inline CUDA_CALLABLE int next_higher_pow2(int input)
87
+ {
88
+ if (input <= 0) return 1; // Smallest power of 2 is 1
89
+
90
+ input--; // Decrement to handle already a power of 2 cases
91
+ input |= input >> 1;
92
+ input |= input >> 2;
93
+ input |= input >> 4;
94
+ input |= input >> 8;
95
+ input |= input >> 16;
96
+ input++; // Next power of 2
97
+
98
+ return input;
99
+ }
100
+
101
+
102
+ #if defined(__CUDA_ARCH__)
103
+
104
+
105
+ // Bitonic sort fast pass for small arrays
106
+
107
+ template<typename T>
108
+ inline CUDA_CALLABLE T shfl_xor(unsigned int thread_id, T* sh_mem, unsigned int lane_mask)
109
+ {
110
+ unsigned int source_lane = thread_id ^ lane_mask;
111
+ return sh_mem[source_lane];
112
+ }
113
+
114
+ template<typename K, typename V, int num_loops>
115
+ inline CUDA_CALLABLE void bitonic_sort_single_stage_full_thread_block(int k, unsigned int thread_id, unsigned int stride, K* key_sh_mem, V* val_sh_mem, int length, K max_key_value,
116
+ K* key_register, V* val_register)
117
+ {
118
+ __syncthreads();
119
+ #pragma unroll
120
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
121
+ {
122
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
123
+
124
+ key_register[loop_id] = thread_id2 < length ? key_sh_mem[thread_id2] : max_key_value;
125
+ val_register[loop_id] = thread_id2 < length ? val_sh_mem[thread_id2] : 0;
126
+ }
127
+
128
+ __syncthreads();
129
+
130
+ K s_key[num_loops];
131
+ V s_val[num_loops];
132
+ bool swap[num_loops];
133
+ #pragma unroll
134
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
135
+ {
136
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
137
+
138
+ if(thread_id2 < length)
139
+ {
140
+ s_key[loop_id] = shfl_xor(thread_id2, key_sh_mem, stride);
141
+ s_val[loop_id] = shfl_xor(thread_id2, val_sh_mem, stride);
142
+ swap[loop_id] = (((thread_id2 & stride) != 0 ? key_register[loop_id] > s_key[loop_id] : key_register[loop_id] < s_key[loop_id])) ^ ((thread_id2 & k) == 0);
143
+ }
144
+ }
145
+
146
+ __syncthreads();
147
+
148
+ #pragma unroll
149
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
150
+ {
151
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
152
+ if (thread_id2 < length)
153
+ {
154
+ key_sh_mem[thread_id2] = swap[loop_id] ? s_key[loop_id] : key_register[loop_id];
155
+ val_sh_mem[thread_id2] = swap[loop_id] ? s_val[loop_id] : val_register[loop_id];
156
+ }
157
+ }
158
+ __syncthreads();
159
+ }
160
+
161
+ //stride can be 1, 2, 4, 8, 16
162
+ template<typename K, typename V>
163
+ inline CUDA_CALLABLE void bitonic_sort_single_stage_full_warp(int k, unsigned int thread_id, int stride, K& key, V& val)
164
+ {
165
+ auto s_key = __shfl_xor_sync(0xFFFFFFFFu, key, stride);
166
+ auto s_val = __shfl_xor_sync(0xFFFFFFFFu, val, stride);
167
+ auto swap = (((thread_id & stride) != 0 ? key > s_key : key < s_key)) ^ ((thread_id & k) == 0);
168
+ key = swap ? s_key : key;
169
+ val = swap ? s_val : val;
170
+ }
171
+
172
+
173
+ //Sorts 32 elements according to keys
174
+ template<typename K, typename V>
175
+ inline CUDA_CALLABLE void bitonic_sort_single_warp(unsigned int thread_id, K& key, V& val)
176
+ {
177
+ #pragma unroll
178
+ for (int k = 2; k <= 32; k <<= 1)
179
+ {
180
+ #pragma unroll
181
+ for (int stride = k / 2; stride > 0; stride >>= 1)
182
+ {
183
+ bitonic_sort_single_stage_full_warp(k, thread_id, stride, key, val);
184
+ }
185
+ }
186
+ }
187
+
188
+ template<typename K, typename V, typename KeyToUint>
189
+ inline CUDA_CALLABLE void bitonic_sort_single_warp(int thread_id,
190
+ K* keys_input,
191
+ V* values_input,
192
+ int num_elements_to_sort)
193
+ {
194
+ KeyToUint key_converter;
195
+
196
+ __syncwarp();
197
+
198
+ K key = thread_id < num_elements_to_sort ? keys_input[thread_id] : key_converter.max_possible_key_value();
199
+ V value;
200
+ if(thread_id < num_elements_to_sort)
201
+ value = values_input[thread_id];
202
+
203
+ __syncwarp();
204
+ bitonic_sort_single_warp(thread_id, key, value);
205
+ __syncwarp();
206
+
207
+ if(thread_id < num_elements_to_sort)
208
+ {
209
+ keys_input[thread_id] = key;
210
+ values_input[thread_id] = value;
211
+ }
212
+ __syncwarp();
213
+ }
214
+
215
+
216
+ //Sorts according to keys
217
+ template<int max_num_elements, typename K, typename V>
218
+ inline CUDA_CALLABLE void bitonic_sort_pow2_length(unsigned int thread_id, K* key_sh_mem, V* val_sh_mem, int length, K key_max_possible_value)
219
+ {
220
+ constexpr int num_loops = (max_num_elements + WP_TILE_BLOCK_DIM - 1) / WP_TILE_BLOCK_DIM;
221
+ K key[num_loops];
222
+ V val[num_loops];
223
+
224
+ #pragma unroll
225
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
226
+ {
227
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
228
+ key[loop_id] = thread_id2 < length ? key_sh_mem[thread_id2] : key_max_possible_value;
229
+ if (thread_id2 < length)
230
+ val[loop_id] = val_sh_mem[thread_id2];
231
+ }
232
+
233
+ __syncthreads();
234
+ bool full_block_sort_active = false;
235
+
236
+ for (int k = 2; k <= length; k <<= 1)
237
+ {
238
+ for (int stride = k / 2; stride > 0; stride >>= 1)
239
+ {
240
+ if (stride <= 16) //no inter-warp communication needed up to stride 16
241
+ {
242
+ if(full_block_sort_active)
243
+ {
244
+ __syncthreads();
245
+ #pragma unroll
246
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
247
+ {
248
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
249
+
250
+ //Switch from shared mem to registers
251
+ if (thread_id2 < length)
252
+ {
253
+ key[loop_id] = key_sh_mem[thread_id2];
254
+ val[loop_id] = val_sh_mem[thread_id2];
255
+ }
256
+ }
257
+ full_block_sort_active = false;
258
+ __syncthreads();
259
+ }
260
+
261
+ #pragma unroll
262
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
263
+ {
264
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
265
+ bitonic_sort_single_stage_full_warp(k, thread_id2, stride, key[loop_id], val[loop_id]);
266
+ }
267
+ }
268
+ else
269
+ {
270
+ if (!full_block_sort_active)
271
+ {
272
+ __syncthreads();
273
+ #pragma unroll
274
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
275
+ {
276
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
277
+
278
+ //Switch from registers t0 shared mem
279
+ if (thread_id2 < length)
280
+ {
281
+ key_sh_mem[thread_id2] = key[loop_id];
282
+ val_sh_mem[thread_id2] = val[loop_id];
283
+ }
284
+ }
285
+ full_block_sort_active = true;
286
+ __syncthreads();
287
+ }
288
+
289
+ bitonic_sort_single_stage_full_thread_block<K, V, num_loops>(k, thread_id, (unsigned int)stride, key_sh_mem, val_sh_mem, length, key_max_possible_value, key, val);
290
+ }
291
+ }
292
+ }
293
+
294
+ if (!full_block_sort_active)
295
+ {
296
+ #pragma unroll
297
+ for (int loop_id = 0; loop_id < num_loops; ++loop_id)
298
+ {
299
+ int thread_id2 = loop_id * WP_TILE_BLOCK_DIM + thread_id;
300
+ //Switch from registers t0 shared mem
301
+ if (thread_id2 < length)
302
+ {
303
+ key_sh_mem[thread_id2] = key[loop_id];
304
+ val_sh_mem[thread_id2] = val[loop_id];
305
+ }
306
+ }
307
+ full_block_sort_active = true;
308
+ __syncthreads();
309
+ }
310
+ }
311
+
312
+ //Allocates shared memory to buffer the arrays that need to be sorted
313
+ template <int max_num_elements, typename K, typename V, typename KeyToUint>
314
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
315
+ int thread_id,
316
+ K* keys_input,
317
+ V* values_input,
318
+ int num_elements_to_sort)
319
+ {
320
+ if constexpr(max_num_elements < 32)
321
+ {
322
+ //Fast track - single warp sort
323
+ if (thread_id < 32)
324
+ bitonic_sort_single_warp<K, V, KeyToUint>(thread_id, keys_input, values_input, num_elements_to_sort);
325
+ __syncthreads();
326
+ }
327
+ else
328
+ {
329
+ KeyToUint key_converter;
330
+ const K key_max_possible_value = key_converter.max_possible_key_value();
331
+
332
+ constexpr int shared_mem_count = next_higher_pow2(max_num_elements);
333
+
334
+ __shared__ K keys_shared_mem[shared_mem_count]; //TODO: This shared memory can be avoided if keys_input is already shared memory
335
+ __shared__ V values_shared_mem[shared_mem_count]; //TODO: This shared memory can be avoided if values_input is already shared memory
336
+
337
+ for(int i = thread_id; i < shared_mem_count; i += WP_TILE_BLOCK_DIM)
338
+ {
339
+ if (i < num_elements_to_sort)
340
+ {
341
+ keys_shared_mem[i] = keys_input[i];
342
+ values_shared_mem[i] = values_input[i];
343
+ }
344
+ else
345
+ keys_shared_mem[i] = key_max_possible_value;
346
+ }
347
+ __syncthreads();
348
+
349
+ bitonic_sort_pow2_length<shared_mem_count, K, V>((unsigned int)thread_id, keys_shared_mem, values_shared_mem, shared_mem_count, key_max_possible_value);
350
+
351
+ __syncthreads();
352
+
353
+ for (int i = thread_id; i < num_elements_to_sort; i += WP_TILE_BLOCK_DIM)
354
+ {
355
+ keys_input[i] = keys_shared_mem[i];
356
+ values_input[i] = values_shared_mem[i];
357
+ }
358
+ __syncthreads();
359
+ }
360
+ }
361
+
362
+
363
+ // Specialization for int keys
364
+ template <int max_num_elements, typename V>
365
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
366
+ int thread_id,
367
+ int* keys_input,
368
+ V* values_input,
369
+ int num_elements_to_sort)
370
+ {
371
+ bitonic_sort_thread_block_shared_mem<max_num_elements, int, V, IntKeyToUint>(
372
+ thread_id, keys_input, values_input, num_elements_to_sort);
373
+ }
374
+
375
+ // Specialization for unsigned int keys
376
+ template <int max_num_elements, typename V>
377
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
378
+ int thread_id,
379
+ unsigned int* keys_input,
380
+ V* values_input,
381
+ int num_elements_to_sort)
382
+ {
383
+ bitonic_sort_thread_block_shared_mem<max_num_elements, unsigned int, V, UintKeyToUint>(
384
+ thread_id, keys_input, values_input, num_elements_to_sort);
385
+ }
386
+
387
+ // Specialization for float keys
388
+ template <int max_num_elements, typename V>
389
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_shared_mem(
390
+ int thread_id,
391
+ float* keys_input,
392
+ V* values_input,
393
+ int num_elements_to_sort)
394
+ {
395
+ bitonic_sort_thread_block_shared_mem<max_num_elements, float, V, FloatKeyToUint>(
396
+ thread_id, keys_input, values_input, num_elements_to_sort);
397
+ }
398
+
399
+
400
+
401
+ // Ideally keys_input and values_input point into fast memory (shared memory)
402
+ template <int max_num_elements, typename K, typename V, typename KeyToUint>
403
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
404
+ int thread_id,
405
+ K* keys_input,
406
+ V* values_input,
407
+ int num_elements_to_sort)
408
+ {
409
+ if constexpr(max_num_elements < 32)
410
+ {
411
+ //Fast track - single warp sort
412
+ if (thread_id < 32)
413
+ bitonic_sort_single_warp<K, V, KeyToUint>(thread_id, keys_input, values_input, num_elements_to_sort);
414
+ __syncthreads();
415
+ }
416
+ else
417
+ {
418
+ assert(num_elements_to_sort <= num_threads);
419
+
420
+ KeyToUint key_converter;
421
+ const K key_max_possible_value = key_converter.max_possible_key_value();
422
+
423
+ bitonic_sort_pow2_length<max_num_elements, K, V>((unsigned int)thread_id, keys_input, values_input, num_elements_to_sort, key_max_possible_value);
424
+ }
425
+ }
426
+
427
+ // Specialization for int keys
428
+ template <int max_num_elements, typename V>
429
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
430
+ int thread_id,
431
+ int* keys_input,
432
+ V* values_input,
433
+ int num_elements_to_sort)
434
+ {
435
+ bitonic_sort_thread_block_direct<max_num_elements, int, V, IntKeyToUint>(
436
+ thread_id, keys_input, values_input, num_elements_to_sort);
437
+ }
438
+
439
+ // Specialization for unsigned int keys
440
+ template <int max_num_elements, typename V>
441
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
442
+ int thread_id,
443
+ unsigned int* keys_input,
444
+ V* values_input,
445
+ int num_elements_to_sort)
446
+ {
447
+ bitonic_sort_thread_block_direct<max_num_elements, unsigned int, V, UintKeyToUint>(
448
+ thread_id, keys_input, values_input, num_elements_to_sort);
449
+ }
450
+
451
+ // Specialization for float keys
452
+ template <int max_num_elements, typename V>
453
+ inline CUDA_CALLABLE void bitonic_sort_thread_block_direct(
454
+ int thread_id,
455
+ float* keys_input,
456
+ V* values_input,
457
+ int num_elements_to_sort)
458
+ {
459
+ bitonic_sort_thread_block_direct<max_num_elements, float, V, FloatKeyToUint>(
460
+ thread_id, keys_input, values_input, num_elements_to_sort);
461
+ }
462
+
463
+ // End bitonic sort
464
+
465
+ inline CUDA_CALLABLE int warp_scan_inclusive(int lane, unsigned int ballot_mask)
466
+ {
467
+ uint32_t mask = ((1u << (lane + 1)) - 1);
468
+ return __popc(ballot_mask & mask);
469
+ }
470
+
471
+ inline CUDA_CALLABLE int warp_scan_inclusive(int lane, unsigned int mask, bool thread_contributes_element)
472
+ {
473
+ return warp_scan_inclusive(lane, __ballot_sync(mask, thread_contributes_element));
474
+ }
475
+
476
+ template<typename T>
477
+ inline CUDA_CALLABLE T warp_scan_inclusive(int lane, T value)
478
+ {
479
+ //Computes an inclusive cumulative sum
480
+ #pragma unroll
481
+ for (int i = 1; i <= 32; i *= 2)
482
+ {
483
+ auto n = __shfl_up_sync(0xffffffffu, value, i, 32);
484
+
485
+ if (lane >= i)
486
+ value = value + n;
487
+ }
488
+ return value;
489
+ }
490
+
491
+ template<typename T>
492
+ inline CUDA_CALLABLE T warp_scan_exclusive(int lane, T value)
493
+ {
494
+ T scan = warp_scan_inclusive(lane, value);
495
+ return scan - value;
496
+ }
497
+
498
+ template <int num_warps, int num_threads, typename K, typename V, typename KeyToUint>
499
+ inline CUDA_CALLABLE void radix_sort_thread_block_core(
500
+ int thread_id,
501
+ K* keys_input, K* keys_tmp,
502
+ V* values_input, V* values_tmp,
503
+ int num_elements_to_sort)
504
+ {
505
+ KeyToUint key_converter;
506
+
507
+ int num_bits_to_sort = 32; //Sort all bits because that's what the bitonic fast pass does as well
508
+
509
+ const int warp_id = thread_id / 32;
510
+ const int lane_id = thread_id & 31;
511
+
512
+ const int bits_per_pass = 4; //Higher than 5 is currently not supported - 2^5=32 is the warp size and is still just fine
513
+ const int lowest_bits_mask = (1 << bits_per_pass) - 1;
514
+ const int num_scan_buckets = (1 << bits_per_pass);
515
+
516
+ const int num_warp_passes = (num_scan_buckets + num_warps - 1) / num_warps;
517
+
518
+ __shared__ int buckets[num_scan_buckets];
519
+ __shared__ int buckets2[num_scan_buckets];
520
+ __shared__ int buckets_cumulative_sum[num_scan_buckets];
521
+ __shared__ int shared_mem[num_warps][num_scan_buckets];
522
+
523
+ const int num_passes = (num_bits_to_sort + bits_per_pass - 1) / bits_per_pass;
524
+ const int num_inner_loops = (num_elements_to_sort + num_threads - 1) / num_threads;
525
+
526
+ for (int pass_id = 0; pass_id < num_passes; ++pass_id)
527
+ {
528
+ __syncthreads();
529
+ if (thread_id < num_scan_buckets)
530
+ {
531
+ buckets[lane_id] = 0;
532
+ buckets2[lane_id] = 0;
533
+ }
534
+ __syncthreads();
535
+
536
+ int shift = pass_id * bits_per_pass;
537
+
538
+ for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
539
+ {
540
+ int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
541
+
542
+ for (int b = 0; b < num_scan_buckets; b++)
543
+ {
544
+ bool contributes = digit == b;
545
+ int sum_per_warp = warp_scan_inclusive(lane_id, 0xFFFFFFFF, contributes);
546
+
547
+ if (lane_id == 31)
548
+ shared_mem[warp_id][b] = sum_per_warp;
549
+ }
550
+ __syncthreads();
551
+
552
+ for(int b=warp_id;b< num_warp_passes * num_warps;b += num_warps)
553
+ {
554
+ int f = lane_id < num_warps ? shared_mem[lane_id][b] : 0;
555
+ f = warp_scan_inclusive(lane_id, f);
556
+ if (lane_id == 31)
557
+ buckets[b] += f;
558
+ }
559
+ __syncthreads();
560
+ }
561
+
562
+ #if VALIDATE_SORT
563
+ if (thread_id == 0)
564
+ {
565
+ for (int b = 0; b < num_scan_buckets; b++)
566
+ {
567
+ int bucket_sum = 0;
568
+ for (int j = 0; j < num_elements_to_sort; j++)
569
+ {
570
+ int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
571
+ if (digit == b)
572
+ ++bucket_sum;
573
+ }
574
+ assert(buckets[b] == bucket_sum);
575
+ }
576
+ }
577
+ __syncthreads();
578
+ #endif
579
+
580
+ if (warp_id == 0)
581
+ {
582
+ int value = lane_id < num_scan_buckets ? buckets[lane_id] : 0;
583
+ value = warp_scan_exclusive(lane_id, value);
584
+ if (lane_id < num_scan_buckets)
585
+ buckets_cumulative_sum[lane_id] = value;
586
+
587
+ if (lane_id == num_scan_buckets - 1)
588
+ assert(debug + value == num_elements_to_sort);
589
+ }
590
+
591
+ __syncthreads();
592
+
593
+ #if VALIDATE_SORT
594
+ if(thread_id == 0)
595
+ {
596
+ for (int b = 0; b < num_scan_buckets; b++)
597
+ {
598
+ int bucket_sum = 0;
599
+ for(int j=0; j<num_elements_to_sort; j++)
600
+ {
601
+ int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
602
+ if (digit == b)
603
+ ++bucket_sum;
604
+ }
605
+ assert(buckets[b] == bucket_sum);
606
+ }
607
+
608
+ int exclusive_bucket_sum = 0;
609
+ for (int b = 0; b < num_scan_buckets; b++)
610
+ {
611
+ assert(exclusive_bucket_sum == buckets_cumulative_sum[b]);
612
+ exclusive_bucket_sum += buckets[b];
613
+ }
614
+ assert(exclusive_bucket_sum == num_elements_to_sort);
615
+ }
616
+ __syncthreads();
617
+ #endif
618
+
619
+ //Now buckets holds numBuckets inclusive cumulative sums (e. g. 16 sums for 4 bit radix sort - 2^4=16)
620
+ //The problem is that we either store local_offset_per_thread for every element array (potentially many) or we recompute it again
621
+ for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
622
+ {
623
+ int digit = j < num_elements_to_sort ? (int)((key_converter.convert(keys_input[j]) >> shift) & lowest_bits_mask) : num_scan_buckets;
624
+
625
+ int local_offset_per_thread = 0;
626
+
627
+ for (int b = 0; b < num_scan_buckets; b++)
628
+ {
629
+ bool contributes = digit == b;
630
+ int sum_per_warp = warp_scan_inclusive(lane_id, 0xFFFFFFFF, contributes);
631
+ if (lane_id == 31)
632
+ shared_mem[warp_id][b] = sum_per_warp;
633
+
634
+ if (contributes)
635
+ local_offset_per_thread = sum_per_warp - 1; //-1 because of inclusive scan and local_offset_per_thread needs exclusive scan
636
+ }
637
+
638
+ for (int b = 0; b < num_scan_buckets; b++)
639
+ {
640
+ __syncthreads();
641
+ int global_offset = buckets2[b];
642
+ __syncthreads();
643
+
644
+ int f = lane_id < num_warps ? shared_mem[lane_id][b] : 0;
645
+ int inclusive_scan = warp_scan_inclusive(lane_id, f);
646
+ if (lane_id == 31 && warp_id == 0)
647
+ {
648
+ buckets2[b] += inclusive_scan;
649
+ }
650
+
651
+ int warp_offset = __shfl_sync(0xFFFFFFFF, inclusive_scan - f, warp_id); //-f because warp_offset needs to be an exclusive scan
652
+
653
+ bool contributes = digit == b;
654
+ if (contributes)
655
+ {
656
+ local_offset_per_thread += global_offset + warp_offset;
657
+
658
+ #if VALIDATE_SORT
659
+ int curr = buckets_cumulative_sum[b];
660
+ int next = b + 1 < num_scan_buckets ? buckets_cumulative_sum[b + 1] : num_elements_to_sort;
661
+ assert(local_offset_per_thread < next - curr && local_offset_per_thread >= 0);
662
+ #endif
663
+ }
664
+ }
665
+ __syncthreads();
666
+
667
+ if (j < num_elements_to_sort)
668
+ {
669
+ int final_offset = buckets_cumulative_sum[digit] + local_offset_per_thread;
670
+
671
+ keys_tmp[final_offset] = keys_input[j];
672
+ values_tmp[final_offset] = values_input[j];
673
+ }
674
+ }
675
+
676
+ __syncthreads();
677
+
678
+ #if VALIDATE_SORT
679
+ for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
680
+ {
681
+ if(j>0 && j < num_elements_to_sort)
682
+ {
683
+ int digit1 = (int)((keys_tmp[j-1] >> shift) & lowest_bits_mask);
684
+ int digit2 = (int)((keys_tmp[j] >> shift) & lowest_bits_mask);
685
+
686
+ assert(digit1<=digit2);
687
+ }
688
+ }
689
+ __syncthreads();
690
+ #endif
691
+
692
+ auto tmp = keys_tmp;
693
+ keys_tmp = keys_input;
694
+ keys_input = tmp;
695
+
696
+ auto tmp2 = values_tmp;
697
+ values_tmp = values_input;
698
+ values_input = tmp2;
699
+ }
700
+
701
+ //For odd number of passes, the result is the const& wrong array - copy it over
702
+ if (num_passes % 2 != 0)
703
+ {
704
+ for (int j = thread_id; j < num_inner_loops * num_threads; j += num_threads)
705
+ {
706
+ if (j < num_elements_to_sort)
707
+ {
708
+ keys_tmp[j] = keys_input[j];
709
+ values_tmp[j] = values_input[j];
710
+ }
711
+ }
712
+
713
+ auto tmp = keys_tmp;
714
+ keys_tmp = keys_input;
715
+ keys_input = tmp;
716
+
717
+ auto tmp2 = values_tmp;
718
+ values_tmp = values_input;
719
+ values_input = tmp2;
720
+ }
721
+ }
722
+
723
+
724
+
725
+
726
+ template <int num_warps, int num_threads, typename V>
727
+ inline CUDA_CALLABLE void radix_sort_thread_block(
728
+ int thread_id,
729
+ int* keys_input, int* keys_tmp,
730
+ V* values_input, V* values_tmp,
731
+ int num_elements_to_sort)
732
+ {
733
+ radix_sort_thread_block_core<num_warps, num_threads, int, V, IntKeyToUint>(
734
+ thread_id, keys_input, keys_tmp,
735
+ values_input, values_tmp, num_elements_to_sort);
736
+ }
737
+
738
+ template <int num_warps, int num_threads, typename V>
739
+ inline CUDA_CALLABLE void radix_sort_thread_block(
740
+ int thread_id,
741
+ unsigned int* keys_input, unsigned int* keys_tmp,
742
+ V* values_input, V* values_tmp,
743
+ int num_elements_to_sort)
744
+ {
745
+ radix_sort_thread_block_core<num_warps, num_threads, unsigned int, V, UintKeyToUint>(
746
+ thread_id, keys_input, keys_tmp,
747
+ values_input, values_tmp,
748
+ num_elements_to_sort);
749
+ }
750
+
751
+ template <int num_warps, int num_threads, typename V>
752
+ inline CUDA_CALLABLE void radix_sort_thread_block(
753
+ int thread_id,
754
+ float* keys_input, float* keys_tmp,
755
+ V* values_input, V* values_tmp,
756
+ int num_elements_to_sort)
757
+ {
758
+ radix_sort_thread_block_core<num_warps, num_threads, float, V, FloatKeyToUint>(
759
+ thread_id, keys_input, keys_tmp,
760
+ values_input, values_tmp,
761
+ num_elements_to_sort);
762
+ }
763
+
764
+
765
+ template <typename TileK, typename TileV>
766
+ void tile_sort(TileK& t, TileV& t2)
767
+ {
768
+ using T = typename TileK::Type;
769
+ using V = typename TileV::Type;
770
+
771
+ constexpr int num_elements_to_sort = TileK::Layout::Shape::size();
772
+ T* keys = &t.data(0);
773
+ V* values = &t2.data(0);
774
+
775
+ //Trim away the code that won't be used - possible because the number of elements to sort is known at compile time
776
+ if constexpr (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
777
+ {
778
+ if constexpr(is_power_of_two(num_elements_to_sort))
779
+ bitonic_sort_thread_block_direct<num_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
780
+ else
781
+ bitonic_sort_thread_block_shared_mem<num_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
782
+ }
783
+ else
784
+ {
785
+ __shared__ T keys_tmp[num_elements_to_sort];
786
+ __shared__ V values_tmp[num_elements_to_sort];
787
+
788
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
789
+
790
+ radix_sort_thread_block<warp_count, WP_TILE_BLOCK_DIM, V>(WP_TILE_THREAD_IDX, keys, keys_tmp,
791
+ values, values_tmp, num_elements_to_sort);
792
+ }
793
+
794
+ WP_TILE_SYNC();
795
+ }
796
+
797
+ template <typename TileK, typename TileV>
798
+ void tile_sort(TileK& t, TileV& t2, int start, int length)
799
+ {
800
+ using T = typename TileK::Type;
801
+ using V = typename TileV::Type;
802
+
803
+ constexpr int max_elements_to_sort = TileK::Layout::Shape::size();
804
+ const int num_elements_to_sort = length;
805
+ T* keys = &t.data(start);
806
+ V* values = &t2.data(start);
807
+
808
+ if (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
809
+ {
810
+ if (is_power_of_two(num_elements_to_sort))
811
+ bitonic_sort_thread_block_direct<max_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
812
+ else
813
+ bitonic_sort_thread_block_shared_mem<max_elements_to_sort, V>(WP_TILE_THREAD_IDX, keys, values, num_elements_to_sort);
814
+ }
815
+ else
816
+ {
817
+ if constexpr (max_elements_to_sort > BITONIC_SORT_THRESHOLD)
818
+ {
819
+ __shared__ T keys_tmp[max_elements_to_sort];
820
+ __shared__ V values_tmp[max_elements_to_sort];
821
+
822
+ constexpr int warp_count = (WP_TILE_BLOCK_DIM + WP_TILE_WARP_SIZE - 1)/WP_TILE_WARP_SIZE;
823
+
824
+ radix_sort_thread_block<warp_count, WP_TILE_BLOCK_DIM, V>(WP_TILE_THREAD_IDX, keys, keys_tmp,
825
+ values, values_tmp, num_elements_to_sort);
826
+ }
827
+ }
828
+
829
+ WP_TILE_SYNC();
830
+ }
831
+
832
+ #else
833
+
834
+ // CPU implementation
835
+
836
+ template <typename K>
837
+ void swap_elements(K& a, K& b)
838
+ {
839
+ K tmp = a;
840
+ a = b;
841
+ b = tmp;
842
+ }
843
+
844
+ // length must be a power of two
845
+ template <typename K, typename V>
846
+ void bitonic_sort_pairs_pow2_length_cpu(K* keys, V* values, int length)
847
+ {
848
+ for (int k = 2; k <= length; k *= 2)
849
+ {
850
+ for (int stride = k / 2; stride > 0; stride /= 2)
851
+ {
852
+ for (int i = 0; i < length; i++)
853
+ {
854
+ int swap_idx = i ^ stride;
855
+ if (swap_idx > i)
856
+ {
857
+ bool ascending = ((i & k) == 0);
858
+ if ((ascending && keys[i] > keys[swap_idx]) || (!ascending && keys[i] < keys[swap_idx]))
859
+ {
860
+ swap_elements(keys[i], keys[swap_idx]);
861
+ swap_elements(values[i], values[swap_idx]);
862
+ }
863
+ }
864
+ }
865
+ }
866
+ }
867
+ }
868
+
869
+ template <typename K, typename V, int max_size, typename KeyToUint>
870
+ void bitonic_sort_pairs_general_size_cpu(K* keys, V* values, int length)
871
+ {
872
+ constexpr int pow2_size = next_higher_pow2(max_size);
873
+
874
+ K keys_tmp[pow2_size];
875
+ V values_tmp[pow2_size];
876
+
877
+ KeyToUint converter;
878
+ K max_key = converter.max_possible_key_value();
879
+
880
+ for(int i=0; i<pow2_size; ++i)
881
+ {
882
+ keys_tmp[i] = i < length ? keys[i] : max_key;
883
+ if(i < length)
884
+ values_tmp[i] = values[i];
885
+ }
886
+
887
+ bitonic_sort_pairs_pow2_length_cpu(keys_tmp, values_tmp, pow2_size);
888
+
889
+ for(int i=0; i<length; ++i)
890
+ {
891
+ keys[i] = keys_tmp[i];
892
+ values[i] = values_tmp[i];
893
+ }
894
+ }
895
+
896
+ template <typename V, int max_size>
897
+ void bitonic_sort_pairs_general_size_cpu(unsigned int* keys, V* values, int length)
898
+ {
899
+ bitonic_sort_pairs_general_size_cpu<unsigned int, V, max_size, UintKeyToUint>(keys, values, length);
900
+ }
901
+
902
+ template <typename V, int max_size>
903
+ void bitonic_sort_pairs_general_size_cpu(int* keys, V* values, int length)
904
+ {
905
+ bitonic_sort_pairs_general_size_cpu<int, V, max_size, IntKeyToUint>(keys, values, length);
906
+ }
907
+
908
+ template <typename V, int max_size>
909
+ void bitonic_sort_pairs_general_size_cpu(float* keys, V* values, int length)
910
+ {
911
+ bitonic_sort_pairs_general_size_cpu<float, V, max_size, FloatKeyToUint>(keys, values, length);
912
+ }
913
+
914
+
915
+
916
+ template <typename K, typename V, typename KeyToUint>
917
+ void radix_sort_pairs_cpu_core(K* keys, K* aux_keys, V* values, V* aux_values, int n)
918
+ {
919
+ KeyToUint converter;
920
+ static unsigned int tables[2][1 << 16];
921
+ memset(tables, 0, sizeof(tables));
922
+
923
+ // build histograms
924
+ for (int i=0; i < n; ++i)
925
+ {
926
+ const unsigned int k = converter.convert(keys[i]);
927
+ const unsigned short low = k & 0xffff;
928
+ const unsigned short high = k >> 16;
929
+
930
+ ++tables[0][low];
931
+ ++tables[1][high];
932
+ }
933
+
934
+ // convert histograms to offset tables in-place
935
+ unsigned int offlow = 0;
936
+ unsigned int offhigh = 0;
937
+
938
+ for (int i=0; i < 65536; ++i)
939
+ {
940
+ const unsigned int newofflow = offlow + tables[0][i];
941
+ const unsigned int newoffhigh = offhigh + tables[1][i];
942
+
943
+ tables[0][i] = offlow;
944
+ tables[1][i] = offhigh;
945
+
946
+ offlow = newofflow;
947
+ offhigh = newoffhigh;
948
+ }
949
+
950
+ // pass 1 - sort by low 16 bits
951
+ for (int i=0; i < n; ++i)
952
+ {
953
+ // lookup offset of input
954
+ const K f = keys[i];
955
+ const unsigned int k = converter.convert(f);
956
+ const V v = values[i];
957
+ const unsigned int b = k & 0xffff;
958
+
959
+ // find offset and increment
960
+ const unsigned int offset = tables[0][b]++;
961
+
962
+ aux_keys[offset] = f;
963
+ aux_values[offset] = v;
964
+ }
965
+
966
+ // pass 2 - sort by high 16 bits
967
+ for (int i=0; i < n; ++i)
968
+ {
969
+ // lookup offset of input
970
+ const K f = aux_keys[i];
971
+ const unsigned int k = converter.convert(f);
972
+ const V v = aux_values[i];
973
+
974
+ const unsigned int b = k >> 16;
975
+
976
+ const unsigned int offset = tables[1][b]++;
977
+
978
+ keys[offset] = f;
979
+ values[offset] = v;
980
+ }
981
+ }
982
+
983
+ template <typename V>
984
+ inline void radix_sort_pairs_cpu(
985
+ int* keys_input,
986
+ int* keys_aux,
987
+ V* values_input,
988
+ V* values_aux,
989
+ int num_elements_to_sort)
990
+ {
991
+ radix_sort_pairs_cpu_core<int, V, IntKeyToUint>(
992
+ keys_input, keys_aux,
993
+ values_input, values_aux,
994
+ num_elements_to_sort);
995
+ }
996
+
997
+ template <typename V>
998
+ inline void radix_sort_pairs_cpu(
999
+ unsigned int* keys_input,
1000
+ unsigned int* keys_aux,
1001
+ V* values_input,
1002
+ V* values_aux,
1003
+ int num_elements_to_sort)
1004
+ {
1005
+ radix_sort_pairs_cpu_core<unsigned int, V, UintKeyToUint>(
1006
+ keys_input, keys_aux,
1007
+ values_input, values_aux,
1008
+ num_elements_to_sort);
1009
+ }
1010
+
1011
+ template <typename V>
1012
+ inline void radix_sort_pairs_cpu(
1013
+ float* keys_input,
1014
+ float* keys_aux,
1015
+ V* values_input,
1016
+ V* values_aux,
1017
+ int num_elements_to_sort)
1018
+ {
1019
+ radix_sort_pairs_cpu_core<float, V, FloatKeyToUint>(
1020
+ keys_input, keys_aux,
1021
+ values_input, values_aux,
1022
+ num_elements_to_sort);
1023
+ }
1024
+
1025
+
1026
+
1027
+ template <typename TileK, typename TileV>
1028
+ void tile_sort(TileK& t, TileV& t2)
1029
+ {
1030
+ using T = typename TileK::Type;
1031
+ using V = typename TileV::Type;
1032
+
1033
+ constexpr int num_elements_to_sort = TileK::Layout::Shape::size();
1034
+ T* keys = &t.data(0);
1035
+ V* values = &t2.data(0);
1036
+
1037
+ //Trim away the code that won't be used - possible because the number of elements to sort is known at compile time
1038
+ if constexpr (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
1039
+ {
1040
+ if constexpr(is_power_of_two(num_elements_to_sort))
1041
+ bitonic_sort_pairs_pow2_length_cpu<T, V>(keys, values, num_elements_to_sort);
1042
+ else
1043
+ bitonic_sort_pairs_general_size_cpu<V, num_elements_to_sort>(keys, values, num_elements_to_sort);
1044
+ }
1045
+ else
1046
+ {
1047
+ T keys_tmp[num_elements_to_sort];
1048
+ V values_tmp[num_elements_to_sort];
1049
+
1050
+ radix_sort_pairs_cpu<V>(keys, keys_tmp, values, values_tmp, num_elements_to_sort);
1051
+ }
1052
+
1053
+ WP_TILE_SYNC();
1054
+ }
1055
+
1056
+ template <typename TileK, typename TileV>
1057
+ void tile_sort(TileK& t, TileV& t2, int start, int length)
1058
+ {
1059
+ using T = typename TileK::Type;
1060
+ using V = typename TileV::Type;
1061
+
1062
+ constexpr int max_elements_to_sort = TileK::Layout::Shape::size();
1063
+ const int num_elements_to_sort = length;
1064
+ T* keys = &t.data(start);
1065
+ V* values = &t2.data(start);
1066
+
1067
+ if (num_elements_to_sort <= BITONIC_SORT_THRESHOLD)
1068
+ {
1069
+ if (is_power_of_two(num_elements_to_sort))
1070
+ bitonic_sort_pairs_pow2_length_cpu<T, V>(keys, values, num_elements_to_sort);
1071
+ else
1072
+ bitonic_sort_pairs_general_size_cpu<V, max_elements_to_sort>(keys, values, num_elements_to_sort);
1073
+ }
1074
+ else
1075
+ {
1076
+ if constexpr (max_elements_to_sort > BITONIC_SORT_THRESHOLD)
1077
+ {
1078
+ T keys_tmp[max_elements_to_sort];
1079
+ V values_tmp[max_elements_to_sort];
1080
+
1081
+ radix_sort_pairs_cpu<V>(keys, keys_tmp, values, values_tmp, num_elements_to_sort);
1082
+ }
1083
+ }
1084
+
1085
+ WP_TILE_SYNC();
1086
+ }
1087
+
1088
+
1089
+ #endif // !defined(__CUDA_ARCH__)
1090
+
1091
+
1092
+ template <typename TileK, typename TileV>
1093
+ inline void adj_tile_sort(TileK& t, TileV& t2, TileK& adj_t1, TileV& adj_t2)
1094
+ {
1095
+ // todo: general purpose sort gradients not implemented
1096
+ }
1097
+
1098
+ template <typename TileK, typename TileV>
1099
+ inline void adj_tile_sort(TileK& t, TileV& t2, int start, int length, TileK& adj_t1, TileV& adj_t2, int adj_start, int adj_length)
1100
+ {
1101
+ // todo: general purpose sort gradients not implemented
1102
+ }
1103
+
1104
+ } // namespace wp
1105
+
1106
+ #if defined(__clang__)
1107
+ #pragma clang diagnostic pop
1108
+ #endif